refactor: major rewrite of message structure and wiring, reducing complexity

This commit is contained in:
Derrick Hammer 2024-01-28 23:39:40 -05:00
parent 6b9a4fb7dc
commit 31ccfb8c0b
Signed by: pcfreak30
GPG Key ID: C997C339BE476FF2
15 changed files with 184 additions and 293 deletions

View File

@ -1,27 +1,71 @@
package base package base
import ( import (
"git.lumeweb.com/LumeWeb/libs5-go/interfaces" "context"
"git.lumeweb.com/LumeWeb/libs5-go/net" "git.lumeweb.com/LumeWeb/libs5-go/net"
"git.lumeweb.com/LumeWeb/libs5-go/node"
"github.com/vmihailenco/msgpack/v5" "github.com/vmihailenco/msgpack/v5"
"io"
) )
//go:generate mockgen -source=base.go -destination=../mocks/base/base.go -package=base //go:generate mockgen -source=base.go -destination=../mocks/base/base.go -package=base
var (
_ msgpack.CustomDecoder = (*IncomingMessageReader)(nil)
)
type IncomingMessage interface { type IncomingMessage interface {
HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error HandleMessage(message IncomingMessageData) error
SetIncomingMessage(msg IncomingMessage) DecodeMessage(dec *msgpack.Decoder, message IncomingMessageData) error
IncomingMessage() IncomingMessage HandshakeRequirer
Self() IncomingMessage
SetSelf(self IncomingMessage)
Original() []byte
Kind() int
RequiresHandshake() bool
SetRequiresHandshake(value bool)
msgpack.CustomDecoder
} }
type IncomingMessageTyped interface { type IncomingMessageData struct {
DecodeMessage(dec *msgpack.Decoder) error Original []byte
IncomingMessage Data []byte
Ctx context.Context
Node *node.NodeImpl
Peer net.Peer
VerifyId bool
}
type IncomingMessageReader struct {
Kind int
Data []byte
}
func (i *IncomingMessageReader) DecodeMsgpack(dec *msgpack.Decoder) error {
kind, err := dec.DecodeInt()
if err != nil {
return err
}
i.Kind = kind
raw, err := io.ReadAll(dec.Buffered())
if err != nil {
return err
}
i.Data = raw
return nil
}
type HandshakeRequirer interface {
RequiresHandshake() bool
SetRequiresHandshake(value bool)
}
type HandshakeRequirement struct {
requiresHandshake bool
}
func (hr *HandshakeRequirement) RequiresHandshake() bool {
return hr.requiresHandshake
}
func (hr *HandshakeRequirement) SetRequiresHandshake(value bool) {
hr.requiresHandshake = value
} }

View File

@ -2,19 +2,8 @@ package base
import "github.com/vmihailenco/msgpack/v5" import "github.com/vmihailenco/msgpack/v5"
var (
_ EncodeableMessage = (*EncodeableMessageImpl)(nil)
)
//go:generate mockgen -source=encodeable_message.go -destination=../mocks/base/encodeable_message.go -package=base //go:generate mockgen -source=encodeable_message.go -destination=../mocks/base/encodeable_message.go -package=base
type EncodeableMessage interface { type EncodeableMessage interface {
msgpack.CustomEncoder msgpack.CustomEncoder
} }
type EncodeableMessageImpl struct {
}
func (e EncodeableMessageImpl) EncodeMsgpack(encoder *msgpack.Encoder) error {
panic("this method should be implemented by the child class")
}

View File

@ -1,137 +0,0 @@
package base
import (
"fmt"
"git.lumeweb.com/LumeWeb/libs5-go/interfaces"
"git.lumeweb.com/LumeWeb/libs5-go/net"
"github.com/vmihailenco/msgpack/v5"
"io"
"net/url"
)
//go:generate mockgen -source=incoming_message.go -destination=../../mocks/base/incoming_message.go -package=base
var _ msgpack.CustomDecoder = (*IncomingMessageImpl)(nil)
var _ IncomingMessage = (*IncomingMessageImpl)(nil)
var _ IncomingMessageTyped = (*IncomingMessageImpl)(nil)
type IncomingMessageHandler func(node interfaces.Node, peer *net.Peer, u *url.URL, verifyId bool) error
type IncomingMessageImpl struct {
kind int
data msgpack.RawMessage
original []byte
known bool
self IncomingMessage
incoming IncomingMessage
requiresHandshake bool
}
func (i *IncomingMessageImpl) Self() IncomingMessage {
return i.self
}
func (i *IncomingMessageImpl) SetSelf(self IncomingMessage) {
i.self = self
}
func (i *IncomingMessageImpl) DecodeMessage(dec *msgpack.Decoder) error {
panic("child class should implement this method")
}
func (i *IncomingMessageImpl) Known() bool {
return i.known
}
func (i *IncomingMessageImpl) SetKnown(known bool) {
i.known = known
}
func (i *IncomingMessageImpl) SetOriginal(original []byte) {
i.original = original
}
func (i *IncomingMessageImpl) Original() []byte {
return i.original
}
func (i *IncomingMessageImpl) SetIncomingMessage(msg IncomingMessage) {
i.incoming = msg
i.known = true
}
func (i *IncomingMessageImpl) IncomingMessage() IncomingMessage {
return i.incoming
}
func (i *IncomingMessageImpl) Kind() int {
return i.kind
}
func (i *IncomingMessageImpl) ToMessage() (message []byte, err error) {
return msgpack.Marshal(i)
}
func (i *IncomingMessageImpl) HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error {
panic("child class should implement this method")
}
func (i *IncomingMessageImpl) Data() msgpack.RawMessage {
return i.data
}
type IncomingMessageTypedImpl struct {
IncomingMessageImpl
}
func NewIncomingMessageUnknown() *IncomingMessageImpl {
return &IncomingMessageImpl{
known: false,
}
}
func NewIncomingMessageKnown(kind int, data msgpack.RawMessage) *IncomingMessageImpl {
return &IncomingMessageImpl{
kind: kind,
data: data,
known: true,
}
}
func NewIncomingMessageTyped(kind int, data msgpack.RawMessage) *IncomingMessageTypedImpl {
known := NewIncomingMessageKnown(kind, data)
return &IncomingMessageTypedImpl{*known}
}
func (i *IncomingMessageImpl) DecodeMsgpack(dec *msgpack.Decoder) error {
if i.known {
if msgTyped, ok := interface{}(i.Self()).(IncomingMessageTyped); ok {
return msgTyped.DecodeMessage(dec)
}
return fmt.Errorf("type assertion to IncomingMessageTyped failed")
}
kind, err := dec.DecodeInt()
if err != nil {
return err
}
i.kind = kind
raw, err := io.ReadAll(dec.Buffered())
if err != nil {
return err
}
i.data = raw
return nil
}
func (i *IncomingMessageImpl) RequiresHandshake() bool {
return i.requiresHandshake
}
func (i *IncomingMessageImpl) SetRequiresHandshake(value bool) {
i.requiresHandshake = value
}

View File

@ -1,7 +0,0 @@
package base
//go:generate mockgen -source=signed.go -destination=../../mocks/base/signed.go -package=base -aux_files=git.lumeweb.com/LumeWeb/libs5-go/protocol/base=base.go
type SignedIncomingMessage interface {
IncomingMessage
}

View File

@ -1,24 +1,21 @@
package protocol package protocol
import ( import (
"errors"
"fmt" "fmt"
"git.lumeweb.com/LumeWeb/libs5-go/interfaces"
"git.lumeweb.com/LumeWeb/libs5-go/net"
"git.lumeweb.com/LumeWeb/libs5-go/protocol/base" "git.lumeweb.com/LumeWeb/libs5-go/protocol/base"
"git.lumeweb.com/LumeWeb/libs5-go/protocol/signed" "git.lumeweb.com/LumeWeb/libs5-go/protocol/signed"
"git.lumeweb.com/LumeWeb/libs5-go/types" "git.lumeweb.com/LumeWeb/libs5-go/types"
"github.com/vmihailenco/msgpack/v5" "github.com/vmihailenco/msgpack/v5"
) )
var _ base.IncomingMessageTyped = (*HandshakeOpen)(nil) var _ base.EncodeableMessage = (*HandshakeOpen)(nil)
var _ base.IncomingMessage = (*HandshakeOpen)(nil)
type HandshakeOpen struct { type HandshakeOpen struct {
challenge []byte challenge []byte
networkId string networkId string
handshake []byte handshake []byte
base.IncomingMessageTypedImpl base.HandshakeRequirement
base.IncomingMessageHandler
} }
func (h *HandshakeOpen) SetHandshake(handshake []byte) { func (h *HandshakeOpen) SetHandshake(handshake []byte) {
@ -34,9 +31,6 @@ func (h HandshakeOpen) NetworkId() string {
} }
var _ base.EncodeableMessage = (*HandshakeOpen)(nil) var _ base.EncodeableMessage = (*HandshakeOpen)(nil)
var (
errInvalidChallenge = errors.New("Invalid challenge")
)
func NewHandshakeOpen(challenge []byte, networkId string) *HandshakeOpen { func NewHandshakeOpen(challenge []byte, networkId string) *HandshakeOpen {
ho := &HandshakeOpen{ ho := &HandshakeOpen{
@ -68,7 +62,7 @@ func (h HandshakeOpen) EncodeMsgpack(enc *msgpack.Encoder) error {
return nil return nil
} }
func (h *HandshakeOpen) DecodeMessage(dec *msgpack.Decoder) error { func (h *HandshakeOpen) DecodeMessage(dec *msgpack.Decoder, message base.IncomingMessageData) error {
handshake, err := dec.DecodeBytes() handshake, err := dec.DecodeBytes()
if err != nil { if err != nil {
@ -99,19 +93,22 @@ func (h *HandshakeOpen) DecodeMessage(dec *msgpack.Decoder) error {
return nil return nil
} }
func (h *HandshakeOpen) HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error { func (h *HandshakeOpen) HandleMessage(message base.IncomingMessageData) error {
node := message.Node
peer := message.Peer
if h.networkId != node.NetworkId() { if h.networkId != node.NetworkId() {
return fmt.Errorf("Peer is in different network: %s", h.networkId) return fmt.Errorf("Peer is in different network: %s", h.networkId)
} }
handshake := signed.NewHandshakeDoneRequest(h.handshake, types.SupportedFeatures, node.Services().P2P().SelfConnectionUris()) handshake := signed.NewHandshakeDoneRequest(h.handshake, types.SupportedFeatures, node.Services().P2P().SelfConnectionUris())
message, err := msgpack.Marshal(handshake) hsMessage, err := msgpack.Marshal(handshake)
if err != nil { if err != nil {
return err return err
} }
secureMessage, err := node.Services().P2P().SignMessageSimple(message) secureMessage, err := node.Services().P2P().SignMessageSimple(hsMessage)
if err != nil { if err != nil {
return err return err

View File

@ -2,7 +2,6 @@ package protocol
import ( import (
"git.lumeweb.com/LumeWeb/libs5-go/encoding" "git.lumeweb.com/LumeWeb/libs5-go/encoding"
"git.lumeweb.com/LumeWeb/libs5-go/interfaces"
"git.lumeweb.com/LumeWeb/libs5-go/net" "git.lumeweb.com/LumeWeb/libs5-go/net"
"git.lumeweb.com/LumeWeb/libs5-go/protocol/base" "git.lumeweb.com/LumeWeb/libs5-go/protocol/base"
"git.lumeweb.com/LumeWeb/libs5-go/types" "git.lumeweb.com/LumeWeb/libs5-go/types"
@ -12,15 +11,13 @@ import (
"log" "log"
) )
var _ base.IncomingMessageTyped = (*HashQuery)(nil)
var _ base.EncodeableMessage = (*HashQuery)(nil) var _ base.EncodeableMessage = (*HashQuery)(nil)
var _ base.IncomingMessage = (*HashQuery)(nil)
type HashQuery struct { type HashQuery struct {
hash *encoding.Multihash hash *encoding.Multihash
kinds []types.StorageLocationType kinds []types.StorageLocationType
base.HandshakeRequirement
base.IncomingMessageTypedImpl
base.IncomingMessageHandler
} }
func NewHashQuery() *HashQuery { func NewHashQuery() *HashQuery {
@ -49,7 +46,7 @@ func (h HashQuery) Kinds() []types.StorageLocationType {
return h.kinds return h.kinds
} }
func (h *HashQuery) DecodeMessage(dec *msgpack.Decoder) error { func (h *HashQuery) DecodeMessage(dec *msgpack.Decoder, message base.IncomingMessageData) error {
hash, err := dec.DecodeBytes() hash, err := dec.DecodeBytes()
if err != nil { if err != nil {
@ -90,7 +87,10 @@ func (h HashQuery) EncodeMsgpack(enc *msgpack.Encoder) error {
return nil return nil
} }
func (h *HashQuery) HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error { func (h *HashQuery) HandleMessage(message base.IncomingMessageData) error {
node := message.Node
peer := message.Peer
mapLocations, err := node.GetCachedStorageLocations(h.hash, h.kinds) mapLocations, err := node.GetCachedStorageLocations(h.hash, h.kinds)
if err != nil { if err != nil {
log.Printf("Error getting cached storage locations: %v", err) log.Printf("Error getting cached storage locations: %v", err)
@ -173,7 +173,7 @@ func (h *HashQuery) HandleMessage(node interfaces.Node, peer net.Peer, verifyId
for _, val := range node.Services().P2P().Peers().Values() { for _, val := range node.Services().P2P().Peers().Values() {
peerVal := val.(net.Peer) peerVal := val.(net.Peer)
if !peerVal.Id().Equals(peer.Id()) { if !peerVal.Id().Equals(peer.Id()) {
err := peerVal.SendMessage(h.IncomingMessage().Original()) err := peerVal.SendMessage(message.Original)
if err != nil { if err != nil {
node.Logger().Error("Failed to send message", zap.Error(err)) node.Logger().Error("Failed to send message", zap.Error(err))
} }

View File

@ -10,10 +10,6 @@ var (
messageTypes map[int]func() base.IncomingMessage messageTypes map[int]func() base.IncomingMessage
) )
var (
_ base.IncomingMessage = (*base.IncomingMessageImpl)(nil)
)
func Init() { func Init() {
messageTypes = make(map[int]func() base.IncomingMessage) messageTypes = make(map[int]func() base.IncomingMessage)

View File

@ -2,19 +2,17 @@ package protocol
import ( import (
"git.lumeweb.com/LumeWeb/libs5-go/interfaces" "git.lumeweb.com/LumeWeb/libs5-go/interfaces"
"git.lumeweb.com/LumeWeb/libs5-go/net"
"git.lumeweb.com/LumeWeb/libs5-go/protocol/base" "git.lumeweb.com/LumeWeb/libs5-go/protocol/base"
"git.lumeweb.com/LumeWeb/libs5-go/types" "git.lumeweb.com/LumeWeb/libs5-go/types"
"github.com/vmihailenco/msgpack/v5" "github.com/vmihailenco/msgpack/v5"
) )
var _ base.IncomingMessageTyped = (*RegistryEntryRequest)(nil) var _ base.IncomingMessage = (*RegistryEntryRequest)(nil)
var _ base.EncodeableMessage = (*RegistryEntryRequest)(nil) var _ base.EncodeableMessage = (*RegistryEntryRequest)(nil)
type RegistryEntryRequest struct { type RegistryEntryRequest struct {
sre interfaces.SignedRegistryEntry sre interfaces.SignedRegistryEntry
base.IncomingMessageTypedImpl base.HandshakeRequirement
base.IncomingMessageHandler
} }
func NewEmptyRegistryEntryRequest() *RegistryEntryRequest { func NewEmptyRegistryEntryRequest() *RegistryEntryRequest {
@ -42,10 +40,8 @@ func (s *RegistryEntryRequest) EncodeMsgpack(enc *msgpack.Encoder) error {
return nil return nil
} }
func (s *RegistryEntryRequest) DecodeMessage(dec *msgpack.Decoder) error { func (s *RegistryEntryRequest) DecodeMessage(dec *msgpack.Decoder, message base.IncomingMessageData) error {
data := s.IncomingMessage().Original() sre, err := UnmarshalSignedRegistryEntry(message.Data)
sre, err := UnmarshalSignedRegistryEntry(data)
if err != nil { if err != nil {
return err return err
} }
@ -55,6 +51,8 @@ func (s *RegistryEntryRequest) DecodeMessage(dec *msgpack.Decoder) error {
return nil return nil
} }
func (s *RegistryEntryRequest) HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error { func (s *RegistryEntryRequest) HandleMessage(message base.IncomingMessageData) error {
node := message.Node
peer := message.Peer
return node.Services().Registry().Set(s.sre, false, peer) return node.Services().Registry().Set(s.sre, false, peer)
} }

View File

@ -1,20 +1,17 @@
package protocol package protocol
import ( import (
"git.lumeweb.com/LumeWeb/libs5-go/interfaces"
"git.lumeweb.com/LumeWeb/libs5-go/net"
"git.lumeweb.com/LumeWeb/libs5-go/protocol/base" "git.lumeweb.com/LumeWeb/libs5-go/protocol/base"
"git.lumeweb.com/LumeWeb/libs5-go/types" "git.lumeweb.com/LumeWeb/libs5-go/types"
"github.com/vmihailenco/msgpack/v5" "github.com/vmihailenco/msgpack/v5"
) )
var _ base.IncomingMessageTyped = (*RegistryQuery)(nil) var _ base.IncomingMessage = (*RegistryQuery)(nil)
var _ base.EncodeableMessage = (*RegistryQuery)(nil) var _ base.EncodeableMessage = (*RegistryQuery)(nil)
type RegistryQuery struct { type RegistryQuery struct {
pk []byte pk []byte
base.IncomingMessageTypedImpl base.HandshakeRequirement
base.IncomingMessageHandler
} }
func NewEmptyRegistryQuery() *RegistryQuery { func NewEmptyRegistryQuery() *RegistryQuery {
@ -42,7 +39,7 @@ func (s *RegistryQuery) EncodeMsgpack(enc *msgpack.Encoder) error {
return nil return nil
} }
func (s *RegistryQuery) DecodeMessage(dec *msgpack.Decoder) error { func (s *RegistryQuery) DecodeMessage(dec *msgpack.Decoder, message base.IncomingMessageData) error {
pk, err := dec.DecodeBytes() pk, err := dec.DecodeBytes()
if err != nil { if err != nil {
return err return err
@ -53,7 +50,9 @@ func (s *RegistryQuery) DecodeMessage(dec *msgpack.Decoder) error {
return nil return nil
} }
func (s *RegistryQuery) HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error { func (s *RegistryQuery) HandleMessage(message base.IncomingMessageData) error {
node := message.Node
peer := message.Peer
sre, err := node.Services().Registry().Get(s.pk) sre, err := node.Services().Registry().Get(s.pk)
if err != nil { if err != nil {
return err return err

View File

@ -2,7 +2,6 @@ package signed
import ( import (
"git.lumeweb.com/LumeWeb/libs5-go/encoding" "git.lumeweb.com/LumeWeb/libs5-go/encoding"
"git.lumeweb.com/LumeWeb/libs5-go/interfaces"
"git.lumeweb.com/LumeWeb/libs5-go/net" "git.lumeweb.com/LumeWeb/libs5-go/net"
"git.lumeweb.com/LumeWeb/libs5-go/protocol/base" "git.lumeweb.com/LumeWeb/libs5-go/protocol/base"
"git.lumeweb.com/LumeWeb/libs5-go/types" "git.lumeweb.com/LumeWeb/libs5-go/types"
@ -11,14 +10,14 @@ import (
) )
var ( var (
_ base.IncomingMessageTyped = (*AnnouncePeers)(nil) _ IncomingMessageSigned = (*AnnouncePeers)(nil)
) )
type AnnouncePeers struct { type AnnouncePeers struct {
peer net.Peer peer net.Peer
connectionUris []*url.URL connectionUris []*url.URL
peersToSend []net.Peer peersToSend []net.Peer
base.IncomingMessageTypedImpl base.HandshakeRequirement
} }
func (a *AnnouncePeers) PeersToSend() []net.Peer { func (a *AnnouncePeers) PeersToSend() []net.Peer {
@ -41,7 +40,7 @@ func NewAnnouncePeers() *AnnouncePeers {
return ap return ap
} }
func (a *AnnouncePeers) DecodeMessage(dec *msgpack.Decoder) error { func (a *AnnouncePeers) DecodeMessage(dec *msgpack.Decoder, message IncomingMessageDataSigned) error {
// CIDFromString the number of peers. // CIDFromString the number of peers.
numPeers, err := dec.DecodeInt() numPeers, err := dec.DecodeInt()
if err != nil { if err != nil {
@ -106,7 +105,9 @@ func (a *AnnouncePeers) DecodeMessage(dec *msgpack.Decoder) error {
return nil return nil
} }
func (a AnnouncePeers) HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error { func (a AnnouncePeers) HandleMessage(message IncomingMessageDataSigned) error {
node := message.Node
peer := message.Peer
if len(a.connectionUris) > 0 { if len(a.connectionUris) > 0 {
err := node.Services().P2P().ConnectToNode([]*url.URL{a.connectionUris[0]}, false, peer) err := node.Services().P2P().ConnectToNode([]*url.URL{a.connectionUris[0]}, false, peer)
if err != nil { if err != nil {

View File

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"errors" "errors"
"fmt" "fmt"
"git.lumeweb.com/LumeWeb/libs5-go/interfaces"
"git.lumeweb.com/LumeWeb/libs5-go/net" "git.lumeweb.com/LumeWeb/libs5-go/net"
"git.lumeweb.com/LumeWeb/libs5-go/protocol/base" "git.lumeweb.com/LumeWeb/libs5-go/protocol/base"
"git.lumeweb.com/LumeWeb/libs5-go/types" "git.lumeweb.com/LumeWeb/libs5-go/types"
@ -13,17 +12,16 @@ import (
"net/url" "net/url"
) )
var _ base.IncomingMessageTyped = (*HandshakeDone)(nil) var _ IncomingMessageSigned = (*HandshakeDone)(nil)
var _ base.EncodeableMessage = (*HandshakeDone)(nil) var _ base.EncodeableMessage = (*HandshakeDone)(nil)
type HandshakeDone struct { type HandshakeDone struct {
challenge []byte challenge []byte
networkId string networkId string
base.IncomingMessageTypedImpl
base.IncomingMessageHandler
supportedFeatures int supportedFeatures int
connectionUris []*url.URL connectionUris []*url.URL
handshake []byte handshake []byte
base.HandshakeRequirement
} }
func NewHandshakeDoneRequest(handshake []byte, supportedFeatures int, connectionUris []*url.URL) *HandshakeDone { func NewHandshakeDoneRequest(handshake []byte, supportedFeatures int, connectionUris []*url.URL) *HandshakeDone {
@ -78,7 +76,12 @@ func NewHandshakeDone() *HandshakeDone {
return hn return hn
} }
func (h HandshakeDone) HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error { func (h HandshakeDone) HandleMessage(message IncomingMessageDataSigned) error {
node := message.Node
peer := message.Peer
verifyId := message.VerifyId
nodeId := message.NodeId
if !node.IsStarted() { if !node.IsStarted() {
err := peer.End() err := peer.End()
if err != nil { if err != nil {
@ -91,8 +94,6 @@ func (h HandshakeDone) HandleMessage(node interfaces.Node, peer net.Peer, verify
return errors.New("Invalid challenge") return errors.New("Invalid challenge")
} }
nodeId := h.IncomingMessage().(*SignedMessage).NodeId()
if !verifyId { if !verifyId {
peer.SetId(nodeId) peer.SetId(nodeId)
} else { } else {
@ -130,7 +131,7 @@ func (h HandshakeDone) HandleMessage(node interfaces.Node, peer net.Peer, verify
return nil return nil
} }
func (h *HandshakeDone) DecodeMessage(dec *msgpack.Decoder) error { func (h *HandshakeDone) DecodeMessage(dec *msgpack.Decoder, message IncomingMessageDataSigned) error {
challenge, err := dec.DecodeBytes() challenge, err := dec.DecodeBytes()
if err != nil { if err != nil {
return err return err

View File

@ -1,33 +1,46 @@
package signed package signed
import ( import (
"git.lumeweb.com/LumeWeb/libs5-go/encoding"
"git.lumeweb.com/LumeWeb/libs5-go/protocol/base" "git.lumeweb.com/LumeWeb/libs5-go/protocol/base"
"git.lumeweb.com/LumeWeb/libs5-go/types" "git.lumeweb.com/LumeWeb/libs5-go/types"
"github.com/vmihailenco/msgpack/v5"
) )
type IncomingMessageDataSigned struct {
base.IncomingMessageData
NodeId *encoding.NodeId
}
type IncomingMessageSigned interface {
HandleMessage(message IncomingMessageDataSigned) error
DecodeMessage(dec *msgpack.Decoder, message IncomingMessageDataSigned) error
base.HandshakeRequirer
}
var ( var (
messageTypes map[int]func() base.SignedIncomingMessage messageTypes map[int]func() IncomingMessageSigned
) )
func Init() { func Init() {
messageTypes = make(map[int]func() base.SignedIncomingMessage) messageTypes = make(map[int]func() IncomingMessageSigned)
RegisterMessageType(int(types.ProtocolMethodHandshakeDone), func() base.SignedIncomingMessage { RegisterMessageType(int(types.ProtocolMethodHandshakeDone), func() IncomingMessageSigned {
return NewHandshakeDone() return NewHandshakeDone()
}) })
RegisterMessageType(int(types.ProtocolMethodAnnouncePeers), func() base.SignedIncomingMessage { RegisterMessageType(int(types.ProtocolMethodAnnouncePeers), func() IncomingMessageSigned {
return NewAnnouncePeers() return NewAnnouncePeers()
}) })
} }
func RegisterMessageType(messageType int, factoryFunc func() base.SignedIncomingMessage) { func RegisterMessageType(messageType int, factoryFunc func() IncomingMessageSigned) {
if factoryFunc == nil { if factoryFunc == nil {
panic("factoryFunc cannot be nil") panic("factoryFunc cannot be nil")
} }
messageTypes[messageType] = factoryFunc messageTypes[messageType] = factoryFunc
} }
func GetMessageType(kind int) (base.SignedIncomingMessage, bool) { func GetMessageType(kind int) (IncomingMessageSigned, bool) {
value, ok := messageTypes[kind] value, ok := messageTypes[kind]
if !ok { if !ok {
return nil, false return nil, false
@ -35,12 +48,3 @@ func GetMessageType(kind int) (base.SignedIncomingMessage, bool) {
return value(), true return value(), true
} }
var (
_ base.SignedIncomingMessage = (*IncomingMessageImpl)(nil)
)
type IncomingMessageImpl struct {
base.IncomingMessageImpl
message []byte
}

View File

@ -5,7 +5,6 @@ import (
"errors" "errors"
"git.lumeweb.com/LumeWeb/libs5-go/encoding" "git.lumeweb.com/LumeWeb/libs5-go/encoding"
"git.lumeweb.com/LumeWeb/libs5-go/interfaces" "git.lumeweb.com/LumeWeb/libs5-go/interfaces"
"git.lumeweb.com/LumeWeb/libs5-go/net"
"git.lumeweb.com/LumeWeb/libs5-go/protocol/base" "git.lumeweb.com/LumeWeb/libs5-go/protocol/base"
"git.lumeweb.com/LumeWeb/libs5-go/types" "git.lumeweb.com/LumeWeb/libs5-go/types"
"github.com/vmihailenco/msgpack/v5" "github.com/vmihailenco/msgpack/v5"
@ -14,8 +13,8 @@ import (
) )
var ( var (
_ base.IncomingMessageTyped = (*SignedMessage)(nil) _ base.IncomingMessage = (*SignedMessage)(nil)
_ msgpack.CustomDecoder = (*signedMessagePayoad)(nil) _ msgpack.CustomDecoder = (*signedMessageReader)(nil)
_ msgpack.CustomEncoder = (*SignedMessage)(nil) _ msgpack.CustomEncoder = (*SignedMessage)(nil)
) )
@ -27,7 +26,7 @@ type SignedMessage struct {
nodeId *encoding.NodeId nodeId *encoding.NodeId
signature []byte signature []byte
message []byte message []byte
base.IncomingMessageTypedImpl base.HandshakeRequirement
} }
func (s *SignedMessage) NodeId() *encoding.NodeId { func (s *SignedMessage) NodeId() *encoding.NodeId {
@ -50,12 +49,12 @@ func NewSignedMessageRequest(message []byte) *SignedMessage {
return &SignedMessage{message: message} return &SignedMessage{message: message}
} }
type signedMessagePayoad struct { type signedMessageReader struct {
kind int kind int
message msgpack.RawMessage message msgpack.RawMessage
} }
func (s *signedMessagePayoad) DecodeMsgpack(dec *msgpack.Decoder) error { func (s *signedMessageReader) DecodeMsgpack(dec *msgpack.Decoder) error {
kind, err := dec.DecodeInt() kind, err := dec.DecodeInt()
if err != nil { if err != nil {
return err return err
@ -82,8 +81,10 @@ func NewSignedMessage() *SignedMessage {
return sm return sm
} }
func (s *SignedMessage) HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error { func (s *SignedMessage) HandleMessage(message base.IncomingMessageData) error {
var payload signedMessagePayoad var payload signedMessageReader
node := message.Node
peer := message.Peer
err := msgpack.Unmarshal(s.message, &payload) err := msgpack.Unmarshal(s.message, &payload)
if err != nil { if err != nil {
@ -96,14 +97,17 @@ func (s *SignedMessage) HandleMessage(node interfaces.Node, peer net.Peer, verif
node.Logger().Debug("Peer is not handshake done, ignoring message", zap.Any("type", types.ProtocolMethodMap[types.ProtocolMethod(payload.kind)])) node.Logger().Debug("Peer is not handshake done, ignoring message", zap.Any("type", types.ProtocolMethodMap[types.ProtocolMethod(payload.kind)]))
return nil return nil
} }
msgHandler.SetIncomingMessage(s)
msgHandler.SetSelf(msgHandler)
err := msgpack.Unmarshal(payload.message, &msgHandler) err := msgpack.Unmarshal(payload.message, &msgHandler)
if err != nil { if err != nil {
return err return err
} }
err = msgHandler.HandleMessage(node, peer, verifyId) data := IncomingMessageDataSigned{
IncomingMessageData: message,
NodeId: s.nodeId,
}
err = msgHandler.HandleMessage(data)
if err != nil { if err != nil {
return err return err
} }
@ -112,7 +116,7 @@ func (s *SignedMessage) HandleMessage(node interfaces.Node, peer net.Peer, verif
return nil return nil
} }
func (s *SignedMessage) DecodeMessage(dec *msgpack.Decoder) error { func (s *SignedMessage) DecodeMessage(dec *msgpack.Decoder, message base.IncomingMessageData) error {
nodeId, err := dec.DecodeBytes() nodeId, err := dec.DecodeBytes()
if err != nil { if err != nil {
return err return err
@ -127,12 +131,12 @@ func (s *SignedMessage) DecodeMessage(dec *msgpack.Decoder) error {
s.signature = signature s.signature = signature
message, err := dec.DecodeBytes() signedMessage, err := dec.DecodeBytes()
if err != nil { if err != nil {
return err return err
} }
s.message = message s.message = signedMessage
if !ed25519.Verify(s.nodeId.Raw()[1:], s.message, s.signature) { if !ed25519.Verify(s.nodeId.Raw()[1:], s.message, s.signature) {
return errInvalidSignature return errInvalidSignature

View File

@ -4,7 +4,6 @@ import (
"crypto/ed25519" "crypto/ed25519"
"fmt" "fmt"
"git.lumeweb.com/LumeWeb/libs5-go/encoding" "git.lumeweb.com/LumeWeb/libs5-go/encoding"
"git.lumeweb.com/LumeWeb/libs5-go/interfaces"
"git.lumeweb.com/LumeWeb/libs5-go/net" "git.lumeweb.com/LumeWeb/libs5-go/net"
"git.lumeweb.com/LumeWeb/libs5-go/protocol/base" "git.lumeweb.com/LumeWeb/libs5-go/protocol/base"
"git.lumeweb.com/LumeWeb/libs5-go/storage" "git.lumeweb.com/LumeWeb/libs5-go/storage"
@ -15,7 +14,7 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
) )
var _ base.IncomingMessageTyped = (*StorageLocation)(nil) var _ base.IncomingMessage = (*StorageLocation)(nil)
type StorageLocation struct { type StorageLocation struct {
hash *encoding.Multihash hash *encoding.Multihash
@ -24,9 +23,7 @@ type StorageLocation struct {
parts []string parts []string
publicKey []byte publicKey []byte
signature []byte signature []byte
base.HandshakeRequirement
base.IncomingMessageTypedImpl
base.IncomingMessageHandler
} }
func NewStorageLocation() *StorageLocation { func NewStorageLocation() *StorageLocation {
@ -37,12 +34,14 @@ func NewStorageLocation() *StorageLocation {
return sl return sl
} }
func (s *StorageLocation) DecodeMessage(dec *msgpack.Decoder) error { func (s *StorageLocation) DecodeMessage(dec *msgpack.Decoder, message base.IncomingMessageData) error {
// nop, we use the incoming message -> original already stored // nop, we use the incoming message -> original already stored
return nil return nil
} }
func (s *StorageLocation) HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error { func (s *StorageLocation) HandleMessage(message base.IncomingMessageData) error {
msg := s.IncomingMessage().Original() msg := message.Original
node := message.Node
peer := message.Peer
hash := encoding.NewMultihash(msg[1:34]) // Replace NewMultihash with appropriate function hash := encoding.NewMultihash(msg[1:34]) // Replace NewMultihash with appropriate function

View File

@ -1,6 +1,8 @@
package service package service
import ( import (
"bytes"
"context"
ed25519p "crypto/ed25519" ed25519p "crypto/ed25519"
"errors" "errors"
"fmt" "fmt"
@ -8,6 +10,7 @@ import (
"git.lumeweb.com/LumeWeb/libs5-go/encoding" "git.lumeweb.com/LumeWeb/libs5-go/encoding"
"git.lumeweb.com/LumeWeb/libs5-go/interfaces" "git.lumeweb.com/LumeWeb/libs5-go/interfaces"
"git.lumeweb.com/LumeWeb/libs5-go/net" "git.lumeweb.com/LumeWeb/libs5-go/net"
"git.lumeweb.com/LumeWeb/libs5-go/node"
"git.lumeweb.com/LumeWeb/libs5-go/protocol" "git.lumeweb.com/LumeWeb/libs5-go/protocol"
"git.lumeweb.com/LumeWeb/libs5-go/protocol/base" "git.lumeweb.com/LumeWeb/libs5-go/protocol/base"
"git.lumeweb.com/LumeWeb/libs5-go/protocol/signed" "git.lumeweb.com/LumeWeb/libs5-go/protocol/signed"
@ -398,15 +401,6 @@ func (p *P2PImpl) OnNewPeer(peer net.Peer, verifyId bool) error {
return nil return nil
} }
func (p *P2PImpl) OnNewPeerListen(peer net.Peer, verifyId bool) { func (p *P2PImpl) OnNewPeerListen(peer net.Peer, verifyId bool) {
var pid string
if peer.Id() != nil {
pid, _ = peer.Id().ToString()
} else {
pid = "unknown"
}
onDone := net.CloseCallback(func() { onDone := net.CloseCallback(func() {
if peer.Id() != nil { if peer.Id() != nil {
pid, err := peer.Id().ToString() pid, err := peer.Id().ToString()
@ -431,33 +425,43 @@ func (p *P2PImpl) OnNewPeerListen(peer net.Peer, verifyId bool) {
}) })
peer.ListenForMessages(func(message []byte) error { peer.ListenForMessages(func(message []byte) error {
imsg := base.NewIncomingMessageUnknown() var reader base.IncomingMessageReader
err := msgpack.Unmarshal(message, imsg) err := msgpack.Unmarshal(message, &reader)
p.logger.Debug("ListenForMessages", zap.Any("message", imsg), zap.String("peer", pid))
if err != nil { if err != nil {
p.logger.Error("Error decoding basic message info", zap.Error(err))
return err return err
} }
handler, ok := protocol.GetMessageType(imsg.Kind()) // Now, get the specific message handler based on the message kind
handler, ok := protocol.GetMessageType(reader.Kind)
if !ok {
p.logger.Error("Unknown message type", zap.Int("type", reader.Kind))
return fmt.Errorf("unknown message type: %d", reader.Kind)
}
if ok { data := base.IncomingMessageData{
if handler.RequiresHandshake() && !peer.IsHandshakeDone() { Original: message,
p.logger.Debug("Peer is not handshake done, ignoring message", zap.Any("type", types.ProtocolMethodMap[types.ProtocolMethod(imsg.Kind())])) Data: reader.Data,
return nil Ctx: context.Background(),
Node: p.node.(*node.NodeImpl),
Peer: peer,
VerifyId: verifyId,
} }
imsg.SetOriginal(message)
handler.SetIncomingMessage(imsg) dec := msgpack.NewDecoder(bytes.NewReader(reader.Data))
handler.SetSelf(handler)
err := msgpack.Unmarshal(imsg.Data(), handler) err = handler.DecodeMessage(dec, data)
if err != nil { if err != nil {
p.logger.Error("Error decoding message", zap.Error(err))
return err return err
} }
err = handler.HandleMessage(p.node, peer, verifyId)
if err != nil { // Directly decode and handle the specific message type
if err := handler.HandleMessage(data); err != nil {
p.logger.Error("Error handling message", zap.Error(err))
return err return err
} }
}
return nil return nil
}, net.ListenerOptions{ }, net.ListenerOptions{
@ -465,7 +469,6 @@ func (p *P2PImpl) OnNewPeerListen(peer net.Peer, verifyId bool) {
OnError: &onError, OnError: &onError,
Logger: p.logger, Logger: p.logger,
}) })
} }
func (p *P2PImpl) readNodeVotes(nodeId *encoding.NodeId) (interfaces.NodeVotes, error) { func (p *P2PImpl) readNodeVotes(nodeId *encoding.NodeId) (interfaces.NodeVotes, error) {