fix: add a new property on messages and peers to prevent messages from being processed before the handshake is done
This commit is contained in:
parent
36f087dc83
commit
3d41119f74
11
net/peer.go
11
net/peer.go
|
@ -43,6 +43,8 @@ type Peer interface {
|
|||
IsConnected() bool
|
||||
SetConnectionURIs(uris []*url.URL)
|
||||
ConnectionURIs() []*url.URL
|
||||
IsHandshakeDone() bool
|
||||
SetHandshakeDone(status bool)
|
||||
}
|
||||
|
||||
type BasePeer struct {
|
||||
|
@ -51,6 +53,7 @@ type BasePeer struct {
|
|||
challenge []byte
|
||||
socket interface{}
|
||||
id *encoding.NodeId
|
||||
handshaked bool
|
||||
}
|
||||
|
||||
func (b *BasePeer) IsConnected() bool {
|
||||
|
@ -106,3 +109,11 @@ func (b *BasePeer) SetConnectionURIs(uris []*url.URL) {
|
|||
func (b *BasePeer) ConnectionURIs() []*url.URL {
|
||||
return b.connectionURIs
|
||||
}
|
||||
|
||||
func (b *BasePeer) IsHandshakeDone() bool {
|
||||
return b.handshaked
|
||||
}
|
||||
|
||||
func (b *BasePeer) SetHandshakeDone(status bool) {
|
||||
b.handshaked = status
|
||||
}
|
||||
|
|
|
@ -16,6 +16,8 @@ type IncomingMessage interface {
|
|||
SetSelf(self IncomingMessage)
|
||||
Original() []byte
|
||||
Kind() int
|
||||
RequiresHandshake() bool
|
||||
SetRequiresHandshake(value bool)
|
||||
msgpack.CustomDecoder
|
||||
}
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ type IncomingMessageImpl struct {
|
|||
known bool
|
||||
self IncomingMessage
|
||||
incoming IncomingMessage
|
||||
requiresHandshake bool
|
||||
}
|
||||
|
||||
func (i *IncomingMessageImpl) Self() IncomingMessage {
|
||||
|
@ -126,3 +127,11 @@ func (i *IncomingMessageImpl) DecodeMsgpack(dec *msgpack.Decoder) error {
|
|||
i.data = raw
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *IncomingMessageImpl) RequiresHandshake() bool {
|
||||
return i.requiresHandshake
|
||||
}
|
||||
|
||||
func (i *IncomingMessageImpl) SetRequiresHandshake(value bool) {
|
||||
i.requiresHandshake = value
|
||||
}
|
||||
|
|
|
@ -39,10 +39,14 @@ var (
|
|||
)
|
||||
|
||||
func NewHandshakeOpen(challenge []byte, networkId string) *HandshakeOpen {
|
||||
return &HandshakeOpen{
|
||||
ho := &HandshakeOpen{
|
||||
challenge: challenge,
|
||||
networkId: networkId,
|
||||
}
|
||||
|
||||
ho.SetRequiresHandshake(false)
|
||||
|
||||
return ho
|
||||
}
|
||||
func (h HandshakeOpen) EncodeMsgpack(enc *msgpack.Encoder) error {
|
||||
err := enc.EncodeUint(uint64(types.ProtocolMethodHandshakeOpen))
|
||||
|
|
|
@ -24,7 +24,11 @@ type HashQuery struct {
|
|||
}
|
||||
|
||||
func NewHashQuery() *HashQuery {
|
||||
return &HashQuery{}
|
||||
hq := &HashQuery{}
|
||||
|
||||
hq.SetRequiresHandshake(true)
|
||||
|
||||
return hq
|
||||
}
|
||||
|
||||
func NewHashRequest(hash *encoding.Multihash, kinds []types.StorageLocationType) *HashQuery {
|
||||
|
|
|
@ -18,7 +18,11 @@ type RegistryEntryRequest struct {
|
|||
}
|
||||
|
||||
func NewEmptyRegistryEntryRequest() *RegistryEntryRequest {
|
||||
return &RegistryEntryRequest{}
|
||||
rer := &RegistryEntryRequest{}
|
||||
|
||||
rer.SetRequiresHandshake(true)
|
||||
|
||||
return rer
|
||||
}
|
||||
func NewRegistryEntryRequest(sre interfaces.SignedRegistryEntry) *RegistryEntryRequest {
|
||||
return &RegistryEntryRequest{sre: sre}
|
||||
|
|
|
@ -18,7 +18,11 @@ type RegistryQuery struct {
|
|||
}
|
||||
|
||||
func NewEmptyRegistryQuery() *RegistryQuery {
|
||||
return &RegistryQuery{}
|
||||
rq := &RegistryQuery{}
|
||||
|
||||
rq.SetRequiresHandshake(true)
|
||||
|
||||
return rq
|
||||
}
|
||||
func NewRegistryQuery(pk []byte) *RegistryQuery {
|
||||
return &RegistryQuery{pk: pk}
|
||||
|
|
|
@ -34,7 +34,11 @@ func NewAnnounceRequest(peer net.Peer, peersToSend []net.Peer) *AnnouncePeers {
|
|||
}
|
||||
|
||||
func NewAnnouncePeers() *AnnouncePeers {
|
||||
return &AnnouncePeers{peer: nil, connectionUris: nil}
|
||||
ap := &AnnouncePeers{peer: nil, connectionUris: nil}
|
||||
|
||||
ap.SetRequiresHandshake(false)
|
||||
|
||||
return ap
|
||||
}
|
||||
|
||||
func (a *AnnouncePeers) DecodeMessage(dec *msgpack.Decoder) error {
|
||||
|
|
|
@ -27,11 +27,15 @@ type HandshakeDone struct {
|
|||
}
|
||||
|
||||
func NewHandshakeDoneRequest(handshake []byte, supportedFeatures int, connectionUris []*url.URL) *HandshakeDone {
|
||||
return &HandshakeDone{
|
||||
ho := &HandshakeDone{
|
||||
handshake: handshake,
|
||||
supportedFeatures: supportedFeatures,
|
||||
connectionUris: connectionUris,
|
||||
}
|
||||
|
||||
ho.SetRequiresHandshake(false)
|
||||
|
||||
return ho
|
||||
}
|
||||
|
||||
func (m HandshakeDone) EncodeMsgpack(enc *msgpack.Encoder) error {
|
||||
|
@ -67,7 +71,11 @@ func (m *HandshakeDone) SetNetworkId(networkId string) {
|
|||
}
|
||||
|
||||
func NewHandshakeDone() *HandshakeDone {
|
||||
return &HandshakeDone{challenge: nil, networkId: "", supportedFeatures: -1}
|
||||
hn := &HandshakeDone{challenge: nil, networkId: "", supportedFeatures: -1}
|
||||
|
||||
hn.SetRequiresHandshake(false)
|
||||
|
||||
return hn
|
||||
}
|
||||
|
||||
func (h HandshakeDone) HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error {
|
||||
|
@ -94,6 +102,7 @@ func (h HandshakeDone) HandleMessage(node interfaces.Node, peer net.Peer, verify
|
|||
}
|
||||
|
||||
peer.SetConnected(true)
|
||||
peer.SetHandshakeDone(true)
|
||||
|
||||
if h.supportedFeatures != types.SupportedFeatures {
|
||||
return fmt.Errorf("Remote node does not support required features")
|
||||
|
|
|
@ -75,7 +75,11 @@ func (s *signedMessagePayoad) DecodeMsgpack(dec *msgpack.Decoder) error {
|
|||
}
|
||||
|
||||
func NewSignedMessage() *SignedMessage {
|
||||
return &SignedMessage{}
|
||||
sm := &SignedMessage{}
|
||||
|
||||
sm.SetRequiresHandshake(false)
|
||||
|
||||
return sm
|
||||
}
|
||||
|
||||
func (s *SignedMessage) HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error {
|
||||
|
@ -88,6 +92,10 @@ func (s *SignedMessage) HandleMessage(node interfaces.Node, peer net.Peer, verif
|
|||
|
||||
if msgHandler, valid := GetMessageType(payload.kind); valid {
|
||||
node.Logger().Debug("SignedMessage", zap.Any("type", types.ProtocolMethodMap[types.ProtocolMethod(payload.kind)]))
|
||||
if msgHandler.RequiresHandshake() && !peer.IsHandshakeDone() {
|
||||
node.Logger().Debug("Peer is not handshake done, ignoring message", zap.Any("type", types.ProtocolMethodMap[types.ProtocolMethod(payload.kind)]))
|
||||
return nil
|
||||
}
|
||||
msgHandler.SetIncomingMessage(s)
|
||||
msgHandler.SetSelf(msgHandler)
|
||||
err := msgpack.Unmarshal(payload.message, &msgHandler)
|
||||
|
|
|
@ -30,7 +30,11 @@ type StorageLocation struct {
|
|||
}
|
||||
|
||||
func NewStorageLocation() *StorageLocation {
|
||||
return &StorageLocation{}
|
||||
sl := &StorageLocation{}
|
||||
|
||||
sl.SetRequiresHandshake(true)
|
||||
|
||||
return sl
|
||||
}
|
||||
|
||||
func (s *StorageLocation) DecodeMessage(dec *msgpack.Decoder) error {
|
||||
|
|
|
@ -326,7 +326,10 @@ func (p *P2PImpl) OnNewPeerListen(peer net.Peer, verifyId bool) {
|
|||
handler, ok := protocol.GetMessageType(imsg.Kind())
|
||||
|
||||
if ok {
|
||||
|
||||
if handler.RequiresHandshake() && !peer.IsHandshakeDone() {
|
||||
p.logger.Debug("Peer is not handshake done, ignoring message", zap.Any("type", types.ProtocolMethodMap[types.ProtocolMethod(imsg.Kind())]))
|
||||
return nil
|
||||
}
|
||||
imsg.SetOriginal(message)
|
||||
handler.SetIncomingMessage(imsg)
|
||||
handler.SetSelf(handler)
|
||||
|
|
|
@ -16,7 +16,7 @@ const (
|
|||
|
||||
var ProtocolMethodMap = map[ProtocolMethod]string{
|
||||
ProtocolMethodHandshakeOpen: "HandshakeOpen",
|
||||
ProtocolMethodHandshakeDone: "HandshakeDone",
|
||||
ProtocolMethodHandshakeDone: "IsHandshakeDone",
|
||||
ProtocolMethodSignedMessage: "SignedMessage",
|
||||
ProtocolMethodHashQuery: "HashQuery",
|
||||
ProtocolMethodAnnouncePeers: "AnnouncePeers",
|
||||
|
|
Loading…
Reference in New Issue