diff --git a/pkg/handler/unrouted_handler.go b/pkg/handler/unrouted_handler.go index ad950b6..74072b6 100644 --- a/pkg/handler/unrouted_handler.go +++ b/pkg/handler/unrouted_handler.go @@ -9,6 +9,7 @@ import ( "math" "net" "net/http" + uri "net/url" "os" "regexp" "strconv" @@ -112,12 +113,10 @@ func newHookEvent(info FileInfo, r *http.Request) HookEvent { // such as PostFile, HeadFile, PatchFile and DelFile. In addition the GetFile method // is provided which is, however, not part of the specification. type UnroutedHandler struct { - config Config - composer *StoreComposer - isBasePathAbs bool - basePath string - logger *log.Logger - extensions string + config Config + composer *StoreComposer + logger *log.Logger + extensions string // CompleteUploads is used to send notifications whenever an upload is // completed by a user. The HookEvent will contain information about this @@ -149,6 +148,10 @@ type UnroutedHandler struct { CreatedUploads chan HookEvent // Metrics provides numbers of the usage for this handler. Metrics Metrics + // GetID is a customisable callback used by the handler to determine the id of an upload in a request + GetID func(r *http.Request) (string, error) + // GetURL is a customisable callback used by the handler to determine the url that uniquely identifies the resource id + GetURL func(r *http.Request, id string) (*uri.URL, error) } // NewUnroutedHandler creates a new handler without routing using the given @@ -175,8 +178,6 @@ func NewUnroutedHandler(config Config) (*UnroutedHandler, error) { handler := &UnroutedHandler{ config: config, composer: config.StoreComposer, - basePath: config.BasePath, - isBasePathAbs: config.isAbs, CompleteUploads: make(chan HookEvent), TerminatedUploads: make(chan HookEvent), UploadProgress: make(chan HookEvent), @@ -184,6 +185,21 @@ func NewUnroutedHandler(config Config) (*UnroutedHandler, error) { logger: config.Logger, extensions: extensions, Metrics: newMetrics(), + GetID: func(r *http.Request) (string, error) { + return extractIDFromPath(r.URL.Path) + }, + GetURL: func(r *http.Request, id string) (*uri.URL, error) { + if config.isAbs { + return uri.Parse(config.BasePath + id) + } + + // Read origin and protocol from request + host, proto := getHostAndProtocol(r, config.RespectForwardedHeaders) + + url := proto + "://" + host + config.BasePath + id + + return uri.Parse(url) + }, } return handler, nil @@ -289,12 +305,33 @@ func (handler *UnroutedHandler) PostFile(w http.ResponseWriter, r *http.Request) } // Parse Upload-Concat header - isPartial, isFinal, partialUploadIDs, err := parseConcat(concatHeader) + isPartial, isFinal, partialUploadPaths, err := parseConcat(concatHeader) if err != nil { handler.sendError(w, r, err) return } + var partialUploadIDs []string + if len(partialUploadPaths) > 0 { + partialUploadIDs = make([]string, len(partialUploadPaths)) + origURL := r.URL + for i, path := range partialUploadPaths { + // insert partialUpload URI and call GetID + r.URL, err = r.URL.Parse(path) + if err != nil { + handler.sendError(w, r, err) + return + } + + partialUploadIDs[i], err = handler.GetID(r) + if err != nil { + handler.sendError(w, r, err) + return + } + } + r.URL = origURL + } + // If the upload is a final upload created by concatenation multiple partial // uploads the size is sum of all sizes of these files (no need for // Upload-Length header) @@ -364,11 +401,15 @@ func (handler *UnroutedHandler) PostFile(w http.ResponseWriter, r *http.Request) // Add the Location header directly after creating the new resource to even // include it in cases of failure when an error is returned - url := handler.absFileURL(r, id) - w.Header().Set("Location", url) + url, err := handler.GetURL(r, id) + if err != nil { + handler.sendError(w, r, err) + return + } + w.Header().Set("Location", url.String()) handler.Metrics.incUploadsCreated() - handler.log("UploadCreated", "id", id, "size", i64toa(size), "url", url) + handler.log("UploadCreated", "id", id, "size", i64toa(size), "url", url.String()) if handler.config.NotifyCreatedUploads { handler.CreatedUploads <- newHookEvent(info, r) @@ -416,7 +457,7 @@ func (handler *UnroutedHandler) PostFile(w http.ResponseWriter, r *http.Request) func (handler *UnroutedHandler) HeadFile(w http.ResponseWriter, r *http.Request) { ctx := context.Background() - id, err := extractIDFromPath(r.URL.Path) + id, err := handler.GetID(r) if err != nil { handler.sendError(w, r, err) return @@ -452,7 +493,12 @@ func (handler *UnroutedHandler) HeadFile(w http.ResponseWriter, r *http.Request) if info.IsFinal { v := "final;" for _, uploadID := range info.PartialUploads { - v += handler.absFileURL(r, uploadID) + " " + url, err := handler.GetURL(r, uploadID) + if err != nil { + handler.sendError(w, r, err) + return + } + v += url.String() + " " } // Remove trailing space v = v[:len(v)-1] @@ -493,7 +539,7 @@ func (handler *UnroutedHandler) PatchFile(w http.ResponseWriter, r *http.Request return } - id, err := extractIDFromPath(r.URL.Path) + id, err := handler.GetID(r) if err != nil { handler.sendError(w, r, err) return @@ -695,7 +741,7 @@ func (handler *UnroutedHandler) finishUploadIfComplete(ctx context.Context, uplo func (handler *UnroutedHandler) GetFile(w http.ResponseWriter, r *http.Request) { ctx := context.Background() - id, err := extractIDFromPath(r.URL.Path) + id, err := handler.GetID(r) if err != nil { handler.sendError(w, r, err) return @@ -822,7 +868,7 @@ func (handler *UnroutedHandler) DelFile(w http.ResponseWriter, r *http.Request) return } - id, err := extractIDFromPath(r.URL.Path) + id, err := handler.GetID(r) if err != nil { handler.sendError(w, r, err) return @@ -934,21 +980,6 @@ func (handler *UnroutedHandler) sendResp(w http.ResponseWriter, r *http.Request, handler.log("ResponseOutgoing", "status", strconv.Itoa(status), "method", r.Method, "path", r.URL.Path) } -// Make an absolute URLs to the given upload id. If the base path is absolute -// it will be prepended else the host and protocol from the request is used. -func (handler *UnroutedHandler) absFileURL(r *http.Request, id string) string { - if handler.isBasePathAbs { - return handler.basePath + id - } - - // Read origin and protocol from request - host, proto := getHostAndProtocol(r, handler.config.RespectForwardedHeaders) - - url := proto + "://" + host + handler.basePath + id - - return url -} - type progressWriter struct { Offset int64 } @@ -1171,13 +1202,7 @@ func parseConcat(header string) (isPartial bool, isFinal bool, partialUploads [] continue } - id, extractErr := extractIDFromPath(value) - if extractErr != nil { - err = extractErr - return - } - - partialUploads = append(partialUploads, id) + partialUploads = append(partialUploads, value) } }