fix: ensure we use int everywhere for kind to try and avoid any weird bitwise or implied conversions

This commit is contained in:
Derrick Hammer 2024-01-09 12:47:58 -05:00
parent 2622f2b9d0
commit 1458cbe1d9
Signed by: pcfreak30
GPG Key ID: C997C339BE476FF2
3 changed files with 16 additions and 21 deletions

View File

@ -4,7 +4,6 @@ import (
"fmt"
"git.lumeweb.com/LumeWeb/libs5-go/interfaces"
"git.lumeweb.com/LumeWeb/libs5-go/net"
"git.lumeweb.com/LumeWeb/libs5-go/types"
"github.com/vmihailenco/msgpack/v5"
"io"
"net/url"
@ -19,7 +18,7 @@ var _ IncomingMessageTyped = (*IncomingMessageImpl)(nil)
type IncomingMessageHandler func(node interfaces.Node, peer *net.Peer, u *url.URL, verifyId bool) error
type IncomingMessageImpl struct {
kind types.ProtocolMethod
kind int
data msgpack.RawMessage
original []byte
known bool
@ -64,7 +63,7 @@ func (i *IncomingMessageImpl) IncomingMessage() IncomingMessage {
return i.incoming
}
func (i *IncomingMessageImpl) GetKind() types.ProtocolMethod {
func (i *IncomingMessageImpl) Kind() int {
return i.kind
}
@ -76,10 +75,6 @@ func (i *IncomingMessageImpl) HandleMessage(node interfaces.Node, peer net.Peer,
panic("child class should implement this method")
}
func (i *IncomingMessageImpl) Kind() types.ProtocolMethod {
return i.kind
}
func (i *IncomingMessageImpl) Data() msgpack.RawMessage {
return i.data
}
@ -94,7 +89,7 @@ func NewIncomingMessageUnknown() *IncomingMessageImpl {
}
}
func NewIncomingMessageKnown(kind types.ProtocolMethod, data msgpack.RawMessage) *IncomingMessageImpl {
func NewIncomingMessageKnown(kind int, data msgpack.RawMessage) *IncomingMessageImpl {
return &IncomingMessageImpl{
kind: kind,
data: data,
@ -102,7 +97,7 @@ func NewIncomingMessageKnown(kind types.ProtocolMethod, data msgpack.RawMessage)
}
}
func NewIncomingMessageTyped(kind types.ProtocolMethod, data msgpack.RawMessage) *IncomingMessageTypedImpl {
func NewIncomingMessageTyped(kind int, data msgpack.RawMessage) *IncomingMessageTypedImpl {
known := NewIncomingMessageKnown(kind, data)
return &IncomingMessageTypedImpl{*known}
}
@ -120,7 +115,7 @@ func (i *IncomingMessageImpl) DecodeMsgpack(dec *msgpack.Decoder) error {
return err
}
i.kind = types.ProtocolMethod(kind)
i.kind = kind
raw, err := io.ReadAll(dec.Buffered())

View File

@ -19,26 +19,26 @@ func init() {
messageTypes = sync.Map{}
// Register factory functions instead of instances
RegisterMessageType(types.ProtocolMethodHandshakeOpen, func() base.IncomingMessage {
RegisterMessageType(int(types.ProtocolMethodHandshakeOpen), func() base.IncomingMessage {
return NewHandshakeOpen([]byte{}, "")
})
RegisterMessageType(types.ProtocolMethodHashQuery, func() base.IncomingMessage {
RegisterMessageType(int(types.ProtocolMethodHashQuery), func() base.IncomingMessage {
return NewHashQuery()
})
RegisterMessageType(types.ProtocolMethodSignedMessage, func() base.IncomingMessage {
RegisterMessageType(int(types.ProtocolMethodSignedMessage), func() base.IncomingMessage {
return signed.NewSignedMessage()
})
}
func RegisterMessageType(messageType types.ProtocolMethod, factoryFunc func() base.IncomingMessage) {
func RegisterMessageType(messageType int, factoryFunc func() base.IncomingMessage) {
if factoryFunc == nil {
panic("factoryFunc cannot be nil")
}
messageTypes.Store(messageType, factoryFunc)
messageTypes.Store(int(messageType), factoryFunc)
}
func GetMessageType(kind types.ProtocolMethod) (base.IncomingMessage, bool) {
func GetMessageType(kind int) (base.IncomingMessage, bool) {
value, ok := messageTypes.Load(kind)
if !ok {
return nil, false

View File

@ -13,22 +13,22 @@ var (
func init() {
messageTypes = sync.Map{}
RegisterMessageType(types.ProtocolMethodHandshakeDone, func() base.SignedIncomingMessage {
RegisterMessageType(int(types.ProtocolMethodHandshakeDone), func() base.SignedIncomingMessage {
return NewHandshakeDone()
})
RegisterMessageType(types.ProtocolMethodAnnouncePeers, func() base.SignedIncomingMessage {
RegisterMessageType(int(types.ProtocolMethodAnnouncePeers), func() base.SignedIncomingMessage {
return NewAnnouncePeers()
})
}
func RegisterMessageType(messageType types.ProtocolMethod, factoryFunc func() base.SignedIncomingMessage) {
func RegisterMessageType(messageType int, factoryFunc func() base.SignedIncomingMessage) {
if factoryFunc == nil {
panic("factoryFunc cannot be nil")
}
messageTypes.Store(messageType, factoryFunc)
messageTypes.Store(int(messageType), factoryFunc)
}
func GetMessageType(kind types.ProtocolMethod) (base.SignedIncomingMessage, bool) {
func GetMessageType(kind int) (base.SignedIncomingMessage, bool) {
value, ok := messageTypes.Load(kind)
if !ok {
return nil, false