From 3d41119f74514c8168193c0edfcae136282a17a8 Mon Sep 17 00:00:00 2001 From: Derrick Hammer Date: Sat, 13 Jan 2024 11:22:01 -0500 Subject: [PATCH] fix: add a new property on messages and peers to prevent messages from being processed before the handshake is done --- net/peer.go | 11 +++++++++++ protocol/base/base.go | 2 ++ protocol/base/incoming_message.go | 21 +++++++++++++++------ protocol/handshake_open.go | 6 +++++- protocol/hash_query.go | 6 +++++- protocol/registry_entry.go | 6 +++++- protocol/registry_query.go | 6 +++++- protocol/signed/accounce_peers.go | 6 +++++- protocol/signed/handshake_done.go | 13 +++++++++++-- protocol/signed/signed_message.go | 10 +++++++++- protocol/storage_location.go | 6 +++++- service/p2p.go | 5 ++++- types/protocol.go | 2 +- 13 files changed, 83 insertions(+), 17 deletions(-) diff --git a/net/peer.go b/net/peer.go index 3ab2e0f..b9efc4a 100644 --- a/net/peer.go +++ b/net/peer.go @@ -43,6 +43,8 @@ type Peer interface { IsConnected() bool SetConnectionURIs(uris []*url.URL) ConnectionURIs() []*url.URL + IsHandshakeDone() bool + SetHandshakeDone(status bool) } type BasePeer struct { @@ -51,6 +53,7 @@ type BasePeer struct { challenge []byte socket interface{} id *encoding.NodeId + handshaked bool } func (b *BasePeer) IsConnected() bool { @@ -106,3 +109,11 @@ func (b *BasePeer) SetConnectionURIs(uris []*url.URL) { func (b *BasePeer) ConnectionURIs() []*url.URL { return b.connectionURIs } + +func (b *BasePeer) IsHandshakeDone() bool { + return b.handshaked +} + +func (b *BasePeer) SetHandshakeDone(status bool) { + b.handshaked = status +} diff --git a/protocol/base/base.go b/protocol/base/base.go index 3fbd3c2..d189d81 100644 --- a/protocol/base/base.go +++ b/protocol/base/base.go @@ -16,6 +16,8 @@ type IncomingMessage interface { SetSelf(self IncomingMessage) Original() []byte Kind() int + RequiresHandshake() bool + SetRequiresHandshake(value bool) msgpack.CustomDecoder } diff --git a/protocol/base/incoming_message.go b/protocol/base/incoming_message.go index 06628a2..a5dec51 100644 --- a/protocol/base/incoming_message.go +++ b/protocol/base/incoming_message.go @@ -18,12 +18,13 @@ var _ IncomingMessageTyped = (*IncomingMessageImpl)(nil) type IncomingMessageHandler func(node interfaces.Node, peer *net.Peer, u *url.URL, verifyId bool) error type IncomingMessageImpl struct { - kind int - data msgpack.RawMessage - original []byte - known bool - self IncomingMessage - incoming IncomingMessage + kind int + data msgpack.RawMessage + original []byte + known bool + self IncomingMessage + incoming IncomingMessage + requiresHandshake bool } func (i *IncomingMessageImpl) Self() IncomingMessage { @@ -126,3 +127,11 @@ func (i *IncomingMessageImpl) DecodeMsgpack(dec *msgpack.Decoder) error { i.data = raw return nil } + +func (i *IncomingMessageImpl) RequiresHandshake() bool { + return i.requiresHandshake +} + +func (i *IncomingMessageImpl) SetRequiresHandshake(value bool) { + i.requiresHandshake = value +} diff --git a/protocol/handshake_open.go b/protocol/handshake_open.go index a4f9850..1042357 100644 --- a/protocol/handshake_open.go +++ b/protocol/handshake_open.go @@ -39,10 +39,14 @@ var ( ) func NewHandshakeOpen(challenge []byte, networkId string) *HandshakeOpen { - return &HandshakeOpen{ + ho := &HandshakeOpen{ challenge: challenge, networkId: networkId, } + + ho.SetRequiresHandshake(false) + + return ho } func (h HandshakeOpen) EncodeMsgpack(enc *msgpack.Encoder) error { err := enc.EncodeUint(uint64(types.ProtocolMethodHandshakeOpen)) diff --git a/protocol/hash_query.go b/protocol/hash_query.go index db5694c..cceda06 100644 --- a/protocol/hash_query.go +++ b/protocol/hash_query.go @@ -24,7 +24,11 @@ type HashQuery struct { } func NewHashQuery() *HashQuery { - return &HashQuery{} + hq := &HashQuery{} + + hq.SetRequiresHandshake(true) + + return hq } func NewHashRequest(hash *encoding.Multihash, kinds []types.StorageLocationType) *HashQuery { diff --git a/protocol/registry_entry.go b/protocol/registry_entry.go index 90e45ed..c311a94 100644 --- a/protocol/registry_entry.go +++ b/protocol/registry_entry.go @@ -18,7 +18,11 @@ type RegistryEntryRequest struct { } func NewEmptyRegistryEntryRequest() *RegistryEntryRequest { - return &RegistryEntryRequest{} + rer := &RegistryEntryRequest{} + + rer.SetRequiresHandshake(true) + + return rer } func NewRegistryEntryRequest(sre interfaces.SignedRegistryEntry) *RegistryEntryRequest { return &RegistryEntryRequest{sre: sre} diff --git a/protocol/registry_query.go b/protocol/registry_query.go index fa8a83b..d53d5ed 100644 --- a/protocol/registry_query.go +++ b/protocol/registry_query.go @@ -18,7 +18,11 @@ type RegistryQuery struct { } func NewEmptyRegistryQuery() *RegistryQuery { - return &RegistryQuery{} + rq := &RegistryQuery{} + + rq.SetRequiresHandshake(true) + + return rq } func NewRegistryQuery(pk []byte) *RegistryQuery { return &RegistryQuery{pk: pk} diff --git a/protocol/signed/accounce_peers.go b/protocol/signed/accounce_peers.go index 1b18a08..6c84872 100644 --- a/protocol/signed/accounce_peers.go +++ b/protocol/signed/accounce_peers.go @@ -34,7 +34,11 @@ func NewAnnounceRequest(peer net.Peer, peersToSend []net.Peer) *AnnouncePeers { } func NewAnnouncePeers() *AnnouncePeers { - return &AnnouncePeers{peer: nil, connectionUris: nil} + ap := &AnnouncePeers{peer: nil, connectionUris: nil} + + ap.SetRequiresHandshake(false) + + return ap } func (a *AnnouncePeers) DecodeMessage(dec *msgpack.Decoder) error { diff --git a/protocol/signed/handshake_done.go b/protocol/signed/handshake_done.go index 57d174c..b342425 100644 --- a/protocol/signed/handshake_done.go +++ b/protocol/signed/handshake_done.go @@ -27,11 +27,15 @@ type HandshakeDone struct { } func NewHandshakeDoneRequest(handshake []byte, supportedFeatures int, connectionUris []*url.URL) *HandshakeDone { - return &HandshakeDone{ + ho := &HandshakeDone{ handshake: handshake, supportedFeatures: supportedFeatures, connectionUris: connectionUris, } + + ho.SetRequiresHandshake(false) + + return ho } func (m HandshakeDone) EncodeMsgpack(enc *msgpack.Encoder) error { @@ -67,7 +71,11 @@ func (m *HandshakeDone) SetNetworkId(networkId string) { } func NewHandshakeDone() *HandshakeDone { - return &HandshakeDone{challenge: nil, networkId: "", supportedFeatures: -1} + hn := &HandshakeDone{challenge: nil, networkId: "", supportedFeatures: -1} + + hn.SetRequiresHandshake(false) + + return hn } func (h HandshakeDone) HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error { @@ -94,6 +102,7 @@ func (h HandshakeDone) HandleMessage(node interfaces.Node, peer net.Peer, verify } peer.SetConnected(true) + peer.SetHandshakeDone(true) if h.supportedFeatures != types.SupportedFeatures { return fmt.Errorf("Remote node does not support required features") diff --git a/protocol/signed/signed_message.go b/protocol/signed/signed_message.go index 091b146..dc98006 100644 --- a/protocol/signed/signed_message.go +++ b/protocol/signed/signed_message.go @@ -75,7 +75,11 @@ func (s *signedMessagePayoad) DecodeMsgpack(dec *msgpack.Decoder) error { } func NewSignedMessage() *SignedMessage { - return &SignedMessage{} + sm := &SignedMessage{} + + sm.SetRequiresHandshake(false) + + return sm } func (s *SignedMessage) HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error { @@ -88,6 +92,10 @@ func (s *SignedMessage) HandleMessage(node interfaces.Node, peer net.Peer, verif if msgHandler, valid := GetMessageType(payload.kind); valid { node.Logger().Debug("SignedMessage", zap.Any("type", types.ProtocolMethodMap[types.ProtocolMethod(payload.kind)])) + if msgHandler.RequiresHandshake() && !peer.IsHandshakeDone() { + node.Logger().Debug("Peer is not handshake done, ignoring message", zap.Any("type", types.ProtocolMethodMap[types.ProtocolMethod(payload.kind)])) + return nil + } msgHandler.SetIncomingMessage(s) msgHandler.SetSelf(msgHandler) err := msgpack.Unmarshal(payload.message, &msgHandler) diff --git a/protocol/storage_location.go b/protocol/storage_location.go index eb285eb..f74f190 100644 --- a/protocol/storage_location.go +++ b/protocol/storage_location.go @@ -30,7 +30,11 @@ type StorageLocation struct { } func NewStorageLocation() *StorageLocation { - return &StorageLocation{} + sl := &StorageLocation{} + + sl.SetRequiresHandshake(true) + + return sl } func (s *StorageLocation) DecodeMessage(dec *msgpack.Decoder) error { diff --git a/service/p2p.go b/service/p2p.go index 6191be9..01df1db 100644 --- a/service/p2p.go +++ b/service/p2p.go @@ -326,7 +326,10 @@ func (p *P2PImpl) OnNewPeerListen(peer net.Peer, verifyId bool) { handler, ok := protocol.GetMessageType(imsg.Kind()) if ok { - + if handler.RequiresHandshake() && !peer.IsHandshakeDone() { + p.logger.Debug("Peer is not handshake done, ignoring message", zap.Any("type", types.ProtocolMethodMap[types.ProtocolMethod(imsg.Kind())])) + return nil + } imsg.SetOriginal(message) handler.SetIncomingMessage(imsg) handler.SetSelf(handler) diff --git a/types/protocol.go b/types/protocol.go index 5b1a6f9..4e67e4a 100644 --- a/types/protocol.go +++ b/types/protocol.go @@ -16,7 +16,7 @@ const ( var ProtocolMethodMap = map[ProtocolMethod]string{ ProtocolMethodHandshakeOpen: "HandshakeOpen", - ProtocolMethodHandshakeDone: "HandshakeDone", + ProtocolMethodHandshakeDone: "IsHandshakeDone", ProtocolMethodSignedMessage: "SignedMessage", ProtocolMethodHashQuery: "HashQuery", ProtocolMethodAnnouncePeers: "AnnouncePeers",