diff --git a/bao/bao.go b/bao/bao.go index 406d30b..44c4dc7 100644 --- a/bao/bao.go +++ b/bao/bao.go @@ -27,31 +27,50 @@ var _ io.ReadCloser = (*Verifier)(nil) var ErrVerifyFailed = errors.New("verification failed") type Verifier struct { - r io.ReadCloser - proof Result - read uint64 + r io.ReadCloser + proof Result + read uint64 + buffer *bytes.Buffer } -func (v Verifier) Read(p []byte) (n int, err error) { - var buf [VERIFY_CHUNK_SIZE]byte - - n, err = v.r.Read(buf[:]) - if err != nil { - return 0, err +func (v *Verifier) Read(p []byte) (int, error) { + // Initial attempt to read from the buffer + n, err := v.buffer.Read(p) + if n == len(p) { + // If the buffer already had enough data to fulfill the request, return immediately + return n, nil + } else if err != nil && err != io.EOF { + // For errors other than EOF, return the error immediately + return n, err } - if !bao.Verify(buf[:n], v.read, v.proof.Proof, v.proof.Hash) { - return 0, ErrVerifyFailed + buf := make([]byte, VERIFY_CHUNK_SIZE) + // Continue reading from the source and verifying until we have enough data or hit an error + for v.buffer.Len() < len(p)-n { + bytesRead, err := v.r.Read(buf) + if err != nil && err != io.EOF { + return n, err // Return any read error immediately + } + + if !bao.Verify(buf[:bytesRead], v.read, v.proof.Proof, v.proof.Hash) { + return n, ErrVerifyFailed + } + + v.read += uint64(bytesRead) + v.buffer.Write(buf[:bytesRead]) // Append new data to the buffer + + if err == io.EOF { + // If EOF, break the loop as no more data can be read + break + } } - v.read += uint64(n) - - copy(p, buf[:n]) - - return n, nil + // Attempt to read the remainder of the data from the buffer + additionalBytes, _ := v.buffer.Read(p[n:]) + return n + additionalBytes, nil } -func (v Verifier) Close() error { +func (v *Verifier) Close() error { return v.r.Close() } @@ -145,5 +164,9 @@ func Hash(r io.Reader) (*Result, error) { } func NewVerifier(r io.ReadCloser, proof Result) *Verifier { - return &Verifier{r: r, proof: proof} + return &Verifier{ + r: r, + proof: proof, + buffer: new(bytes.Buffer), + } }