# python wrapper for fairseq-interactive command line tool

#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Translate raw text with a trained model. Batches data on-the-fly.
"""

import ast
from collections import namedtuple

import torch
from fairseq import checkpoint_utils, options, tasks, utils
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.token_generation_constraints import pack_constraints, unpack_constraints
from fairseq_cli.generate import get_symbols_to_strip_from_output

import codecs


Batch = namedtuple("Batch", "ids src_tokens src_lengths constraints")
Translation = namedtuple("Translation", "src_str hypos pos_scores alignments")


def make_batches(
    lines, cfg, task, max_positions, encode_fn, constrainted_decoding=False
):
    def encode_fn_target(x):
        return encode_fn(x)

    if constrainted_decoding:
        # Strip (tab-delimited) contraints, if present, from input lines,
        # store them in batch_constraints
        batch_constraints = [list() for _ in lines]
        for i, line in enumerate(lines):
            if "\t" in line:
                lines[i], *batch_constraints[i] = line.split("\t")

        # Convert each List[str] to List[Tensor]
        for i, constraint_list in enumerate(batch_constraints):
            batch_constraints[i] = [
                task.target_dictionary.encode_line(
                    encode_fn_target(constraint),
                    append_eos=False,
                    add_if_not_exist=False,
                )
                for constraint in constraint_list
            ]

    if constrainted_decoding:
        constraints_tensor = pack_constraints(batch_constraints)
    else:
        constraints_tensor = None

    tokens, lengths = task.get_interactive_tokens_and_lengths(lines, encode_fn)

    itr = task.get_batch_iterator(
        dataset=task.build_dataset_for_inference(
            tokens, lengths, constraints=constraints_tensor
        ),
        max_tokens=cfg.dataset.max_tokens,
        max_sentences=cfg.dataset.batch_size,
        max_positions=max_positions,
        ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
    ).next_epoch_itr(shuffle=False)
    for batch in itr:
        ids = batch["id"]
        src_tokens = batch["net_input"]["src_tokens"]
        src_lengths = batch["net_input"]["src_lengths"]
        constraints = batch.get("constraints", None)

        yield Batch(
            ids=ids,
            src_tokens=src_tokens,
            src_lengths=src_lengths,
            constraints=constraints,
        )


class Translator:
    def __init__(
        self, data_dir, checkpoint_path, batch_size=25, constrained_decoding=False
    ):

        self.constrained_decoding = constrained_decoding
        self.parser = options.get_generation_parser(interactive=True)
        # buffer_size is currently not used but we just initialize it to batch
        # size + 1 to avoid any assertion errors.
        if self.constrained_decoding:
            self.parser.set_defaults(
                path=checkpoint_path,
                remove_bpe="subword_nmt",
                num_workers=-1,
                constraints="ordered",
                batch_size=batch_size,
                buffer_size=batch_size + 1,
            )
        else:
            self.parser.set_defaults(
                path=checkpoint_path,
                remove_bpe="subword_nmt",
                num_workers=-1,
                batch_size=batch_size,
                buffer_size=batch_size + 1,
            )
        args = options.parse_args_and_arch(self.parser, input_args=[data_dir])
        # we are explictly setting src_lang and tgt_lang here
        # generally the data_dir we pass contains {split}-{src_lang}-{tgt_lang}.*.idx files from
        # which fairseq infers the src and tgt langs(if these are not passed). In deployment we dont
        # use any idx files and only store the SRC and TGT dictionaries.
        args.source_lang = "SRC"
        args.target_lang = "TGT"
        # since we are truncating sentences to max_seq_len in engine, we can set it to False here
        args.skip_invalid_size_inputs_valid_test = False

        # we have custom architechtures in this folder and we will let fairseq
        # import this
        args.user_dir = "model_configs"
        self.cfg = convert_namespace_to_omegaconf(args)

        utils.import_user_module(self.cfg.common)

        if self.cfg.interactive.buffer_size < 1:
            self.cfg.interactive.buffer_size = 1
        if self.cfg.dataset.max_tokens is None and self.cfg.dataset.batch_size is None:
            self.cfg.dataset.batch_size = 1

        assert (
            not self.cfg.generation.sampling
            or self.cfg.generation.nbest == self.cfg.generation.beam
        ), "--sampling requires --nbest to be equal to --beam"
        assert (
            not self.cfg.dataset.batch_size
            or self.cfg.dataset.batch_size <= self.cfg.interactive.buffer_size
        ), "--batch-size cannot be larger than --buffer-size"

        # Fix seed for stochastic decoding
        # if self.cfg.common.seed is not None and not self.cfg.generation.no_seed_provided:
        #     np.random.seed(self.cfg.common.seed)
        #     utils.set_torch_seed(self.cfg.common.seed)

        # if not self.constrained_decoding:
        #     self.use_cuda = torch.cuda.is_available() and not self.cfg.common.cpu
        # else:
        #     self.use_cuda = False

        self.use_cuda = torch.cuda.is_available() and not self.cfg.common.cpu

        # Setup task, e.g., translation
        self.task = tasks.setup_task(self.cfg.task)

        # Load ensemble
        overrides = ast.literal_eval(self.cfg.common_eval.model_overrides)
        self.models, self._model_args = checkpoint_utils.load_model_ensemble(
            utils.split_paths(self.cfg.common_eval.path),
            arg_overrides=overrides,
            task=self.task,
            suffix=self.cfg.checkpoint.checkpoint_suffix,
            strict=(self.cfg.checkpoint.checkpoint_shard_count == 1),
            num_shards=self.cfg.checkpoint.checkpoint_shard_count,
        )

        # Set dictionaries
        self.src_dict = self.task.source_dictionary
        self.tgt_dict = self.task.target_dictionary

        # Optimize ensemble for generation
        for model in self.models:
            if model is None:
                continue
            if self.cfg.common.fp16:
                model.half()
            if (
                self.use_cuda
                and not self.cfg.distributed_training.pipeline_model_parallel
            ):
                model.cuda()
            model.prepare_for_inference_(self.cfg)

        # Initialize generator
        self.generator = self.task.build_generator(self.models, self.cfg.generation)

        # Handle tokenization and BPE
        self.tokenizer = self.task.build_tokenizer(self.cfg.tokenizer)
        self.bpe = self.task.build_bpe(self.cfg.bpe)

        # Load alignment dictionary for unknown word replacement
        # (None if no unknown word replacement, empty if no path to align dictionary)
        self.align_dict = utils.load_align_dict(self.cfg.generation.replace_unk)

        self.max_positions = utils.resolve_max_positions(
            self.task.max_positions(), *[model.max_positions() for model in self.models]
        )

    def encode_fn(self, x):
        if self.tokenizer is not None:
            x = self.tokenizer.encode(x)
        if self.bpe is not None:
            x = self.bpe.encode(x)
        return x

    def decode_fn(self, x):
        if self.bpe is not None:
            x = self.bpe.decode(x)
        if self.tokenizer is not None:
            x = self.tokenizer.decode(x)
        return x

    def translate(self, inputs, constraints=None):
        if self.constrained_decoding and constraints is None:
            raise ValueError("Constraints cant be None in constrained decoding mode")
        if not self.constrained_decoding and constraints is not None:
            raise ValueError("Cannot pass constraints during normal translation")
        if constraints:
            constrained_decoding = True
            modified_inputs = []
            for _input, constraint in zip(inputs, constraints):
                modified_inputs.append(_input + f"\t{constraint}")
            inputs = modified_inputs
        else:
            constrained_decoding = False

        start_id = 0
        results = []
        final_translations = []
        for batch in make_batches(
            inputs,
            self.cfg,
            self.task,
            self.max_positions,
            self.encode_fn,
            constrained_decoding,
        ):
            bsz = batch.src_tokens.size(0)
            src_tokens = batch.src_tokens
            src_lengths = batch.src_lengths
            constraints = batch.constraints
            if self.use_cuda:
                src_tokens = src_tokens.cuda()
                src_lengths = src_lengths.cuda()
                if constraints is not None:
                    constraints = constraints.cuda()

            sample = {
                "net_input": {
                    "src_tokens": src_tokens,
                    "src_lengths": src_lengths,
                },
            }

            translations = self.task.inference_step(
                self.generator, self.models, sample, constraints=constraints
            )

            list_constraints = [[] for _ in range(bsz)]
            if constrained_decoding:
                list_constraints = [unpack_constraints(c) for c in constraints]
            for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)):
                src_tokens_i = utils.strip_pad(src_tokens[i], self.tgt_dict.pad())
                constraints = list_constraints[i]
                results.append(
                    (
                        start_id + id,
                        src_tokens_i,
                        hypos,
                        {
                            "constraints": constraints,
                        },
                    )
                )

        # sort output to match input order
        for id_, src_tokens, hypos, _ in sorted(results, key=lambda x: x[0]):
            src_str = ""
            if self.src_dict is not None:
                src_str = self.src_dict.string(
                    src_tokens, self.cfg.common_eval.post_process
                )

            # Process top predictions
            for hypo in hypos[: min(len(hypos), self.cfg.generation.nbest)]:
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo["tokens"].int().cpu(),
                    src_str=src_str,
                    alignment=hypo["alignment"],
                    align_dict=self.align_dict,
                    tgt_dict=self.tgt_dict,
                    remove_bpe="subword_nmt",
                    extra_symbols_to_ignore=get_symbols_to_strip_from_output(
                        self.generator
                    ),
                )
                detok_hypo_str = self.decode_fn(hypo_str)
                final_translations.append(detok_hypo_str)
        return final_translations