fix: add a new property on messages and peers to prevent messages from being processed before the handshake is done
This commit is contained in:
parent
36f087dc83
commit
3d41119f74
11
net/peer.go
11
net/peer.go
|
@ -43,6 +43,8 @@ type Peer interface {
|
||||||
IsConnected() bool
|
IsConnected() bool
|
||||||
SetConnectionURIs(uris []*url.URL)
|
SetConnectionURIs(uris []*url.URL)
|
||||||
ConnectionURIs() []*url.URL
|
ConnectionURIs() []*url.URL
|
||||||
|
IsHandshakeDone() bool
|
||||||
|
SetHandshakeDone(status bool)
|
||||||
}
|
}
|
||||||
|
|
||||||
type BasePeer struct {
|
type BasePeer struct {
|
||||||
|
@ -51,6 +53,7 @@ type BasePeer struct {
|
||||||
challenge []byte
|
challenge []byte
|
||||||
socket interface{}
|
socket interface{}
|
||||||
id *encoding.NodeId
|
id *encoding.NodeId
|
||||||
|
handshaked bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BasePeer) IsConnected() bool {
|
func (b *BasePeer) IsConnected() bool {
|
||||||
|
@ -106,3 +109,11 @@ func (b *BasePeer) SetConnectionURIs(uris []*url.URL) {
|
||||||
func (b *BasePeer) ConnectionURIs() []*url.URL {
|
func (b *BasePeer) ConnectionURIs() []*url.URL {
|
||||||
return b.connectionURIs
|
return b.connectionURIs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *BasePeer) IsHandshakeDone() bool {
|
||||||
|
return b.handshaked
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BasePeer) SetHandshakeDone(status bool) {
|
||||||
|
b.handshaked = status
|
||||||
|
}
|
||||||
|
|
|
@ -16,6 +16,8 @@ type IncomingMessage interface {
|
||||||
SetSelf(self IncomingMessage)
|
SetSelf(self IncomingMessage)
|
||||||
Original() []byte
|
Original() []byte
|
||||||
Kind() int
|
Kind() int
|
||||||
|
RequiresHandshake() bool
|
||||||
|
SetRequiresHandshake(value bool)
|
||||||
msgpack.CustomDecoder
|
msgpack.CustomDecoder
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -24,6 +24,7 @@ type IncomingMessageImpl struct {
|
||||||
known bool
|
known bool
|
||||||
self IncomingMessage
|
self IncomingMessage
|
||||||
incoming IncomingMessage
|
incoming IncomingMessage
|
||||||
|
requiresHandshake bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *IncomingMessageImpl) Self() IncomingMessage {
|
func (i *IncomingMessageImpl) Self() IncomingMessage {
|
||||||
|
@ -126,3 +127,11 @@ func (i *IncomingMessageImpl) DecodeMsgpack(dec *msgpack.Decoder) error {
|
||||||
i.data = raw
|
i.data = raw
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (i *IncomingMessageImpl) RequiresHandshake() bool {
|
||||||
|
return i.requiresHandshake
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *IncomingMessageImpl) SetRequiresHandshake(value bool) {
|
||||||
|
i.requiresHandshake = value
|
||||||
|
}
|
||||||
|
|
|
@ -39,10 +39,14 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewHandshakeOpen(challenge []byte, networkId string) *HandshakeOpen {
|
func NewHandshakeOpen(challenge []byte, networkId string) *HandshakeOpen {
|
||||||
return &HandshakeOpen{
|
ho := &HandshakeOpen{
|
||||||
challenge: challenge,
|
challenge: challenge,
|
||||||
networkId: networkId,
|
networkId: networkId,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ho.SetRequiresHandshake(false)
|
||||||
|
|
||||||
|
return ho
|
||||||
}
|
}
|
||||||
func (h HandshakeOpen) EncodeMsgpack(enc *msgpack.Encoder) error {
|
func (h HandshakeOpen) EncodeMsgpack(enc *msgpack.Encoder) error {
|
||||||
err := enc.EncodeUint(uint64(types.ProtocolMethodHandshakeOpen))
|
err := enc.EncodeUint(uint64(types.ProtocolMethodHandshakeOpen))
|
||||||
|
|
|
@ -24,7 +24,11 @@ type HashQuery struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHashQuery() *HashQuery {
|
func NewHashQuery() *HashQuery {
|
||||||
return &HashQuery{}
|
hq := &HashQuery{}
|
||||||
|
|
||||||
|
hq.SetRequiresHandshake(true)
|
||||||
|
|
||||||
|
return hq
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHashRequest(hash *encoding.Multihash, kinds []types.StorageLocationType) *HashQuery {
|
func NewHashRequest(hash *encoding.Multihash, kinds []types.StorageLocationType) *HashQuery {
|
||||||
|
|
|
@ -18,7 +18,11 @@ type RegistryEntryRequest struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewEmptyRegistryEntryRequest() *RegistryEntryRequest {
|
func NewEmptyRegistryEntryRequest() *RegistryEntryRequest {
|
||||||
return &RegistryEntryRequest{}
|
rer := &RegistryEntryRequest{}
|
||||||
|
|
||||||
|
rer.SetRequiresHandshake(true)
|
||||||
|
|
||||||
|
return rer
|
||||||
}
|
}
|
||||||
func NewRegistryEntryRequest(sre interfaces.SignedRegistryEntry) *RegistryEntryRequest {
|
func NewRegistryEntryRequest(sre interfaces.SignedRegistryEntry) *RegistryEntryRequest {
|
||||||
return &RegistryEntryRequest{sre: sre}
|
return &RegistryEntryRequest{sre: sre}
|
||||||
|
|
|
@ -18,7 +18,11 @@ type RegistryQuery struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewEmptyRegistryQuery() *RegistryQuery {
|
func NewEmptyRegistryQuery() *RegistryQuery {
|
||||||
return &RegistryQuery{}
|
rq := &RegistryQuery{}
|
||||||
|
|
||||||
|
rq.SetRequiresHandshake(true)
|
||||||
|
|
||||||
|
return rq
|
||||||
}
|
}
|
||||||
func NewRegistryQuery(pk []byte) *RegistryQuery {
|
func NewRegistryQuery(pk []byte) *RegistryQuery {
|
||||||
return &RegistryQuery{pk: pk}
|
return &RegistryQuery{pk: pk}
|
||||||
|
|
|
@ -34,7 +34,11 @@ func NewAnnounceRequest(peer net.Peer, peersToSend []net.Peer) *AnnouncePeers {
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAnnouncePeers() *AnnouncePeers {
|
func NewAnnouncePeers() *AnnouncePeers {
|
||||||
return &AnnouncePeers{peer: nil, connectionUris: nil}
|
ap := &AnnouncePeers{peer: nil, connectionUris: nil}
|
||||||
|
|
||||||
|
ap.SetRequiresHandshake(false)
|
||||||
|
|
||||||
|
return ap
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AnnouncePeers) DecodeMessage(dec *msgpack.Decoder) error {
|
func (a *AnnouncePeers) DecodeMessage(dec *msgpack.Decoder) error {
|
||||||
|
|
|
@ -27,11 +27,15 @@ type HandshakeDone struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHandshakeDoneRequest(handshake []byte, supportedFeatures int, connectionUris []*url.URL) *HandshakeDone {
|
func NewHandshakeDoneRequest(handshake []byte, supportedFeatures int, connectionUris []*url.URL) *HandshakeDone {
|
||||||
return &HandshakeDone{
|
ho := &HandshakeDone{
|
||||||
handshake: handshake,
|
handshake: handshake,
|
||||||
supportedFeatures: supportedFeatures,
|
supportedFeatures: supportedFeatures,
|
||||||
connectionUris: connectionUris,
|
connectionUris: connectionUris,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ho.SetRequiresHandshake(false)
|
||||||
|
|
||||||
|
return ho
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m HandshakeDone) EncodeMsgpack(enc *msgpack.Encoder) error {
|
func (m HandshakeDone) EncodeMsgpack(enc *msgpack.Encoder) error {
|
||||||
|
@ -67,7 +71,11 @@ func (m *HandshakeDone) SetNetworkId(networkId string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHandshakeDone() *HandshakeDone {
|
func NewHandshakeDone() *HandshakeDone {
|
||||||
return &HandshakeDone{challenge: nil, networkId: "", supportedFeatures: -1}
|
hn := &HandshakeDone{challenge: nil, networkId: "", supportedFeatures: -1}
|
||||||
|
|
||||||
|
hn.SetRequiresHandshake(false)
|
||||||
|
|
||||||
|
return hn
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h HandshakeDone) HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error {
|
func (h HandshakeDone) HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error {
|
||||||
|
@ -94,6 +102,7 @@ func (h HandshakeDone) HandleMessage(node interfaces.Node, peer net.Peer, verify
|
||||||
}
|
}
|
||||||
|
|
||||||
peer.SetConnected(true)
|
peer.SetConnected(true)
|
||||||
|
peer.SetHandshakeDone(true)
|
||||||
|
|
||||||
if h.supportedFeatures != types.SupportedFeatures {
|
if h.supportedFeatures != types.SupportedFeatures {
|
||||||
return fmt.Errorf("Remote node does not support required features")
|
return fmt.Errorf("Remote node does not support required features")
|
||||||
|
|
|
@ -75,7 +75,11 @@ func (s *signedMessagePayoad) DecodeMsgpack(dec *msgpack.Decoder) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSignedMessage() *SignedMessage {
|
func NewSignedMessage() *SignedMessage {
|
||||||
return &SignedMessage{}
|
sm := &SignedMessage{}
|
||||||
|
|
||||||
|
sm.SetRequiresHandshake(false)
|
||||||
|
|
||||||
|
return sm
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SignedMessage) HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error {
|
func (s *SignedMessage) HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error {
|
||||||
|
@ -88,6 +92,10 @@ func (s *SignedMessage) HandleMessage(node interfaces.Node, peer net.Peer, verif
|
||||||
|
|
||||||
if msgHandler, valid := GetMessageType(payload.kind); valid {
|
if msgHandler, valid := GetMessageType(payload.kind); valid {
|
||||||
node.Logger().Debug("SignedMessage", zap.Any("type", types.ProtocolMethodMap[types.ProtocolMethod(payload.kind)]))
|
node.Logger().Debug("SignedMessage", zap.Any("type", types.ProtocolMethodMap[types.ProtocolMethod(payload.kind)]))
|
||||||
|
if msgHandler.RequiresHandshake() && !peer.IsHandshakeDone() {
|
||||||
|
node.Logger().Debug("Peer is not handshake done, ignoring message", zap.Any("type", types.ProtocolMethodMap[types.ProtocolMethod(payload.kind)]))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
msgHandler.SetIncomingMessage(s)
|
msgHandler.SetIncomingMessage(s)
|
||||||
msgHandler.SetSelf(msgHandler)
|
msgHandler.SetSelf(msgHandler)
|
||||||
err := msgpack.Unmarshal(payload.message, &msgHandler)
|
err := msgpack.Unmarshal(payload.message, &msgHandler)
|
||||||
|
|
|
@ -30,7 +30,11 @@ type StorageLocation struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewStorageLocation() *StorageLocation {
|
func NewStorageLocation() *StorageLocation {
|
||||||
return &StorageLocation{}
|
sl := &StorageLocation{}
|
||||||
|
|
||||||
|
sl.SetRequiresHandshake(true)
|
||||||
|
|
||||||
|
return sl
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *StorageLocation) DecodeMessage(dec *msgpack.Decoder) error {
|
func (s *StorageLocation) DecodeMessage(dec *msgpack.Decoder) error {
|
||||||
|
|
|
@ -326,7 +326,10 @@ func (p *P2PImpl) OnNewPeerListen(peer net.Peer, verifyId bool) {
|
||||||
handler, ok := protocol.GetMessageType(imsg.Kind())
|
handler, ok := protocol.GetMessageType(imsg.Kind())
|
||||||
|
|
||||||
if ok {
|
if ok {
|
||||||
|
if handler.RequiresHandshake() && !peer.IsHandshakeDone() {
|
||||||
|
p.logger.Debug("Peer is not handshake done, ignoring message", zap.Any("type", types.ProtocolMethodMap[types.ProtocolMethod(imsg.Kind())]))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
imsg.SetOriginal(message)
|
imsg.SetOriginal(message)
|
||||||
handler.SetIncomingMessage(imsg)
|
handler.SetIncomingMessage(imsg)
|
||||||
handler.SetSelf(handler)
|
handler.SetSelf(handler)
|
||||||
|
|
|
@ -16,7 +16,7 @@ const (
|
||||||
|
|
||||||
var ProtocolMethodMap = map[ProtocolMethod]string{
|
var ProtocolMethodMap = map[ProtocolMethod]string{
|
||||||
ProtocolMethodHandshakeOpen: "HandshakeOpen",
|
ProtocolMethodHandshakeOpen: "HandshakeOpen",
|
||||||
ProtocolMethodHandshakeDone: "HandshakeDone",
|
ProtocolMethodHandshakeDone: "IsHandshakeDone",
|
||||||
ProtocolMethodSignedMessage: "SignedMessage",
|
ProtocolMethodSignedMessage: "SignedMessage",
|
||||||
ProtocolMethodHashQuery: "HashQuery",
|
ProtocolMethodHashQuery: "HashQuery",
|
||||||
ProtocolMethodAnnouncePeers: "AnnouncePeers",
|
ProtocolMethodAnnouncePeers: "AnnouncePeers",
|
||||||
|
|
Loading…
Reference in New Issue