From d7a46190802a5b8312d44839228ffe46947e69d2 Mon Sep 17 00:00:00 2001 From: oliverpool Date: Wed, 1 Feb 2017 08:37:46 +0100 Subject: [PATCH] Prevent race condition on errorstotal increment The HTTPError interface is used to have detailed metrics (by code as well) --- metrics.go | 73 ++++++++++++++++------ prometheuscollector/prometheuscollector.go | 11 ++-- unrouted_handler.go | 15 +++-- 3 files changed, 70 insertions(+), 29 deletions(-) diff --git a/metrics.go b/metrics.go index 177d965..76c2a8d 100644 --- a/metrics.go +++ b/metrics.go @@ -1,6 +1,7 @@ package tusd import ( + "sync" "sync/atomic" ) @@ -13,7 +14,7 @@ type Metrics struct { // RequestTotal counts the number of incoming requests per method RequestsTotal map[string]*uint64 // ErrorsTotal counts the number of returned errors by their message - ErrorsTotal map[string]*uint64 + ErrorsTotal ErrorsTotalMap BytesReceived *uint64 UploadsFinished *uint64 UploadsCreated *uint64 @@ -22,43 +23,35 @@ type Metrics struct { // incRequestsTotal increases the counter for this request method atomically by // one. The method must be one of GET, HEAD, POST, PATCH, DELETE. -func (m Metrics) incRequestsTotal(method string) { +func (m *Metrics) incRequestsTotal(method string) { if ptr, ok := m.RequestsTotal[method]; ok { atomic.AddUint64(ptr, 1) } } // incErrorsTotal increases the counter for this error atomically by one. -func (m Metrics) incErrorsTotal(err error) { - msg := err.Error() - - if addr, ok := m.ErrorsTotal[msg]; ok { - atomic.AddUint64(addr, 1) - } else { - addr := new(uint64) - *addr = 1 - m.ErrorsTotal[msg] = addr - } +func (m *Metrics) incErrorsTotal(err HTTPError) { + m.ErrorsTotal.incError(err) } // incBytesReceived increases the number of received bytes atomically be the // specified number. -func (m Metrics) incBytesReceived(delta uint64) { +func (m *Metrics) incBytesReceived(delta uint64) { atomic.AddUint64(m.BytesReceived, delta) } // incUploadsFinished increases the counter for finished uploads atomically by one. -func (m Metrics) incUploadsFinished() { +func (m *Metrics) incUploadsFinished() { atomic.AddUint64(m.UploadsFinished, 1) } // incUploadsCreated increases the counter for completed uploads atomically by one. -func (m Metrics) incUploadsCreated() { +func (m *Metrics) incUploadsCreated() { atomic.AddUint64(m.UploadsCreated, 1) } // incUploadsTerminated increases the counter for completed uploads atomically by one. -func (m Metrics) incUploadsTerminated() { +func (m *Metrics) incUploadsTerminated() { atomic.AddUint64(m.UploadsTerminated, 1) } @@ -80,7 +73,49 @@ func newMetrics() Metrics { } } -func newErrorsTotalMap() map[string]*uint64 { - m := make(map[string]*uint64, 20) - return m +// ErrorsTotalMap stores the counter for the different http errors. +type ErrorsTotalMap struct { + sync.RWMutex + m map[HTTPError]*uint64 +} + +func newErrorsTotalMap() ErrorsTotalMap { + m := make(map[HTTPError]*uint64, 20) + return ErrorsTotalMap{ + m: m, + } +} + +// incErrorsTotal increases the counter for this error atomically by one. +func (e *ErrorsTotalMap) incError(err HTTPError) { + // The goal is to have a valid ptr to the number of HTTPError + e.RLock() + if ptr, ok := e.m[err]; !ok { + // The ptr does not seem to exist for this err + // Hence we create it (using a write lock) + e.RUnlock() + e.Lock() + // We ensure that the ptr wasn't created in the meantime + if ptr, ok = e.m[err]; !ok { + ptr = new(uint64) + *ptr = 0 + e.m[err] = ptr + } + e.Unlock() + } else { + e.RUnlock() + } + // We can then increase the counter + atomic.AddUint64(e.m[err], 1) +} + +// Load retrieves the map of the counter pointers atomically +func (e *ErrorsTotalMap) Load() (m map[HTTPError]*uint64) { + m = make(map[HTTPError]*uint64, len(e.m)) + e.RLock() + for err, ptr := range e.m { + m[err] = ptr + } + e.RUnlock() + return } diff --git a/prometheuscollector/prometheuscollector.go b/prometheuscollector/prometheuscollector.go index 25c8943..0113f20 100644 --- a/prometheuscollector/prometheuscollector.go +++ b/prometheuscollector/prometheuscollector.go @@ -13,6 +13,8 @@ import ( "github.com/tus/tusd" + "strconv" + "github.com/prometheus/client_golang/prometheus" ) @@ -23,8 +25,8 @@ var ( []string{"method"}, nil) errorsTotalDesc = prometheus.NewDesc( "tusd_errors_total", - "Total number of erorrs per cause.", - []string{"cause"}, nil) + "Total number of errors per cause.", + []string{"status", "cause"}, nil) bytesReceivedDesc = prometheus.NewDesc( "tusd_bytes_received", "Number of bytes received for uploads.", @@ -73,12 +75,13 @@ func (c Collector) Collect(metrics chan<- prometheus.Metric) { ) } - for error, valuePtr := range c.metrics.ErrorsTotal { + for httpError, valuePtr := range c.metrics.ErrorsTotal.Load() { metrics <- prometheus.MustNewConstMetric( errorsTotalDesc, prometheus.GaugeValue, float64(atomic.LoadUint64(valuePtr)), - error, + strconv.Itoa(httpError.StatusCode()), + httpError.Error(), ) } diff --git a/unrouted_handler.go b/unrouted_handler.go index 7886ef2..c338d9f 100644 --- a/unrouted_handler.go +++ b/unrouted_handler.go @@ -603,9 +603,12 @@ func (handler *UnroutedHandler) sendError(w http.ResponseWriter, r *http.Request err = ErrNotFound } - status := 500 - if statusErr, ok := err.(HTTPError); ok { - status = statusErr.StatusCode() + statusErr, ok := err.(HTTPError) + if !ok { + statusErr = httpError{ + error: err, + statusCode: 500, // default status code + } } reason := err.Error() + "\n" @@ -615,12 +618,12 @@ func (handler *UnroutedHandler) sendError(w http.ResponseWriter, r *http.Request w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.Header().Set("Content-Length", strconv.Itoa(len(reason))) - w.WriteHeader(status) + w.WriteHeader(statusErr.StatusCode()) w.Write([]byte(reason)) - handler.log("ResponseOutgoing", "status", strconv.Itoa(status), "method", r.Method, "path", r.URL.Path, "error", err.Error()) + handler.log("ResponseOutgoing", "status", strconv.Itoa(statusErr.StatusCode()), "method", r.Method, "path", r.URL.Path, "error", err.Error()) - go handler.Metrics.incErrorsTotal(err) + go handler.Metrics.incErrorsTotal(statusErr) } // sendResp writes the header to w with the specified status code.