fix: ensure we use int everywhere for kind to try and avoid any weird bitwise or implied conversions

This commit is contained in:
Derrick Hammer 2024-01-09 12:47:58 -05:00
parent 2622f2b9d0
commit 1458cbe1d9
Signed by: pcfreak30
GPG Key ID: C997C339BE476FF2
3 changed files with 16 additions and 21 deletions

View File

@ -4,7 +4,6 @@ import (
"fmt" "fmt"
"git.lumeweb.com/LumeWeb/libs5-go/interfaces" "git.lumeweb.com/LumeWeb/libs5-go/interfaces"
"git.lumeweb.com/LumeWeb/libs5-go/net" "git.lumeweb.com/LumeWeb/libs5-go/net"
"git.lumeweb.com/LumeWeb/libs5-go/types"
"github.com/vmihailenco/msgpack/v5" "github.com/vmihailenco/msgpack/v5"
"io" "io"
"net/url" "net/url"
@ -19,7 +18,7 @@ var _ IncomingMessageTyped = (*IncomingMessageImpl)(nil)
type IncomingMessageHandler func(node interfaces.Node, peer *net.Peer, u *url.URL, verifyId bool) error type IncomingMessageHandler func(node interfaces.Node, peer *net.Peer, u *url.URL, verifyId bool) error
type IncomingMessageImpl struct { type IncomingMessageImpl struct {
kind types.ProtocolMethod kind int
data msgpack.RawMessage data msgpack.RawMessage
original []byte original []byte
known bool known bool
@ -64,7 +63,7 @@ func (i *IncomingMessageImpl) IncomingMessage() IncomingMessage {
return i.incoming return i.incoming
} }
func (i *IncomingMessageImpl) GetKind() types.ProtocolMethod { func (i *IncomingMessageImpl) Kind() int {
return i.kind return i.kind
} }
@ -76,10 +75,6 @@ func (i *IncomingMessageImpl) HandleMessage(node interfaces.Node, peer net.Peer,
panic("child class should implement this method") panic("child class should implement this method")
} }
func (i *IncomingMessageImpl) Kind() types.ProtocolMethod {
return i.kind
}
func (i *IncomingMessageImpl) Data() msgpack.RawMessage { func (i *IncomingMessageImpl) Data() msgpack.RawMessage {
return i.data return i.data
} }
@ -94,7 +89,7 @@ func NewIncomingMessageUnknown() *IncomingMessageImpl {
} }
} }
func NewIncomingMessageKnown(kind types.ProtocolMethod, data msgpack.RawMessage) *IncomingMessageImpl { func NewIncomingMessageKnown(kind int, data msgpack.RawMessage) *IncomingMessageImpl {
return &IncomingMessageImpl{ return &IncomingMessageImpl{
kind: kind, kind: kind,
data: data, data: data,
@ -102,7 +97,7 @@ func NewIncomingMessageKnown(kind types.ProtocolMethod, data msgpack.RawMessage)
} }
} }
func NewIncomingMessageTyped(kind types.ProtocolMethod, data msgpack.RawMessage) *IncomingMessageTypedImpl { func NewIncomingMessageTyped(kind int, data msgpack.RawMessage) *IncomingMessageTypedImpl {
known := NewIncomingMessageKnown(kind, data) known := NewIncomingMessageKnown(kind, data)
return &IncomingMessageTypedImpl{*known} return &IncomingMessageTypedImpl{*known}
} }
@ -120,7 +115,7 @@ func (i *IncomingMessageImpl) DecodeMsgpack(dec *msgpack.Decoder) error {
return err return err
} }
i.kind = types.ProtocolMethod(kind) i.kind = kind
raw, err := io.ReadAll(dec.Buffered()) raw, err := io.ReadAll(dec.Buffered())

View File

@ -19,26 +19,26 @@ func init() {
messageTypes = sync.Map{} messageTypes = sync.Map{}
// Register factory functions instead of instances // Register factory functions instead of instances
RegisterMessageType(types.ProtocolMethodHandshakeOpen, func() base.IncomingMessage { RegisterMessageType(int(types.ProtocolMethodHandshakeOpen), func() base.IncomingMessage {
return NewHandshakeOpen([]byte{}, "") return NewHandshakeOpen([]byte{}, "")
}) })
RegisterMessageType(types.ProtocolMethodHashQuery, func() base.IncomingMessage { RegisterMessageType(int(types.ProtocolMethodHashQuery), func() base.IncomingMessage {
return NewHashQuery() return NewHashQuery()
}) })
RegisterMessageType(types.ProtocolMethodSignedMessage, func() base.IncomingMessage { RegisterMessageType(int(types.ProtocolMethodSignedMessage), func() base.IncomingMessage {
return signed.NewSignedMessage() return signed.NewSignedMessage()
}) })
} }
func RegisterMessageType(messageType types.ProtocolMethod, factoryFunc func() base.IncomingMessage) { func RegisterMessageType(messageType int, factoryFunc func() base.IncomingMessage) {
if factoryFunc == nil { if factoryFunc == nil {
panic("factoryFunc cannot be nil") panic("factoryFunc cannot be nil")
} }
messageTypes.Store(messageType, factoryFunc) messageTypes.Store(int(messageType), factoryFunc)
} }
func GetMessageType(kind types.ProtocolMethod) (base.IncomingMessage, bool) { func GetMessageType(kind int) (base.IncomingMessage, bool) {
value, ok := messageTypes.Load(kind) value, ok := messageTypes.Load(kind)
if !ok { if !ok {
return nil, false return nil, false

View File

@ -13,22 +13,22 @@ var (
func init() { func init() {
messageTypes = sync.Map{} messageTypes = sync.Map{}
RegisterMessageType(types.ProtocolMethodHandshakeDone, func() base.SignedIncomingMessage { RegisterMessageType(int(types.ProtocolMethodHandshakeDone), func() base.SignedIncomingMessage {
return NewHandshakeDone() return NewHandshakeDone()
}) })
RegisterMessageType(types.ProtocolMethodAnnouncePeers, func() base.SignedIncomingMessage { RegisterMessageType(int(types.ProtocolMethodAnnouncePeers), func() base.SignedIncomingMessage {
return NewAnnouncePeers() return NewAnnouncePeers()
}) })
} }
func RegisterMessageType(messageType types.ProtocolMethod, factoryFunc func() base.SignedIncomingMessage) { func RegisterMessageType(messageType int, factoryFunc func() base.SignedIncomingMessage) {
if factoryFunc == nil { if factoryFunc == nil {
panic("factoryFunc cannot be nil") panic("factoryFunc cannot be nil")
} }
messageTypes.Store(messageType, factoryFunc) messageTypes.Store(int(messageType), factoryFunc)
} }
func GetMessageType(kind types.ProtocolMethod) (base.SignedIncomingMessage, bool) { func GetMessageType(kind int) (base.SignedIncomingMessage, bool) {
value, ok := messageTypes.Load(kind) value, ok := messageTypes.Load(kind)
if !ok { if !ok {
return nil, false return nil, false