diff --git a/api/api.go b/api/api.go index 92581b0..1f029bd 100644 --- a/api/api.go +++ b/api/api.go @@ -6,8 +6,6 @@ import ( "git.lumeweb.com/LumeWeb/portal/config" - "git.lumeweb.com/LumeWeb/portal/api/middleware" - "git.lumeweb.com/LumeWeb/portal/api/registry" "go.uber.org/fx" ) @@ -39,18 +37,6 @@ func BuildApis(cm *config.Manager) fx.Option { return nil })) - options = append(options, fx.Invoke(func(params initParams) error { - for _, protocol := range params.Protocols { - routes, err := protocol.Routes() - if err != nil { - return err - } - middleware.RegisterProtocolSubdomain(cm, routes, protocol.Name()) - } - - return nil - })) - return fx.Module("api", fx.Options(options...)) } diff --git a/api/middleware/middleware.go b/api/middleware/middleware.go index 7349bea..47396c7 100644 --- a/api/middleware/middleware.go +++ b/api/middleware/middleware.go @@ -11,9 +11,7 @@ import ( "git.lumeweb.com/LumeWeb/portal/config" "git.lumeweb.com/LumeWeb/portal/account" - "git.lumeweb.com/LumeWeb/portal/api/registry" "github.com/golang-jwt/jwt/v5" - "github.com/julienschmidt/httprouter" "go.sia.tech/jape" ) @@ -65,12 +63,6 @@ func ApplyMiddlewares(handler jape.Handler, middlewares ...interface{}) jape.Han } return handler } -func RegisterProtocolSubdomain(config *config.Manager, mux *httprouter.Router, name string) { - router := registry.GetRouter() - domain := config.Config().Core.Domain - - (router)[name+"."+domain] = mux -} func FindAuthToken(r *http.Request, cookieName string, queryParam string) string { authHeader := ParseAuthTokenHeader(r.Header) diff --git a/api/registry/registry.go b/api/registry/registry.go index 4820057..2f65aae 100644 --- a/api/registry/registry.go +++ b/api/registry/registry.go @@ -3,8 +3,6 @@ package registry import ( "context" - "github.com/julienschmidt/httprouter" - router2 "git.lumeweb.com/LumeWeb/portal/api/router" "go.uber.org/fx" ) @@ -14,7 +12,6 @@ type API interface { Init() error Start(ctx context.Context) error Stop(ctx context.Context) error - Routes() (*httprouter.Router, error) } type APIEntry struct { @@ -23,10 +20,10 @@ type APIEntry struct { } var apiRegistry []APIEntry -var router router2.ProtocolRouter +var router *router2.APIRouter func init() { - router = make(router2.ProtocolRouter) + router = router2.NewAPIRouter() } func Register(entry APIEntry) { @@ -37,6 +34,6 @@ func GetRegistry() []APIEntry { return apiRegistry } -func GetRouter() router2.ProtocolRouter { +func GetRouter() *router2.APIRouter { return router } diff --git a/api/router/router.go b/api/router/router.go index f5ac4ba..63595f3 100644 --- a/api/router/router.go +++ b/api/router/router.go @@ -1,14 +1,92 @@ package router -import "net/http" +import ( + "net/http" -type ProtocolRouter map[string]http.Handler + "git.lumeweb.com/LumeWeb/portal/config" + + "go.uber.org/zap" + + "github.com/julienschmidt/httprouter" +) + +type RoutableAPI interface { + Name() string + Can(w http.ResponseWriter, r *http.Request) bool + Handle(w http.ResponseWriter, r *http.Request) + Routes() (*httprouter.Router, error) +} + +type APIRouter struct { + apis map[string]RoutableAPI + apiDomain map[string]string + apiHandlers map[string]http.Handler + logger *zap.Logger + config *config.Manager +} // Implement the ServeHTTP method on our new type -func (hs ProtocolRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if handler := hs[r.Host]; handler != nil { +func (hs APIRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if handler := hs.getHandlerByDomain(r.Host); handler != nil { handler.ServeHTTP(w, r) - } else { - http.Error(w, "Forbidden", 403) // Or Redirect? + return + } + + for _, api := range hs.apis { + if api.Can(w, r) { + api.Handle(w, r) + return + } + } + + http.NotFound(w, r) +} + +func (hs *APIRouter) RegisterAPI(impl RoutableAPI) { + name := impl.Name() + hs.apis[name] = impl + hs.apiDomain[name+"."+hs.config.Config().Core.Domain] = name +} + +func (hs *APIRouter) getHandlerByDomain(domain string) http.Handler { + if apiName := hs.apiDomain[domain]; apiName != "" { + return hs.getHandler(apiName) + } + + return nil +} + +func (hs *APIRouter) getHandler(protocol string) http.Handler { + if handler := hs.apiHandlers[protocol]; handler == nil { + if proto := hs.apis[protocol]; proto == nil { + hs.logger.Fatal("Protocol not found", zap.String("protocol", protocol)) + return nil + } + + routes, err := hs.apis[protocol].Routes() + + if err != nil { + hs.logger.Fatal("Error getting routes", zap.Error(err)) + return nil + } + + hs.apiHandlers[protocol] = routes + } + + return hs.apiHandlers[protocol] +} + +func NewAPIRouter() *APIRouter { + return &APIRouter{ + apis: make(map[string]RoutableAPI), + apiHandlers: make(map[string]http.Handler), } } + +func (hs *APIRouter) SetLogger(logger *zap.Logger) { + hs.logger = logger +} + +func (hs *APIRouter) SetConfig(config *config.Manager) { + hs.config = config +} diff --git a/cmd/portal/init.go b/cmd/portal/init.go index f088419..16f8152 100644 --- a/cmd/portal/init.go +++ b/cmd/portal/init.go @@ -7,6 +7,8 @@ import ( "net/http" "strconv" + "git.lumeweb.com/LumeWeb/portal/api/router" + "git.lumeweb.com/LumeWeb/portal/config" "git.lumeweb.com/LumeWeb/portal/api/registry" @@ -62,11 +64,32 @@ func NewIdentity(config *config.Manager, logger *zap.Logger) (ed25519.PrivateKey return ed25519.PrivateKey(wallet.KeyFromSeed(&seed, 0)), nil } -func NewServer(lc fx.Lifecycle, config *config.Manager, logger *zap.Logger) (*http.Server, error) { +type NewServerParams struct { + fx.In + Config *config.Manager + Logger *zap.Logger + APIs []registry.API `group:"api"` +} + +func NewServer(lc fx.Lifecycle, params NewServerParams) (*http.Server, error) { + r := registry.GetRouter() + + r.SetConfig(params.Config) + r.SetLogger(params.Logger) + + for _, api := range params.APIs { + routableAPI, ok := interface{}(api).(router.RoutableAPI) + + if !ok { + params.Logger.Fatal("API does not implement RoutableAPI", zap.String("api", api.Name())) + } + + r.RegisterAPI(routableAPI) + } srv := &http.Server{ - Addr: ":" + strconv.FormatUint(uint64(config.Config().Core.Port), 10), - Handler: registry.GetRouter(), + Addr: ":" + strconv.FormatUint(uint64(params.Config.Config().Core.Port), 10), + Handler: r, } lc.Append(fx.Hook{ @@ -79,7 +102,7 @@ func NewServer(lc fx.Lifecycle, config *config.Manager, logger *zap.Logger) (*ht go func() { err := srv.Serve(ln) if err != nil { - logger.Fatal("Failed to serve", zap.Error(err)) + params.Logger.Fatal("Failed to serve", zap.Error(err)) } }()