from os import truncate
from sacremoses import MosesPunctNormalizer
from sacremoses import MosesTokenizer
from sacremoses import MosesDetokenizer
from subword_nmt.apply_bpe import BPE, read_vocabulary
import codecs
from tqdm import tqdm
from indicnlp.tokenize import indic_tokenize
from indicnlp.tokenize import indic_detokenize
from indicnlp.normalize import indic_normalize
from indicnlp.transliterate import unicode_transliterate
from mosestokenizer import MosesSentenceSplitter
from indicnlp.tokenize import sentence_tokenize

from inference.custom_interactive import Translator


INDIC = ["as", "bn", "gu", "hi", "kn", "ml", "mr", "or", "pa", "ta", "te"]


def split_sentences(paragraph, language):
    if language == "en":
        with MosesSentenceSplitter(language) as splitter:
            return splitter([paragraph])
    elif language in INDIC:
        return sentence_tokenize.sentence_split(paragraph, lang=language)


def add_token(sent, tag_infos):
    """add special tokens specified by tag_infos to each element in list

    tag_infos: list of tuples (tag_type,tag)

    each tag_info results in a token of the form: __{tag_type}__{tag}__

    """

    tokens = []
    for tag_type, tag in tag_infos:
        token = "__" + tag_type + "__" + tag + "__"
        tokens.append(token)

    return " ".join(tokens) + " " + sent


def apply_lang_tags(sents, src_lang, tgt_lang):
    tagged_sents = []
    for sent in sents:
        tagged_sent = add_token(sent.strip(), [("src", src_lang), ("tgt", tgt_lang)])
        tagged_sents.append(tagged_sent)
    return tagged_sents


def truncate_long_sentences(sents):

    MAX_SEQ_LEN = 200
    new_sents = []

    for sent in sents:
        words = sent.split()
        num_words = len(words)
        if num_words > MAX_SEQ_LEN:
            print_str = " ".join(words[:5]) + " .... " + " ".join(words[-5:])
            sent = " ".join(words[:MAX_SEQ_LEN])
            print(
                f"WARNING: Sentence {print_str} truncated to 200 tokens as it exceeds maximum length limit"
            )

        new_sents.append(sent)
    return new_sents


class Model:
    def __init__(self, expdir):
        self.expdir = expdir
        self.en_tok = MosesTokenizer(lang="en")
        self.en_normalizer = MosesPunctNormalizer()
        self.en_detok = MosesDetokenizer(lang="en")
        self.xliterator = unicode_transliterate.UnicodeIndicTransliterator()
        print("Initializing vocab and bpe")
        self.vocabulary = read_vocabulary(
            codecs.open(f"{expdir}/vocab/vocab.SRC", encoding="utf-8"), 5
        )
        self.bpe = BPE(
            codecs.open(f"{expdir}/vocab/bpe_codes.32k.SRC", encoding="utf-8"),
            -1,
            "@@",
            self.vocabulary,
            None,
        )

        print("Initializing model for translation")
        # initialize the model
        self.translator = Translator(
            f"{expdir}/final_bin", f"{expdir}/model/checkpoint_best.pt", batch_size=100
        )

    # translate a batch of sentences from src_lang to tgt_lang
    def batch_translate(self, batch, src_lang, tgt_lang):

        assert isinstance(batch, list)
        preprocessed_sents = self.preprocess(batch, lang=src_lang)
        bpe_sents = self.apply_bpe(preprocessed_sents)
        tagged_sents = apply_lang_tags(bpe_sents, src_lang, tgt_lang)
        tagged_sents = truncate_long_sentences(tagged_sents)

        translations = self.translator.translate(tagged_sents)
        postprocessed_sents = self.postprocess(translations, tgt_lang)

        return postprocessed_sents

    # translate a paragraph from src_lang to tgt_lang
    def translate_paragraph(self, paragraph, src_lang, tgt_lang):

        assert isinstance(paragraph, str)
        sents = split_sentences(paragraph, src_lang)

        postprocessed_sents = self.batch_translate(sents, src_lang, tgt_lang)

        translated_paragraph = " ".join(postprocessed_sents)

        return translated_paragraph

    def preprocess_sent(self, sent, normalizer, lang):
        if lang == "en":
            return " ".join(
                self.en_tok.tokenize(
                    self.en_normalizer.normalize(sent.strip()), escape=False
                )
            )
        else:
            # line = indic_detokenize.trivial_detokenize(line.strip(), lang)
            return unicode_transliterate.UnicodeIndicTransliterator.transliterate(
                " ".join(
                    indic_tokenize.trivial_tokenize(
                        normalizer.normalize(sent.strip()), lang
                    )
                ),
                lang,
                "hi",
            ).replace(" ् ", "्")

    def preprocess(self, sents, lang):
        """
        Normalize, tokenize and script convert(for Indic)
        return number of sentences input file

        """

        if lang == "en":

            # processed_sents = Parallel(n_jobs=-1, backend="multiprocessing")(
            #     delayed(preprocess_line)(line, None, lang) for line in tqdm(sents, total=num_lines)
            # )
            processed_sents = [
                self.preprocess_sent(line, None, lang) for line in tqdm(sents)
            ]

        else:
            normfactory = indic_normalize.IndicNormalizerFactory()
            normalizer = normfactory.get_normalizer(lang)

            # processed_sents = Parallel(n_jobs=-1, backend="multiprocessing")(
            #     delayed(preprocess_line)(line, normalizer, lang) for line in tqdm(infile, total=num_lines)
            # )
            processed_sents = [
                self.preprocess_sent(line, normalizer, lang) for line in tqdm(sents)
            ]

        return processed_sents

    def postprocess(self, sents, lang, common_lang="hi"):
        """
        parse fairseq interactive output, convert script back to native Indic script (in case of Indic languages) and detokenize.

        infname: fairseq log file
        outfname: output file of translation (sentences not translated contain the dummy string 'DUMMY_OUTPUT'
        input_size: expected number of output sentences
        lang: language
        """
        postprocessed_sents = []

        if lang == "en":
            for sent in sents:
                # outfile.write(en_detok.detokenize(sent.split(" ")) + "\n")
                postprocessed_sents.append(self.en_detok.detokenize(sent.split(" ")))
        else:
            for sent in sents:
                outstr = indic_detokenize.trivial_detokenize(
                    self.xliterator.transliterate(sent, common_lang, lang), lang
                )
                # outfile.write(outstr + "\n")
                postprocessed_sents.append(outstr)
        return postprocessed_sents

    def apply_bpe(self, sents):

        return [self.bpe.process_line(sent) for sent in sents]