fix SeekEnd behavior
This commit is contained in:
parent
d19fa689c4
commit
4833fabd0e
67
blake3.go
67
blake3.go
|
@ -9,6 +9,7 @@ import (
|
|||
"errors"
|
||||
"hash"
|
||||
"io"
|
||||
"math"
|
||||
"math/bits"
|
||||
)
|
||||
|
||||
|
@ -136,64 +137,68 @@ func wordsToBytes(words []uint32, bytes []byte) {
|
|||
}
|
||||
}
|
||||
|
||||
// An OutputReader produces an seekable stream of output. Up to 2^64 - 1 bytes
|
||||
// can be safely read from the stream.
|
||||
// An OutputReader produces an seekable stream of 2^64 - 1 output bytes.
|
||||
type OutputReader struct {
|
||||
n node
|
||||
block [blockLen]byte
|
||||
unread int
|
||||
n node
|
||||
block [blockLen]byte
|
||||
off uint64
|
||||
}
|
||||
|
||||
// Read implements io.Reader. It always return len(p), nil.
|
||||
// Read implements io.Reader. Callers may assume that Read returns len(p), nil
|
||||
// unless the read would extend beyond the end of the stream.
|
||||
func (or *OutputReader) Read(p []byte) (int, error) {
|
||||
if or.off == math.MaxUint64 {
|
||||
return 0, io.EOF
|
||||
} else if rem := math.MaxUint64 - or.off; uint64(len(p)) > rem {
|
||||
p = p[:rem]
|
||||
}
|
||||
lenp := len(p)
|
||||
for len(p) > 0 {
|
||||
if or.unread == 0 {
|
||||
if or.off%blockLen == 0 {
|
||||
or.n.counter = or.off / blockLen
|
||||
words := or.n.compress()
|
||||
wordsToBytes(words[:], or.block[:])
|
||||
or.unread = blockLen
|
||||
or.n.counter++
|
||||
}
|
||||
|
||||
// copy from output buffer
|
||||
n := copy(p, or.block[blockLen-or.unread:])
|
||||
or.unread -= n
|
||||
n := copy(p, or.block[or.off%blockLen:])
|
||||
p = p[n:]
|
||||
or.off += uint64(n)
|
||||
}
|
||||
return lenp, nil
|
||||
}
|
||||
|
||||
// Seek implements io.Seeker. SeekEnd is defined as 2^64 - 1 bytes, the maximum
|
||||
// safe output of a BLAKE3 stream.
|
||||
// Seek implements io.Seeker.
|
||||
func (or *OutputReader) Seek(offset int64, whence int) (int64, error) {
|
||||
off := int64(or.n.counter*blockLen) + int64(blockLen-or.unread)
|
||||
off := or.off
|
||||
switch whence {
|
||||
case io.SeekStart:
|
||||
off = offset
|
||||
if offset < 0 {
|
||||
return 0, errors.New("seek position cannot be negative")
|
||||
}
|
||||
off = uint64(offset)
|
||||
case io.SeekCurrent:
|
||||
off += offset
|
||||
if offset < 0 {
|
||||
if uint64(-offset) > off {
|
||||
return 0, errors.New("seek position cannot be negative")
|
||||
}
|
||||
off -= uint64(-offset)
|
||||
} else {
|
||||
off += uint64(offset)
|
||||
}
|
||||
case io.SeekEnd:
|
||||
// BLAKE3 can safely output up to 2^64 - 1 bytes. Seeking to the "end"
|
||||
// of this stream is kind of strange, but perhaps could be useful for
|
||||
// testing overflow scenarios.
|
||||
off = int64(^uint64(0) - uint64(offset))
|
||||
off = uint64(offset) - 1
|
||||
default:
|
||||
panic("invalid whence")
|
||||
}
|
||||
if off < 0 {
|
||||
return 0, errors.New("seek position cannot be negative")
|
||||
}
|
||||
or.off = off
|
||||
or.n.counter = uint64(off) / blockLen
|
||||
or.unread = blockLen - (int(off) % blockLen)
|
||||
|
||||
// If the new offset is not a block boundary, generate the block we are
|
||||
// "inside."
|
||||
if or.unread != 0 {
|
||||
if or.off%blockLen != 0 {
|
||||
words := or.n.compress()
|
||||
wordsToBytes(words[:], or.block[:])
|
||||
}
|
||||
|
||||
return off, nil
|
||||
// NOTE: or.off >= 2^63 will result in a negative return value.
|
||||
// Nothing we can do about this.
|
||||
return int64(or.off), nil
|
||||
}
|
||||
|
||||
type chunkState struct {
|
||||
|
|
|
@ -104,9 +104,20 @@ func TestXOF(t *testing.T) {
|
|||
if !bytes.Equal(outRead[:n], xofRead[:n]) {
|
||||
t.Errorf("XOF output did not match test vector at offset %v:\n\texpected: %x...\n\t got: %x...", offset, outRead[:10], xofRead[:10])
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
// test behavior at end of stream
|
||||
xof := blake3.New(0, nil).XOF()
|
||||
buf := make([]byte, 1024)
|
||||
xof.Seek(-1000, io.SeekEnd)
|
||||
n, err := xof.Read(buf)
|
||||
if n != 1000 || err != nil {
|
||||
t.Errorf("expected (1000, nil) when reading near end of stream, got (%v, %v)", n, err)
|
||||
}
|
||||
n, err = xof.Read(buf)
|
||||
if n != 0 || err != io.EOF {
|
||||
t.Errorf("expected (0, EOF) when reading past end of stream, got (%v, %v)", n, err)
|
||||
}
|
||||
}
|
||||
|
||||
type nopReader struct{}
|
||||
|
|
Reference in New Issue