libs5-go/protocol/signed_message.go

195 lines
3.8 KiB
Go

package protocol
import (
"bytes"
"crypto/ed25519"
"errors"
"git.lumeweb.com/LumeWeb/libs5-go/config"
"git.lumeweb.com/LumeWeb/libs5-go/encoding"
"git.lumeweb.com/LumeWeb/libs5-go/types"
"github.com/vmihailenco/msgpack/v5"
"go.uber.org/zap"
"io"
)
var (
_ IncomingMessage = (*SignedMessage)(nil)
_ msgpack.CustomDecoder = (*signedMessageReader)(nil)
_ msgpack.CustomEncoder = (*SignedMessage)(nil)
)
var (
errInvalidSignature = errors.New("Invalid signature found")
)
type SignedMessage struct {
nodeId *encoding.NodeId
signature []byte
message []byte
HandshakeRequirement
}
func (s *SignedMessage) NodeId() *encoding.NodeId {
return s.nodeId
}
func (s *SignedMessage) SetNodeId(nodeId *encoding.NodeId) {
s.nodeId = nodeId
}
func (s *SignedMessage) SetSignature(signature []byte) {
s.signature = signature
}
func (s *SignedMessage) SetMessage(message []byte) {
s.message = message
}
func NewSignedMessageRequest(message []byte) *SignedMessage {
return &SignedMessage{message: message}
}
type signedMessageReader struct {
kind int
message msgpack.RawMessage
}
func (s *signedMessageReader) DecodeMsgpack(dec *msgpack.Decoder) error {
kind, err := dec.DecodeInt()
if err != nil {
return err
}
s.kind = kind
message, err := io.ReadAll(dec.Buffered())
if err != nil {
return err
}
s.message = message
return nil
}
func NewSignedMessage() *SignedMessage {
sm := &SignedMessage{}
sm.SetRequiresHandshake(false)
return sm
}
func (s *SignedMessage) HandleMessage(message IncomingMessageData) error {
var payload signedMessageReader
peer := message.Peer
logger := message.Logger
err := msgpack.Unmarshal(s.message, &payload)
if err != nil {
return err
}
if msgHandler, valid := GetSignedMessageType(payload.kind); valid {
logger.Debug("SignedMessage", zap.Any("type", types.ProtocolMethodMap[types.ProtocolMethod(payload.kind)]))
if msgHandler.RequiresHandshake() && !peer.IsHandshakeDone() {
nid, err := s.nodeId.ToString()
if err != nil {
return err
}
logger.Debug("Peer is not handshake done, ignoring message", zap.Any("type", types.ProtocolMethodMap[types.ProtocolMethod(payload.kind)]), zap.String("peerIP", peer.GetIPString()), zap.String("nodeId", nid))
return nil
}
signedDec := msgpack.NewDecoder(bytes.NewReader(payload.message))
signedMsg := IncomingMessageDataSigned{
IncomingMessageData: message,
NodeId: s.nodeId,
}
err = msgHandler.DecodeMessage(signedDec, signedMsg)
if err != nil {
logger.Error("Error decoding signed message", zap.Error(err))
return err
}
if err = msgHandler.HandleMessage(signedMsg); err != nil {
logger.Error("Error handling signed message", zap.Error(err))
return err
}
}
return nil
}
func (s *SignedMessage) DecodeMessage(dec *msgpack.Decoder, message IncomingMessageData) error {
nodeId, err := dec.DecodeBytes()
if err != nil {
return err
}
s.nodeId = encoding.NewNodeId(nodeId)
signature, err := dec.DecodeBytes()
if err != nil {
return err
}
s.signature = signature
signedMessage, err := dec.DecodeBytes()
if err != nil {
return err
}
s.message = signedMessage
if !ed25519.Verify(s.nodeId.Raw()[1:], s.message, s.signature) {
return errInvalidSignature
}
return nil
}
func (s *SignedMessage) EncodeMsgpack(enc *msgpack.Encoder) error {
err := enc.EncodeInt(int64(types.ProtocolMethodSignedMessage))
if err != nil {
return err
}
err = enc.EncodeBytes(s.nodeId.Raw())
if err != nil {
return err
}
err = enc.EncodeBytes(s.signature)
if err != nil {
return err
}
err = enc.EncodeBytes(s.message)
if err != nil {
return err
}
return nil
}
func (s *SignedMessage) Sign(cfg *config.NodeConfig) error {
if s.nodeId == nil {
panic("nodeId is nil")
}
if s.message == nil {
panic("message is nil")
}
s.signature = ed25519.Sign(cfg.KeyPair.ExtractBytes(), s.message)
return nil
}