diff --git a/protocol/message.go b/protocol/message.go index 2a9b312..290a164 100644 --- a/protocol/message.go +++ b/protocol/message.go @@ -4,11 +4,10 @@ import ( "git.lumeweb.com/LumeWeb/libs5-go/protocol/base" "git.lumeweb.com/LumeWeb/libs5-go/protocol/signed" "git.lumeweb.com/LumeWeb/libs5-go/types" - "sync" ) var ( - messageTypes sync.Map + messageTypes map[int]func() base.IncomingMessage ) var ( @@ -16,7 +15,7 @@ var ( ) func Init() { - messageTypes = sync.Map{} + messageTypes = make(map[int]func() base.IncomingMessage) // Register factory functions instead of instances RegisterMessageType(int(types.ProtocolMethodHandshakeOpen), func() base.IncomingMessage { @@ -35,19 +34,14 @@ func RegisterMessageType(messageType int, factoryFunc func() base.IncomingMessag if factoryFunc == nil { panic("factoryFunc cannot be nil") } - messageTypes.Store(int(messageType), factoryFunc) + messageTypes[messageType] = factoryFunc } func GetMessageType(kind int) (base.IncomingMessage, bool) { - value, ok := messageTypes.Load(kind) + value, ok := messageTypes[kind] if !ok { return nil, false } - factoryFunc, ok := value.(func() base.IncomingMessage) - if !ok { - return nil, false - } - - return factoryFunc(), true + return value(), true } diff --git a/protocol/signed/signed.go b/protocol/signed/signed.go index 8d4f529..261eb1e 100644 --- a/protocol/signed/signed.go +++ b/protocol/signed/signed.go @@ -3,15 +3,14 @@ package signed import ( "git.lumeweb.com/LumeWeb/libs5-go/protocol/base" "git.lumeweb.com/LumeWeb/libs5-go/types" - "sync" ) var ( - messageTypes sync.Map + messageTypes map[int]func() base.SignedIncomingMessage ) func Init() { - messageTypes = sync.Map{} + messageTypes = make(map[int]func() base.SignedIncomingMessage) RegisterMessageType(int(types.ProtocolMethodHandshakeDone), func() base.SignedIncomingMessage { return NewHandshakeDone() @@ -25,21 +24,16 @@ func RegisterMessageType(messageType int, factoryFunc func() base.SignedIncoming if factoryFunc == nil { panic("factoryFunc cannot be nil") } - messageTypes.Store(int(messageType), factoryFunc) + messageTypes[messageType] = factoryFunc } func GetMessageType(kind int) (base.SignedIncomingMessage, bool) { - value, ok := messageTypes.Load(kind) + value, ok := messageTypes[kind] if !ok { return nil, false } - factoryFunc, ok := value.(func() base.SignedIncomingMessage) - if !ok { - return nil, false - } - - return factoryFunc(), true + return value(), true } var (