refactor: add context to all tus apis

This commit is contained in:
Derrick Hammer 2024-02-16 22:08:34 -05:00
parent c468a81543
commit 6845dac609
Signed by: pcfreak30
GPG Key ID: C997C339BE476FF2
1 changed files with 18 additions and 20 deletions

View File

@ -102,7 +102,7 @@ func (t *TusHandler) Init() error {
return blankResp, blankChanges, errors.New("file already exists") return blankResp, blankChanges, errors.New("file already exists")
} }
exists, _ := t.UploadExists(decodedHash.HashBytes()) exists, _ := t.UploadExists(hook.Context, decodedHash.HashBytes())
if exists { if exists {
return blankResp, blankChanges, errors.New("file is already being uploaded") return blankResp, blankChanges, errors.New("file is already being uploaded")
@ -175,16 +175,16 @@ func (t *TusHandler) Tus() *tusd.Handler {
return t.tus return t.tus
} }
func (t *TusHandler) UploadExists(hash []byte) (bool, models.TusUpload) { func (t *TusHandler) UploadExists(ctx context.Context, hash []byte) (bool, models.TusUpload) {
hashStr := hex.EncodeToString(hash) hashStr := hex.EncodeToString(hash)
var upload models.TusUpload var upload models.TusUpload
result := t.db.Model(&models.TusUpload{}).Where(&models.TusUpload{Hash: hashStr}).First(&upload) result := t.db.WithContext(ctx).Model(&models.TusUpload{}).Where(&models.TusUpload{Hash: hashStr}).First(&upload)
return result.RowsAffected > 0, upload return result.RowsAffected > 0, upload
} }
func (t *TusHandler) CreateUpload(hash []byte, uploadID string, uploaderID uint, uploaderIP string, protocol string) (*models.TusUpload, error) { func (t *TusHandler) CreateUpload(ctx context.Context, hash []byte, uploadID string, uploaderID uint, uploaderIP string, protocol string) (*models.TusUpload, error) {
hashStr := hex.EncodeToString(hash) hashStr := hex.EncodeToString(hash)
upload := &models.TusUpload{ upload := &models.TusUpload{
@ -196,7 +196,7 @@ func (t *TusHandler) CreateUpload(hash []byte, uploadID string, uploaderID uint,
Protocol: protocol, Protocol: protocol,
} }
result := t.db.Create(upload) result := t.db.WithContext(ctx).Create(upload)
if result.Error != nil { if result.Error != nil {
return nil, result.Error return nil, result.Error
@ -204,7 +204,7 @@ func (t *TusHandler) CreateUpload(hash []byte, uploadID string, uploaderID uint,
return upload, nil return upload, nil
} }
func (t *TusHandler) UploadProgress(uploadID string) error { func (t *TusHandler) UploadProgress(ctx context.Context, uploadID string) error {
find := &models.TusUpload{UploadID: uploadID} find := &models.TusUpload{UploadID: uploadID}
@ -215,7 +215,7 @@ func (t *TusHandler) UploadProgress(uploadID string) error {
return errors.New("upload not found") return errors.New("upload not found")
} }
result = t.db.Model(&models.TusUpload{}).Where(find).Update("updated_at", time.Now()) result = t.db.WithContext(ctx).Model(&models.TusUpload{}).Where(find).Update("updated_at", time.Now())
if result.Error != nil { if result.Error != nil {
return result.Error return result.Error
@ -223,7 +223,7 @@ func (t *TusHandler) UploadProgress(uploadID string) error {
return nil return nil
} }
func (t *TusHandler) UploadCompleted(uploadID string) error { func (t *TusHandler) UploadCompleted(ctx context.Context, uploadID string) error {
find := &models.TusUpload{UploadID: uploadID} find := &models.TusUpload{UploadID: uploadID}
@ -234,12 +234,12 @@ func (t *TusHandler) UploadCompleted(uploadID string) error {
return errors.New("upload not found") return errors.New("upload not found")
} }
result = t.db.Model(&models.TusUpload{}).Where(find).Update("completed", true) result = t.db.WithContext(ctx).Model(&models.TusUpload{}).Where(find).Update("completed", true)
return nil return nil
} }
func (t *TusHandler) DeleteUpload(uploadID string) error { func (t *TusHandler) DeleteUpload(ctx context.Context, uploadID string) error {
result := t.db.Where(&models.TusUpload{UploadID: uploadID}).Delete(&models.TusUpload{}) result := t.db.WithContext(ctx).Where(&models.TusUpload{UploadID: uploadID}).Delete(&models.TusUpload{})
if result.Error != nil { if result.Error != nil {
return result.Error return result.Error
@ -248,11 +248,11 @@ func (t *TusHandler) DeleteUpload(uploadID string) error {
return nil return nil
} }
func (t *TusHandler) ScheduleUpload(uploadID string) error { func (t *TusHandler) ScheduleUpload(ctx context.Context, uploadID string) error {
find := &models.TusUpload{UploadID: uploadID} find := &models.TusUpload{UploadID: uploadID}
var upload models.TusUpload var upload models.TusUpload
result := t.db.Model(&models.TusUpload{}).Where(find).First(&upload) result := t.db.WithContext(ctx).Model(&models.TusUpload{}).Where(find).First(&upload)
if result.RowsAffected == 0 { if result.RowsAffected == 0 {
return errors.New("upload not found") return errors.New("upload not found")
@ -266,7 +266,7 @@ func (t *TusHandler) ScheduleUpload(uploadID string) error {
Limit: 0, Limit: 0,
After: func(jobID uuid.UUID, jobName string) { After: func(jobID uuid.UUID, jobName string) {
t.logger.Info("Job finished", zap.String("jobName", jobName), zap.String("uploadID", uploadID)) t.logger.Info("Job finished", zap.String("jobName", jobName), zap.String("uploadID", uploadID))
err := t.DeleteUpload(uploadID) err := t.DeleteUpload(ctx, uploadID)
if err != nil { if err != nil {
t.logger.Error("Error deleting tus upload", zap.Error(err)) t.logger.Error("Error deleting tus upload", zap.Error(err))
} }
@ -281,9 +281,8 @@ func (t *TusHandler) ScheduleUpload(uploadID string) error {
return nil return nil
} }
func (t *TusHandler) GetUploadReader(hash []byte, start int64) (io.ReadCloser, error) { func (t *TusHandler) GetUploadReader(ctx context.Context, hash []byte, start int64) (io.ReadCloser, error) {
ctx := context.Background() exists, upload := t.UploadExists(ctx, hash)
exists, upload := t.UploadExists(hash)
if !exists { if !exists {
return nil, metadata.ErrNotFound return nil, metadata.ErrNotFound
@ -313,8 +312,7 @@ func (t *TusHandler) GetUploadReader(hash []byte, start int64) (io.ReadCloser, e
return reader, nil return reader, nil
} }
func (t *TusHandler) uploadTask(upload *models.TusUpload) error { func (t *TusHandler) uploadTask(ctx context.Context, upload *models.TusUpload) error {
ctx := context.Background()
tusUpload, err := t.tusStore.GetUpload(ctx, upload.UploadID) tusUpload, err := t.tusStore.GetUpload(ctx, upload.UploadID)
if err != nil { if err != nil {
t.logger.Error("Could not get upload", zap.Error(err)) t.logger.Error("Could not get upload", zap.Error(err))
@ -366,7 +364,7 @@ func (t *TusHandler) uploadTask(upload *models.TusUpload) error {
return err return err
} }
info, err := tusUpload.GetInfo(context.Background()) info, err := tusUpload.GetInfo(ctx)
if err != nil { if err != nil {
t.logger.Error("Could not get tus info", zap.Error(err)) t.logger.Error("Could not get tus info", zap.Error(err))
return err return err