diff --git a/blake3.go b/blake3.go index 3a5e34d..094582d 100644 --- a/blake3.go +++ b/blake3.go @@ -9,6 +9,7 @@ import ( "errors" "hash" "io" + "math/bits" ) const ( @@ -32,31 +33,40 @@ var iv = [8]uint32{ 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19, } -func g(state *[16]uint32, a, b, c, d int, mx, my uint32) { - rotr := func(x uint32, n int) uint32 { - return (x >> n) | (x << (32 - n)) - } - state[a] = state[a] + state[b] + mx - state[d] = rotr(state[d]^state[a], 16) - state[c] = state[c] + state[d] - state[b] = rotr(state[b]^state[c], 12) - state[a] = state[a] + state[b] + my - state[d] = rotr(state[d]^state[a], 8) - state[c] = state[c] + state[d] - state[b] = rotr(state[b]^state[c], 7) +func gx(state *[16]uint32, a, b, c, d int, mx uint32) { + state[a] += state[b] + mx + state[d] = bits.RotateLeft32(state[d]^state[a], -16) + state[c] += state[d] + state[b] = bits.RotateLeft32(state[b]^state[c], -12) +} + +func gy(state *[16]uint32, a, b, c, d int, my uint32) { + state[a] += state[b] + my + state[d] = bits.RotateLeft32(state[d]^state[a], -8) + state[c] += state[d] + state[b] = bits.RotateLeft32(state[b]^state[c], -7) } func round(state *[16]uint32, m *[16]uint32) { // Mix the columns. - g(state, 0, 4, 8, 12, m[0], m[1]) - g(state, 1, 5, 9, 13, m[2], m[3]) - g(state, 2, 6, 10, 14, m[4], m[5]) - g(state, 3, 7, 11, 15, m[6], m[7]) + gx(state, 0, 4, 8, 12, m[0]) + gy(state, 0, 4, 8, 12, m[1]) + gx(state, 1, 5, 9, 13, m[2]) + gy(state, 1, 5, 9, 13, m[3]) + gx(state, 2, 6, 10, 14, m[4]) + gy(state, 2, 6, 10, 14, m[5]) + gx(state, 3, 7, 11, 15, m[6]) + gy(state, 3, 7, 11, 15, m[7]) + // Mix the diagonals. - g(state, 0, 5, 10, 15, m[8], m[9]) - g(state, 1, 6, 11, 12, m[10], m[11]) - g(state, 2, 7, 8, 13, m[12], m[13]) - g(state, 3, 4, 9, 14, m[14], m[15]) + gx(state, 0, 5, 10, 15, m[8]) + gy(state, 0, 5, 10, 15, m[9]) + gx(state, 1, 6, 11, 12, m[10]) + gy(state, 1, 6, 11, 12, m[11]) + gx(state, 2, 7, 8, 13, m[12]) + gy(state, 2, 7, 8, 13, m[13]) + gx(state, 3, 4, 9, 14, m[14]) + gy(state, 3, 4, 9, 14, m[15]) } func permute(m *[16]uint32) {