From b030de9714ab4433fe4ef0ad6186826c1ad2eb25 Mon Sep 17 00:00:00 2001 From: Derrick Hammer Date: Wed, 28 Feb 2024 11:36:53 -0500 Subject: [PATCH] feat: add generic s3 upload tracking --- db/db.go | 1 + db/models/s3_upload.go | 10 ++++++ storage/storage.go | 80 ++++++++++++++++++++++++++++++++++++------ 3 files changed, 81 insertions(+), 10 deletions(-) create mode 100644 db/models/s3_upload.go diff --git a/db/db.go b/db/db.go index de45ed6..4817a8d 100644 --- a/db/db.go +++ b/db/db.go @@ -71,6 +71,7 @@ func NewDatabase(lc fx.Lifecycle, params DatabaseParams) *gorm.DB { &models.PublicKey{}, &models.Upload{}, &models.User{}, + &models.S3Upload{}, &models.S5Challenge{}, &models.TusLock{}, &models.TusUpload{}, diff --git a/db/models/s3_upload.go b/db/models/s3_upload.go new file mode 100644 index 0000000..348d1dc --- /dev/null +++ b/db/models/s3_upload.go @@ -0,0 +1,10 @@ +package models + +import "gorm.io/gorm" + +type S3Upload struct { + gorm.Model + UploadID string `gorm:"unique;not null"` + Bucket string `gorm:"not null"` + Key string `gorm:"not null"` +} diff --git a/storage/storage.go b/storage/storage.go index 9f24c9f..09efb24 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -3,12 +3,15 @@ package storage import ( "bytes" "context" + "errors" "fmt" "io" "math" "net/http" "sort" + "git.lumeweb.com/LumeWeb/portal/db/models" + "github.com/aws/aws-sdk-go-v2/service/s3/types" "github.com/docker/go-units" @@ -336,13 +339,8 @@ func (s StorageServiceDefault) S3MultipartUpload(ctx context.Context, data io.Re return err } - mu, err := client.CreateMultipartUpload(ctx, &s3.CreateMultipartUploadInput{ - Bucket: aws.String(bucket), - Key: aws.String(key), - }) - if err != nil { - return err - } + var uploadId string + var lastPartNumber int32 partSize := S3_MULTIPART_MIN_PART_SIZE totalParts := int(math.Ceil(float64(size) / float64(partSize))) @@ -353,6 +351,60 @@ func (s StorageServiceDefault) S3MultipartUpload(ctx context.Context, data io.Re var completedParts []types.CompletedPart + var s3Upload models.S3Upload + + s3Upload.Bucket = bucket + s3Upload.Key = key + + ret := s.db.Model(&s3Upload).First(&s3Upload) + if ret.Error != nil { + if !errors.Is(ret.Error, gorm.ErrRecordNotFound) { + return ret.Error + } + } else { + uploadId = s3Upload.UploadID + } + + if uploadId == "" { + mu, err := client.CreateMultipartUpload(ctx, &s3.CreateMultipartUploadInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + }) + if err != nil { + return err + } + + uploadId = *mu.UploadId + + s3Upload.UploadID = uploadId + ret = s.db.Create(&s3Upload) + if ret.Error != nil { + return ret.Error + } + } else { + parts, err := client.ListParts(ctx, &s3.ListPartsInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + UploadId: aws.String(uploadId), + }) + + if err != nil { + return err + } + + for _, part := range parts.Parts { + if uint64(*part.Size) == partSize { + if *part.PartNumber > lastPartNumber { + lastPartNumber = *part.PartNumber + completedParts = append(completedParts, types.CompletedPart{ + ETag: part.ETag, + PartNumber: part.PartNumber, + }) + } + } + } + } + for partNum := 1; partNum <= totalParts; partNum++ { partData := make([]byte, partSize) readSize, err := data.Read(partData) @@ -360,11 +412,15 @@ func (s StorageServiceDefault) S3MultipartUpload(ctx context.Context, data io.Re return err } + if partNum <= int(lastPartNumber) { + continue + } + uploadPartOutput, err := client.UploadPart(ctx, &s3.UploadPartInput{ Bucket: aws.String(bucket), Key: aws.String(key), PartNumber: aws.Int32(int32(partNum)), - UploadId: mu.UploadId, + UploadId: aws.String(uploadId), Body: bytes.NewReader(partData[:readSize]), }) if err != nil { @@ -372,7 +428,7 @@ func (s StorageServiceDefault) S3MultipartUpload(ctx context.Context, data io.Re _, abortErr := client.AbortMultipartUpload(ctx, &s3.AbortMultipartUploadInput{ Bucket: aws.String(bucket), Key: aws.String(key), - UploadId: mu.UploadId, + UploadId: aws.String(uploadId), }) if abortErr != nil { s.logger.Error("error aborting multipart upload", zap.Error(abortErr)) @@ -396,7 +452,7 @@ func (s StorageServiceDefault) S3MultipartUpload(ctx context.Context, data io.Re _, err = client.CompleteMultipartUpload(ctx, &s3.CompleteMultipartUploadInput{ Bucket: aws.String(bucket), Key: aws.String(key), - UploadId: mu.UploadId, + UploadId: aws.String(uploadId), MultipartUpload: &types.CompletedMultipartUpload{ Parts: completedParts, }, @@ -405,6 +461,10 @@ func (s StorageServiceDefault) S3MultipartUpload(ctx context.Context, data io.Re return err } + if tx := s.db.Delete(&s3Upload); tx.Error != nil { + return tx.Error + } + return nil }