feat: add generic s3 upload tracking

This commit is contained in:
Derrick Hammer 2024-02-28 11:36:53 -05:00
parent b2325eb9af
commit b030de9714
Signed by: pcfreak30
GPG Key ID: C997C339BE476FF2
3 changed files with 81 additions and 10 deletions

View File

@ -71,6 +71,7 @@ func NewDatabase(lc fx.Lifecycle, params DatabaseParams) *gorm.DB {
&models.PublicKey{}, &models.PublicKey{},
&models.Upload{}, &models.Upload{},
&models.User{}, &models.User{},
&models.S3Upload{},
&models.S5Challenge{}, &models.S5Challenge{},
&models.TusLock{}, &models.TusLock{},
&models.TusUpload{}, &models.TusUpload{},

10
db/models/s3_upload.go Normal file
View File

@ -0,0 +1,10 @@
package models
import "gorm.io/gorm"
type S3Upload struct {
gorm.Model
UploadID string `gorm:"unique;not null"`
Bucket string `gorm:"not null"`
Key string `gorm:"not null"`
}

View File

@ -3,12 +3,15 @@ package storage
import ( import (
"bytes" "bytes"
"context" "context"
"errors"
"fmt" "fmt"
"io" "io"
"math" "math"
"net/http" "net/http"
"sort" "sort"
"git.lumeweb.com/LumeWeb/portal/db/models"
"github.com/aws/aws-sdk-go-v2/service/s3/types" "github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/docker/go-units" "github.com/docker/go-units"
@ -336,13 +339,8 @@ func (s StorageServiceDefault) S3MultipartUpload(ctx context.Context, data io.Re
return err return err
} }
mu, err := client.CreateMultipartUpload(ctx, &s3.CreateMultipartUploadInput{ var uploadId string
Bucket: aws.String(bucket), var lastPartNumber int32
Key: aws.String(key),
})
if err != nil {
return err
}
partSize := S3_MULTIPART_MIN_PART_SIZE partSize := S3_MULTIPART_MIN_PART_SIZE
totalParts := int(math.Ceil(float64(size) / float64(partSize))) totalParts := int(math.Ceil(float64(size) / float64(partSize)))
@ -353,6 +351,60 @@ func (s StorageServiceDefault) S3MultipartUpload(ctx context.Context, data io.Re
var completedParts []types.CompletedPart var completedParts []types.CompletedPart
var s3Upload models.S3Upload
s3Upload.Bucket = bucket
s3Upload.Key = key
ret := s.db.Model(&s3Upload).First(&s3Upload)
if ret.Error != nil {
if !errors.Is(ret.Error, gorm.ErrRecordNotFound) {
return ret.Error
}
} else {
uploadId = s3Upload.UploadID
}
if uploadId == "" {
mu, err := client.CreateMultipartUpload(ctx, &s3.CreateMultipartUploadInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
})
if err != nil {
return err
}
uploadId = *mu.UploadId
s3Upload.UploadID = uploadId
ret = s.db.Create(&s3Upload)
if ret.Error != nil {
return ret.Error
}
} else {
parts, err := client.ListParts(ctx, &s3.ListPartsInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
UploadId: aws.String(uploadId),
})
if err != nil {
return err
}
for _, part := range parts.Parts {
if uint64(*part.Size) == partSize {
if *part.PartNumber > lastPartNumber {
lastPartNumber = *part.PartNumber
completedParts = append(completedParts, types.CompletedPart{
ETag: part.ETag,
PartNumber: part.PartNumber,
})
}
}
}
}
for partNum := 1; partNum <= totalParts; partNum++ { for partNum := 1; partNum <= totalParts; partNum++ {
partData := make([]byte, partSize) partData := make([]byte, partSize)
readSize, err := data.Read(partData) readSize, err := data.Read(partData)
@ -360,11 +412,15 @@ func (s StorageServiceDefault) S3MultipartUpload(ctx context.Context, data io.Re
return err return err
} }
if partNum <= int(lastPartNumber) {
continue
}
uploadPartOutput, err := client.UploadPart(ctx, &s3.UploadPartInput{ uploadPartOutput, err := client.UploadPart(ctx, &s3.UploadPartInput{
Bucket: aws.String(bucket), Bucket: aws.String(bucket),
Key: aws.String(key), Key: aws.String(key),
PartNumber: aws.Int32(int32(partNum)), PartNumber: aws.Int32(int32(partNum)),
UploadId: mu.UploadId, UploadId: aws.String(uploadId),
Body: bytes.NewReader(partData[:readSize]), Body: bytes.NewReader(partData[:readSize]),
}) })
if err != nil { if err != nil {
@ -372,7 +428,7 @@ func (s StorageServiceDefault) S3MultipartUpload(ctx context.Context, data io.Re
_, abortErr := client.AbortMultipartUpload(ctx, &s3.AbortMultipartUploadInput{ _, abortErr := client.AbortMultipartUpload(ctx, &s3.AbortMultipartUploadInput{
Bucket: aws.String(bucket), Bucket: aws.String(bucket),
Key: aws.String(key), Key: aws.String(key),
UploadId: mu.UploadId, UploadId: aws.String(uploadId),
}) })
if abortErr != nil { if abortErr != nil {
s.logger.Error("error aborting multipart upload", zap.Error(abortErr)) s.logger.Error("error aborting multipart upload", zap.Error(abortErr))
@ -396,7 +452,7 @@ func (s StorageServiceDefault) S3MultipartUpload(ctx context.Context, data io.Re
_, err = client.CompleteMultipartUpload(ctx, &s3.CompleteMultipartUploadInput{ _, err = client.CompleteMultipartUpload(ctx, &s3.CompleteMultipartUploadInput{
Bucket: aws.String(bucket), Bucket: aws.String(bucket),
Key: aws.String(key), Key: aws.String(key),
UploadId: mu.UploadId, UploadId: aws.String(uploadId),
MultipartUpload: &types.CompletedMultipartUpload{ MultipartUpload: &types.CompletedMultipartUpload{
Parts: completedParts, Parts: completedParts,
}, },
@ -405,6 +461,10 @@ func (s StorageServiceDefault) S3MultipartUpload(ctx context.Context, data io.Re
return err return err
} }
if tx := s.db.Delete(&s3Upload); tx.Error != nil {
return tx.Error
}
return nil return nil
} }