import os
import math
import argparse
import glob
import gradio

import torch
from torch.utils.data import DataLoader

from tqdm import tqdm

from transformers import (
    PreTrainedTokenizerBase,
    DataCollatorForSeq2Seq,
)

from model import load_model_for_inference

from dataset import DatasetReader, count_lines

from accelerate import Accelerator, DistributedType, find_executable_batch_size

from typing import Optional


def encode_string(text):
    return text.replace("\r", r"\r").replace("\n", r"\n").replace("\t", r"\t")


def get_dataloader(
    accelerator: Accelerator,
    filename: str,
    tokenizer: PreTrainedTokenizerBase,
    batch_size: int,
    max_length: int,
    prompt: str,
) -> DataLoader:
    dataset = DatasetReader(
        filename=filename,
        tokenizer=tokenizer,
        max_length=max_length,
        prompt=prompt,
    )
    if accelerator.distributed_type == DistributedType.TPU:
        data_collator = DataCollatorForSeq2Seq(
            tokenizer,
            padding="max_length",
            max_length=max_length,
            label_pad_token_id=tokenizer.pad_token_id,
            return_tensors="pt",
        )
    else:
        data_collator = DataCollatorForSeq2Seq(
            tokenizer,
            padding=True,
            label_pad_token_id=tokenizer.pad_token_id,
            # max_length=max_length, No need to set max_length here, we already truncate in the preprocess function
            pad_to_multiple_of=8,
            return_tensors="pt",
        )

    return DataLoader(
        dataset,
        batch_size=batch_size,
        collate_fn=data_collator,
        num_workers=0,  # Disable multiprocessing
    )


def main(
    input_string: str,
    source_lang: Optional[str],
    target_lang: Optional[str],
    model_name: str = "facebook/m2m100_1.2B",
    starting_batch_size: int = 8,
    lora_weights_name_or_path: str = None,
    force_auto_device_map: bool = False,
    precision: str = None,
    max_length: int = 256,
    num_beams: int = 4,
    num_return_sequences: int = 1,
    do_sample: bool = False,
    temperature: float = 1.0,
    top_k: int = 50,
    top_p: float = 1.0,
    keep_special_tokens: bool = False,
    keep_tokenization_spaces: bool = False,
    repetition_penalty: float = None,
    prompt: str = None,
    trust_remote_code: bool = False,
):
    accelerator = Accelerator()
    sentences_path = "input.txt"
    output_path = "output.txt"
    with open(sentences_path, "w", encoding="utf-8") as f:
        f.write(input_string)

    if force_auto_device_map and starting_batch_size >= 64:
        print(
            f"WARNING: You are using a very large batch size ({starting_batch_size}) and the auto_device_map  flag. "
            f"auto_device_map will offload model parameters to the CPU when they don't fit on the GPU VRAM. "
            f"If you use a very large batch size, it will offload a lot of parameters to the CPU and slow down the "
            f"inference. You should consider using a smaller batch size, i.e '--starting_batch_size 8'"
        )

    if precision is None:
        quantization = None
        dtype = None
    elif precision == "8" or precision == "4":
        quantization = int(precision)
        dtype = None
    elif precision == "fp16":
        quantization = None
        dtype = "float16"
    elif precision == "bf16":
        quantization = None
        dtype = "bfloat16"
    elif precision == "32":
        quantization = None
        dtype = "float32"
    else:
        raise ValueError(
            f"Precision {precision} not supported. Please choose between 8, 4, fp16, bf16, 32 or None."
        )

    model, tokenizer = load_model_for_inference(
        weights_path=model_name,
        quantization=quantization,
        lora_weights_name_or_path=lora_weights_name_or_path,
        torch_dtype=dtype,
        force_auto_device_map=force_auto_device_map,
        trust_remote_code=trust_remote_code,
    )

    is_translation_model = hasattr(tokenizer, "lang_code_to_id")
    lang_code_to_idx = None

    if (
        is_translation_model
        and (source_lang is None or target_lang is None)
        and "small100" not in model_name
    ):
        raise ValueError(
            f"The model you are using requires a source and target language. "
            f"Please specify them with --source-lang and --target-lang. "
            f"The supported languages are: {tokenizer.lang_code_to_id.keys()}"
        )
    if not is_translation_model and (
        source_lang is not None or target_lang is not None
    ):
        if prompt is None:
            print(
                "WARNING: You are using a model that does not support source and target languages parameters "
                "but you specified them. You probably want to use m2m100/nllb200 for translation or "
                "set --prompt to define the task for you model. "
            )
        else:
            print(
                "WARNING: You are using a model that does not support source and target languages parameters "
                "but you specified them."
            )

    if prompt is not None and "%%SENTENCE%%" not in prompt:
        raise ValueError(
            f"The prompt must contain the %%SENTENCE%% token to indicate where the sentence should be inserted. "
            f"Your prompt: {prompt}"
        )

    if is_translation_model:
        try:
            _ = tokenizer.lang_code_to_id[source_lang]
        except KeyError:
            raise KeyError(
                f"Language {source_lang} not found in tokenizer. Available languages: {tokenizer.lang_code_to_id.keys()}"
            )
        tokenizer.src_lang = source_lang

        try:
            lang_code_to_idx = tokenizer.lang_code_to_id[target_lang]
        except KeyError:
            raise KeyError(
                f"Language {target_lang} not found in tokenizer. Available languages: {tokenizer.lang_code_to_id.keys()}"
            )
        if "small100" in model_name:
            tokenizer.tgt_lang = target_lang
            # We don't need to force the BOS token, so we set is_translation_model to False
            is_translation_model = False

    if model.config.model_type == "seamless_m4t":
        # Loading a seamless_m4t model, we need to set a few things to ensure compatibility

        supported_langs = tokenizer.additional_special_tokens
        supported_langs = [lang.replace("__", "") for lang in supported_langs]

        if source_lang is None or target_lang is None:
            raise ValueError(
                f"The model you are using requires a source and target language. "
                f"Please specify them with --source-lang and --target-lang. "
                f"The supported languages are: {supported_langs}"
            )

        if source_lang not in supported_langs:
            raise ValueError(
                f"Language {source_lang} not found in tokenizer. Available languages: {supported_langs}"
            )
        if target_lang not in supported_langs:
            raise ValueError(
                f"Language {target_lang} not found in tokenizer. Available languages: {supported_langs}"
            )

        tokenizer.src_lang = source_lang

    gen_kwargs = {
        "max_new_tokens": max_length,
        "num_beams": num_beams,
        "num_return_sequences": num_return_sequences,
        "do_sample": do_sample,
        "temperature": temperature,
        "top_k": top_k,
        "top_p": top_p,
    }

    if repetition_penalty is not None:
        gen_kwargs["repetition_penalty"] = repetition_penalty

    if is_translation_model:
        gen_kwargs["forced_bos_token_id"] = lang_code_to_idx

    if model.config.model_type == "seamless_m4t":
        gen_kwargs["tgt_lang"] = target_lang

    if accelerator.is_main_process:
        print(
            f"** Translation **\n"
            f"Input file: {sentences_path}\n"
            f"Output file: {output_path}\n"
            f"Source language: {source_lang}\n"
            f"Target language: {target_lang}\n"
            f"Force target lang as BOS token: {is_translation_model}\n"
            f"Prompt: {prompt}\n"
            f"Starting batch size: {starting_batch_size}\n"
            f"Device: {str(accelerator.device).split(':')[0]}\n"
            f"Num. Devices: {accelerator.num_processes}\n"
            f"Distributed_type: {accelerator.distributed_type}\n"
            f"Max length: {max_length}\n"
            f"Quantization: {quantization}\n"
            f"Precision: {dtype}\n"
            f"Model: {model_name}\n"
            f"LoRA weights: {lora_weights_name_or_path}\n"
            f"Force auto device map: {force_auto_device_map}\n"
            f"Keep special tokens: {keep_special_tokens}\n"
            f"Keep tokenization spaces: {keep_tokenization_spaces}\n"
        )
        print("** Generation parameters **")
        print("\n".join(f"{k}: {v}" for k, v in gen_kwargs.items()))
        print("\n")

    @find_executable_batch_size(starting_batch_size=starting_batch_size)
    def inference(batch_size, sentences_path, output_path):
        nonlocal model, tokenizer, max_length, gen_kwargs, precision, prompt, is_translation_model

        print(f"Translating {sentences_path} with batch size {batch_size}")

        total_lines: int = count_lines(sentences_path)

        data_loader = get_dataloader(
            accelerator=accelerator,
            filename=sentences_path,
            tokenizer=tokenizer,
            batch_size=batch_size,
            max_length=max_length,
            prompt=prompt,
        )

        model, data_loader = accelerator.prepare(model, data_loader)

        samples_seen: int = 0

        with tqdm(
            total=total_lines,
            desc="Dataset translation",
            leave=True,
            ascii=True,
            disable=(not accelerator.is_main_process),
        ) as pbar, open(output_path, "w", encoding="utf-8") as output_file:
            with torch.no_grad():
                for step, batch in enumerate(data_loader):
                    batch["input_ids"] = batch["input_ids"]
                    batch["attention_mask"] = batch["attention_mask"]

                    generated_tokens = accelerator.unwrap_model(model).generate(
                        **batch,
                        **gen_kwargs,
                    )

                    generated_tokens = accelerator.pad_across_processes(
                        generated_tokens, dim=1, pad_index=tokenizer.pad_token_id
                    )

                    generated_tokens = (
                        accelerator.gather(generated_tokens).cpu().numpy()
                    )

                    tgt_text = tokenizer.batch_decode(
                        generated_tokens,
                        skip_special_tokens=not keep_special_tokens,
                        clean_up_tokenization_spaces=not keep_tokenization_spaces,
                    )
                    if accelerator.is_main_process:
                        if (
                            step
                            == math.ceil(
                                math.ceil(total_lines / batch_size)
                                / accelerator.num_processes
                            )
                            - 1
                        ):
                            tgt_text = tgt_text[
                                : (total_lines * num_return_sequences) - samples_seen
                            ]
                        else:
                            samples_seen += len(tgt_text)

                        print(
                            "\n".join(
                                [encode_string(sentence) for sentence in tgt_text]
                            ),
                            file=output_file,
                        )

                    pbar.update(len(tgt_text) // gen_kwargs["num_return_sequences"])

        print(f"Translation done. Output written to {output_path}\n")

    if sentences_path is not None:
        os.makedirs(os.path.abspath(os.path.dirname(output_path)), exist_ok=True)
        inference(sentences_path=sentences_path, output_path=output_path)

    print(f"Translation done.\n")
    with open(output_path, "r", encoding="utf-8") as f:
        return f.read()


# if __name__ == "__main__":
#     parser = argparse.ArgumentParser(description="Run the translation experiments")
#     input_group = parser.add_mutually_exclusive_group(required=True)
#     input_group.add_argument(
#         "--sentences_path",
#         default=None,
#         type=str,
#         help="Path to a txt file containing the sentences to translate. One sentence per line.",
#     )

#     input_group.add_argument(
#         "--sentences_dir",
#         type=str,
#         default=None,
#         help="Path to a directory containing the sentences to translate. "
#         "Sentences must be in  .txt files containing containing one sentence per line.",
#     )

#     parser.add_argument(
#         "--files_extension",
#         type=str,
#         default="txt",
#         help="If sentences_dir is specified, extension of the files to translate. Defaults to txt. "
#         "If set to an empty string, we will translate all files in the directory.",
#     )

#     parser.add_argument(
#         "--output_path",
#         type=str,
#         required=True,
#         help="Path to a txt file where the translated sentences will be written. If the input is a directory, "
#         "the output will be a directory with the same structure.",
#     )

#     parser.add_argument(
#         "--source_lang",
#         type=str,
#         default=None,
#         required=False,
#         help="Source language id. See: supported_languages.md. Required for m2m100 and nllb200",
#     )

#     parser.add_argument(
#         "--target_lang",
#         type=str,
#         default=None,
#         required=False,
#         help="Source language id. See: supported_languages.md. Required for m2m100 and nllb200",
#     )

#     parser.add_argument(
#         "--starting_batch_size",
#         type=int,
#         default=128,
#         help="Starting batch size, we will automatically reduce it if we find an OOM error."
#         "If you use multiple devices, we will divide this number by the number of devices.",
#     )

#     parser.add_argument(
#         "--model_name",
#         type=str,
#         default="facebook/m2m100_1.2B",
#         help="Path to the model to use. See: https://huggingface.co/models",
#     )

#     parser.add_argument(
#         "--lora_weights_name_or_path",
#         type=str,
#         default=None,
#         help="If the model uses LoRA weights, path to those weights. See: https://github.com/huggingface/peft",
#     )

#     parser.add_argument(
#         "--force_auto_device_map",
#         action="store_true",
#         help=" Whether to force the use of the auto device map. If set to True, "
#         "the model will be split across GPUs and CPU to fit the model in memory. "
#         "If set to False, a full copy of the model will be loaded  into each GPU. Defaults to False.",
#     )

#     parser.add_argument(
#         "--max_length",
#         type=int,
#         default=256,
#         help="Maximum number of tokens in the source sentence and generated sentence. "
#         "Increase this value to translate longer sentences, at the cost of increasing memory usage.",
#     )

#     parser.add_argument(
#         "--num_beams",
#         type=int,
#         default=5,
#         help="Number of beams for beam search, m2m10 author recommends 5, but it might use too much memory",
#     )

#     parser.add_argument(
#         "--num_return_sequences",
#         type=int,
#         default=1,
#         help="Number of possible translation to return for each sentence (num_return_sequences<=num_beams).",
#     )

#     parser.add_argument(
#         "--precision",
#         type=str,
#         default=None,
#         choices=["bf16", "fp16", "32", "4", "8"],
#         help="Precision of the model. bf16, fp16 or 32, 8 , 4 "
#         "(4bits/8bits quantification, requires bitsandbytes library: https://github.com/TimDettmers/bitsandbytes). "
#         "If None, we will use the torch.dtype of the model weights.",
#     )

#     parser.add_argument(
#         "--do_sample",
#         action="store_true",
#         help="Use sampling instead of beam search.",
#     )

#     parser.add_argument(
#         "--temperature",
#         type=float,
#         default=0.8,
#         help="Temperature for sampling, value used only if do_sample is True.",
#     )

#     parser.add_argument(
#         "--top_k",
#         type=int,
#         default=100,
#         help="If do_sample is True, will sample from the top k most likely tokens.",
#     )

#     parser.add_argument(
#         "--top_p",
#         type=float,
#         default=0.75,
#         help="If do_sample is True, will sample from the top k most likely tokens.",
#     )

#     parser.add_argument(
#         "--keep_special_tokens",
#         action="store_true",
#         help="Keep special tokens in the decoded text.",
#     )

#     parser.add_argument(
#         "--keep_tokenization_spaces",
#         action="store_true",
#         help="Do not clean spaces in the decoded text.",
#     )

#     parser.add_argument(
#         "--repetition_penalty",
#         type=float,
#         default=None,
#         help="Repetition penalty.",
#     )

#     parser.add_argument(
#         "--prompt",
#         type=str,
#         default=None,
#         help="Prompt to use for generation. "
#         "It must include the special token %%SENTENCE%% which will be replaced by the sentence to translate.",
#     )

#     parser.add_argument(
#         "--trust_remote_code",
#         action="store_true",
#         help="If set we will trust remote code in HuggingFace models. This is required for some models.",
#     )

#     args = parser.parse_args()

#     main(
#         sentences_path=args.sentences_path,
#         sentences_dir=args.sentences_dir,
#         files_extension=args.files_extension,
#         output_path=args.output_path,
#         source_lang=args.source_lang,
#         target_lang=args.target_lang,
#         starting_batch_size=args.starting_batch_size,
#         model_name=args.model_name,
#         max_length=args.max_length,
#         num_beams=args.num_beams,
#         num_return_sequences=args.num_return_sequences,
#         precision=args.precision,
#         do_sample=args.do_sample,
#         temperature=args.temperature,
#         top_k=args.top_k,
#         top_p=args.top_p,
#         keep_special_tokens=args.keep_special_tokens,
#         keep_tokenization_spaces=args.keep_tokenization_spaces,
#         repetition_penalty=args.repetition_penalty,
#         prompt=args.prompt,
#         trust_remote_code=args.trust_remote_code,
#     )

demo = gradio.Interface(fn=main, inputs=["textbox", "textbox", "textbox", "textbox"], outputs="textbox")
demo.launch()