refactor: make oauth DI managed

This commit is contained in:
Derrick Hammer 2024-02-11 16:30:03 -05:00
parent 6d20106bb5
commit 34898771ab
Signed by: pcfreak30
GPG Key ID: C997C339BE476FF2
4 changed files with 18 additions and 16 deletions

View File

@ -11,7 +11,7 @@ import (
"time" "time"
) )
type oauth struct { type Oauth struct {
cfg *config.Config cfg *config.Config
logger *zap.Logger logger *zap.Logger
token *oauth2.Token token *oauth2.Token
@ -19,19 +19,19 @@ type oauth struct {
keepAliveRunning bool keepAliveRunning bool
} }
func newOauth(cfg *config.Config, logger *zap.Logger) *oauth { func NewOauth(cfg *config.Config, logger *zap.Logger) *Oauth {
return &oauth{cfg: cfg, logger: logger} return &Oauth{cfg: cfg, logger: logger}
} }
func (o oauth) config() *oauth2.Config { func (o Oauth) config() *oauth2.Config {
cfg := &oauth2.Config{ cfg := &oauth2.Config{
ClientID: o.cfg.Oauth.ClientId, ClientID: o.cfg.Oauth.ClientId,
ClientSecret: o.cfg.Oauth.ClientSecret, ClientSecret: o.cfg.Oauth.ClientSecret,
Scopes: []string{"admin"}, Scopes: []string{"admin"},
RedirectURL: fmt.Sprintf("https://%s/setup/callback", o.cfg.Domain), RedirectURL: fmt.Sprintf("https://%s/setup/callback", o.cfg.Domain),
Endpoint: oauth2.Endpoint{ Endpoint: oauth2.Endpoint{
TokenURL: fmt.Sprintf("%s/login/oauth/access_token", o.cfg.GiteaUrl), TokenURL: fmt.Sprintf("%s/login/Oauth/access_token", o.cfg.GiteaUrl),
AuthURL: fmt.Sprintf("%s/login/oauth/authorize", o.cfg.GiteaUrl), AuthURL: fmt.Sprintf("%s/login/Oauth/authorize", o.cfg.GiteaUrl),
}, },
} }
@ -41,11 +41,11 @@ func (o oauth) config() *oauth2.Config {
return cfg return cfg
} }
func (o oauth) authUrl() string { func (o Oauth) authUrl() string {
return o.config().AuthCodeURL("state") return o.config().AuthCodeURL("state")
} }
func (o oauth) loadToken(config *oauth2.Config) { func (o Oauth) loadToken(config *oauth2.Config) {
token := &oauth2.Token{} token := &oauth2.Token{}
if o.cfg.Oauth.Token != "" { if o.cfg.Oauth.Token != "" {
@ -85,7 +85,7 @@ func (o oauth) loadToken(config *oauth2.Config) {
o.refresher = config.TokenSource(context.Background(), token) o.refresher = config.TokenSource(context.Background(), token)
} }
func (o oauth) keepAlive() { func (o Oauth) keepAlive() {
if o.cfg.Oauth.Token == "" || o.cfg.Oauth.RefreshToken == "" { if o.cfg.Oauth.Token == "" || o.cfg.Oauth.RefreshToken == "" {
o.logger.Error("No token or refresh token provided.") o.logger.Error("No token or refresh token provided.")
return return
@ -112,7 +112,7 @@ func (o oauth) keepAlive() {
} }
}() }()
} }
func (o *oauth) isTokenValid() bool { func (o *Oauth) isTokenValid() bool {
if o.token == nil { if o.token == nil {
return false return false
} }
@ -120,7 +120,7 @@ func (o *oauth) isTokenValid() bool {
return o.token.Valid() return o.token.Valid()
} }
func (o *oauth) refreshToken() error { func (o *Oauth) refreshToken() error {
o.logger.Info("Refreshing token...") o.logger.Info("Refreshing token...")
token, err := o.refresher.Token() token, err := o.refresher.Token()
@ -133,7 +133,7 @@ func (o *oauth) refreshToken() error {
return nil return nil
} }
func (o *oauth) exchange(code string) (*oauth2.Token, error) { func (o *Oauth) exchange(code string) (*oauth2.Token, error) {
cfg := o.config() cfg := o.config()
token, err := o.config().Exchange(context.Background(), code) token, err := o.config().Exchange(context.Background(), code)
if err != nil { if err != nil {
@ -153,7 +153,7 @@ func (o *oauth) exchange(code string) (*oauth2.Token, error) {
return token, nil return token, nil
} }
func (o *oauth) client() *gitea.Client { func (o *Oauth) client() *gitea.Client {
client, err := getClient(ClientParams{ client, err := getClient(ClientParams{
Config: o.cfg, Config: o.cfg,
}) })

View File

@ -16,6 +16,7 @@ type RouteParams struct {
Logger *zap.Logger Logger *zap.Logger
R *mux.Router R *mux.Router
WebhookManager *WebhookManager WebhookManager *WebhookManager
Oauth *Oauth
} }
func SetupRoutes(params RouteParams) { func SetupRoutes(params RouteParams) {

View File

@ -10,10 +10,10 @@ import (
type setupApi struct { type setupApi struct {
config *config.Config config *config.Config
logger *zap.Logger logger *zap.Logger
oauth *oauth oauth *Oauth
} }
func newSetupApi(config *config.Config, logger *zap.Logger, oauth *oauth) *setupApi { func newSetupApi(config *config.Config, logger *zap.Logger, oauth *Oauth) *setupApi {
return &setupApi{config: config, logger: logger, oauth: oauth} return &setupApi{config: config, logger: logger, oauth: oauth}
} }
@ -64,7 +64,7 @@ func setupApiRoutes(params RouteParams) {
setupRouter := r.PathPrefix("/setup").Subrouter() setupRouter := r.PathPrefix("/setup").Subrouter()
setupRouter.Use(giteaOauthVerifyMiddleware(params.Config)) setupRouter.Use(giteaOauthVerifyMiddleware(params.Config))
setupApi := newSetupApi(params.Config, params.Logger, newOauth(params.Config, params.Logger)) setupApi := newSetupApi(params.Config, params.Logger, params.Oauth)
setupRouter.HandleFunc("", setupApi.setupHandler).Methods("GET") setupRouter.HandleFunc("", setupApi.setupHandler).Methods("GET")
setupRouter.HandleFunc("/callback", setupApi.callbackHandler).Methods("GET") setupRouter.HandleFunc("/callback", setupApi.callbackHandler).Methods("GET")
} }

View File

@ -24,6 +24,7 @@ func main() {
}), }),
config.Module, config.Module,
db.Module, db.Module,
fx.Provide(api.NewOauth),
fx.Provide(api.NewRouter), fx.Provide(api.NewRouter),
fx.Provide(NewServer), fx.Provide(NewServer),
fx.Provide(api.NewWebhookManager), fx.Provide(api.NewWebhookManager),