package net import ( "context" "git.lumeweb.com/LumeWeb/libs5-go/encoding" "net" "net/url" "nhooyr.io/websocket" "sync" ) var ( _ PeerFactory = (*WebSocketPeer)(nil) _ PeerStatic = (*WebSocketPeer)(nil) _ Peer = (*WebSocketPeer)(nil) ) type WebSocketPeer struct { BasePeer socket *websocket.Conn abuser bool } func (p *WebSocketPeer) Connect(uri *url.URL) (interface{}, error) { dial, _, err := websocket.Dial(context.Background(), uri.String(), nil) if err != nil { return nil, err } return dial, nil } func (p *WebSocketPeer) NewPeer(options *TransportPeerConfig) (Peer, error) { peer := &WebSocketPeer{ BasePeer: BasePeer{ connectionURIs: options.Uris, socket: options.Socket, }, socket: options.Socket.(*websocket.Conn), } return peer, nil } func (p *WebSocketPeer) SendMessage(message []byte) error { err := p.socket.Write(context.Background(), websocket.MessageBinary, message) if err != nil { return err } return nil } func (p *WebSocketPeer) RenderLocationURI() string { return "WebSocket client" } func (p *WebSocketPeer) ListenForMessages(callback EventCallback, options ListenerOptions) { errChan := make(chan error, 10) doneChan := make(chan struct{}) var wg sync.WaitGroup for { _, message, err := p.socket.Read(context.Background()) if err != nil { if options.OnError != nil { (*options.OnError)(err) } break } wg.Add(1) // Process each message in a separate goroutine go func(msg []byte) { defer wg.Done() // Call the callback and send any errors to the error channel if err := callback(msg); err != nil { select { case errChan <- err: case <-doneChan: // Stop sending errors if doneChan is closed } } }(message) // Non-blocking error check select { case err := <-errChan: if options.OnError != nil { (*options.OnError)(err) } default: } } if options.OnClose != nil { (*options.OnClose)() } // Close doneChan and wait for all goroutines to finish close(doneChan) wg.Wait() // Handle remaining errors close(errChan) for err := range errChan { if options.OnError != nil { (*options.OnError)(err) } } } func (p *WebSocketPeer) End() error { err := p.socket.Close(websocket.StatusNormalClosure, "") if err != nil { return err } return nil } func (p *WebSocketPeer) EndForAbuse() error { p.abuser = true err := p.socket.Close(websocket.StatusPolicyViolation, "") if err != nil { return err } return nil } func (p *WebSocketPeer) SetId(id *encoding.NodeId) { p.id = id } func (p *WebSocketPeer) SetChallenge(challenge []byte) { p.challenge = challenge } func (p *WebSocketPeer) GetChallenge() []byte { return p.challenge } func (p *WebSocketPeer) GetIP() net.Addr { ctx, cancel := context.WithCancel(context.Background()) netConn := websocket.NetConn(ctx, p.socket, websocket.MessageBinary) ipAddr := netConn.RemoteAddr() cancel() return ipAddr } func (b *WebSocketPeer) GetIPString() string { return b.GetIP().String() } func (p *WebSocketPeer) Abuser() bool { return p.abuser }