"""
Mostly a TaxonomicTree class that implements a taxonomy and some helpers for easily
walking and looking in the tree.

A tree is an arrangement of TaxonomicNodes.


"""


import itertools
import json


class TaxonomicNode:
    __slots__ = ("name", "index", "root", "_children")

    def __init__(self, name, index, root):
        self.name = name
        self.index = index
        self.root = root
        self._children = {}

    def add(self, name):
        added = 0
        if not name:
            return added

        first, rest = name[0], name[1:]
        if first not in self._children:
            self._children[first] = TaxonomicNode(first, self.root.size, self.root)
            self.root.size += 1

        self._children[first].add(rest)

    def children(self, name):
        if not name:
            return set((child.name, child.index) for child in self._children.values())

        first, rest = name[0], name[1:]
        if first not in self._children:
            return set()

        return self._children[first].children(rest)

    def descendants(self, prefix=None):
        """Iterates over all values in the subtree that match prefix."""

        if not prefix:
            yield (self.name,), self.index
            for child in self._children.values():
                for name, i in child.descendants():
                    yield (self.name, *name), i
            return

        first, rest = prefix[0], prefix[1:]
        if first not in self._children:
            return

        for name, i in self._children[first].descendants(rest):
            yield (self.name, *name), i

    def values(self):
        """Iterates over all (name, i) pairs in the tree."""
        yield (self.name,), self.index

        for child in self._children.values():
            for name, index in child.values():
                yield (self.name, *name), index

    @classmethod
    def from_dict(cls, dct, root):
        node = cls(dct["name"], dct["index"], root)
        node._children = {
            child["name"]: cls.from_dict(child, root) for child in dct["children"]
        }
        return node


class TaxonomicTree:
    """
    Efficient structure for finding taxonomic names and their descendants.
    Also returns an integer index i for each possible name.
    """

    def __init__(self):
        self.kingdoms = {}
        self.size = 0

    def add(self, name: list[str]):
        if not name:
            return

        first, rest = name[0], name[1:]
        if first not in self.kingdoms:
            self.kingdoms[first] = TaxonomicNode(first, self.size, self)
            self.size += 1

        self.kingdoms[first].add(rest)

    def children(self, name=None):
        if not name:
            return set(
                (kingdom.name, kingdom.index) for kingdom in self.kingdoms.values()
            )

        first, rest = name[0], name[1:]
        if first not in self.kingdoms:
            return set()

        return self.kingdoms[first].children(rest)

    def descendants(self, prefix=None):
        """Iterates over all values in the tree that match prefix."""
        if not prefix:
            # Give them all the subnodes
            for kingdom in self.kingdoms.values():
                yield from kingdom.descendants()

            return

        first, rest = prefix[0], prefix[1:]
        if first not in self.kingdoms:
            return

        yield from self.kingdoms[first].descendants(rest)

    def values(self):
        """Iterates over all (name, i) pairs in the tree."""
        for kingdom in self.kingdoms.values():
            yield from kingdom.values()

    def __len__(self):
        return self.size

    @classmethod
    def from_dict(cls, dct):
        tree = cls()
        tree.kingdoms = {
            kingdom["name"]: TaxonomicNode.from_dict(kingdom, tree)
            for kingdom in dct["kingdoms"]
        }
        tree.size = dct["size"]
        return tree


class TaxonomicJsonEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, TaxonomicNode):
            return {
                "name": obj.name,
                "index": obj.index,
                "children": list(obj._children.values()),
            }
        elif isinstance(obj, TaxonomicTree):
            return {
                "kingdoms": list(obj.kingdoms.values()),
                "size": obj.size,
            }
        else:
            super().default(self, obj)


def batched(iterable, n):
    # batched('ABCDEFG', 3) --> ABC DEF G
    if n < 1:
        raise ValueError("n must be at least one")
    it = iter(iterable)
    while batch := tuple(itertools.islice(it, n)):
        yield zip(*batch)