# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import json
import os
import random
import sys
from io import StringIO

import torch
import torch.nn.functional as F
from fairseq import options, utils
from fairseq.data import Dictionary
from fairseq.data.language_pair_dataset import collate
from fairseq.models import (
    FairseqEncoder,
    FairseqEncoderDecoderModel,
    FairseqIncrementalDecoder,
)
from fairseq.models.fairseq_encoder import EncoderOut
from fairseq.tasks import LegacyFairseqTask
from fairseq_cli import generate, interactive, preprocess, train, validate
import fairseq.distributed.utils as distributed_utils
from fairseq.dataclass.utils import convert_namespace_to_omegaconf


def dummy_dictionary(vocab_size, prefix="token_"):
    d = Dictionary()
    for i in range(vocab_size):
        token = prefix + str(i)
        d.add_symbol(token)
    d.finalize(padding_factor=1)  # don't add extra padding symbols
    return d


def dummy_dataloader(
    samples, padding_idx=1, eos_idx=2, batch_size=None,
):
    if batch_size is None:
        batch_size = len(samples)

    # add any missing data to samples
    for i, sample in enumerate(samples):
        if "id" not in sample:
            sample["id"] = i

    # create dataloader
    dataset = TestDataset(samples)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        collate_fn=(lambda samples: collate(samples, padding_idx, eos_idx)),
    )
    return iter(dataloader)


def sequence_generator_setup():
    # construct dummy dictionary
    d = dummy_dictionary(vocab_size=2)

    eos = d.eos()
    w1 = 4
    w2 = 5

    # construct source data
    src_tokens = torch.LongTensor([[w1, w2, eos], [w1, w2, eos]])
    src_lengths = torch.LongTensor([2, 2])

    args = argparse.Namespace()
    unk = 0.0
    args.beam_probs = [
        # step 0:
        torch.FloatTensor(
            [
                # eos      w1   w2
                # sentence 1:
                [0.0, unk, 0.9, 0.1],  # beam 1
                [0.0, unk, 0.9, 0.1],  # beam 2
                # sentence 2:
                [0.0, unk, 0.7, 0.3],
                [0.0, unk, 0.7, 0.3],
            ]
        ),
        # step 1:
        torch.FloatTensor(
            [
                # eos      w1   w2       prefix
                # sentence 1:
                [1.0, unk, 0.0, 0.0],  # w1: 0.9  (emit: w1 <eos>: 0.9*1.0)
                [0.0, unk, 0.9, 0.1],  # w2: 0.1
                # sentence 2:
                [0.25, unk, 0.35, 0.4],  # w1: 0.7  (don't emit: w1 <eos>: 0.7*0.25)
                [0.00, unk, 0.10, 0.9],  # w2: 0.3
            ]
        ),
        # step 2:
        torch.FloatTensor(
            [
                # eos      w1   w2       prefix
                # sentence 1:
                [0.0, unk, 0.1, 0.9],  # w2 w1: 0.1*0.9
                [
                    0.6,
                    unk,
                    0.2,
                    0.2,
                ],  # w2 w2: 0.1*0.1  (emit: w2 w2 <eos>: 0.1*0.1*0.6)
                # sentence 2:
                [
                    0.60,
                    unk,
                    0.4,
                    0.00,
                ],  # w1 w2: 0.7*0.4  (emit: w1 w2 <eos>: 0.7*0.4*0.6)
                [0.01, unk, 0.0, 0.99],  # w2 w2: 0.3*0.9
            ]
        ),
        # step 3:
        torch.FloatTensor(
            [
                # eos      w1   w2       prefix
                # sentence 1:
                [
                    1.0,
                    unk,
                    0.0,
                    0.0,
                ],  # w2 w1 w2: 0.1*0.9*0.9  (emit: w2 w1 w2 <eos>: 0.1*0.9*0.9*1.0)
                [
                    1.0,
                    unk,
                    0.0,
                    0.0,
                ],  # w2 w1 w1: 0.1*0.9*0.1  (emit: w2 w1 w1 <eos>: 0.1*0.9*0.1*1.0)
                # sentence 2:
                [
                    0.1,
                    unk,
                    0.5,
                    0.4,
                ],  # w2 w2 w2: 0.3*0.9*0.99  (emit: w2 w2 w2 <eos>: 0.3*0.9*0.99*0.1)
                [
                    1.0,
                    unk,
                    0.0,
                    0.0,
                ],  # w1 w2 w1: 0.7*0.4*0.4  (emit: w1 w2 w1 <eos>: 0.7*0.4*0.4*1.0)
            ]
        ),
    ]

    task = TestTranslationTask.setup_task(args, d, d)
    model = task.build_model(args)
    tgt_dict = task.target_dictionary

    return tgt_dict, w1, w2, src_tokens, src_lengths, model


def create_dummy_data(data_dir, num_examples=100, maxlen=20, alignment=False):
    def _create_dummy_data(filename):
        data = torch.rand(num_examples * maxlen)
        data = 97 + torch.floor(26 * data).int()
        with open(os.path.join(data_dir, filename), "w") as h:
            offset = 0
            for _ in range(num_examples):
                ex_len = random.randint(1, maxlen)
                ex_str = " ".join(map(chr, data[offset : offset + ex_len]))
                print(ex_str, file=h)
                offset += ex_len

    def _create_dummy_alignment_data(filename_src, filename_tgt, filename):
        with open(os.path.join(data_dir, filename_src), "r") as src_f, open(
            os.path.join(data_dir, filename_tgt), "r"
        ) as tgt_f, open(os.path.join(data_dir, filename), "w") as h:
            for src, tgt in zip(src_f, tgt_f):
                src_len = len(src.split())
                tgt_len = len(tgt.split())
                avg_len = (src_len + tgt_len) // 2
                num_alignments = random.randint(avg_len // 2, 2 * avg_len)
                src_indices = torch.floor(torch.rand(num_alignments) * src_len).int()
                tgt_indices = torch.floor(torch.rand(num_alignments) * tgt_len).int()
                ex_str = " ".join(
                    [
                        "{}-{}".format(src, tgt)
                        for src, tgt in zip(src_indices, tgt_indices)
                    ]
                )
                print(ex_str, file=h)

    _create_dummy_data("train.in")
    _create_dummy_data("train.out")
    _create_dummy_data("valid.in")
    _create_dummy_data("valid.out")
    _create_dummy_data("test.in")
    _create_dummy_data("test.out")

    if alignment:
        _create_dummy_alignment_data("train.in", "train.out", "train.align")
        _create_dummy_alignment_data("valid.in", "valid.out", "valid.align")
        _create_dummy_alignment_data("test.in", "test.out", "test.align")


def preprocess_lm_data(data_dir):
    preprocess_parser = options.get_preprocessing_parser()
    preprocess_args = preprocess_parser.parse_args(
        [
            "--only-source",
            "--trainpref",
            os.path.join(data_dir, "train.out"),
            "--validpref",
            os.path.join(data_dir, "valid.out"),
            "--testpref",
            os.path.join(data_dir, "test.out"),
            "--destdir",
            data_dir,
        ]
    )
    preprocess.main(preprocess_args)


def preprocess_translation_data(data_dir, extra_flags=None):
    preprocess_parser = options.get_preprocessing_parser()
    preprocess_args = preprocess_parser.parse_args(
        [
            "--source-lang",
            "in",
            "--target-lang",
            "out",
            "--trainpref",
            os.path.join(data_dir, "train"),
            "--validpref",
            os.path.join(data_dir, "valid"),
            "--testpref",
            os.path.join(data_dir, "test"),
            "--thresholdtgt",
            "0",
            "--thresholdsrc",
            "0",
            "--destdir",
            data_dir,
        ]
        + (extra_flags or []),
    )
    preprocess.main(preprocess_args)


def preprocess_summarization_data(data_dir, extra_flags=None):
    preprocess_parser = options.get_preprocessing_parser()
    preprocess_args = preprocess_parser.parse_args(
        [
            "--source-lang",
            "in",
            "--target-lang",
            "out",
            "--trainpref",
            os.path.join(data_dir, "train"),
            "--validpref",
            os.path.join(data_dir, "valid"),
            "--testpref",
            os.path.join(data_dir, "test"),
            "--thresholdtgt",
            "0",
            "--thresholdsrc",
            "0",
            "--joined-dictionary",
            "--destdir",
            data_dir,
        ]
        + (extra_flags or []),
    )
    preprocess.main(preprocess_args)


def create_laser_data_and_config_json(data_dir):
    src_langs = ["de", "fr", "ru", "tr", "zh"]
    tgt_langs = ["en", "es"]
    config_json = {}
    config_train_json = []
    src_vocab = None
    tgt_vocab = None

    for src_lang in src_langs:
        for tgt_lang in tgt_langs:
            langpair_folder = f"{src_lang}-{tgt_lang}"

            langpair_path = os.path.join(data_dir, langpair_folder)
            os.mkdir(langpair_path)
            create_dummy_data(langpair_path)
            preprocess_translation_data(langpair_path, ["--dataset-impl", "cached"])

            src_vocab = os.path.join(langpair_path, "dict.in.txt")
            tgt_vocab = os.path.join(langpair_path, "dict.out.txt")
            config_train_json.append(
                {
                    "id": 0 if tgt_lang == "en" else 1,
                    "src": os.path.join(langpair_path, "train.in-out.in"),
                    "tgt": os.path.join(langpair_path, "train.in-out.out"),
                }
            )

    config_json["src_vocab"] = src_vocab
    config_json["tgt_vocab"] = tgt_vocab
    config_json["train"] = config_train_json

    with open(os.path.join(data_dir, "laserconfig.json"), "w") as config_file:
        json.dump(config_json, config_file)

    return config_file


def train_translation_model(
    data_dir,
    arch,
    extra_flags=None,
    task="translation",
    run_validation=False,
    lang_flags=None,
    extra_valid_flags=None,
    world_size=1,
):
    if lang_flags is None:
        lang_flags = [
            "--source-lang",
            "in",
            "--target-lang",
            "out",
        ]
    train_parser = options.get_training_parser()
    train_args = options.parse_args_and_arch(
        train_parser,
        [
            "--task",
            task,
            data_dir,
            "--save-dir",
            data_dir,
            "--arch",
            arch,
            "--optimizer",
            "nag",
            "--lr",
            "0.05",
            "--max-tokens",
            "500",
            "--max-epoch",
            "1",
            "--no-progress-bar",
            "--distributed-world-size",
            str(world_size),
            "--num-workers",
            "0",
        ]
        + lang_flags
        + (extra_flags or []),
    )

    cfg = convert_namespace_to_omegaconf(train_args)
    distributed_utils.call_main(cfg, train.main)

    if run_validation:
        # test validation
        validate_parser = options.get_validation_parser()
        validate_args = options.parse_args_and_arch(
            validate_parser,
            [
                "--task",
                task,
                data_dir,
                "--path",
                os.path.join(data_dir, "checkpoint_last.pt"),
                "--valid-subset",
                "valid",
                "--max-tokens",
                "500",
                "--no-progress-bar",
                "--num-workers",
                "0",
            ]
            + lang_flags
            + (extra_valid_flags or []),
        )
        validate.main(validate_args)


def generate_main(data_dir, extra_flags=None, path=None):
    if extra_flags is None:
        extra_flags = [
            "--print-alignment",
        ]
    if path is None:
        path = os.path.join(data_dir, "checkpoint_last.pt")
    generate_parser = options.get_generation_parser()
    generate_args = options.parse_args_and_arch(
        generate_parser,
        [
            data_dir,
            "--path",
            path,
            "--beam",
            "3",
            "--batch-size",
            "64",
            "--max-len-b",
            "5",
            "--gen-subset",
            "valid",
            "--no-progress-bar",
            "--num-workers",
            "0",
        ]
        + (extra_flags or []),
    )

    # evaluate model in batch mode
    generate.main(generate_args)

    # evaluate model interactively
    generate_args.buffer_size = 0
    generate_args.input = "-"
    generate_args.batch_size = None
    orig_stdin = sys.stdin
    sys.stdin = StringIO("h e l l o\n")
    interactive.main(generate_args)
    sys.stdin = orig_stdin


class TestDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        super().__init__()
        self.data = data
        self.sizes = None

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return len(self.data)


class TestTranslationTask(LegacyFairseqTask):
    def __init__(self, args, src_dict, tgt_dict, model):
        super().__init__(args)
        self.src_dict = src_dict
        self.tgt_dict = tgt_dict
        self.model = model

    @classmethod
    def setup_task(cls, args, src_dict=None, tgt_dict=None, model=None):
        return cls(args, src_dict, tgt_dict, model)

    def build_model(self, args):
        return TestModel.build_model(args, self)

    @property
    def source_dictionary(self):
        return self.src_dict

    @property
    def target_dictionary(self):
        return self.tgt_dict


class TestModel(FairseqEncoderDecoderModel):
    def __init__(self, encoder, decoder):
        super().__init__(encoder, decoder)

    @classmethod
    def build_model(cls, args, task):
        encoder = TestEncoder(args, task.source_dictionary)
        decoder = TestIncrementalDecoder(args, task.target_dictionary)
        return cls(encoder, decoder)


class TestEncoder(FairseqEncoder):
    def __init__(self, args, dictionary):
        super().__init__(dictionary)
        self.args = args

    def forward(self, src_tokens, src_lengths=None, **kwargs):
        return EncoderOut(
            encoder_out=src_tokens,
            encoder_padding_mask=None,
            encoder_embedding=None,
            encoder_states=None,
            src_tokens=None,
            src_lengths=None,
        )

    def reorder_encoder_out(self, encoder_out, new_order):
        return EncoderOut(
            encoder_out=encoder_out.encoder_out.index_select(0, new_order),
            encoder_padding_mask=None,
            encoder_embedding=None,
            encoder_states=None,
            src_tokens=None,
            src_lengths=None,
        )


class TestIncrementalDecoder(FairseqIncrementalDecoder):
    def __init__(self, args, dictionary):
        super().__init__(dictionary)
        assert hasattr(args, "beam_probs") or hasattr(args, "probs")
        args.max_decoder_positions = getattr(args, "max_decoder_positions", 100)
        self.args = args

    def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None):
        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]
        bbsz = prev_output_tokens.size(0)
        vocab = len(self.dictionary)
        src_len = encoder_out.encoder_out.size(1)
        tgt_len = prev_output_tokens.size(1)

        # determine number of steps
        if incremental_state is not None:
            # cache step number
            step = utils.get_incremental_state(self, incremental_state, "step")
            if step is None:
                step = 0
            utils.set_incremental_state(self, incremental_state, "step", step + 1)
            steps = [step]
        else:
            steps = list(range(tgt_len))

        # define output in terms of raw probs
        if hasattr(self.args, "probs"):
            assert (
                self.args.probs.dim() == 3
            ), "expected probs to have size bsz*steps*vocab"
            probs = self.args.probs.index_select(1, torch.LongTensor(steps))
        else:
            probs = torch.FloatTensor(bbsz, len(steps), vocab).zero_()
            for i, step in enumerate(steps):
                # args.beam_probs gives the probability for every vocab element,
                # starting with eos, then unknown, and then the rest of the vocab
                if step < len(self.args.beam_probs):
                    probs[:, i, self.dictionary.eos() :] = self.args.beam_probs[step]
                else:
                    probs[:, i, self.dictionary.eos()] = 1.0

        # random attention
        attn = torch.rand(bbsz, tgt_len, src_len)

        dev = prev_output_tokens.device
        return probs.to(dev), {"attn": [attn.to(dev)]}

    def get_normalized_probs(self, net_output, log_probs, _):
        # the decoder returns probabilities directly
        probs = net_output[0]
        if log_probs:
            return probs.log()
        else:
            return probs

    def max_positions(self):
        return self.args.max_decoder_positions


class TestReshapingEncoder(FairseqEncoder):
    def __init__(self, args, dictionary):
        super().__init__(dictionary)
        self.args = args

    def forward(self, src_tokens, src_lengths=None, **kwargs):
        b_sz, t_sz = src_tokens.shape
        padding_needed = t_sz % 2
        x = src_tokens
        if padding_needed > 0:
            padding_needed = 2 - padding_needed
            x = F.pad(x, (0, padding_needed))

        return EncoderOut(
            encoder_out=x.view(b_sz, -1, 2),
            encoder_padding_mask=None,
            encoder_embedding=None,
            encoder_states=None,
            src_tokens=None,
            src_lengths=None,
        )

    def reorder_encoder_out(self, encoder_out, new_order):
        return EncoderOut(
            encoder_out=encoder_out.encoder_out.index_select(0, new_order),
            encoder_padding_mask=None,
            encoder_embedding=None,
            encoder_states=None,
            src_tokens=None,
            src_lengths=None,
        )


class TestReshapingModel(FairseqEncoderDecoderModel):
    def __init__(self, encoder, decoder):
        super().__init__(encoder, decoder)

    @classmethod
    def build_model(cls, args, task):
        encoder = TestReshapingEncoder(args, task.source_dictionary)
        decoder = TestIncrementalDecoder(args, task.target_dictionary)
        return cls(encoder, decoder)


class TestAdditionalInputEncoder(FairseqEncoder):
    def __init__(self, args, dictionary):
        super().__init__(dictionary)
        self.args = args

    def forward(self, src_tokens, src_lengths=None, **kwargs):
        assert "fancy_other_input" in kwargs
        assert kwargs["fancy_other_input"] is not None
        return EncoderOut(
            encoder_out=src_tokens,
            encoder_padding_mask=None,
            encoder_embedding=None,
            encoder_states=None,
            src_tokens=None,
            src_lengths=None,
        )

    def reorder_encoder_out(self, encoder_out, new_order):
        return EncoderOut(
            encoder_out=encoder_out.encoder_out.index_select(0, new_order),
            encoder_padding_mask=None,
            encoder_embedding=None,
            encoder_states=None,
            src_tokens=None,
            src_lengths=None,
        )


class TestAdditionalInputModel(FairseqEncoderDecoderModel):
    def __init__(self, encoder, decoder):
        super().__init__(encoder, decoder)

    @classmethod
    def build_model(cls, args, task):
        encoder = TestAdditionalInputEncoder(args, task.source_dictionary)
        decoder = TestIncrementalDecoder(args, task.target_dictionary)
        return cls(encoder, decoder)

    def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
        encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
        decoder_out = self.decoder(
            prev_output_tokens, encoder_out=encoder_out, **kwargs
        )
        return decoder_out


def train_language_model(
    data_dir,
    arch,
    extra_flags=None,
    run_validation=False,
    extra_valid_flags=None,
    task="language_modeling",
    world_size=1,
):
    train_parser = options.get_training_parser()
    train_args = options.parse_args_and_arch(
        train_parser,
        [
            "--task",
            task,
            data_dir,
            "--arch",
            arch,
            "--optimizer",
            "adam",
            "--lr",
            "0.0001",
            "--max-tokens",
            "500",
            "--tokens-per-sample",
            "500",
            "--save-dir",
            data_dir,
            "--max-epoch",
            "1",
            "--no-progress-bar",
            "--distributed-world-size",
            str(world_size),
            "--ddp-backend",
            "no_c10d",
            "--num-workers",
            "0",
        ]
        + (extra_flags or []),
    )
    cfg = convert_namespace_to_omegaconf(train_args)
    distributed_utils.call_main(cfg, train.main)

    if run_validation:
        # test validation
        validate_parser = options.get_validation_parser()
        validate_args = options.parse_args_and_arch(
            validate_parser,
            [
                "--task",
                task,
                data_dir,
                "--path",
                os.path.join(data_dir, "checkpoint_last.pt"),
                "--valid-subset",
                "valid",
                "--max-tokens",
                "500",
                "--no-progress-bar",
                "--num-workers",
                "0",
            ]
            + (extra_valid_flags or []),
        )
        validate.main(validate_args)