refactor: major middleware refactor
This commit is contained in:
parent
9f6f2c9c87
commit
0b3d54e7c5
|
@ -2,14 +2,29 @@ package account
|
|||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
type JWTPurpose string
|
||||
type VerifyTokenFunc func(claim jwt.RegisteredClaims) error
|
||||
|
||||
var (
|
||||
nopVerifyFunc VerifyTokenFunc = func(claim jwt.RegisteredClaims) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
ErrJWTUnexpectedClaimsType = errors.New("unexpected claims type")
|
||||
ErrJWTUnexpectedIssuer = errors.New("unexpected issuer")
|
||||
ErrJWTInvalid = errors.New("invalid JWT")
|
||||
)
|
||||
|
||||
const (
|
||||
JWTPurposeLogin JWTPurpose = "login"
|
||||
JWTPurpose2FA JWTPurpose = "2fa"
|
||||
)
|
||||
|
||||
func GenerateToken(domain string, privateKey ed25519.PrivateKey, userID uint, purpose JWTPurpose) (string, error) {
|
||||
|
@ -17,13 +32,14 @@ func GenerateToken(domain string, privateKey ed25519.PrivateKey, userID uint, pu
|
|||
}
|
||||
|
||||
func GenerateTokenWithDuration(domain string, privateKey ed25519.PrivateKey, userID uint, duration time.Duration, purpose JWTPurpose) (string, error) {
|
||||
|
||||
// Define the claims
|
||||
claims := jwt.MapClaims{
|
||||
"iss": domain,
|
||||
"sub": userID,
|
||||
"exp": time.Now().Add(duration).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"aud": string(purpose),
|
||||
claims := jwt.RegisteredClaims{
|
||||
Issuer: domain,
|
||||
Subject: strconv.Itoa(int(userID)),
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(duration)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
Audience: []string{string(purpose)},
|
||||
}
|
||||
|
||||
// Create the token
|
||||
|
@ -37,3 +53,37 @@ func GenerateTokenWithDuration(domain string, privateKey ed25519.PrivateKey, use
|
|||
|
||||
return tokenString, nil
|
||||
}
|
||||
|
||||
func VerifyToken(token string, domain string, privateKey ed25519.PrivateKey, verifyFunc VerifyTokenFunc) (*jwt.RegisteredClaims, error) {
|
||||
validatedToken, err := jwt.ParseWithClaims(token, jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodEd25519); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
|
||||
publicKey := privateKey.Public()
|
||||
|
||||
return publicKey, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if verifyFunc == nil {
|
||||
verifyFunc = nopVerifyFunc
|
||||
}
|
||||
|
||||
claim, ok := validatedToken.Claims.(jwt.RegisteredClaims)
|
||||
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%w: %s", ErrJWTUnexpectedClaimsType, validatedToken.Claims)
|
||||
}
|
||||
|
||||
if domain != claim.Issuer {
|
||||
return nil, fmt.Errorf("%w: %s", ErrJWTUnexpectedIssuer, claim.Issuer)
|
||||
}
|
||||
|
||||
err = verifyFunc(validatedToken.Claims.(jwt.RegisteredClaims))
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -1,17 +1,27 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"git.lumeweb.com/LumeWeb/portal/account"
|
||||
"git.lumeweb.com/LumeWeb/portal/api/registry"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"github.com/spf13/viper"
|
||||
"go.sia.tech/jape"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const DEFAULT_AUTH_CONTEXT_KEY = "user_id"
|
||||
|
||||
type JapeMiddlewareFunc func(jape.Handler) jape.Handler
|
||||
type HttpMiddlewareFunc func(http.Handler) http.Handler
|
||||
|
||||
type FindAuthTokenFunc func(r *http.Request) string
|
||||
|
||||
func AdaptMiddleware(mid func(http.Handler) http.Handler) JapeMiddlewareFunc {
|
||||
return jape.Adapt(func(h http.Handler) http.Handler {
|
||||
handler := mid(h)
|
||||
|
@ -56,3 +66,93 @@ func RegisterProtocolSubdomain(config *viper.Viper, mux *httprouter.Router, name
|
|||
|
||||
(router)[name+"."+domain] = mux
|
||||
}
|
||||
|
||||
func FindAuthToken(r *http.Request, cookieName string, queryParam string) string {
|
||||
authHeader := ParseAuthTokenHeader(r.Header)
|
||||
|
||||
if authHeader != "" {
|
||||
return authHeader
|
||||
}
|
||||
|
||||
if cookie, err := r.Cookie(cookieName); cookie != nil && err == nil {
|
||||
return cookie.Value
|
||||
}
|
||||
|
||||
return r.FormValue(queryParam)
|
||||
}
|
||||
|
||||
func ParseAuthTokenHeader(headers http.Header) string {
|
||||
authHeader := headers.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
authHeader = strings.TrimPrefix(authHeader, "Bearer ")
|
||||
|
||||
return authHeader
|
||||
}
|
||||
|
||||
type AuthMiddlewareOptions struct {
|
||||
Identity ed25519.PrivateKey
|
||||
Accounts *account.AccountServiceDefault
|
||||
FindToken FindAuthTokenFunc
|
||||
Purpose account.JWTPurpose
|
||||
AuthContextKey string
|
||||
Config *viper.Viper
|
||||
}
|
||||
|
||||
func AuthMiddleware(options AuthMiddlewareOptions) func(http.Handler) http.Handler {
|
||||
if options.AuthContextKey == "" {
|
||||
options.AuthContextKey = DEFAULT_AUTH_CONTEXT_KEY
|
||||
}
|
||||
if options.Purpose == "" {
|
||||
panic("purpose is missing")
|
||||
}
|
||||
|
||||
domain := options.Config.GetString("core.domain")
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
authToken := options.FindToken(r)
|
||||
|
||||
if authToken == "" {
|
||||
http.Error(w, "Invalid JWT", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
claim, err := account.VerifyToken(authToken, domain, options.Identity, func(claim jwt.RegisteredClaims) error {
|
||||
aud, _ := claim.GetAudience()
|
||||
|
||||
if slices.Contains[jwt.ClaimStrings, string](aud, string(options.Purpose)) == false {
|
||||
return account.ErrJWTInvalid
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
userId, err := strconv.ParseUint(claim.Subject, 10, 64)
|
||||
|
||||
if err != nil {
|
||||
http.Error(w, account.ErrJWTInvalid.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
exists, _, err := options.Accounts.AccountExists(uint(userId))
|
||||
|
||||
if !exists || err != nil {
|
||||
http.Error(w, account.ErrJWTInvalid.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), options.AuthContextKey, uint(userId))
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,108 +0,0 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"fmt"
|
||||
"git.lumeweb.com/LumeWeb/portal/account"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
S5AuthUserIDKey = "userID"
|
||||
S5AuthCookieName = "s5-auth-token"
|
||||
S5AuthQueryParam = "auth_token"
|
||||
)
|
||||
|
||||
func FindAuthToken(r *http.Request) string {
|
||||
authHeader := ParseAuthTokenHeader(r.Header)
|
||||
|
||||
if authHeader != "" {
|
||||
return authHeader
|
||||
}
|
||||
|
||||
for _, cookie := range r.Cookies() {
|
||||
if cookie.Name == S5AuthCookieName {
|
||||
return cookie.Value
|
||||
}
|
||||
}
|
||||
|
||||
return r.FormValue(S5AuthQueryParam)
|
||||
}
|
||||
|
||||
func ParseAuthTokenHeader(headers http.Header) string {
|
||||
authHeader := headers.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
authHeader = strings.TrimPrefix(authHeader, "Bearer ")
|
||||
|
||||
return authHeader
|
||||
}
|
||||
|
||||
func AuthMiddleware(identity ed25519.PrivateKey, accounts *account.AccountServiceDefault) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
authToken := FindAuthToken(r)
|
||||
|
||||
if authToken == "" {
|
||||
http.Error(w, "Invalid JWT", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
token, err := jwt.Parse(authToken, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodEd25519); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
|
||||
publicKey := identity.Public()
|
||||
|
||||
return publicKey, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if claim, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
|
||||
subject, ok := claim["sub"]
|
||||
|
||||
if !ok {
|
||||
http.Error(w, "Invalid User ID", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var userID uint64
|
||||
|
||||
switch v := subject.(type) {
|
||||
case uint64:
|
||||
userID = v
|
||||
case float64:
|
||||
userID = uint64(v)
|
||||
default:
|
||||
// Handle the case where userID is of an unexpected type
|
||||
http.Error(w, "Invalid User ID", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
exists, _ := accounts.AccountExists(userID)
|
||||
|
||||
if !exists {
|
||||
http.Error(w, "Invalid User ID", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), S5AuthUserIDKey, userID)
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
} else {
|
||||
http.Error(w, "Invalid JWT", http.StatusUnauthorized)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
package s5
|
||||
|
||||
import (
|
||||
"git.lumeweb.com/LumeWeb/portal/api/middleware"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
authCookieName = "s5-auth-token"
|
||||
authQueryParam = "auth_token"
|
||||
)
|
||||
|
||||
func findToken(r *http.Request) string {
|
||||
return middleware.FindAuthToken(r, authCookieName, authQueryParam)
|
||||
}
|
||||
|
||||
func authMiddleware(options middleware.AuthMiddlewareOptions) middleware.HttpMiddlewareFunc {
|
||||
options.FindToken = findToken
|
||||
return middleware.AuthMiddleware(options)
|
||||
}
|
Loading…
Reference in New Issue