From 72c3167e5fc35f8c624c0f213fb5c7ef3cf6b80b Mon Sep 17 00:00:00 2001 From: Derrick Hammer Date: Wed, 17 Jan 2024 16:46:13 -0500 Subject: [PATCH] feat: implement POST /s5/registry/subscription --- api/s5.go | 5 +-- api/s5/http.go | 84 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+), 2 deletions(-) diff --git a/api/s5.go b/api/s5.go index 0f40426..81445a1 100644 --- a/api/s5.go +++ b/api/s5.go @@ -49,7 +49,8 @@ func getRoutes(h *s5.HttpHandler, portal interfaces.Portal) map[string]jape.Hand "/s5/debug/download_urls/:cid": s5.AuthMiddleware(h.DebugDownloadUrls, portal), //Registry API - "GET /s5/registry": s5.AuthMiddleware(h.RegistryQuery, portal), - "POST /s5/registry": s5.AuthMiddleware(h.RegistrySet, portal), + "GET /s5/registry": s5.AuthMiddleware(h.RegistryQuery, portal), + "POST /s5/registry": s5.AuthMiddleware(h.RegistrySet, portal), + "GET /s5/registry/subscription": s5.AuthMiddleware(h.RegistrySubscription, portal), } } diff --git a/api/s5/http.go b/api/s5/http.go index 89c4e99..aec1367 100644 --- a/api/s5/http.go +++ b/api/s5/http.go @@ -2,6 +2,7 @@ package s5 import ( "bytes" + "context" "crypto/ed25519" "crypto/rand" "encoding/base64" @@ -26,6 +27,7 @@ import ( "math" "mime/multipart" "net/http" + "nhooyr.io/websocket" "strings" "time" ) @@ -1002,6 +1004,88 @@ func (h *HttpHandler) RegistrySet(jc jape.Context) { } } +func (h *HttpHandler) RegistrySubscription(jc jape.Context) { + // Create a context for the WebSocket operations + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var listeners []func() + + // Accept the WebSocket connection + c, err := websocket.Accept(jc.ResponseWriter, jc.Request, nil) + if err != nil { + h.portal.Logger().Error("error accepting websocket connection", zap.Error(err)) + return + } + defer func(c *websocket.Conn, code websocket.StatusCode, reason string) { + err := c.Close(code, reason) + if err != nil { + h.portal.Logger().Error("error closing websocket connection", zap.Error(err)) + } + + for _, listener := range listeners { + listener() + } + }(c, websocket.StatusNormalClosure, "connection closed") + + // Main loop for reading messages + for { + // Read a message (the actual reading and unpacking is skipped here) + _, data, err := c.Read(ctx) + if err != nil { + if websocket.CloseStatus(err) == websocket.StatusNormalClosure { + // Normal closure + h.portal.Logger().Info("websocket connection closed normally") + } else { + // Handle different types of errors + h.portal.Logger().Error("error in websocket connection", zap.Error(err)) + } + break + } + + decoder := msgpack.NewDecoder(bytes.NewReader(data)) + + method, err := decoder.DecodeInt() + + if err != nil { + h.portal.Logger().Error("error decoding method", zap.Error(err)) + break + } + + if method != 2 { + h.portal.Logger().Error("invalid method", zap.Int64("method", int64(method))) + break + } + + sre, err := decoder.DecodeBytes() + + if err != nil { + h.portal.Logger().Error("error decoding sre", zap.Error(err)) + break + } + + off, err := h.getNode().Services().Registry().Listen(sre, func(entry s5interfaces.SignedRegistryEntry) { + encoded, err := msgpack.Marshal(entry) + if err != nil { + h.portal.Logger().Error("error encoding entry", zap.Error(err)) + return + } + + err = c.Write(ctx, websocket.MessageBinary, encoded) + + if err != nil { + h.portal.Logger().Error("error writing to websocket", zap.Error(err)) + } + }) + if err != nil { + h.portal.Logger().Error("error listening to registry", zap.Error(err)) + break + } + + listeners = append(listeners, off) + } +} + func (h *HttpHandler) getNode() s5interfaces.Node { proto, _ := h.portal.ProtocolRegistry().Get("s5") protoInstance := proto.(*protocols.S5Protocol)