From 2e9b07c6bdc1f9e01227ba1fb4b694257b38d8db Mon Sep 17 00:00:00 2001 From: Derrick Hammer Date: Sun, 7 Jan 2024 06:35:41 -0500 Subject: [PATCH] refactor: dont use pointers with interfaces --- interfaces/p2p.go | 4 ++-- net/net.go | 4 ++-- net/peer.go | 20 ++++++++++++++++++++ protocol/base/base.go | 2 +- protocol/base/incoming_message.go | 2 +- protocol/handshake_open.go | 2 +- protocol/hash_query.go | 10 +++++----- protocol/signed/accounce_peers.go | 2 +- protocol/signed/handshake_done.go | 8 ++++---- protocol/signed/signed_message.go | 2 +- protocol/storage_location.go | 2 +- service/p2p.go | 15 +++++++-------- 12 files changed, 46 insertions(+), 27 deletions(-) diff --git a/interfaces/p2p.go b/interfaces/p2p.go index 433d8de..58871e2 100644 --- a/interfaces/p2p.go +++ b/interfaces/p2p.go @@ -14,8 +14,8 @@ type P2PService interface { Stop() error Init() error ConnectToNode(connectionUris []*url.URL, retried bool) error - OnNewPeer(peer *net.Peer, verifyId bool) error - OnNewPeerListen(peer *net.Peer, verifyId bool) + OnNewPeer(peer net.Peer, verifyId bool) error + OnNewPeerListen(peer net.Peer, verifyId bool) ReadNodeScore(nodeId *encoding.NodeId) (NodeVotes, error) GetNodeScore(nodeId *encoding.NodeId) (float64, error) SortNodesByScore(nodes []*encoding.NodeId) ([]*encoding.NodeId, error) diff --git a/net/net.go b/net/net.go index 21783d3..865ae42 100644 --- a/net/net.go +++ b/net/net.go @@ -51,7 +51,7 @@ func CreateTransportSocket(peerType string, uri *url.URL) (interface{}, error) { return &t, err } -func CreateTransportPeer(peerType string, options *TransportPeerConfig) (*Peer, error) { +func CreateTransportPeer(peerType string, options *TransportPeerConfig) (Peer, error) { factory, ok := transports.Load(peerType) if !ok { return nil, errors.New("no factory registered for type: " + peerType) @@ -59,5 +59,5 @@ func CreateTransportPeer(peerType string, options *TransportPeerConfig) (*Peer, t, err := factory.(PeerFactory).NewPeer(options) - return &t, err + return t, err } diff --git a/net/peer.go b/net/peer.go index d56c55e..5ff7171 100644 --- a/net/peer.go +++ b/net/peer.go @@ -6,6 +6,10 @@ import ( "net/url" ) +var ( + _ Peer = (*BasePeer)(nil) +) + // EventCallback type for the callback function type EventCallback func(event []byte) error @@ -43,6 +47,22 @@ type BasePeer struct { id *encoding.NodeId } +func (b *BasePeer) SendMessage(message []byte) error { + panic("must implement in child class") +} + +func (b *BasePeer) RenderLocationURI() string { + panic("must implement in child class") +} + +func (b *BasePeer) ListenForMessages(callback EventCallback, options ListenerOptions) { + panic("must implement in child class") +} + +func (b *BasePeer) End() error { + panic("must implement in child class") +} + func (b *BasePeer) Challenge() []byte { return b.challenge } diff --git a/protocol/base/base.go b/protocol/base/base.go index bfc4b87..3b283d6 100644 --- a/protocol/base/base.go +++ b/protocol/base/base.go @@ -7,7 +7,7 @@ import ( ) type IncomingMessage interface { - HandleMessage(node interfaces.Node, peer *net.Peer, verifyId bool) error + HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error SetIncomingMessage(msg IncomingMessage) msgpack.CustomDecoder } diff --git a/protocol/base/incoming_message.go b/protocol/base/incoming_message.go index 8f2372e..20f5be1 100644 --- a/protocol/base/incoming_message.go +++ b/protocol/base/incoming_message.go @@ -41,7 +41,7 @@ func (i *IncomingMessageImpl) ToMessage() (message []byte, err error) { return msgpack.Marshal(i) } -func (i *IncomingMessageImpl) HandleMessage(node interfaces.Node, peer *net.Peer, verifyId bool) error { +func (i *IncomingMessageImpl) HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error { panic("child class should implement this method") } diff --git a/protocol/handshake_open.go b/protocol/handshake_open.go index bcab306..8c183e5 100644 --- a/protocol/handshake_open.go +++ b/protocol/handshake_open.go @@ -56,7 +56,7 @@ func (m HandshakeOpen) EncodeMsgpack(enc *msgpack.Encoder) error { return nil } -func (m *HandshakeOpen) HandleMessage(node interfaces.Node, peer *net.Peer, verifyId bool) error { +func (m *HandshakeOpen) HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error { return nil } diff --git a/protocol/hash_query.go b/protocol/hash_query.go index f985e65..33837bd 100644 --- a/protocol/hash_query.go +++ b/protocol/hash_query.go @@ -48,7 +48,7 @@ func (h *HashQuery) DecodeMessage(dec *msgpack.Decoder) error { return nil } -func (h *HashQuery) HandleMessage(node interfaces.Node, peer *net.Peer, verifyId bool) error { +func (h *HashQuery) HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error { mapLocations, err := node.GetCachedStorageLocations(h.hash, h.kinds) if err != nil { log.Printf("Error getting cached storage locations: %v", err) @@ -79,7 +79,7 @@ func (h *HashQuery) HandleMessage(node interfaces.Node, peer *net.Peer, verifyId entry, exists := mapLocations[sortedNodeId] if exists { - err := (*peer).SendMessage(entry.ProviderMessage()) + err := peer.SendMessage(entry.ProviderMessage()) if err != nil { return err } @@ -94,13 +94,13 @@ func (h *HashQuery) HandleMessage(node interfaces.Node, peer *net.Peer, verifyId peers = peersVal.(*hashset.Set) - if exists := peers.Contains((*peer).Id()); !exists { - peers.Add((*peer).Id()) + if exists := peers.Contains(peer.Id()); !exists { + peers.Add(peer.Id()) } for _, val := range node.Services().P2P().Peers().Values() { peerVal := val.(net.Peer) - if !peerVal.Id().Equals((*peer).Id()) { + if !peerVal.Id().Equals(peer.Id()) { err := peerVal.SendMessage(h.IncomingMessageImpl.Original()) if err != nil { node.Logger().Error("Failed to send message", zap.Error(err)) diff --git a/protocol/signed/accounce_peers.go b/protocol/signed/accounce_peers.go index 832d047..a7cfb3e 100644 --- a/protocol/signed/accounce_peers.go +++ b/protocol/signed/accounce_peers.go @@ -60,7 +60,7 @@ func (a *AnnouncePeers) DecodeMessage(dec *msgpack.Decoder) error { return nil } -func (a AnnouncePeers) HandleMessage(node interfaces.Node, peer *net.Peer, verifyId bool) error { +func (a AnnouncePeers) HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error { if len(a.connectionUris) > 0 { firstUrl := a.connectionUris[0] uri := new(url.URL) diff --git a/protocol/signed/handshake_done.go b/protocol/signed/handshake_done.go index 18493dc..323d434 100644 --- a/protocol/signed/handshake_done.go +++ b/protocol/signed/handshake_done.go @@ -31,21 +31,21 @@ func NewHandshakeDone() *HandshakeDone { return &HandshakeDone{challenge: nil, networkId: "", supportedFeatures: -1} } -func (h HandshakeDone) HandleMessage(node interfaces.Node, peer *net.Peer, verifyId bool) error { +func (h HandshakeDone) HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error { if !node.IsStarted() { - err := (*peer).End() + err := peer.End() if err != nil { return nil } return nil } - if !bytes.Equal((*peer).Challenge(), h.challenge) { + if !bytes.Equal(peer.Challenge(), h.challenge) { return errors.New("Invalid challenge") } /* if !verifyId { - (*peer).SetId(h) + peer.SetId(h) } else { if !peer.ID.Equals(pId) { return errInvalidChallenge diff --git a/protocol/signed/signed_message.go b/protocol/signed/signed_message.go index 975ca82..b90877f 100644 --- a/protocol/signed/signed_message.go +++ b/protocol/signed/signed_message.go @@ -54,7 +54,7 @@ func NewSignedMessage() *SignedMessage { return &SignedMessage{} } -func (s *SignedMessage) HandleMessage(node interfaces.Node, peer *net.Peer, verifyId bool) error { +func (s *SignedMessage) HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error { var payload signedMessagePayoad err := msgpack.Unmarshal(s.message, &payload) diff --git a/protocol/storage_location.go b/protocol/storage_location.go index 3e5bc1f..51186e3 100644 --- a/protocol/storage_location.go +++ b/protocol/storage_location.go @@ -41,7 +41,7 @@ func (s *StorageLocation) DecodeMessage(dec *msgpack.Decoder) error { return nil } -func (s *StorageLocation) HandleMessage(node interfaces.Node, peer *net.Peer, verifyId bool) error { +func (s *StorageLocation) HandleMessage(node interfaces.Node, peer net.Peer, verifyId bool) error { hash := encoding.NewMultihash(s.raw[1:34]) // Replace NewMultihash with appropriate function fmt.Println("Hash:", hash) diff --git a/service/p2p.go b/service/p2p.go index 2d0df13..f3ee15c 100644 --- a/service/p2p.go +++ b/service/p2p.go @@ -194,15 +194,14 @@ func (p *P2PImpl) ConnectToNode(connectionUris []*url.URL, retried bool) error { return err } - (*peer).SetId(id) + peer.SetId(id) return p.OnNewPeer(peer, true) } -func (p *P2PImpl) OnNewPeer(peer *net.Peer, verifyId bool) error { +func (p *P2PImpl) OnNewPeer(peer net.Peer, verifyId bool) error { challenge := protocol.GenerateChallenge() - pd := *peer - pd.SetChallenge(challenge) + peer.SetChallenge(challenge) p.OnNewPeerListen(peer, verifyId) @@ -212,15 +211,15 @@ func (p *P2PImpl) OnNewPeer(peer *net.Peer, verifyId bool) error { return err } - err = pd.SendMessage(handshakeOpenMsg) + err = peer.SendMessage(handshakeOpenMsg) if err != nil { return err } return nil } -func (p *P2PImpl) OnNewPeerListen(peer *net.Peer, verifyId bool) { +func (p *P2PImpl) OnNewPeerListen(peer net.Peer, verifyId bool) { onDone := net.CloseCallback(func() { - peerId, err := (*peer).Id().ToString() + peerId, err := peer.Id().ToString() if err != nil { p.logger.Error("failed to get peer id", zap.Error(err)) return @@ -236,7 +235,7 @@ func (p *P2PImpl) OnNewPeerListen(peer *net.Peer, verifyId bool) { p.logger.Error("peer error", zap.Any("args", args)) }) - (*peer).ListenForMessages(func(message []byte) error { + peer.ListenForMessages(func(message []byte) error { imsg := base.NewIncomingMessageUnknown() err := msgpack.Unmarshal(message, imsg)