portal/api/middleware/middleware.go

261 lines
6.0 KiB
Go
Raw Permalink Normal View History

package middleware
import (
2024-02-14 03:17:34 +00:00
"context"
"crypto/ed25519"
"errors"
2024-02-17 08:04:27 +00:00
"net/http"
"slices"
"strconv"
"strings"
"git.lumeweb.com/LumeWeb/portal/config"
2024-02-14 03:17:34 +00:00
"git.lumeweb.com/LumeWeb/portal/account"
"github.com/golang-jwt/jwt/v5"
"go.sia.tech/jape"
)
2024-02-14 03:17:34 +00:00
const DEFAULT_AUTH_CONTEXT_KEY = "user_id"
const AUTH_TOKEN_CONTEXT_KEY = "auth_token"
2024-02-14 03:17:34 +00:00
type JapeMiddlewareFunc func(jape.Handler) jape.Handler
type HttpMiddlewareFunc func(http.Handler) http.Handler
2024-02-14 03:17:34 +00:00
type FindAuthTokenFunc func(r *http.Request) string
func AdaptMiddleware(mid func(http.Handler) http.Handler) JapeMiddlewareFunc {
return jape.Adapt(func(h http.Handler) http.Handler {
handler := mid(h)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handler.ServeHTTP(w, r)
})
})
}
// ProxyMiddleware creates a new HTTP middleware for handling X-Forwarded-For headers.
func ProxyMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
ips := strings.Split(xff, ", ")
if len(ips) > 0 {
r.RemoteAddr = ips[0]
}
}
next.ServeHTTP(w, r)
})
}
func ApplyMiddlewares(handler jape.Handler, middlewares ...interface{}) jape.Handler {
for i := len(middlewares) - 1; i >= 0; i-- {
switch middlewares[i].(type) {
case JapeMiddlewareFunc:
mid := middlewares[i].(JapeMiddlewareFunc)
handler = mid(handler)
case func(http.Handler) http.Handler:
mid := middlewares[i].(func(http.Handler) http.Handler)
handler = AdaptMiddleware(mid)(handler)
case HttpMiddlewareFunc:
mid := middlewares[i].(HttpMiddlewareFunc)
handler = AdaptMiddleware(mid)(handler)
default:
panic("Invalid middleware type")
}
}
return handler
}
2024-02-14 03:17:34 +00:00
func FindAuthToken(r *http.Request, cookieName string, queryParam string) string {
authHeader := ParseAuthTokenHeader(r.Header)
if authHeader != "" {
return authHeader
}
if cookie, err := r.Cookie(cookieName); cookie != nil && err == nil {
return cookie.Value
}
if cookie, err := r.Cookie(account.AUTH_COOKIE_NAME); cookie != nil && err == nil {
return cookie.Value
}
2024-02-14 03:17:34 +00:00
return r.FormValue(queryParam)
}
func ParseAuthTokenHeader(headers http.Header) string {
authHeader := headers.Get("Authorization")
if authHeader == "" {
return ""
}
authHeader = strings.TrimPrefix(authHeader, "Bearer ")
2024-02-18 00:48:28 +00:00
authHeader = strings.TrimPrefix(authHeader, "bearer ")
2024-02-14 03:17:34 +00:00
return authHeader
}
type AuthMiddlewareOptions struct {
Identity ed25519.PrivateKey
Accounts *account.AccountServiceDefault
FindToken FindAuthTokenFunc
Purpose account.JWTPurpose
AuthContextKey string
Config *config.Manager
EmptyAllowed bool
ExpiredAllowed bool
2024-02-14 03:17:34 +00:00
}
func AuthMiddleware(options AuthMiddlewareOptions) func(http.Handler) http.Handler {
if options.AuthContextKey == "" {
options.AuthContextKey = DEFAULT_AUTH_CONTEXT_KEY
}
domain := options.Config.Config().Core.Domain
2024-02-14 03:17:34 +00:00
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authToken := options.FindToken(r)
if authToken == "" {
if !options.EmptyAllowed {
http.Error(w, "Invalid JWT", http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
2024-02-14 03:17:34 +00:00
return
}
var audList *jwt.ClaimStrings
2024-03-20 18:13:59 +00:00
2024-02-18 01:11:43 +00:00
claim, err := account.JWTVerifyToken(authToken, domain, options.Identity, func(claim *jwt.RegisteredClaims) error {
2024-02-14 03:17:34 +00:00
aud, _ := claim.GetAudience()
audList = &aud
2024-03-20 18:13:59 +00:00
if options.Purpose != account.JWTPurposeNone && jwtPurposeEqual(aud, options.Purpose) == false {
return account.ErrJWTInvalid
2024-02-14 03:17:34 +00:00
}
return nil
})
if err != nil {
unauthorized := true
if errors.Is(err, jwt.ErrTokenExpired) && options.ExpiredAllowed {
unauthorized = false
}
2024-03-20 18:28:46 +00:00
if !unauthorized && audList == nil {
if audList == nil {
var claim jwt.RegisteredClaims
unverified, _, err := jwt.NewParser().ParseUnverified(authToken, &claim)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
audList, err := unverified.Claims.GetAudience()
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if jwtPurposeEqual(audList, options.Purpose) == true {
unauthorized = true
}
}
2024-03-20 18:30:03 +00:00
}
2024-03-20 18:30:03 +00:00
if unauthorized {
2024-03-20 18:44:28 +00:00
http.Error(w, err.Error(), http.StatusUnauthorized)
2024-03-20 18:34:48 +00:00
return
}
2024-02-14 03:17:34 +00:00
}
if claim == nil && options.ExpiredAllowed {
next.ServeHTTP(w, r)
return
}
2024-02-14 03:17:34 +00:00
userId, err := strconv.ParseUint(claim.Subject, 10, 64)
if err != nil {
http.Error(w, account.ErrJWTInvalid.Error(), http.StatusBadRequest)
return
}
exists, _, err := options.Accounts.AccountExists(uint(userId))
if !exists || err != nil {
http.Error(w, account.ErrJWTInvalid.Error(), http.StatusBadRequest)
return
}
ctx := context.WithValue(r.Context(), options.AuthContextKey, uint(userId))
2024-03-18 21:29:49 +00:00
ctx = context.WithValue(ctx, AUTH_TOKEN_CONTEXT_KEY, authToken)
2024-02-14 03:17:34 +00:00
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
})
}
}
2024-02-14 04:22:36 +00:00
2024-02-17 08:04:27 +00:00
func MergeRoutes(routes ...map[string]jape.Handler) map[string]jape.Handler {
merged := make(map[string]jape.Handler)
for _, route := range routes {
for k, v := range route {
merged[k] = v
}
}
return merged
}
2024-02-14 04:22:36 +00:00
func GetUserFromContext(ctx context.Context, key ...string) uint {
realKey := ""
if len(key) > 0 {
realKey = key[0]
}
if realKey == "" {
realKey = DEFAULT_AUTH_CONTEXT_KEY
}
2024-02-18 01:25:30 +00:00
userId, ok := ctx.Value(realKey).(uint)
2024-02-14 04:22:36 +00:00
if !ok {
panic("user id stored in context is not of type uint")
}
return userId
}
2024-03-18 21:18:26 +00:00
func GetAuthTokenFromContext(ctx context.Context) string {
authToken, ok := ctx.Value(AUTH_TOKEN_CONTEXT_KEY).(string)
if !ok {
panic("auth token stored in context is not of type string")
}
return authToken
}
func CtxAborted(ctx context.Context) bool {
select {
case <-ctx.Done():
return true
default:
return false
}
}
func jwtPurposeEqual(aud jwt.ClaimStrings, purpose account.JWTPurpose) bool {
return slices.Contains[jwt.ClaimStrings, string](aud, string(purpose))
}