diff --git a/pkg/handler/context.go b/pkg/handler/context.go index 5cf8786..2742591 100644 --- a/pkg/handler/context.go +++ b/pkg/handler/context.go @@ -12,16 +12,25 @@ import ( type httpContext struct { context.Context - res http.ResponseWriter - req *http.Request - body *bodyReader + parentCtx context.Context + res http.ResponseWriter + req *http.Request + body *bodyReader } func newContext(w http.ResponseWriter, r *http.Request) *httpContext { return &httpContext{ - Context: r.Context(), - res: w, - req: r, - body: nil, // body can be filled later for PATCH requests + Context: context.Background(), + parentCtx: r.Context(), + res: w, + req: r, + body: nil, // body can be filled later for PATCH requests } } + +func (hctx *httpContext) Value(key interface{}) interface{} { + if v := hctx.Context.Value(key); v != nil { + return v + } + return hctx.parentCtx.Value(key) +} diff --git a/pkg/handler/context_test.go b/pkg/handler/context_test.go new file mode 100644 index 0000000..282600a --- /dev/null +++ b/pkg/handler/context_test.go @@ -0,0 +1,41 @@ +package handler + +import ( + "context" + "errors" + "github.com/stretchr/testify/assert" + "net/http" + "net/http/httptest" + "testing" +) + +func TestContext(t *testing.T) { + + t.Run("new context returns values from parent context", func(t *testing.T) { + parentCtx := context.WithValue(context.Background(), "test", "value") + req := http.Request{} + reqWithCtx := req.WithContext(parentCtx) + ctx := newContext(&httptest.ResponseRecorder{}, reqWithCtx) + + ctxToTest := context.WithValue(ctx, "another", "testvalue") + + a := assert.New(t) + + a.Equal("testvalue", ctxToTest.Value("another")) + a.Equal("value", ctxToTest.Value("test")) + }) + + t.Run("parent context cancellation does not cancel the httpContext", func(t *testing.T) { + parentCtx := context.Background() + req := http.Request{} + reqWithCtx := req.WithContext(parentCtx) + ctx := newContext(&httptest.ResponseRecorder{}, reqWithCtx) + + parentCtx.Done() + + a := assert.New(t) + + a.False(errors.Is(ctx.Err(), context.Canceled)) + }) + +}