refactor: make oauth DI managed
This commit is contained in:
parent
6d20106bb5
commit
34898771ab
26
api/oauth.go
26
api/oauth.go
|
@ -11,7 +11,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type oauth struct {
|
type Oauth struct {
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
logger *zap.Logger
|
logger *zap.Logger
|
||||||
token *oauth2.Token
|
token *oauth2.Token
|
||||||
|
@ -19,19 +19,19 @@ type oauth struct {
|
||||||
keepAliveRunning bool
|
keepAliveRunning bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func newOauth(cfg *config.Config, logger *zap.Logger) *oauth {
|
func NewOauth(cfg *config.Config, logger *zap.Logger) *Oauth {
|
||||||
return &oauth{cfg: cfg, logger: logger}
|
return &Oauth{cfg: cfg, logger: logger}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o oauth) config() *oauth2.Config {
|
func (o Oauth) config() *oauth2.Config {
|
||||||
cfg := &oauth2.Config{
|
cfg := &oauth2.Config{
|
||||||
ClientID: o.cfg.Oauth.ClientId,
|
ClientID: o.cfg.Oauth.ClientId,
|
||||||
ClientSecret: o.cfg.Oauth.ClientSecret,
|
ClientSecret: o.cfg.Oauth.ClientSecret,
|
||||||
Scopes: []string{"admin"},
|
Scopes: []string{"admin"},
|
||||||
RedirectURL: fmt.Sprintf("https://%s/setup/callback", o.cfg.Domain),
|
RedirectURL: fmt.Sprintf("https://%s/setup/callback", o.cfg.Domain),
|
||||||
Endpoint: oauth2.Endpoint{
|
Endpoint: oauth2.Endpoint{
|
||||||
TokenURL: fmt.Sprintf("%s/login/oauth/access_token", o.cfg.GiteaUrl),
|
TokenURL: fmt.Sprintf("%s/login/Oauth/access_token", o.cfg.GiteaUrl),
|
||||||
AuthURL: fmt.Sprintf("%s/login/oauth/authorize", o.cfg.GiteaUrl),
|
AuthURL: fmt.Sprintf("%s/login/Oauth/authorize", o.cfg.GiteaUrl),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -41,11 +41,11 @@ func (o oauth) config() *oauth2.Config {
|
||||||
return cfg
|
return cfg
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o oauth) authUrl() string {
|
func (o Oauth) authUrl() string {
|
||||||
return o.config().AuthCodeURL("state")
|
return o.config().AuthCodeURL("state")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o oauth) loadToken(config *oauth2.Config) {
|
func (o Oauth) loadToken(config *oauth2.Config) {
|
||||||
token := &oauth2.Token{}
|
token := &oauth2.Token{}
|
||||||
|
|
||||||
if o.cfg.Oauth.Token != "" {
|
if o.cfg.Oauth.Token != "" {
|
||||||
|
@ -85,7 +85,7 @@ func (o oauth) loadToken(config *oauth2.Config) {
|
||||||
o.refresher = config.TokenSource(context.Background(), token)
|
o.refresher = config.TokenSource(context.Background(), token)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o oauth) keepAlive() {
|
func (o Oauth) keepAlive() {
|
||||||
if o.cfg.Oauth.Token == "" || o.cfg.Oauth.RefreshToken == "" {
|
if o.cfg.Oauth.Token == "" || o.cfg.Oauth.RefreshToken == "" {
|
||||||
o.logger.Error("No token or refresh token provided.")
|
o.logger.Error("No token or refresh token provided.")
|
||||||
return
|
return
|
||||||
|
@ -112,7 +112,7 @@ func (o oauth) keepAlive() {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
func (o *oauth) isTokenValid() bool {
|
func (o *Oauth) isTokenValid() bool {
|
||||||
if o.token == nil {
|
if o.token == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -120,7 +120,7 @@ func (o *oauth) isTokenValid() bool {
|
||||||
return o.token.Valid()
|
return o.token.Valid()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *oauth) refreshToken() error {
|
func (o *Oauth) refreshToken() error {
|
||||||
o.logger.Info("Refreshing token...")
|
o.logger.Info("Refreshing token...")
|
||||||
|
|
||||||
token, err := o.refresher.Token()
|
token, err := o.refresher.Token()
|
||||||
|
@ -133,7 +133,7 @@ func (o *oauth) refreshToken() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *oauth) exchange(code string) (*oauth2.Token, error) {
|
func (o *Oauth) exchange(code string) (*oauth2.Token, error) {
|
||||||
cfg := o.config()
|
cfg := o.config()
|
||||||
token, err := o.config().Exchange(context.Background(), code)
|
token, err := o.config().Exchange(context.Background(), code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -153,7 +153,7 @@ func (o *oauth) exchange(code string) (*oauth2.Token, error) {
|
||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *oauth) client() *gitea.Client {
|
func (o *Oauth) client() *gitea.Client {
|
||||||
client, err := getClient(ClientParams{
|
client, err := getClient(ClientParams{
|
||||||
Config: o.cfg,
|
Config: o.cfg,
|
||||||
})
|
})
|
||||||
|
|
|
@ -16,6 +16,7 @@ type RouteParams struct {
|
||||||
Logger *zap.Logger
|
Logger *zap.Logger
|
||||||
R *mux.Router
|
R *mux.Router
|
||||||
WebhookManager *WebhookManager
|
WebhookManager *WebhookManager
|
||||||
|
Oauth *Oauth
|
||||||
}
|
}
|
||||||
|
|
||||||
func SetupRoutes(params RouteParams) {
|
func SetupRoutes(params RouteParams) {
|
||||||
|
|
|
@ -10,10 +10,10 @@ import (
|
||||||
type setupApi struct {
|
type setupApi struct {
|
||||||
config *config.Config
|
config *config.Config
|
||||||
logger *zap.Logger
|
logger *zap.Logger
|
||||||
oauth *oauth
|
oauth *Oauth
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSetupApi(config *config.Config, logger *zap.Logger, oauth *oauth) *setupApi {
|
func newSetupApi(config *config.Config, logger *zap.Logger, oauth *Oauth) *setupApi {
|
||||||
return &setupApi{config: config, logger: logger, oauth: oauth}
|
return &setupApi{config: config, logger: logger, oauth: oauth}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -64,7 +64,7 @@ func setupApiRoutes(params RouteParams) {
|
||||||
setupRouter := r.PathPrefix("/setup").Subrouter()
|
setupRouter := r.PathPrefix("/setup").Subrouter()
|
||||||
setupRouter.Use(giteaOauthVerifyMiddleware(params.Config))
|
setupRouter.Use(giteaOauthVerifyMiddleware(params.Config))
|
||||||
|
|
||||||
setupApi := newSetupApi(params.Config, params.Logger, newOauth(params.Config, params.Logger))
|
setupApi := newSetupApi(params.Config, params.Logger, params.Oauth)
|
||||||
setupRouter.HandleFunc("", setupApi.setupHandler).Methods("GET")
|
setupRouter.HandleFunc("", setupApi.setupHandler).Methods("GET")
|
||||||
setupRouter.HandleFunc("/callback", setupApi.callbackHandler).Methods("GET")
|
setupRouter.HandleFunc("/callback", setupApi.callbackHandler).Methods("GET")
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,6 +24,7 @@ func main() {
|
||||||
}),
|
}),
|
||||||
config.Module,
|
config.Module,
|
||||||
db.Module,
|
db.Module,
|
||||||
|
fx.Provide(api.NewOauth),
|
||||||
fx.Provide(api.NewRouter),
|
fx.Provide(api.NewRouter),
|
||||||
fx.Provide(NewServer),
|
fx.Provide(NewServer),
|
||||||
fx.Provide(api.NewWebhookManager),
|
fx.Provide(api.NewWebhookManager),
|
||||||
|
|
Loading…
Reference in New Issue