import argparse

import torch
from datasets import load_dataset
from transformers import AutoProcessor, VisionEncoderDecoderModel


def speedometer(
    model: torch.nn.Module,
    pixel_values: torch.Tensor,
    decoder_input_ids: torch.Tensor,
    processor: AutoProcessor,
    bad_words_ids: list,
    warmup_iters: int = 100,
    timing_iters: int = 100,
    num_tokens: int = 10,
) -> None:
    """Measure average run time for a PyTorch module

    Performs forward passes.
    """
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    # Warmup runs
    torch.cuda.synchronize()
    for _ in range(warmup_iters):
        outputs = model.generate(
            pixel_values.to(model.device),
            decoder_input_ids=decoder_input_ids.to(model.device),
            early_stopping=True,
            pad_token_id=processor.tokenizer.pad_token_id,
            eos_token_id=processor.tokenizer.eos_token_id,
            use_cache=True,
            num_beams=1,
            bad_words_ids=bad_words_ids,
            return_dict_in_generate=True,
            min_length=num_tokens,
            max_length=num_tokens,
        )

    # Timing runs
    start.record()
    for _ in range(timing_iters):
        outputs = model.generate(
            pixel_values.to(model.device),
            decoder_input_ids=decoder_input_ids.to(model.device),
            early_stopping=True,
            pad_token_id=processor.tokenizer.pad_token_id,
            eos_token_id=processor.tokenizer.eos_token_id,
            use_cache=True,
            num_beams=1,
            bad_words_ids=bad_words_ids,
            return_dict_in_generate=True,
            min_length=num_tokens,
            max_length=num_tokens,
        )
    end.record()
    torch.cuda.synchronize()

    mean = start.elapsed_time(end) / timing_iters
    print(f"Mean time: {mean} ms")

    return mean


def get_ja_list_of_lists(processor):
    def is_japanese(s):
        "Made by GPT-4: https://chat.openai.com/share/a795b15c-8534-40b9-9699-c8c1319f5f25"
        for char in s:
            code_point = ord(char)
            if (
                0x3040 <= code_point <= 0x309F
                or 0x30A0 <= code_point <= 0x30FF
                or 0x4E00 <= code_point <= 0x9FFF
                or 0x3400 <= code_point <= 0x4DBF
                or 0x20000 <= code_point <= 0x2A6DF
                or 0x31F0 <= code_point <= 0x31FF
                or 0xFF00 <= code_point <= 0xFFEF
                or 0x3000 <= code_point <= 0x303F
                or 0x3200 <= code_point <= 0x32FF
            ):
                continue
            else:
                return False
        return True

    ja_tokens, ja_ids = [], []
    for token, id in processor.tokenizer.vocab.items():
        if is_japanese(token.lstrip("▁")):
            ja_tokens.append(token)
            ja_ids.append(id)

    return [[x] for x in ja_ids]


def main():

    parser = argparse.ArgumentParser(description='Description of your program')
    parser.add_argument('--model_path', help='Description for foo argument', required=True)
    parser.add_argument('--ja_bad_words', help='Use ja bad_words_ids', action="store_true", default=False)
    args = parser.parse_args()

    print("Running speed test on model: ", args.model_path, "with ja_bad_words: ", args.ja_bad_words)

    processor = AutoProcessor.from_pretrained(args.model_path)
    model = VisionEncoderDecoderModel.from_pretrained(args.model_path)

    device = 0 if torch.cuda.is_available() else torch.device("cpu")

    model.to(device)

    dataset = load_dataset("hf-internal-testing/example-documents", split="test")

    image = dataset[1]["image"]

    task_prompt = "<s_synthdog>"
    decoder_input_ids = processor.tokenizer(
        task_prompt, add_special_tokens=False, return_tensors="pt"
    ).input_ids

    pixel_values = processor(image, return_tensors="pt").pixel_values

    bad_words_ids = [[processor.tokenizer.unk_token_id]]

    if args.ja_bad_words:
        bad_words_ids += get_ja_list_of_lists(processor)

    print("Length of bad_words_ids: ", len(bad_words_ids))

    results = speedometer(
        model,
        pixel_values,
        decoder_input_ids,
        processor,
        bad_words_ids=bad_words_ids,
        warmup_iters=100,
        timing_iters=100,
        num_tokens=10,
    )


if __name__ == "__main__":

    main()