diff --git a/blake3.go b/blake3.go index 53293ca..11b2a94 100644 --- a/blake3.go +++ b/blake3.go @@ -44,7 +44,7 @@ func g(state *[16]uint32, a, b, c, d int, mx, my uint32) { state[b] = rotr(state[b]^state[c], 7) } -func round(state, m *[16]uint32) { +func round(state *[16]uint32, m [16]uint32) { // Mix the columns. g(state, 0, 4, 8, 12, m[0], m[1]) g(state, 1, 5, 9, 13, m[2], m[3]) @@ -65,37 +65,50 @@ func permute(m *[16]uint32) { *m = permuted } -func compress(cv [8]uint32, block [16]uint32, counter uint64, blockLen uint32, flags uint32) [16]uint32 { +// Each chunk or parent node can produce either an 8-word chaining value or, by +// setting flagRoot, any number of final output bytes. The node struct +// captures the state just prior to choosing between those two possibilities. +type node struct { + cv [8]uint32 + block [16]uint32 + counter uint64 + blockLen uint32 + flags uint32 +} + +func (n node) compress() [16]uint32 { state := [16]uint32{ - cv[0], cv[1], cv[2], cv[3], - cv[4], cv[5], cv[6], cv[7], + n.cv[0], n.cv[1], n.cv[2], n.cv[3], + n.cv[4], n.cv[5], n.cv[6], n.cv[7], iv[0], iv[1], iv[2], iv[3], - uint32(counter), uint32(counter >> 32), blockLen, flags, + uint32(n.counter), uint32(n.counter >> 32), n.blockLen, n.flags, } - round(&state, &block) // round 1 + block := n.block + round(&state, block) // round 1 permute(&block) - round(&state, &block) // round 2 + round(&state, block) // round 2 permute(&block) - round(&state, &block) // round 3 + round(&state, block) // round 3 permute(&block) - round(&state, &block) // round 4 + round(&state, block) // round 4 permute(&block) - round(&state, &block) // round 5 + round(&state, block) // round 5 permute(&block) - round(&state, &block) // round 6 + round(&state, block) // round 6 permute(&block) - round(&state, &block) // round 7 + round(&state, block) // round 7 - for i := range cv { + for i := range n.cv { state[i] ^= state[i+8] - state[i+8] ^= cv[i] + state[i+8] ^= n.cv[i] } return state } -func first8(words [16]uint32) (out [8]uint32) { - copy(out[:], words[:8]) +func (n node) chainingValue() (cv [8]uint32) { + full := n.compress() + copy(cv[:], full[:8]) return } @@ -105,68 +118,48 @@ func bytesToWords(bytes []byte, words []uint32) { } } -func wordsToBlock(words []uint32, bytes []byte) { +func wordsToBytes(words []uint32, bytes []byte) { for i, w := range words { binary.LittleEndian.PutUint32(bytes[i*4:], w) } } -// Each chunk or parent node can produce either an 8-word chaining value or, by -// setting flagRoot, any number of final output bytes. The output struct -// captures the state just prior to choosing between those two possibilities. -type output struct { - inChain [8]uint32 - blockWords [16]uint32 - counter uint64 - blockLen uint32 - flags uint32 -} - -func (o *output) chainingValue() [8]uint32 { - return first8(compress(o.inChain, o.blockWords, o.counter, o.blockLen, o.flags)) -} - // An OutputReader produces an unbounded stream of output from its initial // state. type OutputReader struct { - o *output - block [blockLen]byte - remaining int - blocksoutput uint64 + n node + block [blockLen]byte + blockUsed int } // Read implements io.Reader. Read always return len(p), nil. func (or *OutputReader) Read(p []byte) (int, error) { lenp := len(p) for len(p) > 0 { - if or.remaining == 0 { - words := compress( - or.o.inChain, - or.o.blockWords, - or.blocksoutput, - or.o.blockLen, - or.o.flags|flagRoot, - ) - wordsToBlock(words[:], or.block[:]) - or.remaining = blockLen - or.blocksoutput++ + if or.blockUsed == 0 { + words := or.n.compress() + wordsToBytes(words[:], or.block[:]) + or.blockUsed = blockLen + or.n.counter++ } // copy from output buffer - n := copy(p, or.block[blockLen-or.remaining:]) - or.remaining -= n + n := copy(p, or.block[blockLen-or.blockUsed:]) + or.blockUsed -= n p = p[n:] } return lenp, nil } type chunkState struct { - chainingValue [8]uint32 - chunkCounter uint64 + n node block [blockLen]byte blockLen int bytesConsumed int - flags uint32 +} + +func (cs *chunkState) chunkCounter() uint64 { + return cs.n.counter } func (cs *chunkState) update(input []byte) { @@ -174,19 +167,12 @@ func (cs *chunkState) update(input []byte) { // If the block buffer is full, compress it and clear it. More // input is coming, so this compression is not flagChunkEnd. if cs.blockLen == blockLen { - var blockWords [16]uint32 - bytesToWords(cs.block[:], blockWords[:]) - cs.chainingValue = first8(compress( - cs.chainingValue, - blockWords, - cs.chunkCounter, - blockLen, - cs.flags, - )) + bytesToWords(cs.block[:], cs.n.block[:]) + cs.n.cv = cs.n.chainingValue() cs.block = [blockLen]byte{} cs.blockLen = 0 // After the first chunk has been compressed, clear the start flag. - cs.flags &^= flagChunkStart + cs.n.flags &^= flagChunkStart } // Copy input bytes into the block buffer. @@ -197,37 +183,36 @@ func (cs *chunkState) update(input []byte) { } } -func (cs *chunkState) output() *output { - var blockWords [16]uint32 - bytesToWords(cs.block[:], blockWords[:]) - return &output{ - inChain: cs.chainingValue, - blockWords: blockWords, - blockLen: uint32(cs.blockLen), - counter: cs.chunkCounter, - flags: cs.flags | flagChunkEnd, - } +func (cs *chunkState) node() node { + n := cs.n + bytesToWords(cs.block[:], n.block[:]) + n.blockLen = uint32(cs.blockLen) + n.flags |= flagChunkEnd + return n } func newChunkState(key [8]uint32, chunkCounter uint64, flags uint32) chunkState { return chunkState{ - chainingValue: key, - chunkCounter: chunkCounter, - // compress the first chunk with the start flag set - flags: flags | flagChunkStart, + n: node{ + cv: key, + counter: chunkCounter, + blockLen: blockLen, + // compress the first chunk with the start flag set + flags: flags | flagChunkStart, + }, } } -func parentOutput(left, right [8]uint32, key [8]uint32, flags uint32) *output { +func parentNode(left, right [8]uint32, key [8]uint32, flags uint32) node { var blockWords [16]uint32 copy(blockWords[:8], left[:]) copy(blockWords[8:], right[:]) - return &output{ - inChain: key, - blockWords: blockWords, - counter: 0, // Always 0 for parent nodes. - blockLen: blockLen, // Always blockLen (64) for parent nodes. - flags: flagParent | flags, + return node{ + cv: key, + block: blockWords, + counter: 0, // Always 0 for parent nodes. + blockLen: blockLen, // Always blockLen (64) for parent nodes. + flags: flags | flagParent, } } @@ -264,12 +249,11 @@ func New(size int, key []byte) *Hasher { // NewFromDerivedKey returns a Hasher whose key was derived from the supplied // context string. func NewFromDerivedKey(size int, ctx string) *Hasher { - const ( - derivedKeyLen = 32 - ) + const derivedKeyLen = 32 h := newHasher(iv, flagDeriveKeyContext, derivedKeyLen) h.Write([]byte(ctx)) - key := h.Sum(nil) + key := make([]byte, derivedKeyLen) + h.Sum(key[:0]) var keyWords [8]uint32 bytesToWords(key, keyWords[:]) return newHasher(keyWords, flagDeriveKeyMaterial, size) @@ -289,7 +273,7 @@ func (h *Hasher) addChunkChainingValue(cv [8]uint32, totalChunks uint64) { h.stackSize-- left := h.chainStack[h.stackSize] // merge - right = parentOutput(left, right, h.key, h.flags).chainingValue() + right = parentNode(left, right, h.key, h.flags).chainingValue() totalChunks >>= 1 } h.chainStack[h.stackSize] = right @@ -315,8 +299,8 @@ func (h *Hasher) Write(p []byte) (int, error) { // If the current chunk is complete, finalize it and reset the // chunk state. More input is coming, so this chunk is not flagRoot. if h.cs.bytesConsumed == chunkLen { - cv := h.cs.output().chainingValue() - totalChunks := h.cs.chunkCounter + 1 + cv := h.cs.node().chainingValue() + totalChunks := h.cs.chunkCounter() + 1 h.addChunkChainingValue(cv, totalChunks) h.cs = newChunkState(h.key, totalChunks, h.flags) } @@ -334,29 +318,54 @@ func (h *Hasher) Write(p []byte) (int, error) { // Sum implements hash.Hash. func (h *Hasher) Sum(b []byte) []byte { - out := make([]byte, h.Size()) - h.XOF().Read(out) - return append(b, out...) + ret, fill := sliceForAppend(b, h.Size()) + h.XOF().Read(fill) + return ret } // XOF returns an OutputReader initialized with the current hash state. func (h *Hasher) XOF() *OutputReader { - // Starting with the output from the current chunk, compute all the + // Starting with the node from the current chunk, compute all the // parent chaining values along the right edge of the tree, until we - // have the root output. - output := h.cs.output() + // have the root node. + n := h.cs.node() for i := h.stackSize - 1; i >= 0; i-- { - output = parentOutput( - h.chainStack[i], - output.chainingValue(), - h.key, - h.flags, - ) + n = parentNode(h.chainStack[i], n.chainingValue(), h.key, h.flags) } + n.flags |= flagRoot return &OutputReader{ - o: output, + n: n, } } +// Sum256 returns the unkeyed BLAKE3 hash of b, truncated to 256 bits. +func Sum256(b []byte) [32]byte { + var out [32]byte + h := New(32, nil) + h.Write(b) + h.Sum(out[:0]) + return out +} + +// Sum512 returns the unkeyed BLAKE3 hash of b, truncated to 512 bits. +func Sum512(b []byte) [64]byte { + var out [64]byte + h := New(64, nil) + h.Write(b) + h.Sum(out[:0]) + return out +} + // ensure that Hasher implements hash.Hash var _ hash.Hash = (*Hasher)(nil) + +func sliceForAppend(in []byte, n int) (head, tail []byte) { + if total := len(in) + n; cap(in) >= total { + head = in[:total] + } else { + head = make([]byte, total) + copy(head, in) + } + tail = head[len(in):] + return +} diff --git a/blake3_test.go b/blake3_test.go index 2079fbf..a4a74cb 100644 --- a/blake3_test.go +++ b/blake3_test.go @@ -80,7 +80,7 @@ func BenchmarkWrite(b *testing.B) { func BenchmarkChunk(b *testing.B) { h := blake3.New(32, nil) buf := make([]byte, 1024) - out := make([]byte, 32) + out := make([]byte, 0, 32) for i := 0; i < b.N; i++ { h.Write(buf) h.Sum(out)