From f6be51b94217e7351dab77cb4e2f1d2f5907504e Mon Sep 17 00:00:00 2001 From: Derrick Hammer Date: Sat, 23 Mar 2024 10:43:18 -0400 Subject: [PATCH] refactor: move away from a granular tracking to a simpler, stage based tracking --- api/s5/s5.go | 46 +++++++++++------------ import/import.go | 95 ------------------------------------------------ 2 files changed, 21 insertions(+), 120 deletions(-) diff --git a/api/s5/s5.go b/api/s5/s5.go index e34a16b..259df14 100644 --- a/api/s5/s5.go +++ b/api/s5/s5.go @@ -2034,14 +2034,7 @@ func (s *S5API) pinImportCronJob(cid string, url string, proofUrl string, userId return err } - __import, err := s._import.GetImport(ctx, parsedCid.Hash.HashBytes()) - if err != nil { - return err - } - - __import.Status = models.ImportStatusProcessing - - err = s._import.SaveImport(ctx, __import, false) + err = s._import.UpdateStatus(ctx, parsedCid.Hash.HashBytes(), models.ImportStatusProcessing) if err != nil { return err } @@ -2067,9 +2060,7 @@ func (s *S5API) pinImportCronJob(cid string, url string, proofUrl string, userId return nil, err } - importReader := _import.NewImportReader(s._import, __import, res.Body, parsedCid.Size, 1, totalStages) - - defer closeBody(importReader) + defer closeBody(res.Body) if res.StatusCode != http.StatusOK { errMsg := "error fetching URL: " + fetchUrl @@ -2077,15 +2068,26 @@ func (s *S5API) pinImportCronJob(cid string, url string, proofUrl string, userId return nil, fmt.Errorf(errMsg+" with status: %s", res.Status) } - data, err := io.ReadAll(importReader) + data, err := io.ReadAll(res.Body) if err != nil { s.logger.Error("error reading response body", zap.Error(err)) return nil, err } + + err = s._import.UpdateProgress(ctx, parsedCid.Hash.HashBytes(), 1, totalStages) + if err != nil { + return nil, err + } + return data, nil } saveAndPin := func(upload *metadata.UploadMetadata) error { + err = s._import.UpdateProgress(ctx, parsedCid.Hash.HashBytes(), 3, totalStages) + if err != nil { + return err + } + upload.UserID = userId if err := s.metadata.SaveUpload(ctx, *upload, true); err != nil { return err @@ -2119,9 +2121,7 @@ func (s *S5API) pinImportCronJob(cid string, url string, proofUrl string, userId return fmt.Errorf("hash mismatch") } - importReader := _import.NewImportReader(s._import, __import, bytes.NewReader(fileData), parsedCid.Size, 2, totalStages) - - upload, err := s.storage.UploadObject(ctx, s5.GetStorageProtocol(s.protocol), importReader, nil, hash) + upload, err := s.storage.UploadObject(ctx, s5.GetStorageProtocol(s.protocol), bytes.NewReader(fileData), nil, hash) if err != nil { return err } @@ -2174,13 +2174,11 @@ func (s *S5API) pinImportCronJob(cid string, url string, proofUrl string, userId }(verifier) - importReader := _import.NewImportReader(s._import, __import, verifier, parsedCid.Size, 2, totalStages) - if parsedCid.Size < storage.S3_MULTIPART_MIN_PART_SIZE { _, err = client.PutObject(ctx, &s3.PutObjectInput{ Bucket: aws.String(s.config.Config().Core.Storage.S3.BufferBucket), Key: aws.String(cid), - Body: importReader, + Body: verifier, ContentLength: aws.Int64(int64(parsedCid.Size)), }) if err != nil { @@ -2188,14 +2186,17 @@ func (s *S5API) pinImportCronJob(cid string, url string, proofUrl string, userId return err } } else { - err := s.storage.S3MultipartUpload(ctx, importReader, s.config.Config().Core.Storage.S3.BufferBucket, cid, parsedCid.Size) + err := s.storage.S3MultipartUpload(ctx, verifier, s.config.Config().Core.Storage.S3.BufferBucket, cid, parsedCid.Size) if err != nil { s.logger.Error("error uploading object", zap.Error(err)) return err } } - importReader = _import.NewImportReader(s._import, __import, res.Body, parsedCid.Size, 3, totalStages) + err = s._import.UpdateProgress(ctx, parsedCid.Hash.HashBytes(), 2, totalStages) + if err != nil { + return err + } upload, err := s.storage.UploadObject(ctx, s5.GetStorageProtocol(s.protocol), nil, &renter.MultiPartUploadParams{ ReaderFactory: func(start uint, end uint) (io.ReadCloser, error) { @@ -2216,11 +2217,6 @@ func (s *S5API) pinImportCronJob(cid string, url string, proofUrl string, userId return nil, err } - err = importReader.ReadBytes(int(end - start)) - if err != nil { - return nil, err - } - return object.Body, nil }, Bucket: s.config.Config().Core.Storage.S3.BufferBucket, diff --git a/import/import.go b/import/import.go index 6c84eed..3f4967f 100644 --- a/import/import.go +++ b/import/import.go @@ -3,7 +3,6 @@ package _import import ( "context" "errors" - "io" "time" "git.lumeweb.com/LumeWeb/portal/db/models" @@ -15,7 +14,6 @@ import ( var ErrNotFound = gorm.ErrRecordNotFound var _ ImportService = (*ImportServiceDefault)(nil) -var _ io.ReadSeekCloser = (*ImportReader)(nil) type ImportMetadata struct { ID uint @@ -217,96 +215,3 @@ func NewImportService(params ImportServiceParams) *ImportServiceDefault { db: params.Db, } } - -type ImportReader struct { - service ImportService - meta ImportMetadata - reader io.Reader - size uint64 - stage int - totalStages int - bytesRead uint64 -} - -func (i *ImportReader) Seek(offset int64, whence int) (int64, error) { - if seeker, ok := i.reader.(io.Seeker); ok { - // If seeking to the start, reset progress based on recorded bytes - if whence == io.SeekStart && offset == 0 { - i.bytesRead = 0 - i.meta.Progress = 0 - if err := i.service.SaveImport(context.Background(), i.meta, false); err != nil { - return 0, err - } - } - return seeker.Seek(offset, whence) - } - - return 0, errors.New("Seek not supported") -} - -func (i *ImportReader) Close() error { - if closer, ok := i.reader.(io.Closer); ok { - return closer.Close() - } - - return nil -} - -func (i *ImportReader) Read(p []byte) (n int, err error) { - n, err = i.reader.Read(p) - if err != nil { - if err == io.EOF { - return n, err - } - return 0, err - } - - // Update cumulative bytes read - i.bytesRead += uint64(n) - - err = i.ReadBytes(0) - if err != nil { - return 0, err - } - - return n, nil -} - -func (i *ImportReader) ReadBytes(n int) (err error) { - if n > 0 { - i.bytesRead += uint64(n) - } - stageProgress := float64(100) / float64(i.totalStages) - - // Calculate progress based on bytes read - i.meta.Progress = float64(i.bytesRead) / float64(i.size) * 100.0 - - // Adjust progress for current stage - if i.stage > 1 { - i.meta.Progress += float64(i.stage-1) * stageProgress - } - - // Ensure progress doesn't exceed 100% - if i.meta.Progress > 100 { - i.meta.Progress = 100 - } - - // Save import progress - err = i.service.SaveImport(context.Background(), i.meta, false) - if err != nil { - return err - } - - return nil -} - -func NewImportReader(service ImportService, meta ImportMetadata, reader io.Reader, size uint64, stage, totalStages int) *ImportReader { - return &ImportReader{ - service: service, - meta: meta, - reader: reader, - size: size, - stage: stage, - totalStages: totalStages, - } -}