diff --git a/post_test.go b/post_test.go index 4010bc1..5f41ddd 100644 --- a/post_test.go +++ b/post_test.go @@ -66,4 +66,78 @@ func TestPost(t *testing.T) { }, Code: http.StatusRequestEntityTooLarge, }).Run(handler, t) + + (&httpTest{ + Name: "Ignore Forwarded headers", + Method: "POST", + ReqHeader: map[string]string{ + "Tus-Resumable": "1.0.0", + "Upload-Length": "300", + "Upload-Metadata": "foo aGVsbG8=, bar d29ybGQ=", + "X-Forwarded-Host": "foo.com", + "X-Forwarded-Proto": "https", + }, + Code: http.StatusCreated, + ResHeader: map[string]string{ + "Location": "http://tus.io/files/foo", + }, + }).Run(handler, t) + + handler, _ = NewHandler(Config{ + MaxSize: 400, + BasePath: "files", + DataStore: postStore{ + t: t, + }, + RespectForwardedHeaders: true, + }) + + (&httpTest{ + Name: "Respect X-Forwarded-* headers", + Method: "POST", + ReqHeader: map[string]string{ + "Tus-Resumable": "1.0.0", + "Upload-Length": "300", + "Upload-Metadata": "foo aGVsbG8=, bar d29ybGQ=", + "X-Forwarded-Host": "foo.com", + "X-Forwarded-Proto": "https", + }, + Code: http.StatusCreated, + ResHeader: map[string]string{ + "Location": "https://foo.com/files/foo", + }, + }).Run(handler, t) + + (&httpTest{ + Name: "Respect Forwarded headers", + Method: "POST", + ReqHeader: map[string]string{ + "Tus-Resumable": "1.0.0", + "Upload-Length": "300", + "Upload-Metadata": "foo aGVsbG8=, bar d29ybGQ=", + "X-Forwarded-Host": "bar.com", + "X-Forwarded-Proto": "http", + "Forwarded": "proto=https,host=foo.com", + }, + Code: http.StatusCreated, + ResHeader: map[string]string{ + "Location": "https://foo.com/files/foo", + }, + }).Run(handler, t) + + (&httpTest{ + Name: "Filter forwarded protocol", + Method: "POST", + ReqHeader: map[string]string{ + "Tus-Resumable": "1.0.0", + "Upload-Length": "300", + "Upload-Metadata": "foo aGVsbG8=, bar d29ybGQ=", + "X-Forwarded-Proto": "aaa", + "Forwarded": "proto=bbb", + }, + Code: http.StatusCreated, + ResHeader: map[string]string{ + "Location": "http://tus.io/files/foo", + }, + }).Run(handler, t) } diff --git a/unrouted_handler.go b/unrouted_handler.go index 920e30a..0428299 100644 --- a/unrouted_handler.go +++ b/unrouted_handler.go @@ -13,7 +13,11 @@ import ( "strings" ) -var reExtractFileID = regexp.MustCompile(`([^/]+)\/?$`) +var ( + reExtractFileID = regexp.MustCompile(`([^/]+)\/?$`) + reForwardedHost = regexp.MustCompile(`host=([^,]+)`) + reForwardedProto = regexp.MustCompile(`proto=(https?)`) +) var ( ErrUnsupportedVersion = errors.New("unsupported version") @@ -65,6 +69,10 @@ type Config struct { NotifyCompleteUploads bool // Logger the logger to use internally Logger *log.Logger + // Respect the X-Forwarded-Host, X-Forwarded-Proto and Forwarded headers + // potentially set by proxies when generating an absolute URL in the + // reponse to POST requests. + RespectForwardedHeaders bool } // UnroutedHandler exposes methods to handle requests as part of the tus protocol, @@ -500,16 +508,51 @@ func (handler *UnroutedHandler) absFileURL(r *http.Request, id string) string { } // Read origin and protocol from request - url := "http://" - if r.TLS != nil { - url = "https://" - } + host, proto := getHostAndProtocol(r, handler.config.RespectForwardedHeaders) - url += r.Host + handler.basePath + id + url := proto + "://" + host + handler.basePath + id return url } +// getHostAndProtocol extracts the host and used protocol (either HTTP or HTTPS) +// from the given request. If `allowForwarded` is set, the X-Forwarded-Host, +// X-Forwarded-Proto and Forwarded headers will also be checked to +// support proxies. +func getHostAndProtocol(r *http.Request, allowForwarded bool) (host, proto string) { + if r.TLS != nil { + proto = "https" + } else { + proto = "http" + } + + host = r.Host + + if !allowForwarded { + return + } + + if h := r.Header.Get("X-Forwarded-Host"); h != "" { + host = h + } + + if h := r.Header.Get("X-Forwarded-Proto"); h == "http" || h == "https" { + proto = h + } + + if h := r.Header.Get("Forwarded"); h != "" { + if r := reForwardedHost.FindStringSubmatch(h); len(r) == 2 { + host = r[1] + } + + if r := reForwardedProto.FindStringSubmatch(h); len(r) == 2 { + proto = r[1] + } + } + + return +} + // The get sum of all sizes for a list of upload ids while checking whether // all of these uploads are finished yet. This is used to calculate the size // of a final resource.