feat: implement github jwt verification via middleware
This commit is contained in:
parent
d83df43411
commit
41532495bc
|
@ -3,10 +3,16 @@ package api
|
|||
import (
|
||||
"code.gitea.io/sdk/gitea"
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"git.lumeweb.com/LumeWeb/gitea-github-proxy/config"
|
||||
"git.lumeweb.com/LumeWeb/gitea-github-proxy/db/model"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/gorilla/mux"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
@ -62,6 +68,72 @@ func giteaOauthVerifyMiddleware(cfg *config.Config) mux.MiddlewareFunc {
|
|||
})
|
||||
}
|
||||
}
|
||||
func githubRestVerifyMiddleware(db *gorm.DB) mux.MiddlewareFunc {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
token := findAuthToken(r)
|
||||
|
||||
if token != "" {
|
||||
parseToken, _, err := new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{})
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid JWT", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
claims, ok := parseToken.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
http.Error(w, "Invalid JWT", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
appId, ok := claims["iss"].(uint)
|
||||
if !ok {
|
||||
http.Error(w, "Invalid JWT", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
appRecord := &model.Apps{}
|
||||
appRecord.ID = appId
|
||||
|
||||
if err := db.First(appRecord).Error; err != nil {
|
||||
http.Error(w, "Invalid JWT", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
block, _ := pem.Decode([]byte(appRecord.PrivateKey))
|
||||
if block == nil {
|
||||
// Handle error
|
||||
http.Error(w, "Invalid Private Key", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
// Handle error
|
||||
http.Error(w, "Failed to parse Private Key", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
publicKey := &privateKey.PublicKey
|
||||
|
||||
parseToken, err = jwt.ParseWithClaims(token, &jwt.StandardClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
|
||||
// Return the RSA public key
|
||||
return publicKey, nil
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
addAuthStatusToRequestServ(true, r, w, next)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func storeWebhookDataMiddleware(logger *zap.Logger) mux.MiddlewareFunc {
|
||||
return func(next http.Handler) http.Handler {
|
||||
|
@ -112,6 +184,22 @@ func requireAuthMiddleware(cfg *config.Config) mux.MiddlewareFunc {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func githubRestRequireAuthMiddleware(cfg *config.Config) mux.MiddlewareFunc {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
status := getAuthedStatusFromRequest(r)
|
||||
|
||||
if !status {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func loggingMiddleware(logger *zap.Logger) mux.MiddlewareFunc {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
|
|
@ -143,6 +143,8 @@ func setupRestRoutes(params RouteParams) {
|
|||
|
||||
restApi := newRestApi(cfg, logger)
|
||||
restRouter := r.PathPrefix("/api").Subrouter()
|
||||
restRouter.Use(githubRestVerifyMiddleware(params.Db))
|
||||
restRouter.Use(githubRestRequireAuthMiddleware(params.Config))
|
||||
|
||||
restRouter.HandleFunc("/repos/{owner}/{repo}/pulls/{pull_number}/files", restApi.handlerGetPullRequestFiles).Methods("GET")
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue