diff --git a/interfaces/storage.go b/interfaces/storage.go index 936c06d..bb0be3d 100644 --- a/interfaces/storage.go +++ b/interfaces/storage.go @@ -16,7 +16,7 @@ type StorageService interface { BuildUploadBufferTus(basePath string, preUploadCb TusPreUploadCreateCallback, preFinishCb TusPreFinishResponseCallback) (*tusd.Handler, tusd.DataStore, *s3.Client, error) FileExists(hash []byte) (bool, models.Upload) GetHashSmall(file io.ReadSeeker) ([]byte, error) - GetHash(file io.Reader) ([]byte, error) + GetHash(file io.Reader) ([]byte, int64, error) CreateUpload(hash []byte, uploaderID uint, uploaderIP string, size uint64, protocol string) (*models.Upload, error) TusUploadExists(hash []byte) (bool, models.TusUpload) CreateTusUpload(hash []byte, uploadID string, uploaderID uint, uploaderIP string, protocol string) (*models.TusUpload, error) diff --git a/storage/storage.go b/storage/storage.go index 08da0d8..696d34e 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -275,18 +275,18 @@ func (s *StorageServiceImpl) GetHashSmall(file io.ReadSeeker) ([]byte, error) { return hash[:], nil } -func (s *StorageServiceImpl) GetHash(file io.Reader) ([]byte, error) { +func (s *StorageServiceImpl) GetHash(file io.Reader) ([]byte, int64, error) { hasher := blake3.New(64, nil) - _, err := io.Copy(hasher, file) + totalBytes, err := io.Copy(hasher, file) if err != nil { - return nil, err + return nil, 0, err } hash := hasher.Sum(nil) - return hash[:32], nil + return hash[:32], totalBytes, nil } func (s *StorageServiceImpl) CreateUpload(hash []byte, uploaderID uint, uploaderIP string, size uint64, protocol string) (*models.Upload, error) { @@ -506,7 +506,7 @@ func (s *StorageServiceImpl) buildNewTusUploadTask(upload *models.TusUpload) (jo return err } - hash, err := s.GetHash(reader) + hash, byteCount, err := s.GetHash(reader) if err != nil { s.portal.Logger().Error("Could not compute hash", zap.Error(err)) @@ -544,7 +544,7 @@ func (s *StorageServiceImpl) buildNewTusUploadTask(upload *models.TusUpload) (jo return err } - _, err = s.CreateUpload(dbHash, upload.UploaderID, upload.UploaderIP, uint64(info.Size), upload.Protocol) + _, err = s.CreateUpload(dbHash, upload.UploaderID, upload.UploaderIP, uint64(byteCount), upload.Protocol) if err != nil { s.portal.Logger().Error("Could not create upload", zap.Error(err)) return err