From 7225439860d8675b231f408f8e9a26b3fadcf1e2 Mon Sep 17 00:00:00 2001 From: Christian Kaps <307006+akkie@users.noreply.github.com> Date: Mon, 27 Mar 2023 00:11:41 +0200 Subject: [PATCH] cli: Add flag for disabling CORS headers (#899) --- cmd/tusd/cli/flags.go | 3 ++- cmd/tusd/cli/serve.go | 1 + docs/usage-binary.md | 2 ++ pkg/handler/config.go | 3 +++ pkg/handler/cors_test.go | 16 ++++++++++++++++ pkg/handler/unrouted_handler.go | 2 +- 6 files changed, 25 insertions(+), 2 deletions(-) diff --git a/cmd/tusd/cli/flags.go b/cmd/tusd/cli/flags.go index 072d490..0c4a896 100644 --- a/cmd/tusd/cli/flags.go +++ b/cmd/tusd/cli/flags.go @@ -23,6 +23,7 @@ var Flags struct { ShowGreeting bool DisableDownload bool DisableTermination bool + DisableCors bool Timeout int64 S3Bucket string S3ObjectPrefix string @@ -72,6 +73,7 @@ func ParseFlags() { flag.BoolVar(&Flags.ShowGreeting, "show-greeting", true, "Show the greeting message") flag.BoolVar(&Flags.DisableDownload, "disable-download", false, "Disable the download endpoint") flag.BoolVar(&Flags.DisableTermination, "disable-termination", false, "Disable the termination endpoint") + flag.BoolVar(&Flags.DisableCors, "disable-cors", false, "Disable CORS headers") flag.Int64Var(&Flags.Timeout, "timeout", 6*1000, "Read timeout for connections in milliseconds. A zero value means that reads will not timeout") flag.StringVar(&Flags.S3Bucket, "s3-bucket", "", "Use AWS S3 with this bucket as storage backend (requires the AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY and AWS_REGION environment variables to be set)") flag.StringVar(&Flags.S3ObjectPrefix, "s3-object-prefix", "", "Prefix for S3 object names") @@ -106,7 +108,6 @@ func ParseFlags() { flag.StringVar(&Flags.TLSCertFile, "tls-certificate", "", "Path to the file containing the x509 TLS certificate to be used. The file should also contain any intermediate certificates and the CA certificate.") flag.StringVar(&Flags.TLSKeyFile, "tls-key", "", "Path to the file containing the key for the TLS certificate.") flag.StringVar(&Flags.TLSMode, "tls-mode", "tls12", "Specify which TLS mode to use; valid modes are tls13, tls12, and tls12-strong.") - flag.StringVar(&Flags.CPUProfile, "cpuprofile", "", "write cpu profile to file") flag.Parse() diff --git a/cmd/tusd/cli/serve.go b/cmd/tusd/cli/serve.go index 30f1f73..8449b6a 100644 --- a/cmd/tusd/cli/serve.go +++ b/cmd/tusd/cli/serve.go @@ -29,6 +29,7 @@ func Serve() { RespectForwardedHeaders: Flags.BehindProxy, DisableDownload: Flags.DisableDownload, DisableTermination: Flags.DisableTermination, + DisableCors: Flags.DisableCors, StoreComposer: Composer, NotifyCompleteUploads: true, NotifyTerminatedUploads: true, diff --git a/docs/usage-binary.md b/docs/usage-binary.md index 8b647c2..80872c9 100644 --- a/docs/usage-binary.md +++ b/docs/usage-binary.md @@ -216,6 +216,8 @@ $ tusd -help If set, will listen to a UNIX socket at this location instead of a TCP socket -upload-dir string Directory to store uploads in (default "./data") + -disable-cors + Disables CORS headers. If set to true, tusd will not send any CORS related header. This is useful if you have a proxy sitting in front of tusd that handles CORS (default false) -verbose Enable verbose logging output (default true) -version diff --git a/pkg/handler/config.go b/pkg/handler/config.go index 9b729b8..bc790fb 100644 --- a/pkg/handler/config.go +++ b/pkg/handler/config.go @@ -28,6 +28,9 @@ type Config struct { // DisableTermination indicates whether the server will refuse termination // requests of the uploaded file, by not mounting the DELETE handler. DisableTermination bool + // Disable cors headers. If set to true, tusd will not send any CORS related header. + // This is useful if you have a proxy sitting in front of tusd that handles CORS. + DisableCors bool // NotifyCompleteUploads indicates whether sending notifications about // completed uploads using the CompleteUploads channel should be enabled. NotifyCompleteUploads bool diff --git a/pkg/handler/cors_test.go b/pkg/handler/cors_test.go index e464916..1979f54 100644 --- a/pkg/handler/cors_test.go +++ b/pkg/handler/cors_test.go @@ -96,4 +96,20 @@ func TestCORS(t *testing.T) { t.Errorf("expected header to contain METHOD but got: %#v", methods) } }) + + SubTest(t, "Disable CORS", func(t *testing.T, store *MockFullDataStore, composer *StoreComposer) { + handler, _ := NewHandler(Config{ + StoreComposer: composer, + DisableCors: true, + }) + + (&httpTest{ + Method: "OPTIONS", + ReqHeader: map[string]string{ + "Origin": "tus.io", + }, + Code: http.StatusOK, + ResHeader: map[string]string{}, + }).Run(handler, t) + }) } diff --git a/pkg/handler/unrouted_handler.go b/pkg/handler/unrouted_handler.go index 55e7ccc..e798fc1 100644 --- a/pkg/handler/unrouted_handler.go +++ b/pkg/handler/unrouted_handler.go @@ -217,7 +217,7 @@ func (handler *UnroutedHandler) Middleware(h http.Handler) http.Handler { header := w.Header() - if origin := r.Header.Get("Origin"); origin != "" { + if origin := r.Header.Get("Origin"); !handler.config.DisableCors && origin != "" { header.Set("Access-Control-Allow-Origin", origin) if r.Method == "OPTIONS" {