diff --git a/api/middleware/middleware.go b/api/middleware/middleware.go index 47396c7..5bff6be 100644 --- a/api/middleware/middleware.go +++ b/api/middleware/middleware.go @@ -186,3 +186,11 @@ func GetUserFromContext(ctx context.Context, key ...string) uint { return userId } +func CtxAborted(ctx context.Context) bool { + select { + case <-ctx.Done(): + return true + default: + return false + } +} diff --git a/api/s5/s5.go b/api/s5/s5.go index a86c9ed..e5c41c9 100644 --- a/api/s5/s5.go +++ b/api/s5/s5.go @@ -844,9 +844,13 @@ func (s *S5API) accountPinDelete(jc jape.Context) { jc.ResponseWriter.WriteHeader(http.StatusNoContent) } -func (s *S5API) getManifestCids(cid *encoding.CID) ([]*encoding.CID, error) { +func (s *S5API) getManifestCids(ctx context.Context, cid *encoding.CID) ([]*encoding.CID, error) { var cids []*encoding.CID + if middleware.CtxAborted(ctx) { + return nil, ctx.Err() + } + manifest, err := s.getNode().Services().Storage().GetMetadataByCID(cid) if err != nil { return nil, err @@ -869,6 +873,9 @@ func (s *S5API) getManifestCids(cid *encoding.CID) ([]*encoding.CID, error) { dir := manifest.(*s5libmetadata.DirectoryMetadata) lo.ForEach(lo.Values(dir.Directories.Items()), func(d *s5libmetadata.DirectoryReference, _i int) { + if middleware.CtxAborted(ctx) { + return + } entry, err := s.getNode().Services().Registry().Get(d.PublicKey) if err != nil || entry == nil { s.logger.Error("Error getting registry entry", zap.Error(err)) @@ -881,7 +888,7 @@ func (s *S5API) getManifestCids(cid *encoding.CID) ([]*encoding.CID, error) { return } - childCids, err := s.getManifestCids(cid) + childCids, err := s.getManifestCids(ctx, cid) if err != nil { s.logger.Error("Error getting child manifest CIDs", zap.Error(err)) return @@ -902,6 +909,10 @@ func (s *S5API) getManifestCids(cid *encoding.CID) ([]*encoding.CID, error) { }) } + if middleware.CtxAborted(ctx) { + return nil, ctx.Err() + } + return cids, nil } @@ -917,7 +928,7 @@ func (s *S5API) accountPinManifest(jc jape.Context, userId uint, cid *encoding.C cid *encoding.CID } - cids, err := s.getManifestCids(cid) + cids, err := s.getManifestCids(jc.Request.Context(), cid) if err != nil { s.sendErrorResponse(jc, NewS5Error(ErrKeyInvalidOperation, err)) return @@ -982,6 +993,10 @@ func (s *S5API) accountPinManifest(jc jape.Context, userId uint, cid *encoding.C }() q.Wait() + + if middleware.CtxAborted(jc.Request.Context()) { + return + } jc.Encode(&results) } @@ -1134,6 +1149,10 @@ func (s *S5API) pinEntity(ctx context.Context, userId uint, cid *encoding.CID) e return nil } + if middleware.CtxAborted(ctx) { + return ctx.Err() + } + jobName := fmt.Sprintf("pin-import-%s", cid64) if job := s.cron.GetJobByName(jobName); job == nil {