refactor: create generic AdaptMiddleware factory and change ApplyMiddlewares to take interfaces and handle multiple situations

This commit is contained in:
Derrick Hammer 2024-01-22 16:50:03 -05:00
parent dd857650e0
commit 527334f829
Signed by: pcfreak30
GPG Key ID: C997C339BE476FF2
3 changed files with 68 additions and 44 deletions

View File

@ -0,0 +1,49 @@
package middleware
import (
"go.sia.tech/jape"
"net/http"
"strings"
)
type JapeMiddlewareFunc func(jape.Handler) jape.Handler
type HttpMiddlewareFunc func(next http.Handler) http.Handler
func AdaptMiddleware(mid func(http.Handler) http.Handler) JapeMiddlewareFunc {
return jape.Adapt(func(h http.Handler) http.Handler {
handler := mid(h)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handler.ServeHTTP(w, r)
})
})
}
// proxyMiddleware creates a new HTTP middleware for handling X-Forwarded-For headers.
func proxyMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
ips := strings.Split(xff, ", ")
if len(ips) > 0 {
r.RemoteAddr = ips[0]
}
}
next.ServeHTTP(w, r)
})
}
func ApplyMiddlewares(handler jape.Handler, middlewares ...interface{}) jape.Handler {
for i := len(middlewares) - 1; i >= 0; i-- {
switch middlewares[i].(type) {
case JapeMiddlewareFunc:
mid := middlewares[i].(JapeMiddlewareFunc)
handler = mid(handler)
case func(next http.Handler) http.Handler:
mid := middlewares[i].(HttpMiddlewareFunc)
handler = AdaptMiddleware(mid)(handler)
default:
panic("Invalid middleware type")
}
}
return handler
}

View File

@ -44,8 +44,8 @@ func parseAuthTokenHeader(headers http.Header) string {
return authHeader return authHeader
} }
func AuthMiddleware(handler jape.Handler, portal interfaces.Portal) jape.Handler { func AuthMiddleware(portal interfaces.Portal) func(http.Handler) http.Handler {
return jape.Adapt(func(h http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authToken := findAuthToken(r) authToken := findAuthToken(r)
@ -100,12 +100,12 @@ func AuthMiddleware(handler jape.Handler, portal interfaces.Portal) jape.Handler
ctx := context.WithValue(r.Context(), S5AuthUserIDKey, userID) ctx := context.WithValue(r.Context(), S5AuthUserIDKey, userID)
r = r.WithContext(ctx) r = r.WithContext(ctx)
h.ServeHTTP(w, r) next.ServeHTTP(w, r)
} else { } else {
http.Error(w, "Invalid JWT", http.StatusUnauthorized) http.Error(w, "Invalid JWT", http.StatusUnauthorized)
} }
}) })
})(handler) }
} }
type tusJwtResponseWriter struct { type tusJwtResponseWriter struct {
@ -159,46 +159,36 @@ func replacePrefix(prefix string, h http.Handler) http.Handler {
} }
func BuildS5TusApi(portal interfaces.Portal) jape.Handler { func BuildS5TusApi(portal interfaces.Portal) jape.Handler {
// Wrapper function for AuthMiddleware to fit the MiddlewareFunc signature
authMiddlewareFunc := func(h jape.Handler) jape.Handler {
return AuthMiddleware(h, portal)
}
// Create a jape.Handler for your tusHandler // Create a jape.Handler for your tusHandler
tusJapeHandler := func(c jape.Context) { tusJapeHandler := func(c jape.Context) {
tusHandler := portal.Storage().Tus() tusHandler := portal.Storage().Tus()
tusHandler.ServeHTTP(c.ResponseWriter, c.Request) tusHandler.ServeHTTP(c.ResponseWriter, c.Request)
} }
protocolMiddleware := jape.Adapt(func(h http.Handler) http.Handler { protocolMiddleware := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), "protocol", "s5") ctx := context.WithValue(r.Context(), "protocol", "s5")
h.ServeHTTP(w, r.WithContext(ctx)) next.ServeHTTP(w, r.WithContext(ctx))
}) })
})
stripPrefix := func(h jape.Handler) jape.Handler {
return jape.Adapt(func(h http.Handler) http.Handler {
return replacePrefix("/s5/upload/tus", h)
})(h)
} }
injectJwt := func(h jape.Handler) jape.Handler { stripPrefix := func(next http.Handler) http.Handler {
return jape.Adapt(func(h http.Handler) http.Handler { return replacePrefix("/s5/upload/tus", next)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { }
res := w
if r.Method == http.MethodPost && r.URL.Path == "/s5/upload/tus" {
res = &tusJwtResponseWriter{ResponseWriter: w, req: r}
}
h.ServeHTTP(res, r) injectJwt := func(next http.Handler) http.Handler {
}) return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
})(h) res := w
if r.Method == http.MethodPost && r.URL.Path == "/s5/upload/tus" {
res = &tusJwtResponseWriter{ResponseWriter: w, req: r}
}
next.ServeHTTP(res, r)
})
} }
// Apply the middlewares to the tusJapeHandler // Apply the middlewares to the tusJapeHandler
tusHandler := ApplyMiddlewares(tusJapeHandler, authMiddlewareFunc, injectJwt, protocolMiddleware, stripPrefix) tusHandler := ApplyMiddlewares(tusJapeHandler, AuthMiddleware(portal), injectJwt, protocolMiddleware, stripPrefix, proxyMiddleware)
return tusHandler return tusHandler
} }

View File

@ -1,15 +0,0 @@
package middleware
import (
"go.sia.tech/jape"
)
type MiddlewareFunc func(jape.Handler) jape.Handler
func ApplyMiddlewares(handler jape.Handler, middlewares ...MiddlewareFunc) jape.Handler {
// Apply each middleware in reverse order
for i := len(middlewares) - 1; i >= 0; i-- {
handler = middlewares[i](handler)
}
return handler
}