From 31ccfb8c0bea0337db827bffdd97a18ad3e54203 Mon Sep 17 00:00:00 2001 From: Derrick Hammer Date: Sun, 28 Jan 2024 23:39:40 -0500 Subject: [PATCH] refactor: major rewrite of message structure and wiring, reducing complexity --- protocol/base/base.go | 72 ++++++++++++--- protocol/base/encodeable_message.go | 11 --- protocol/base/incoming_message.go | 137 ---------------------------- protocol/base/signed.go | 7 -- protocol/handshake_open.go | 23 ++--- protocol/hash_query.go | 16 ++-- protocol/message.go | 4 - protocol/registry_entry.go | 16 ++-- protocol/registry_query.go | 13 ++- protocol/signed/announce_peers.go | 11 ++- protocol/signed/handshake_done.go | 21 +++-- protocol/signed/signed.go | 34 ++++--- protocol/signed/signed_message.go | 34 ++++--- protocol/storage_location.go | 15 ++- service/p2p.go | 63 +++++++------ 15 files changed, 184 insertions(+), 293 deletions(-) delete mode 100644 protocol/base/incoming_message.go delete mode 100644 protocol/base/signed.go diff --git a/protocol/base/base.go b/protocol/base/base.go index d189d81..a096b99 100644 --- a/protocol/base/base.go +++ b/protocol/base/base.go @@ -1,27 +1,71 @@ package base import ( - "git.lumeweb.com/LumeWeb/libs5-go/interfaces" + "context" "git.lumeweb.com/LumeWeb/libs5-go/net" + "git.lumeweb.com/LumeWeb/libs5-go/node" "github.com/vmihailenco/msgpack/v5" + "io" ) //go:generate mockgen -source=base.go -destination=../mocks/base/base.go -package=base +var ( + _ msgpack.CustomDecoder = (*IncomingMessageReader)(nil) +) + type IncomingMessage interface { - HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error - SetIncomingMessage(msg IncomingMessage) - IncomingMessage() IncomingMessage - Self() IncomingMessage - SetSelf(self IncomingMessage) - Original() []byte - Kind() int - RequiresHandshake() bool - SetRequiresHandshake(value bool) - msgpack.CustomDecoder + HandleMessage(message IncomingMessageData) error + DecodeMessage(dec *msgpack.Decoder, message IncomingMessageData) error + HandshakeRequirer } -type IncomingMessageTyped interface { - DecodeMessage(dec *msgpack.Decoder) error - IncomingMessage +type IncomingMessageData struct { + Original []byte + Data []byte + Ctx context.Context + Node *node.NodeImpl + Peer net.Peer + VerifyId bool +} + +type IncomingMessageReader struct { + Kind int + Data []byte +} + +func (i *IncomingMessageReader) DecodeMsgpack(dec *msgpack.Decoder) error { + kind, err := dec.DecodeInt() + if err != nil { + return err + } + + i.Kind = kind + + raw, err := io.ReadAll(dec.Buffered()) + + if err != nil { + return err + } + + i.Data = raw + + return nil +} + +type HandshakeRequirer interface { + RequiresHandshake() bool + SetRequiresHandshake(value bool) +} + +type HandshakeRequirement struct { + requiresHandshake bool +} + +func (hr *HandshakeRequirement) RequiresHandshake() bool { + return hr.requiresHandshake +} + +func (hr *HandshakeRequirement) SetRequiresHandshake(value bool) { + hr.requiresHandshake = value } diff --git a/protocol/base/encodeable_message.go b/protocol/base/encodeable_message.go index b780bba..20e2c77 100644 --- a/protocol/base/encodeable_message.go +++ b/protocol/base/encodeable_message.go @@ -2,19 +2,8 @@ package base import "github.com/vmihailenco/msgpack/v5" -var ( - _ EncodeableMessage = (*EncodeableMessageImpl)(nil) -) - //go:generate mockgen -source=encodeable_message.go -destination=../mocks/base/encodeable_message.go -package=base type EncodeableMessage interface { msgpack.CustomEncoder } - -type EncodeableMessageImpl struct { -} - -func (e EncodeableMessageImpl) EncodeMsgpack(encoder *msgpack.Encoder) error { - panic("this method should be implemented by the child class") -} diff --git a/protocol/base/incoming_message.go b/protocol/base/incoming_message.go deleted file mode 100644 index a5dec51..0000000 --- a/protocol/base/incoming_message.go +++ /dev/null @@ -1,137 +0,0 @@ -package base - -import ( - "fmt" - "git.lumeweb.com/LumeWeb/libs5-go/interfaces" - "git.lumeweb.com/LumeWeb/libs5-go/net" - "github.com/vmihailenco/msgpack/v5" - "io" - "net/url" -) - -//go:generate mockgen -source=incoming_message.go -destination=../../mocks/base/incoming_message.go -package=base - -var _ msgpack.CustomDecoder = (*IncomingMessageImpl)(nil) -var _ IncomingMessage = (*IncomingMessageImpl)(nil) -var _ IncomingMessageTyped = (*IncomingMessageImpl)(nil) - -type IncomingMessageHandler func(node interfaces.Node, peer *net.Peer, u *url.URL, verifyId bool) error - -type IncomingMessageImpl struct { - kind int - data msgpack.RawMessage - original []byte - known bool - self IncomingMessage - incoming IncomingMessage - requiresHandshake bool -} - -func (i *IncomingMessageImpl) Self() IncomingMessage { - return i.self -} - -func (i *IncomingMessageImpl) SetSelf(self IncomingMessage) { - i.self = self -} - -func (i *IncomingMessageImpl) DecodeMessage(dec *msgpack.Decoder) error { - panic("child class should implement this method") -} - -func (i *IncomingMessageImpl) Known() bool { - return i.known -} - -func (i *IncomingMessageImpl) SetKnown(known bool) { - i.known = known -} - -func (i *IncomingMessageImpl) SetOriginal(original []byte) { - i.original = original -} - -func (i *IncomingMessageImpl) Original() []byte { - return i.original -} - -func (i *IncomingMessageImpl) SetIncomingMessage(msg IncomingMessage) { - i.incoming = msg - i.known = true -} - -func (i *IncomingMessageImpl) IncomingMessage() IncomingMessage { - return i.incoming -} - -func (i *IncomingMessageImpl) Kind() int { - return i.kind -} - -func (i *IncomingMessageImpl) ToMessage() (message []byte, err error) { - return msgpack.Marshal(i) -} - -func (i *IncomingMessageImpl) HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error { - panic("child class should implement this method") -} - -func (i *IncomingMessageImpl) Data() msgpack.RawMessage { - return i.data -} - -type IncomingMessageTypedImpl struct { - IncomingMessageImpl -} - -func NewIncomingMessageUnknown() *IncomingMessageImpl { - return &IncomingMessageImpl{ - known: false, - } -} - -func NewIncomingMessageKnown(kind int, data msgpack.RawMessage) *IncomingMessageImpl { - return &IncomingMessageImpl{ - kind: kind, - data: data, - known: true, - } -} - -func NewIncomingMessageTyped(kind int, data msgpack.RawMessage) *IncomingMessageTypedImpl { - known := NewIncomingMessageKnown(kind, data) - return &IncomingMessageTypedImpl{*known} -} - -func (i *IncomingMessageImpl) DecodeMsgpack(dec *msgpack.Decoder) error { - if i.known { - if msgTyped, ok := interface{}(i.Self()).(IncomingMessageTyped); ok { - return msgTyped.DecodeMessage(dec) - } - return fmt.Errorf("type assertion to IncomingMessageTyped failed") - } - - kind, err := dec.DecodeInt() - if err != nil { - return err - } - - i.kind = kind - - raw, err := io.ReadAll(dec.Buffered()) - - if err != nil { - return err - } - - i.data = raw - return nil -} - -func (i *IncomingMessageImpl) RequiresHandshake() bool { - return i.requiresHandshake -} - -func (i *IncomingMessageImpl) SetRequiresHandshake(value bool) { - i.requiresHandshake = value -} diff --git a/protocol/base/signed.go b/protocol/base/signed.go deleted file mode 100644 index 3432ceb..0000000 --- a/protocol/base/signed.go +++ /dev/null @@ -1,7 +0,0 @@ -package base - -//go:generate mockgen -source=signed.go -destination=../../mocks/base/signed.go -package=base -aux_files=git.lumeweb.com/LumeWeb/libs5-go/protocol/base=base.go - -type SignedIncomingMessage interface { - IncomingMessage -} diff --git a/protocol/handshake_open.go b/protocol/handshake_open.go index 1042357..d08b9c7 100644 --- a/protocol/handshake_open.go +++ b/protocol/handshake_open.go @@ -1,24 +1,21 @@ package protocol import ( - "errors" "fmt" - "git.lumeweb.com/LumeWeb/libs5-go/interfaces" - "git.lumeweb.com/LumeWeb/libs5-go/net" "git.lumeweb.com/LumeWeb/libs5-go/protocol/base" "git.lumeweb.com/LumeWeb/libs5-go/protocol/signed" "git.lumeweb.com/LumeWeb/libs5-go/types" "github.com/vmihailenco/msgpack/v5" ) -var _ base.IncomingMessageTyped = (*HandshakeOpen)(nil) +var _ base.EncodeableMessage = (*HandshakeOpen)(nil) +var _ base.IncomingMessage = (*HandshakeOpen)(nil) type HandshakeOpen struct { challenge []byte networkId string handshake []byte - base.IncomingMessageTypedImpl - base.IncomingMessageHandler + base.HandshakeRequirement } func (h *HandshakeOpen) SetHandshake(handshake []byte) { @@ -34,9 +31,6 @@ func (h HandshakeOpen) NetworkId() string { } var _ base.EncodeableMessage = (*HandshakeOpen)(nil) -var ( - errInvalidChallenge = errors.New("Invalid challenge") -) func NewHandshakeOpen(challenge []byte, networkId string) *HandshakeOpen { ho := &HandshakeOpen{ @@ -68,7 +62,7 @@ func (h HandshakeOpen) EncodeMsgpack(enc *msgpack.Encoder) error { return nil } -func (h *HandshakeOpen) DecodeMessage(dec *msgpack.Decoder) error { +func (h *HandshakeOpen) DecodeMessage(dec *msgpack.Decoder, message base.IncomingMessageData) error { handshake, err := dec.DecodeBytes() if err != nil { @@ -99,19 +93,22 @@ func (h *HandshakeOpen) DecodeMessage(dec *msgpack.Decoder) error { return nil } -func (h *HandshakeOpen) HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error { +func (h *HandshakeOpen) HandleMessage(message base.IncomingMessageData) error { + node := message.Node + peer := message.Peer + if h.networkId != node.NetworkId() { return fmt.Errorf("Peer is in different network: %s", h.networkId) } handshake := signed.NewHandshakeDoneRequest(h.handshake, types.SupportedFeatures, node.Services().P2P().SelfConnectionUris()) - message, err := msgpack.Marshal(handshake) + hsMessage, err := msgpack.Marshal(handshake) if err != nil { return err } - secureMessage, err := node.Services().P2P().SignMessageSimple(message) + secureMessage, err := node.Services().P2P().SignMessageSimple(hsMessage) if err != nil { return err diff --git a/protocol/hash_query.go b/protocol/hash_query.go index 6a221ee..edbffcd 100644 --- a/protocol/hash_query.go +++ b/protocol/hash_query.go @@ -2,7 +2,6 @@ package protocol import ( "git.lumeweb.com/LumeWeb/libs5-go/encoding" - "git.lumeweb.com/LumeWeb/libs5-go/interfaces" "git.lumeweb.com/LumeWeb/libs5-go/net" "git.lumeweb.com/LumeWeb/libs5-go/protocol/base" "git.lumeweb.com/LumeWeb/libs5-go/types" @@ -12,15 +11,13 @@ import ( "log" ) -var _ base.IncomingMessageTyped = (*HashQuery)(nil) var _ base.EncodeableMessage = (*HashQuery)(nil) +var _ base.IncomingMessage = (*HashQuery)(nil) type HashQuery struct { hash *encoding.Multihash kinds []types.StorageLocationType - - base.IncomingMessageTypedImpl - base.IncomingMessageHandler + base.HandshakeRequirement } func NewHashQuery() *HashQuery { @@ -49,7 +46,7 @@ func (h HashQuery) Kinds() []types.StorageLocationType { return h.kinds } -func (h *HashQuery) DecodeMessage(dec *msgpack.Decoder) error { +func (h *HashQuery) DecodeMessage(dec *msgpack.Decoder, message base.IncomingMessageData) error { hash, err := dec.DecodeBytes() if err != nil { @@ -90,7 +87,10 @@ func (h HashQuery) EncodeMsgpack(enc *msgpack.Encoder) error { return nil } -func (h *HashQuery) HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error { +func (h *HashQuery) HandleMessage(message base.IncomingMessageData) error { + node := message.Node + peer := message.Peer + mapLocations, err := node.GetCachedStorageLocations(h.hash, h.kinds) if err != nil { log.Printf("Error getting cached storage locations: %v", err) @@ -173,7 +173,7 @@ func (h *HashQuery) HandleMessage(node interfaces.Node, peer net.Peer, verifyId for _, val := range node.Services().P2P().Peers().Values() { peerVal := val.(net.Peer) if !peerVal.Id().Equals(peer.Id()) { - err := peerVal.SendMessage(h.IncomingMessage().Original()) + err := peerVal.SendMessage(message.Original) if err != nil { node.Logger().Error("Failed to send message", zap.Error(err)) } diff --git a/protocol/message.go b/protocol/message.go index 902e22d..8426e81 100644 --- a/protocol/message.go +++ b/protocol/message.go @@ -10,10 +10,6 @@ var ( messageTypes map[int]func() base.IncomingMessage ) -var ( - _ base.IncomingMessage = (*base.IncomingMessageImpl)(nil) -) - func Init() { messageTypes = make(map[int]func() base.IncomingMessage) diff --git a/protocol/registry_entry.go b/protocol/registry_entry.go index c311a94..e9912aa 100644 --- a/protocol/registry_entry.go +++ b/protocol/registry_entry.go @@ -2,19 +2,17 @@ package protocol import ( "git.lumeweb.com/LumeWeb/libs5-go/interfaces" - "git.lumeweb.com/LumeWeb/libs5-go/net" "git.lumeweb.com/LumeWeb/libs5-go/protocol/base" "git.lumeweb.com/LumeWeb/libs5-go/types" "github.com/vmihailenco/msgpack/v5" ) -var _ base.IncomingMessageTyped = (*RegistryEntryRequest)(nil) +var _ base.IncomingMessage = (*RegistryEntryRequest)(nil) var _ base.EncodeableMessage = (*RegistryEntryRequest)(nil) type RegistryEntryRequest struct { sre interfaces.SignedRegistryEntry - base.IncomingMessageTypedImpl - base.IncomingMessageHandler + base.HandshakeRequirement } func NewEmptyRegistryEntryRequest() *RegistryEntryRequest { @@ -42,10 +40,8 @@ func (s *RegistryEntryRequest) EncodeMsgpack(enc *msgpack.Encoder) error { return nil } -func (s *RegistryEntryRequest) DecodeMessage(dec *msgpack.Decoder) error { - data := s.IncomingMessage().Original() - - sre, err := UnmarshalSignedRegistryEntry(data) +func (s *RegistryEntryRequest) DecodeMessage(dec *msgpack.Decoder, message base.IncomingMessageData) error { + sre, err := UnmarshalSignedRegistryEntry(message.Data) if err != nil { return err } @@ -55,6 +51,8 @@ func (s *RegistryEntryRequest) DecodeMessage(dec *msgpack.Decoder) error { return nil } -func (s *RegistryEntryRequest) HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error { +func (s *RegistryEntryRequest) HandleMessage(message base.IncomingMessageData) error { + node := message.Node + peer := message.Peer return node.Services().Registry().Set(s.sre, false, peer) } diff --git a/protocol/registry_query.go b/protocol/registry_query.go index d53d5ed..5b975ab 100644 --- a/protocol/registry_query.go +++ b/protocol/registry_query.go @@ -1,20 +1,17 @@ package protocol import ( - "git.lumeweb.com/LumeWeb/libs5-go/interfaces" - "git.lumeweb.com/LumeWeb/libs5-go/net" "git.lumeweb.com/LumeWeb/libs5-go/protocol/base" "git.lumeweb.com/LumeWeb/libs5-go/types" "github.com/vmihailenco/msgpack/v5" ) -var _ base.IncomingMessageTyped = (*RegistryQuery)(nil) +var _ base.IncomingMessage = (*RegistryQuery)(nil) var _ base.EncodeableMessage = (*RegistryQuery)(nil) type RegistryQuery struct { pk []byte - base.IncomingMessageTypedImpl - base.IncomingMessageHandler + base.HandshakeRequirement } func NewEmptyRegistryQuery() *RegistryQuery { @@ -42,7 +39,7 @@ func (s *RegistryQuery) EncodeMsgpack(enc *msgpack.Encoder) error { return nil } -func (s *RegistryQuery) DecodeMessage(dec *msgpack.Decoder) error { +func (s *RegistryQuery) DecodeMessage(dec *msgpack.Decoder, message base.IncomingMessageData) error { pk, err := dec.DecodeBytes() if err != nil { return err @@ -53,7 +50,9 @@ func (s *RegistryQuery) DecodeMessage(dec *msgpack.Decoder) error { return nil } -func (s *RegistryQuery) HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error { +func (s *RegistryQuery) HandleMessage(message base.IncomingMessageData) error { + node := message.Node + peer := message.Peer sre, err := node.Services().Registry().Get(s.pk) if err != nil { return err diff --git a/protocol/signed/announce_peers.go b/protocol/signed/announce_peers.go index 4115dea..f53f600 100644 --- a/protocol/signed/announce_peers.go +++ b/protocol/signed/announce_peers.go @@ -2,7 +2,6 @@ package signed import ( "git.lumeweb.com/LumeWeb/libs5-go/encoding" - "git.lumeweb.com/LumeWeb/libs5-go/interfaces" "git.lumeweb.com/LumeWeb/libs5-go/net" "git.lumeweb.com/LumeWeb/libs5-go/protocol/base" "git.lumeweb.com/LumeWeb/libs5-go/types" @@ -11,14 +10,14 @@ import ( ) var ( - _ base.IncomingMessageTyped = (*AnnouncePeers)(nil) + _ IncomingMessageSigned = (*AnnouncePeers)(nil) ) type AnnouncePeers struct { peer net.Peer connectionUris []*url.URL peersToSend []net.Peer - base.IncomingMessageTypedImpl + base.HandshakeRequirement } func (a *AnnouncePeers) PeersToSend() []net.Peer { @@ -41,7 +40,7 @@ func NewAnnouncePeers() *AnnouncePeers { return ap } -func (a *AnnouncePeers) DecodeMessage(dec *msgpack.Decoder) error { +func (a *AnnouncePeers) DecodeMessage(dec *msgpack.Decoder, message IncomingMessageDataSigned) error { // CIDFromString the number of peers. numPeers, err := dec.DecodeInt() if err != nil { @@ -106,7 +105,9 @@ func (a *AnnouncePeers) DecodeMessage(dec *msgpack.Decoder) error { return nil } -func (a AnnouncePeers) HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error { +func (a AnnouncePeers) HandleMessage(message IncomingMessageDataSigned) error { + node := message.Node + peer := message.Peer if len(a.connectionUris) > 0 { err := node.Services().P2P().ConnectToNode([]*url.URL{a.connectionUris[0]}, false, peer) if err != nil { diff --git a/protocol/signed/handshake_done.go b/protocol/signed/handshake_done.go index b342425..52e13fb 100644 --- a/protocol/signed/handshake_done.go +++ b/protocol/signed/handshake_done.go @@ -4,7 +4,6 @@ import ( "bytes" "errors" "fmt" - "git.lumeweb.com/LumeWeb/libs5-go/interfaces" "git.lumeweb.com/LumeWeb/libs5-go/net" "git.lumeweb.com/LumeWeb/libs5-go/protocol/base" "git.lumeweb.com/LumeWeb/libs5-go/types" @@ -13,17 +12,16 @@ import ( "net/url" ) -var _ base.IncomingMessageTyped = (*HandshakeDone)(nil) +var _ IncomingMessageSigned = (*HandshakeDone)(nil) var _ base.EncodeableMessage = (*HandshakeDone)(nil) type HandshakeDone struct { - challenge []byte - networkId string - base.IncomingMessageTypedImpl - base.IncomingMessageHandler + challenge []byte + networkId string supportedFeatures int connectionUris []*url.URL handshake []byte + base.HandshakeRequirement } func NewHandshakeDoneRequest(handshake []byte, supportedFeatures int, connectionUris []*url.URL) *HandshakeDone { @@ -78,7 +76,12 @@ func NewHandshakeDone() *HandshakeDone { return hn } -func (h HandshakeDone) HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error { +func (h HandshakeDone) HandleMessage(message IncomingMessageDataSigned) error { + node := message.Node + peer := message.Peer + verifyId := message.VerifyId + nodeId := message.NodeId + if !node.IsStarted() { err := peer.End() if err != nil { @@ -91,8 +94,6 @@ func (h HandshakeDone) HandleMessage(node interfaces.Node, peer net.Peer, verify return errors.New("Invalid challenge") } - nodeId := h.IncomingMessage().(*SignedMessage).NodeId() - if !verifyId { peer.SetId(nodeId) } else { @@ -130,7 +131,7 @@ func (h HandshakeDone) HandleMessage(node interfaces.Node, peer net.Peer, verify return nil } -func (h *HandshakeDone) DecodeMessage(dec *msgpack.Decoder) error { +func (h *HandshakeDone) DecodeMessage(dec *msgpack.Decoder, message IncomingMessageDataSigned) error { challenge, err := dec.DecodeBytes() if err != nil { return err diff --git a/protocol/signed/signed.go b/protocol/signed/signed.go index 261eb1e..5565698 100644 --- a/protocol/signed/signed.go +++ b/protocol/signed/signed.go @@ -1,33 +1,46 @@ package signed import ( + "git.lumeweb.com/LumeWeb/libs5-go/encoding" "git.lumeweb.com/LumeWeb/libs5-go/protocol/base" "git.lumeweb.com/LumeWeb/libs5-go/types" + "github.com/vmihailenco/msgpack/v5" ) +type IncomingMessageDataSigned struct { + base.IncomingMessageData + NodeId *encoding.NodeId +} + +type IncomingMessageSigned interface { + HandleMessage(message IncomingMessageDataSigned) error + DecodeMessage(dec *msgpack.Decoder, message IncomingMessageDataSigned) error + base.HandshakeRequirer +} + var ( - messageTypes map[int]func() base.SignedIncomingMessage + messageTypes map[int]func() IncomingMessageSigned ) func Init() { - messageTypes = make(map[int]func() base.SignedIncomingMessage) + messageTypes = make(map[int]func() IncomingMessageSigned) - RegisterMessageType(int(types.ProtocolMethodHandshakeDone), func() base.SignedIncomingMessage { + RegisterMessageType(int(types.ProtocolMethodHandshakeDone), func() IncomingMessageSigned { return NewHandshakeDone() }) - RegisterMessageType(int(types.ProtocolMethodAnnouncePeers), func() base.SignedIncomingMessage { + RegisterMessageType(int(types.ProtocolMethodAnnouncePeers), func() IncomingMessageSigned { return NewAnnouncePeers() }) } -func RegisterMessageType(messageType int, factoryFunc func() base.SignedIncomingMessage) { +func RegisterMessageType(messageType int, factoryFunc func() IncomingMessageSigned) { if factoryFunc == nil { panic("factoryFunc cannot be nil") } messageTypes[messageType] = factoryFunc } -func GetMessageType(kind int) (base.SignedIncomingMessage, bool) { +func GetMessageType(kind int) (IncomingMessageSigned, bool) { value, ok := messageTypes[kind] if !ok { return nil, false @@ -35,12 +48,3 @@ func GetMessageType(kind int) (base.SignedIncomingMessage, bool) { return value(), true } - -var ( - _ base.SignedIncomingMessage = (*IncomingMessageImpl)(nil) -) - -type IncomingMessageImpl struct { - base.IncomingMessageImpl - message []byte -} diff --git a/protocol/signed/signed_message.go b/protocol/signed/signed_message.go index dc98006..2565328 100644 --- a/protocol/signed/signed_message.go +++ b/protocol/signed/signed_message.go @@ -5,7 +5,6 @@ import ( "errors" "git.lumeweb.com/LumeWeb/libs5-go/encoding" "git.lumeweb.com/LumeWeb/libs5-go/interfaces" - "git.lumeweb.com/LumeWeb/libs5-go/net" "git.lumeweb.com/LumeWeb/libs5-go/protocol/base" "git.lumeweb.com/LumeWeb/libs5-go/types" "github.com/vmihailenco/msgpack/v5" @@ -14,9 +13,9 @@ import ( ) var ( - _ base.IncomingMessageTyped = (*SignedMessage)(nil) - _ msgpack.CustomDecoder = (*signedMessagePayoad)(nil) - _ msgpack.CustomEncoder = (*SignedMessage)(nil) + _ base.IncomingMessage = (*SignedMessage)(nil) + _ msgpack.CustomDecoder = (*signedMessageReader)(nil) + _ msgpack.CustomEncoder = (*SignedMessage)(nil) ) var ( @@ -27,7 +26,7 @@ type SignedMessage struct { nodeId *encoding.NodeId signature []byte message []byte - base.IncomingMessageTypedImpl + base.HandshakeRequirement } func (s *SignedMessage) NodeId() *encoding.NodeId { @@ -50,12 +49,12 @@ func NewSignedMessageRequest(message []byte) *SignedMessage { return &SignedMessage{message: message} } -type signedMessagePayoad struct { +type signedMessageReader struct { kind int message msgpack.RawMessage } -func (s *signedMessagePayoad) DecodeMsgpack(dec *msgpack.Decoder) error { +func (s *signedMessageReader) DecodeMsgpack(dec *msgpack.Decoder) error { kind, err := dec.DecodeInt() if err != nil { return err @@ -82,8 +81,10 @@ func NewSignedMessage() *SignedMessage { return sm } -func (s *SignedMessage) HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error { - var payload signedMessagePayoad +func (s *SignedMessage) HandleMessage(message base.IncomingMessageData) error { + var payload signedMessageReader + node := message.Node + peer := message.Peer err := msgpack.Unmarshal(s.message, &payload) if err != nil { @@ -96,14 +97,17 @@ func (s *SignedMessage) HandleMessage(node interfaces.Node, peer net.Peer, verif node.Logger().Debug("Peer is not handshake done, ignoring message", zap.Any("type", types.ProtocolMethodMap[types.ProtocolMethod(payload.kind)])) return nil } - msgHandler.SetIncomingMessage(s) - msgHandler.SetSelf(msgHandler) err := msgpack.Unmarshal(payload.message, &msgHandler) if err != nil { return err } - err = msgHandler.HandleMessage(node, peer, verifyId) + data := IncomingMessageDataSigned{ + IncomingMessageData: message, + NodeId: s.nodeId, + } + + err = msgHandler.HandleMessage(data) if err != nil { return err } @@ -112,7 +116,7 @@ func (s *SignedMessage) HandleMessage(node interfaces.Node, peer net.Peer, verif return nil } -func (s *SignedMessage) DecodeMessage(dec *msgpack.Decoder) error { +func (s *SignedMessage) DecodeMessage(dec *msgpack.Decoder, message base.IncomingMessageData) error { nodeId, err := dec.DecodeBytes() if err != nil { return err @@ -127,12 +131,12 @@ func (s *SignedMessage) DecodeMessage(dec *msgpack.Decoder) error { s.signature = signature - message, err := dec.DecodeBytes() + signedMessage, err := dec.DecodeBytes() if err != nil { return err } - s.message = message + s.message = signedMessage if !ed25519.Verify(s.nodeId.Raw()[1:], s.message, s.signature) { return errInvalidSignature diff --git a/protocol/storage_location.go b/protocol/storage_location.go index 536dca5..c76e9fb 100644 --- a/protocol/storage_location.go +++ b/protocol/storage_location.go @@ -4,7 +4,6 @@ import ( "crypto/ed25519" "fmt" "git.lumeweb.com/LumeWeb/libs5-go/encoding" - "git.lumeweb.com/LumeWeb/libs5-go/interfaces" "git.lumeweb.com/LumeWeb/libs5-go/net" "git.lumeweb.com/LumeWeb/libs5-go/protocol/base" "git.lumeweb.com/LumeWeb/libs5-go/storage" @@ -15,7 +14,7 @@ import ( "go.uber.org/zap" ) -var _ base.IncomingMessageTyped = (*StorageLocation)(nil) +var _ base.IncomingMessage = (*StorageLocation)(nil) type StorageLocation struct { hash *encoding.Multihash @@ -24,9 +23,7 @@ type StorageLocation struct { parts []string publicKey []byte signature []byte - - base.IncomingMessageTypedImpl - base.IncomingMessageHandler + base.HandshakeRequirement } func NewStorageLocation() *StorageLocation { @@ -37,12 +34,14 @@ func NewStorageLocation() *StorageLocation { return sl } -func (s *StorageLocation) DecodeMessage(dec *msgpack.Decoder) error { +func (s *StorageLocation) DecodeMessage(dec *msgpack.Decoder, message base.IncomingMessageData) error { // nop, we use the incoming message -> original already stored return nil } -func (s *StorageLocation) HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error { - msg := s.IncomingMessage().Original() +func (s *StorageLocation) HandleMessage(message base.IncomingMessageData) error { + msg := message.Original + node := message.Node + peer := message.Peer hash := encoding.NewMultihash(msg[1:34]) // Replace NewMultihash with appropriate function diff --git a/service/p2p.go b/service/p2p.go index 0c9cfd1..9ce07d0 100644 --- a/service/p2p.go +++ b/service/p2p.go @@ -1,6 +1,8 @@ package service import ( + "bytes" + "context" ed25519p "crypto/ed25519" "errors" "fmt" @@ -8,6 +10,7 @@ import ( "git.lumeweb.com/LumeWeb/libs5-go/encoding" "git.lumeweb.com/LumeWeb/libs5-go/interfaces" "git.lumeweb.com/LumeWeb/libs5-go/net" + "git.lumeweb.com/LumeWeb/libs5-go/node" "git.lumeweb.com/LumeWeb/libs5-go/protocol" "git.lumeweb.com/LumeWeb/libs5-go/protocol/base" "git.lumeweb.com/LumeWeb/libs5-go/protocol/signed" @@ -398,15 +401,6 @@ func (p *P2PImpl) OnNewPeer(peer net.Peer, verifyId bool) error { return nil } func (p *P2PImpl) OnNewPeerListen(peer net.Peer, verifyId bool) { - - var pid string - - if peer.Id() != nil { - pid, _ = peer.Id().ToString() - } else { - pid = "unknown" - } - onDone := net.CloseCallback(func() { if peer.Id() != nil { pid, err := peer.Id().ToString() @@ -431,32 +425,42 @@ func (p *P2PImpl) OnNewPeerListen(peer net.Peer, verifyId bool) { }) peer.ListenForMessages(func(message []byte) error { - imsg := base.NewIncomingMessageUnknown() + var reader base.IncomingMessageReader - err := msgpack.Unmarshal(message, imsg) - p.logger.Debug("ListenForMessages", zap.Any("message", imsg), zap.String("peer", pid)) + err := msgpack.Unmarshal(message, &reader) if err != nil { + p.logger.Error("Error decoding basic message info", zap.Error(err)) return err } - handler, ok := protocol.GetMessageType(imsg.Kind()) + // Now, get the specific message handler based on the message kind + handler, ok := protocol.GetMessageType(reader.Kind) + if !ok { + p.logger.Error("Unknown message type", zap.Int("type", reader.Kind)) + return fmt.Errorf("unknown message type: %d", reader.Kind) + } - if ok { - if handler.RequiresHandshake() && !peer.IsHandshakeDone() { - p.logger.Debug("Peer is not handshake done, ignoring message", zap.Any("type", types.ProtocolMethodMap[types.ProtocolMethod(imsg.Kind())])) - return nil - } - imsg.SetOriginal(message) - handler.SetIncomingMessage(imsg) - handler.SetSelf(handler) - err := msgpack.Unmarshal(imsg.Data(), handler) - if err != nil { - return err - } - err = handler.HandleMessage(p.node, peer, verifyId) - if err != nil { - return err - } + data := base.IncomingMessageData{ + Original: message, + Data: reader.Data, + Ctx: context.Background(), + Node: p.node.(*node.NodeImpl), + Peer: peer, + VerifyId: verifyId, + } + + dec := msgpack.NewDecoder(bytes.NewReader(reader.Data)) + + err = handler.DecodeMessage(dec, data) + if err != nil { + p.logger.Error("Error decoding message", zap.Error(err)) + return err + } + + // Directly decode and handle the specific message type + if err := handler.HandleMessage(data); err != nil { + p.logger.Error("Error handling message", zap.Error(err)) + return err } return nil @@ -465,7 +469,6 @@ func (p *P2PImpl) OnNewPeerListen(peer net.Peer, verifyId bool) { OnError: &onError, Logger: p.logger, }) - } func (p *P2PImpl) readNodeVotes(nodeId *encoding.NodeId) (interfaces.NodeVotes, error) {