unexport constants and refactor to idiomatic Go

This commit is contained in:
lukechampine 2020-01-09 17:58:48 -05:00
parent 2ca7badf67
commit 83947f416f
2 changed files with 172 additions and 214 deletions

344
blake3.go
View File

@ -10,40 +10,38 @@ import (
) )
const ( const (
OUT_LEN = 32 blockLen = 64
KEY_LEN = 32 chunkLen = 1024
BLOCK_LEN = 64
CHUNK_LEN = 1024
CHUNK_START = 1 << 0
CHUNK_END = 1 << 1
PARENT = 1 << 2
ROOT = 1 << 3
KEYED_HASH = 1 << 4
DERIVE_KEY_CONTEXT = 1 << 5
DERIVE_KEY_MATERIAL = 1 << 6
) )
var IV = [8]uint32{ // flags
0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19, const (
flagChunkStart = 1 << iota
flagChunkEnd
flagParent
flagRoot
flagKeyedHash
flagDeriveKeyContext
flagDeriveKeyMaterial
)
var iv = [8]uint32{
0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A,
0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19,
} }
var MSG_PERMUTATION = [16]uint{2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8} func g(state *[16]uint32, a, b, c, d int, mx, my uint32) {
rotr := func(x uint32, n int) uint32 {
func rotate_right(x uint32, n int) uint32 {
return (x >> n) | (x << (32 - n)) return (x >> n) | (x << (32 - n))
} }
// The mixing function, G, which mixes either a column or a diagonal.
func g(state *[16]uint32, a, b, c, d int, mx, my uint32) {
state[a] = state[a] + state[b] + mx state[a] = state[a] + state[b] + mx
state[d] = rotate_right(state[d]^state[a], 16) state[d] = rotr(state[d]^state[a], 16)
state[c] = state[c] + state[d] state[c] = state[c] + state[d]
state[b] = rotate_right(state[b]^state[c], 12) state[b] = rotr(state[b]^state[c], 12)
state[a] = state[a] + state[b] + my state[a] = state[a] + state[b] + my
state[d] = rotate_right(state[d]^state[a], 8) state[d] = rotr(state[d]^state[a], 8)
state[c] = state[c] + state[d] state[c] = state[c] + state[d]
state[b] = rotate_right(state[b]^state[c], 7) state[b] = rotr(state[b]^state[c], 7)
} }
func round(state, m *[16]uint32) { func round(state, m *[16]uint32) {
@ -60,33 +58,20 @@ func round(state, m *[16]uint32) {
} }
func permute(m *[16]uint32) { func permute(m *[16]uint32) {
var permuted [16]uint32 permuted := [16]uint32{2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8}
for i := range permuted { for i := range permuted {
permuted[i] = m[MSG_PERMUTATION[i]] permuted[i] = m[permuted[i]]
} }
*m = permuted *m = permuted
} }
func compress(chaining_value *[8]uint32, block_words *[16]uint32, counter uint64, block_len uint32, flags uint32) [16]uint32 { func compress(cv [8]uint32, block [16]uint32, counter uint64, blockLen uint32, flags uint32) [16]uint32 {
state := [16]uint32{ state := [16]uint32{
chaining_value[0], cv[0], cv[1], cv[2], cv[3],
chaining_value[1], cv[4], cv[5], cv[6], cv[7],
chaining_value[2], iv[0], iv[1], iv[2], iv[3],
chaining_value[3], uint32(counter), uint32(counter >> 32), blockLen, flags,
chaining_value[4],
chaining_value[5],
chaining_value[6],
chaining_value[7],
IV[0],
IV[1],
IV[2],
IV[3],
uint32(counter),
uint32(counter >> 32),
block_len,
flags,
} }
block := *block_words
round(&state, &block) // round 1 round(&state, &block) // round 1
permute(&block) permute(&block)
@ -102,52 +87,52 @@ func compress(chaining_value *[8]uint32, block_words *[16]uint32, counter uint64
permute(&block) permute(&block)
round(&state, &block) // round 7 round(&state, &block) // round 7
for i := range chaining_value { for i := range cv {
state[i] ^= state[i+8] state[i] ^= state[i+8]
state[i+8] ^= chaining_value[i] state[i+8] ^= cv[i]
} }
return state return state
} }
func first_8_words(compression_output [16]uint32) (out [8]uint32) { func first8(words [16]uint32) (out [8]uint32) {
copy(out[:], compression_output[:8]) copy(out[:], words[:8])
return return
} }
func words_from_litte_endian_bytes(bytes []byte, words []uint32) { func bytesToWords(bytes []byte, words []uint32) {
for i := 0; i < len(bytes); i += 4 { for i := 0; i < len(bytes); i += 4 {
words[i/4] = binary.LittleEndian.Uint32(bytes[i:]) words[i/4] = binary.LittleEndian.Uint32(bytes[i:])
} }
} }
func wordsToBlock(words []uint32, bytes []byte) {
for i, w := range words {
binary.LittleEndian.PutUint32(bytes[i*4:], w)
}
}
// Each chunk or parent node can produce either an 8-word chaining value or, by // Each chunk or parent node can produce either an 8-word chaining value or, by
// setting the ROOT flag, any number of final output bytes. The output struct // setting flagRoot, any number of final output bytes. The output struct
// captures the state just prior to choosing between those two possibilities. // captures the state just prior to choosing between those two possibilities.
type output struct { type output struct {
input_chaining_value [8]uint32 inChain [8]uint32
block_words [16]uint32 blockWords [16]uint32
counter uint64 counter uint64
block_len uint32 blockLen uint32
flags uint32 flags uint32
} }
func (o *output) chaining_value() [8]uint32 { func (o *output) chainingValue() [8]uint32 {
return first_8_words(compress( return first8(compress(o.inChain, o.blockWords, o.counter, o.blockLen, o.flags))
&o.input_chaining_value,
&o.block_words,
o.counter,
o.block_len,
o.flags,
))
} }
// An OutputReader produces an unbounded stream of output from its initial // An OutputReader produces an unbounded stream of output from its initial
// state. // state.
type OutputReader struct { type OutputReader struct {
o *output o *output
block [BLOCK_LEN]byte block [blockLen]byte
remaining int remaining int
blocks_output uint64 blocksoutput uint64
} }
// Read implements io.Reader. Read always return len(p), nil. // Read implements io.Reader. Read always return len(p), nil.
@ -156,21 +141,19 @@ func (or *OutputReader) Read(p []byte) (int, error) {
for len(p) > 0 { for len(p) > 0 {
if or.remaining == 0 { if or.remaining == 0 {
words := compress( words := compress(
&or.o.input_chaining_value, or.o.inChain,
&or.o.block_words, or.o.blockWords,
or.blocks_output, or.blocksoutput,
or.o.block_len, or.o.blockLen,
or.o.flags|ROOT, or.o.flags|flagRoot,
) )
for i, w := range words { wordsToBlock(words[:], or.block[:])
binary.LittleEndian.PutUint32(or.block[i*4:], w) or.remaining = blockLen
} or.blocksoutput++
or.remaining = BLOCK_LEN
or.blocks_output++
} }
// copy from output buffer // copy from output buffer
n := copy(p, or.block[BLOCK_LEN-or.remaining:]) n := copy(p, or.block[blockLen-or.remaining:])
or.remaining -= n or.remaining -= n
p = p[n:] p = p[n:]
} }
@ -178,104 +161,92 @@ func (or *OutputReader) Read(p []byte) (int, error) {
} }
type chunkState struct { type chunkState struct {
chaining_value [8]uint32 chainingValue [8]uint32
chunk_counter uint64 chunkCounter uint64
block [BLOCK_LEN]byte block [blockLen]byte
block_len byte blockLen int
blocks_compressed byte bytesConsumed int
flags uint32 flags uint32
} }
func (cs *chunkState) len() int {
return BLOCK_LEN*int(cs.blocks_compressed) + int(cs.block_len)
}
func (cs *chunkState) start_flag() uint32 {
if cs.blocks_compressed == 0 {
return CHUNK_START
}
return 0
}
func (cs *chunkState) update(input []byte) { func (cs *chunkState) update(input []byte) {
for len(input) > 0 { for len(input) > 0 {
// If the block buffer is full, compress it and clear it. More // If the block buffer is full, compress it and clear it. More
// input is coming, so this compression is not CHUNK_END. // input is coming, so this compression is not flagChunkEnd.
if cs.block_len == BLOCK_LEN { if cs.blockLen == blockLen {
var block_words [16]uint32 var blockWords [16]uint32
words_from_litte_endian_bytes(cs.block[:], block_words[:]) bytesToWords(cs.block[:], blockWords[:])
cs.chaining_value = first_8_words(compress( cs.chainingValue = first8(compress(
&cs.chaining_value, cs.chainingValue,
&block_words, blockWords,
cs.chunk_counter, cs.chunkCounter,
BLOCK_LEN, blockLen,
cs.flags|cs.start_flag(), cs.flags,
)) ))
cs.blocks_compressed++ cs.block = [blockLen]byte{}
cs.block = [BLOCK_LEN]byte{} cs.blockLen = 0
cs.block_len = 0 // After the first chunk has been compressed, clear the start flag.
cs.flags &^= flagChunkStart
} }
// Copy input bytes into the block buffer. // Copy input bytes into the block buffer.
n := copy(cs.block[cs.block_len:], input) n := copy(cs.block[cs.blockLen:], input)
cs.block_len += byte(n) cs.blockLen += n
cs.bytesConsumed += n
input = input[n:] input = input[n:]
} }
} }
func (cs *chunkState) output() *output { func (cs *chunkState) output() *output {
var block_words [16]uint32 var blockWords [16]uint32
words_from_litte_endian_bytes(cs.block[:], block_words[:]) bytesToWords(cs.block[:], blockWords[:])
return &output{ return &output{
input_chaining_value: cs.chaining_value, inChain: cs.chainingValue,
block_words: block_words, blockWords: blockWords,
block_len: uint32(cs.block_len), blockLen: uint32(cs.blockLen),
counter: cs.chunk_counter, counter: cs.chunkCounter,
flags: cs.flags | cs.start_flag() | CHUNK_END, flags: cs.flags | flagChunkEnd,
} }
} }
func newChunkState(key [8]uint32, chunk_counter uint64, flags uint32) chunkState { func newChunkState(key [8]uint32, chunkCounter uint64, flags uint32) chunkState {
return chunkState{ return chunkState{
chaining_value: key, chainingValue: key,
chunk_counter: chunk_counter, chunkCounter: chunkCounter,
flags: flags, // compress the first chunk with the start flag set
flags: flags | flagChunkStart,
} }
} }
func parent_output(left_child_cv [8]uint32, right_child_cv [8]uint32, key [8]uint32, flags uint32) *output { func parentOutput(left, right [8]uint32, key [8]uint32, flags uint32) *output {
var block_words [16]uint32 var blockWords [16]uint32
copy(block_words[:8], left_child_cv[:]) copy(blockWords[:8], left[:])
copy(block_words[8:], right_child_cv[:]) copy(blockWords[8:], right[:])
return &output{ return &output{
input_chaining_value: key, inChain: key,
block_words: block_words, blockWords: blockWords,
counter: 0, // Always 0 for parent nodes. counter: 0, // Always 0 for parent nodes.
block_len: BLOCK_LEN, // Always BLOCK_LEN (64) for parent nodes. blockLen: blockLen, // Always blockLen (64) for parent nodes.
flags: PARENT | flags, flags: flagParent | flags,
} }
} }
func parent_cv(left_child_cv [8]uint32, right_child_cv [8]uint32, key [8]uint32, flags uint32) [8]uint32 {
return parent_output(left_child_cv, right_child_cv, key, flags).chaining_value()
}
// Hasher implements hash.Hash. // Hasher implements hash.Hash.
type Hasher struct { type Hasher struct {
chunk_state chunkState cs chunkState
key [8]uint32 key [8]uint32
cv_stack [54][8]uint32 // Space for 54 subtree chaining values: chainStack [54][8]uint32 // space for 54 subtrees (2^54 * chunkLen = 2^64)
cv_stack_len byte // 2^54 * CHUNK_LEN = 2^64 stackSize int // index within chainStack
flags uint32 flags uint32
out_size int size int // output size, for Sum
} }
func newHasher(key [8]uint32, flags uint32, out_size int) *Hasher { func newHasher(key [8]uint32, flags uint32, size int) *Hasher {
return &Hasher{ return &Hasher{
chunk_state: newChunkState(key, 0, flags), cs: newChunkState(key, 0, flags),
key: key, key: key,
flags: flags, flags: flags,
out_size: out_size, size: size,
} }
} }
@ -283,90 +254,89 @@ func newHasher(key [8]uint32, flags uint32, out_size int) *Hasher {
// is unkeyed. // is unkeyed.
func New(size int, key []byte) *Hasher { func New(size int, key []byte) *Hasher {
if key == nil { if key == nil {
return newHasher(IV, 0, size) return newHasher(iv, 0, size)
} }
var key_words [8]uint32 var keyWords [8]uint32
words_from_litte_endian_bytes(key[:], key_words[:]) bytesToWords(key[:], keyWords[:])
return newHasher(key_words, KEYED_HASH, size) return newHasher(keyWords, flagKeyedHash, size)
} }
// NewFromDerivedKey returns a Hasher whose key was derived from the supplied // NewFromDerivedKey returns a Hasher whose key was derived from the supplied
// context string. // context string.
func NewFromDerivedKey(size int, ctx string) *Hasher { func NewFromDerivedKey(size int, ctx string) *Hasher {
h := newHasher(IV, DERIVE_KEY_CONTEXT, KEY_LEN) const (
derivedKeyLen = 32
)
h := newHasher(iv, flagDeriveKeyContext, derivedKeyLen)
h.Write([]byte(ctx)) h.Write([]byte(ctx))
key := h.Sum(nil) key := h.Sum(nil)
var key_words [8]uint32 var keyWords [8]uint32
words_from_litte_endian_bytes(key, key_words[:]) bytesToWords(key, keyWords[:])
return newHasher(key_words, DERIVE_KEY_MATERIAL, size) return newHasher(keyWords, flagDeriveKeyMaterial, size)
} }
func (h *Hasher) push_stack(cv [8]uint32) { func (h *Hasher) addChunkChainingValue(cv [8]uint32, totalChunks uint64) {
h.cv_stack[h.cv_stack_len] = cv
h.cv_stack_len++
}
func (h *Hasher) pop_stack() [8]uint32 {
h.cv_stack_len--
return h.cv_stack[h.cv_stack_len]
}
func (h *Hasher) add_chunk_chaining_value(new_cv [8]uint32, total_chunks uint64) {
// This chunk might complete some subtrees. For each completed subtree, // This chunk might complete some subtrees. For each completed subtree,
// its left child will be the current top entry in the CV stack, and // its left child will be the current top entry in the CV stack, and
// its right child will be the current value of `new_cv`. Pop each left // its right child will be the current value of `cv`. Pop each left
// child off the stack, merge it with `new_cv`, and overwrite `new_cv` // child off the stack, merge it with `cv`, and overwrite `cv`
// with the result. After all these merges, push the final value of // with the result. After all these merges, push the final value of
// `new_cv` onto the stack. The number of completed subtrees is given // `cv` onto the stack. The number of completed subtrees is given
// by the number of trailing 0-bits in the new total number of chunks. // by the number of trailing 0-bits in the new total number of chunks.
for total_chunks&1 == 0 { right := cv
new_cv = parent_cv(h.pop_stack(), new_cv, h.key, h.flags) for totalChunks&1 == 0 {
total_chunks >>= 1 // pop
h.stackSize--
left := h.chainStack[h.stackSize]
// merge
right = parentOutput(left, right, h.key, h.flags).chainingValue()
totalChunks >>= 1
} }
h.push_stack(new_cv) h.chainStack[h.stackSize] = right
h.stackSize++
} }
// Reset implements hash.Hash. // Reset implements hash.Hash.
func (h *Hasher) Reset() { func (h *Hasher) Reset() {
h.chunk_state = newChunkState(h.key, 0, h.flags) h.cs = newChunkState(h.key, 0, h.flags)
h.cv_stack_len = 0 h.stackSize = 0
} }
// BlockSize implements hash.Hash. // BlockSize implements hash.Hash.
func (h *Hasher) BlockSize() int { return 64 } func (h *Hasher) BlockSize() int { return 64 }
// Size implements hash.Hash. // Size implements hash.Hash.
func (h *Hasher) Size() int { return h.out_size } func (h *Hasher) Size() int { return h.size }
// Write implements hash.Hash. // Write implements hash.Hash.
func (h *Hasher) Write(input []byte) (int, error) { func (h *Hasher) Write(p []byte) (int, error) {
written := len(input) lenp := len(p)
for len(input) > 0 { for len(p) > 0 {
// If the current chunk is complete, finalize it and reset the // If the current chunk is complete, finalize it and reset the
// chunk state. More input is coming, so this chunk is not ROOT. // chunk state. More input is coming, so this chunk is not flagRoot.
if h.chunk_state.len() == CHUNK_LEN { if h.cs.bytesConsumed == chunkLen {
chunk_cv := h.chunk_state.output().chaining_value() cv := h.cs.output().chainingValue()
total_chunks := h.chunk_state.chunk_counter + 1 totalChunks := h.cs.chunkCounter + 1
h.add_chunk_chaining_value(chunk_cv, total_chunks) h.addChunkChainingValue(cv, totalChunks)
h.chunk_state = newChunkState(h.key, total_chunks, h.flags) h.cs = newChunkState(h.key, totalChunks, h.flags)
} }
// Compress input bytes into the current chunk state. // Compress input bytes into the current chunk state.
n := len(input) n := chunkLen - h.cs.bytesConsumed
if n > CHUNK_LEN-h.chunk_state.len() { if n > len(p) {
n = CHUNK_LEN - h.chunk_state.len() n = len(p)
} }
h.chunk_state.update(input[:n]) h.cs.update(p[:n])
input = input[n:] p = p[n:]
} }
return written, nil return lenp, nil
} }
// Sum implements hash.Hash. // Sum implements hash.Hash.
func (h *Hasher) Sum(out_slice []byte) []byte { func (h *Hasher) Sum(b []byte) []byte {
out := make([]byte, h.Size()) out := make([]byte, h.Size())
h.XOF().Read(out) h.XOF().Read(out)
return append(out_slice, out...) return append(b, out...)
} }
// XOF returns an OutputReader initialized with the current hash state. // XOF returns an OutputReader initialized with the current hash state.
@ -374,13 +344,11 @@ func (h *Hasher) XOF() *OutputReader {
// Starting with the output from the current chunk, compute all the // Starting with the output from the current chunk, compute all the
// parent chaining values along the right edge of the tree, until we // parent chaining values along the right edge of the tree, until we
// have the root output. // have the root output.
var output = h.chunk_state.output() output := h.cs.output()
var parent_nodes_remaining = h.cv_stack_len for i := h.stackSize - 1; i >= 0; i-- {
for parent_nodes_remaining > 0 { output = parentOutput(
parent_nodes_remaining-- h.chainStack[i],
output = parent_output( output.chainingValue(),
h.cv_stack[parent_nodes_remaining],
output.chaining_value(),
h.key, h.key,
h.flags, h.flags,
) )

View File

@ -11,17 +11,7 @@ import (
"lukechampine.com/blake3" "lukechampine.com/blake3"
) )
func toHex(data []byte) string { func toHex(data []byte) string { return hex.EncodeToString(data) }
return hex.EncodeToString(data)
}
func fromHex(s string) []byte {
data, err := hex.DecodeString(s)
if err != nil {
panic(err)
}
return data
}
func TestVectors(t *testing.T) { func TestVectors(t *testing.T) {
data, err := ioutil.ReadFile("testdata/vectors.json") data, err := ioutil.ReadFile("testdata/vectors.json")
@ -89,7 +79,7 @@ func BenchmarkWrite(b *testing.B) {
func BenchmarkChunk(b *testing.B) { func BenchmarkChunk(b *testing.B) {
h := blake3.New(32, nil) h := blake3.New(32, nil)
buf := make([]byte, blake3.CHUNK_LEN) buf := make([]byte, 1024)
out := make([]byte, 32) out := make([]byte, 32)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
h.Write(buf) h.Write(buf)