diff --git a/import/import.go b/import/import.go index 7459b47..cef5d49 100644 --- a/import/import.go +++ b/import/import.go @@ -15,7 +15,7 @@ import ( var ErrNotFound = gorm.ErrRecordNotFound var _ ImportService = (*ImportServiceDefault)(nil) -var _ io.Reader = (*ImportReader)(nil) +var _ io.ReadSeekCloser = (*ImportReader)(nil) type ImportMetadata struct { ID uint @@ -180,3 +180,90 @@ func NewImportService(params ImportServiceParams) *ImportServiceDefault { db: params.Db, } } + +type ImportReader struct { + service ImportService + meta ImportMetadata + reader io.Reader + size uint64 + stage int + totalStages int + bytesRead uint64 +} + +func (i *ImportReader) Seek(offset int64, whence int) (int64, error) { + if seeker, ok := i.reader.(io.Seeker); ok { + // If seeking to the start, reset progress based on recorded bytes + if whence == io.SeekStart && offset == 0 { + i.bytesRead = 0 + i.meta.Progress = 0 + if err := i.service.SaveImport(context.Background(), i.meta, false); err != nil { + return 0, err + } + } + return seeker.Seek(offset, whence) + } + + return 0, errors.New("Seek not supported") +} + +func (i *ImportReader) Close() error { + if closer, ok := i.reader.(io.Closer); ok { + return closer.Close() + } + + return nil +} + +func (i *ImportReader) Read(p []byte) (n int, err error) { + n, err = i.reader.Read(p) + if err != nil { + return 0, err + } + + // Update cumulative bytes read + i.bytesRead += uint64(n) + + err = i.ReadBytes(n) + if err != nil { + return 0, err + } + + return n, nil +} + +func (i *ImportReader) ReadBytes(n int) (err error) { + stageProgress := float64(100) / float64(i.totalStages) + + // Calculate progress based on bytes read + i.meta.Progress = float64(i.bytesRead) / float64(i.size) * 100.0 + + // Adjust progress for current stage + if i.stage > 1 { + i.meta.Progress += float64(i.stage-1) * stageProgress + } + + // Ensure progress doesn't exceed 100% + if i.meta.Progress > 100 { + i.meta.Progress = 100 + } + + // Save import progress + err = i.service.SaveImport(context.Background(), i.meta, false) + if err != nil { + return err + } + + return nil +} + +func NewImportReader(service ImportService, meta ImportMetadata, reader io.Reader, size uint64, stage, totalStages int) *ImportReader { + return &ImportReader{ + service: service, + meta: meta, + reader: reader, + size: size, + stage: stage, + totalStages: totalStages, + } +}