gitea-github-proxy/api/middleware.go

292 lines
7.2 KiB
Go
Raw Normal View History

package api
import (
"code.gitea.io/sdk/gitea"
"context"
"crypto/x509"
"encoding/json"
"encoding/pem"
"fmt"
"git.lumeweb.com/LumeWeb/gitea-github-proxy/config"
"git.lumeweb.com/LumeWeb/gitea-github-proxy/db/model"
"github.com/golang-jwt/jwt"
"github.com/gorilla/mux"
"go.uber.org/zap"
"gorm.io/gorm"
"io"
"net/http"
"strconv"
"strings"
"time"
)
const AUTHED_CONTEXT_KEY = "authed"
const REDIRECT_AFTER_AUTH = "redirect-after-auth"
const WEBHOOK_CONTEXT_KEY = "webhook"
const AuthCookieName = "auth-token"
var _ = jwt.Claims(&standardClaims{})
type standardClaims struct {
Issuer any `json:"iss,omitempty"`
ExpiresAt any `json:"exp,omitempty"`
jwt.StandardClaims
}
func (s *standardClaims) Valid() error {
if timeStr, ok := s.ExpiresAt.(string); ok {
t, err := time.Parse(time.RFC3339Nano, timeStr)
if err != nil {
return err
}
unixTimestamp := t.Unix()
s.ExpiresAt = unixTimestamp
}
return s.StandardClaims.Valid()
}
func findAuthToken(r *http.Request) string {
authHeader := parseAuthTokenHeader(r.Header)
if authHeader != "" {
return authHeader
}
cookie := getCookie(r, AuthCookieName)
if cookie != "" {
return cookie
}
return r.FormValue(AuthCookieName)
}
func giteaOauthVerifyMiddleware(cfg *config.Config) mux.MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := findAuthToken(r)
if token == "" {
addAuthStatusToRequestServ(false, r, w, next)
return
}
client, err := getClient(ClientParams{
Config: cfg,
AuthToken: token,
})
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
_, _, err = client.AdminListUsers(gitea.AdminListUsersOptions{})
if err != nil {
addAuthStatusToRequestServ(false, r, w, next)
return
}
addAuthStatusToRequestServ(true, r, w, next)
})
}
}
func githubRestVerifyMiddleware(db *gorm.DB) mux.MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := findAuthToken(r)
if token != "" {
parseToken, _, err := new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{})
if err != nil {
http.Error(w, "Invalid JWT", http.StatusUnauthorized)
return
}
claims, ok := parseToken.Claims.(jwt.MapClaims)
if !ok {
http.Error(w, "Invalid JWT", http.StatusUnauthorized)
return
}
2024-02-12 05:47:29 +00:00
var appId string
switch v := claims["iss"].(type) {
case string:
appId = v
case float64:
appId = strconv.FormatFloat(v, 'f', -1, 64)
default:
http.Error(w, "Invalid JWT", http.StatusUnauthorized)
return
}
appIdInt, err := strconv.Atoi(appId)
if err != nil {
http.Error(w, "Invalid JWT", http.StatusUnauthorized)
return
}
appRecord := &model.Apps{}
appRecord.ID = uint(appIdInt)
if err := db.First(appRecord).Error; err != nil {
http.Error(w, "Invalid JWT", http.StatusUnauthorized)
return
}
block, _ := pem.Decode([]byte(appRecord.PrivateKey))
if block == nil {
// Handle error
http.Error(w, "Invalid Private Key", http.StatusInternalServerError)
return
}
privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
// Handle error
http.Error(w, "Failed to parse Private Key", http.StatusInternalServerError)
return
}
publicKey := &privateKey.PublicKey
parseToken, err = jwt.ParseWithClaims(token, &standardClaims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
// Return the RSA public key
return publicKey, nil
})
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if mux.CurrentRoute(r).GetName() == "app-install-get-access-token" {
installId := mux.Vars(r)["installation_id"]
installIdInt, err := strconv.Atoi(installId)
if err != nil {
http.Error(w, "Invalid Install", http.StatusUnauthorized)
return
}
if appIdInt != installIdInt {
http.Error(w, "Invalid Install", http.StatusUnauthorized)
return
}
}
}
addAuthStatusToRequestServ(true, r, w, next)
})
}
}
func storeWebhookDataMiddleware(logger *zap.Logger) mux.MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var webhook map[string]interface{}
body, err := io.ReadAll(r.Body)
if err != nil {
logger.Error("Failed to read request body", zap.Error(err))
w.WriteHeader(http.StatusInternalServerError)
return
}
err = json.Unmarshal(body, &webhook)
if err != nil {
logger.Error("Failed to unmarshal webhook", zap.Error(err))
w.WriteHeader(http.StatusInternalServerError)
return
}
if len(webhook) == 0 {
logger.Error("Webhook data is empty")
w.WriteHeader(http.StatusBadRequest)
return
}
ctx := context.WithValue(r.Context(), WEBHOOK_CONTEXT_KEY, body)
r = r.WithContext(ctx)
r.Body = io.NopCloser(strings.NewReader(string(body)))
next.ServeHTTP(w, r)
})
}
}
func requireAuthMiddleware(cfg *config.Config) mux.MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
status := getAuthedStatusFromRequest(r)
if !status {
setCookie(w, REDIRECT_AFTER_AUTH, cfg.Domain, r.Referer(), 0, http.SameSiteLaxMode)
http.Redirect(w, r, "/setup", http.StatusFound)
return
}
next.ServeHTTP(w, r)
})
}
}
func githubRestRequireAuthMiddleware(cfg *config.Config) mux.MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
status := getAuthedStatusFromRequest(r)
if !status {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
})
}
}
func loggingMiddleware(logger *zap.Logger) mux.MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Do stuff here
logger.Debug("Request", zap.String("method", r.Method), zap.String("url", r.RequestURI))
// Call the next handler, which can be another middleware in the chain, or the final handler.
next.ServeHTTP(w, r)
})
}
}
func addAuthStatusToRequestServ(status bool, r *http.Request, w http.ResponseWriter, next http.Handler) {
ctx := context.WithValue(r.Context(), AUTHED_CONTEXT_KEY, status)
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
}
func parseAuthTokenHeader(headers http.Header) string {
authHeader := headers.Get("Authorization")
if authHeader == "" {
return ""
}
authHeader = strings.TrimPrefix(authHeader, "Bearer ")
2024-02-12 05:39:58 +00:00
authHeader = strings.TrimPrefix(authHeader, "bearer ")
2024-02-12 07:03:21 +00:00
authHeader = strings.TrimPrefix(authHeader, "Token ")
authHeader = strings.TrimPrefix(authHeader, "token ")
return authHeader
}
func getAuthedStatusFromRequest(r *http.Request) bool {
authed, ok := r.Context().Value(AUTHED_CONTEXT_KEY).(bool)
if !ok {
return false
}
return authed
}