refactor: use new RetryableTask abstraction and move task function as a private method

This commit is contained in:
Derrick Hammer 2024-01-28 16:26:15 -05:00
parent 1af1ea9505
commit 2a067102da
Signed by: pcfreak30
GPG Key ID: C997C339BE476FF2
1 changed files with 95 additions and 107 deletions

View File

@ -17,7 +17,6 @@ import (
"github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/s3"
s3types "github.com/aws/aws-sdk-go-v2/service/s3/types" s3types "github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/go-co-op/gocron/v2"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/spf13/viper" "github.com/spf13/viper"
tusd "github.com/tus/tusd/v2/pkg/handler" tusd "github.com/tus/tusd/v2/pkg/handler"
@ -383,7 +382,7 @@ func (s *StorageServiceImpl) tusWorker() {
s.logger.Error("Could not complete tus upload", zap.Error(err)) s.logger.Error("Could not complete tus upload", zap.Error(err))
continue continue
} }
err = s.ScheduleTusUpload(info.Upload.ID, 0) err = s.ScheduleTusUpload(info.Upload.ID)
if err != nil { if err != nil {
s.logger.Error("Could not schedule tus upload", zap.Error(err)) s.logger.Error("Could not schedule tus upload", zap.Error(err))
continue continue
@ -466,7 +465,7 @@ func (s *StorageServiceImpl) DeleteTusUpload(uploadID string) error {
return nil return nil
} }
func (s *StorageServiceImpl) ScheduleTusUpload(uploadID string, attempt int) error { func (s *StorageServiceImpl) ScheduleTusUpload(uploadID string) error {
find := &models.TusUpload{UploadID: uploadID} find := &models.TusUpload{UploadID: uploadID}
var upload models.TusUpload var upload models.TusUpload
@ -476,26 +475,22 @@ func (s *StorageServiceImpl) ScheduleTusUpload(uploadID string, attempt int) err
return errors.New("upload not found") return errors.New("upload not found")
} }
job, task := s.buildNewTusUploadTask(&upload) task := s.cron.RetryableTask(cron.RetryableTaskParams{
Name: "tusUpload",
if attempt > 0 { Function: s.tusUploadTask,
job = gocron.OneTimeJob(gocron.OneTimeJobStartDateTime(time.Now().Add(time.Duration(attempt) * time.Minute))) Args: []interface{}{&upload},
} Attempt: 0,
Limit: 0,
_, err := s.cron.Scheduler().NewJob(job, task, gocron.WithEventListeners(gocron.AfterJobRunsWithError(func(jobID uuid.UUID, jobName string, err error) { After: func(jobID uuid.UUID, jobName string) {
s.logger.Error("Error running job", zap.Error(err))
err = s.ScheduleTusUpload(uploadID, attempt+1)
if err != nil {
s.logger.Error("Error rescheduling job", zap.Error(err))
}
}),
gocron.AfterJobRuns(func(jobID uuid.UUID, jobName string) {
s.logger.Info("Job finished", zap.String("jobName", jobName), zap.String("uploadID", uploadID)) s.logger.Info("Job finished", zap.String("jobName", jobName), zap.String("uploadID", uploadID))
err := s.DeleteTusUpload(uploadID) err := s.DeleteTusUpload(uploadID)
if err != nil { if err != nil {
s.logger.Error("Error deleting tus upload", zap.Error(err)) s.logger.Error("Error deleting tus upload", zap.Error(err))
} }
}))) },
})
_, err := s.cron.CreateJob(task)
if err != nil { if err != nil {
return err return err
@ -503,118 +498,111 @@ func (s *StorageServiceImpl) ScheduleTusUpload(uploadID string, attempt int) err
return nil return nil
} }
func (s *StorageServiceImpl) buildNewTusUploadTask(upload *models.TusUpload) (job gocron.JobDefinition, task gocron.Task) { func (s *StorageServiceImpl) tusUploadTask(upload *models.TusUpload) error {
job = gocron.OneTimeJob(gocron.OneTimeJobStartImmediately()) ctx := context.Background()
tusUpload, err := s.tusStore.GetUpload(ctx, upload.UploadID)
if err != nil {
s.logger.Error("Could not get upload", zap.Error(err))
return err
}
task = gocron.NewTask( reader, err := tusUpload.GetReader(ctx)
func(upload *models.TusUpload) error { if err != nil {
ctx := context.Background() s.logger.Error("Could not get tus file", zap.Error(err))
tusUpload, err := s.tusStore.GetUpload(ctx, upload.UploadID) return err
if err != nil { }
s.logger.Error("Could not get upload", zap.Error(err))
return err
}
reader, err := tusUpload.GetReader(ctx) hash, byteCount, err := s.GetHash(reader)
if err != nil {
s.logger.Error("Could not get tus file", zap.Error(err))
return err
}
hash, byteCount, err := s.GetHash(reader) if err != nil {
s.logger.Error("Could not compute hash", zap.Error(err))
return err
}
if err != nil { dbHash, err := hex.DecodeString(upload.Hash)
s.logger.Error("Could not compute hash", zap.Error(err))
return err
}
dbHash, err := hex.DecodeString(upload.Hash) if err != nil {
s.logger.Error("Could not decode hash", zap.Error(err))
return err
}
if err != nil { if !bytes.Equal(hash, dbHash) {
s.logger.Error("Could not decode hash", zap.Error(err)) s.logger.Error("Hashes do not match", zap.Any("upload", upload), zap.Any("hash", hash), zap.Any("dbHash", dbHash))
return err return err
} }
if !bytes.Equal(hash, dbHash) { reader, err = tusUpload.GetReader(ctx)
s.logger.Error("Hashes do not match", zap.Any("upload", upload), zap.Any("hash", hash), zap.Any("dbHash", dbHash)) if err != nil {
return err s.logger.Error("Could not get tus file", zap.Error(err))
} return err
}
reader, err = tusUpload.GetReader(ctx) var mimeBuf [512]byte
if err != nil {
s.logger.Error("Could not get tus file", zap.Error(err))
return err
}
var mimeBuf [512]byte _, err = reader.Read(mimeBuf[:])
_, err = reader.Read(mimeBuf[:]) if err != nil {
s.logger.Error("Could not read mime", zap.Error(err))
return err
}
if err != nil { mimeType := http.DetectContentType(mimeBuf[:])
s.logger.Error("Could not read mime", zap.Error(err))
return err
}
mimeType := http.DetectContentType(mimeBuf[:]) upload.MimeType = mimeType
upload.MimeType = mimeType if tx := s.db.Save(upload); tx.Error != nil {
s.logger.Error("Could not update tus upload", zap.Error(tx.Error))
return tx.Error
}
if tx := s.db.Save(upload); tx.Error != nil { reader, err = tusUpload.GetReader(ctx)
s.logger.Error("Could not update tus upload", zap.Error(tx.Error)) if err != nil {
return tx.Error s.logger.Error("Could not get tus file", zap.Error(err))
} return err
}
reader, err = tusUpload.GetReader(ctx) err = s.PutFile(reader, upload.Protocol, dbHash)
if err != nil {
s.logger.Error("Could not get tus file", zap.Error(err))
return err
}
err = s.PutFile(reader, upload.Protocol, dbHash) if err != nil {
s.logger.Error("Could not upload file", zap.Error(err))
return err
}
if err != nil { s3InfoId, _ := splitS3Ids(upload.UploadID)
s.logger.Error("Could not upload file", zap.Error(err))
return err
}
s3InfoId, _ := splitS3Ids(upload.UploadID) _, err = s.s3Client.DeleteObjects(ctx, &s3.DeleteObjectsInput{
Bucket: aws.String(s.config.GetString("core.storage.s3.bufferBucket")),
_, err = s.s3Client.DeleteObjects(ctx, &s3.DeleteObjectsInput{ Delete: &s3types.Delete{
Bucket: aws.String(s.config.GetString("core.storage.s3.bufferBucket")), Objects: []s3types.ObjectIdentifier{
Delete: &s3types.Delete{ {
Objects: []s3types.ObjectIdentifier{ Key: aws.String(s3InfoId),
{
Key: aws.String(s3InfoId),
},
{
Key: aws.String(s3InfoId + ".info"),
},
},
Quiet: aws.Bool(true),
}, },
}) {
Key: aws.String(s3InfoId + ".info"),
},
},
Quiet: aws.Bool(true),
},
})
if err != nil { if err != nil {
s.logger.Error("Could not delete upload metadata", zap.Error(err)) s.logger.Error("Could not delete upload metadata", zap.Error(err))
return err return err
} }
newUpload, err := s.CreateUpload(dbHash, mimeType, upload.UploaderID, upload.UploaderIP, uint64(byteCount), upload.Protocol) newUpload, err := s.CreateUpload(dbHash, mimeType, upload.UploaderID, upload.UploaderIP, uint64(byteCount), upload.Protocol)
if err != nil { if err != nil {
s.logger.Error("Could not create upload", zap.Error(err)) s.logger.Error("Could not create upload", zap.Error(err))
return err return err
} }
err = s.accounts.PinByID(newUpload.ID, upload.UploaderID) err = s.accounts.PinByID(newUpload.ID, upload.UploaderID)
if err != nil { if err != nil {
s.logger.Error("Could not pin upload", zap.Error(err)) s.logger.Error("Could not pin upload", zap.Error(err))
return err return err
} }
return nil return nil
}, upload)
return job, task
} }
func (s *StorageServiceImpl) getPrefixedHash(hash []byte) []byte { func (s *StorageServiceImpl) getPrefixedHash(hash []byte) []byte {