Compare commits

..

2 Commits

2 changed files with 53 additions and 120 deletions

View File

@ -2034,14 +2034,7 @@ func (s *S5API) pinImportCronJob(cid string, url string, proofUrl string, userId
return err return err
} }
__import, err := s._import.GetImport(ctx, parsedCid.Hash.HashBytes()) err = s._import.UpdateStatus(ctx, parsedCid.Hash.HashBytes(), models.ImportStatusProcessing)
if err != nil {
return err
}
__import.Status = models.ImportStatusProcessing
err = s._import.SaveImport(ctx, __import, false)
if err != nil { if err != nil {
return err return err
} }
@ -2067,9 +2060,7 @@ func (s *S5API) pinImportCronJob(cid string, url string, proofUrl string, userId
return nil, err return nil, err
} }
importReader := _import.NewImportReader(s._import, __import, res.Body, parsedCid.Size, 1, totalStages) defer closeBody(res.Body)
defer closeBody(importReader)
if res.StatusCode != http.StatusOK { if res.StatusCode != http.StatusOK {
errMsg := "error fetching URL: " + fetchUrl errMsg := "error fetching URL: " + fetchUrl
@ -2077,15 +2068,26 @@ func (s *S5API) pinImportCronJob(cid string, url string, proofUrl string, userId
return nil, fmt.Errorf(errMsg+" with status: %s", res.Status) return nil, fmt.Errorf(errMsg+" with status: %s", res.Status)
} }
data, err := io.ReadAll(importReader) data, err := io.ReadAll(res.Body)
if err != nil { if err != nil {
s.logger.Error("error reading response body", zap.Error(err)) s.logger.Error("error reading response body", zap.Error(err))
return nil, err return nil, err
} }
err = s._import.UpdateProgress(ctx, parsedCid.Hash.HashBytes(), 1, totalStages)
if err != nil {
return nil, err
}
return data, nil return data, nil
} }
saveAndPin := func(upload *metadata.UploadMetadata) error { saveAndPin := func(upload *metadata.UploadMetadata) error {
err = s._import.UpdateProgress(ctx, parsedCid.Hash.HashBytes(), 3, totalStages)
if err != nil {
return err
}
upload.UserID = userId upload.UserID = userId
if err := s.metadata.SaveUpload(ctx, *upload, true); err != nil { if err := s.metadata.SaveUpload(ctx, *upload, true); err != nil {
return err return err
@ -2119,9 +2121,7 @@ func (s *S5API) pinImportCronJob(cid string, url string, proofUrl string, userId
return fmt.Errorf("hash mismatch") return fmt.Errorf("hash mismatch")
} }
importReader := _import.NewImportReader(s._import, __import, bytes.NewReader(fileData), parsedCid.Size, 2, totalStages) upload, err := s.storage.UploadObject(ctx, s5.GetStorageProtocol(s.protocol), bytes.NewReader(fileData), nil, hash)
upload, err := s.storage.UploadObject(ctx, s5.GetStorageProtocol(s.protocol), importReader, nil, hash)
if err != nil { if err != nil {
return err return err
} }
@ -2174,13 +2174,11 @@ func (s *S5API) pinImportCronJob(cid string, url string, proofUrl string, userId
}(verifier) }(verifier)
importReader := _import.NewImportReader(s._import, __import, verifier, parsedCid.Size, 2, totalStages)
if parsedCid.Size < storage.S3_MULTIPART_MIN_PART_SIZE { if parsedCid.Size < storage.S3_MULTIPART_MIN_PART_SIZE {
_, err = client.PutObject(ctx, &s3.PutObjectInput{ _, err = client.PutObject(ctx, &s3.PutObjectInput{
Bucket: aws.String(s.config.Config().Core.Storage.S3.BufferBucket), Bucket: aws.String(s.config.Config().Core.Storage.S3.BufferBucket),
Key: aws.String(cid), Key: aws.String(cid),
Body: importReader, Body: verifier,
ContentLength: aws.Int64(int64(parsedCid.Size)), ContentLength: aws.Int64(int64(parsedCid.Size)),
}) })
if err != nil { if err != nil {
@ -2188,14 +2186,17 @@ func (s *S5API) pinImportCronJob(cid string, url string, proofUrl string, userId
return err return err
} }
} else { } else {
err := s.storage.S3MultipartUpload(ctx, importReader, s.config.Config().Core.Storage.S3.BufferBucket, cid, parsedCid.Size) err := s.storage.S3MultipartUpload(ctx, verifier, s.config.Config().Core.Storage.S3.BufferBucket, cid, parsedCid.Size)
if err != nil { if err != nil {
s.logger.Error("error uploading object", zap.Error(err)) s.logger.Error("error uploading object", zap.Error(err))
return err return err
} }
} }
importReader = _import.NewImportReader(s._import, __import, res.Body, parsedCid.Size, 3, totalStages) err = s._import.UpdateProgress(ctx, parsedCid.Hash.HashBytes(), 2, totalStages)
if err != nil {
return err
}
upload, err := s.storage.UploadObject(ctx, s5.GetStorageProtocol(s.protocol), nil, &renter.MultiPartUploadParams{ upload, err := s.storage.UploadObject(ctx, s5.GetStorageProtocol(s.protocol), nil, &renter.MultiPartUploadParams{
ReaderFactory: func(start uint, end uint) (io.ReadCloser, error) { ReaderFactory: func(start uint, end uint) (io.ReadCloser, error) {
@ -2216,11 +2217,6 @@ func (s *S5API) pinImportCronJob(cid string, url string, proofUrl string, userId
return nil, err return nil, err
} }
err = importReader.ReadBytes(int(end - start))
if err != nil {
return nil, err
}
return object.Body, nil return object.Body, nil
}, },
Bucket: s.config.Config().Core.Storage.S3.BufferBucket, Bucket: s.config.Config().Core.Storage.S3.BufferBucket,

View File

@ -3,7 +3,6 @@ package _import
import ( import (
"context" "context"
"errors" "errors"
"io"
"time" "time"
"git.lumeweb.com/LumeWeb/portal/db/models" "git.lumeweb.com/LumeWeb/portal/db/models"
@ -15,7 +14,6 @@ import (
var ErrNotFound = gorm.ErrRecordNotFound var ErrNotFound = gorm.ErrRecordNotFound
var _ ImportService = (*ImportServiceDefault)(nil) var _ ImportService = (*ImportServiceDefault)(nil)
var _ io.ReadSeekCloser = (*ImportReader)(nil)
type ImportMetadata struct { type ImportMetadata struct {
ID uint ID uint
@ -32,6 +30,8 @@ type ImportService interface {
SaveImport(ctx context.Context, metadata ImportMetadata, skipExisting bool) error SaveImport(ctx context.Context, metadata ImportMetadata, skipExisting bool) error
GetImport(ctx context.Context, objectHash []byte) (ImportMetadata, error) GetImport(ctx context.Context, objectHash []byte) (ImportMetadata, error)
DeleteImport(ctx context.Context, objectHash []byte) error DeleteImport(ctx context.Context, objectHash []byte) error
UpdateProgress(ctx context.Context, objectHash []byte, stage int, totalStages int) error
UpdateStatus(ctx context.Context, objectHash []byte, status models.ImportStatus) error
} }
func (u ImportMetadata) IsEmpty() bool { func (u ImportMetadata) IsEmpty() bool {
@ -63,6 +63,36 @@ type ImportServiceDefault struct {
db *gorm.DB db *gorm.DB
} }
func (i ImportServiceDefault) UpdateProgress(ctx context.Context, objectHash []byte, stage int, totalStages int) error {
_import, err := i.GetImport(ctx, objectHash)
if err != nil {
return err
}
if _import.IsEmpty() {
return ErrNotFound
}
_import.Progress = float64(stage) / float64(totalStages) * 100.0
return i.SaveImport(ctx, _import, false)
}
func (i ImportServiceDefault) UpdateStatus(ctx context.Context, objectHash []byte, status models.ImportStatus) error {
_import, err := i.GetImport(ctx, objectHash)
if err != nil {
return err
}
if _import.IsEmpty() {
return ErrNotFound
}
_import.Status = status
return i.SaveImport(ctx, _import, false)
}
func (i ImportServiceDefault) SaveImport(ctx context.Context, metadata ImportMetadata, skipExisting bool) error { func (i ImportServiceDefault) SaveImport(ctx context.Context, metadata ImportMetadata, skipExisting bool) error {
var __import models.Import var __import models.Import
@ -185,96 +215,3 @@ func NewImportService(params ImportServiceParams) *ImportServiceDefault {
db: params.Db, db: params.Db,
} }
} }
type ImportReader struct {
service ImportService
meta ImportMetadata
reader io.Reader
size uint64
stage int
totalStages int
bytesRead uint64
}
func (i *ImportReader) Seek(offset int64, whence int) (int64, error) {
if seeker, ok := i.reader.(io.Seeker); ok {
// If seeking to the start, reset progress based on recorded bytes
if whence == io.SeekStart && offset == 0 {
i.bytesRead = 0
i.meta.Progress = 0
if err := i.service.SaveImport(context.Background(), i.meta, false); err != nil {
return 0, err
}
}
return seeker.Seek(offset, whence)
}
return 0, errors.New("Seek not supported")
}
func (i *ImportReader) Close() error {
if closer, ok := i.reader.(io.Closer); ok {
return closer.Close()
}
return nil
}
func (i *ImportReader) Read(p []byte) (n int, err error) {
n, err = i.reader.Read(p)
if err != nil {
if err == io.EOF {
return n, err
}
return 0, err
}
// Update cumulative bytes read
i.bytesRead += uint64(n)
err = i.ReadBytes(0)
if err != nil {
return 0, err
}
return n, nil
}
func (i *ImportReader) ReadBytes(n int) (err error) {
if n > 0 {
i.bytesRead += uint64(n)
}
stageProgress := float64(100) / float64(i.totalStages)
// Calculate progress based on bytes read
i.meta.Progress = float64(i.bytesRead) / float64(i.size) * 100.0
// Adjust progress for current stage
if i.stage > 1 {
i.meta.Progress += float64(i.stage-1) * stageProgress
}
// Ensure progress doesn't exceed 100%
if i.meta.Progress > 100 {
i.meta.Progress = 100
}
// Save import progress
err = i.service.SaveImport(context.Background(), i.meta, false)
if err != nil {
return err
}
return nil
}
func NewImportReader(service ImportService, meta ImportMetadata, reader io.Reader, size uint64, stage, totalStages int) *ImportReader {
return &ImportReader{
service: service,
meta: meta,
reader: reader,
size: size,
stage: stage,
totalStages: totalStages,
}
}