diff --git a/cmd/portal/init.go b/cmd/portal/init.go index d20c081..e1ffd16 100644 --- a/cmd/portal/init.go +++ b/cmd/portal/init.go @@ -72,7 +72,7 @@ func initAccess(p interfaces.Portal) error { } func initDatabase(p interfaces.Portal) error { - return p.Database().Init(p) + return p.DatabaseService().Init(p) } func initProtocols(p interfaces.Portal) error { diff --git a/cmd/portal/portal.go b/cmd/portal/portal.go index cc364be..d180504 100644 --- a/cmd/portal/portal.go +++ b/cmd/portal/portal.go @@ -31,9 +31,12 @@ type PortalImpl struct { accounts interfaces.AccountService } -func (p *PortalImpl) Database() interfaces.Database { +func (p *PortalImpl) DatabaseService() interfaces.Database { return p.database } +func (p *PortalImpl) Database() *gorm.DB { + return p.database.Get() +} func NewPortal() interfaces.Portal { portal := &PortalImpl{ diff --git a/db/db.go b/db/db.go index 2d8b543..f68b3b5 100644 --- a/db/db.go +++ b/db/db.go @@ -15,7 +15,7 @@ var ( ) type DatabaseImpl struct { - DB *gorm.DB + db *gorm.DB portal interfaces.Portal } @@ -43,14 +43,14 @@ func (d *DatabaseImpl) Init(p interfaces.Portal) error { if err != nil { p.Logger().Error("Failed to connect to database", zap.Error(err)) } - d.DB = db + d.db = db return nil } // Start performs any additional setup func (d *DatabaseImpl) Start() error { - return d.DB.AutoMigrate( + return d.db.AutoMigrate( &models.APIKey{}, &models.Blocklist{}, &models.Download{}, @@ -60,3 +60,7 @@ func (d *DatabaseImpl) Start() error { &models.User{}, ) } + +func (d *DatabaseImpl) Get() *gorm.DB { + return d.db +} diff --git a/interfaces/database.go b/interfaces/database.go index 537623c..b652b14 100644 --- a/interfaces/database.go +++ b/interfaces/database.go @@ -1,6 +1,9 @@ package interfaces +import "gorm.io/gorm" + type Database interface { Init(p Portal) error Start() error + Get() *gorm.DB } diff --git a/interfaces/portal.go b/interfaces/portal.go index 4bf4c15..ec1c3da 100644 --- a/interfaces/portal.go +++ b/interfaces/portal.go @@ -20,7 +20,8 @@ type Portal interface { Storage() StorageService SetIdentity(identity ed25519.PrivateKey) SetLogger(logger *zap.Logger) - Database() Database + Database() *gorm.DB + DatabaseService() Database Casbin() *casbin.Enforcer SetCasbin(e *casbin.Enforcer) Accounts() AccountService