#! /usr/bin/env python3

# This is a Python port of the Rust reference implementation of BLAKE3:
# https://github.com/BLAKE3-team/BLAKE3/blob/master/reference_impl/reference_impl.rs

from __future__ import annotations
from dataclasses import dataclass

OUT_LEN = 32
KEY_LEN = 32
BLOCK_LEN = 64
CHUNK_LEN = 1024

CHUNK_START = 1 << 0
CHUNK_END = 1 << 1
PARENT = 1 << 2
ROOT = 1 << 3
KEYED_HASH = 1 << 4
DERIVE_KEY_CONTEXT = 1 << 5
DERIVE_KEY_MATERIAL = 1 << 6

IV = [
    0x6A09E667,
    0xBB67AE85,
    0x3C6EF372,
    0xA54FF53A,
    0x510E527F,
    0x9B05688C,
    0x1F83D9AB,
    0x5BE0CD19,
]

MSG_PERMUTATION = [2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8]


def mask32(x: int) -> int:
    return x & 0xFFFFFFFF


def add32(x: int, y: int) -> int:
    return mask32(x + y)


def rightrotate32(x: int, n: int) -> int:
    return mask32(x << (32 - n)) | (x >> n)


# The mixing function, G, which mixes either a column or a diagonal.
def g(state: list[int], a: int, b: int, c: int, d: int, mx: int, my: int) -> None:
    state[a] = add32(state[a], add32(state[b], mx))
    state[d] = rightrotate32(state[d] ^ state[a], 16)
    state[c] = add32(state[c], state[d])
    state[b] = rightrotate32(state[b] ^ state[c], 12)
    state[a] = add32(state[a], add32(state[b], my))
    state[d] = rightrotate32(state[d] ^ state[a], 8)
    state[c] = add32(state[c], state[d])
    state[b] = rightrotate32(state[b] ^ state[c], 7)


def round(state: list[int], m: list[int]) -> None:
    # Mix the columns.
    g(state, 0, 4, 8, 12, m[0], m[1])
    g(state, 1, 5, 9, 13, m[2], m[3])
    g(state, 2, 6, 10, 14, m[4], m[5])
    g(state, 3, 7, 11, 15, m[6], m[7])
    # Mix the diagonals.
    g(state, 0, 5, 10, 15, m[8], m[9])
    g(state, 1, 6, 11, 12, m[10], m[11])
    g(state, 2, 7, 8, 13, m[12], m[13])
    g(state, 3, 4, 9, 14, m[14], m[15])


def permute(m: list[int]) -> None:
    original = list(m)
    for i in range(16):
        m[i] = original[MSG_PERMUTATION[i]]


def compress(
    chaining_value: list[int],
    block_words: list[int],
    counter: int,
    block_len: int,
    flags: int,
) -> list[int]:
    state = [
        chaining_value[0],
        chaining_value[1],
        chaining_value[2],
        chaining_value[3],
        chaining_value[4],
        chaining_value[5],
        chaining_value[6],
        chaining_value[7],
        IV[0],
        IV[1],
        IV[2],
        IV[3],
        mask32(counter),
        mask32(counter >> 32),
        block_len,
        flags,
    ]

    assert len(block_words) == 16
    block = list(block_words)

    round(state, block)  # round 1
    permute(block)
    round(state, block)  # round 2
    permute(block)
    round(state, block)  # round 3
    permute(block)
    round(state, block)  # round 4
    permute(block)
    round(state, block)  # round 5
    permute(block)
    round(state, block)  # round 6
    permute(block)
    round(state, block)  # round 7

    for i in range(8):
        state[i] ^= state[i + 8]
        state[i + 8] ^= chaining_value[i]

    return state


def words_from_little_endian_bytes(b: bytes) -> list[int]:
    assert len(b) % 4 == 0
    return [int.from_bytes(b[i : i + 4], "little") for i in range(0, len(b), 4)]


# Each chunk or parent node can produce either an 8-word chaining value or, by
# setting the ROOT flag, any number of final output bytes. The Output struct
# captures the state just prior to choosing between those two possibilities.
@dataclass
class Output:
    input_chaining_value: list[int]
    block_words: list[int]
    counter: int
    block_len: int
    flags: int

    def chaining_value(self) -> list[int]:
        return compress(
            self.input_chaining_value,
            self.block_words,
            self.counter,
            self.block_len,
            self.flags,
        )[:8]

    def root_output_bytes(self, length: int) -> bytes:
        output_bytes = bytearray()
        i = 0
        while i < length:
            words = compress(
                self.input_chaining_value,
                self.block_words,
                i // 64,
                self.block_len,
                self.flags | ROOT,
            )
            # The output length might not be a multiple of 4.
            for word in words:
                word_bytes = word.to_bytes(4, "little")
                take = min(len(word_bytes), length - i)
                output_bytes.extend(word_bytes[:take])
                i += take
        return output_bytes


@dataclass
class ChunkState:
    chaining_value: list[int]
    chunk_counter: int
    block: bytearray
    block_len: int
    blocks_compressed: int
    flags: int

    def __init__(self, key_words: list[int], chunk_counter: int, flags: int) -> None:
        self.chaining_value = key_words
        self.chunk_counter = chunk_counter
        self.block = bytearray(BLOCK_LEN)
        self.block_len = 0
        self.blocks_compressed = 0
        self.flags = flags

    def len(self) -> int:
        return BLOCK_LEN * self.blocks_compressed + self.block_len

    def start_flag(self) -> int:
        if self.blocks_compressed == 0:
            return CHUNK_START
        else:
            return 0

    def update(self, input_bytes: bytes) -> None:
        while input_bytes:
            # If the block buffer is full, compress it and clear it. More
            # input_bytes is coming, so this compression is not CHUNK_END.
            if self.block_len == BLOCK_LEN:
                block_words = words_from_little_endian_bytes(self.block)
                self.chaining_value = compress(
                    self.chaining_value,
                    block_words,
                    self.chunk_counter,
                    BLOCK_LEN,
                    self.flags | self.start_flag(),
                )[:8]
                self.blocks_compressed += 1
                self.block = bytearray(BLOCK_LEN)
                self.block_len = 0

            # Copy input bytes into the block buffer.
            want = BLOCK_LEN - self.block_len
            take = min(want, len(input_bytes))
            self.block[self.block_len : self.block_len + take] = input_bytes[:take]
            self.block_len += take
            input_bytes = input_bytes[take:]

    def output(self) -> Output:
        block_words = words_from_little_endian_bytes(self.block)
        return Output(
            self.chaining_value,
            block_words,
            self.chunk_counter,
            self.block_len,
            self.flags | self.start_flag() | CHUNK_END,
        )


def parent_output(
    left_child_cv: list[int],
    right_child_cv: list[int],
    key_words: list[int],
    flags: int,
) -> Output:
    return Output(
        key_words, left_child_cv + right_child_cv, 0, BLOCK_LEN, PARENT | flags
    )


def parent_cv(
    left_child_cv: list[int],
    right_child_cv: list[int],
    key_words: list[int],
    flags: int,
) -> list[int]:
    return parent_output(
        left_child_cv, right_child_cv, key_words, flags
    ).chaining_value()


# An incremental hasher that can accept any number of writes.
@dataclass
class Hasher:
    chunk_state: ChunkState
    key_words: list[int]
    cv_stack: list[list[int]]
    flags: int

    def _init(self, key_words: list[int], flags: int) -> None:
        assert len(key_words) == 8
        self.chunk_state = ChunkState(key_words, 0, flags)
        self.key_words = key_words
        self.cv_stack = []
        self.flags = flags

    # Construct a new `Hasher` for the regular hash function.
    def __init__(self) -> None:
        self._init(IV, 0)

    # Construct a new `Hasher` for the keyed hash function.
    @classmethod
    def new_keyed(cls, key: bytes) -> Hasher:
        keyed_hasher = cls()
        key_words = words_from_little_endian_bytes(key)
        keyed_hasher._init(key_words, KEYED_HASH)
        return keyed_hasher

    # Construct a new `Hasher` for the key derivation function. The context
    # string should be hardcoded, globally unique, and application-specific.
    @classmethod
    def new_derive_key(cls, context: str) -> Hasher:
        context_hasher = cls()
        context_hasher._init(IV, DERIVE_KEY_CONTEXT)
        context_hasher.update(context.encode("utf8"))
        context_key = context_hasher.finalize(KEY_LEN)
        context_key_words = words_from_little_endian_bytes(context_key)
        derive_key_hasher = cls()
        derive_key_hasher._init(context_key_words, DERIVE_KEY_MATERIAL)
        return derive_key_hasher

    # Section 5.1.2 of the BLAKE3 spec explains this algorithm in more detail.
    def add_chunk_chaining_value(self, new_cv: list[int], total_chunks: int) -> None:
        # This chunk might complete some subtrees. For each completed subtree,
        # its left child will be the current top entry in the CV stack, and
        # its right child will be the current value of `new_cv`. Pop each left
        # child off the stack, merge it with `new_cv`, and overwrite `new_cv`
        # with the result. After all these merges, push the final value of
        # `new_cv` onto the stack. The number of completed subtrees is given
        # by the number of trailing 0-bits in the new total number of chunks.
        while total_chunks & 1 == 0:
            new_cv = parent_cv(self.cv_stack.pop(), new_cv, self.key_words, self.flags)
            total_chunks >>= 1
        self.cv_stack.append(new_cv)

    # Add input to the hash state. This can be called any number of times.
    def update(self, input_bytes: bytes) -> None:
        while input_bytes:
            # If the current chunk is complete, finalize it and reset the
            # chunk state. More input is coming, so this chunk is not ROOT.
            if self.chunk_state.len() == CHUNK_LEN:
                chunk_cv = self.chunk_state.output().chaining_value()
                total_chunks = self.chunk_state.chunk_counter + 1
                self.add_chunk_chaining_value(chunk_cv, total_chunks)
                self.chunk_state = ChunkState(self.key_words, total_chunks, self.flags)

            # Compress input bytes into the current chunk state.
            want = CHUNK_LEN - self.chunk_state.len()
            take = min(want, len(input_bytes))
            self.chunk_state.update(input_bytes[:take])
            input_bytes = input_bytes[take:]

    # Finalize the hash and write any number of output bytes.
    def finalize(self, length: int = OUT_LEN) -> bytes:
        # Starting with the Output from the current chunk, compute all the
        # parent chaining values along the right edge of the tree, until we
        # have the root Output.
        output = self.chunk_state.output()
        parent_nodes_remaining = len(self.cv_stack)
        while parent_nodes_remaining > 0:
            parent_nodes_remaining -= 1
            output = parent_output(
                self.cv_stack[parent_nodes_remaining],
                output.chaining_value(),
                self.key_words,
                self.flags,
            )
        return output.root_output_bytes(length)


# If this file is executed directly, hash standard input.
if __name__ == "__main__":
    import sys

    hasher = Hasher()
    while buf := sys.stdin.buffer.read(65536):
        hasher.update(buf)
    print(hasher.finalize().hex())