fix: add a new property on messages and peers to prevent messages from being processed before the handshake is done

This commit is contained in:
Derrick Hammer 2024-01-13 11:22:01 -05:00
parent 36f087dc83
commit 3d41119f74
Signed by: pcfreak30
GPG Key ID: C997C339BE476FF2
13 changed files with 83 additions and 17 deletions

View File

@ -43,6 +43,8 @@ type Peer interface {
IsConnected() bool
SetConnectionURIs(uris []*url.URL)
ConnectionURIs() []*url.URL
IsHandshakeDone() bool
SetHandshakeDone(status bool)
}
type BasePeer struct {
@ -51,6 +53,7 @@ type BasePeer struct {
challenge []byte
socket interface{}
id *encoding.NodeId
handshaked bool
}
func (b *BasePeer) IsConnected() bool {
@ -106,3 +109,11 @@ func (b *BasePeer) SetConnectionURIs(uris []*url.URL) {
func (b *BasePeer) ConnectionURIs() []*url.URL {
return b.connectionURIs
}
func (b *BasePeer) IsHandshakeDone() bool {
return b.handshaked
}
func (b *BasePeer) SetHandshakeDone(status bool) {
b.handshaked = status
}

View File

@ -16,6 +16,8 @@ type IncomingMessage interface {
SetSelf(self IncomingMessage)
Original() []byte
Kind() int
RequiresHandshake() bool
SetRequiresHandshake(value bool)
msgpack.CustomDecoder
}

View File

@ -24,6 +24,7 @@ type IncomingMessageImpl struct {
known bool
self IncomingMessage
incoming IncomingMessage
requiresHandshake bool
}
func (i *IncomingMessageImpl) Self() IncomingMessage {
@ -126,3 +127,11 @@ func (i *IncomingMessageImpl) DecodeMsgpack(dec *msgpack.Decoder) error {
i.data = raw
return nil
}
func (i *IncomingMessageImpl) RequiresHandshake() bool {
return i.requiresHandshake
}
func (i *IncomingMessageImpl) SetRequiresHandshake(value bool) {
i.requiresHandshake = value
}

View File

@ -39,10 +39,14 @@ var (
)
func NewHandshakeOpen(challenge []byte, networkId string) *HandshakeOpen {
return &HandshakeOpen{
ho := &HandshakeOpen{
challenge: challenge,
networkId: networkId,
}
ho.SetRequiresHandshake(false)
return ho
}
func (h HandshakeOpen) EncodeMsgpack(enc *msgpack.Encoder) error {
err := enc.EncodeUint(uint64(types.ProtocolMethodHandshakeOpen))

View File

@ -24,7 +24,11 @@ type HashQuery struct {
}
func NewHashQuery() *HashQuery {
return &HashQuery{}
hq := &HashQuery{}
hq.SetRequiresHandshake(true)
return hq
}
func NewHashRequest(hash *encoding.Multihash, kinds []types.StorageLocationType) *HashQuery {

View File

@ -18,7 +18,11 @@ type RegistryEntryRequest struct {
}
func NewEmptyRegistryEntryRequest() *RegistryEntryRequest {
return &RegistryEntryRequest{}
rer := &RegistryEntryRequest{}
rer.SetRequiresHandshake(true)
return rer
}
func NewRegistryEntryRequest(sre interfaces.SignedRegistryEntry) *RegistryEntryRequest {
return &RegistryEntryRequest{sre: sre}

View File

@ -18,7 +18,11 @@ type RegistryQuery struct {
}
func NewEmptyRegistryQuery() *RegistryQuery {
return &RegistryQuery{}
rq := &RegistryQuery{}
rq.SetRequiresHandshake(true)
return rq
}
func NewRegistryQuery(pk []byte) *RegistryQuery {
return &RegistryQuery{pk: pk}

View File

@ -34,7 +34,11 @@ func NewAnnounceRequest(peer net.Peer, peersToSend []net.Peer) *AnnouncePeers {
}
func NewAnnouncePeers() *AnnouncePeers {
return &AnnouncePeers{peer: nil, connectionUris: nil}
ap := &AnnouncePeers{peer: nil, connectionUris: nil}
ap.SetRequiresHandshake(false)
return ap
}
func (a *AnnouncePeers) DecodeMessage(dec *msgpack.Decoder) error {

View File

@ -27,11 +27,15 @@ type HandshakeDone struct {
}
func NewHandshakeDoneRequest(handshake []byte, supportedFeatures int, connectionUris []*url.URL) *HandshakeDone {
return &HandshakeDone{
ho := &HandshakeDone{
handshake: handshake,
supportedFeatures: supportedFeatures,
connectionUris: connectionUris,
}
ho.SetRequiresHandshake(false)
return ho
}
func (m HandshakeDone) EncodeMsgpack(enc *msgpack.Encoder) error {
@ -67,7 +71,11 @@ func (m *HandshakeDone) SetNetworkId(networkId string) {
}
func NewHandshakeDone() *HandshakeDone {
return &HandshakeDone{challenge: nil, networkId: "", supportedFeatures: -1}
hn := &HandshakeDone{challenge: nil, networkId: "", supportedFeatures: -1}
hn.SetRequiresHandshake(false)
return hn
}
func (h HandshakeDone) HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error {
@ -94,6 +102,7 @@ func (h HandshakeDone) HandleMessage(node interfaces.Node, peer net.Peer, verify
}
peer.SetConnected(true)
peer.SetHandshakeDone(true)
if h.supportedFeatures != types.SupportedFeatures {
return fmt.Errorf("Remote node does not support required features")

View File

@ -75,7 +75,11 @@ func (s *signedMessagePayoad) DecodeMsgpack(dec *msgpack.Decoder) error {
}
func NewSignedMessage() *SignedMessage {
return &SignedMessage{}
sm := &SignedMessage{}
sm.SetRequiresHandshake(false)
return sm
}
func (s *SignedMessage) HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error {
@ -88,6 +92,10 @@ func (s *SignedMessage) HandleMessage(node interfaces.Node, peer net.Peer, verif
if msgHandler, valid := GetMessageType(payload.kind); valid {
node.Logger().Debug("SignedMessage", zap.Any("type", types.ProtocolMethodMap[types.ProtocolMethod(payload.kind)]))
if msgHandler.RequiresHandshake() && !peer.IsHandshakeDone() {
node.Logger().Debug("Peer is not handshake done, ignoring message", zap.Any("type", types.ProtocolMethodMap[types.ProtocolMethod(payload.kind)]))
return nil
}
msgHandler.SetIncomingMessage(s)
msgHandler.SetSelf(msgHandler)
err := msgpack.Unmarshal(payload.message, &msgHandler)

View File

@ -30,7 +30,11 @@ type StorageLocation struct {
}
func NewStorageLocation() *StorageLocation {
return &StorageLocation{}
sl := &StorageLocation{}
sl.SetRequiresHandshake(true)
return sl
}
func (s *StorageLocation) DecodeMessage(dec *msgpack.Decoder) error {

View File

@ -326,7 +326,10 @@ func (p *P2PImpl) OnNewPeerListen(peer net.Peer, verifyId bool) {
handler, ok := protocol.GetMessageType(imsg.Kind())
if ok {
if handler.RequiresHandshake() && !peer.IsHandshakeDone() {
p.logger.Debug("Peer is not handshake done, ignoring message", zap.Any("type", types.ProtocolMethodMap[types.ProtocolMethod(imsg.Kind())]))
return nil
}
imsg.SetOriginal(message)
handler.SetIncomingMessage(imsg)
handler.SetSelf(handler)

View File

@ -16,7 +16,7 @@ const (
var ProtocolMethodMap = map[ProtocolMethod]string{
ProtocolMethodHandshakeOpen: "HandshakeOpen",
ProtocolMethodHandshakeDone: "HandshakeDone",
ProtocolMethodHandshakeDone: "IsHandshakeDone",
ProtocolMethodSignedMessage: "SignedMessage",
ProtocolMethodHashQuery: "HashQuery",
ProtocolMethodAnnouncePeers: "AnnouncePeers",