diff --git a/net/net.go b/net/net.go index 6d0cd27..21783d3 100644 --- a/net/net.go +++ b/net/net.go @@ -25,8 +25,8 @@ var ( func init() { transports = sync.Map{} - //RegisterPeerType("ws", WebSocketPeer) - //RegisterPeerType("wss", WebSocketPeer) + RegisterTransport("ws", WebSocketPeer{}) + RegisterTransport("wss", WebSocketPeer{}) } func RegisterTransport(peerType string, factory interface{}) { if _, ok := factory.(PeerFactory); !ok { diff --git a/net/peer.go b/net/peer.go index 0e84bda..d56c55e 100644 --- a/net/peer.go +++ b/net/peer.go @@ -9,15 +9,15 @@ import ( // EventCallback type for the callback function type EventCallback func(event []byte) error -// DoneCallback type for the onDone callback -type DoneCallback func() +// CloseCallback type for the OnClose callback +type CloseCallback func() // ErrorCallback type for the onError callback type ErrorCallback func(args ...interface{}) // ListenerOptions struct for options type ListenerOptions struct { - OnDone *DoneCallback + OnClose *CloseCallback OnError *ErrorCallback Logger *zap.Logger } @@ -28,15 +28,41 @@ type Peer interface { ListenForMessages(callback EventCallback, options ListenerOptions) End() error SetId(id *encoding.NodeId) - GetId() *encoding.NodeId + Id() *encoding.NodeId SetChallenge(challenge []byte) - GetChallenge() []byte + Challenge() []byte + SetSocket(socket interface{}) + Socket() interface{} } type BasePeer struct { - ConnectionURIs []url.URL - IsConnected bool + connectionURIs []*url.URL + isConnected bool challenge []byte - Socket interface{} - Id *encoding.NodeId + socket interface{} + id *encoding.NodeId +} + +func (b *BasePeer) Challenge() []byte { + return b.challenge +} + +func (b *BasePeer) SetChallenge(challenge []byte) { + b.challenge = challenge +} + +func (b *BasePeer) Socket() interface{} { + return b.socket +} + +func (b *BasePeer) SetSocket(socket interface{}) { + b.socket = socket +} + +func (b *BasePeer) Id() *encoding.NodeId { + return b.id +} + +func (b *BasePeer) SetId(id *encoding.NodeId) { + b.id = id } diff --git a/net/ws.go b/net/ws.go index bcf37e5..2c2def2 100644 --- a/net/ws.go +++ b/net/ws.go @@ -1,8 +1,16 @@ package net import ( + "context" "git.lumeweb.com/LumeWeb/libs5-go/encoding" - "github.com/gorilla/websocket" + "net/url" + "nhooyr.io/websocket" +) + +var ( + _ PeerFactory = (*WebSocketPeer)(nil) + _ PeerStatic = (*WebSocketPeer)(nil) + _ Peer = (*WebSocketPeer)(nil) ) type WebSocketPeer struct { @@ -10,49 +18,74 @@ type WebSocketPeer struct { Socket *websocket.Conn } -func (p *WebSocketPeer) SendMessage(message []byte) { - err := p.Socket.WriteMessage(websocket.BinaryMessage, message) +func (p *WebSocketPeer) Connect(uri *url.URL) (interface{}, error) { + dial, _, err := websocket.Dial(context.Background(), uri.String(), nil) if err != nil { - return + return nil, err } + + return dial, nil +} + +func (p *WebSocketPeer) NewPeer(options *TransportPeerConfig) (Peer, error) { + peer := &WebSocketPeer{ + BasePeer: BasePeer{ + connectionURIs: options.Uris, + socket: options.Socket, + }, + Socket: options.Socket.(*websocket.Conn), + } + + return peer, nil +} + +func (p *WebSocketPeer) SendMessage(message []byte) error { + err := p.Socket.Write(context.Background(), websocket.MessageBinary, message) + if err != nil { + return err + } + + return nil } func (p *WebSocketPeer) RenderLocationURI() string { - return p.Socket.RemoteAddr().String() + return "WebSocket client" } -func (p *WebSocketPeer) ListenForMessages(callback EventCallback, onClose func(), onError func(error)) { +func (p *WebSocketPeer) ListenForMessages(callback EventCallback, options ListenerOptions) { for { - _, message, err := p.Socket.ReadMessage() + _, message, err := p.Socket.Read(context.Background()) if err != nil { - if onError != nil { - onError(err) + if options.OnError != nil { + (*options.OnError)(err) } break } err = callback(message) if err != nil { - if onError != nil { - onError(err) + if options.OnError != nil { + (*options.OnError)(err) } } } - if onClose != nil { - onClose() + if options.OnClose != nil { + (*options.OnClose)() } } -func (p *WebSocketPeer) End() { - err := p.Socket.Close() +func (p *WebSocketPeer) End() error { + err := p.Socket.Close(websocket.StatusNormalClosure, "") if err != nil { - return + return err } + + return nil } func (p *WebSocketPeer) SetId(id *encoding.NodeId) { - p.Id = id + p.id = id } func (p *WebSocketPeer) SetChallenge(challenge []byte) { diff --git a/protocol/hash_query.go b/protocol/hash_query.go index 0ce55ab..f985e65 100644 --- a/protocol/hash_query.go +++ b/protocol/hash_query.go @@ -94,13 +94,13 @@ func (h *HashQuery) HandleMessage(node interfaces.Node, peer *net.Peer, verifyId peers = peersVal.(*hashset.Set) - if exists := peers.Contains((*peer).GetId()); !exists { - peers.Add((*peer).GetId()) + if exists := peers.Contains((*peer).Id()); !exists { + peers.Add((*peer).Id()) } for _, val := range node.Services().P2P().Peers().Values() { peerVal := val.(net.Peer) - if !peerVal.GetId().Equals((*peer).GetId()) { + if !peerVal.Id().Equals((*peer).Id()) { err := peerVal.SendMessage(h.IncomingMessageImpl.Original()) if err != nil { node.Logger().Error("Failed to send message", zap.Error(err)) diff --git a/protocol/signed/handshake_done.go b/protocol/signed/handshake_done.go index 317d662..18493dc 100644 --- a/protocol/signed/handshake_done.go +++ b/protocol/signed/handshake_done.go @@ -40,7 +40,7 @@ func (h HandshakeDone) HandleMessage(node interfaces.Node, peer *net.Peer, verif return nil } - if !bytes.Equal((*peer).GetChallenge(), h.challenge) { + if !bytes.Equal((*peer).Challenge(), h.challenge) { return errors.New("Invalid challenge") } /* @@ -52,7 +52,7 @@ func (h HandshakeDone) HandleMessage(node interfaces.Node, peer *net.Peer, verif } } - peer.IsConnected = true + peer.isConnected = true supportedFeatures := data.UnpackInt() diff --git a/service/p2p.go b/service/p2p.go index c4cb00f..2d0df13 100644 --- a/service/p2p.go +++ b/service/p2p.go @@ -219,8 +219,8 @@ func (p *P2PImpl) OnNewPeer(peer *net.Peer, verifyId bool) error { return nil } func (p *P2PImpl) OnNewPeerListen(peer *net.Peer, verifyId bool) { - onDone := net.DoneCallback(func() { - peerId, err := (*peer).GetId().ToString() + onDone := net.CloseCallback(func() { + peerId, err := (*peer).Id().ToString() if err != nil { p.logger.Error("failed to get peer id", zap.Error(err)) return @@ -262,7 +262,7 @@ func (p *P2PImpl) OnNewPeerListen(peer *net.Peer, verifyId bool) { return nil }, net.ListenerOptions{ - OnDone: &onDone, + OnClose: &onDone, OnError: &onError, Logger: p.logger, })