diff --git a/account/account.go b/account/account.go index 81d9860..d0e8133 100644 --- a/account/account.go +++ b/account/account.go @@ -1,35 +1,48 @@ package account import ( + "crypto/ed25519" "git.lumeweb.com/LumeWeb/portal/db/models" - "git.lumeweb.com/LumeWeb/portal/interfaces" + "github.com/spf13/viper" + "go.uber.org/fx" "golang.org/x/crypto/bcrypt" "gorm.io/gorm" ) -var ( - _ interfaces.AccountService = (*AccountServiceImpl)(nil) +type AccountServiceParams struct { + fx.In + Db *gorm.DB + Config *viper.Viper + Identity ed25519.PrivateKey +} + +var Module = fx.Module("account", + fx.Options( + fx.Provide(NewAccountService), + ), ) type AccountServiceImpl struct { - portal interfaces.Portal + db *gorm.DB + config *viper.Viper + identity ed25519.PrivateKey } -func NewAccountService(portal interfaces.Portal) interfaces.AccountService { - return &AccountServiceImpl{portal: portal} +func NewAccountService(params AccountServiceParams) *AccountServiceImpl { + return &AccountServiceImpl{db: params.Db, config: params.Config, identity: params.Identity} } func (s AccountServiceImpl) EmailExists(email string) (bool, models.User) { var user models.User - result := s.portal.Database().Model(&models.User{}).Where(&models.User{Email: email}).First(&user) + result := s.db.Model(&models.User{}).Where(&models.User{Email: email}).First(&user) return result.RowsAffected > 0, user } func (s AccountServiceImpl) PubkeyExists(pubkey string) (bool, models.PublicKey) { var model models.PublicKey - result := s.portal.Database().Model(&models.PublicKey{}).Where(&models.PublicKey{Key: pubkey}).First(&model) + result := s.db.Model(&models.PublicKey{}).Where(&models.PublicKey{Key: pubkey}).First(&model) return result.RowsAffected > 0, model } @@ -37,7 +50,7 @@ func (s AccountServiceImpl) PubkeyExists(pubkey string) (bool, models.PublicKey) func (s AccountServiceImpl) AccountExists(id uint64) (bool, models.User) { var model models.User - result := s.portal.Database().Model(&models.User{}).First(&model, id) + result := s.db.Model(&models.User{}).First(&model, id) return result.RowsAffected > 0, model } @@ -52,7 +65,7 @@ func (s AccountServiceImpl) CreateAccount(email string, password string) (*model user.Email = email user.PasswordHash = string(bytes) - result := s.portal.Database().Create(&user) + result := s.db.Create(&user) if result.Error != nil { return nil, result.Error @@ -66,7 +79,7 @@ func (s AccountServiceImpl) AddPubkeyToAccount(user models.User, pubkey string) model.Key = pubkey model.UserID = user.ID - result := s.portal.Database().Create(&model) + result := s.db.Create(&model) if result.Error != nil { return result.Error @@ -77,7 +90,7 @@ func (s AccountServiceImpl) AddPubkeyToAccount(user models.User, pubkey string) func (s AccountServiceImpl) LoginPassword(email string, password string) (string, error) { var user models.User - result := s.portal.Database().Model(&models.User{}).Where(&models.User{Email: email}).First(&user) + result := s.db.Model(&models.User{}).Where(&models.User{Email: email}).First(&user) if result.RowsAffected == 0 || result.Error != nil { return "", result.Error @@ -88,7 +101,7 @@ func (s AccountServiceImpl) LoginPassword(email string, password string) (string return "", err } - token, err := GenerateToken(s.portal.Identity(), user.ID) + token, err := GenerateToken(s.identity, user.ID) if err != nil { return "", err } @@ -99,13 +112,13 @@ func (s AccountServiceImpl) LoginPassword(email string, password string) (string func (s AccountServiceImpl) LoginPubkey(pubkey string) (string, error) { var model models.PublicKey - result := s.portal.Database().Model(&models.PublicKey{}).Where(&models.PublicKey{Key: pubkey}).First(&model) + result := s.db.Model(&models.PublicKey{}).Where(&models.PublicKey{Key: pubkey}).First(&model) if result.RowsAffected == 0 || result.Error != nil { return "", result.Error } - token, err := GenerateToken(s.portal.Identity(), model.UserID) + token, err := GenerateToken(s.identity, model.UserID) if err != nil { return "", err } @@ -116,7 +129,7 @@ func (s AccountServiceImpl) LoginPubkey(pubkey string) (string, error) { func (s AccountServiceImpl) AccountPins(id uint64, createdAfter uint64) ([]models.Pin, error) { var pins []models.Pin - result := s.portal.Database().Model(&models.Pin{}). + result := s.db.Model(&models.Pin{}). Preload("Upload"). // Preload the related Upload for each Pin Where(&models.Pin{UserID: uint(id)}). Where("created_at > ?", createdAfter). @@ -136,7 +149,7 @@ func (s AccountServiceImpl) DeletePinByHash(hash string, accountID uint) error { // Retrieve the upload ID for the given hash var uploadID uint - result := s.portal.Database(). + result := s.db. Model(&models.Upload{}). Where(&uploadQuery). Select("id"). @@ -152,7 +165,7 @@ func (s AccountServiceImpl) DeletePinByHash(hash string, accountID uint) error { // Delete pins with the retrieved upload ID and matching account ID pinQuery := models.Pin{UploadID: uploadID, UserID: accountID} - result = s.portal.Database(). + result = s.db. Where(&pinQuery). Delete(&models.Pin{}) @@ -168,7 +181,7 @@ func (s AccountServiceImpl) PinByHash(hash string, accountID uint) error { // Retrieve the upload ID for the given hash var uploadID uint - result := s.portal.Database(). + result := s.db. Model(&models.Upload{}). Where(&uploadQuery). First(&uploadID) @@ -181,7 +194,7 @@ func (s AccountServiceImpl) PinByHash(hash string, accountID uint) error { } func (s AccountServiceImpl) PinByID(uploadId uint, accountID uint) error { - result := s.portal.Database().Model(&models.Pin{}).Where(&models.Pin{UploadID: uploadId, UserID: accountID}).First(&models.Pin{}) + result := s.db.Model(&models.Pin{}).Where(&models.Pin{UploadID: uploadId, UserID: accountID}).First(&models.Pin{}) if result.Error != nil && result.Error != gorm.ErrRecordNotFound { return result.Error @@ -193,7 +206,7 @@ func (s AccountServiceImpl) PinByID(uploadId uint, accountID uint) error { // Create a pin with the retrieved upload ID and matching account ID pinQuery := models.Pin{UploadID: uploadId, UserID: accountID} - result = s.portal.Database().Create(&pinQuery) + result = s.db.Create(&pinQuery) if result.Error != nil { return result.Error diff --git a/api/api.go b/api/api.go index 50dbceb..d5a47b2 100644 --- a/api/api.go +++ b/api/api.go @@ -1,19 +1,59 @@ package api import ( - "git.lumeweb.com/LumeWeb/portal/interfaces" - "github.com/julienschmidt/httprouter" + "context" + "git.lumeweb.com/LumeWeb/portal/api/registry" + "git.lumeweb.com/LumeWeb/portal/api/router" + "github.com/spf13/viper" + "go.uber.org/fx" ) -func Init(router interfaces.APIRegistry) error { - router.Register("s5", NewS5()) - return nil +func RegisterApis() { + registry.Register(registry.APIEntry{ + Key: "s5", + Module: S5Module, + }) } -func registerProtocolSubdomain(portal interfaces.Portal, mux *httprouter.Router, name string) { - - router := portal.ApiRegistry().Router() - domain := portal.Config().GetString("core.domain") - - (*router)[name+"."+domain] = mux +func getModulesBasedOnConfig() []fx.Option { + var modules []fx.Option + for _, entry := range registry.GetRegistry() { + if viper.GetBool("protocols." + entry.Key + ".enabled") { + modules = append(modules, entry.Module) + } + } + return modules +} + +func BuildApis(config *viper.Viper) fx.Option { + var options []fx.Option + for _, entry := range registry.GetRegistry() { + if config.GetBool("protocols." + entry.Key + ".enabled") { + options = append(options, entry.Module) + if entry.InitFunc != nil { + options = append(options, fx.Invoke(entry.InitFunc)) + } + } + } + + return fx.Module("api", fx.Options(options...), fx.Provide(func() router.ProtocolRouter { + return registry.GetRouter() + })) +} + +func SetupLifecycles(lifecycle fx.Lifecycle, protocols []registry.API) { + for _, entry := range registry.GetRegistry() { + for _, protocol := range protocols { + if protocol.Name() == entry.Key { + lifecycle.Append(fx.Hook{ + OnStart: func(ctx context.Context) error { + return protocol.Start(ctx) + }, + OnStop: func(ctx context.Context) error { + return protocol.Stop(ctx) + }, + }) + } + } + } } diff --git a/api/casbin.go b/api/casbin.go index ebc6376..e0ae12d 100644 --- a/api/casbin.go +++ b/api/casbin.go @@ -9,7 +9,7 @@ import ( "sync" ) -func GetCasbin(logger *zap.Logger) *casbin.Enforcer { +func NewCasbin(logger *zap.Logger) *casbin.Enforcer { m := model.NewModel() m.AddDef("r", "r", "sub, obj, act") m.AddDef("p", "p", "sub, obj, act") diff --git a/api/middleware/middleware.go b/api/middleware/middleware.go index 2003fc6..e625f2b 100644 --- a/api/middleware/middleware.go +++ b/api/middleware/middleware.go @@ -1,6 +1,9 @@ package middleware import ( + "git.lumeweb.com/LumeWeb/portal/api/registry" + "github.com/julienschmidt/httprouter" + "github.com/spf13/viper" "go.sia.tech/jape" "net/http" "strings" @@ -47,3 +50,9 @@ func ApplyMiddlewares(handler jape.Handler, middlewares ...interface{}) jape.Han } return handler } +func RegisterProtocolSubdomain(config *viper.Viper, mux *httprouter.Router, name string) { + router := registry.GetRouter() + domain := config.GetString("core.domain") + + (router)[name+"."+domain] = mux +} diff --git a/api/middleware/s5.go b/api/middleware/s5.go index d5803b5..c8bb88e 100644 --- a/api/middleware/s5.go +++ b/api/middleware/s5.go @@ -2,8 +2,10 @@ package middleware import ( "context" + "crypto/ed25519" "fmt" - "git.lumeweb.com/LumeWeb/portal/interfaces" + "git.lumeweb.com/LumeWeb/portal/account" + "git.lumeweb.com/LumeWeb/portal/storage" "github.com/golang-jwt/jwt/v5" "go.sia.tech/jape" "net/http" @@ -44,7 +46,7 @@ func parseAuthTokenHeader(headers http.Header) string { return authHeader } -func AuthMiddleware(portal interfaces.Portal) func(http.Handler) http.Handler { +func AuthMiddleware(identity ed25519.PrivateKey, accounts *account.AccountServiceImpl) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { authToken := findAuthToken(r) @@ -59,7 +61,7 @@ func AuthMiddleware(portal interfaces.Portal) func(http.Handler) http.Handler { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } - publicKey := portal.Identity().Public() + publicKey := identity.Public() return publicKey, nil }) @@ -90,7 +92,7 @@ func AuthMiddleware(portal interfaces.Portal) func(http.Handler) http.Handler { return } - exists, _ := portal.Accounts().AccountExists(userID) + exists, _ := accounts.AccountExists(userID) if !exists { http.Error(w, "Invalid User ID", http.StatusBadRequest) @@ -134,10 +136,10 @@ func (w *tusJwtResponseWriter) WriteHeader(statusCode int) { w.ResponseWriter.WriteHeader(statusCode) } -func BuildS5TusApi(portal interfaces.Portal) jape.Handler { +func BuildS5TusApi(identity ed25519.PrivateKey, accounts *account.AccountServiceImpl, storage *storage.StorageServiceImpl) jape.Handler { // Create a jape.Handler for your tusHandler tusJapeHandler := func(c jape.Context) { - tusHandler := portal.Storage().Tus() + tusHandler := storage.Tus() tusHandler.ServeHTTP(c.ResponseWriter, c.Request) } @@ -164,7 +166,7 @@ func BuildS5TusApi(portal interfaces.Portal) jape.Handler { } // Apply the middlewares to the tusJapeHandler - tusHandler := ApplyMiddlewares(tusJapeHandler, AuthMiddleware(portal), injectJwt, protocolMiddleware, stripPrefix, proxyMiddleware) + tusHandler := ApplyMiddlewares(tusJapeHandler, AuthMiddleware(identity, accounts), injectJwt, protocolMiddleware, stripPrefix, proxyMiddleware) return tusHandler } diff --git a/api/registry.go b/api/registry.go deleted file mode 100644 index cf96285..0000000 --- a/api/registry.go +++ /dev/null @@ -1,48 +0,0 @@ -package api - -import ( - "errors" - "git.lumeweb.com/LumeWeb/portal/api/router" - "git.lumeweb.com/LumeWeb/portal/interfaces" -) - -var ( - _ interfaces.APIRegistry = (*APIRegistryImpl)(nil) -) - -type APIRegistryImpl struct { - apis map[string]interfaces.API - router *router.ProtocolRouter -} - -func NewRegistry() interfaces.APIRegistry { - return &APIRegistryImpl{ - apis: make(map[string]interfaces.API), - router: &router.ProtocolRouter{}, - } -} - -func (r *APIRegistryImpl) Register(name string, APIRegistry interfaces.API) { - if _, exists := r.apis[name]; exists { - panic("api already registered") - } - r.apis[name] = APIRegistry -} - -func (r *APIRegistryImpl) Get(name string) (interfaces.API, error) { - APIRegistry, exists := r.apis[name] - if !exists { - return nil, errors.New("api not found") - } - return APIRegistry, nil -} -func (r *APIRegistryImpl) Router() *router.ProtocolRouter { - return r.router -} -func (r *APIRegistryImpl) All() map[string]interfaces.API { - aMap := make(map[string]interfaces.API) - for key, value := range r.apis { - aMap[key] = value - } - return aMap -} diff --git a/api/registry/registry.go b/api/registry/registry.go new file mode 100644 index 0000000..15e49dc --- /dev/null +++ b/api/registry/registry.go @@ -0,0 +1,35 @@ +package registry + +import ( + "context" + router2 "git.lumeweb.com/LumeWeb/portal/api/router" + "go.uber.org/fx" +) + +type API interface { + Name() string + Init() error + Start(ctx context.Context) error + Stop(ctx context.Context) error +} + +type APIEntry struct { + Key string + Module fx.Option + InitFunc interface{} +} + +var apiRegistry []APIEntry +var router router2.ProtocolRouter + +func Register(entry APIEntry) { + apiRegistry = append(apiRegistry, entry) +} + +func GetRegistry() []APIEntry { + return apiRegistry +} + +func GetRouter() router2.ProtocolRouter { + return router +} diff --git a/api/s5.go b/api/s5.go index 51a64f2..fed1825 100644 --- a/api/s5.go +++ b/api/s5.go @@ -1,50 +1,111 @@ package api import ( + "context" + "crypto/ed25519" + "fmt" + "git.lumeweb.com/LumeWeb/portal/account" "git.lumeweb.com/LumeWeb/portal/api/middleware" + "git.lumeweb.com/LumeWeb/portal/api/registry" "git.lumeweb.com/LumeWeb/portal/api/s5" - "git.lumeweb.com/LumeWeb/portal/interfaces" "git.lumeweb.com/LumeWeb/portal/protocols" + protoRegistry "git.lumeweb.com/LumeWeb/portal/protocols/registry" + "git.lumeweb.com/LumeWeb/portal/storage" "github.com/rs/cors" + "github.com/spf13/viper" "go.sia.tech/jape" + "go.uber.org/fx" ) var ( - _ interfaces.API = (*S5API)(nil) + _ registry.API = (*S5API)(nil) ) type S5API struct { + config *viper.Viper + identity ed25519.PrivateKey + accounts *account.AccountServiceImpl + storage *storage.StorageServiceImpl + protocols []protoRegistry.Protocol + httpHandler s5.HttpHandler + protocol *protocols.S5Protocol } -func NewS5() *S5API { - return &S5API{} +type S5ApiParams struct { + fx.In + Config *viper.Viper + Identity ed25519.PrivateKey + Accounts *account.AccountServiceImpl + Storage *storage.StorageServiceImpl + Protocols []protoRegistry.Protocol + HttpHandler s5.HttpHandler } -func (s S5API) Initialize(portal interfaces.Portal, protocol interfaces.Protocol) error { - s5protocol := protocol.(*protocols.S5Protocol) - s5http := s5.NewHttpHandler(portal) - registerProtocolSubdomain(portal, s5protocol.Node().Services().HTTP().GetHttpRouter(getRoutes(s5http, portal)), "s5") +type S5ApiResult struct { + fx.Out + Protocol registry.API `group:"api"` +} + +func NewS5(params S5ApiParams) (S5ApiResult, error) { + return S5ApiResult{ + Protocol: &S5API{ + config: params.Config, + identity: params.Identity, + accounts: params.Accounts, + storage: params.Storage, + protocols: params.Protocols, + httpHandler: params.HttpHandler, + }, + }, nil +} + +var S5Module = fx.Module("s5_api", + fx.Provide(NewS5), + fx.Provide(s5.NewHttpHandler), +) + +func (s *S5API) Init() error { + s5protocol := protoRegistry.FindProtocolByName("s5", s.protocols) + if s5protocol == nil { + return fmt.Errorf("s5 protocol not found") + } + + s5protocolInstance := s5protocol.(*protocols.S5Protocol) + s.protocol = s5protocolInstance + router := s5protocolInstance.Node().Services().HTTP().GetHttpRouter(getRoutes(s)) + middleware.RegisterProtocolSubdomain(s.config, router, "s5") return nil } -func getRoutes(h *s5.HttpHandler, portal interfaces.Portal) map[string]jape.Handler { +func (s S5API) Name() string { + return "s5" +} - tusHandler := middleware.BuildS5TusApi(portal) +func (s S5API) Start(ctx context.Context) error { + return s.protocol.Node().Start() +} + +func (s S5API) Stop(ctx context.Context) error { + return nil +} + +func getRoutes(s *S5API) map[string]jape.Handler { + tusHandler := middleware.BuildS5TusApi(s.identity, s.accounts, s.storage) return map[string]jape.Handler{ // Account API - "GET /s5/account/register": h.AccountRegisterChallenge, - "POST /s5/account/register": h.AccountRegister, - "GET /s5/account/login": h.AccountLoginChallenge, - "POST /s5/account/login": h.AccountLogin, - "GET /s5/account": middleware.ApplyMiddlewares(h.AccountInfo, middleware.AuthMiddleware(portal)), - "GET /s5/account/stats": middleware.ApplyMiddlewares(h.AccountStats, middleware.AuthMiddleware(portal)), - "GET /s5/account/pins.bin": middleware.ApplyMiddlewares(h.AccountPins, middleware.AuthMiddleware(portal)), + "GET /s5/account/register": s.httpHandler.AccountRegisterChallenge, + "POST /s5/account/register": s.httpHandler.AccountRegister, + "GET /s5/account/login": s.httpHandler.AccountLoginChallenge, + "POST /s5/account/login": s.httpHandler.AccountLogin, + "GET /s5/account": middleware.ApplyMiddlewares(s.httpHandler.AccountInfo, middleware.AuthMiddleware(s.identity, s.accounts)), + "GET /s5/account/stats": middleware.ApplyMiddlewares(s.httpHandler.AccountStats, middleware.AuthMiddleware(s.identity, s.accounts)), + "GET /s5/account/pins.bin": middleware.ApplyMiddlewares(s.httpHandler.AccountPins, middleware.AuthMiddleware(s.identity, s.accounts)), // Upload API - "POST /s5/upload": middleware.ApplyMiddlewares(h.SmallFileUpload, middleware.AuthMiddleware(portal)), - "POST /s5/upload/directory": middleware.ApplyMiddlewares(h.DirectoryUpload, middleware.AuthMiddleware(portal)), + "POST /s5/upload": middleware.ApplyMiddlewares(s.httpHandler.SmallFileUpload, middleware.AuthMiddleware(s.identity, s.accounts)), + "POST /s5/upload/directory": middleware.ApplyMiddlewares(s.httpHandler.DirectoryUpload, middleware.AuthMiddleware(s.identity, s.accounts)), // Tus API "POST /s5/upload/tus": tusHandler, @@ -53,22 +114,22 @@ func getRoutes(h *s5.HttpHandler, portal interfaces.Portal) map[string]jape.Hand "PATCH /s5/upload/tus/:id": tusHandler, // Download API - "GET /s5/blob/:cid": middleware.ApplyMiddlewares(h.DownloadBlob, middleware.AuthMiddleware(portal)), - "GET /s5/metadata/:cid": h.DownloadMetadata, - // "GET /s5/download/:cid": middleware.ApplyMiddlewares(h.DownloadFile, middleware.AuthMiddleware(portal)), - "GET /s5/download/:cid": middleware.ApplyMiddlewares(h.DownloadFile, cors.Default().Handler), + "GET /s5/blob/:cid": middleware.ApplyMiddlewares(s.httpHandler.DownloadBlob, middleware.AuthMiddleware(s.identity, s.accounts)), + "GET /s5/metadata/:cid": s.httpHandler.DownloadMetadata, + // "GET /s5/download/:cid": middleware.ApplyMiddlewares(s.httpHandler.DownloadFile, middleware.AuthMiddleware(portal)), + "GET /s5/download/:cid": middleware.ApplyMiddlewares(s.httpHandler.DownloadFile, cors.Default().Handler), // Pins API - "POST /s5/pin/:cid": middleware.ApplyMiddlewares(h.AccountPin, middleware.AuthMiddleware(portal)), - "DELETE /s5/delete/:cid": middleware.ApplyMiddlewares(h.AccountPinDelete, middleware.AuthMiddleware(portal)), + "POST /s5/pin/:cid": middleware.ApplyMiddlewares(s.httpHandler.AccountPin, middleware.AuthMiddleware(s.identity, s.accounts)), + "DELETE /s5/delete/:cid": middleware.ApplyMiddlewares(s.httpHandler.AccountPinDelete, middleware.AuthMiddleware(s.identity, s.accounts)), // Debug API - "GET /s5/debug/download_urls/:cid": middleware.ApplyMiddlewares(h.DebugDownloadUrls, middleware.AuthMiddleware(portal)), - "GET /s5/debug/storage_locations/:hash": middleware.ApplyMiddlewares(h.DebugStorageLocations, middleware.AuthMiddleware(portal)), + "GET /s5/debug/download_urls/:cid": middleware.ApplyMiddlewares(s.httpHandler.DebugDownloadUrls, middleware.AuthMiddleware(s.identity, s.accounts)), + "GET /s5/debug/storage_locations/:hash": middleware.ApplyMiddlewares(s.httpHandler.DebugStorageLocations, middleware.AuthMiddleware(s.identity, s.accounts)), // Registry API - "GET /s5/registry": middleware.ApplyMiddlewares(h.RegistryQuery, middleware.AuthMiddleware(portal)), - "POST /s5/registry": middleware.ApplyMiddlewares(h.RegistrySet, middleware.AuthMiddleware(portal)), - "GET /s5/registry/subscription": middleware.ApplyMiddlewares(h.RegistrySubscription, middleware.AuthMiddleware(portal)), + "GET /s5/registry": middleware.ApplyMiddlewares(s.httpHandler.RegistryQuery, middleware.AuthMiddleware(s.identity, s.accounts)), + "POST /s5/registry": middleware.ApplyMiddlewares(s.httpHandler.RegistrySet, middleware.AuthMiddleware(s.identity, s.accounts)), + "GET /s5/registry/subscription": middleware.ApplyMiddlewares(s.httpHandler.RegistrySubscription, middleware.AuthMiddleware(s.identity, s.accounts)), } } diff --git a/api/s5/http.go b/api/s5/http.go index 57b2fe2..616ed78 100644 --- a/api/s5/http.go +++ b/api/s5/http.go @@ -15,15 +15,19 @@ import ( s5protocol "git.lumeweb.com/LumeWeb/libs5-go/protocol" s5storage "git.lumeweb.com/LumeWeb/libs5-go/storage" "git.lumeweb.com/LumeWeb/libs5-go/types" + "git.lumeweb.com/LumeWeb/portal/account" "git.lumeweb.com/LumeWeb/portal/api/middleware" "git.lumeweb.com/LumeWeb/portal/db/models" - "git.lumeweb.com/LumeWeb/portal/interfaces" "git.lumeweb.com/LumeWeb/portal/protocols" + "git.lumeweb.com/LumeWeb/portal/storage" emailverifier "github.com/AfterShip/email-verifier" "github.com/samber/lo" + "github.com/spf13/viper" "github.com/vmihailenco/msgpack/v5" "go.sia.tech/jape" + "go.uber.org/fx" "go.uber.org/zap" + "gorm.io/gorm" "io" "math" "mime/multipart" @@ -73,12 +77,27 @@ var ( ) type HttpHandler struct { - portal interfaces.Portal verifier *emailverifier.Verifier + config *viper.Viper + logger *zap.Logger + storage *storage.StorageServiceImpl + db *gorm.DB + accounts *account.AccountServiceImpl + protocol *protocols.S5Protocol } -func NewHttpHandler(portal interfaces.Portal) *HttpHandler { +type HttpHandlerParams struct { + fx.In + Config *viper.Viper + Logger *zap.Logger + Storage *storage.StorageServiceImpl + Db *gorm.DB + Accounts *account.AccountServiceImpl + Protocol *protocols.S5Protocol +} + +func NewHttpHandler(params HttpHandlerParams) *HttpHandler { verifier := emailverifier.NewVerifier() verifier.DisableSMTPCheck() @@ -87,8 +106,13 @@ func NewHttpHandler(portal interfaces.Portal) *HttpHandler { verifier.DisableAutoUpdateDisposable() return &HttpHandler{ - portal: portal, verifier: verifier, + config: params.Config, + logger: params.Logger, + storage: params.Storage, + db: params.Db, + accounts: params.Accounts, + protocol: params.Protocol, } } @@ -101,23 +125,23 @@ func (h *HttpHandler) SmallFileUpload(jc jape.Context) { if strings.HasPrefix(contentType, "multipart/form-data") { // Parse the multipart form - err := r.ParseMultipartForm(h.portal.Config().GetInt64("core.post-upload-limit")) + err := r.ParseMultipartForm(h.config.GetInt64("core.post-upload-limit")) if jc.Check(errMultiformParse, err) != nil { - h.portal.Logger().Error(errMultiformParse, zap.Error(err)) + h.logger.Error(errMultiformParse, zap.Error(err)) return } // Retrieve the file from the form data file, _, err := r.FormFile("file") if jc.Check(errRetrievingFile, err) != nil { - h.portal.Logger().Error(errRetrievingFile, zap.Error(err)) + h.logger.Error(errRetrievingFile, zap.Error(err)) return } defer func(file multipart.File) { err := file.Close() if err != nil { - h.portal.Logger().Error(errClosingStream, zap.Error(err)) + h.logger.Error(errClosingStream, zap.Error(err)) } }(file) @@ -125,7 +149,7 @@ func (h *HttpHandler) SmallFileUpload(jc jape.Context) { } else { data, err := io.ReadAll(r.Body) if jc.Check(errReadFile, err) != nil { - h.portal.Logger().Error(errReadFile, zap.Error(err)) + h.logger.Error(errReadFile, zap.Error(err)) return } @@ -136,37 +160,37 @@ func (h *HttpHandler) SmallFileUpload(jc jape.Context) { defer func(Body io.ReadCloser) { err := Body.Close() if err != nil { - h.portal.Logger().Error(errClosingStream, zap.Error(err)) + h.logger.Error(errClosingStream, zap.Error(err)) } }(r.Body) } - hash, err := h.portal.Storage().GetHashSmall(rs) + hash, err := h.storage.GetHashSmall(rs) _, err = rs.Seek(0, io.SeekStart) if err != nil { _ = jc.Error(errUploadingFileErr, http.StatusInternalServerError) - h.portal.Logger().Error(errUploadingFile, zap.Error(err)) + h.logger.Error(errUploadingFile, zap.Error(err)) return } - if exists, upload := h.portal.Storage().FileExists(hash); exists { + if exists, upload := h.storage.FileExists(hash); exists { cid, err := encoding.CIDFromHash(hash, upload.Size, types.CIDTypeRaw, types.HashTypeBlake3) if err != nil { _ = jc.Error(errUploadingFileErr, http.StatusInternalServerError) - h.portal.Logger().Error(errUploadingFile, zap.Error(err)) + h.logger.Error(errUploadingFile, zap.Error(err)) return } cidStr, err := cid.ToString() if err != nil { _ = jc.Error(errUploadingFileErr, http.StatusInternalServerError) - h.portal.Logger().Error(errUploadingFile, zap.Error(err)) + h.logger.Error(errUploadingFile, zap.Error(err)) return } - err = h.portal.Accounts().PinByID(upload.ID, uint(jc.Request.Context().Value(middleware.S5AuthUserIDKey).(uint64))) + err = h.accounts.PinByID(upload.ID, uint(jc.Request.Context().Value(middleware.S5AuthUserIDKey).(uint64))) if err != nil { _ = jc.Error(errUploadingFileErr, http.StatusInternalServerError) - h.portal.Logger().Error(errUploadingFile, zap.Error(err)) + h.logger.Error(errUploadingFile, zap.Error(err)) return } @@ -176,21 +200,21 @@ func (h *HttpHandler) SmallFileUpload(jc jape.Context) { return } - hash, err = h.portal.Storage().PutFileSmall(rs, "s5", false) + hash, err = h.storage.PutFileSmall(rs, "s5", false) if err != nil { _ = jc.Error(errUploadingFileErr, http.StatusInternalServerError) - h.portal.Logger().Error(errUploadingFile, zap.Error(err)) + h.logger.Error(errUploadingFile, zap.Error(err)) return } - h.portal.Logger().Info("Hash", zap.String("hash", hex.EncodeToString(hash))) + h.logger.Info("Hash", zap.String("hash", hex.EncodeToString(hash))) cid, err := encoding.CIDFromHash(hash, uint64(bufferSize), types.CIDTypeRaw, types.HashTypeBlake3) if err != nil { _ = jc.Error(errUploadingFileErr, http.StatusInternalServerError) - h.portal.Logger().Error(errUploadingFile, zap.Error(err)) + h.logger.Error(errUploadingFile, zap.Error(err)) return } @@ -198,16 +222,16 @@ func (h *HttpHandler) SmallFileUpload(jc jape.Context) { if err != nil { _ = jc.Error(errUploadingFileErr, http.StatusInternalServerError) - h.portal.Logger().Error(errUploadingFile, zap.Error(err)) + h.logger.Error(errUploadingFile, zap.Error(err)) return } - h.portal.Logger().Info("CID", zap.String("cidStr", cidStr)) + h.logger.Info("CID", zap.String("cidStr", cidStr)) _, err = rs.Seek(0, io.SeekStart) if err != nil { _ = jc.Error(errUploadingFileErr, http.StatusInternalServerError) - h.portal.Logger().Error(errUploadingFile, zap.Error(err)) + h.logger.Error(errUploadingFile, zap.Error(err)) return } @@ -217,22 +241,22 @@ func (h *HttpHandler) SmallFileUpload(jc jape.Context) { _, err = rs.Read(mimeBytes[:]) if err != nil { _ = jc.Error(errUploadingFileErr, http.StatusInternalServerError) - h.portal.Logger().Error(errUploadingFile, zap.Error(err)) + h.logger.Error(errUploadingFile, zap.Error(err)) return } mimeType := http.DetectContentType(mimeBytes[:]) - upload, err := h.portal.Storage().CreateUpload(hash, mimeType, uint(jc.Request.Context().Value(middleware.S5AuthUserIDKey).(uint64)), jc.Request.RemoteAddr, uint64(bufferSize), "s5") + upload, err := h.storage.CreateUpload(hash, mimeType, uint(jc.Request.Context().Value(middleware.S5AuthUserIDKey).(uint64)), jc.Request.RemoteAddr, uint64(bufferSize), "s5") if err != nil { _ = jc.Error(errUploadingFileErr, http.StatusInternalServerError) - h.portal.Logger().Error(errUploadingFile, zap.Error(err)) + h.logger.Error(errUploadingFile, zap.Error(err)) } - err = h.portal.Accounts().PinByID(upload.ID, uint(jc.Request.Context().Value(middleware.S5AuthUserIDKey).(uint64))) + err = h.accounts.PinByID(upload.ID, uint(jc.Request.Context().Value(middleware.S5AuthUserIDKey).(uint64))) if err != nil { _ = jc.Error(errUploadingFileErr, http.StatusInternalServerError) - h.portal.Logger().Error(errUploadingFile, zap.Error(err)) + h.logger.Error(errUploadingFile, zap.Error(err)) } jc.Encode(&SmallUploadResponse{ @@ -251,7 +275,7 @@ func (h *HttpHandler) AccountRegisterChallenge(jc jape.Context) { _, err := rand.Read(challenge) if err != nil { _ = jc.Error(errAccountGenerateChallengeErr, http.StatusInternalServerError) - h.portal.Logger().Error(errAccountGenerateChallenge, zap.Error(err)) + h.logger.Error(errAccountGenerateChallenge, zap.Error(err)) return } @@ -259,17 +283,17 @@ func (h *HttpHandler) AccountRegisterChallenge(jc jape.Context) { if err != nil { _ = jc.Error(errAccountGenerateChallengeErr, http.StatusInternalServerError) - h.portal.Logger().Error(errAccountGenerateChallenge, zap.Error(err)) + h.logger.Error(errAccountGenerateChallenge, zap.Error(err)) return } if len(decodedKey) != 33 && int(decodedKey[0]) != int(types.HashTypeEd25519) { _ = jc.Error(errAccountGenerateChallengeErr, http.StatusInternalServerError) - h.portal.Logger().Error(errAccountGenerateChallenge, zap.Error(err)) + h.logger.Error(errAccountGenerateChallenge, zap.Error(err)) return } - result := h.portal.Database().Create(&models.S5Challenge{ + result := h.db.Create(&models.S5Challenge{ Pubkey: pubkey, Challenge: base64.RawURLEncoding.EncodeToString(challenge), Type: "register", @@ -277,7 +301,7 @@ func (h *HttpHandler) AccountRegisterChallenge(jc jape.Context) { if result.Error != nil { _ = jc.Error(errAccountGenerateChallengeErr, http.StatusInternalServerError) - h.portal.Logger().Error(errAccountGenerateChallenge, zap.Error(err)) + h.logger.Error(errAccountGenerateChallenge, zap.Error(err)) return } @@ -295,7 +319,7 @@ func (h *HttpHandler) AccountRegister(jc jape.Context) { errored := func(err error) { _ = jc.Error(errAccountRegisterErr, http.StatusInternalServerError) - h.portal.Logger().Error(errAccountRegister, zap.Error(err)) + h.logger.Error(errAccountRegister, zap.Error(err)) } decodedKey, err := base64.RawURLEncoding.DecodeString(request.Pubkey) @@ -312,7 +336,7 @@ func (h *HttpHandler) AccountRegister(jc jape.Context) { var challenge models.S5Challenge - result := h.portal.Database().Model(&models.S5Challenge{}).Where(&models.S5Challenge{Pubkey: request.Pubkey, Type: "register"}).First(&challenge) + result := h.db.Model(&models.S5Challenge{}).Where(&models.S5Challenge{Pubkey: request.Pubkey, Type: "register"}).First(&challenge) if result.RowsAffected == 0 || result.Error != nil { errored(err) @@ -371,14 +395,14 @@ func (h *HttpHandler) AccountRegister(jc jape.Context) { return } - accountExists, _ := h.portal.Accounts().EmailExists(request.Email) + accountExists, _ := h.accounts.EmailExists(request.Email) if accountExists { errored(errEmailAlreadyExists) return } - pubkeyExists, _ := h.portal.Accounts().PubkeyExists(hex.EncodeToString(decodedKey[1:])) + pubkeyExists, _ := h.accounts.PubkeyExists(hex.EncodeToString(decodedKey[1:])) if pubkeyExists { errored(errPubkeyAlreadyExists) @@ -394,7 +418,7 @@ func (h *HttpHandler) AccountRegister(jc jape.Context) { return } - newAccount, err := h.portal.Accounts().CreateAccount(request.Email, string(passwd)) + newAccount, err := h.accounts.CreateAccount(request.Email, string(passwd)) if err != nil { errored(errAccountRegisterErr) return @@ -402,19 +426,19 @@ func (h *HttpHandler) AccountRegister(jc jape.Context) { rawPubkey := hex.EncodeToString(decodedKey[1:]) - err = h.portal.Accounts().AddPubkeyToAccount(*newAccount, rawPubkey) + err = h.accounts.AddPubkeyToAccount(*newAccount, rawPubkey) if err != nil { errored(errAccountRegisterErr) return } - jwt, err := h.portal.Accounts().LoginPubkey(rawPubkey) + jwt, err := h.accounts.LoginPubkey(rawPubkey) if err != nil { errored(errAccountRegisterErr) return } - result = h.portal.Database().Delete(&challenge) + result = h.db.Delete(&challenge) if result.Error != nil { errored(errAccountRegisterErr) @@ -432,7 +456,7 @@ func (h *HttpHandler) AccountLoginChallenge(jc jape.Context) { errored := func(err error) { _ = jc.Error(errAccountLoginErr, http.StatusInternalServerError) - h.portal.Logger().Error(errAccountLogin, zap.Error(err)) + h.logger.Error(errAccountLogin, zap.Error(err)) } challenge := make([]byte, 32) @@ -440,7 +464,7 @@ func (h *HttpHandler) AccountLoginChallenge(jc jape.Context) { _, err := rand.Read(challenge) if err != nil { _ = jc.Error(errAccountGenerateChallengeErr, http.StatusInternalServerError) - h.portal.Logger().Error(errAccountGenerateChallenge, zap.Error(err)) + h.logger.Error(errAccountGenerateChallenge, zap.Error(err)) return } @@ -456,14 +480,14 @@ func (h *HttpHandler) AccountLoginChallenge(jc jape.Context) { return } - pubkeyExists, _ := h.portal.Accounts().PubkeyExists(hex.EncodeToString(decodedKey[1:])) + pubkeyExists, _ := h.accounts.PubkeyExists(hex.EncodeToString(decodedKey[1:])) if pubkeyExists { errored(errPubkeyNotExist) return } - result := h.portal.Database().Create(&models.S5Challenge{ + result := h.db.Create(&models.S5Challenge{ Challenge: base64.RawURLEncoding.EncodeToString(challenge), Type: "login", }) @@ -487,7 +511,7 @@ func (h *HttpHandler) AccountLogin(jc jape.Context) { errored := func(err error) { _ = jc.Error(errAccountLoginErr, http.StatusInternalServerError) - h.portal.Logger().Error(errAccountLogin, zap.Error(err)) + h.logger.Error(errAccountLogin, zap.Error(err)) } decodedKey, err := base64.RawURLEncoding.DecodeString(request.Pubkey) @@ -503,7 +527,7 @@ func (h *HttpHandler) AccountLogin(jc jape.Context) { var challenge models.S5Challenge - result := h.portal.Database().Model(&models.S5Challenge{}).Where(&models.S5Challenge{Pubkey: request.Pubkey, Type: "login"}).First(&challenge) + result := h.db.Model(&models.S5Challenge{}).Where(&models.S5Challenge{Pubkey: request.Pubkey, Type: "login"}).First(&challenge) if result.RowsAffected == 0 || result.Error != nil { errored(err) @@ -551,14 +575,14 @@ func (h *HttpHandler) AccountLogin(jc jape.Context) { return } - jwt, err := h.portal.Accounts().LoginPubkey(request.Pubkey) + jwt, err := h.accounts.LoginPubkey(request.Pubkey) if err != nil { errored(errAccountLoginErr) return } - result = h.portal.Database().Delete(&challenge) + result = h.db.Delete(&challenge) if result.Error != nil { errored(errAccountLoginErr) @@ -569,7 +593,7 @@ func (h *HttpHandler) AccountLogin(jc jape.Context) { } func (h *HttpHandler) AccountInfo(jc jape.Context) { - _, user := h.portal.Accounts().AccountExists(jc.Request.Context().Value(middleware.S5AuthUserIDKey).(uint64)) + _, user := h.accounts.AccountExists(jc.Request.Context().Value(middleware.S5AuthUserIDKey).(uint64)) info := &AccountInfoResponse{ Email: user.Email, @@ -589,7 +613,7 @@ func (h *HttpHandler) AccountInfo(jc jape.Context) { } func (h *HttpHandler) AccountStats(jc jape.Context) { - _, user := h.portal.Accounts().AccountExists(jc.Request.Context().Value(middleware.S5AuthUserIDKey).(uint64)) + _, user := h.accounts.AccountExists(jc.Request.Context().Value(middleware.S5AuthUserIDKey).(uint64)) info := &AccountStatsResponse{ AccountInfoResponse: AccountInfoResponse{ @@ -624,10 +648,10 @@ func (h *HttpHandler) AccountPins(jc jape.Context) { errored := func(err error) { _ = jc.Error(errFailedToGetPinsErr, http.StatusInternalServerError) - h.portal.Logger().Error(errFailedToGetPins, zap.Error(err)) + h.logger.Error(errFailedToGetPins, zap.Error(err)) } - pins, err := h.portal.Accounts().AccountPins(jc.Request.Context().Value(middleware.S5AuthUserIDKey).(uint64), cursor) + pins, err := h.accounts.AccountPins(jc.Request.Context().Value(middleware.S5AuthUserIDKey).(uint64), cursor) if err != nil { errored(err) @@ -657,7 +681,7 @@ func (h *HttpHandler) AccountPinDelete(jc jape.Context) { errored := func(err error) { _ = jc.Error(errFailedToDelPinErr, http.StatusInternalServerError) - h.portal.Logger().Error(errFailedToDelPin, zap.Error(err)) + h.logger.Error(errFailedToDelPin, zap.Error(err)) } decodedCid, err := encoding.CIDFromString(cid) @@ -669,7 +693,7 @@ func (h *HttpHandler) AccountPinDelete(jc jape.Context) { hash := hex.EncodeToString(decodedCid.Hash.HashBytes()) - err = h.portal.Accounts().DeletePinByHash(hash, uint(jc.Request.Context().Value(middleware.S5AuthUserIDKey).(uint64))) + err = h.accounts.DeletePinByHash(hash, uint(jc.Request.Context().Value(middleware.S5AuthUserIDKey).(uint64))) if err != nil { errored(err) @@ -686,7 +710,7 @@ func (h *HttpHandler) AccountPin(jc jape.Context) { errored := func(err error) { _ = jc.Error(errFailedToAddPinErr, http.StatusInternalServerError) - h.portal.Logger().Error(errFailedToAddPin, zap.Error(err)) + h.logger.Error(errFailedToAddPin, zap.Error(err)) } decodedCid, err := encoding.CIDFromString(cid) @@ -696,12 +720,12 @@ func (h *HttpHandler) AccountPin(jc jape.Context) { return } - h.portal.Logger().Info("CID", zap.String("cidStr", cid)) - h.portal.Logger().Info("hash", zap.String("hash", hex.EncodeToString(decodedCid.Hash.HashBytes()))) + h.logger.Info("CID", zap.String("cidStr", cid)) + h.logger.Info("hash", zap.String("hash", hex.EncodeToString(decodedCid.Hash.HashBytes()))) hash := hex.EncodeToString(decodedCid.Hash.HashBytes()) - err = h.portal.Accounts().PinByHash(hash, uint(jc.Request.Context().Value(middleware.S5AuthUserIDKey).(uint64))) + err = h.accounts.PinByHash(hash, uint(jc.Request.Context().Value(middleware.S5AuthUserIDKey).(uint64))) if err != nil { errored(err) @@ -733,19 +757,19 @@ func (h *HttpHandler) DirectoryUpload(jc jape.Context) { errored := func(err error) { _ = jc.Error(errUploadingFileErr, http.StatusInternalServerError) - h.portal.Logger().Error(errUploadingFile, zap.Error(err)) + h.logger.Error(errUploadingFile, zap.Error(err)) } if !strings.HasPrefix(contentType, "multipart/form-data") { _ = jc.Error(errNotMultiformErr, http.StatusBadRequest) - h.portal.Logger().Error(errorNotMultiform) + h.logger.Error(errorNotMultiform) return } - err := r.ParseMultipartForm(h.portal.Config().GetInt64("core.post-upload-limit")) + err := r.ParseMultipartForm(h.config.GetInt64("core.post-upload-limit")) if jc.Check(errMultiformParse, err) != nil { - h.portal.Logger().Error(errMultiformParse, zap.Error(err)) + h.logger.Error(errMultiformParse, zap.Error(err)) return } @@ -763,26 +787,26 @@ func (h *HttpHandler) DirectoryUpload(jc jape.Context) { defer func(file multipart.File) { err := file.Close() if err != nil { - h.portal.Logger().Error(errClosingStream, zap.Error(err)) + h.logger.Error(errClosingStream, zap.Error(err)) } }(file) var rs io.ReadSeeker - hash, err := h.portal.Storage().GetHashSmall(rs) + hash, err := h.storage.GetHashSmall(rs) _, err = rs.Seek(0, io.SeekStart) if err != nil { _ = jc.Error(errUploadingFileErr, http.StatusInternalServerError) - h.portal.Logger().Error(errUploadingFile, zap.Error(err)) + h.logger.Error(errUploadingFile, zap.Error(err)) return } - if exists, upload := h.portal.Storage().FileExists(hash); exists { + if exists, upload := h.storage.FileExists(hash); exists { uploadMap[fileHeader.Filename] = upload continue } - hash, err = h.portal.Storage().PutFileSmall(rs, "s5", false) + hash, err = h.storage.PutFileSmall(rs, "s5", false) if err != nil { errored(err) @@ -807,7 +831,7 @@ func (h *HttpHandler) DirectoryUpload(jc jape.Context) { } mimeType := http.DetectContentType(mimeBytes[:]) - upload, err := h.portal.Storage().CreateUpload(hash, mimeType, uint(jc.Request.Context().Value(middleware.S5AuthUserIDKey).(uint64)), jc.Request.RemoteAddr, uint64(fileHeader.Size), "s5") + upload, err := h.storage.CreateUpload(hash, mimeType, uint(jc.Request.Context().Value(middleware.S5AuthUserIDKey).(uint64)), jc.Request.RemoteAddr, uint64(fileHeader.Size), "s5") if err != nil { errored(err) @@ -859,32 +883,32 @@ func (h *HttpHandler) DirectoryUpload(jc jape.Context) { var rs = bytes.NewReader(appData) - hash, err := h.portal.Storage().GetHashSmall(rs) + hash, err := h.storage.GetHashSmall(rs) _, err = rs.Seek(0, io.SeekStart) if err != nil { _ = jc.Error(errUploadingFileErr, http.StatusInternalServerError) - h.portal.Logger().Error(errUploadingFile, zap.Error(err)) + h.logger.Error(errUploadingFile, zap.Error(err)) return } - if exists, upload := h.portal.Storage().FileExists(hash); exists { + if exists, upload := h.storage.FileExists(hash); exists { cid, err := encoding.CIDFromHash(hash, upload.Size, types.CIDTypeMetadataWebapp, types.HashTypeBlake3) if err != nil { _ = jc.Error(errUploadingFileErr, http.StatusInternalServerError) - h.portal.Logger().Error(errUploadingFile, zap.Error(err)) + h.logger.Error(errUploadingFile, zap.Error(err)) return } cidStr, err := cid.ToString() if err != nil { _ = jc.Error(errUploadingFileErr, http.StatusInternalServerError) - h.portal.Logger().Error(errUploadingFile, zap.Error(err)) + h.logger.Error(errUploadingFile, zap.Error(err)) return } jc.Encode(map[string]string{"hash": cidStr}) return } - hash, err = h.portal.Storage().PutFileSmall(rs, "s5", false) + hash, err = h.storage.PutFileSmall(rs, "s5", false) if err != nil { errored(err) @@ -917,7 +941,7 @@ func (h *HttpHandler) DebugDownloadUrls(jc jape.Context) { if err != nil { _ = jc.Error(errFetchingUrlsErr, http.StatusInternalServerError) - h.portal.Logger().Error(errFetchingUrls, zap.Error(err)) + h.logger.Error(errFetchingUrls, zap.Error(err)) return } @@ -928,14 +952,14 @@ func (h *HttpHandler) DebugDownloadUrls(jc jape.Context) { err = dlUriProvider.Start() if err != nil { _ = jc.Error(errFetchingUrlsErr, http.StatusInternalServerError) - h.portal.Logger().Error(errFetchingUrls, zap.Error(err)) + h.logger.Error(errFetchingUrls, zap.Error(err)) return } _, err = dlUriProvider.Next() if err != nil { _ = jc.Error(errFetchingUrlsErr, http.StatusInternalServerError) - h.portal.Logger().Error(errFetchingUrls, zap.Error(err)) + h.logger.Error(errFetchingUrls, zap.Error(err)) return } @@ -944,7 +968,7 @@ func (h *HttpHandler) DebugDownloadUrls(jc jape.Context) { }) if err != nil { _ = jc.Error(errFetchingUrlsErr, http.StatusInternalServerError) - h.portal.Logger().Error(errFetchingUrls, zap.Error(err)) + h.logger.Error(errFetchingUrls, zap.Error(err)) return } @@ -956,7 +980,7 @@ func (h *HttpHandler) DebugDownloadUrls(jc jape.Context) { nodeId, err := encoding.DecodeNodeId(nodeIdStr) if err != nil { _ = jc.Error(errFetchingUrlsErr, http.StatusInternalServerError) - h.portal.Logger().Error(errFetchingUrls, zap.Error(err)) + h.logger.Error(errFetchingUrls, zap.Error(err)) return } availableNodesIds[i] = nodeId @@ -969,7 +993,7 @@ func (h *HttpHandler) DebugDownloadUrls(jc jape.Context) { if err != nil { _ = jc.Error(errFetchingUrlsErr, http.StatusInternalServerError) - h.portal.Logger().Error(errFetchingUrls, zap.Error(err)) + h.logger.Error(errFetchingUrls, zap.Error(err)) return } @@ -979,7 +1003,7 @@ func (h *HttpHandler) DebugDownloadUrls(jc jape.Context) { nodeIdStr, err := nodeId.ToString() if err != nil { _ = jc.Error(errFetchingUrlsErr, http.StatusInternalServerError) - h.portal.Logger().Error(errFetchingUrls, zap.Error(err)) + h.logger.Error(errFetchingUrls, zap.Error(err)) return } output[i] = locations[nodeIdStr].BytesURL() @@ -1058,13 +1082,13 @@ func (h *HttpHandler) RegistrySubscription(jc jape.Context) { // Accept the WebSocket connection c, err := websocket.Accept(jc.ResponseWriter, jc.Request, nil) if err != nil { - h.portal.Logger().Error("error accepting websocket connection", zap.Error(err)) + h.logger.Error("error accepting websocket connection", zap.Error(err)) return } defer func(c *websocket.Conn, code websocket.StatusCode, reason string) { err := c.Close(code, reason) if err != nil { - h.portal.Logger().Error("error closing websocket connection", zap.Error(err)) + h.logger.Error("error closing websocket connection", zap.Error(err)) } for _, listener := range listeners { @@ -1079,10 +1103,10 @@ func (h *HttpHandler) RegistrySubscription(jc jape.Context) { if err != nil { if websocket.CloseStatus(err) == websocket.StatusNormalClosure { // Normal closure - h.portal.Logger().Info("websocket connection closed normally") + h.logger.Info("websocket connection closed normally") } else { // Handle different types of errors - h.portal.Logger().Error("error in websocket connection", zap.Error(err)) + h.logger.Error("error in websocket connection", zap.Error(err)) } break } @@ -1092,37 +1116,37 @@ func (h *HttpHandler) RegistrySubscription(jc jape.Context) { method, err := decoder.DecodeInt() if err != nil { - h.portal.Logger().Error("error decoding method", zap.Error(err)) + h.logger.Error("error decoding method", zap.Error(err)) break } if method != 2 { - h.portal.Logger().Error("invalid method", zap.Int64("method", int64(method))) + h.logger.Error("invalid method", zap.Int64("method", int64(method))) break } sre, err := decoder.DecodeBytes() if err != nil { - h.portal.Logger().Error("error decoding sre", zap.Error(err)) + h.logger.Error("error decoding sre", zap.Error(err)) break } off, err := h.getNode().Services().Registry().Listen(sre, func(entry s5interfaces.SignedRegistryEntry) { encoded, err := msgpack.Marshal(entry) if err != nil { - h.portal.Logger().Error("error encoding entry", zap.Error(err)) + h.logger.Error("error encoding entry", zap.Error(err)) return } err = c.Write(ctx, websocket.MessageBinary, encoded) if err != nil { - h.portal.Logger().Error("error writing to websocket", zap.Error(err)) + h.logger.Error("error writing to websocket", zap.Error(err)) } }) if err != nil { - h.portal.Logger().Error("error listening to registry", zap.Error(err)) + h.logger.Error("error listening to registry", zap.Error(err)) break } @@ -1131,10 +1155,7 @@ func (h *HttpHandler) RegistrySubscription(jc jape.Context) { } func (h *HttpHandler) getNode() s5interfaces.Node { - proto, _ := h.portal.ProtocolRegistry().Get("s5") - protoInstance := proto.(*protocols.S5Protocol) - - return protoInstance.Node() + return h.protocol.Node() } func (h *HttpHandler) DownloadBlob(jc jape.Context) { @@ -1276,7 +1297,7 @@ func (h *HttpHandler) DownloadMetadata(jc jape.Context) { cidDecoded, err := encoding.CIDFromString(cid) if jc.Check("error decoding cid", err) != nil { - h.portal.Logger().Error("error decoding cid", zap.Error(err)) + h.logger.Error("error decoding cid", zap.Error(err)) return } @@ -1293,7 +1314,7 @@ func (h *HttpHandler) DownloadMetadata(jc jape.Context) { meta, err := h.getNode().GetMetadataByCID(cidDecoded) if jc.Check("error getting metadata", err) != nil { - h.portal.Logger().Error("error getting metadata", zap.Error(err)) + h.logger.Error("error getting metadata", zap.Error(err)) return } @@ -1331,7 +1352,7 @@ func (h *HttpHandler) DownloadFile(jc jape.Context) { hashBytes = cidDecoded.Hash.HashBytes() } - file := h.portal.Storage().NewFile(hashBytes) + file := h.storage.NewFile(hashBytes) if !file.Exists() { jc.ResponseWriter.WriteHeader(http.StatusNotFound) @@ -1341,7 +1362,7 @@ func (h *HttpHandler) DownloadFile(jc jape.Context) { defer func(file io.ReadCloser) { err := file.Close() if err != nil { - h.portal.Logger().Error("error closing file", zap.Error(err)) + h.logger.Error("error closing file", zap.Error(err)) } }(file) diff --git a/cmd/portal/init.go b/cmd/portal/init.go index ada72c6..125ea11 100644 --- a/cmd/portal/init.go +++ b/cmd/portal/init.go @@ -2,44 +2,12 @@ package main import ( "crypto/ed25519" - "git.lumeweb.com/LumeWeb/portal/api" - "git.lumeweb.com/LumeWeb/portal/config" - "git.lumeweb.com/LumeWeb/portal/interfaces" - "git.lumeweb.com/LumeWeb/portal/logger" - "git.lumeweb.com/LumeWeb/portal/protocols" + "github.com/spf13/viper" "go.sia.tech/core/wallet" + "go.uber.org/zap" ) -type initFunc func(p interfaces.Portal) error - -func initConfig(p interfaces.Portal) error { - return config.Init(p) -} - -func initIdentity(p interfaces.Portal) error { - var seed [32]byte - identitySeed := p.Config().GetString("core.identity") - - if identitySeed == "" { - p.Logger().Info("Generating new identity seed") - identitySeed = wallet.NewSeedPhrase() - p.Config().Set("core.identity", identitySeed) - err := p.Config().WriteConfig() - if err != nil { - return err - } - } - err := wallet.SeedFromPhrase(&seed, identitySeed) - if err != nil { - return err - } - - p.SetIdentity(ed25519.PrivateKey(wallet.KeyFromSeed(&seed, 0))) - - return nil -} - -func initCheckRequiredConfig(p interfaces.Portal) error { +func initCheckRequiredConfig(logger zap.Logger, config *viper.Viper) error { required := []string{ "core.domain", "core.port", @@ -57,84 +25,31 @@ func initCheckRequiredConfig(p interfaces.Portal) error { } for _, key := range required { - if !p.Config().IsSet(key) { - p.Logger().Fatal(key + " is required") + if !config.IsSet(key) { + logger.Fatal(key + " is required") } } return nil } -func initLogger(p interfaces.Portal) error { - p.SetLogger(logger.Init(p.Config())) +func NewIdentity(config *viper.Viper, logger *zap.Logger) (ed25519.PrivateKey, error) { + var seed [32]byte + identitySeed := config.GetString("core.identity") - return nil -} - -func initAccess(p interfaces.Portal) error { - p.SetCasbin(api.GetCasbin(p.Logger())) - return nil -} - -func initDatabase(p interfaces.Portal) error { - return p.DatabaseService().Init() -} - -func initProtocols(p interfaces.Portal) error { - return protocols.Init(p.ProtocolRegistry()) -} - -func initStorage(p interfaces.Portal) error { - return p.Storage().Init() -} - -func initAPI(p interfaces.Portal) error { - return api.Init(p.ApiRegistry()) -} - -func initializeProtocolRegistry(p interfaces.Portal) error { - for _, _func := range p.ProtocolRegistry().All() { - err := _func.Initialize(p) + if identitySeed == "" { + logger.Info("Generating new identity seed") + identitySeed = wallet.NewSeedPhrase() + config.Set("core.identity", identitySeed) + err := config.WriteConfig() if err != nil { - return err + return nil, err } } - - return nil -} - -func initializeAPIRegistry(p interfaces.Portal) error { - for protoName, _func := range p.ApiRegistry().All() { - proto, err := p.ProtocolRegistry().Get(protoName) - if err != nil { - return err - } - err = _func.Initialize(p, proto) - if err != nil { - return err - } + err := wallet.SeedFromPhrase(&seed, identitySeed) + if err != nil { + return nil, err } - return nil -} - -func initCron(p interfaces.Portal) error { - return p.CronService().Init() -} - -func getInitList() []initFunc { - return []initFunc{ - initConfig, - initIdentity, - initLogger, - initCheckRequiredConfig, - initAccess, - initDatabase, - initProtocols, - initStorage, - initAPI, - initializeProtocolRegistry, - initializeAPIRegistry, - initCron, - } + return ed25519.PrivateKey(wallet.KeyFromSeed(&seed, 0)), nil } diff --git a/cmd/portal/main.go b/cmd/portal/main.go index df4fbd5..c420531 100644 --- a/cmd/portal/main.go +++ b/cmd/portal/main.go @@ -1,12 +1,82 @@ package main -import "go.uber.org/zap" +import ( + "context" + "git.lumeweb.com/LumeWeb/portal/api" + "git.lumeweb.com/LumeWeb/portal/api/registry" + _config "git.lumeweb.com/LumeWeb/portal/config" + "git.lumeweb.com/LumeWeb/portal/cron" + "git.lumeweb.com/LumeWeb/portal/db" + _logger "git.lumeweb.com/LumeWeb/portal/logger" + "git.lumeweb.com/LumeWeb/portal/protocols" + "git.lumeweb.com/LumeWeb/portal/storage" + "github.com/spf13/viper" + "go.uber.org/fx" + "go.uber.org/fx/fxevent" + "go.uber.org/zap" + "net" + "net/http" +) func main() { - portal := NewPortal() - err := portal.Initialize() + + logger := _logger.NewLogger() + config, err := _config.NewConfig(logger) + if err != nil { - portal.Logger().Fatal("Failed to initialize portal", zap.Error(err)) + logger.Fatal("Failed to load config", zap.Error(err)) } - portal.Run() + + protocols.RegisterProtocols() + api.RegisterApis() + + fx.New( + fx.Provide(_logger.NewFallbackLogger), + fx.Provide(func() *viper.Viper { + return config + }), + + fx.Decorate(func() *zap.Logger { + return logger + }, + fx.WithLogger(func(logger *zap.Logger) *fxevent.ZapLogger { + return &fxevent.ZapLogger{Logger: logger} + })), + fx.Invoke(initCheckRequiredConfig), + fx.Provide(NewIdentity), + db.Module, + storage.Module, + cron.Module, + protocols.BuildProtocols(config), + api.BuildApis(config), + fx.Provide(api.NewCasbin), + fx.Provide(func(lc fx.Lifecycle, config *viper.Viper) *http.Server { + srv := &http.Server{ + Addr: config.GetString("core.port"), + Handler: registry.GetRouter(), + } + + lc.Append(fx.Hook{ + OnStart: func(ctx context.Context) error { + ln, err := net.Listen("tcp", srv.Addr) + if err != nil { + return err + } + + go func() { + err := srv.Serve(ln) + if err != nil { + logger.Fatal("Failed to serve", zap.Error(err)) + } + }() + + return nil + }, + OnStop: func(ctx context.Context) error { + return srv.Shutdown(ctx) + }, + }) + return srv + }), + ).Run() } diff --git a/cmd/portal/portal.go b/cmd/portal/portal.go deleted file mode 100644 index 65bc7ea..0000000 --- a/cmd/portal/portal.go +++ /dev/null @@ -1,133 +0,0 @@ -package main - -import ( - "crypto/ed25519" - "git.lumeweb.com/LumeWeb/portal/account" - "git.lumeweb.com/LumeWeb/portal/api" - "git.lumeweb.com/LumeWeb/portal/cron" - "git.lumeweb.com/LumeWeb/portal/db" - "git.lumeweb.com/LumeWeb/portal/interfaces" - "git.lumeweb.com/LumeWeb/portal/protocols" - "git.lumeweb.com/LumeWeb/portal/storage" - "github.com/casbin/casbin/v2" - "github.com/go-co-op/gocron/v2" - "github.com/spf13/viper" - "go.uber.org/zap" - "gorm.io/gorm" - "net/http" - "strconv" -) - -var ( - _ interfaces.Portal = (*PortalImpl)(nil) -) - -type PortalImpl struct { - apiRegistry interfaces.APIRegistry - protocolRegistry interfaces.ProtocolRegistry - logger *zap.Logger - identity ed25519.PrivateKey - storage interfaces.StorageService - database interfaces.Database - casbin *casbin.Enforcer - accounts interfaces.AccountService - cron interfaces.CronService -} - -func NewPortal() interfaces.Portal { - portal := &PortalImpl{ - apiRegistry: api.NewRegistry(), - protocolRegistry: protocols.NewProtocolRegistry(), - storage: nil, - database: nil, - } - - storageServ := storage.NewStorageService(portal) - database := db.NewDatabase(portal) - accountService := account.NewAccountService(portal) - cronService := cron.NewCronServiceImpl(portal) - portal.storage = storageServ - portal.database = database - portal.accounts = accountService - portal.cron = cronService - - return portal -} - -func (p *PortalImpl) DatabaseService() interfaces.Database { - return p.database -} -func (p *PortalImpl) Database() *gorm.DB { - return p.database.Get() -} - -func (p *PortalImpl) Cron() gocron.Scheduler { - return p.cron.Scheduler() -} -func (p *PortalImpl) CronService() interfaces.CronService { - return p.cron -} - -func (p *PortalImpl) Initialize() error { - for _, initFunc := range getInitList() { - if err := initFunc(p); err != nil { - return err - } - } - - return nil -} -func (p *PortalImpl) Run() { - for _, initFunc := range getStartList() { - if err := initFunc(p); err != nil { - p.logger.Fatal("Failed to start", zap.Error(err)) - } - } - p.logger.Fatal("HTTP server stopped", zap.Error(http.ListenAndServe(":"+strconv.FormatUint(uint64(p.Config().GetUint("core.port")), 10), p.apiRegistry.Router()))) -} - -func (p *PortalImpl) Config() *viper.Viper { - return viper.GetViper() -} - -func (p *PortalImpl) Logger() *zap.Logger { - if p.logger == nil { - logger, _ := zap.NewDevelopment() - return logger - } - - return p.logger -} - -func (p *PortalImpl) ApiRegistry() interfaces.APIRegistry { - return p.apiRegistry -} - -func (p *PortalImpl) Identity() ed25519.PrivateKey { - return p.identity -} -func (p *PortalImpl) Storage() interfaces.StorageService { - return p.storage -} - -func (p *PortalImpl) SetIdentity(identity ed25519.PrivateKey) { - p.identity = identity -} - -func (p *PortalImpl) SetLogger(logger *zap.Logger) { - p.logger = logger -} -func (p *PortalImpl) ProtocolRegistry() interfaces.ProtocolRegistry { - return p.protocolRegistry -} -func (p *PortalImpl) Casbin() *casbin.Enforcer { - return p.casbin -} - -func (p *PortalImpl) SetCasbin(e *casbin.Enforcer) { - p.casbin = e -} - -func (p *PortalImpl) Accounts() interfaces.AccountService { - return p.accounts -} diff --git a/cmd/portal/start.go b/cmd/portal/start.go deleted file mode 100644 index 3adb021..0000000 --- a/cmd/portal/start.go +++ /dev/null @@ -1,32 +0,0 @@ -package main - -import "git.lumeweb.com/LumeWeb/portal/interfaces" - -type startFunc func(p interfaces.Portal) error - -func startProtocolRegistry(p interfaces.Portal) error { - for _, _func := range p.ProtocolRegistry().All() { - err := _func.Start() - if err != nil { - return err - } - } - - return nil -} - -func startDatabase(p interfaces.Portal) error { - return p.DatabaseService().Start() -} - -func startCron(p interfaces.Portal) error { - return p.CronService().Start() -} - -func getStartList() []startFunc { - return []startFunc{ - startProtocolRegistry, - startDatabase, - startCron, - } -} diff --git a/config/config.go b/config/config.go index 1ea44d0..82c8688 100644 --- a/config/config.go +++ b/config/config.go @@ -2,8 +2,9 @@ package config import ( "errors" - "git.lumeweb.com/LumeWeb/portal/interfaces" + _logger "git.lumeweb.com/LumeWeb/portal/logger" "github.com/spf13/viper" + "go.uber.org/zap" ) var ( @@ -14,8 +15,11 @@ var ( } ) -func Init(p interfaces.Portal) error { - logger := p.Logger() +func NewConfig(logger *zap.Logger) (*viper.Viper, error) { + if logger == nil { + logger = _logger.NewFallbackLogger() + } + viper.SetConfigName("config") viper.SetConfigType("yaml") @@ -32,14 +36,24 @@ func Init(p interfaces.Portal) error { logger.Info("Config file not found, using default settings.") err := viper.SafeWriteConfig() if err != nil { - return err + return nil, err } - return writeDefaults() + err = writeDefaults() + if err != nil { + return nil, err + } + + return viper.GetViper(), nil } - return err + return nil, err } - return writeDefaults() + err = writeDefaults() + if err != nil { + return nil, err + } + + return viper.GetViper(), nil } func writeDefaults() error { diff --git a/cron/cron.go b/cron/cron.go index 7e801ab..5f2ae2d 100644 --- a/cron/cron.go +++ b/cron/cron.go @@ -1,48 +1,64 @@ package cron import ( - "git.lumeweb.com/LumeWeb/portal/interfaces" + "context" + "go.uber.org/fx" "go.uber.org/zap" "github.com/go-co-op/gocron/v2" ) -var ( - _ interfaces.CronService = (*CronServiceImpl)(nil) +type CronService interface { + Scheduler() gocron.Scheduler + RegisterService(service CronableService) +} + +type CronableService interface { + LoadInitialTasks(cron CronService) error +} + +type CronServiceParams struct { + fx.In + Logger *zap.Logger + Scheduler gocron.Scheduler +} + +var Module = fx.Module("cron", + fx.Options( + fx.Provide(NewCronService), + ), ) type CronServiceImpl struct { scheduler gocron.Scheduler - services []interfaces.CronableService - portal interfaces.Portal + services []CronableService + logger *zap.Logger } func (c *CronServiceImpl) Scheduler() gocron.Scheduler { return c.scheduler } -func NewCronServiceImpl(portal interfaces.Portal) interfaces.CronService { - return &CronServiceImpl{ - portal: portal, - } -} - -func (c *CronServiceImpl) Init() error { - s, err := gocron.NewScheduler() - if err != nil { - return err +func NewCronService(lc fx.Lifecycle, params CronServiceParams) *CronServiceImpl { + sc := &CronServiceImpl{ + logger: params.Logger, + scheduler: params.Scheduler, } - c.scheduler = s + lc.Append(fx.Hook{ + OnStart: func(ctx context.Context) error { + return sc.start() + }, + }) - return nil + return sc } -func (c *CronServiceImpl) Start() error { +func (c *CronServiceImpl) start() error { for _, service := range c.services { err := service.LoadInitialTasks(c) if err != nil { - c.portal.Logger().Fatal("Failed to load initial tasks for service", zap.Error(err)) + c.logger.Fatal("Failed to load initial tasks for service", zap.Error(err)) } } @@ -51,6 +67,6 @@ func (c *CronServiceImpl) Start() error { return nil } -func (c *CronServiceImpl) RegisterService(service interfaces.CronableService) { +func (c *CronServiceImpl) RegisterService(service CronableService) { c.services = append(c.services, service) } diff --git a/db/db.go b/db/db.go index 3eff48a..717f544 100644 --- a/db/db.go +++ b/db/db.go @@ -1,69 +1,57 @@ package db import ( + "context" "fmt" "git.lumeweb.com/LumeWeb/portal/db/models" - "git.lumeweb.com/LumeWeb/portal/interfaces" "github.com/spf13/viper" - "go.uber.org/zap" + "go.uber.org/fx" "gorm.io/driver/mysql" "gorm.io/gorm" ) -var ( - _ interfaces.Database = (*DatabaseImpl)(nil) +type DatabaseParams struct { + fx.In + Config *viper.Viper +} + +var Module = fx.Module("db", + fx.Options( + fx.Provide(NewDatabase), + ), ) -type DatabaseImpl struct { - db *gorm.DB - portal interfaces.Portal -} +func NewDatabase(lc fx.Lifecycle, params DatabaseParams) *gorm.DB { + username := params.Config.GetString("core.db.username") + password := params.Config.GetString("core.db.password") + host := params.Config.GetString("core.db.host") + port := params.Config.GetString("core.db.port") + dbname := params.Config.GetString("core.db.name") + charset := params.Config.GetString("core.db.charset") -func NewDatabase(p interfaces.Portal) interfaces.Database { - return &DatabaseImpl{ - portal: p, - } -} - -// Init initializes the database connection -func (d *DatabaseImpl) Init() error { - // Retrieve DB config from Viper - username := viper.GetString("core.db.username") - password := viper.GetString("core.db.password") - host := viper.GetString("core.db.host") - port := viper.GetString("core.db.port") - dbname := viper.GetString("core.db.name") - charset := viper.GetString("core.db.charset") - - // Construct DSN dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=%s&parseTime=True&loc=Local", username, password, host, port, dbname, charset) - // Open DB connection db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{}) if err != nil { - d.portal.Logger().Error("Failed to connect to database", zap.Error(err)) + panic(err) } - d.db = db - return nil -} + lc.Append(fx.Hook{ + OnStart: func(ctx context.Context) error { + return db.AutoMigrate( + &models.APIKey{}, + &models.Blocklist{}, + &models.Download{}, + &models.Pin{}, + &models.PublicKey{}, + &models.Upload{}, + &models.User{}, + &models.S5Challenge{}, + &models.TusLock{}, + &models.TusUpload{}, + ) + }, + }) -// Start performs any additional setup -func (d *DatabaseImpl) Start() error { - return d.db.AutoMigrate( - &models.APIKey{}, - &models.Blocklist{}, - &models.Download{}, - &models.Pin{}, - &models.PublicKey{}, - &models.Upload{}, - &models.User{}, - &models.S5Challenge{}, - &models.TusLock{}, - &models.TusUpload{}, - ) -} - -func (d *DatabaseImpl) Get() *gorm.DB { - return d.db + return db } diff --git a/go.mod b/go.mod index 3bb0fb6..9308d89 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( git.lumeweb.com/LumeWeb/libs5-go v0.0.0-20240124213331-6b9a4fb7dc4d github.com/AfterShip/email-verifier v1.4.0 github.com/aws/aws-sdk-go-v2 v1.24.0 - github.com/aws/aws-sdk-go-v2/config v1.26.2 + github.com/aws/aws-sdk-go-v2/Config v1.26.2 github.com/aws/aws-sdk-go-v2/credentials v1.16.13 github.com/aws/aws-sdk-go-v2/service/s3 v1.47.7 github.com/casbin/casbin/v2 v2.81.0 @@ -26,6 +26,7 @@ require ( go.sia.tech/core v0.1.12 go.sia.tech/jape v0.11.1 go.sia.tech/renterd v1.0.2 + go.uber.org/fx v1.20.1 go.uber.org/zap v1.26.0 golang.org/x/crypto v0.18.0 gorm.io/driver/mysql v1.5.2 @@ -101,6 +102,7 @@ require ( gitlab.com/NebulousLabs/threadgroup v0.0.0-20200608151952-38921fbef213 // indirect go.sia.tech/mux v1.2.0 // indirect go.sia.tech/siad v1.5.10-0.20230228235644-3059c0b930ca // indirect + go.uber.org/dig v1.17.0 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/exp v0.0.0-20240119083558-1b970713d09a // indirect golang.org/x/net v0.20.0 // indirect diff --git a/go.sum b/go.sum index 040441c..055d0d4 100644 --- a/go.sum +++ b/go.sum @@ -24,8 +24,8 @@ github.com/aws/aws-sdk-go-v2 v1.24.0 h1:890+mqQ+hTpNuw0gGP6/4akolQkSToDJgHfQE7Aw github.com/aws/aws-sdk-go-v2 v1.24.0/go.mod h1:LNh45Br1YAkEKaAqvmE1m8FUx6a5b/V0oAKV7of29b4= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.5.4 h1:OCs21ST2LrepDfD3lwlQiOqIGp6JiEUqG84GzTDoyJs= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.5.4/go.mod h1:usURWEKSNNAcAZuzRn/9ZYPT8aZQkR7xcCtunK/LkJo= -github.com/aws/aws-sdk-go-v2/config v1.26.2 h1:+RWLEIWQIGgrz2pBPAUoGgNGs1TOyF4Hml7hCnYj2jc= -github.com/aws/aws-sdk-go-v2/config v1.26.2/go.mod h1:l6xqvUxt0Oj7PI/SUXYLNyZ9T/yBPn3YTQcJLLOdtR8= +github.com/aws/aws-sdk-go-v2/Config v1.26.2 h1:+RWLEIWQIGgrz2pBPAUoGgNGs1TOyF4Hml7hCnYj2jc= +github.com/aws/aws-sdk-go-v2/Config v1.26.2/go.mod h1:l6xqvUxt0Oj7PI/SUXYLNyZ9T/yBPn3YTQcJLLOdtR8= github.com/aws/aws-sdk-go-v2/credentials v1.16.13 h1:WLABQ4Cp4vXtXfOWOS3MEZKr6AAYUpMczLhgKtAjQ/8= github.com/aws/aws-sdk-go-v2/credentials v1.16.13/go.mod h1:Qg6x82FXwW0sJHzYruxGiuApNo31UEtJvXVSZAXeWiw= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.10 h1:w98BT5w+ao1/r5sUuiH6JkVzjowOKeOJRHERyy1vh58= @@ -56,6 +56,8 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.26.6 h1:HJeiuZ2fldpd0WqngyMR6KW7ofkX github.com/aws/aws-sdk-go-v2/service/sts v1.26.6/go.mod h1:XX5gh4CB7wAs4KhcF46G6C8a2i7eupU19dcAAE+EydU= github.com/aws/smithy-go v1.19.0 h1:KWFKQV80DpP3vJrrA9sVAHQ5gc2z8i4EzrLhLlWXcBM= github.com/aws/smithy-go v1.19.0/go.mod h1:NukqUGpCZIILqqiV0NIjeFh24kd/FAa4beRb6nbIUPE= +github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A= +github.com/benbjohnson/clock v1.3.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= @@ -333,6 +335,12 @@ go.sia.tech/renterd v1.0.2/go.mod h1:Lu70aWeRH90NRvd27m7yi0kDA0IPmeZJCk/SgVPq3T8 go.sia.tech/siad v1.5.10-0.20230228235644-3059c0b930ca h1:aZMg2AKevn7jKx+wlusWQfwSM5pNU9aGtRZme29q3O4= go.sia.tech/siad v1.5.10-0.20230228235644-3059c0b930ca/go.mod h1:h/1afFwpxzff6/gG5i1XdAgPK7dEY6FaibhK7N5F86Y= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= +go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/dig v1.17.0 h1:5Chju+tUvcC+N7N6EV08BJz41UZuO3BmHcN4A287ZLI= +go.uber.org/dig v1.17.0/go.mod h1:rTxpf7l5I0eBTlE6/9RL+lDybC7WFwY2QH55ZSjy1mU= +go.uber.org/fx v1.20.1 h1:zVwVQGS8zYvhh9Xxcu4w1M6ESyeMzebzj2NbSayZ4Mk= +go.uber.org/fx v1.20.1/go.mod h1:iSYNbHf2y55acNCwCXKx7LbWb5WG1Bnue5RDXz1OREg= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= diff --git a/interfaces/account.go b/interfaces/account.go deleted file mode 100644 index 769856d..0000000 --- a/interfaces/account.go +++ /dev/null @@ -1,17 +0,0 @@ -package interfaces - -import "git.lumeweb.com/LumeWeb/portal/db/models" - -type AccountService interface { - EmailExists(email string) (bool, models.User) - PubkeyExists(pubkey string) (bool, models.PublicKey) - AccountExists(id uint64) (bool, models.User) - CreateAccount(email string, password string) (*models.User, error) - AddPubkeyToAccount(user models.User, pubkey string) error - LoginPassword(email string, password string) (string, error) - LoginPubkey(pubkey string) (string, error) - AccountPins(id uint64, createdAfter uint64) ([]models.Pin, error) - DeletePinByHash(hash string, accountID uint) error - PinByHash(hash string, accountID uint) error - PinByID(uploadId uint, accountID uint) error -} diff --git a/interfaces/api.go b/interfaces/api.go deleted file mode 100644 index e47ae67..0000000 --- a/interfaces/api.go +++ /dev/null @@ -1,16 +0,0 @@ -package interfaces - -import ( - "git.lumeweb.com/LumeWeb/portal/api/router" -) - -type API interface { - Initialize(portal Portal, protocol Protocol) error -} - -type APIRegistry interface { - All() map[string]API - Register(name string, APIRegistry API) - Get(name string) (API, error) - Router() *router.ProtocolRouter -} diff --git a/interfaces/cron.go b/interfaces/cron.go deleted file mode 100644 index c35721a..0000000 --- a/interfaces/cron.go +++ /dev/null @@ -1,14 +0,0 @@ -package interfaces - -import "github.com/go-co-op/gocron/v2" - -type CronService interface { - Scheduler() gocron.Scheduler - RegisterService(service CronableService) - Service -} - -type CronableService interface { - LoadInitialTasks(cron CronService) error - Service -} diff --git a/interfaces/database.go b/interfaces/database.go deleted file mode 100644 index ce2ad30..0000000 --- a/interfaces/database.go +++ /dev/null @@ -1,8 +0,0 @@ -package interfaces - -import "gorm.io/gorm" - -type Database interface { - Get() *gorm.DB - Service -} diff --git a/interfaces/file.go b/interfaces/file.go deleted file mode 100644 index a8edb78..0000000 --- a/interfaces/file.go +++ /dev/null @@ -1,21 +0,0 @@ -package interfaces - -import ( - "git.lumeweb.com/LumeWeb/libs5-go/encoding" - "git.lumeweb.com/LumeWeb/portal/db/models" - "io" - "time" -) - -type File interface { - Record() (*models.Upload, error) - Hash() []byte - HashString() string - Name() string - Modtime() time.Time - Mime() string - Size() uint64 - CID() *encoding.CID - Exists() bool - io.ReadSeekCloser -} diff --git a/interfaces/portal.go b/interfaces/portal.go deleted file mode 100644 index 6d0d24a..0000000 --- a/interfaces/portal.go +++ /dev/null @@ -1,30 +0,0 @@ -package interfaces - -import ( - "crypto/ed25519" - "github.com/casbin/casbin/v2" - "github.com/go-co-op/gocron/v2" - "github.com/spf13/viper" - "go.uber.org/zap" - "gorm.io/gorm" -) - -type Portal interface { - Initialize() error - Run() - Config() *viper.Viper - Logger() *zap.Logger - ApiRegistry() APIRegistry - ProtocolRegistry() ProtocolRegistry - Identity() ed25519.PrivateKey - Storage() StorageService - SetIdentity(identity ed25519.PrivateKey) - SetLogger(logger *zap.Logger) - Database() *gorm.DB - DatabaseService() Database - Casbin() *casbin.Enforcer - SetCasbin(e *casbin.Enforcer) - Accounts() AccountService - CronService() CronService - Cron() gocron.Scheduler -} diff --git a/interfaces/protocol.go b/interfaces/protocol.go deleted file mode 100644 index f399466..0000000 --- a/interfaces/protocol.go +++ /dev/null @@ -1,12 +0,0 @@ -package interfaces - -type Protocol interface { - Initialize(portal Portal) error - Start() error -} - -type ProtocolRegistry interface { - Register(name string, protocol Protocol) - Get(name string) (Protocol, error) - All() map[string]Protocol -} diff --git a/interfaces/service.go b/interfaces/service.go deleted file mode 100644 index 7d47252..0000000 --- a/interfaces/service.go +++ /dev/null @@ -1,6 +0,0 @@ -package interfaces - -type Service interface { - Init() error - Start() error -} diff --git a/interfaces/storage.go b/interfaces/storage.go deleted file mode 100644 index e9737fc..0000000 --- a/interfaces/storage.go +++ /dev/null @@ -1,32 +0,0 @@ -package interfaces - -import ( - "git.lumeweb.com/LumeWeb/portal/db/models" - "github.com/aws/aws-sdk-go-v2/service/s3" - tusd "github.com/tus/tusd/v2/pkg/handler" - "io" -) - -type TusPreUploadCreateCallback func(hook tusd.HookEvent) (tusd.HTTPResponse, tusd.FileInfoChanges, error) -type TusPreFinishResponseCallback func(hook tusd.HookEvent) (tusd.HTTPResponse, error) - -type StorageService interface { - Portal() Portal - PutFileSmall(file io.ReadSeeker, bucket string, generateProof bool) ([]byte, error) - PutFile(file io.Reader, bucket string, hash []byte) error - BuildUploadBufferTus(basePath string, preUploadCb TusPreUploadCreateCallback, preFinishCb TusPreFinishResponseCallback) (*tusd.Handler, tusd.DataStore, *s3.Client, error) - FileExists(hash []byte) (bool, models.Upload) - GetHashSmall(file io.ReadSeeker) ([]byte, error) - GetHash(file io.Reader) ([]byte, int64, error) - GetFile(hash []byte, start int64) (io.ReadCloser, int64, error) - CreateUpload(hash []byte, mime string, uploaderID uint, uploaderIP string, size uint64, protocol string) (*models.Upload, error) - TusUploadExists(hash []byte) (bool, models.TusUpload) - CreateTusUpload(hash []byte, uploadID string, uploaderID uint, uploaderIP string, protocol string) (*models.TusUpload, error) - TusUploadProgress(uploadID string) error - TusUploadCompleted(uploadID string) error - DeleteTusUpload(uploadID string) error - ScheduleTusUpload(uploadID string, attempt int) error - Tus() *tusd.Handler - NewFile(hash []byte) File - Service -} diff --git a/logger/logger.go b/logger/logger.go index 722682c..a570042 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -7,7 +7,13 @@ import ( "os" ) -func Init(viper *viper.Viper) *zap.Logger { +func NewFallbackLogger() *zap.Logger { + logger, _ := zap.NewDevelopment() + + return logger +} + +func NewLogger() *zap.Logger { // Create a new atomic level atomicLevel := zap.NewAtomicLevel() diff --git a/protocols/protocols.go b/protocols/protocols.go index dfe6f26..2666aca 100644 --- a/protocols/protocols.go +++ b/protocols/protocols.go @@ -1,8 +1,47 @@ package protocols -import "git.lumeweb.com/LumeWeb/portal/interfaces" +import ( + "context" + "git.lumeweb.com/LumeWeb/portal/protocols/registry" + "github.com/spf13/viper" + "go.uber.org/fx" +) -func Init(registry interfaces.ProtocolRegistry) error { - registry.Register("s5", NewS5Protocol()) - return nil +func RegisterProtocols() { + registry.Register(registry.ProtocolEntry{ + Key: "s5", + Module: S5ProtocolModule, + InitFunc: InitS5Protocol, + }) +} + +func BuildProtocols(config *viper.Viper) fx.Option { + var options []fx.Option + for _, entry := range registry.GetRegistry() { + if config.GetBool("protocols." + entry.Key + ".enabled") { + options = append(options, entry.Module) + if entry.InitFunc != nil { + options = append(options, fx.Invoke(entry.InitFunc)) + } + } + } + + return fx.Options(options...) +} + +func SetupLifecycles(lifecycle fx.Lifecycle, protocols []registry.Protocol) { + for _, entry := range registry.GetRegistry() { + for _, protocol := range protocols { + if protocol.Name() == entry.Key { + lifecycle.Append(fx.Hook{ + OnStart: func(ctx context.Context) error { + return protocol.Start(ctx) + }, + OnStop: func(ctx context.Context) error { + return protocol.Stop(ctx) + }, + }) + } + } + } } diff --git a/protocols/registry.go b/protocols/registry.go deleted file mode 100644 index 09e7e18..0000000 --- a/protocols/registry.go +++ /dev/null @@ -1,43 +0,0 @@ -package protocols - -import ( - "errors" - "git.lumeweb.com/LumeWeb/portal/interfaces" -) - -var ( - _ interfaces.ProtocolRegistry = (*ProtocolRegistryImpl)(nil) -) - -type ProtocolRegistryImpl struct { - protocols map[string]interfaces.Protocol -} - -func NewProtocolRegistry() interfaces.ProtocolRegistry { - return &ProtocolRegistryImpl{ - protocols: make(map[string]interfaces.Protocol), - } -} - -func (r *ProtocolRegistryImpl) Register(name string, protocol interfaces.Protocol) { - if _, exists := r.protocols[name]; exists { - panic("protocol already registered") - } - r.protocols[name] = protocol -} - -func (r *ProtocolRegistryImpl) Get(name string) (interfaces.Protocol, error) { - protocol, exists := r.protocols[name] - if !exists { - return nil, errors.New("protocol not found") - } - return protocol, nil -} - -func (r *ProtocolRegistryImpl) All() map[string]interfaces.Protocol { - pMap := make(map[string]interfaces.Protocol) - for key, value := range r.protocols { - pMap[key] = value - } - return pMap -} diff --git a/protocols/registry/registry.go b/protocols/registry/registry.go new file mode 100644 index 0000000..8a60568 --- /dev/null +++ b/protocols/registry/registry.go @@ -0,0 +1,40 @@ +package registry + +import ( + "context" + "go.uber.org/fx" +) + +const GroupName = "protocols" + +type Protocol interface { + Name() string + Init() error + Start(ctx context.Context) error + Stop(ctx context.Context) error +} + +type ProtocolEntry struct { + Key string + Module fx.Option + InitFunc interface{} +} + +var protocolEntry []ProtocolEntry + +func Register(entry ProtocolEntry) { + protocolEntry = append(protocolEntry, entry) +} + +func GetRegistry() []ProtocolEntry { + return protocolEntry +} + +func FindProtocolByName(name string, protocols []Protocol) Protocol { + for _, protocol := range protocols { + if protocol.Name() == name { + return protocol + } + } + return nil +} diff --git a/protocols/s5.go b/protocols/s5.go index 0de097a..8376529 100644 --- a/protocols/s5.go +++ b/protocols/s5.go @@ -1,6 +1,7 @@ package protocols import ( + "context" "crypto/ed25519" "fmt" s5config "git.lumeweb.com/LumeWeb/libs5-go/config" @@ -10,47 +11,83 @@ import ( s5node "git.lumeweb.com/LumeWeb/libs5-go/node" s5storage "git.lumeweb.com/LumeWeb/libs5-go/storage" "git.lumeweb.com/LumeWeb/libs5-go/types" - "git.lumeweb.com/LumeWeb/portal/interfaces" + "git.lumeweb.com/LumeWeb/portal/protocols/registry" + "git.lumeweb.com/LumeWeb/portal/storage" + "github.com/spf13/viper" bolt "go.etcd.io/bbolt" + "go.uber.org/fx" "go.uber.org/zap" "time" ) var ( - _ interfaces.Protocol = (*S5Protocol)(nil) _ s5interfaces.ProviderStore = (*S5ProviderStore)(nil) + _ registry.Protocol = (*S5Protocol)(nil) ) type S5Protocol struct { - node s5interfaces.Node - portal interfaces.Portal + node s5interfaces.Node + config *viper.Viper + logger *zap.Logger + storage *storage.StorageServiceImpl + identity ed25519.PrivateKey + providerStore *S5ProviderStore } -func NewS5Protocol() *S5Protocol { - return &S5Protocol{} +type S5ProtocolParams struct { + fx.In + Config *viper.Viper + Logger *zap.Logger + Storage *storage.StorageServiceImpl + Identity ed25519.PrivateKey + ProviderStore *S5ProviderStore } -func (s *S5Protocol) Initialize(portal interfaces.Portal) error { - s.portal = portal +type S5ProtocolResult struct { + fx.Out + Protocol registry.Protocol `group:"protocol"` +} - logger := portal.Logger() - config := portal.Config() +var S5ProtocolModule = fx.Module("s5_protocol", + fx.Provide(NewS5Protocol), + fx.Provide(func(protocol *S5Protocol) *S5ProviderStore { + return &S5ProviderStore{proto: protocol} + }), +) +func NewS5Protocol( + params S5ProtocolParams, +) (S5ProtocolResult, error) { + return S5ProtocolResult{ + Protocol: &S5Protocol{ + config: params.Config, + logger: params.Logger, + storage: params.Storage, + identity: params.Identity, + providerStore: params.ProviderStore, + }, + }, nil +} + +func InitS5Protocol(s5 *S5Protocol) error { + return s5.Init() +} +func (s *S5Protocol) Init() error { cfg := &s5config.NodeConfig{ P2P: s5config.P2PConfig{ Network: "", Peers: s5config.PeersConfig{Initial: []string{}}, }, - KeyPair: s5ed.New(portal.Identity()), + KeyPair: s5ed.New(s.identity), DB: nil, - Logger: portal.Logger().Named("s5"), + Logger: s.logger.Named("s5"), HTTP: s5config.HTTPConfig{}, } - pconfig := config.Sub("protocol.s5") + pconfig := s.config.Sub("protocol.s5") if pconfig == nil { - logger.Fatal("Missing protocol.s5 config") + s.logger.Fatal("Missing protocol.s5 Config") } err := pconfig.Unmarshal(cfg) @@ -58,41 +95,41 @@ func (s *S5Protocol) Initialize(portal interfaces.Portal) error { return err } - cfg.HTTP.API.Domain = fmt.Sprintf("s5.%s", config.GetString("core.domain")) + cfg.HTTP.API.Domain = fmt.Sprintf("s5.%s", s.config.GetString("core.domain")) - if config.IsSet("core.externalPort") { - cfg.HTTP.API.Port = config.GetUint("core.externalPort") + if s.config.IsSet("core.externalPort") { + cfg.HTTP.API.Port = s.config.GetUint("core.externalPort") } else { - cfg.HTTP.API.Port = config.GetUint("core.port") + cfg.HTTP.API.Port = s.config.GetUint("core.port") } dbPath := pconfig.GetString("dbPath") if dbPath == "" { - logger.Fatal("protocol.s5.dbPath is required") + s.logger.Fatal("protocol.s5.dbPath is required") } _, p, err := ed25519.GenerateKey(nil) if err != nil { - logger.Fatal("Failed to generate key", zap.Error(err)) + s.logger.Fatal("Failed to generate key", zap.Error(err)) } cfg.KeyPair = s5ed.New(p) db, err := bolt.Open(dbPath, 0600, nil) if err != nil { - logger.Fatal("Failed to open db", zap.Error(err)) + s.logger.Fatal("Failed to open db", zap.Error(err)) } cfg.DB = db s.node = s5node.NewNode(cfg) - s.node.SetProviderStore(&S5ProviderStore{proto: s}) + s.node.SetProviderStore(s.providerStore) return nil } -func (s *S5Protocol) Start() error { +func (s *S5Protocol) Start(ctx context.Context) error { err := s.node.Start() if err != nil { return err @@ -104,10 +141,19 @@ func (s *S5Protocol) Start() error { return err } - s.portal.Logger().Info("S5 protocol started", zap.String("identity", identity), zap.String("network", s.node.NetworkId()), zap.String("domain", s.node.Config().HTTP.API.Domain)) + s.logger.Info("S5 protocol started", zap.String("identity", identity), zap.String("network", s.node.NetworkId()), zap.String("domain", s.node.Config().HTTP.API.Domain)) return nil } + +func (s *S5Protocol) Name() string { + return "s5" +} + +func (s *S5Protocol) Stop(ctx context.Context) error { + return nil +} + func (s *S5Protocol) Node() s5interfaces.Node { return s.node } @@ -122,13 +168,13 @@ func (s S5ProviderStore) CanProvide(hash *encoding.Multihash, kind []types.Stora case types.StorageLocationTypeArchive, types.StorageLocationTypeFile, types.StorageLocationTypeFull: rawHash := hash.HashBytes() - if exists, upload := s.proto.portal.Storage().TusUploadExists(rawHash); exists { + if exists, upload := s.proto.storage.TusUploadExists(rawHash); exists { if upload.Completed { return true } } - if exists, _ := s.proto.portal.Storage().FileExists(rawHash); exists { + if exists, _ := s.proto.storage.FileExists(rawHash); exists { return true } } @@ -146,7 +192,7 @@ func (s S5ProviderStore) Provide(hash *encoding.Multihash, kind []types.StorageL case types.StorageLocationTypeArchive: return s5storage.NewStorageLocation(int(types.StorageLocationTypeArchive), []string{}, calculateExpiry(24*time.Hour)), nil case types.StorageLocationTypeFile, types.StorageLocationTypeFull: - return s5storage.NewStorageLocation(int(types.StorageLocationTypeFull), []string{generateDownloadUrl(hash, s.proto.portal)}, calculateExpiry(24*time.Hour)), nil + return s5storage.NewStorageLocation(int(types.StorageLocationTypeFull), []string{generateDownloadUrl(hash, s.proto.config, s.proto.logger)}, calculateExpiry(24*time.Hour)), nil } } @@ -161,12 +207,12 @@ func calculateExpiry(duration time.Duration) int64 { return time.Now().Add(duration).Unix() } -func generateDownloadUrl(hash *encoding.Multihash, portal interfaces.Portal) string { - domain := portal.Config().GetString("core.domain") +func generateDownloadUrl(hash *encoding.Multihash, config *viper.Viper, logger *zap.Logger) string { + domain := config.GetString("core.domain") hashStr, err := hash.ToBase64Url() if err != nil { - portal.Logger().Error("error encoding hash", zap.Error(err)) + logger.Error("error encoding hash", zap.Error(err)) } return fmt.Sprintf("https://s5.%s/s5/download/%s", domain, hashStr) diff --git a/storage/file.go b/storage/file.go index f3d9472..f159572 100644 --- a/storage/file.go +++ b/storage/file.go @@ -1,30 +1,25 @@ package storage import ( - "encoding/hex" - "errors" - "git.lumeweb.com/LumeWeb/libs5-go/encoding" - "git.lumeweb.com/LumeWeb/libs5-go/types" - "git.lumeweb.com/LumeWeb/portal/db/models" - "git.lumeweb.com/LumeWeb/portal/interfaces" - "io" - "time" -) - -var ( - _ interfaces.File = (*FileImpl)(nil) + "encoding/hex" + "errors" + "git.lumeweb.com/LumeWeb/libs5-go/encoding" + "git.lumeweb.com/LumeWeb/libs5-go/types" + "git.lumeweb.com/LumeWeb/portal/db/models" + "io" + "time" ) type FileImpl struct { reader io.ReadCloser hash []byte - storage interfaces.StorageService + storage *StorageServiceImpl record *models.Upload cid *encoding.CID read bool } -func NewFile(hash []byte, storage interfaces.StorageService) *FileImpl { +func NewFile(hash []byte, storage *StorageServiceImpl) *FileImpl { return &FileImpl{hash: hash, storage: storage, read: false} } diff --git a/storage/locker.go b/storage/locker.go index 99a0e0d..3764ea9 100644 --- a/storage/locker.go +++ b/storage/locker.go @@ -3,9 +3,9 @@ package storage import ( "context" "git.lumeweb.com/LumeWeb/portal/db/models" - "git.lumeweb.com/LumeWeb/portal/interfaces" tusd "github.com/tus/tusd/v2/pkg/handler" "go.uber.org/zap" + "gorm.io/gorm" "os" "sync" "time" @@ -17,9 +17,11 @@ var ( ) type MySQLLocker struct { - storage interfaces.StorageService + storage *StorageServiceImpl AcquirerPollInterval time.Duration HolderPollInterval time.Duration + db *gorm.DB + logger *zap.Logger } type Lock struct { @@ -32,14 +34,14 @@ type Lock struct { once sync.Once } -func NewMySQLLocker(storage interfaces.StorageService) *MySQLLocker { - return &MySQLLocker{storage: storage, HolderPollInterval: 5 * time.Second, AcquirerPollInterval: 2 * time.Second} +func NewMySQLLocker(db *gorm.DB, logger *zap.Logger) *MySQLLocker { + return &MySQLLocker{HolderPollInterval: 5 * time.Second, AcquirerPollInterval: 2 * time.Second, db: db, logger: logger} } func (l *Lock) released() error { - err := l.lockRecord.Released(l.locker.storage.Portal().Database()) + err := l.lockRecord.Released(l.locker.db) if err != nil { - l.locker.storage.Portal().Logger().Error("Failed to release lock", zap.Error(err)) + l.locker.logger.Error("Failed to release lock", zap.Error(err)) return err } @@ -47,7 +49,7 @@ func (l *Lock) released() error { } func (l *Lock) Lock(ctx context.Context, requestUnlock func()) error { - db := l.locker.storage.Portal().Database() + db := l.locker.db for { err := l.lockRecord.TryLock(db, ctx) @@ -111,7 +113,7 @@ func (l *Lock) Unlock() error { close(l.stopHolderPoll) }) - return l.lockRecord.Delete(l.locker.storage.Portal().Database()) + return l.lockRecord.Delete(l.locker.db) } func (m *MySQLLocker) NewLock(id string) (tusd.Lock, error) { diff --git a/storage/storage.go b/storage/storage.go index f4e8c2b..2679be8 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -8,9 +8,10 @@ import ( "fmt" "git.lumeweb.com/LumeWeb/libs5-go/encoding" "git.lumeweb.com/LumeWeb/libs5-go/types" + "git.lumeweb.com/LumeWeb/portal/account" "git.lumeweb.com/LumeWeb/portal/api/middleware" + "git.lumeweb.com/LumeWeb/portal/cron" "git.lumeweb.com/LumeWeb/portal/db/models" - "git.lumeweb.com/LumeWeb/portal/interfaces" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" @@ -18,12 +19,15 @@ import ( s3types "github.com/aws/aws-sdk-go-v2/service/s3/types" "github.com/go-co-op/gocron/v2" "github.com/google/uuid" + "github.com/spf13/viper" tusd "github.com/tus/tusd/v2/pkg/handler" - s3store "github.com/tus/tusd/v2/pkg/s3store" + "github.com/tus/tusd/v2/pkg/s3store" "go.sia.tech/renterd/api" busClient "go.sia.tech/renterd/bus/client" workerClient "go.sia.tech/renterd/worker/client" + "go.uber.org/fx" "go.uber.org/zap" + "gorm.io/gorm" "io" "lukechampine.com/blake3" "net/http" @@ -32,17 +36,35 @@ import ( "time" ) -var ( - _ interfaces.StorageService = (*StorageServiceImpl)(nil) +type TusPreUploadCreateCallback func(hook tusd.HookEvent) (tusd.HTTPResponse, tusd.FileInfoChanges, error) +type TusPreFinishResponseCallback func(hook tusd.HookEvent) (tusd.HTTPResponse, error) + +type StorageServiceParams struct { + fx.In + Config *viper.Viper + Logger *zap.Logger + Db *gorm.DB + Accounts *account.AccountServiceImpl + Cron *cron.CronServiceImpl +} + +var Module = fx.Module("storage", + fx.Provide( + NewStorageService, + ), ) type StorageServiceImpl struct { - portal interfaces.Portal busClient *busClient.Client workerClient *workerClient.Client tus *tusd.Handler tusStore tusd.DataStore s3Client *s3.Client + config *viper.Viper + logger *zap.Logger + db *gorm.DB + accounts *account.AccountServiceImpl + cron *cron.CronServiceImpl } func (s *StorageServiceImpl) Tus() *tusd.Handler { @@ -53,13 +75,13 @@ func (s *StorageServiceImpl) Start() error { return nil } -func (s *StorageServiceImpl) Portal() interfaces.Portal { - return s.portal -} - -func NewStorageService(portal interfaces.Portal) interfaces.StorageService { +func NewStorageService(params StorageServiceParams) *StorageServiceImpl { return &StorageServiceImpl{ - portal: portal, + config: params.Config, + logger: params.Logger, + db: params.Db, + accounts: params.Accounts, + cron: params.Cron, } } @@ -103,12 +125,12 @@ func (s StorageServiceImpl) PutFile(file io.Reader, bucket string, hash []byte) return nil } -func (s *StorageServiceImpl) BuildUploadBufferTus(basePath string, preUploadCb interfaces.TusPreUploadCreateCallback, preFinishCb interfaces.TusPreFinishResponseCallback) (*tusd.Handler, tusd.DataStore, *s3.Client, error) { +func (s *StorageServiceImpl) BuildUploadBufferTus(basePath string, preUploadCb TusPreUploadCreateCallback, preFinishCb TusPreFinishResponseCallback) (*tusd.Handler, tusd.DataStore, *s3.Client, error) { customResolver := aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...interface{}) (aws.Endpoint, error) { if service == s3.ServiceID { return aws.Endpoint{ - URL: s.portal.Config().GetString("core.storage.s3.endpoint"), - SigningRegion: s.portal.Config().GetString("core.storage.s3.region"), + URL: s.config.GetString("core.storage.s3.endpoint"), + SigningRegion: s.config.GetString("core.storage.s3.region"), }, nil } return aws.Endpoint{}, &aws.EndpointNotFoundError{} @@ -117,8 +139,8 @@ func (s *StorageServiceImpl) BuildUploadBufferTus(basePath string, preUploadCb i cfg, err := config.LoadDefaultConfig(context.TODO(), config.WithRegion("us-east-1"), config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider( - s.portal.Config().GetString("core.storage.s3.accessKey"), - s.portal.Config().GetString("core.storage.s3.secretKey"), + s.config.GetString("core.storage.s3.accessKey"), + s.config.GetString("core.storage.s3.secretKey"), "", )), config.WithEndpointResolverWithOptions(customResolver), @@ -129,9 +151,9 @@ func (s *StorageServiceImpl) BuildUploadBufferTus(basePath string, preUploadCb i s3Client := s3.NewFromConfig(cfg) - store := s3store.New(s.portal.Config().GetString("core.storage.s3.bufferBucket"), s3Client) + store := s3store.New(s.config.GetString("core.storage.s3.bufferBucket"), s3Client) - locker := NewMySQLLocker(s) + locker := NewMySQLLocker(s.db, s.logger) composer := tusd.NewStoreComposer() store.UseIn(composer) @@ -151,10 +173,10 @@ func (s *StorageServiceImpl) BuildUploadBufferTus(basePath string, preUploadCb i return handler, store, s3Client, err } -func (s *StorageServiceImpl) Init() error { +func (s *StorageServiceImpl) init() error { - addr := s.portal.Config().GetString("core.sia.url") - passwd := s.portal.Config().GetString("core.sia.key") + addr := s.config.GetString("core.sia.url") + passwd := s.config.GetString("core.sia.key") addrURL, err := url.Parse(addr) @@ -210,13 +232,13 @@ func (s *StorageServiceImpl) Init() error { s.tusStore = store s.s3Client = s3client - s.portal.CronService().RegisterService(s) + s.cron.RegisterService(s) go s.tusWorker() return nil } -func (s *StorageServiceImpl) LoadInitialTasks(cron interfaces.CronService) error { +func (s *StorageServiceImpl) LoadInitialTasks(cron cron.CronService) error { return nil } @@ -250,7 +272,7 @@ func (s *StorageServiceImpl) FileExists(hash []byte) (bool, models.Upload) { hashStr := hex.EncodeToString(hash) var upload models.Upload - result := s.portal.Database().Model(&models.Upload{}).Where(&models.Upload{Hash: hashStr}).First(&upload) + result := s.db.Model(&models.Upload{}).Where(&models.Upload{Hash: hashStr}).First(&upload) return result.RowsAffected > 0, upload } @@ -293,7 +315,7 @@ func (s *StorageServiceImpl) CreateUpload(hash []byte, mime string, uploaderID u Size: size, } - result := s.portal.Database().Create(upload) + result := s.db.Create(upload) if result.Error != nil { return nil, result.Error @@ -309,7 +331,7 @@ func (s *StorageServiceImpl) tusWorker() { hash, ok := info.Upload.MetaData["hash"] errorResponse := tusd.HTTPResponse{StatusCode: 400, Header: nil} if !ok { - s.portal.Logger().Error("Missing hash in metadata") + s.logger.Error("Missing hash in metadata") continue } @@ -317,7 +339,7 @@ func (s *StorageServiceImpl) tusWorker() { if !ok { errorResponse.Body = "Missing user id in context" info.Upload.StopUpload(errorResponse) - s.portal.Logger().Error("Missing user id in context") + s.logger.Error("Missing user id in context") continue } @@ -328,7 +350,7 @@ func (s *StorageServiceImpl) tusWorker() { if err != nil { errorResponse.Body = "Could not decode hash" info.Upload.StopUpload(errorResponse) - s.portal.Logger().Error("Could not decode hash", zap.Error(err)) + s.logger.Error("Could not decode hash", zap.Error(err)) continue } @@ -336,19 +358,19 @@ func (s *StorageServiceImpl) tusWorker() { if err != nil { errorResponse.Body = "Could not create tus upload" info.Upload.StopUpload(errorResponse) - s.portal.Logger().Error("Could not create tus upload", zap.Error(err)) + s.logger.Error("Could not create tus upload", zap.Error(err)) continue } case info := <-s.tus.UploadProgress: err := s.TusUploadProgress(info.Upload.ID) if err != nil { - s.portal.Logger().Error("Could not update tus upload", zap.Error(err)) + s.logger.Error("Could not update tus upload", zap.Error(err)) continue } case info := <-s.tus.TerminatedUploads: err := s.DeleteTusUpload(info.Upload.ID) if err != nil { - s.portal.Logger().Error("Could not delete tus upload", zap.Error(err)) + s.logger.Error("Could not delete tus upload", zap.Error(err)) continue } @@ -358,12 +380,12 @@ func (s *StorageServiceImpl) tusWorker() { } err := s.TusUploadCompleted(info.Upload.ID) if err != nil { - s.portal.Logger().Error("Could not complete tus upload", zap.Error(err)) + s.logger.Error("Could not complete tus upload", zap.Error(err)) continue } err = s.ScheduleTusUpload(info.Upload.ID, 0) if err != nil { - s.portal.Logger().Error("Could not schedule tus upload", zap.Error(err)) + s.logger.Error("Could not schedule tus upload", zap.Error(err)) continue } @@ -375,7 +397,7 @@ func (s *StorageServiceImpl) TusUploadExists(hash []byte) (bool, models.TusUploa hashStr := hex.EncodeToString(hash) var upload models.TusUpload - result := s.portal.Database().Model(&models.TusUpload{}).Where(&models.TusUpload{Hash: hashStr}).First(&upload) + result := s.db.Model(&models.TusUpload{}).Where(&models.TusUpload{Hash: hashStr}).First(&upload) return result.RowsAffected > 0, upload } @@ -392,7 +414,7 @@ func (s *StorageServiceImpl) CreateTusUpload(hash []byte, uploadID string, uploa Protocol: protocol, } - result := s.portal.Database().Create(upload) + result := s.db.Create(upload) if result.Error != nil { return nil, result.Error @@ -405,13 +427,13 @@ func (s *StorageServiceImpl) TusUploadProgress(uploadID string) error { find := &models.TusUpload{UploadID: uploadID} var upload models.TusUpload - result := s.portal.Database().Model(&models.TusUpload{}).Where(find).First(&upload) + result := s.db.Model(&models.TusUpload{}).Where(find).First(&upload) if result.RowsAffected == 0 { return errors.New("upload not found") } - result = s.portal.Database().Model(&models.TusUpload{}).Where(find).Update("updated_at", time.Now()) + result = s.db.Model(&models.TusUpload{}).Where(find).Update("updated_at", time.Now()) if result.Error != nil { return result.Error @@ -424,18 +446,18 @@ func (s *StorageServiceImpl) TusUploadCompleted(uploadID string) error { find := &models.TusUpload{UploadID: uploadID} var upload models.TusUpload - result := s.portal.Database().Model(&models.TusUpload{}).Where(find).First(&upload) + result := s.db.Model(&models.TusUpload{}).Where(find).First(&upload) if result.RowsAffected == 0 { return errors.New("upload not found") } - result = s.portal.Database().Model(&models.TusUpload{}).Where(find).Update("completed", true) + result = s.db.Model(&models.TusUpload{}).Where(find).Update("completed", true) return nil } func (s *StorageServiceImpl) DeleteTusUpload(uploadID string) error { - result := s.portal.Database().Where(&models.TusUpload{UploadID: uploadID}).Delete(&models.TusUpload{}) + result := s.db.Where(&models.TusUpload{UploadID: uploadID}).Delete(&models.TusUpload{}) if result.Error != nil { return result.Error @@ -448,7 +470,7 @@ func (s *StorageServiceImpl) ScheduleTusUpload(uploadID string, attempt int) err find := &models.TusUpload{UploadID: uploadID} var upload models.TusUpload - result := s.portal.Database().Model(&models.TusUpload{}).Where(find).First(&upload) + result := s.db.Model(&models.TusUpload{}).Where(find).First(&upload) if result.RowsAffected == 0 { return errors.New("upload not found") @@ -460,18 +482,18 @@ func (s *StorageServiceImpl) ScheduleTusUpload(uploadID string, attempt int) err job = gocron.OneTimeJob(gocron.OneTimeJobStartDateTime(time.Now().Add(time.Duration(attempt) * time.Minute))) } - _, err := s.portal.Cron().NewJob(job, task, gocron.WithEventListeners(gocron.AfterJobRunsWithError(func(jobID uuid.UUID, jobName string, err error) { - s.portal.Logger().Error("Error running job", zap.Error(err)) + _, err := s.cron.Scheduler().NewJob(job, task, gocron.WithEventListeners(gocron.AfterJobRunsWithError(func(jobID uuid.UUID, jobName string, err error) { + s.logger.Error("Error running job", zap.Error(err)) err = s.ScheduleTusUpload(uploadID, attempt+1) if err != nil { - s.portal.Logger().Error("Error rescheduling job", zap.Error(err)) + s.logger.Error("Error rescheduling job", zap.Error(err)) } }), gocron.AfterJobRuns(func(jobID uuid.UUID, jobName string) { - s.portal.Logger().Info("Job finished", zap.String("jobName", jobName), zap.String("uploadID", uploadID)) + s.logger.Info("Job finished", zap.String("jobName", jobName), zap.String("uploadID", uploadID)) err := s.DeleteTusUpload(uploadID) if err != nil { - s.portal.Logger().Error("Error deleting tus upload", zap.Error(err)) + s.logger.Error("Error deleting tus upload", zap.Error(err)) } }))) @@ -489,38 +511,38 @@ func (s *StorageServiceImpl) buildNewTusUploadTask(upload *models.TusUpload) (jo ctx := context.Background() tusUpload, err := s.tusStore.GetUpload(ctx, upload.UploadID) if err != nil { - s.portal.Logger().Error("Could not get upload", zap.Error(err)) + s.logger.Error("Could not get upload", zap.Error(err)) return err } reader, err := tusUpload.GetReader(ctx) if err != nil { - s.portal.Logger().Error("Could not get tus file", zap.Error(err)) + s.logger.Error("Could not get tus file", zap.Error(err)) return err } hash, byteCount, err := s.GetHash(reader) if err != nil { - s.portal.Logger().Error("Could not compute hash", zap.Error(err)) + s.logger.Error("Could not compute hash", zap.Error(err)) return err } dbHash, err := hex.DecodeString(upload.Hash) if err != nil { - s.portal.Logger().Error("Could not decode hash", zap.Error(err)) + s.logger.Error("Could not decode hash", zap.Error(err)) return err } if !bytes.Equal(hash, dbHash) { - s.portal.Logger().Error("Hashes do not match", zap.Any("upload", upload), zap.Any("hash", hash), zap.Any("dbHash", dbHash)) + s.logger.Error("Hashes do not match", zap.Any("upload", upload), zap.Any("hash", hash), zap.Any("dbHash", dbHash)) return err } reader, err = tusUpload.GetReader(ctx) if err != nil { - s.portal.Logger().Error("Could not get tus file", zap.Error(err)) + s.logger.Error("Could not get tus file", zap.Error(err)) return err } @@ -529,7 +551,7 @@ func (s *StorageServiceImpl) buildNewTusUploadTask(upload *models.TusUpload) (jo _, err = reader.Read(mimeBuf[:]) if err != nil { - s.portal.Logger().Error("Could not read mime", zap.Error(err)) + s.logger.Error("Could not read mime", zap.Error(err)) return err } @@ -537,28 +559,28 @@ func (s *StorageServiceImpl) buildNewTusUploadTask(upload *models.TusUpload) (jo upload.MimeType = mimeType - if tx := s.Portal().Database().Save(upload); tx.Error != nil { - s.portal.Logger().Error("Could not update tus upload", zap.Error(tx.Error)) + if tx := s.db.Save(upload); tx.Error != nil { + s.logger.Error("Could not update tus upload", zap.Error(tx.Error)) return tx.Error } reader, err = tusUpload.GetReader(ctx) if err != nil { - s.portal.Logger().Error("Could not get tus file", zap.Error(err)) + s.logger.Error("Could not get tus file", zap.Error(err)) return err } err = s.PutFile(reader, upload.Protocol, dbHash) if err != nil { - s.portal.Logger().Error("Could not upload file", zap.Error(err)) + s.logger.Error("Could not upload file", zap.Error(err)) return err } s3InfoId, _ := splitS3Ids(upload.UploadID) _, err = s.s3Client.DeleteObjects(ctx, &s3.DeleteObjectsInput{ - Bucket: aws.String(s.portal.Config().GetString("core.storage.s3.bufferBucket")), + Bucket: aws.String(s.config.GetString("core.storage.s3.bufferBucket")), Delete: &s3types.Delete{ Objects: []s3types.ObjectIdentifier{ { @@ -573,19 +595,19 @@ func (s *StorageServiceImpl) buildNewTusUploadTask(upload *models.TusUpload) (jo }) if err != nil { - s.portal.Logger().Error("Could not delete upload metadata", zap.Error(err)) + s.logger.Error("Could not delete upload metadata", zap.Error(err)) return err } newUpload, err := s.CreateUpload(dbHash, mimeType, upload.UploaderID, upload.UploaderIP, uint64(byteCount), upload.Protocol) if err != nil { - s.portal.Logger().Error("Could not create upload", zap.Error(err)) + s.logger.Error("Could not create upload", zap.Error(err)) return err } - err = s.portal.Accounts().PinByID(newUpload.ID, upload.UploaderID) + err = s.accounts.PinByID(newUpload.ID, upload.UploaderID) if err != nil { - s.portal.Logger().Error("Could not pin upload", zap.Error(err)) + s.logger.Error("Could not pin upload", zap.Error(err)) return err } @@ -665,6 +687,6 @@ func (s *StorageServiceImpl) GetFile(hash []byte, start int64) (io.ReadCloser, i return object.Content, int64(upload.Size), nil } -func (s *StorageServiceImpl) NewFile(hash []byte) interfaces.File { +func (s *StorageServiceImpl) NewFile(hash []byte) *FileImpl { return NewFile(hash, s) }