diff --git a/protocol/base/incoming_message.go b/protocol/base/incoming_message.go index 100b484..06628a2 100644 --- a/protocol/base/incoming_message.go +++ b/protocol/base/incoming_message.go @@ -4,7 +4,6 @@ import ( "fmt" "git.lumeweb.com/LumeWeb/libs5-go/interfaces" "git.lumeweb.com/LumeWeb/libs5-go/net" - "git.lumeweb.com/LumeWeb/libs5-go/types" "github.com/vmihailenco/msgpack/v5" "io" "net/url" @@ -19,7 +18,7 @@ var _ IncomingMessageTyped = (*IncomingMessageImpl)(nil) type IncomingMessageHandler func(node interfaces.Node, peer *net.Peer, u *url.URL, verifyId bool) error type IncomingMessageImpl struct { - kind types.ProtocolMethod + kind int data msgpack.RawMessage original []byte known bool @@ -64,7 +63,7 @@ func (i *IncomingMessageImpl) IncomingMessage() IncomingMessage { return i.incoming } -func (i *IncomingMessageImpl) GetKind() types.ProtocolMethod { +func (i *IncomingMessageImpl) Kind() int { return i.kind } @@ -76,10 +75,6 @@ func (i *IncomingMessageImpl) HandleMessage(node interfaces.Node, peer net.Peer, panic("child class should implement this method") } -func (i *IncomingMessageImpl) Kind() types.ProtocolMethod { - return i.kind -} - func (i *IncomingMessageImpl) Data() msgpack.RawMessage { return i.data } @@ -94,7 +89,7 @@ func NewIncomingMessageUnknown() *IncomingMessageImpl { } } -func NewIncomingMessageKnown(kind types.ProtocolMethod, data msgpack.RawMessage) *IncomingMessageImpl { +func NewIncomingMessageKnown(kind int, data msgpack.RawMessage) *IncomingMessageImpl { return &IncomingMessageImpl{ kind: kind, data: data, @@ -102,7 +97,7 @@ func NewIncomingMessageKnown(kind types.ProtocolMethod, data msgpack.RawMessage) } } -func NewIncomingMessageTyped(kind types.ProtocolMethod, data msgpack.RawMessage) *IncomingMessageTypedImpl { +func NewIncomingMessageTyped(kind int, data msgpack.RawMessage) *IncomingMessageTypedImpl { known := NewIncomingMessageKnown(kind, data) return &IncomingMessageTypedImpl{*known} } @@ -120,7 +115,7 @@ func (i *IncomingMessageImpl) DecodeMsgpack(dec *msgpack.Decoder) error { return err } - i.kind = types.ProtocolMethod(kind) + i.kind = kind raw, err := io.ReadAll(dec.Buffered()) diff --git a/protocol/message.go b/protocol/message.go index ab39edb..4e2d2a1 100644 --- a/protocol/message.go +++ b/protocol/message.go @@ -19,26 +19,26 @@ func init() { messageTypes = sync.Map{} // Register factory functions instead of instances - RegisterMessageType(types.ProtocolMethodHandshakeOpen, func() base.IncomingMessage { + RegisterMessageType(int(types.ProtocolMethodHandshakeOpen), func() base.IncomingMessage { return NewHandshakeOpen([]byte{}, "") }) - RegisterMessageType(types.ProtocolMethodHashQuery, func() base.IncomingMessage { + RegisterMessageType(int(types.ProtocolMethodHashQuery), func() base.IncomingMessage { return NewHashQuery() }) - RegisterMessageType(types.ProtocolMethodSignedMessage, func() base.IncomingMessage { + RegisterMessageType(int(types.ProtocolMethodSignedMessage), func() base.IncomingMessage { return signed.NewSignedMessage() }) } -func RegisterMessageType(messageType types.ProtocolMethod, factoryFunc func() base.IncomingMessage) { +func RegisterMessageType(messageType int, factoryFunc func() base.IncomingMessage) { if factoryFunc == nil { panic("factoryFunc cannot be nil") } - messageTypes.Store(messageType, factoryFunc) + messageTypes.Store(int(messageType), factoryFunc) } -func GetMessageType(kind types.ProtocolMethod) (base.IncomingMessage, bool) { +func GetMessageType(kind int) (base.IncomingMessage, bool) { value, ok := messageTypes.Load(kind) if !ok { return nil, false diff --git a/protocol/signed/signed.go b/protocol/signed/signed.go index 21600c9..dc3c474 100644 --- a/protocol/signed/signed.go +++ b/protocol/signed/signed.go @@ -13,22 +13,22 @@ var ( func init() { messageTypes = sync.Map{} - RegisterMessageType(types.ProtocolMethodHandshakeDone, func() base.SignedIncomingMessage { + RegisterMessageType(int(types.ProtocolMethodHandshakeDone), func() base.SignedIncomingMessage { return NewHandshakeDone() }) - RegisterMessageType(types.ProtocolMethodAnnouncePeers, func() base.SignedIncomingMessage { + RegisterMessageType(int(types.ProtocolMethodAnnouncePeers), func() base.SignedIncomingMessage { return NewAnnouncePeers() }) } -func RegisterMessageType(messageType types.ProtocolMethod, factoryFunc func() base.SignedIncomingMessage) { +func RegisterMessageType(messageType int, factoryFunc func() base.SignedIncomingMessage) { if factoryFunc == nil { panic("factoryFunc cannot be nil") } - messageTypes.Store(messageType, factoryFunc) + messageTypes.Store(int(messageType), factoryFunc) } -func GetMessageType(kind types.ProtocolMethod) (base.SignedIncomingMessage, bool) { +func GetMessageType(kind int) (base.SignedIncomingMessage, bool) { value, ok := messageTypes.Load(kind) if !ok { return nil, false