diff --git a/blake3.go b/blake3.go index f141aa1..c68dee4 100644 --- a/blake3.go +++ b/blake3.go @@ -1,7 +1,4 @@ // Package blake3 implements the BLAKE3 cryptographic hash function. -// -// This is a direct port of the Rust reference implementation. It is not -// optimized for performance. package blake3 import ( @@ -34,6 +31,21 @@ var iv = [8]uint32{ 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19, } +// helper functions for converting between bytes and BLAKE3 "words" + +func bytesToWords(bytes []byte, words []uint32) { + for i := range words { + words[i] = binary.LittleEndian.Uint32(bytes[i*4:]) + } +} + +func wordsToBytes(words []uint32, bytes []byte) { + for i, w := range words { + binary.LittleEndian.PutUint32(bytes[i*4:], w) + } +} + +// The g function, split into two parts so that the compiler will inline it. func gx(state *[16]uint32, a, b, c, d int, mx uint32) { state[a] += state[b] + mx state[d] = bits.RotateLeft32(state[d]^state[a], -16) @@ -79,17 +91,29 @@ func permute(m *[16]uint32) { } } -// Each chunk or parent node can produce either an 8-word chaining value or, by -// setting flagRoot, any number of final output bytes. The node struct -// captures the state just prior to choosing between those two possibilities. +// A node represents a chunk or parent in the BLAKE3 Merkle tree. In BLAKE3 +// terminology, the elements of the bottom layer (aka "leaves") of the tree are +// called chunk nodes, and the elements of upper layers (aka "interior nodes") +// are called parent nodes. +// +// Computing a BLAKE3 hash involves splitting the input into chunk nodes, then +// repeatedly merging these nodes into parent nodes, until only a single "root" +// node remains. The root node can then be used to generate up to 2^64 - 1 bytes +// of pseudorandom output. type node struct { - cv [8]uint32 + // the chaining value from the previous state + cv [8]uint32 + // the current state block [16]uint32 counter uint64 blockLen uint32 flags uint32 } +// compress is the core hash function, generating 16 pseudorandom words from a +// node. When nodes are being merged into parents, only the first 8 words are +// used. When the root node is being used to generate output, the full 16 words +// are used. func (n node) compress() [16]uint32 { state := [16]uint32{ n.cv[0], n.cv[1], n.cv[2], n.cv[3], @@ -119,25 +143,19 @@ func (n node) compress() [16]uint32 { return state } +// chainingValue returns the first 8 words of the compressed node. This is used +// in two places. First, when a chunk node is being constructed, its cv is +// overwritten with this value after each block of input is processed. Second, +// when two nodes are merged into a parent, each of their chaining values +// supplies half of the new node's block. Second, when func (n node) chainingValue() (cv [8]uint32) { full := n.compress() copy(cv[:], full[:8]) return } -func bytesToWords(bytes []byte, words []uint32) { - for i := range words { - words[i] = binary.LittleEndian.Uint32(bytes[i*4:]) - } -} - -func wordsToBytes(words []uint32, bytes []byte) { - for i, w := range words { - binary.LittleEndian.PutUint32(bytes[i*4:], w) - } -} - -// An OutputReader produces an seekable stream of 2^64 - 1 output bytes. +// An OutputReader produces an seekable stream of 2^64 - 1 pseudorandom output +// bytes. type OutputReader struct { n node block [blockSize]byte @@ -201,6 +219,7 @@ func (or *OutputReader) Seek(offset int64, whence int) (int64, error) { return int64(or.off), nil } +// chunkState manages the state involved in hashing a single chunk of input. type chunkState struct { n node block [blockSize]byte @@ -208,24 +227,30 @@ type chunkState struct { bytesConsumed int } +// chunkCounter is the index of this chunk, i.e. the number of chunks that have +// been processed prior to this one. func (cs *chunkState) chunkCounter() uint64 { return cs.n.counter } +// update incorporates input into the chunkState. func (cs *chunkState) update(input []byte) { for len(input) > 0 { // If the block buffer is full, compress it and clear it. More // input is coming, so this compression is not flagChunkEnd. if cs.blockLen == blockSize { + // copy the chunk block (bytes) into the node block and chain it. bytesToWords(cs.block[:], cs.n.block[:]) cs.n.cv = cs.n.chainingValue() + // clear the start flag for all but the first block + cs.n.flags &^= flagChunkStart + // reset the chunk block. It must contain zeros, because BLAKE3 + // blocks are zero-padded. cs.block = [blockSize]byte{} cs.blockLen = 0 - // After the first chunk has been compressed, clear the start flag. - cs.n.flags &^= flagChunkStart } - // Copy input bytes into the block buffer. + // Copy input bytes into the chunk block. n := copy(cs.block[cs.blockLen:], input) cs.blockLen += n cs.bytesConsumed += n @@ -233,6 +258,8 @@ func (cs *chunkState) update(input []byte) { } } +// node returns a node containing the chunkState's current state, with the +// ChunkEnd flag set. func (cs *chunkState) node() node { n := cs.n bytesToWords(cs.block[:], n.block[:]) @@ -241,18 +268,20 @@ func (cs *chunkState) node() node { return n } -func newChunkState(key [8]uint32, chunkCounter uint64, flags uint32) chunkState { +func newChunkState(iv [8]uint32, chunkCounter uint64, flags uint32) chunkState { return chunkState{ n: node{ - cv: key, + cv: iv, counter: chunkCounter, blockLen: blockSize, - // compress the first chunk with the start flag set + // compress the first block with the start flag set flags: flags | flagChunkStart, }, } } +// parentNode returns a node that incorporates the chaining values of two child +// nodes. func parentNode(left, right [8]uint32, key [8]uint32, flags uint32) node { var blockWords [16]uint32 copy(blockWords[:8], left[:]) @@ -260,8 +289,8 @@ func parentNode(left, right [8]uint32, key [8]uint32, flags uint32) node { return node{ cv: key, block: blockWords, - counter: 0, // Always 0 for parent nodes. - blockLen: blockSize, // Always blockSize (64) for parent nodes. + counter: 0, // counter is reset for parents + blockLen: blockSize, // block is full: 8 words from left, 8 from right flags: flags | flagParent, } } @@ -271,7 +300,7 @@ type Hasher struct { cs chunkState key [8]uint32 chainStack [54][8]uint32 // space for 54 subtrees (2^54 * chunkSize = 2^64) - stackSize int // index within chainStack + stackSize int // number of chainStack elements that are valid flags uint32 size int // output size, for Sum } @@ -296,14 +325,15 @@ func New(size int, key []byte) *Hasher { return newHasher(keyWords, flagKeyedHash, size) } +// addChunkChainingValue appends a chunk to the right edge of the Merkle tree. func (h *Hasher) addChunkChainingValue(cv [8]uint32, totalChunks uint64) { - // This chunk might complete some subtrees. For each completed subtree, - // its left child will be the current top entry in the CV stack, and - // its right child will be the current value of `cv`. Pop each left - // child off the stack, merge it with `cv`, and overwrite `cv` - // with the result. After all these merges, push the final value of - // `cv` onto the stack. The number of completed subtrees is given - // by the number of trailing 0-bits in the new total number of chunks. + // This chunk might complete some subtrees. For each completed subtree, its + // left child will be the current top entry in the CV stack, and its right + // child will be the current value of cv. Pop each left child off the stack, + // merge it with cv, and overwrite cv with the result. After all these + // merges, push the final value of cv onto the stack. The number of + // completed subtrees is given by the number of trailing 0-bits in the new + // total number of chunks. for totalChunks&1 == 0 { // pop and merge h.stackSize-- @@ -314,7 +344,9 @@ func (h *Hasher) addChunkChainingValue(cv [8]uint32, totalChunks uint64) { h.stackSize++ } -func (h *Hasher) finalNode() node { +// rootNode computes the root of the Merkle tree. It does not modify the +// chainStack. +func (h *Hasher) rootNode() node { // Starting with the node from the current chunk, compute all the // parent chaining values along the right edge of the tree, until we // have the root node. @@ -342,8 +374,9 @@ func (h *Hasher) Size() int { return h.size } func (h *Hasher) Write(p []byte) (int, error) { lenp := len(p) for len(p) > 0 { - // If the current chunk is complete, finalize it and reset the - // chunk state. More input is coming, so this chunk is not flagRoot. + // If the current chunk is complete, finalize it and add it to the tree, + // then reset the chunk state (but keep incrementing the counter across + // chunks). if h.cs.bytesConsumed == chunkSize { cv := h.cs.node().chainingValue() totalChunks := h.cs.chunkCounter() + 1 @@ -363,16 +396,24 @@ func (h *Hasher) Write(p []byte) (int, error) { } // Sum implements hash.Hash. -func (h *Hasher) Sum(b []byte) []byte { - ret, fill := sliceForAppend(b, h.Size()) - h.XOF().Read(fill) - return ret +func (h *Hasher) Sum(b []byte) (sum []byte) { + // We need to append h.Size() bytes to b. Reuse b's capacity if possible; + // otherwise, allocate a new slice. + if total := len(b) + h.Size(); cap(b) >= total { + sum = b[:total] + } else { + sum = make([]byte, total) + copy(sum, b) + } + // Read into the appended portion of sum + h.XOF().Read(sum[len(b):]) + return } // XOF returns an OutputReader initialized with the current hash state. func (h *Hasher) XOF() *OutputReader { return &OutputReader{ - n: h.finalNode(), + n: h.rootNode(), } } @@ -392,7 +433,17 @@ func Sum512(b []byte) (out [64]byte) { return } -// DeriveKey derives a subkey from ctx and srcKey. +// DeriveKey derives a subkey from ctx and srcKey. ctx should be hardcoded, +// globally unique, and application-specific. A good format for ctx strings is: +// +// [application] [commit timestamp] [purpose] +// +// e.g.: +// +// example.com 2019-12-25 16:18:03 session tokens v1 +// +// The purpose of these requirements is to ensure that an attacker cannot trick +// two different applications into using the same context string. func DeriveKey(subKey []byte, ctx string, srcKey []byte) { // construct the derivation Hasher const derivationIVLen = 32 @@ -400,22 +451,11 @@ func DeriveKey(subKey []byte, ctx string, srcKey []byte) { h.Write([]byte(ctx)) var derivationIV [8]uint32 bytesToWords(h.Sum(make([]byte, 0, derivationIVLen)), derivationIV[:]) - h = newHasher(derivationIV, flagDeriveKeyMaterial, len(subKey)) + h = newHasher(derivationIV, flagDeriveKeyMaterial, 0) // derive the subKey h.Write(srcKey) - h.Sum(subKey[:0]) + h.XOF().Read(subKey) } // ensure that Hasher implements hash.Hash var _ hash.Hash = (*Hasher)(nil) - -func sliceForAppend(in []byte, n int) (head, tail []byte) { - if total := len(in) + n; cap(in) >= total { - head = in[:total] - } else { - head = make([]byte, total) - copy(head, in) - } - tail = head[len(in):] - return -}