refactor: major middleware refactor

This commit is contained in:
Derrick Hammer 2024-02-13 22:17:34 -05:00
parent 9f6f2c9c87
commit 0b3d54e7c5
Signed by: pcfreak30
GPG Key ID: C997C339BE476FF2
4 changed files with 176 additions and 114 deletions

View File

@ -2,14 +2,29 @@ package account
import ( import (
"crypto/ed25519" "crypto/ed25519"
"errors"
"fmt"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"strconv"
"time" "time"
) )
type JWTPurpose string type JWTPurpose string
type VerifyTokenFunc func(claim jwt.RegisteredClaims) error
var (
nopVerifyFunc VerifyTokenFunc = func(claim jwt.RegisteredClaims) error {
return nil
}
ErrJWTUnexpectedClaimsType = errors.New("unexpected claims type")
ErrJWTUnexpectedIssuer = errors.New("unexpected issuer")
ErrJWTInvalid = errors.New("invalid JWT")
)
const ( const (
JWTPurposeLogin JWTPurpose = "login" JWTPurposeLogin JWTPurpose = "login"
JWTPurpose2FA JWTPurpose = "2fa"
) )
func GenerateToken(domain string, privateKey ed25519.PrivateKey, userID uint, purpose JWTPurpose) (string, error) { func GenerateToken(domain string, privateKey ed25519.PrivateKey, userID uint, purpose JWTPurpose) (string, error) {
@ -17,13 +32,14 @@ func GenerateToken(domain string, privateKey ed25519.PrivateKey, userID uint, pu
} }
func GenerateTokenWithDuration(domain string, privateKey ed25519.PrivateKey, userID uint, duration time.Duration, purpose JWTPurpose) (string, error) { func GenerateTokenWithDuration(domain string, privateKey ed25519.PrivateKey, userID uint, duration time.Duration, purpose JWTPurpose) (string, error) {
// Define the claims // Define the claims
claims := jwt.MapClaims{ claims := jwt.RegisteredClaims{
"iss": domain, Issuer: domain,
"sub": userID, Subject: strconv.Itoa(int(userID)),
"exp": time.Now().Add(duration).Unix(), ExpiresAt: jwt.NewNumericDate(time.Now().Add(duration)),
"iat": time.Now().Unix(), IssuedAt: jwt.NewNumericDate(time.Now()),
"aud": string(purpose), Audience: []string{string(purpose)},
} }
// Create the token // Create the token
@ -37,3 +53,37 @@ func GenerateTokenWithDuration(domain string, privateKey ed25519.PrivateKey, use
return tokenString, nil return tokenString, nil
} }
func VerifyToken(token string, domain string, privateKey ed25519.PrivateKey, verifyFunc VerifyTokenFunc) (*jwt.RegisteredClaims, error) {
validatedToken, err := jwt.ParseWithClaims(token, jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodEd25519); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
publicKey := privateKey.Public()
return publicKey, nil
})
if err != nil {
return nil, err
}
if verifyFunc == nil {
verifyFunc = nopVerifyFunc
}
claim, ok := validatedToken.Claims.(jwt.RegisteredClaims)
if !ok {
return nil, fmt.Errorf("%w: %s", ErrJWTUnexpectedClaimsType, validatedToken.Claims)
}
if domain != claim.Issuer {
return nil, fmt.Errorf("%w: %s", ErrJWTUnexpectedIssuer, claim.Issuer)
}
err = verifyFunc(validatedToken.Claims.(jwt.RegisteredClaims))
return nil, err
}

View File

@ -1,17 +1,27 @@
package middleware package middleware
import ( import (
"context"
"crypto/ed25519"
"git.lumeweb.com/LumeWeb/portal/account"
"git.lumeweb.com/LumeWeb/portal/api/registry" "git.lumeweb.com/LumeWeb/portal/api/registry"
"github.com/golang-jwt/jwt/v5"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
"github.com/spf13/viper" "github.com/spf13/viper"
"go.sia.tech/jape" "go.sia.tech/jape"
"net/http" "net/http"
"slices"
"strconv"
"strings" "strings"
) )
const DEFAULT_AUTH_CONTEXT_KEY = "user_id"
type JapeMiddlewareFunc func(jape.Handler) jape.Handler type JapeMiddlewareFunc func(jape.Handler) jape.Handler
type HttpMiddlewareFunc func(http.Handler) http.Handler type HttpMiddlewareFunc func(http.Handler) http.Handler
type FindAuthTokenFunc func(r *http.Request) string
func AdaptMiddleware(mid func(http.Handler) http.Handler) JapeMiddlewareFunc { func AdaptMiddleware(mid func(http.Handler) http.Handler) JapeMiddlewareFunc {
return jape.Adapt(func(h http.Handler) http.Handler { return jape.Adapt(func(h http.Handler) http.Handler {
handler := mid(h) handler := mid(h)
@ -56,3 +66,93 @@ func RegisterProtocolSubdomain(config *viper.Viper, mux *httprouter.Router, name
(router)[name+"."+domain] = mux (router)[name+"."+domain] = mux
} }
func FindAuthToken(r *http.Request, cookieName string, queryParam string) string {
authHeader := ParseAuthTokenHeader(r.Header)
if authHeader != "" {
return authHeader
}
if cookie, err := r.Cookie(cookieName); cookie != nil && err == nil {
return cookie.Value
}
return r.FormValue(queryParam)
}
func ParseAuthTokenHeader(headers http.Header) string {
authHeader := headers.Get("Authorization")
if authHeader == "" {
return ""
}
authHeader = strings.TrimPrefix(authHeader, "Bearer ")
return authHeader
}
type AuthMiddlewareOptions struct {
Identity ed25519.PrivateKey
Accounts *account.AccountServiceDefault
FindToken FindAuthTokenFunc
Purpose account.JWTPurpose
AuthContextKey string
Config *viper.Viper
}
func AuthMiddleware(options AuthMiddlewareOptions) func(http.Handler) http.Handler {
if options.AuthContextKey == "" {
options.AuthContextKey = DEFAULT_AUTH_CONTEXT_KEY
}
if options.Purpose == "" {
panic("purpose is missing")
}
domain := options.Config.GetString("core.domain")
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authToken := options.FindToken(r)
if authToken == "" {
http.Error(w, "Invalid JWT", http.StatusUnauthorized)
return
}
claim, err := account.VerifyToken(authToken, domain, options.Identity, func(claim jwt.RegisteredClaims) error {
aud, _ := claim.GetAudience()
if slices.Contains[jwt.ClaimStrings, string](aud, string(options.Purpose)) == false {
return account.ErrJWTInvalid
}
return nil
})
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
userId, err := strconv.ParseUint(claim.Subject, 10, 64)
if err != nil {
http.Error(w, account.ErrJWTInvalid.Error(), http.StatusBadRequest)
return
}
exists, _, err := options.Accounts.AccountExists(uint(userId))
if !exists || err != nil {
http.Error(w, account.ErrJWTInvalid.Error(), http.StatusBadRequest)
return
}
ctx := context.WithValue(r.Context(), options.AuthContextKey, uint(userId))
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
})
}
}

View File

@ -1,108 +0,0 @@
package middleware
import (
"context"
"crypto/ed25519"
"fmt"
"git.lumeweb.com/LumeWeb/portal/account"
"github.com/golang-jwt/jwt/v5"
"net/http"
"strings"
)
const (
S5AuthUserIDKey = "userID"
S5AuthCookieName = "s5-auth-token"
S5AuthQueryParam = "auth_token"
)
func FindAuthToken(r *http.Request) string {
authHeader := ParseAuthTokenHeader(r.Header)
if authHeader != "" {
return authHeader
}
for _, cookie := range r.Cookies() {
if cookie.Name == S5AuthCookieName {
return cookie.Value
}
}
return r.FormValue(S5AuthQueryParam)
}
func ParseAuthTokenHeader(headers http.Header) string {
authHeader := headers.Get("Authorization")
if authHeader == "" {
return ""
}
authHeader = strings.TrimPrefix(authHeader, "Bearer ")
return authHeader
}
func AuthMiddleware(identity ed25519.PrivateKey, accounts *account.AccountServiceDefault) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authToken := FindAuthToken(r)
if authToken == "" {
http.Error(w, "Invalid JWT", http.StatusUnauthorized)
return
}
token, err := jwt.Parse(authToken, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodEd25519); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
publicKey := identity.Public()
return publicKey, nil
})
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if claim, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
subject, ok := claim["sub"]
if !ok {
http.Error(w, "Invalid User ID", http.StatusBadRequest)
return
}
var userID uint64
switch v := subject.(type) {
case uint64:
userID = v
case float64:
userID = uint64(v)
default:
// Handle the case where userID is of an unexpected type
http.Error(w, "Invalid User ID", http.StatusBadRequest)
return
}
exists, _ := accounts.AccountExists(userID)
if !exists {
http.Error(w, "Invalid User ID", http.StatusBadRequest)
return
}
ctx := context.WithValue(r.Context(), S5AuthUserIDKey, userID)
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
} else {
http.Error(w, "Invalid JWT", http.StatusUnauthorized)
}
})
}
}

20
api/s5/middleware.go Normal file
View File

@ -0,0 +1,20 @@
package s5
import (
"git.lumeweb.com/LumeWeb/portal/api/middleware"
"net/http"
)
const (
authCookieName = "s5-auth-token"
authQueryParam = "auth_token"
)
func findToken(r *http.Request) string {
return middleware.FindAuthToken(r, authCookieName, authQueryParam)
}
func authMiddleware(options middleware.AuthMiddlewareOptions) middleware.HttpMiddlewareFunc {
options.FindToken = findToken
return middleware.AuthMiddleware(options)
}