diff --git a/api/oauth.go b/api/oauth.go index db96924..19b706a 100644 --- a/api/oauth.go +++ b/api/oauth.go @@ -1,15 +1,22 @@ package api import ( + "code.gitea.io/sdk/gitea" + "context" "fmt" "git.lumeweb.com/LumeWeb/gitea-github-proxy/config" + "github.com/golang-jwt/jwt" "go.uber.org/zap" "golang.org/x/oauth2" + "time" ) type oauth struct { - cfg *config.Config - logger *zap.Logger + cfg *config.Config + logger *zap.Logger + token *oauth2.Token + refresher oauth2.TokenSource + keepAliveRunning bool } func newOauth(cfg *config.Config, logger *zap.Logger) *oauth { @@ -17,7 +24,7 @@ func newOauth(cfg *config.Config, logger *zap.Logger) *oauth { } func (o oauth) config() *oauth2.Config { - return &oauth2.Config{ + cfg := &oauth2.Config{ ClientID: o.cfg.Oauth.ClientId, ClientSecret: o.cfg.Oauth.ClientSecret, Scopes: []string{"admin"}, @@ -27,8 +34,134 @@ func (o oauth) config() *oauth2.Config { AuthURL: fmt.Sprintf("%s/login/oauth/authorize", o.cfg.GiteaUrl), }, } + + o.loadToken(cfg) + o.keepAlive() + + return cfg } func (o oauth) authUrl() string { return o.config().AuthCodeURL("state") } + +func (o oauth) loadToken(config *oauth2.Config) { + token := &oauth2.Token{} + + if o.cfg.Oauth.Token != "" { + o.token = &oauth2.Token{AccessToken: o.cfg.Oauth.Token} + } + + if o.cfg.Oauth.RefreshToken != "" { + o.token.RefreshToken = o.cfg.Oauth.RefreshToken + } + + if o.token != nil { + parsedToken, err := jwt.Parse(o.cfg.Oauth.Token, func(token *jwt.Token) (interface{}, error) { + return nil, nil + }) + + if err != nil { + o.logger.Fatal("Error parsing token", zap.Error(err)) + } + + valid := false + + if claims, ok := parsedToken.Claims.(jwt.MapClaims); ok && parsedToken.Valid { + if exp, ok := claims["exp"].(float64); ok { + expirationTime := time.Unix(int64(exp), 0) + if time.Now().Before(expirationTime) { + valid = true + o.token.Expiry = expirationTime + } + } + } + + if valid { + token = o.token + } else { + o.logger.Info("Token is expired, ignoring") + } + } + + o.refresher = config.TokenSource(context.Background(), token) +} + +func (o oauth) keepAlive() { + if o.cfg.Oauth.Token == "" || o.cfg.Oauth.RefreshToken == "" { + o.logger.Error("No token or refresh token provided.") + return + } + + if o.keepAliveRunning { + return + } + + ticker := time.NewTicker(30 * time.Minute) + + o.keepAliveRunning = true + + go func() { + for { + select { + case <-ticker.C: + if !o.isTokenValid() { + if err := o.refreshToken(); err != nil { + o.logger.Error("Error refreshing token", zap.Error(err)) + } + } + } + } + }() +} +func (o *oauth) isTokenValid() bool { + if o.token == nil { + return false + } + + return o.token.Valid() +} + +func (o *oauth) refreshToken() error { + o.logger.Info("Refreshing token...") + + token, err := o.refresher.Token() + if err != nil { + return err + } + + o.token = token + + return nil +} + +func (o *oauth) exchange(code string) (*oauth2.Token, error) { + cfg := o.config() + token, err := o.config().Exchange(context.Background(), code) + if err != nil { + return nil, err + } + + o.cfg.Oauth.Token = token.AccessToken + o.cfg.Oauth.RefreshToken = token.RefreshToken + err = config.SaveConfig(o.cfg) + if err != nil { + return nil, err + } + + o.loadToken(cfg) + o.keepAlive() + + return token, nil +} + +func (o *oauth) client() *gitea.Client { + client, err := getClient(ClientParams{ + Config: o.cfg, + }) + if err != nil { + o.logger.Fatal("Error creating gitea client", zap.Error(err)) + } + + return client +}