refactor: major middleware refactor
This commit is contained in:
parent
9f6f2c9c87
commit
0b3d54e7c5
|
@ -2,14 +2,29 @@ package account
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type JWTPurpose string
|
type JWTPurpose string
|
||||||
|
type VerifyTokenFunc func(claim jwt.RegisteredClaims) error
|
||||||
|
|
||||||
|
var (
|
||||||
|
nopVerifyFunc VerifyTokenFunc = func(claim jwt.RegisteredClaims) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ErrJWTUnexpectedClaimsType = errors.New("unexpected claims type")
|
||||||
|
ErrJWTUnexpectedIssuer = errors.New("unexpected issuer")
|
||||||
|
ErrJWTInvalid = errors.New("invalid JWT")
|
||||||
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
JWTPurposeLogin JWTPurpose = "login"
|
JWTPurposeLogin JWTPurpose = "login"
|
||||||
|
JWTPurpose2FA JWTPurpose = "2fa"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GenerateToken(domain string, privateKey ed25519.PrivateKey, userID uint, purpose JWTPurpose) (string, error) {
|
func GenerateToken(domain string, privateKey ed25519.PrivateKey, userID uint, purpose JWTPurpose) (string, error) {
|
||||||
|
@ -17,13 +32,14 @@ func GenerateToken(domain string, privateKey ed25519.PrivateKey, userID uint, pu
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenerateTokenWithDuration(domain string, privateKey ed25519.PrivateKey, userID uint, duration time.Duration, purpose JWTPurpose) (string, error) {
|
func GenerateTokenWithDuration(domain string, privateKey ed25519.PrivateKey, userID uint, duration time.Duration, purpose JWTPurpose) (string, error) {
|
||||||
|
|
||||||
// Define the claims
|
// Define the claims
|
||||||
claims := jwt.MapClaims{
|
claims := jwt.RegisteredClaims{
|
||||||
"iss": domain,
|
Issuer: domain,
|
||||||
"sub": userID,
|
Subject: strconv.Itoa(int(userID)),
|
||||||
"exp": time.Now().Add(duration).Unix(),
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(duration)),
|
||||||
"iat": time.Now().Unix(),
|
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||||
"aud": string(purpose),
|
Audience: []string{string(purpose)},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the token
|
// Create the token
|
||||||
|
@ -37,3 +53,37 @@ func GenerateTokenWithDuration(domain string, privateKey ed25519.PrivateKey, use
|
||||||
|
|
||||||
return tokenString, nil
|
return tokenString, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func VerifyToken(token string, domain string, privateKey ed25519.PrivateKey, verifyFunc VerifyTokenFunc) (*jwt.RegisteredClaims, error) {
|
||||||
|
validatedToken, err := jwt.ParseWithClaims(token, jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||||
|
if _, ok := token.Method.(*jwt.SigningMethodEd25519); !ok {
|
||||||
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||||
|
}
|
||||||
|
|
||||||
|
publicKey := privateKey.Public()
|
||||||
|
|
||||||
|
return publicKey, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if verifyFunc == nil {
|
||||||
|
verifyFunc = nopVerifyFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
claim, ok := validatedToken.Claims.(jwt.RegisteredClaims)
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("%w: %s", ErrJWTUnexpectedClaimsType, validatedToken.Claims)
|
||||||
|
}
|
||||||
|
|
||||||
|
if domain != claim.Issuer {
|
||||||
|
return nil, fmt.Errorf("%w: %s", ErrJWTUnexpectedIssuer, claim.Issuer)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = verifyFunc(validatedToken.Claims.(jwt.RegisteredClaims))
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
|
@ -1,17 +1,27 @@
|
||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/ed25519"
|
||||||
|
"git.lumeweb.com/LumeWeb/portal/account"
|
||||||
"git.lumeweb.com/LumeWeb/portal/api/registry"
|
"git.lumeweb.com/LumeWeb/portal/api/registry"
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
"github.com/julienschmidt/httprouter"
|
"github.com/julienschmidt/httprouter"
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
"go.sia.tech/jape"
|
"go.sia.tech/jape"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"slices"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const DEFAULT_AUTH_CONTEXT_KEY = "user_id"
|
||||||
|
|
||||||
type JapeMiddlewareFunc func(jape.Handler) jape.Handler
|
type JapeMiddlewareFunc func(jape.Handler) jape.Handler
|
||||||
type HttpMiddlewareFunc func(http.Handler) http.Handler
|
type HttpMiddlewareFunc func(http.Handler) http.Handler
|
||||||
|
|
||||||
|
type FindAuthTokenFunc func(r *http.Request) string
|
||||||
|
|
||||||
func AdaptMiddleware(mid func(http.Handler) http.Handler) JapeMiddlewareFunc {
|
func AdaptMiddleware(mid func(http.Handler) http.Handler) JapeMiddlewareFunc {
|
||||||
return jape.Adapt(func(h http.Handler) http.Handler {
|
return jape.Adapt(func(h http.Handler) http.Handler {
|
||||||
handler := mid(h)
|
handler := mid(h)
|
||||||
|
@ -56,3 +66,93 @@ func RegisterProtocolSubdomain(config *viper.Viper, mux *httprouter.Router, name
|
||||||
|
|
||||||
(router)[name+"."+domain] = mux
|
(router)[name+"."+domain] = mux
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func FindAuthToken(r *http.Request, cookieName string, queryParam string) string {
|
||||||
|
authHeader := ParseAuthTokenHeader(r.Header)
|
||||||
|
|
||||||
|
if authHeader != "" {
|
||||||
|
return authHeader
|
||||||
|
}
|
||||||
|
|
||||||
|
if cookie, err := r.Cookie(cookieName); cookie != nil && err == nil {
|
||||||
|
return cookie.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.FormValue(queryParam)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseAuthTokenHeader(headers http.Header) string {
|
||||||
|
authHeader := headers.Get("Authorization")
|
||||||
|
if authHeader == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
authHeader = strings.TrimPrefix(authHeader, "Bearer ")
|
||||||
|
|
||||||
|
return authHeader
|
||||||
|
}
|
||||||
|
|
||||||
|
type AuthMiddlewareOptions struct {
|
||||||
|
Identity ed25519.PrivateKey
|
||||||
|
Accounts *account.AccountServiceDefault
|
||||||
|
FindToken FindAuthTokenFunc
|
||||||
|
Purpose account.JWTPurpose
|
||||||
|
AuthContextKey string
|
||||||
|
Config *viper.Viper
|
||||||
|
}
|
||||||
|
|
||||||
|
func AuthMiddleware(options AuthMiddlewareOptions) func(http.Handler) http.Handler {
|
||||||
|
if options.AuthContextKey == "" {
|
||||||
|
options.AuthContextKey = DEFAULT_AUTH_CONTEXT_KEY
|
||||||
|
}
|
||||||
|
if options.Purpose == "" {
|
||||||
|
panic("purpose is missing")
|
||||||
|
}
|
||||||
|
|
||||||
|
domain := options.Config.GetString("core.domain")
|
||||||
|
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
authToken := options.FindToken(r)
|
||||||
|
|
||||||
|
if authToken == "" {
|
||||||
|
http.Error(w, "Invalid JWT", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
claim, err := account.VerifyToken(authToken, domain, options.Identity, func(claim jwt.RegisteredClaims) error {
|
||||||
|
aud, _ := claim.GetAudience()
|
||||||
|
|
||||||
|
if slices.Contains[jwt.ClaimStrings, string](aud, string(options.Purpose)) == false {
|
||||||
|
return account.ErrJWTInvalid
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userId, err := strconv.ParseUint(claim.Subject, 10, 64)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, account.ErrJWTInvalid.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
exists, _, err := options.Accounts.AccountExists(uint(userId))
|
||||||
|
|
||||||
|
if !exists || err != nil {
|
||||||
|
http.Error(w, account.ErrJWTInvalid.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.WithValue(r.Context(), options.AuthContextKey, uint(userId))
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,108 +0,0 @@
|
||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/ed25519"
|
|
||||||
"fmt"
|
|
||||||
"git.lumeweb.com/LumeWeb/portal/account"
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
S5AuthUserIDKey = "userID"
|
|
||||||
S5AuthCookieName = "s5-auth-token"
|
|
||||||
S5AuthQueryParam = "auth_token"
|
|
||||||
)
|
|
||||||
|
|
||||||
func FindAuthToken(r *http.Request) string {
|
|
||||||
authHeader := ParseAuthTokenHeader(r.Header)
|
|
||||||
|
|
||||||
if authHeader != "" {
|
|
||||||
return authHeader
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, cookie := range r.Cookies() {
|
|
||||||
if cookie.Name == S5AuthCookieName {
|
|
||||||
return cookie.Value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return r.FormValue(S5AuthQueryParam)
|
|
||||||
}
|
|
||||||
|
|
||||||
func ParseAuthTokenHeader(headers http.Header) string {
|
|
||||||
authHeader := headers.Get("Authorization")
|
|
||||||
if authHeader == "" {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
authHeader = strings.TrimPrefix(authHeader, "Bearer ")
|
|
||||||
|
|
||||||
return authHeader
|
|
||||||
}
|
|
||||||
|
|
||||||
func AuthMiddleware(identity ed25519.PrivateKey, accounts *account.AccountServiceDefault) func(http.Handler) http.Handler {
|
|
||||||
return func(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
authToken := FindAuthToken(r)
|
|
||||||
|
|
||||||
if authToken == "" {
|
|
||||||
http.Error(w, "Invalid JWT", http.StatusUnauthorized)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
token, err := jwt.Parse(authToken, func(token *jwt.Token) (interface{}, error) {
|
|
||||||
if _, ok := token.Method.(*jwt.SigningMethodEd25519); !ok {
|
|
||||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
||||||
}
|
|
||||||
|
|
||||||
publicKey := identity.Public()
|
|
||||||
|
|
||||||
return publicKey, nil
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if claim, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
|
|
||||||
subject, ok := claim["sub"]
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
http.Error(w, "Invalid User ID", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var userID uint64
|
|
||||||
|
|
||||||
switch v := subject.(type) {
|
|
||||||
case uint64:
|
|
||||||
userID = v
|
|
||||||
case float64:
|
|
||||||
userID = uint64(v)
|
|
||||||
default:
|
|
||||||
// Handle the case where userID is of an unexpected type
|
|
||||||
http.Error(w, "Invalid User ID", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
exists, _ := accounts.AccountExists(userID)
|
|
||||||
|
|
||||||
if !exists {
|
|
||||||
http.Error(w, "Invalid User ID", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := context.WithValue(r.Context(), S5AuthUserIDKey, userID)
|
|
||||||
r = r.WithContext(ctx)
|
|
||||||
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
} else {
|
|
||||||
http.Error(w, "Invalid JWT", http.StatusUnauthorized)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
package s5
|
||||||
|
|
||||||
|
import (
|
||||||
|
"git.lumeweb.com/LumeWeb/portal/api/middleware"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
authCookieName = "s5-auth-token"
|
||||||
|
authQueryParam = "auth_token"
|
||||||
|
)
|
||||||
|
|
||||||
|
func findToken(r *http.Request) string {
|
||||||
|
return middleware.FindAuthToken(r, authCookieName, authQueryParam)
|
||||||
|
}
|
||||||
|
|
||||||
|
func authMiddleware(options middleware.AuthMiddlewareOptions) middleware.HttpMiddlewareFunc {
|
||||||
|
options.FindToken = findToken
|
||||||
|
return middleware.AuthMiddleware(options)
|
||||||
|
}
|
Loading…
Reference in New Issue