diff --git a/storage/storage.go b/storage/storage.go index f6728b8..d3230bd 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -5,7 +5,13 @@ import ( "context" "fmt" "io" + "math" "net/http" + "sort" + + "github.com/aws/aws-sdk-go-v2/service/s3/types" + + "github.com/docker/go-units" "github.com/aws/aws-sdk-go-v2/aws" @@ -31,6 +37,8 @@ import ( ) const PROOF_EXTENSION = ".obao" +const S3_MULTIPART_MAX_PARTS = 9500 +const S3_MULTIPART_MIN_PART_SIZE = uint64(5 * units.MiB) var _ StorageService = (*StorageServiceDefault)(nil) @@ -59,6 +67,7 @@ type StorageService interface { DeleteObject(ctx context.Context, protocol StorageProtocol, objectHash []byte) error DeleteObjectProof(ctx context.Context, protocol StorageProtocol, objectHash []byte) error S3Client(ctx context.Context) (*s3.Client, error) + S3MultipartUpload(ctx context.Context, data io.ReadCloser, bucket, key string, size uint64) error } type StorageServiceDefault struct { @@ -319,7 +328,82 @@ func (s StorageServiceDefault) S3Client(ctx context.Context) (*s3.Client, error) } return s3.NewFromConfig(cfg), nil +} +func (s StorageServiceDefault) S3MultipartUpload(ctx context.Context, data io.ReadCloser, bucket, key string, size uint64) error { + client, err := s.S3Client(ctx) + if err != nil { + return err + } + + mu, err := client.CreateMultipartUpload(ctx, &s3.CreateMultipartUploadInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + }) + if err != nil { + return err + } + + partSize := S3_MULTIPART_MIN_PART_SIZE + totalParts := int(math.Ceil(float64(size) / float64(partSize))) + if totalParts > S3_MULTIPART_MAX_PARTS { + partSize = size / S3_MULTIPART_MAX_PARTS + totalParts = S3_MULTIPART_MAX_PARTS + } + + var completedParts []types.CompletedPart + + for partNum := 1; partNum <= totalParts; partNum++ { + partData := make([]byte, partSize) + readSize, err := data.Read(partData) + if err != nil && err != io.EOF { + return err + } + + uploadPartOutput, err := client.UploadPart(ctx, &s3.UploadPartInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + PartNumber: aws.Int32(int32(partNum)), + UploadId: mu.UploadId, + Body: bytes.NewReader(partData[:readSize]), + }) + if err != nil { + // Abort the multipart upload in case of error + _, abortErr := client.AbortMultipartUpload(ctx, &s3.AbortMultipartUploadInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + UploadId: mu.UploadId, + }) + if abortErr != nil { + s.logger.Error("error aborting multipart upload", zap.Error(abortErr)) + } + return err + } + + completedParts = append(completedParts, types.CompletedPart{ + ETag: uploadPartOutput.ETag, + PartNumber: aws.Int32(int32(partNum)), + }) + } + + // Ensure parts are ordered by part number before completing the upload + sort.Slice(completedParts, func(i, j int) bool { + return *completedParts[i].PartNumber < *completedParts[j].PartNumber + }) + + _, err = client.CompleteMultipartUpload(ctx, &s3.CompleteMultipartUploadInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + UploadId: mu.UploadId, + MultipartUpload: &types.CompletedMultipartUpload{ + Parts: completedParts, + }, + }) + if err != nil { + return err + } + + return nil } func (s StorageServiceDefault) getProofPath(protocol StorageProtocol, objectHash []byte) string {