Prevent race condition on errorstotal increment

The HTTPError interface is used to have detailed metrics (by code as
well)
This commit is contained in:
oliverpool 2017-02-01 08:37:46 +01:00
parent 5491b8ff12
commit d7a4619080
3 changed files with 70 additions and 29 deletions

View File

@ -1,6 +1,7 @@
package tusd package tusd
import ( import (
"sync"
"sync/atomic" "sync/atomic"
) )
@ -13,7 +14,7 @@ type Metrics struct {
// RequestTotal counts the number of incoming requests per method // RequestTotal counts the number of incoming requests per method
RequestsTotal map[string]*uint64 RequestsTotal map[string]*uint64
// ErrorsTotal counts the number of returned errors by their message // ErrorsTotal counts the number of returned errors by their message
ErrorsTotal map[string]*uint64 ErrorsTotal ErrorsTotalMap
BytesReceived *uint64 BytesReceived *uint64
UploadsFinished *uint64 UploadsFinished *uint64
UploadsCreated *uint64 UploadsCreated *uint64
@ -22,43 +23,35 @@ type Metrics struct {
// incRequestsTotal increases the counter for this request method atomically by // incRequestsTotal increases the counter for this request method atomically by
// one. The method must be one of GET, HEAD, POST, PATCH, DELETE. // one. The method must be one of GET, HEAD, POST, PATCH, DELETE.
func (m Metrics) incRequestsTotal(method string) { func (m *Metrics) incRequestsTotal(method string) {
if ptr, ok := m.RequestsTotal[method]; ok { if ptr, ok := m.RequestsTotal[method]; ok {
atomic.AddUint64(ptr, 1) atomic.AddUint64(ptr, 1)
} }
} }
// incErrorsTotal increases the counter for this error atomically by one. // incErrorsTotal increases the counter for this error atomically by one.
func (m Metrics) incErrorsTotal(err error) { func (m *Metrics) incErrorsTotal(err HTTPError) {
msg := err.Error() m.ErrorsTotal.incError(err)
if addr, ok := m.ErrorsTotal[msg]; ok {
atomic.AddUint64(addr, 1)
} else {
addr := new(uint64)
*addr = 1
m.ErrorsTotal[msg] = addr
}
} }
// incBytesReceived increases the number of received bytes atomically be the // incBytesReceived increases the number of received bytes atomically be the
// specified number. // specified number.
func (m Metrics) incBytesReceived(delta uint64) { func (m *Metrics) incBytesReceived(delta uint64) {
atomic.AddUint64(m.BytesReceived, delta) atomic.AddUint64(m.BytesReceived, delta)
} }
// incUploadsFinished increases the counter for finished uploads atomically by one. // incUploadsFinished increases the counter for finished uploads atomically by one.
func (m Metrics) incUploadsFinished() { func (m *Metrics) incUploadsFinished() {
atomic.AddUint64(m.UploadsFinished, 1) atomic.AddUint64(m.UploadsFinished, 1)
} }
// incUploadsCreated increases the counter for completed uploads atomically by one. // incUploadsCreated increases the counter for completed uploads atomically by one.
func (m Metrics) incUploadsCreated() { func (m *Metrics) incUploadsCreated() {
atomic.AddUint64(m.UploadsCreated, 1) atomic.AddUint64(m.UploadsCreated, 1)
} }
// incUploadsTerminated increases the counter for completed uploads atomically by one. // incUploadsTerminated increases the counter for completed uploads atomically by one.
func (m Metrics) incUploadsTerminated() { func (m *Metrics) incUploadsTerminated() {
atomic.AddUint64(m.UploadsTerminated, 1) atomic.AddUint64(m.UploadsTerminated, 1)
} }
@ -80,7 +73,49 @@ func newMetrics() Metrics {
} }
} }
func newErrorsTotalMap() map[string]*uint64 { // ErrorsTotalMap stores the counter for the different http errors.
m := make(map[string]*uint64, 20) type ErrorsTotalMap struct {
return m sync.RWMutex
m map[HTTPError]*uint64
}
func newErrorsTotalMap() ErrorsTotalMap {
m := make(map[HTTPError]*uint64, 20)
return ErrorsTotalMap{
m: m,
}
}
// incErrorsTotal increases the counter for this error atomically by one.
func (e *ErrorsTotalMap) incError(err HTTPError) {
// The goal is to have a valid ptr to the number of HTTPError
e.RLock()
if ptr, ok := e.m[err]; !ok {
// The ptr does not seem to exist for this err
// Hence we create it (using a write lock)
e.RUnlock()
e.Lock()
// We ensure that the ptr wasn't created in the meantime
if ptr, ok = e.m[err]; !ok {
ptr = new(uint64)
*ptr = 0
e.m[err] = ptr
}
e.Unlock()
} else {
e.RUnlock()
}
// We can then increase the counter
atomic.AddUint64(e.m[err], 1)
}
// Load retrieves the map of the counter pointers atomically
func (e *ErrorsTotalMap) Load() (m map[HTTPError]*uint64) {
m = make(map[HTTPError]*uint64, len(e.m))
e.RLock()
for err, ptr := range e.m {
m[err] = ptr
}
e.RUnlock()
return
} }

View File

@ -13,6 +13,8 @@ import (
"github.com/tus/tusd" "github.com/tus/tusd"
"strconv"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
) )
@ -23,8 +25,8 @@ var (
[]string{"method"}, nil) []string{"method"}, nil)
errorsTotalDesc = prometheus.NewDesc( errorsTotalDesc = prometheus.NewDesc(
"tusd_errors_total", "tusd_errors_total",
"Total number of erorrs per cause.", "Total number of errors per cause.",
[]string{"cause"}, nil) []string{"status", "cause"}, nil)
bytesReceivedDesc = prometheus.NewDesc( bytesReceivedDesc = prometheus.NewDesc(
"tusd_bytes_received", "tusd_bytes_received",
"Number of bytes received for uploads.", "Number of bytes received for uploads.",
@ -73,12 +75,13 @@ func (c Collector) Collect(metrics chan<- prometheus.Metric) {
) )
} }
for error, valuePtr := range c.metrics.ErrorsTotal { for httpError, valuePtr := range c.metrics.ErrorsTotal.Load() {
metrics <- prometheus.MustNewConstMetric( metrics <- prometheus.MustNewConstMetric(
errorsTotalDesc, errorsTotalDesc,
prometheus.GaugeValue, prometheus.GaugeValue,
float64(atomic.LoadUint64(valuePtr)), float64(atomic.LoadUint64(valuePtr)),
error, strconv.Itoa(httpError.StatusCode()),
httpError.Error(),
) )
} }

View File

@ -603,9 +603,12 @@ func (handler *UnroutedHandler) sendError(w http.ResponseWriter, r *http.Request
err = ErrNotFound err = ErrNotFound
} }
status := 500 statusErr, ok := err.(HTTPError)
if statusErr, ok := err.(HTTPError); ok { if !ok {
status = statusErr.StatusCode() statusErr = httpError{
error: err,
statusCode: 500, // default status code
}
} }
reason := err.Error() + "\n" reason := err.Error() + "\n"
@ -615,12 +618,12 @@ func (handler *UnroutedHandler) sendError(w http.ResponseWriter, r *http.Request
w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("Content-Length", strconv.Itoa(len(reason))) w.Header().Set("Content-Length", strconv.Itoa(len(reason)))
w.WriteHeader(status) w.WriteHeader(statusErr.StatusCode())
w.Write([]byte(reason)) w.Write([]byte(reason))
handler.log("ResponseOutgoing", "status", strconv.Itoa(status), "method", r.Method, "path", r.URL.Path, "error", err.Error()) handler.log("ResponseOutgoing", "status", strconv.Itoa(statusErr.StatusCode()), "method", r.Method, "path", r.URL.Path, "error", err.Error())
go handler.Metrics.incErrorsTotal(err) go handler.Metrics.incErrorsTotal(statusErr)
} }
// sendResp writes the header to w with the specified status code. // sendResp writes the header to w with the specified status code.