feat: add token monitoring and refresh
This commit is contained in:
parent
5036d0bca4
commit
150c5c6cb2
135
api/oauth.go
135
api/oauth.go
|
@ -1,15 +1,22 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"code.gitea.io/sdk/gitea"
|
||||
"context"
|
||||
"fmt"
|
||||
"git.lumeweb.com/LumeWeb/gitea-github-proxy/config"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/oauth2"
|
||||
"time"
|
||||
)
|
||||
|
||||
type oauth struct {
|
||||
cfg *config.Config
|
||||
logger *zap.Logger
|
||||
token *oauth2.Token
|
||||
refresher oauth2.TokenSource
|
||||
keepAliveRunning bool
|
||||
}
|
||||
|
||||
func newOauth(cfg *config.Config, logger *zap.Logger) *oauth {
|
||||
|
@ -17,7 +24,7 @@ func newOauth(cfg *config.Config, logger *zap.Logger) *oauth {
|
|||
}
|
||||
|
||||
func (o oauth) config() *oauth2.Config {
|
||||
return &oauth2.Config{
|
||||
cfg := &oauth2.Config{
|
||||
ClientID: o.cfg.Oauth.ClientId,
|
||||
ClientSecret: o.cfg.Oauth.ClientSecret,
|
||||
Scopes: []string{"admin"},
|
||||
|
@ -27,8 +34,134 @@ func (o oauth) config() *oauth2.Config {
|
|||
AuthURL: fmt.Sprintf("%s/login/oauth/authorize", o.cfg.GiteaUrl),
|
||||
},
|
||||
}
|
||||
|
||||
o.loadToken(cfg)
|
||||
o.keepAlive()
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
func (o oauth) authUrl() string {
|
||||
return o.config().AuthCodeURL("state")
|
||||
}
|
||||
|
||||
func (o oauth) loadToken(config *oauth2.Config) {
|
||||
token := &oauth2.Token{}
|
||||
|
||||
if o.cfg.Oauth.Token != "" {
|
||||
o.token = &oauth2.Token{AccessToken: o.cfg.Oauth.Token}
|
||||
}
|
||||
|
||||
if o.cfg.Oauth.RefreshToken != "" {
|
||||
o.token.RefreshToken = o.cfg.Oauth.RefreshToken
|
||||
}
|
||||
|
||||
if o.token != nil {
|
||||
parsedToken, err := jwt.Parse(o.cfg.Oauth.Token, func(token *jwt.Token) (interface{}, error) {
|
||||
return nil, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
o.logger.Fatal("Error parsing token", zap.Error(err))
|
||||
}
|
||||
|
||||
valid := false
|
||||
|
||||
if claims, ok := parsedToken.Claims.(jwt.MapClaims); ok && parsedToken.Valid {
|
||||
if exp, ok := claims["exp"].(float64); ok {
|
||||
expirationTime := time.Unix(int64(exp), 0)
|
||||
if time.Now().Before(expirationTime) {
|
||||
valid = true
|
||||
o.token.Expiry = expirationTime
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if valid {
|
||||
token = o.token
|
||||
} else {
|
||||
o.logger.Info("Token is expired, ignoring")
|
||||
}
|
||||
}
|
||||
|
||||
o.refresher = config.TokenSource(context.Background(), token)
|
||||
}
|
||||
|
||||
func (o oauth) keepAlive() {
|
||||
if o.cfg.Oauth.Token == "" || o.cfg.Oauth.RefreshToken == "" {
|
||||
o.logger.Error("No token or refresh token provided.")
|
||||
return
|
||||
}
|
||||
|
||||
if o.keepAliveRunning {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(30 * time.Minute)
|
||||
|
||||
o.keepAliveRunning = true
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if !o.isTokenValid() {
|
||||
if err := o.refreshToken(); err != nil {
|
||||
o.logger.Error("Error refreshing token", zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
func (o *oauth) isTokenValid() bool {
|
||||
if o.token == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return o.token.Valid()
|
||||
}
|
||||
|
||||
func (o *oauth) refreshToken() error {
|
||||
o.logger.Info("Refreshing token...")
|
||||
|
||||
token, err := o.refresher.Token()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
o.token = token
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *oauth) exchange(code string) (*oauth2.Token, error) {
|
||||
cfg := o.config()
|
||||
token, err := o.config().Exchange(context.Background(), code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
o.cfg.Oauth.Token = token.AccessToken
|
||||
o.cfg.Oauth.RefreshToken = token.RefreshToken
|
||||
err = config.SaveConfig(o.cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
o.loadToken(cfg)
|
||||
o.keepAlive()
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (o *oauth) client() *gitea.Client {
|
||||
client, err := getClient(ClientParams{
|
||||
Config: o.cfg,
|
||||
})
|
||||
if err != nil {
|
||||
o.logger.Fatal("Error creating gitea client", zap.Error(err))
|
||||
}
|
||||
|
||||
return client
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue