From 83947f416f5690313c910b2e93d677dcfff04388 Mon Sep 17 00:00:00 2001 From: lukechampine Date: Thu, 9 Jan 2020 17:58:48 -0500 Subject: [PATCH] unexport constants and refactor to idiomatic Go --- blake3.go | 372 ++++++++++++++++++++++--------------------------- blake3_test.go | 14 +- 2 files changed, 172 insertions(+), 214 deletions(-) diff --git a/blake3.go b/blake3.go index 5f5c917..53293ca 100644 --- a/blake3.go +++ b/blake3.go @@ -10,40 +10,38 @@ import ( ) const ( - OUT_LEN = 32 - KEY_LEN = 32 - BLOCK_LEN = 64 - CHUNK_LEN = 1024 - - CHUNK_START = 1 << 0 - CHUNK_END = 1 << 1 - PARENT = 1 << 2 - ROOT = 1 << 3 - KEYED_HASH = 1 << 4 - DERIVE_KEY_CONTEXT = 1 << 5 - DERIVE_KEY_MATERIAL = 1 << 6 + blockLen = 64 + chunkLen = 1024 ) -var IV = [8]uint32{ - 0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19, +// flags +const ( + flagChunkStart = 1 << iota + flagChunkEnd + flagParent + flagRoot + flagKeyedHash + flagDeriveKeyContext + flagDeriveKeyMaterial +) + +var iv = [8]uint32{ + 0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, + 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19, } -var MSG_PERMUTATION = [16]uint{2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8} - -func rotate_right(x uint32, n int) uint32 { - return (x >> n) | (x << (32 - n)) -} - -// The mixing function, G, which mixes either a column or a diagonal. func g(state *[16]uint32, a, b, c, d int, mx, my uint32) { + rotr := func(x uint32, n int) uint32 { + return (x >> n) | (x << (32 - n)) + } state[a] = state[a] + state[b] + mx - state[d] = rotate_right(state[d]^state[a], 16) + state[d] = rotr(state[d]^state[a], 16) state[c] = state[c] + state[d] - state[b] = rotate_right(state[b]^state[c], 12) + state[b] = rotr(state[b]^state[c], 12) state[a] = state[a] + state[b] + my - state[d] = rotate_right(state[d]^state[a], 8) + state[d] = rotr(state[d]^state[a], 8) state[c] = state[c] + state[d] - state[b] = rotate_right(state[b]^state[c], 7) + state[b] = rotr(state[b]^state[c], 7) } func round(state, m *[16]uint32) { @@ -60,33 +58,20 @@ func round(state, m *[16]uint32) { } func permute(m *[16]uint32) { - var permuted [16]uint32 + permuted := [16]uint32{2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8} for i := range permuted { - permuted[i] = m[MSG_PERMUTATION[i]] + permuted[i] = m[permuted[i]] } *m = permuted } -func compress(chaining_value *[8]uint32, block_words *[16]uint32, counter uint64, block_len uint32, flags uint32) [16]uint32 { +func compress(cv [8]uint32, block [16]uint32, counter uint64, blockLen uint32, flags uint32) [16]uint32 { state := [16]uint32{ - chaining_value[0], - chaining_value[1], - chaining_value[2], - chaining_value[3], - chaining_value[4], - chaining_value[5], - chaining_value[6], - chaining_value[7], - IV[0], - IV[1], - IV[2], - IV[3], - uint32(counter), - uint32(counter >> 32), - block_len, - flags, + cv[0], cv[1], cv[2], cv[3], + cv[4], cv[5], cv[6], cv[7], + iv[0], iv[1], iv[2], iv[3], + uint32(counter), uint32(counter >> 32), blockLen, flags, } - block := *block_words round(&state, &block) // round 1 permute(&block) @@ -102,52 +87,52 @@ func compress(chaining_value *[8]uint32, block_words *[16]uint32, counter uint64 permute(&block) round(&state, &block) // round 7 - for i := range chaining_value { + for i := range cv { state[i] ^= state[i+8] - state[i+8] ^= chaining_value[i] + state[i+8] ^= cv[i] } return state } -func first_8_words(compression_output [16]uint32) (out [8]uint32) { - copy(out[:], compression_output[:8]) +func first8(words [16]uint32) (out [8]uint32) { + copy(out[:], words[:8]) return } -func words_from_litte_endian_bytes(bytes []byte, words []uint32) { +func bytesToWords(bytes []byte, words []uint32) { for i := 0; i < len(bytes); i += 4 { words[i/4] = binary.LittleEndian.Uint32(bytes[i:]) } } -// Each chunk or parent node can produce either an 8-word chaining value or, by -// setting the ROOT flag, any number of final output bytes. The output struct -// captures the state just prior to choosing between those two possibilities. -type output struct { - input_chaining_value [8]uint32 - block_words [16]uint32 - counter uint64 - block_len uint32 - flags uint32 +func wordsToBlock(words []uint32, bytes []byte) { + for i, w := range words { + binary.LittleEndian.PutUint32(bytes[i*4:], w) + } } -func (o *output) chaining_value() [8]uint32 { - return first_8_words(compress( - &o.input_chaining_value, - &o.block_words, - o.counter, - o.block_len, - o.flags, - )) +// Each chunk or parent node can produce either an 8-word chaining value or, by +// setting flagRoot, any number of final output bytes. The output struct +// captures the state just prior to choosing between those two possibilities. +type output struct { + inChain [8]uint32 + blockWords [16]uint32 + counter uint64 + blockLen uint32 + flags uint32 +} + +func (o *output) chainingValue() [8]uint32 { + return first8(compress(o.inChain, o.blockWords, o.counter, o.blockLen, o.flags)) } // An OutputReader produces an unbounded stream of output from its initial // state. type OutputReader struct { - o *output - block [BLOCK_LEN]byte - remaining int - blocks_output uint64 + o *output + block [blockLen]byte + remaining int + blocksoutput uint64 } // Read implements io.Reader. Read always return len(p), nil. @@ -156,21 +141,19 @@ func (or *OutputReader) Read(p []byte) (int, error) { for len(p) > 0 { if or.remaining == 0 { words := compress( - &or.o.input_chaining_value, - &or.o.block_words, - or.blocks_output, - or.o.block_len, - or.o.flags|ROOT, + or.o.inChain, + or.o.blockWords, + or.blocksoutput, + or.o.blockLen, + or.o.flags|flagRoot, ) - for i, w := range words { - binary.LittleEndian.PutUint32(or.block[i*4:], w) - } - or.remaining = BLOCK_LEN - or.blocks_output++ + wordsToBlock(words[:], or.block[:]) + or.remaining = blockLen + or.blocksoutput++ } // copy from output buffer - n := copy(p, or.block[BLOCK_LEN-or.remaining:]) + n := copy(p, or.block[blockLen-or.remaining:]) or.remaining -= n p = p[n:] } @@ -178,104 +161,92 @@ func (or *OutputReader) Read(p []byte) (int, error) { } type chunkState struct { - chaining_value [8]uint32 - chunk_counter uint64 - block [BLOCK_LEN]byte - block_len byte - blocks_compressed byte - flags uint32 -} - -func (cs *chunkState) len() int { - return BLOCK_LEN*int(cs.blocks_compressed) + int(cs.block_len) -} - -func (cs *chunkState) start_flag() uint32 { - if cs.blocks_compressed == 0 { - return CHUNK_START - } - return 0 + chainingValue [8]uint32 + chunkCounter uint64 + block [blockLen]byte + blockLen int + bytesConsumed int + flags uint32 } func (cs *chunkState) update(input []byte) { for len(input) > 0 { // If the block buffer is full, compress it and clear it. More - // input is coming, so this compression is not CHUNK_END. - if cs.block_len == BLOCK_LEN { - var block_words [16]uint32 - words_from_litte_endian_bytes(cs.block[:], block_words[:]) - cs.chaining_value = first_8_words(compress( - &cs.chaining_value, - &block_words, - cs.chunk_counter, - BLOCK_LEN, - cs.flags|cs.start_flag(), + // input is coming, so this compression is not flagChunkEnd. + if cs.blockLen == blockLen { + var blockWords [16]uint32 + bytesToWords(cs.block[:], blockWords[:]) + cs.chainingValue = first8(compress( + cs.chainingValue, + blockWords, + cs.chunkCounter, + blockLen, + cs.flags, )) - cs.blocks_compressed++ - cs.block = [BLOCK_LEN]byte{} - cs.block_len = 0 + cs.block = [blockLen]byte{} + cs.blockLen = 0 + // After the first chunk has been compressed, clear the start flag. + cs.flags &^= flagChunkStart } // Copy input bytes into the block buffer. - n := copy(cs.block[cs.block_len:], input) - cs.block_len += byte(n) + n := copy(cs.block[cs.blockLen:], input) + cs.blockLen += n + cs.bytesConsumed += n input = input[n:] } } func (cs *chunkState) output() *output { - var block_words [16]uint32 - words_from_litte_endian_bytes(cs.block[:], block_words[:]) + var blockWords [16]uint32 + bytesToWords(cs.block[:], blockWords[:]) return &output{ - input_chaining_value: cs.chaining_value, - block_words: block_words, - block_len: uint32(cs.block_len), - counter: cs.chunk_counter, - flags: cs.flags | cs.start_flag() | CHUNK_END, + inChain: cs.chainingValue, + blockWords: blockWords, + blockLen: uint32(cs.blockLen), + counter: cs.chunkCounter, + flags: cs.flags | flagChunkEnd, } } -func newChunkState(key [8]uint32, chunk_counter uint64, flags uint32) chunkState { +func newChunkState(key [8]uint32, chunkCounter uint64, flags uint32) chunkState { return chunkState{ - chaining_value: key, - chunk_counter: chunk_counter, - flags: flags, + chainingValue: key, + chunkCounter: chunkCounter, + // compress the first chunk with the start flag set + flags: flags | flagChunkStart, } } -func parent_output(left_child_cv [8]uint32, right_child_cv [8]uint32, key [8]uint32, flags uint32) *output { - var block_words [16]uint32 - copy(block_words[:8], left_child_cv[:]) - copy(block_words[8:], right_child_cv[:]) +func parentOutput(left, right [8]uint32, key [8]uint32, flags uint32) *output { + var blockWords [16]uint32 + copy(blockWords[:8], left[:]) + copy(blockWords[8:], right[:]) return &output{ - input_chaining_value: key, - block_words: block_words, - counter: 0, // Always 0 for parent nodes. - block_len: BLOCK_LEN, // Always BLOCK_LEN (64) for parent nodes. - flags: PARENT | flags, + inChain: key, + blockWords: blockWords, + counter: 0, // Always 0 for parent nodes. + blockLen: blockLen, // Always blockLen (64) for parent nodes. + flags: flagParent | flags, } } -func parent_cv(left_child_cv [8]uint32, right_child_cv [8]uint32, key [8]uint32, flags uint32) [8]uint32 { - return parent_output(left_child_cv, right_child_cv, key, flags).chaining_value() -} - // Hasher implements hash.Hash. type Hasher struct { - chunk_state chunkState - key [8]uint32 - cv_stack [54][8]uint32 // Space for 54 subtree chaining values: - cv_stack_len byte // 2^54 * CHUNK_LEN = 2^64 - flags uint32 - out_size int + cs chunkState + key [8]uint32 + chainStack [54][8]uint32 // space for 54 subtrees (2^54 * chunkLen = 2^64) + stackSize int // index within chainStack + flags uint32 + size int // output size, for Sum } -func newHasher(key [8]uint32, flags uint32, out_size int) *Hasher { +func newHasher(key [8]uint32, flags uint32, size int) *Hasher { return &Hasher{ - chunk_state: newChunkState(key, 0, flags), - key: key, - flags: flags, - out_size: out_size, + cs: newChunkState(key, 0, flags), + key: key, + flags: flags, + size: size, } } @@ -283,90 +254,89 @@ func newHasher(key [8]uint32, flags uint32, out_size int) *Hasher { // is unkeyed. func New(size int, key []byte) *Hasher { if key == nil { - return newHasher(IV, 0, size) + return newHasher(iv, 0, size) } - var key_words [8]uint32 - words_from_litte_endian_bytes(key[:], key_words[:]) - return newHasher(key_words, KEYED_HASH, size) + var keyWords [8]uint32 + bytesToWords(key[:], keyWords[:]) + return newHasher(keyWords, flagKeyedHash, size) } // NewFromDerivedKey returns a Hasher whose key was derived from the supplied // context string. func NewFromDerivedKey(size int, ctx string) *Hasher { - h := newHasher(IV, DERIVE_KEY_CONTEXT, KEY_LEN) + const ( + derivedKeyLen = 32 + ) + h := newHasher(iv, flagDeriveKeyContext, derivedKeyLen) h.Write([]byte(ctx)) key := h.Sum(nil) - var key_words [8]uint32 - words_from_litte_endian_bytes(key, key_words[:]) - return newHasher(key_words, DERIVE_KEY_MATERIAL, size) + var keyWords [8]uint32 + bytesToWords(key, keyWords[:]) + return newHasher(keyWords, flagDeriveKeyMaterial, size) } -func (h *Hasher) push_stack(cv [8]uint32) { - h.cv_stack[h.cv_stack_len] = cv - h.cv_stack_len++ -} - -func (h *Hasher) pop_stack() [8]uint32 { - h.cv_stack_len-- - return h.cv_stack[h.cv_stack_len] -} - -func (h *Hasher) add_chunk_chaining_value(new_cv [8]uint32, total_chunks uint64) { +func (h *Hasher) addChunkChainingValue(cv [8]uint32, totalChunks uint64) { // This chunk might complete some subtrees. For each completed subtree, // its left child will be the current top entry in the CV stack, and - // its right child will be the current value of `new_cv`. Pop each left - // child off the stack, merge it with `new_cv`, and overwrite `new_cv` + // its right child will be the current value of `cv`. Pop each left + // child off the stack, merge it with `cv`, and overwrite `cv` // with the result. After all these merges, push the final value of - // `new_cv` onto the stack. The number of completed subtrees is given + // `cv` onto the stack. The number of completed subtrees is given // by the number of trailing 0-bits in the new total number of chunks. - for total_chunks&1 == 0 { - new_cv = parent_cv(h.pop_stack(), new_cv, h.key, h.flags) - total_chunks >>= 1 + right := cv + for totalChunks&1 == 0 { + // pop + h.stackSize-- + left := h.chainStack[h.stackSize] + // merge + right = parentOutput(left, right, h.key, h.flags).chainingValue() + totalChunks >>= 1 } - h.push_stack(new_cv) + h.chainStack[h.stackSize] = right + h.stackSize++ } // Reset implements hash.Hash. func (h *Hasher) Reset() { - h.chunk_state = newChunkState(h.key, 0, h.flags) - h.cv_stack_len = 0 + h.cs = newChunkState(h.key, 0, h.flags) + h.stackSize = 0 } // BlockSize implements hash.Hash. func (h *Hasher) BlockSize() int { return 64 } // Size implements hash.Hash. -func (h *Hasher) Size() int { return h.out_size } +func (h *Hasher) Size() int { return h.size } // Write implements hash.Hash. -func (h *Hasher) Write(input []byte) (int, error) { - written := len(input) - for len(input) > 0 { +func (h *Hasher) Write(p []byte) (int, error) { + lenp := len(p) + for len(p) > 0 { // If the current chunk is complete, finalize it and reset the - // chunk state. More input is coming, so this chunk is not ROOT. - if h.chunk_state.len() == CHUNK_LEN { - chunk_cv := h.chunk_state.output().chaining_value() - total_chunks := h.chunk_state.chunk_counter + 1 - h.add_chunk_chaining_value(chunk_cv, total_chunks) - h.chunk_state = newChunkState(h.key, total_chunks, h.flags) + // chunk state. More input is coming, so this chunk is not flagRoot. + if h.cs.bytesConsumed == chunkLen { + cv := h.cs.output().chainingValue() + totalChunks := h.cs.chunkCounter + 1 + h.addChunkChainingValue(cv, totalChunks) + h.cs = newChunkState(h.key, totalChunks, h.flags) } // Compress input bytes into the current chunk state. - n := len(input) - if n > CHUNK_LEN-h.chunk_state.len() { - n = CHUNK_LEN - h.chunk_state.len() + n := chunkLen - h.cs.bytesConsumed + if n > len(p) { + n = len(p) } - h.chunk_state.update(input[:n]) - input = input[n:] + h.cs.update(p[:n]) + p = p[n:] } - return written, nil + return lenp, nil } // Sum implements hash.Hash. -func (h *Hasher) Sum(out_slice []byte) []byte { +func (h *Hasher) Sum(b []byte) []byte { out := make([]byte, h.Size()) h.XOF().Read(out) - return append(out_slice, out...) + return append(b, out...) } // XOF returns an OutputReader initialized with the current hash state. @@ -374,13 +344,11 @@ func (h *Hasher) XOF() *OutputReader { // Starting with the output from the current chunk, compute all the // parent chaining values along the right edge of the tree, until we // have the root output. - var output = h.chunk_state.output() - var parent_nodes_remaining = h.cv_stack_len - for parent_nodes_remaining > 0 { - parent_nodes_remaining-- - output = parent_output( - h.cv_stack[parent_nodes_remaining], - output.chaining_value(), + output := h.cs.output() + for i := h.stackSize - 1; i >= 0; i-- { + output = parentOutput( + h.chainStack[i], + output.chainingValue(), h.key, h.flags, ) diff --git a/blake3_test.go b/blake3_test.go index 343dbca..2079fbf 100644 --- a/blake3_test.go +++ b/blake3_test.go @@ -11,17 +11,7 @@ import ( "lukechampine.com/blake3" ) -func toHex(data []byte) string { - return hex.EncodeToString(data) -} - -func fromHex(s string) []byte { - data, err := hex.DecodeString(s) - if err != nil { - panic(err) - } - return data -} +func toHex(data []byte) string { return hex.EncodeToString(data) } func TestVectors(t *testing.T) { data, err := ioutil.ReadFile("testdata/vectors.json") @@ -89,7 +79,7 @@ func BenchmarkWrite(b *testing.B) { func BenchmarkChunk(b *testing.B) { h := blake3.New(32, nil) - buf := make([]byte, blake3.CHUNK_LEN) + buf := make([]byte, 1024) out := make([]byte, 32) for i := 0; i < b.N; i++ { h.Write(buf)