gitea-github-proxy/api/oauth.go

177 lines
3.5 KiB
Go
Raw Normal View History

package api
import (
2024-02-11 19:58:04 +00:00
"code.gitea.io/sdk/gitea"
"context"
"fmt"
"git.lumeweb.com/LumeWeb/gitea-github-proxy/config"
2024-02-11 19:58:04 +00:00
"github.com/golang-jwt/jwt"
2024-02-11 21:32:19 +00:00
"go.uber.org/fx"
"go.uber.org/zap"
"golang.org/x/oauth2"
2024-02-11 19:58:04 +00:00
"time"
)
2024-02-11 21:30:03 +00:00
type Oauth struct {
2024-02-11 19:58:04 +00:00
cfg *config.Config
logger *zap.Logger
token *oauth2.Token
refresher oauth2.TokenSource
keepAliveRunning bool
}
2024-02-11 21:32:19 +00:00
func NewOauth(lc fx.Lifecycle, cfg *config.Config, logger *zap.Logger) *Oauth {
oa := &Oauth{cfg: cfg, logger: logger}
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
oa.config()
return nil
},
})
return oa
}
2024-02-11 21:30:03 +00:00
func (o Oauth) config() *oauth2.Config {
2024-02-11 19:58:04 +00:00
cfg := &oauth2.Config{
ClientID: o.cfg.Oauth.ClientId,
ClientSecret: o.cfg.Oauth.ClientSecret,
Scopes: []string{"admin"},
RedirectURL: fmt.Sprintf("https://%s/setup/callback", o.cfg.Domain),
Endpoint: oauth2.Endpoint{
2024-02-11 21:30:03 +00:00
TokenURL: fmt.Sprintf("%s/login/Oauth/access_token", o.cfg.GiteaUrl),
AuthURL: fmt.Sprintf("%s/login/Oauth/authorize", o.cfg.GiteaUrl),
},
}
2024-02-11 19:58:04 +00:00
o.loadToken(cfg)
o.keepAlive()
return cfg
}
2024-02-11 21:30:03 +00:00
func (o Oauth) authUrl() string {
return o.config().AuthCodeURL("state")
}
2024-02-11 19:58:04 +00:00
2024-02-11 21:30:03 +00:00
func (o Oauth) loadToken(config *oauth2.Config) {
2024-02-11 19:58:04 +00:00
token := &oauth2.Token{}
if o.cfg.Oauth.Token != "" {
o.token = &oauth2.Token{AccessToken: o.cfg.Oauth.Token}
}
if o.cfg.Oauth.RefreshToken != "" {
o.token.RefreshToken = o.cfg.Oauth.RefreshToken
}
if o.token != nil {
valid := false
parseToken, _, err := new(jwt.Parser).ParseUnverified(o.cfg.Oauth.Token, jwt.MapClaims{})
2024-02-11 19:58:04 +00:00
if err != nil {
o.logger.Error("Error parsing token", zap.Error(err))
} else {
// Assert the token's claims to the desired type (MapClaims in this case)
if claims, ok := parseToken.Claims.(jwt.MapClaims); ok {
if exp, ok := claims["exp"].(float64); ok {
expirationTime := time.Unix(int64(exp), 0)
if time.Now().Before(expirationTime) {
valid = true
o.token.Expiry = expirationTime
}
2024-02-11 19:58:04 +00:00
}
}
if valid {
token = o.token
} else {
o.logger.Info("Token is expired, ignoring")
}
2024-02-11 19:58:04 +00:00
}
}
o.refresher = config.TokenSource(context.Background(), token)
}
2024-02-11 21:30:03 +00:00
func (o Oauth) keepAlive() {
2024-02-11 19:58:04 +00:00
if o.cfg.Oauth.Token == "" || o.cfg.Oauth.RefreshToken == "" {
o.logger.Error("No token or refresh token provided.")
return
}
if o.keepAliveRunning {
return
}
ticker := time.NewTicker(30 * time.Minute)
o.keepAliveRunning = true
go func() {
for {
select {
case <-ticker.C:
if !o.isTokenValid() {
if err := o.refreshToken(); err != nil {
o.logger.Error("Error refreshing token", zap.Error(err))
}
}
}
}
}()
}
2024-02-11 21:30:03 +00:00
func (o *Oauth) isTokenValid() bool {
2024-02-11 19:58:04 +00:00
if o.token == nil {
return false
}
return o.token.Valid()
}
2024-02-11 21:30:03 +00:00
func (o *Oauth) refreshToken() error {
2024-02-11 19:58:04 +00:00
o.logger.Info("Refreshing token...")
token, err := o.refresher.Token()
if err != nil {
return err
}
o.token = token
return nil
}
2024-02-11 21:30:03 +00:00
func (o *Oauth) exchange(code string) (*oauth2.Token, error) {
2024-02-11 19:58:04 +00:00
cfg := o.config()
token, err := o.config().Exchange(context.Background(), code)
if err != nil {
return nil, err
}
o.cfg.Oauth.Token = token.AccessToken
o.cfg.Oauth.RefreshToken = token.RefreshToken
err = config.SaveConfig(o.cfg)
if err != nil {
return nil, err
}
o.loadToken(cfg)
o.keepAlive()
return token, nil
}
2024-02-11 21:30:03 +00:00
func (o *Oauth) client() *gitea.Client {
2024-02-11 19:58:04 +00:00
client, err := getClient(ClientParams{
Config: o.cfg,
})
if err != nil {
o.logger.Fatal("Error creating gitea client", zap.Error(err))
}
return client
}