diff --git a/api/middleware/middleware.go b/api/middleware/middleware.go index df87d90..7d0c2c1 100644 --- a/api/middleware/middleware.go +++ b/api/middleware/middleware.go @@ -3,6 +3,7 @@ package middleware import ( "context" "crypto/ed25519" + "errors" "net/http" "slices" "strconv" @@ -103,6 +104,7 @@ type AuthMiddlewareOptions struct { AuthContextKey string Config *config.Manager EmptyAllowed bool + ExpiredAllowed bool } func AuthMiddleware(options AuthMiddlewareOptions) func(http.Handler) http.Handler { @@ -128,7 +130,7 @@ func AuthMiddleware(options AuthMiddlewareOptions) func(http.Handler) http.Handl claim, err := account.JWTVerifyToken(authToken, domain, options.Identity, func(claim *jwt.RegisteredClaims) error { aud, _ := claim.GetAudience() - if options.Purpose != account.JWTPurposeNone && slices.Contains[jwt.ClaimStrings, string](aud, string(options.Purpose)) == false { + if options.Purpose != account.JWTPurposeNone && jwtPurposeEqual(aud, options.Purpose) == false { return account.ErrJWTInvalid } @@ -136,7 +138,14 @@ func AuthMiddleware(options AuthMiddlewareOptions) func(http.Handler) http.Handl }) if err != nil { - http.Error(w, err.Error(), http.StatusUnauthorized) + unauthorized := true + if errors.Is(err, jwt.ErrTokenExpired) && options.ExpiredAllowed { + unauthorized = false + } + + if unauthorized && jwtPurposeEqual(claim.Audience, options.Purpose) == true { + http.Error(w, err.Error(), http.StatusUnauthorized) + } return } @@ -213,3 +222,7 @@ func CtxAborted(ctx context.Context) bool { return false } } + +func jwtPurposeEqual(aud jwt.ClaimStrings, purpose account.JWTPurpose) bool { + return slices.Contains[jwt.ClaimStrings, string](aud, string(purpose)) +}