From 3f042b97a0febd6ab9cab8486be4e9700ce38271 Mon Sep 17 00:00:00 2001 From: Marius Date: Tue, 13 Jun 2023 16:17:46 +0200 Subject: [PATCH] cli: Shutdown gracefully by default (#963) * cli: Implement graceful shutdown * Make shutdown timeout configurable * Also shutdown gracefully for SIGTERM * Add comment * Add test for handler.InterruptRequestHandling * Add documentation --- cmd/tusd/cli/flags.go | 2 + cmd/tusd/cli/serve.go | 154 ++++++++++++++++++++++---------- docs/usage-binary.md | 10 +++ pkg/handler/body_reader.go | 4 + pkg/handler/patch_test.go | 53 +++++++++++ pkg/handler/unrouted_handler.go | 44 ++++++++- 6 files changed, 216 insertions(+), 51 deletions(-) diff --git a/cmd/tusd/cli/flags.go b/cmd/tusd/cli/flags.go index 9d4e47c..dcf3fd1 100644 --- a/cmd/tusd/cli/flags.go +++ b/cmd/tusd/cli/flags.go @@ -61,6 +61,7 @@ var Flags struct { TLSCertFile string TLSKeyFile string TLSMode string + ShutdownTimeout int64 } func ParseFlags() { @@ -115,6 +116,7 @@ func ParseFlags() { flag.StringVar(&Flags.TLSCertFile, "tls-certificate", "", "Path to the file containing the x509 TLS certificate to be used. The file should also contain any intermediate certificates and the CA certificate.") flag.StringVar(&Flags.TLSKeyFile, "tls-key", "", "Path to the file containing the key for the TLS certificate.") flag.StringVar(&Flags.TLSMode, "tls-mode", "tls12", "Specify which TLS mode to use; valid modes are tls13, tls12, and tls12-strong.") + flag.Int64Var(&Flags.ShutdownTimeout, "shutdown-timeout", 10*1000, "Timeout in milliseconds for closing connections gracefully during shutdown. After the timeout, tusd will exit regardless of any open connection.") flag.Parse() SetEnabledHooks() diff --git a/cmd/tusd/cli/serve.go b/cmd/tusd/cli/serve.go index 7349b84..4a2908b 100644 --- a/cmd/tusd/cli/serve.go +++ b/cmd/tusd/cli/serve.go @@ -1,10 +1,15 @@ package cli import ( + "context" "crypto/tls" + "errors" "net" "net/http" + "os" + "os/signal" "strings" + "syscall" "time" "github.com/tus/tusd/v2/pkg/handler" @@ -115,62 +120,117 @@ func Serve() { stdout.Printf("You can now upload files to: %s://%s%s", protocol, address, basepath) } - // If we're not using TLS just start the server and, if http.Serve() returns, just return. - if protocol == "http" { - if err = http.Serve(listener, mux); err != nil { - stderr.Fatalf("Unable to serve: %s", err) - } - return - } - - // TODO: Move TLS handling into own file. - // Fall-through for TLS mode. server := &http.Server{ Handler: mux, } - switch Flags.TLSMode { - case TLS13: - server.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS13} - case TLS12: - // Ciphersuite selection comes from - // https://ssl-config.mozilla.org/#server=go&version=1.14.4&config=intermediate&guideline=5.6 - // 128-bit AES modes remain as TLSv1.3 is enabled in this mode, and TLSv1.3 compatibility requires an AES-128 ciphersuite. - server.TLSConfig = &tls.Config{ - MinVersion: tls.VersionTLS12, - PreferServerCipherSuites: true, - CipherSuites: []uint16{ - tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, - tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, - }, + shutdownComplete := setupSignalHandler(server, handler) + + if protocol == "http" { + // Non-TLS mode + err = server.Serve(listener) + } else { + // TODO: Move TLS handling into own file. + // TLS mode + + switch Flags.TLSMode { + case TLS13: + server.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS13} + + case TLS12: + // Ciphersuite selection comes from + // https://ssl-config.mozilla.org/#server=go&version=1.14.4&config=intermediate&guideline=5.6 + // 128-bit AES modes remain as TLSv1.3 is enabled in this mode, and TLSv1.3 compatibility requires an AES-128 ciphersuite. + server.TLSConfig = &tls.Config{ + MinVersion: tls.VersionTLS12, + PreferServerCipherSuites: true, + CipherSuites: []uint16{ + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + }, + } + + case TLS12STRONG: + // Ciphersuite selection as above, but intersected with + // https://github.com/denji/golang-tls#perfect-ssl-labs-score-with-go + // TLSv1.3 is disabled as it requires an AES-128 ciphersuite. + server.TLSConfig = &tls.Config{ + MinVersion: tls.VersionTLS12, + MaxVersion: tls.VersionTLS12, + PreferServerCipherSuites: true, + CipherSuites: []uint16{ + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + }, + } + + default: + stderr.Fatalf("Invalid TLS mode chosen. Recommended valid modes are tls13, tls12 (default), and tls12-strong") } - case TLS12STRONG: - // Ciphersuite selection as above, but intersected with - // https://github.com/denji/golang-tls#perfect-ssl-labs-score-with-go - // TLSv1.3 is disabled as it requires an AES-128 ciphersuite. - server.TLSConfig = &tls.Config{ - MinVersion: tls.VersionTLS12, - MaxVersion: tls.VersionTLS12, - PreferServerCipherSuites: true, - CipherSuites: []uint16{ - tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - }, - } + // Disable HTTP/2; the default non-TLS mode doesn't support it + server.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler), 0) - default: - stderr.Fatalf("Invalid TLS mode chosen. Recommended valid modes are tls13, tls12 (default), and tls12-strong") + err = server.ServeTLS(listener, Flags.TLSCertFile, Flags.TLSKeyFile) } - // Disable HTTP/2; the default non-TLS mode doesn't support it - server.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler), 0) - - if err = server.ServeTLS(listener, Flags.TLSCertFile, Flags.TLSKeyFile); err != nil { + // Note: http.Server.Serve and http.Server.ServeTLS always return a non-nil error code. So + // we can assume from here that `err != nil` + if err == http.ErrServerClosed { + // ErrServerClosed means that http.Server.Shutdown was called due to an interruption signal. + // We wait until the interruption procedure is complete or times out and then exit main. + <-shutdownComplete + } else { + // Any other error is relayed to the user. stderr.Fatalf("Unable to serve: %s", err) } } + +func setupSignalHandler(server *http.Server, handler *handler.Handler) <-chan struct{} { + shutdownComplete := make(chan struct{}) + + // We read up to two signals, so use a capacity of 2 here to not miss any signal + c := make(chan os.Signal, 2) + + // os.Interrupt is mapped to SIGINT on Unix and to the termination instructions on Windows. + // On Unix we also listen to SIGTERM. + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + + // Signal to the handler that it should stop all long running requests if we shut down + server.RegisterOnShutdown(handler.InterruptRequestHandling) + + go func() { + // First interrupt signal + <-c + stdout.Println("Received interrupt signal. Shutting down tusd...") + + // Wait for second interrupt signal, while also shutting down the existing server + go func() { + <-c + stdout.Println("Received second interrupt signal. Exiting immediately!") + os.Exit(1) + }() + + // Shutdown the server, but with a user-specified timeout + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(Flags.ShutdownTimeout)*time.Millisecond) + defer cancel() + + err := server.Shutdown(ctx) + + if err == nil { + stdout.Println("Shutdown completed. Goodbye!") + } else if errors.Is(err, context.DeadlineExceeded) { + stderr.Println("Shutdown timeout exceeded. Exiting immediately!") + } else { + stderr.Printf("Failed to shutdown gracefully: %s\n", err) + } + + close(shutdownComplete) + }() + + return shutdownComplete +} diff --git a/docs/usage-binary.md b/docs/usage-binary.md index 80872c9..f06e560 100644 --- a/docs/usage-binary.md +++ b/docs/usage-binary.md @@ -224,3 +224,13 @@ $ tusd -help Print tusd version information ``` + +## Graceful shutdown + +If tusd receives a SIGINT or SIGTERM signal, it will initiate a graceful shutdown. SIGINT is usually emitted by pressing Ctrl+C inside the terminal that is running tusd. SIGINT and SIGTERM can also be emitted using the [`kill(1)`](https://man7.org/linux/man-pages/man1/kill.1.html) utility on Unix. Signals in that sense do not exist on Windows, so please refer to the [Go documentation](https://pkg.go.dev/os/signal#hdr-Windows) on how different events are translated into signals on Windows. + +Once the graceful shutdown is started, tusd will stop listening on its port and won't accept new connections anymore. Idle connections are closed down. Already running requests will be given a grace period to complete before their connections are closed as well. PATCH and POST requests with a request body are interrupted, so that data stores can gracefully finish saving all the received data until that point. If all requests have been completed, tusd will exit. + +If not all requests have been completed in the period defined by the `-shutdown-timeout` flag, tusd will exit regardless. By default, tusd will give all requests 10 seconds to complete their processing. If you do not want to wait for requests, use `-shutdown-timeout=0`. + +tusd will also immediately exit if it receives a second SIGINT or SIGTERM signal. It will also always exit immediately if a SIGKILL is received. diff --git a/pkg/handler/body_reader.go b/pkg/handler/body_reader.go index 67d2f11..83daf77 100644 --- a/pkg/handler/body_reader.go +++ b/pkg/handler/body_reader.go @@ -31,6 +31,10 @@ func (r *bodyReader) Read(b []byte) (int, error) { return 0, io.EOF } + // TODO: Mask certain errors that we can safely ignore later on: + // io.EOF, io.UnexpectedEOF, io.ErrClosedPipe, + // read tcp 127.0.0.1:1080->127.0.0.1:56953: read: connection reset by peer, + // read tcp 127.0.0.1:1080->127.0.0.1:9375: i/o timeout n, err := r.reader.Read(b) atomic.AddInt64(&r.bytesCounter, int64(n)) r.err = err diff --git a/pkg/handler/patch_test.go b/pkg/handler/patch_test.go index cf30442..2ad5522 100644 --- a/pkg/handler/patch_test.go +++ b/pkg/handler/patch_test.go @@ -682,4 +682,57 @@ func TestPatch(t *testing.T) { ResBody: "ERR_INTERNAL_SERVER_ERROR: an error while reading the body\n", }).Run(handler, t) }) + + SubTest(t, "InterruptRequestHandling", func(t *testing.T, store *MockFullDataStore, composer *StoreComposer) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + upload := NewMockFullUpload(ctrl) + + gomock.InOrder( + store.EXPECT().GetUpload(gomock.Any(), "yes").Return(upload, nil), + upload.EXPECT().GetInfo(gomock.Any()).Return(FileInfo{ + ID: "yes", + Offset: 0, + Size: 100, + }, nil), + upload.EXPECT().WriteChunk(gomock.Any(), int64(0), NewReaderMatcher("first ")).Return(int64(6), nil), + ) + + handler, _ := NewHandler(Config{ + StoreComposer: composer, + }) + + reader, writer := io.Pipe() + a := assert.New(t) + + go func() { + writer.Write([]byte("first ")) + + handler.InterruptRequestHandling() + + // Wait a short time to ensure that the goroutine in the PATCH + // handler has received and processed the stop event. + <-time.After(10 * time.Millisecond) + + // Assert that the "request body" has been closed. + _, err := writer.Write([]byte("second ")) + a.Equal(err, io.ErrClosedPipe) + }() + + (&httpTest{ + Method: "PATCH", + URL: "yes", + ReqHeader: map[string]string{ + "Tus-Resumable": "1.0.0", + "Content-Type": "application/offset+octet-stream", + "Upload-Offset": "0", + }, + ReqBody: reader, + Code: http.StatusInternalServerError, + ResHeader: map[string]string{ + "Upload-Offset": "", + }, + ResBody: "ERR_SERVER_SHUTDOWN: request has been interrupted because the server is shutting down\n", + }).Run(handler, t) + }) } diff --git a/pkg/handler/unrouted_handler.go b/pkg/handler/unrouted_handler.go index bec8205..0a93feb 100644 --- a/pkg/handler/unrouted_handler.go +++ b/pkg/handler/unrouted_handler.go @@ -43,6 +43,7 @@ var ( ErrUploadStoppedByServer = NewError("ERR_UPLOAD_STOPPED", "upload has been stopped by server", http.StatusBadRequest) ErrUploadRejectedByServer = NewError("ERR_UPLOAD_REJECTED", "upload creation has been rejected by server", http.StatusBadRequest) ErrUploadInterrupted = NewError("ERR_UPLAOD_INTERRUPTED", "upload has been interrupted by another request for this upload resource", http.StatusBadRequest) + ErrServerShutdown = NewError("ERR_SERVER_SHUTDOWN", "request has been interrupted because the server is shutting down", http.StatusInternalServerError) // TODO: These two responses are 500 for backwards compatability. We should discuss // whether it is better to more them to 4XX status codes. @@ -60,6 +61,7 @@ type UnroutedHandler struct { basePath string logger *log.Logger extensions string + serverCtx chan struct{} // CompleteUploads is used to send notifications whenever an upload is // completed by a user. The HookEvent will contain information about this @@ -126,11 +128,29 @@ func NewUnroutedHandler(config Config) (*UnroutedHandler, error) { logger: config.Logger, extensions: extensions, Metrics: newMetrics(), + serverCtx: make(chan struct{}), } return handler, nil } +// InterruptRequestHandling attempts to interrupt long running requests, so +// the server can shutdown gracefully. This function should not be used on +// its own, but as part of http.Server.Shutdown. For example: +// +// server := &http.Server{ +// Handler: handler, +// } +// server.RegisterOnShutdown(handler.InterruptRequestHandling) +// server.Shutdown(ctx) +// +// Note: currently, this function only interrupts POST and PATCH requests +// with a request body. In the future, this might be extended to HEAD, DELETE +// and GET requests. +func (handler UnroutedHandler) InterruptRequestHandling() { + close(handler.serverCtx) +} + // SupportedExtensions returns a comma-separated list of the supported tus extensions. // The availability of an extension usually depends on whether the provided data store // implements some additional interfaces. @@ -596,18 +616,28 @@ func (handler *UnroutedHandler) writeChunk(c *httpContext, resp HTTPResponse, up // We use a context object to allow the hook system to cancel an upload uploadCtx, stopUpload := context.WithCancel(context.Background()) info.stopUpload = stopUpload + // terminateUpload specifies whether the upload should be deleted after // the write has finished terminateUpload := false + + serverShutDown := false + // Cancel the context when the function exits to ensure that the goroutine // is properly cleaned up defer stopUpload() go func() { - // Interrupt the Read() call from the request body - <-uploadCtx.Done() - // TODO: Consider using CloseWithError function from BodyReader - terminateUpload = true + select { + case <-uploadCtx.Done(): + // uploadCtx is done if the upload is stopped by a post-receive hook + terminateUpload = true + case <-handler.serverCtx: + // serverCtx is closed if the server is being shut down + serverShutDown = true + } + + // interrupt the Read() calls from the request body r.Body.Close() }() @@ -639,6 +669,11 @@ func (handler *UnroutedHandler) writeChunk(c *httpContext, resp HTTPResponse, up if terminateUpload { err = ErrUploadStoppedByServer } + + // If the server is closing down, send an error response indicating this. + if serverShutDown { + err = ErrServerShutdown + } } handler.log("ChunkWriteComplete", "id", id, "bytesWritten", i64toa(bytesWritten)) @@ -1104,6 +1139,7 @@ func (handler *UnroutedHandler) lockUpload(c *httpContext, id string) (Lock, err releaseLock := func() { if c.body != nil { handler.log("UploadInterrupted", "id", id, "requestId", getRequestId(c.req)) + // TODO: Consider replacing this with a channel or a context c.body.closeWithError(ErrUploadInterrupted) } }