diff --git a/cmd/tusd/main.go b/cmd/tusd/main.go index bc4b2b0..bdf7408 100644 --- a/cmd/tusd/main.go +++ b/cmd/tusd/main.go @@ -6,6 +6,7 @@ import ( "github.com/tus/tusd/filestore" "github.com/tus/tusd/limitedstore" "log" + "net" "net/http" "os" "time" @@ -82,12 +83,74 @@ func main() { http.Handle(basepath, http.StripPrefix(basepath, handler)) - server := &http.Server{ - Addr: address, - ReadTimeout: time.Duration(timeout) * time.Millisecond, + timeoutDuration := time.Duration(timeout) * time.Millisecond + listener, err := NewListener(address, timeoutDuration, timeoutDuration) + if err != nil { + stderr.Fatalf("Unable to create listener: %s", err) } - if err = server.ListenAndServe(); err != nil { - stderr.Fatalf("Unable to listen: %s", err) + if err = http.Serve(listener, nil); err != nil { + stderr.Fatalf("Unable to serve: %s", err) } } + +// Listener wraps a net.Listener, and gives a place to store the timeout +// parameters. On Accept, it will wrap the net.Conn with our own Conn for us. +// Original implementation taken from https://gist.github.com/jbardin/9663312 +// Thanks! <3 +type Listener struct { + net.Listener + ReadTimeout time.Duration + WriteTimeout time.Duration +} + +func (l *Listener) Accept() (net.Conn, error) { + c, err := l.Listener.Accept() + if err != nil { + return nil, err + } + tc := &Conn{ + Conn: c, + ReadTimeout: l.ReadTimeout, + WriteTimeout: l.WriteTimeout, + } + return tc, nil +} + +// Conn wraps a net.Conn, and sets a deadline for every read +// and write operation. +type Conn struct { + net.Conn + ReadTimeout time.Duration + WriteTimeout time.Duration +} + +func (c *Conn) Read(b []byte) (int, error) { + err := c.Conn.SetReadDeadline(time.Now().Add(c.ReadTimeout)) + if err != nil { + return 0, err + } + return c.Conn.Read(b) +} + +func (c *Conn) Write(b []byte) (int, error) { + err := c.Conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout)) + if err != nil { + return 0, err + } + return c.Conn.Write(b) +} + +func NewListener(addr string, readTimeout, writeTimeout time.Duration) (net.Listener, error) { + l, err := net.Listen("tcp", addr) + if err != nil { + return nil, err + } + + tl := &Listener{ + Listener: l, + ReadTimeout: readTimeout, + WriteTimeout: writeTimeout, + } + return tl, nil +}