diff --git a/api/middleware.go b/api/middleware.go index e7e00cb..686ba4f 100644 --- a/api/middleware.go +++ b/api/middleware.go @@ -3,10 +3,16 @@ package api import ( "code.gitea.io/sdk/gitea" "context" + "crypto/x509" "encoding/json" + "encoding/pem" + "fmt" "git.lumeweb.com/LumeWeb/gitea-github-proxy/config" + "git.lumeweb.com/LumeWeb/gitea-github-proxy/db/model" + "github.com/golang-jwt/jwt" "github.com/gorilla/mux" "go.uber.org/zap" + "gorm.io/gorm" "io" "net/http" "strings" @@ -62,6 +68,72 @@ func giteaOauthVerifyMiddleware(cfg *config.Config) mux.MiddlewareFunc { }) } } +func githubRestVerifyMiddleware(db *gorm.DB) mux.MiddlewareFunc { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + token := findAuthToken(r) + + if token != "" { + parseToken, _, err := new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{}) + if err != nil { + http.Error(w, "Invalid JWT", http.StatusUnauthorized) + return + } + + claims, ok := parseToken.Claims.(jwt.MapClaims) + if !ok { + http.Error(w, "Invalid JWT", http.StatusUnauthorized) + return + } + + appId, ok := claims["iss"].(uint) + if !ok { + http.Error(w, "Invalid JWT", http.StatusUnauthorized) + return + } + + appRecord := &model.Apps{} + appRecord.ID = appId + + if err := db.First(appRecord).Error; err != nil { + http.Error(w, "Invalid JWT", http.StatusUnauthorized) + return + } + + block, _ := pem.Decode([]byte(appRecord.PrivateKey)) + if block == nil { + // Handle error + http.Error(w, "Invalid Private Key", http.StatusInternalServerError) + return + } + + privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + // Handle error + http.Error(w, "Failed to parse Private Key", http.StatusInternalServerError) + return + } + + publicKey := &privateKey.PublicKey + + parseToken, err = jwt.ParseWithClaims(token, &jwt.StandardClaims{}, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + + // Return the RSA public key + return publicKey, nil + }) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + } + + addAuthStatusToRequestServ(true, r, w, next) + }) + } +} func storeWebhookDataMiddleware(logger *zap.Logger) mux.MiddlewareFunc { return func(next http.Handler) http.Handler { @@ -112,6 +184,22 @@ func requireAuthMiddleware(cfg *config.Config) mux.MiddlewareFunc { }) } } + +func githubRestRequireAuthMiddleware(cfg *config.Config) mux.MiddlewareFunc { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + status := getAuthedStatusFromRequest(r) + + if !status { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + next.ServeHTTP(w, r) + }) + } +} + func loggingMiddleware(logger *zap.Logger) mux.MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/api/routes_rest_api.go b/api/routes_rest_api.go index 92f73a5..dc2ed05 100644 --- a/api/routes_rest_api.go +++ b/api/routes_rest_api.go @@ -143,6 +143,8 @@ func setupRestRoutes(params RouteParams) { restApi := newRestApi(cfg, logger) restRouter := r.PathPrefix("/api").Subrouter() + restRouter.Use(githubRestVerifyMiddleware(params.Db)) + restRouter.Use(githubRestRequireAuthMiddleware(params.Config)) restRouter.HandleFunc("/repos/{owner}/{repo}/pulls/{pull_number}/files", restApi.handlerGetPullRequestFiles).Methods("GET") }