186 lines
3.5 KiB
Go
186 lines
3.5 KiB
Go
package protocol
|
|
|
|
import (
|
|
"crypto/ed25519"
|
|
"errors"
|
|
"git.lumeweb.com/LumeWeb/libs5-go/config"
|
|
"git.lumeweb.com/LumeWeb/libs5-go/encoding"
|
|
"git.lumeweb.com/LumeWeb/libs5-go/types"
|
|
"github.com/vmihailenco/msgpack/v5"
|
|
"go.uber.org/zap"
|
|
"io"
|
|
)
|
|
|
|
var (
|
|
_ IncomingMessage = (*SignedMessage)(nil)
|
|
_ msgpack.CustomDecoder = (*signedMessageReader)(nil)
|
|
_ msgpack.CustomEncoder = (*SignedMessage)(nil)
|
|
)
|
|
|
|
var (
|
|
errInvalidSignature = errors.New("Invalid signature found")
|
|
)
|
|
|
|
type SignedMessage struct {
|
|
nodeId *encoding.NodeId
|
|
signature []byte
|
|
message []byte
|
|
HandshakeRequirement
|
|
}
|
|
|
|
func (s *SignedMessage) NodeId() *encoding.NodeId {
|
|
return s.nodeId
|
|
}
|
|
|
|
func (s *SignedMessage) SetNodeId(nodeId *encoding.NodeId) {
|
|
s.nodeId = nodeId
|
|
}
|
|
|
|
func (s *SignedMessage) SetSignature(signature []byte) {
|
|
s.signature = signature
|
|
}
|
|
|
|
func (s *SignedMessage) SetMessage(message []byte) {
|
|
s.message = message
|
|
}
|
|
|
|
func NewSignedMessageRequest(message []byte) *SignedMessage {
|
|
return &SignedMessage{message: message}
|
|
}
|
|
|
|
type signedMessageReader struct {
|
|
kind int
|
|
message msgpack.RawMessage
|
|
}
|
|
|
|
func (s *signedMessageReader) DecodeMsgpack(dec *msgpack.Decoder) error {
|
|
kind, err := dec.DecodeInt()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
s.kind = kind
|
|
|
|
message, err := io.ReadAll(dec.Buffered())
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
s.message = message
|
|
|
|
return nil
|
|
}
|
|
|
|
func NewSignedMessage() *SignedMessage {
|
|
sm := &SignedMessage{}
|
|
|
|
sm.SetRequiresHandshake(false)
|
|
|
|
return sm
|
|
}
|
|
|
|
func (s *SignedMessage) HandleMessage(message IncomingMessageData) error {
|
|
var payload signedMessageReader
|
|
peer := message.Peer
|
|
logger := message.Logger
|
|
|
|
err := msgpack.Unmarshal(s.message, &payload)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if msgHandler, valid := GetSignedMessageType(payload.kind); valid {
|
|
logger.Debug("SignedMessage", zap.Any("type", types.ProtocolMethodMap[types.ProtocolMethod(payload.kind)]))
|
|
if msgHandler.RequiresHandshake() && !peer.IsHandshakeDone() {
|
|
logger.Debug("Peer is not handshake done, ignoring message", zap.Any("type", types.ProtocolMethodMap[types.ProtocolMethod(payload.kind)]))
|
|
return nil
|
|
}
|
|
err := msgpack.Unmarshal(payload.message, &msgHandler)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
data := IncomingMessageDataSigned{
|
|
IncomingMessageData: message,
|
|
NodeId: s.nodeId,
|
|
}
|
|
|
|
err = msgHandler.HandleMessage(data)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *SignedMessage) DecodeMessage(dec *msgpack.Decoder, message IncomingMessageData) error {
|
|
nodeId, err := dec.DecodeBytes()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
s.nodeId = encoding.NewNodeId(nodeId)
|
|
|
|
signature, err := dec.DecodeBytes()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
s.signature = signature
|
|
|
|
signedMessage, err := dec.DecodeBytes()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
s.message = signedMessage
|
|
|
|
if !ed25519.Verify(s.nodeId.Raw()[1:], s.message, s.signature) {
|
|
return errInvalidSignature
|
|
}
|
|
|
|
return nil
|
|
}
|
|
func (s *SignedMessage) EncodeMsgpack(enc *msgpack.Encoder) error {
|
|
err := enc.EncodeInt(int64(types.ProtocolMethodSignedMessage))
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = enc.EncodeBytes(s.nodeId.Raw())
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = enc.EncodeBytes(s.signature)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = enc.EncodeBytes(s.message)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
func (s *SignedMessage) Sign(cfg *config.NodeConfig) error {
|
|
if s.nodeId == nil {
|
|
panic("nodeId is nil")
|
|
}
|
|
|
|
if s.message == nil {
|
|
panic("message is nil")
|
|
}
|
|
|
|
s.signature = ed25519.Sign(cfg.KeyPair.ExtractBytes(), s.message)
|
|
|
|
return nil
|
|
}
|