From 0b3d54e7c5a076bc00036dee82c971ed17e6e0ee Mon Sep 17 00:00:00 2001 From: Derrick Hammer Date: Tue, 13 Feb 2024 22:17:34 -0500 Subject: [PATCH] refactor: major middleware refactor --- account/jwt.go | 62 ++++++++++++++++++-- api/middleware/middleware.go | 100 ++++++++++++++++++++++++++++++++ api/middleware/s5.go | 108 ----------------------------------- api/s5/middleware.go | 20 +++++++ 4 files changed, 176 insertions(+), 114 deletions(-) delete mode 100644 api/middleware/s5.go create mode 100644 api/s5/middleware.go diff --git a/account/jwt.go b/account/jwt.go index b1f4b92..cf6654f 100644 --- a/account/jwt.go +++ b/account/jwt.go @@ -2,14 +2,29 @@ package account import ( "crypto/ed25519" + "errors" + "fmt" "github.com/golang-jwt/jwt/v5" + "strconv" "time" ) type JWTPurpose string +type VerifyTokenFunc func(claim jwt.RegisteredClaims) error + +var ( + nopVerifyFunc VerifyTokenFunc = func(claim jwt.RegisteredClaims) error { + return nil + } + + ErrJWTUnexpectedClaimsType = errors.New("unexpected claims type") + ErrJWTUnexpectedIssuer = errors.New("unexpected issuer") + ErrJWTInvalid = errors.New("invalid JWT") +) const ( JWTPurposeLogin JWTPurpose = "login" + JWTPurpose2FA JWTPurpose = "2fa" ) func GenerateToken(domain string, privateKey ed25519.PrivateKey, userID uint, purpose JWTPurpose) (string, error) { @@ -17,13 +32,14 @@ func GenerateToken(domain string, privateKey ed25519.PrivateKey, userID uint, pu } func GenerateTokenWithDuration(domain string, privateKey ed25519.PrivateKey, userID uint, duration time.Duration, purpose JWTPurpose) (string, error) { + // Define the claims - claims := jwt.MapClaims{ - "iss": domain, - "sub": userID, - "exp": time.Now().Add(duration).Unix(), - "iat": time.Now().Unix(), - "aud": string(purpose), + claims := jwt.RegisteredClaims{ + Issuer: domain, + Subject: strconv.Itoa(int(userID)), + ExpiresAt: jwt.NewNumericDate(time.Now().Add(duration)), + IssuedAt: jwt.NewNumericDate(time.Now()), + Audience: []string{string(purpose)}, } // Create the token @@ -37,3 +53,37 @@ func GenerateTokenWithDuration(domain string, privateKey ed25519.PrivateKey, use return tokenString, nil } + +func VerifyToken(token string, domain string, privateKey ed25519.PrivateKey, verifyFunc VerifyTokenFunc) (*jwt.RegisteredClaims, error) { + validatedToken, err := jwt.ParseWithClaims(token, jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodEd25519); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + + publicKey := privateKey.Public() + + return publicKey, nil + }) + + if err != nil { + return nil, err + } + + if verifyFunc == nil { + verifyFunc = nopVerifyFunc + } + + claim, ok := validatedToken.Claims.(jwt.RegisteredClaims) + + if !ok { + return nil, fmt.Errorf("%w: %s", ErrJWTUnexpectedClaimsType, validatedToken.Claims) + } + + if domain != claim.Issuer { + return nil, fmt.Errorf("%w: %s", ErrJWTUnexpectedIssuer, claim.Issuer) + } + + err = verifyFunc(validatedToken.Claims.(jwt.RegisteredClaims)) + + return nil, err +} diff --git a/api/middleware/middleware.go b/api/middleware/middleware.go index 90e740c..34f543f 100644 --- a/api/middleware/middleware.go +++ b/api/middleware/middleware.go @@ -1,17 +1,27 @@ package middleware import ( + "context" + "crypto/ed25519" + "git.lumeweb.com/LumeWeb/portal/account" "git.lumeweb.com/LumeWeb/portal/api/registry" + "github.com/golang-jwt/jwt/v5" "github.com/julienschmidt/httprouter" "github.com/spf13/viper" "go.sia.tech/jape" "net/http" + "slices" + "strconv" "strings" ) +const DEFAULT_AUTH_CONTEXT_KEY = "user_id" + type JapeMiddlewareFunc func(jape.Handler) jape.Handler type HttpMiddlewareFunc func(http.Handler) http.Handler +type FindAuthTokenFunc func(r *http.Request) string + func AdaptMiddleware(mid func(http.Handler) http.Handler) JapeMiddlewareFunc { return jape.Adapt(func(h http.Handler) http.Handler { handler := mid(h) @@ -56,3 +66,93 @@ func RegisterProtocolSubdomain(config *viper.Viper, mux *httprouter.Router, name (router)[name+"."+domain] = mux } + +func FindAuthToken(r *http.Request, cookieName string, queryParam string) string { + authHeader := ParseAuthTokenHeader(r.Header) + + if authHeader != "" { + return authHeader + } + + if cookie, err := r.Cookie(cookieName); cookie != nil && err == nil { + return cookie.Value + } + + return r.FormValue(queryParam) +} + +func ParseAuthTokenHeader(headers http.Header) string { + authHeader := headers.Get("Authorization") + if authHeader == "" { + return "" + } + + authHeader = strings.TrimPrefix(authHeader, "Bearer ") + + return authHeader +} + +type AuthMiddlewareOptions struct { + Identity ed25519.PrivateKey + Accounts *account.AccountServiceDefault + FindToken FindAuthTokenFunc + Purpose account.JWTPurpose + AuthContextKey string + Config *viper.Viper +} + +func AuthMiddleware(options AuthMiddlewareOptions) func(http.Handler) http.Handler { + if options.AuthContextKey == "" { + options.AuthContextKey = DEFAULT_AUTH_CONTEXT_KEY + } + if options.Purpose == "" { + panic("purpose is missing") + } + + domain := options.Config.GetString("core.domain") + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + authToken := options.FindToken(r) + + if authToken == "" { + http.Error(w, "Invalid JWT", http.StatusUnauthorized) + return + } + + claim, err := account.VerifyToken(authToken, domain, options.Identity, func(claim jwt.RegisteredClaims) error { + aud, _ := claim.GetAudience() + + if slices.Contains[jwt.ClaimStrings, string](aud, string(options.Purpose)) == false { + return account.ErrJWTInvalid + } + + return nil + }) + + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + + userId, err := strconv.ParseUint(claim.Subject, 10, 64) + + if err != nil { + http.Error(w, account.ErrJWTInvalid.Error(), http.StatusBadRequest) + return + } + + exists, _, err := options.Accounts.AccountExists(uint(userId)) + + if !exists || err != nil { + http.Error(w, account.ErrJWTInvalid.Error(), http.StatusBadRequest) + return + } + + ctx := context.WithValue(r.Context(), options.AuthContextKey, uint(userId)) + r = r.WithContext(ctx) + + next.ServeHTTP(w, r) + }) + } +} diff --git a/api/middleware/s5.go b/api/middleware/s5.go deleted file mode 100644 index ee00c8e..0000000 --- a/api/middleware/s5.go +++ /dev/null @@ -1,108 +0,0 @@ -package middleware - -import ( - "context" - "crypto/ed25519" - "fmt" - "git.lumeweb.com/LumeWeb/portal/account" - "github.com/golang-jwt/jwt/v5" - "net/http" - "strings" -) - -const ( - S5AuthUserIDKey = "userID" - S5AuthCookieName = "s5-auth-token" - S5AuthQueryParam = "auth_token" -) - -func FindAuthToken(r *http.Request) string { - authHeader := ParseAuthTokenHeader(r.Header) - - if authHeader != "" { - return authHeader - } - - for _, cookie := range r.Cookies() { - if cookie.Name == S5AuthCookieName { - return cookie.Value - } - } - - return r.FormValue(S5AuthQueryParam) -} - -func ParseAuthTokenHeader(headers http.Header) string { - authHeader := headers.Get("Authorization") - if authHeader == "" { - return "" - } - - authHeader = strings.TrimPrefix(authHeader, "Bearer ") - - return authHeader -} - -func AuthMiddleware(identity ed25519.PrivateKey, accounts *account.AccountServiceDefault) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - authToken := FindAuthToken(r) - - if authToken == "" { - http.Error(w, "Invalid JWT", http.StatusUnauthorized) - return - } - - token, err := jwt.Parse(authToken, func(token *jwt.Token) (interface{}, error) { - if _, ok := token.Method.(*jwt.SigningMethodEd25519); !ok { - return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) - } - - publicKey := identity.Public() - - return publicKey, nil - }) - - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - - if claim, ok := token.Claims.(jwt.MapClaims); ok && token.Valid { - subject, ok := claim["sub"] - - if !ok { - http.Error(w, "Invalid User ID", http.StatusBadRequest) - return - } - - var userID uint64 - - switch v := subject.(type) { - case uint64: - userID = v - case float64: - userID = uint64(v) - default: - // Handle the case where userID is of an unexpected type - http.Error(w, "Invalid User ID", http.StatusBadRequest) - return - } - - exists, _ := accounts.AccountExists(userID) - - if !exists { - http.Error(w, "Invalid User ID", http.StatusBadRequest) - return - } - - ctx := context.WithValue(r.Context(), S5AuthUserIDKey, userID) - r = r.WithContext(ctx) - - next.ServeHTTP(w, r) - } else { - http.Error(w, "Invalid JWT", http.StatusUnauthorized) - } - }) - } -} diff --git a/api/s5/middleware.go b/api/s5/middleware.go new file mode 100644 index 0000000..94e418f --- /dev/null +++ b/api/s5/middleware.go @@ -0,0 +1,20 @@ +package s5 + +import ( + "git.lumeweb.com/LumeWeb/portal/api/middleware" + "net/http" +) + +const ( + authCookieName = "s5-auth-token" + authQueryParam = "auth_token" +) + +func findToken(r *http.Request) string { + return middleware.FindAuthToken(r, authCookieName, authQueryParam) +} + +func authMiddleware(options middleware.AuthMiddlewareOptions) middleware.HttpMiddlewareFunc { + options.FindToken = findToken + return middleware.AuthMiddleware(options) +}