fix SeekEnd behavior

This commit is contained in:
lukechampine 2020-01-10 14:39:34 -05:00
parent d19fa689c4
commit 4833fabd0e
2 changed files with 48 additions and 32 deletions

View File

@ -9,6 +9,7 @@ import (
"errors" "errors"
"hash" "hash"
"io" "io"
"math"
"math/bits" "math/bits"
) )
@ -136,64 +137,68 @@ func wordsToBytes(words []uint32, bytes []byte) {
} }
} }
// An OutputReader produces an seekable stream of output. Up to 2^64 - 1 bytes // An OutputReader produces an seekable stream of 2^64 - 1 output bytes.
// can be safely read from the stream.
type OutputReader struct { type OutputReader struct {
n node n node
block [blockLen]byte block [blockLen]byte
unread int off uint64
} }
// Read implements io.Reader. It always return len(p), nil. // Read implements io.Reader. Callers may assume that Read returns len(p), nil
// unless the read would extend beyond the end of the stream.
func (or *OutputReader) Read(p []byte) (int, error) { func (or *OutputReader) Read(p []byte) (int, error) {
if or.off == math.MaxUint64 {
return 0, io.EOF
} else if rem := math.MaxUint64 - or.off; uint64(len(p)) > rem {
p = p[:rem]
}
lenp := len(p) lenp := len(p)
for len(p) > 0 { for len(p) > 0 {
if or.unread == 0 { if or.off%blockLen == 0 {
or.n.counter = or.off / blockLen
words := or.n.compress() words := or.n.compress()
wordsToBytes(words[:], or.block[:]) wordsToBytes(words[:], or.block[:])
or.unread = blockLen
or.n.counter++
} }
// copy from output buffer n := copy(p, or.block[or.off%blockLen:])
n := copy(p, or.block[blockLen-or.unread:])
or.unread -= n
p = p[n:] p = p[n:]
or.off += uint64(n)
} }
return lenp, nil return lenp, nil
} }
// Seek implements io.Seeker. SeekEnd is defined as 2^64 - 1 bytes, the maximum // Seek implements io.Seeker.
// safe output of a BLAKE3 stream.
func (or *OutputReader) Seek(offset int64, whence int) (int64, error) { func (or *OutputReader) Seek(offset int64, whence int) (int64, error) {
off := int64(or.n.counter*blockLen) + int64(blockLen-or.unread) off := or.off
switch whence { switch whence {
case io.SeekStart: case io.SeekStart:
off = offset if offset < 0 {
return 0, errors.New("seek position cannot be negative")
}
off = uint64(offset)
case io.SeekCurrent: case io.SeekCurrent:
off += offset if offset < 0 {
if uint64(-offset) > off {
return 0, errors.New("seek position cannot be negative")
}
off -= uint64(-offset)
} else {
off += uint64(offset)
}
case io.SeekEnd: case io.SeekEnd:
// BLAKE3 can safely output up to 2^64 - 1 bytes. Seeking to the "end" off = uint64(offset) - 1
// of this stream is kind of strange, but perhaps could be useful for
// testing overflow scenarios.
off = int64(^uint64(0) - uint64(offset))
default: default:
panic("invalid whence") panic("invalid whence")
} }
if off < 0 { or.off = off
return 0, errors.New("seek position cannot be negative")
}
or.n.counter = uint64(off) / blockLen or.n.counter = uint64(off) / blockLen
or.unread = blockLen - (int(off) % blockLen) if or.off%blockLen != 0 {
// If the new offset is not a block boundary, generate the block we are
// "inside."
if or.unread != 0 {
words := or.n.compress() words := or.n.compress()
wordsToBytes(words[:], or.block[:]) wordsToBytes(words[:], or.block[:])
} }
// NOTE: or.off >= 2^63 will result in a negative return value.
return off, nil // Nothing we can do about this.
return int64(or.off), nil
} }
type chunkState struct { type chunkState struct {

View File

@ -104,9 +104,20 @@ func TestXOF(t *testing.T) {
if !bytes.Equal(outRead[:n], xofRead[:n]) { if !bytes.Equal(outRead[:n], xofRead[:n]) {
t.Errorf("XOF output did not match test vector at offset %v:\n\texpected: %x...\n\t got: %x...", offset, outRead[:10], xofRead[:10]) t.Errorf("XOF output did not match test vector at offset %v:\n\texpected: %x...\n\t got: %x...", offset, outRead[:10], xofRead[:10])
} }
} }
} }
// test behavior at end of stream
xof := blake3.New(0, nil).XOF()
buf := make([]byte, 1024)
xof.Seek(-1000, io.SeekEnd)
n, err := xof.Read(buf)
if n != 1000 || err != nil {
t.Errorf("expected (1000, nil) when reading near end of stream, got (%v, %v)", n, err)
}
n, err = xof.Read(buf)
if n != 0 || err != io.EOF {
t.Errorf("expected (0, EOF) when reading past end of stream, got (%v, %v)", n, err)
}
} }
type nopReader struct{} type nopReader struct{}