refactor: create generic AdaptMiddleware factory and change ApplyMiddlewares to take interfaces and handle multiple situations
This commit is contained in:
parent
dd857650e0
commit
527334f829
|
@ -0,0 +1,49 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"go.sia.tech/jape"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type JapeMiddlewareFunc func(jape.Handler) jape.Handler
|
||||||
|
type HttpMiddlewareFunc func(next http.Handler) http.Handler
|
||||||
|
|
||||||
|
func AdaptMiddleware(mid func(http.Handler) http.Handler) JapeMiddlewareFunc {
|
||||||
|
return jape.Adapt(func(h http.Handler) http.Handler {
|
||||||
|
handler := mid(h)
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// proxyMiddleware creates a new HTTP middleware for handling X-Forwarded-For headers.
|
||||||
|
func proxyMiddleware(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||||
|
ips := strings.Split(xff, ", ")
|
||||||
|
if len(ips) > 0 {
|
||||||
|
r.RemoteAddr = ips[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func ApplyMiddlewares(handler jape.Handler, middlewares ...interface{}) jape.Handler {
|
||||||
|
for i := len(middlewares) - 1; i >= 0; i-- {
|
||||||
|
switch middlewares[i].(type) {
|
||||||
|
case JapeMiddlewareFunc:
|
||||||
|
mid := middlewares[i].(JapeMiddlewareFunc)
|
||||||
|
handler = mid(handler)
|
||||||
|
case func(next http.Handler) http.Handler:
|
||||||
|
mid := middlewares[i].(HttpMiddlewareFunc)
|
||||||
|
handler = AdaptMiddleware(mid)(handler)
|
||||||
|
|
||||||
|
default:
|
||||||
|
panic("Invalid middleware type")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return handler
|
||||||
|
}
|
|
@ -44,8 +44,8 @@ func parseAuthTokenHeader(headers http.Header) string {
|
||||||
return authHeader
|
return authHeader
|
||||||
}
|
}
|
||||||
|
|
||||||
func AuthMiddleware(handler jape.Handler, portal interfaces.Portal) jape.Handler {
|
func AuthMiddleware(portal interfaces.Portal) func(http.Handler) http.Handler {
|
||||||
return jape.Adapt(func(h http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
authToken := findAuthToken(r)
|
authToken := findAuthToken(r)
|
||||||
|
|
||||||
|
@ -100,12 +100,12 @@ func AuthMiddleware(handler jape.Handler, portal interfaces.Portal) jape.Handler
|
||||||
ctx := context.WithValue(r.Context(), S5AuthUserIDKey, userID)
|
ctx := context.WithValue(r.Context(), S5AuthUserIDKey, userID)
|
||||||
r = r.WithContext(ctx)
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
h.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
} else {
|
} else {
|
||||||
http.Error(w, "Invalid JWT", http.StatusUnauthorized)
|
http.Error(w, "Invalid JWT", http.StatusUnauthorized)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
})(handler)
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type tusJwtResponseWriter struct {
|
type tusJwtResponseWriter struct {
|
||||||
|
@ -159,46 +159,36 @@ func replacePrefix(prefix string, h http.Handler) http.Handler {
|
||||||
}
|
}
|
||||||
|
|
||||||
func BuildS5TusApi(portal interfaces.Portal) jape.Handler {
|
func BuildS5TusApi(portal interfaces.Portal) jape.Handler {
|
||||||
|
|
||||||
// Wrapper function for AuthMiddleware to fit the MiddlewareFunc signature
|
|
||||||
authMiddlewareFunc := func(h jape.Handler) jape.Handler {
|
|
||||||
return AuthMiddleware(h, portal)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a jape.Handler for your tusHandler
|
// Create a jape.Handler for your tusHandler
|
||||||
tusJapeHandler := func(c jape.Context) {
|
tusJapeHandler := func(c jape.Context) {
|
||||||
tusHandler := portal.Storage().Tus()
|
tusHandler := portal.Storage().Tus()
|
||||||
tusHandler.ServeHTTP(c.ResponseWriter, c.Request)
|
tusHandler.ServeHTTP(c.ResponseWriter, c.Request)
|
||||||
}
|
}
|
||||||
|
|
||||||
protocolMiddleware := jape.Adapt(func(h http.Handler) http.Handler {
|
protocolMiddleware := func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := context.WithValue(r.Context(), "protocol", "s5")
|
ctx := context.WithValue(r.Context(), "protocol", "s5")
|
||||||
h.ServeHTTP(w, r.WithContext(ctx))
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
})
|
})
|
||||||
})
|
|
||||||
|
|
||||||
stripPrefix := func(h jape.Handler) jape.Handler {
|
|
||||||
return jape.Adapt(func(h http.Handler) http.Handler {
|
|
||||||
return replacePrefix("/s5/upload/tus", h)
|
|
||||||
})(h)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
injectJwt := func(h jape.Handler) jape.Handler {
|
stripPrefix := func(next http.Handler) http.Handler {
|
||||||
return jape.Adapt(func(h http.Handler) http.Handler {
|
return replacePrefix("/s5/upload/tus", next)
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
}
|
||||||
res := w
|
|
||||||
if r.Method == http.MethodPost && r.URL.Path == "/s5/upload/tus" {
|
|
||||||
res = &tusJwtResponseWriter{ResponseWriter: w, req: r}
|
|
||||||
}
|
|
||||||
|
|
||||||
h.ServeHTTP(res, r)
|
injectJwt := func(next http.Handler) http.Handler {
|
||||||
})
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
})(h)
|
res := w
|
||||||
|
if r.Method == http.MethodPost && r.URL.Path == "/s5/upload/tus" {
|
||||||
|
res = &tusJwtResponseWriter{ResponseWriter: w, req: r}
|
||||||
|
}
|
||||||
|
|
||||||
|
next.ServeHTTP(res, r)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply the middlewares to the tusJapeHandler
|
// Apply the middlewares to the tusJapeHandler
|
||||||
tusHandler := ApplyMiddlewares(tusJapeHandler, authMiddlewareFunc, injectJwt, protocolMiddleware, stripPrefix)
|
tusHandler := ApplyMiddlewares(tusJapeHandler, AuthMiddleware(portal), injectJwt, protocolMiddleware, stripPrefix, proxyMiddleware)
|
||||||
|
|
||||||
return tusHandler
|
return tusHandler
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,15 +0,0 @@
|
||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"go.sia.tech/jape"
|
|
||||||
)
|
|
||||||
|
|
||||||
type MiddlewareFunc func(jape.Handler) jape.Handler
|
|
||||||
|
|
||||||
func ApplyMiddlewares(handler jape.Handler, middlewares ...MiddlewareFunc) jape.Handler {
|
|
||||||
// Apply each middleware in reverse order
|
|
||||||
for i := len(middlewares) - 1; i >= 0; i-- {
|
|
||||||
handler = middlewares[i](handler)
|
|
||||||
}
|
|
||||||
return handler
|
|
||||||
}
|
|
Loading…
Reference in New Issue