diff --git a/blake3.go b/blake3.go index 14bc907..f79c10c 100644 --- a/blake3.go +++ b/blake3.go @@ -168,8 +168,6 @@ func (cs *chunkState) chunkCounter() uint64 { return cs.n.counter } -// complete is a helper method that reports whether a full chunk has been -// processed. func (cs *chunkState) complete() bool { return cs.bytesConsumed == chunkSize } @@ -244,42 +242,39 @@ func parentNode(left, right [8]uint32, key [8]uint32, flags uint32) node { // Hasher implements hash.Hash. type Hasher struct { - cs chunkState - key [8]uint32 - chainStack [54][8]uint32 // space for 54 subtrees (2^54 * chunkSize = 2^64) - stackSize int // number of chainStack elements that are valid - flags uint32 - size int // output size, for Sum + cs chunkState + key [8]uint32 + flags uint32 + size int // output size, for Sum + + // log(n) set of Merkle subtree roots, at most one per height. + stack [54][8]uint32 // 2^54 * chunkSize = 2^64 + used uint64 // bit vector indicating which stack elems are valid; also number of chunks added +} + +func (h *Hasher) hasSubtreeAtHeight(i uint64) bool { + return h.used&(1<>= 1 +func (h *Hasher) addChunkChainingValue(cv [8]uint32) { + // seek to first open stack slot, merging subtrees as we go + i := uint64(0) + for ; h.hasSubtreeAtHeight(i); i++ { + cv = parentNode(h.stack[i], cv, h.key, h.flags).chainingValue() } - h.chainStack[h.stackSize] = cv - h.stackSize++ + h.stack[i] = cv + h.used++ } // rootNode computes the root of the Merkle tree. It does not modify the // chainStack. func (h *Hasher) rootNode() node { - // Starting with the node from the current chunk, compute all the - // parent chaining values along the right edge of the tree, until we - // have the root node. n := h.cs.node() - for i := h.stackSize - 1; i >= 0; i-- { - n = parentNode(h.chainStack[i], n.chainingValue(), h.key, h.flags) + for i := uint64(bits.TrailingZeros64(h.used)); i < 64; i++ { + if h.hasSubtreeAtHeight(i) { + n = parentNode(h.stack[i], n.chainingValue(), h.key, h.flags) + } } n.flags |= flagRoot return n @@ -288,7 +283,7 @@ func (h *Hasher) rootNode() node { // Reset implements hash.Hash. func (h *Hasher) Reset() { h.cs = newChunkState(h.key, 0, h.flags) - h.stackSize = 0 + h.used = 0 } // BlockSize implements hash.Hash. @@ -306,9 +301,8 @@ func (h *Hasher) Write(p []byte) (int, error) { // chunks). if h.cs.complete() { cv := h.cs.node().chainingValue() - totalChunks := h.cs.chunkCounter() + 1 - h.addChunkChainingValue(cv, totalChunks) - h.cs = newChunkState(h.key, totalChunks, h.flags) + h.addChunkChainingValue(cv) + h.cs = newChunkState(h.key, h.cs.chunkCounter()+1, h.flags) } // Compress input bytes into the current chunk state.