import re
from typing import List, Dict, Set
import numpy as np
import torch
from ufal.chu_liu_edmonds import chu_liu_edmonds

DEPENDENCY_RELATIONS = [
    "acl",
    "advcl",
    "advmod",
    "amod",
    "appos",
    "aux",
    "case",
    "cc",
    "ccomp",
    "conj",
    "cop",
    "csubj",
    "det",
    "iobj",
    "mark",
    "nmod",
    "nsubj",
    "nummod",
    "obj",
    "obl",
    "parataxis",
    "punct",
    "root",
    "vocative",
    "xcomp",
]
INDEX2TAG = {idx: tag for idx, tag in enumerate(DEPENDENCY_RELATIONS)}
TAG2INDEX = {tag: idx for idx, tag in enumerate(DEPENDENCY_RELATIONS)}


def preprocess_text(text: str) -> List[str]:
    text = text.strip()
    text = re.sub("(?<! )(?=[.,!?()·;:])|(?<=[.,!?()·;:])(?! )", r" ", text)
    return text.split()


def batched_index_select(
    input: torch.Tensor, dim: int, index: torch.Tensor
) -> torch.Tensor:
    views = [input.shape[0]] + [
        1 if i != dim else -1 for i in range(1, len(input.shape))
    ]
    expanse = list(input.shape)
    expanse[0] = -1
    expanse[dim] = -1
    index = index.view(views).expand(expanse)
    return torch.gather(input, dim, index)


def get_relevant_tokens(tokenized: torch.Tensor, start_ids: Set[int]) -> List[int]:
    return [tokenized[idx].item() for idx in range(len(tokenized)) if idx in start_ids]


def resolve(
    edmonds_head: List[int], word_ids: List[int], parent_probs_table: torch.Tensor
) -> torch.Tensor:
    multiple_roots = [i for i, x in enumerate(edmonds_head) if x == 0]
    if len(multiple_roots) > 1:
        main_root = max(multiple_roots, key=edmonds_head.count)
        secondary_roots = set(multiple_roots) - {main_root}
        for root in secondary_roots:
            parent_probs_table[0][word_ids.index(root)][0] = 0
    return parent_probs_table


def apply_chu_liu_edmonds(
    parent_probs_table: torch.Tensor, tokenized_input: Dict, start_ids: Set[int]
) -> List[int]:
    parent_probs_table = (
        parent_probs_table
        if parent_probs_table.shape[1] == parent_probs_table.shape[2]
        else parent_probs_table[:, :, 1:]
    )
    edmonds_heads, _ = chu_liu_edmonds(
        parent_probs_table.squeeze(0).cpu().numpy().astype("double")
    )
    edmonds_heads = torch.tensor(edmonds_heads).unsqueeze(0)
    edmonds_heads[edmonds_heads == -1] = 0
    tokenized_input["head_labels"] = edmonds_heads
    return get_relevant_tokens(edmonds_heads[0], start_ids)


def get_word_endings(tokenized_input):
    word_ids = tokenized_input.word_ids(batch_index=0)
    start_ids = set()
    word_endings = {0: (1, 0)}
    for word_id in word_ids:
        if word_id is not None:
            start, end = tokenized_input.word_to_tokens(
                batch_or_word_index=0, word_index=word_id
            )
            start_ids.add(start)
            word_endings[start] = (end, word_id + 1)
            for a in range(start + 1, end + 1):
                word_endings[a] = (end, word_id + 1)
    return word_endings, start_ids, word_ids


def get_dependencies(
    dependency_parser,
    label_parser,
    tokenizer,
    collator,
    labels: bool,
    sentence: List[str],
) -> Dict:
    tokenized_input = tokenizer(
        sentence, truncation=True, is_split_into_words=True, add_special_tokens=True
    )
    dep_dict: Dict[str, List[Dict[str, str]]] = {
        "words": [{"text": "ROOT", "tag": ""}],
        "arcs": [],
    }

    word_endings, start_ids, word_ids = get_word_endings(tokenized_input)
    tokenized_input = collator([tokenized_input])
    _, _, parent_probs_table = dependency_parser(**tokenized_input)

    irrelevant = torch.tensor(
        [
            idx.item()
            for idx in torch.arange(parent_probs_table.size(1))
            if idx.item() not in start_ids and idx.item() != 0
        ]
    )
    if irrelevant.nelement() > 0:
        parent_probs_table.index_fill_(1, irrelevant, torch.nan)
        parent_probs_table.index_fill_(2, irrelevant, torch.nan)

    edmonds_head = apply_chu_liu_edmonds(parent_probs_table, tokenized_input, start_ids)
    parent_probs_table = resolve(edmonds_head, word_ids, parent_probs_table)
    edmonds_head = apply_chu_liu_edmonds(parent_probs_table, tokenized_input, start_ids)

    if labels:
        predictions_labels = np.argmax(
            label_parser(**tokenized_input).logits.detach().cpu().numpy(), axis=-1
        )
        predicted_relations = get_relevant_tokens(predictions_labels[0], start_ids)
        predicted_relations = [
            INDEX2TAG[predicted_relations[idx]] for idx in range(len(sentence))
        ]
    else:
        predicted_relations = [""] * len(sentence)

    for idx, head in enumerate(edmonds_head):
        arc = {
            "start": min(idx + 1, word_endings[head][1]),
            "end": max(idx + 1, word_endings[head][1]),
            "label": predicted_relations[idx],
            "dir": "left" if idx + 1 < word_endings[head][1] else "right",
        }
        dep_dict["arcs"].append(arc)

    return dep_dict