refactor: switch to using a normal map

This commit is contained in:
Derrick Hammer 2024-01-09 13:57:35 -05:00
parent 6b6e7d4fc4
commit ed97c03d16
Signed by: pcfreak30
GPG Key ID: C997C339BE476FF2
2 changed files with 10 additions and 22 deletions

View File

@ -4,11 +4,10 @@ import (
"git.lumeweb.com/LumeWeb/libs5-go/protocol/base"
"git.lumeweb.com/LumeWeb/libs5-go/protocol/signed"
"git.lumeweb.com/LumeWeb/libs5-go/types"
"sync"
)
var (
messageTypes sync.Map
messageTypes map[int]func() base.IncomingMessage
)
var (
@ -16,7 +15,7 @@ var (
)
func Init() {
messageTypes = sync.Map{}
messageTypes = make(map[int]func() base.IncomingMessage)
// Register factory functions instead of instances
RegisterMessageType(int(types.ProtocolMethodHandshakeOpen), func() base.IncomingMessage {
@ -35,19 +34,14 @@ func RegisterMessageType(messageType int, factoryFunc func() base.IncomingMessag
if factoryFunc == nil {
panic("factoryFunc cannot be nil")
}
messageTypes.Store(int(messageType), factoryFunc)
messageTypes[messageType] = factoryFunc
}
func GetMessageType(kind int) (base.IncomingMessage, bool) {
value, ok := messageTypes.Load(kind)
value, ok := messageTypes[kind]
if !ok {
return nil, false
}
factoryFunc, ok := value.(func() base.IncomingMessage)
if !ok {
return nil, false
}
return factoryFunc(), true
return value(), true
}

View File

@ -3,15 +3,14 @@ package signed
import (
"git.lumeweb.com/LumeWeb/libs5-go/protocol/base"
"git.lumeweb.com/LumeWeb/libs5-go/types"
"sync"
)
var (
messageTypes sync.Map
messageTypes map[int]func() base.SignedIncomingMessage
)
func Init() {
messageTypes = sync.Map{}
messageTypes = make(map[int]func() base.SignedIncomingMessage)
RegisterMessageType(int(types.ProtocolMethodHandshakeDone), func() base.SignedIncomingMessage {
return NewHandshakeDone()
@ -25,21 +24,16 @@ func RegisterMessageType(messageType int, factoryFunc func() base.SignedIncoming
if factoryFunc == nil {
panic("factoryFunc cannot be nil")
}
messageTypes.Store(int(messageType), factoryFunc)
messageTypes[messageType] = factoryFunc
}
func GetMessageType(kind int) (base.SignedIncomingMessage, bool) {
value, ok := messageTypes.Load(kind)
value, ok := messageTypes[kind]
if !ok {
return nil, false
}
factoryFunc, ok := value.(func() base.SignedIncomingMessage)
if !ok {
return nil, false
}
return factoryFunc(), true
return value(), true
}
var (