refactor: create generic AdaptMiddleware factory and change ApplyMiddlewares to take interfaces and handle multiple situations

This commit is contained in:
Derrick Hammer 2024-01-22 16:50:03 -05:00
parent dd857650e0
commit 527334f829
Signed by: pcfreak30
GPG Key ID: C997C339BE476FF2
3 changed files with 68 additions and 44 deletions

View File

@ -0,0 +1,49 @@
package middleware
import (
"go.sia.tech/jape"
"net/http"
"strings"
)
type JapeMiddlewareFunc func(jape.Handler) jape.Handler
type HttpMiddlewareFunc func(next http.Handler) http.Handler
func AdaptMiddleware(mid func(http.Handler) http.Handler) JapeMiddlewareFunc {
return jape.Adapt(func(h http.Handler) http.Handler {
handler := mid(h)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handler.ServeHTTP(w, r)
})
})
}
// proxyMiddleware creates a new HTTP middleware for handling X-Forwarded-For headers.
func proxyMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
ips := strings.Split(xff, ", ")
if len(ips) > 0 {
r.RemoteAddr = ips[0]
}
}
next.ServeHTTP(w, r)
})
}
func ApplyMiddlewares(handler jape.Handler, middlewares ...interface{}) jape.Handler {
for i := len(middlewares) - 1; i >= 0; i-- {
switch middlewares[i].(type) {
case JapeMiddlewareFunc:
mid := middlewares[i].(JapeMiddlewareFunc)
handler = mid(handler)
case func(next http.Handler) http.Handler:
mid := middlewares[i].(HttpMiddlewareFunc)
handler = AdaptMiddleware(mid)(handler)
default:
panic("Invalid middleware type")
}
}
return handler
}

View File

@ -44,8 +44,8 @@ func parseAuthTokenHeader(headers http.Header) string {
return authHeader
}
func AuthMiddleware(handler jape.Handler, portal interfaces.Portal) jape.Handler {
return jape.Adapt(func(h http.Handler) http.Handler {
func AuthMiddleware(portal interfaces.Portal) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authToken := findAuthToken(r)
@ -100,12 +100,12 @@ func AuthMiddleware(handler jape.Handler, portal interfaces.Portal) jape.Handler
ctx := context.WithValue(r.Context(), S5AuthUserIDKey, userID)
r = r.WithContext(ctx)
h.ServeHTTP(w, r)
next.ServeHTTP(w, r)
} else {
http.Error(w, "Invalid JWT", http.StatusUnauthorized)
}
})
})(handler)
}
}
type tusJwtResponseWriter struct {
@ -159,46 +159,36 @@ func replacePrefix(prefix string, h http.Handler) http.Handler {
}
func BuildS5TusApi(portal interfaces.Portal) jape.Handler {
// Wrapper function for AuthMiddleware to fit the MiddlewareFunc signature
authMiddlewareFunc := func(h jape.Handler) jape.Handler {
return AuthMiddleware(h, portal)
}
// Create a jape.Handler for your tusHandler
tusJapeHandler := func(c jape.Context) {
tusHandler := portal.Storage().Tus()
tusHandler.ServeHTTP(c.ResponseWriter, c.Request)
}
protocolMiddleware := jape.Adapt(func(h http.Handler) http.Handler {
protocolMiddleware := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), "protocol", "s5")
h.ServeHTTP(w, r.WithContext(ctx))
next.ServeHTTP(w, r.WithContext(ctx))
})
})
stripPrefix := func(h jape.Handler) jape.Handler {
return jape.Adapt(func(h http.Handler) http.Handler {
return replacePrefix("/s5/upload/tus", h)
})(h)
}
injectJwt := func(h jape.Handler) jape.Handler {
return jape.Adapt(func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
res := w
if r.Method == http.MethodPost && r.URL.Path == "/s5/upload/tus" {
res = &tusJwtResponseWriter{ResponseWriter: w, req: r}
}
stripPrefix := func(next http.Handler) http.Handler {
return replacePrefix("/s5/upload/tus", next)
}
h.ServeHTTP(res, r)
})
})(h)
injectJwt := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
res := w
if r.Method == http.MethodPost && r.URL.Path == "/s5/upload/tus" {
res = &tusJwtResponseWriter{ResponseWriter: w, req: r}
}
next.ServeHTTP(res, r)
})
}
// Apply the middlewares to the tusJapeHandler
tusHandler := ApplyMiddlewares(tusJapeHandler, authMiddlewareFunc, injectJwt, protocolMiddleware, stripPrefix)
tusHandler := ApplyMiddlewares(tusJapeHandler, AuthMiddleware(portal), injectJwt, protocolMiddleware, stripPrefix, proxyMiddleware)
return tusHandler
}

View File

@ -1,15 +0,0 @@
package middleware
import (
"go.sia.tech/jape"
)
type MiddlewareFunc func(jape.Handler) jape.Handler
func ApplyMiddlewares(handler jape.Handler, middlewares ...MiddlewareFunc) jape.Handler {
// Apply each middleware in reverse order
for i := len(middlewares) - 1; i >= 0; i-- {
handler = middlewares[i](handler)
}
return handler
}