From 527334f829c95676833bb14bfb51a419ddc3c5ae Mon Sep 17 00:00:00 2001 From: Derrick Hammer Date: Mon, 22 Jan 2024 16:50:03 -0500 Subject: [PATCH] refactor: create generic AdaptMiddleware factory and change ApplyMiddlewares to take interfaces and handle multiple situations --- api/middleware/middleware.go | 49 ++++++++++++++++++++++++++++++++++++ api/middleware/s5.go | 48 ++++++++++++++--------------------- api/middleware/util.go | 15 ----------- 3 files changed, 68 insertions(+), 44 deletions(-) create mode 100644 api/middleware/middleware.go delete mode 100644 api/middleware/util.go diff --git a/api/middleware/middleware.go b/api/middleware/middleware.go new file mode 100644 index 0000000..40eb60d --- /dev/null +++ b/api/middleware/middleware.go @@ -0,0 +1,49 @@ +package middleware + +import ( + "go.sia.tech/jape" + "net/http" + "strings" +) + +type JapeMiddlewareFunc func(jape.Handler) jape.Handler +type HttpMiddlewareFunc func(next http.Handler) http.Handler + +func AdaptMiddleware(mid func(http.Handler) http.Handler) JapeMiddlewareFunc { + return jape.Adapt(func(h http.Handler) http.Handler { + handler := mid(h) + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handler.ServeHTTP(w, r) + }) + }) +} + +// proxyMiddleware creates a new HTTP middleware for handling X-Forwarded-For headers. +func proxyMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + ips := strings.Split(xff, ", ") + if len(ips) > 0 { + r.RemoteAddr = ips[0] + } + } + next.ServeHTTP(w, r) + }) +} + +func ApplyMiddlewares(handler jape.Handler, middlewares ...interface{}) jape.Handler { + for i := len(middlewares) - 1; i >= 0; i-- { + switch middlewares[i].(type) { + case JapeMiddlewareFunc: + mid := middlewares[i].(JapeMiddlewareFunc) + handler = mid(handler) + case func(next http.Handler) http.Handler: + mid := middlewares[i].(HttpMiddlewareFunc) + handler = AdaptMiddleware(mid)(handler) + + default: + panic("Invalid middleware type") + } + } + return handler +} diff --git a/api/middleware/s5.go b/api/middleware/s5.go index 6a0aa98..887526a 100644 --- a/api/middleware/s5.go +++ b/api/middleware/s5.go @@ -44,8 +44,8 @@ func parseAuthTokenHeader(headers http.Header) string { return authHeader } -func AuthMiddleware(handler jape.Handler, portal interfaces.Portal) jape.Handler { - return jape.Adapt(func(h http.Handler) http.Handler { +func AuthMiddleware(portal interfaces.Portal) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { authToken := findAuthToken(r) @@ -100,12 +100,12 @@ func AuthMiddleware(handler jape.Handler, portal interfaces.Portal) jape.Handler ctx := context.WithValue(r.Context(), S5AuthUserIDKey, userID) r = r.WithContext(ctx) - h.ServeHTTP(w, r) + next.ServeHTTP(w, r) } else { http.Error(w, "Invalid JWT", http.StatusUnauthorized) } }) - })(handler) + } } type tusJwtResponseWriter struct { @@ -159,46 +159,36 @@ func replacePrefix(prefix string, h http.Handler) http.Handler { } func BuildS5TusApi(portal interfaces.Portal) jape.Handler { - - // Wrapper function for AuthMiddleware to fit the MiddlewareFunc signature - authMiddlewareFunc := func(h jape.Handler) jape.Handler { - return AuthMiddleware(h, portal) - } - // Create a jape.Handler for your tusHandler tusJapeHandler := func(c jape.Context) { tusHandler := portal.Storage().Tus() tusHandler.ServeHTTP(c.ResponseWriter, c.Request) } - protocolMiddleware := jape.Adapt(func(h http.Handler) http.Handler { + protocolMiddleware := func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := context.WithValue(r.Context(), "protocol", "s5") - h.ServeHTTP(w, r.WithContext(ctx)) + next.ServeHTTP(w, r.WithContext(ctx)) }) - }) - - stripPrefix := func(h jape.Handler) jape.Handler { - return jape.Adapt(func(h http.Handler) http.Handler { - return replacePrefix("/s5/upload/tus", h) - })(h) } - injectJwt := func(h jape.Handler) jape.Handler { - return jape.Adapt(func(h http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - res := w - if r.Method == http.MethodPost && r.URL.Path == "/s5/upload/tus" { - res = &tusJwtResponseWriter{ResponseWriter: w, req: r} - } + stripPrefix := func(next http.Handler) http.Handler { + return replacePrefix("/s5/upload/tus", next) + } - h.ServeHTTP(res, r) - }) - })(h) + injectJwt := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + res := w + if r.Method == http.MethodPost && r.URL.Path == "/s5/upload/tus" { + res = &tusJwtResponseWriter{ResponseWriter: w, req: r} + } + + next.ServeHTTP(res, r) + }) } // Apply the middlewares to the tusJapeHandler - tusHandler := ApplyMiddlewares(tusJapeHandler, authMiddlewareFunc, injectJwt, protocolMiddleware, stripPrefix) + tusHandler := ApplyMiddlewares(tusJapeHandler, AuthMiddleware(portal), injectJwt, protocolMiddleware, stripPrefix, proxyMiddleware) return tusHandler } diff --git a/api/middleware/util.go b/api/middleware/util.go deleted file mode 100644 index 2dcc28c..0000000 --- a/api/middleware/util.go +++ /dev/null @@ -1,15 +0,0 @@ -package middleware - -import ( - "go.sia.tech/jape" -) - -type MiddlewareFunc func(jape.Handler) jape.Handler - -func ApplyMiddlewares(handler jape.Handler, middlewares ...MiddlewareFunc) jape.Handler { - // Apply each middleware in reverse order - for i := len(middlewares) - 1; i >= 0; i-- { - handler = middlewares[i](handler) - } - return handler -}