fix: add a new property on messages and peers to prevent messages from being processed before the handshake is done

This commit is contained in:
Derrick Hammer 2024-01-13 11:22:01 -05:00
parent 36f087dc83
commit 3d41119f74
Signed by: pcfreak30
GPG Key ID: C997C339BE476FF2
13 changed files with 83 additions and 17 deletions

View File

@ -43,6 +43,8 @@ type Peer interface {
IsConnected() bool IsConnected() bool
SetConnectionURIs(uris []*url.URL) SetConnectionURIs(uris []*url.URL)
ConnectionURIs() []*url.URL ConnectionURIs() []*url.URL
IsHandshakeDone() bool
SetHandshakeDone(status bool)
} }
type BasePeer struct { type BasePeer struct {
@ -51,6 +53,7 @@ type BasePeer struct {
challenge []byte challenge []byte
socket interface{} socket interface{}
id *encoding.NodeId id *encoding.NodeId
handshaked bool
} }
func (b *BasePeer) IsConnected() bool { func (b *BasePeer) IsConnected() bool {
@ -106,3 +109,11 @@ func (b *BasePeer) SetConnectionURIs(uris []*url.URL) {
func (b *BasePeer) ConnectionURIs() []*url.URL { func (b *BasePeer) ConnectionURIs() []*url.URL {
return b.connectionURIs return b.connectionURIs
} }
func (b *BasePeer) IsHandshakeDone() bool {
return b.handshaked
}
func (b *BasePeer) SetHandshakeDone(status bool) {
b.handshaked = status
}

View File

@ -16,6 +16,8 @@ type IncomingMessage interface {
SetSelf(self IncomingMessage) SetSelf(self IncomingMessage)
Original() []byte Original() []byte
Kind() int Kind() int
RequiresHandshake() bool
SetRequiresHandshake(value bool)
msgpack.CustomDecoder msgpack.CustomDecoder
} }

View File

@ -18,12 +18,13 @@ var _ IncomingMessageTyped = (*IncomingMessageImpl)(nil)
type IncomingMessageHandler func(node interfaces.Node, peer *net.Peer, u *url.URL, verifyId bool) error type IncomingMessageHandler func(node interfaces.Node, peer *net.Peer, u *url.URL, verifyId bool) error
type IncomingMessageImpl struct { type IncomingMessageImpl struct {
kind int kind int
data msgpack.RawMessage data msgpack.RawMessage
original []byte original []byte
known bool known bool
self IncomingMessage self IncomingMessage
incoming IncomingMessage incoming IncomingMessage
requiresHandshake bool
} }
func (i *IncomingMessageImpl) Self() IncomingMessage { func (i *IncomingMessageImpl) Self() IncomingMessage {
@ -126,3 +127,11 @@ func (i *IncomingMessageImpl) DecodeMsgpack(dec *msgpack.Decoder) error {
i.data = raw i.data = raw
return nil return nil
} }
func (i *IncomingMessageImpl) RequiresHandshake() bool {
return i.requiresHandshake
}
func (i *IncomingMessageImpl) SetRequiresHandshake(value bool) {
i.requiresHandshake = value
}

View File

@ -39,10 +39,14 @@ var (
) )
func NewHandshakeOpen(challenge []byte, networkId string) *HandshakeOpen { func NewHandshakeOpen(challenge []byte, networkId string) *HandshakeOpen {
return &HandshakeOpen{ ho := &HandshakeOpen{
challenge: challenge, challenge: challenge,
networkId: networkId, networkId: networkId,
} }
ho.SetRequiresHandshake(false)
return ho
} }
func (h HandshakeOpen) EncodeMsgpack(enc *msgpack.Encoder) error { func (h HandshakeOpen) EncodeMsgpack(enc *msgpack.Encoder) error {
err := enc.EncodeUint(uint64(types.ProtocolMethodHandshakeOpen)) err := enc.EncodeUint(uint64(types.ProtocolMethodHandshakeOpen))

View File

@ -24,7 +24,11 @@ type HashQuery struct {
} }
func NewHashQuery() *HashQuery { func NewHashQuery() *HashQuery {
return &HashQuery{} hq := &HashQuery{}
hq.SetRequiresHandshake(true)
return hq
} }
func NewHashRequest(hash *encoding.Multihash, kinds []types.StorageLocationType) *HashQuery { func NewHashRequest(hash *encoding.Multihash, kinds []types.StorageLocationType) *HashQuery {

View File

@ -18,7 +18,11 @@ type RegistryEntryRequest struct {
} }
func NewEmptyRegistryEntryRequest() *RegistryEntryRequest { func NewEmptyRegistryEntryRequest() *RegistryEntryRequest {
return &RegistryEntryRequest{} rer := &RegistryEntryRequest{}
rer.SetRequiresHandshake(true)
return rer
} }
func NewRegistryEntryRequest(sre interfaces.SignedRegistryEntry) *RegistryEntryRequest { func NewRegistryEntryRequest(sre interfaces.SignedRegistryEntry) *RegistryEntryRequest {
return &RegistryEntryRequest{sre: sre} return &RegistryEntryRequest{sre: sre}

View File

@ -18,7 +18,11 @@ type RegistryQuery struct {
} }
func NewEmptyRegistryQuery() *RegistryQuery { func NewEmptyRegistryQuery() *RegistryQuery {
return &RegistryQuery{} rq := &RegistryQuery{}
rq.SetRequiresHandshake(true)
return rq
} }
func NewRegistryQuery(pk []byte) *RegistryQuery { func NewRegistryQuery(pk []byte) *RegistryQuery {
return &RegistryQuery{pk: pk} return &RegistryQuery{pk: pk}

View File

@ -34,7 +34,11 @@ func NewAnnounceRequest(peer net.Peer, peersToSend []net.Peer) *AnnouncePeers {
} }
func NewAnnouncePeers() *AnnouncePeers { func NewAnnouncePeers() *AnnouncePeers {
return &AnnouncePeers{peer: nil, connectionUris: nil} ap := &AnnouncePeers{peer: nil, connectionUris: nil}
ap.SetRequiresHandshake(false)
return ap
} }
func (a *AnnouncePeers) DecodeMessage(dec *msgpack.Decoder) error { func (a *AnnouncePeers) DecodeMessage(dec *msgpack.Decoder) error {

View File

@ -27,11 +27,15 @@ type HandshakeDone struct {
} }
func NewHandshakeDoneRequest(handshake []byte, supportedFeatures int, connectionUris []*url.URL) *HandshakeDone { func NewHandshakeDoneRequest(handshake []byte, supportedFeatures int, connectionUris []*url.URL) *HandshakeDone {
return &HandshakeDone{ ho := &HandshakeDone{
handshake: handshake, handshake: handshake,
supportedFeatures: supportedFeatures, supportedFeatures: supportedFeatures,
connectionUris: connectionUris, connectionUris: connectionUris,
} }
ho.SetRequiresHandshake(false)
return ho
} }
func (m HandshakeDone) EncodeMsgpack(enc *msgpack.Encoder) error { func (m HandshakeDone) EncodeMsgpack(enc *msgpack.Encoder) error {
@ -67,7 +71,11 @@ func (m *HandshakeDone) SetNetworkId(networkId string) {
} }
func NewHandshakeDone() *HandshakeDone { func NewHandshakeDone() *HandshakeDone {
return &HandshakeDone{challenge: nil, networkId: "", supportedFeatures: -1} hn := &HandshakeDone{challenge: nil, networkId: "", supportedFeatures: -1}
hn.SetRequiresHandshake(false)
return hn
} }
func (h HandshakeDone) HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error { func (h HandshakeDone) HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error {
@ -94,6 +102,7 @@ func (h HandshakeDone) HandleMessage(node interfaces.Node, peer net.Peer, verify
} }
peer.SetConnected(true) peer.SetConnected(true)
peer.SetHandshakeDone(true)
if h.supportedFeatures != types.SupportedFeatures { if h.supportedFeatures != types.SupportedFeatures {
return fmt.Errorf("Remote node does not support required features") return fmt.Errorf("Remote node does not support required features")

View File

@ -75,7 +75,11 @@ func (s *signedMessagePayoad) DecodeMsgpack(dec *msgpack.Decoder) error {
} }
func NewSignedMessage() *SignedMessage { func NewSignedMessage() *SignedMessage {
return &SignedMessage{} sm := &SignedMessage{}
sm.SetRequiresHandshake(false)
return sm
} }
func (s *SignedMessage) HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error { func (s *SignedMessage) HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error {
@ -88,6 +92,10 @@ func (s *SignedMessage) HandleMessage(node interfaces.Node, peer net.Peer, verif
if msgHandler, valid := GetMessageType(payload.kind); valid { if msgHandler, valid := GetMessageType(payload.kind); valid {
node.Logger().Debug("SignedMessage", zap.Any("type", types.ProtocolMethodMap[types.ProtocolMethod(payload.kind)])) node.Logger().Debug("SignedMessage", zap.Any("type", types.ProtocolMethodMap[types.ProtocolMethod(payload.kind)]))
if msgHandler.RequiresHandshake() && !peer.IsHandshakeDone() {
node.Logger().Debug("Peer is not handshake done, ignoring message", zap.Any("type", types.ProtocolMethodMap[types.ProtocolMethod(payload.kind)]))
return nil
}
msgHandler.SetIncomingMessage(s) msgHandler.SetIncomingMessage(s)
msgHandler.SetSelf(msgHandler) msgHandler.SetSelf(msgHandler)
err := msgpack.Unmarshal(payload.message, &msgHandler) err := msgpack.Unmarshal(payload.message, &msgHandler)

View File

@ -30,7 +30,11 @@ type StorageLocation struct {
} }
func NewStorageLocation() *StorageLocation { func NewStorageLocation() *StorageLocation {
return &StorageLocation{} sl := &StorageLocation{}
sl.SetRequiresHandshake(true)
return sl
} }
func (s *StorageLocation) DecodeMessage(dec *msgpack.Decoder) error { func (s *StorageLocation) DecodeMessage(dec *msgpack.Decoder) error {

View File

@ -326,7 +326,10 @@ func (p *P2PImpl) OnNewPeerListen(peer net.Peer, verifyId bool) {
handler, ok := protocol.GetMessageType(imsg.Kind()) handler, ok := protocol.GetMessageType(imsg.Kind())
if ok { if ok {
if handler.RequiresHandshake() && !peer.IsHandshakeDone() {
p.logger.Debug("Peer is not handshake done, ignoring message", zap.Any("type", types.ProtocolMethodMap[types.ProtocolMethod(imsg.Kind())]))
return nil
}
imsg.SetOriginal(message) imsg.SetOriginal(message)
handler.SetIncomingMessage(imsg) handler.SetIncomingMessage(imsg)
handler.SetSelf(handler) handler.SetSelf(handler)

View File

@ -16,7 +16,7 @@ const (
var ProtocolMethodMap = map[ProtocolMethod]string{ var ProtocolMethodMap = map[ProtocolMethod]string{
ProtocolMethodHandshakeOpen: "HandshakeOpen", ProtocolMethodHandshakeOpen: "HandshakeOpen",
ProtocolMethodHandshakeDone: "HandshakeDone", ProtocolMethodHandshakeDone: "IsHandshakeDone",
ProtocolMethodSignedMessage: "SignedMessage", ProtocolMethodSignedMessage: "SignedMessage",
ProtocolMethodHashQuery: "HashQuery", ProtocolMethodHashQuery: "HashQuery",
ProtocolMethodAnnouncePeers: "AnnouncePeers", ProtocolMethodAnnouncePeers: "AnnouncePeers",