Compare commits

..

2 Commits

1 changed files with 16 additions and 5 deletions

View File

@ -3,6 +3,7 @@ package middleware
import ( import (
"context" "context"
"crypto/ed25519" "crypto/ed25519"
"errors"
"net/http" "net/http"
"slices" "slices"
"strconv" "strconv"
@ -103,6 +104,7 @@ type AuthMiddlewareOptions struct {
AuthContextKey string AuthContextKey string
Config *config.Manager Config *config.Manager
EmptyAllowed bool EmptyAllowed bool
ExpiredAllowed bool
} }
func AuthMiddleware(options AuthMiddlewareOptions) func(http.Handler) http.Handler { func AuthMiddleware(options AuthMiddlewareOptions) func(http.Handler) http.Handler {
@ -128,17 +130,22 @@ func AuthMiddleware(options AuthMiddlewareOptions) func(http.Handler) http.Handl
claim, err := account.JWTVerifyToken(authToken, domain, options.Identity, func(claim *jwt.RegisteredClaims) error { claim, err := account.JWTVerifyToken(authToken, domain, options.Identity, func(claim *jwt.RegisteredClaims) error {
aud, _ := claim.GetAudience() aud, _ := claim.GetAudience()
if options.Purpose != account.JWTPurposeNone && slices.Contains[jwt.ClaimStrings, string](aud, string(options.Purpose)) == false { if options.Purpose != account.JWTPurposeNone && jwtPurposeEqual(aud, options.Purpose) == false {
if !options.EmptyAllowed {
return account.ErrJWTInvalid return account.ErrJWTInvalid
} }
}
return nil return nil
}) })
if err != nil { if err != nil {
unauthorized := true
if errors.Is(err, jwt.ErrTokenExpired) && options.ExpiredAllowed {
unauthorized = false
}
if unauthorized && jwtPurposeEqual(claim.Audience, options.Purpose) == true {
http.Error(w, err.Error(), http.StatusUnauthorized) http.Error(w, err.Error(), http.StatusUnauthorized)
}
return return
} }
@ -215,3 +222,7 @@ func CtxAborted(ctx context.Context) bool {
return false return false
} }
} }
func jwtPurposeEqual(aud jwt.ClaimStrings, purpose account.JWTPurpose) bool {
return slices.Contains[jwt.ClaimStrings, string](aud, string(purpose))
}