import argparse
import onnx
import os
import requests
import shutil
import subprocess
import sys
import torch

from onnxruntime_genai.models.builder import create_model
from PIL import Image
from transformers import AutoConfig, AutoProcessor, AutoModelForCausalLM


def build_vision(args):
    # Many images:
    prompt = f"{user_prompt}<|image_1|>\n <|image_2|>\n <|image_3|>\n <|image_4|>\n What is shown in these four images?{prompt_suffix}{assistant_prompt}"
    url = "https://www.ilankelman.org/stopsigns/australia.jpg"
    image_1 = Image.open(requests.get(url, stream=True).raw)
    url = "https://img.freepik.com/free-photo/painting-mountain-lake-with-mountain-background_188544-9126.jpg?w=2000"
    image_2 = Image.open(requests.get(url, stream=True).raw)
    url = "https://th.bing.com/th/id/OIP.gCvQ1vmPVJmrq1nnzM3ZHQHaEo?rs=1&pid=ImgDetMain"
    image_3 = Image.open(requests.get(url, stream=True).raw)
    url = "https://wallpaper.dog/large/10809054.jpg"
    image_4 = Image.open(requests.get(url, stream=True).raw)
    images = [image_1, image_2, image_3, image_4]
    inputs = processor(prompt, images, return_tensors="pt").to(args.execution_provider.replace("dml", "cuda"))
    inputs["pixel_values"] = inputs["pixel_values"].to(args.precision)

    # TorchScript export
    dummy_inputs = (
        inputs["pixel_values"],   # inputs_embeds: Optional[torch.FloatTensor] = None,
        inputs["image_sizes"],    # image_sizes: Optional[torch.FloatTensor] = None,
    )
    dynamic_axes = {
        "pixel_values": {0: "num_images", 1: "max_num_crops", 3: "height", 4: "width"},
        "image_sizes": {0: "num_images"},
        "image_features": {0: "num_image_tokens"},
    }
    filename = "phi-3.5-v-instruct-vision.onnx"

    temp_folder_1 = os.path.join(args.output, "vision_init_export")
    os.makedirs(temp_folder_1, exist_ok=True)

    fpath_1 = os.path.join(temp_folder_1, filename)
    torch.onnx.export(
        model.model.vision_embed_tokens,
        args=dummy_inputs,
        f=fpath_1,
        export_params=True,
        input_names=["pixel_values", "image_sizes"],
        output_names=["image_features"],
        dynamic_axes=dynamic_axes,
        opset_version=14,
        do_constant_folding=True,
    )

    onnx.checker.check_model(fpath_1)
    onnx.shape_inference.infer_shapes_path(fpath_1)
    onnx_model = onnx.load_model(fpath_1, load_external_data=True)

    temp_folder_2 = os.path.join(args.output, "vision_after_export")
    os.makedirs(temp_folder_2, exist_ok=True)

    fpath_2 = os.path.join(temp_folder_2, filename)
    onnx.save_model(
        onnx_model,
        fpath_2,
        save_as_external_data=True,
        all_tensors_to_one_file=True,
        location=f"{filename}.data",
        size_threshold=0,
        convert_attribute=False,
    )
    shutil.rmtree(temp_folder_1)

    # ORT transformer optimizer
    temp_folder_3 = os.path.join(args.output, "vision_after_opt")
    fpath_3 = os.path.join(temp_folder_3, filename)
    subprocess.run(
        [
            f"{sys.executable}", "-m", "onnxruntime.transformers.optimizer",
            "--input", fpath_2,
            "--output", fpath_3,
            "--model_type", "clip",
            "--num_heads", str(16),
            "--hidden_size", str(1024),
            "--use_external_data_format",
            "--opt_level", str(0),
            "--disable_shape_inference",
        ]
    )
    shutil.rmtree(temp_folder_2)

    # ORT 4-bits quantizer
    fpath_4 = os.path.join(args.output, filename)
    cmd = [
        f"{sys.executable}", "-m", "onnxruntime.quantization.matmul_4bits_quantizer",
        "--input_model", fpath_3,
        "--output_model", fpath_4,
        "--block_size", str(32),
    ]
    if args.precision == torch.float32: cmd.extend(["--accuracy_level", str(4)])
    subprocess.run(cmd)
    shutil.rmtree(temp_folder_3)


def build_embedding(args):
    # TorchScript export
    batch_size, sequence_length, num_img_tokens = 2, 8, 2
    inputs = {
        "input_ids": torch.randint(low=0, high=config.vocab_size, size=(batch_size, sequence_length), device=args.execution_provider.replace("dml", "cuda"), dtype=torch.int64),
        "image_features": torch.randn(num_img_tokens, config.hidden_size, device=args.execution_provider.replace("dml", "cuda"), dtype=args.precision),
        "inputs_embeds": torch.randn(batch_size, sequence_length, config.hidden_size, device=args.execution_provider.replace("dml", "cuda"), dtype=args.precision),
    }
    inputs["input_ids"][0][0] = -1
    inputs["input_ids"][0][1] = -1
    dummy_inputs = (
        inputs["input_ids"],      # input_ids: torch.LongTensor
        inputs["image_features"], # image_features: Optional[torch.FloatTensor] = None,
    )
    dynamic_axes = {
        "input_ids": {0: "batch_size", 1: "sequence_length"},
        "image_features": {0: "num_image_tokens"},
        "inputs_embeds": {0: "batch_size", 1: "sequence_length"},
    }
    filename = "phi-3.5-v-instruct-embedding.onnx"

    temp_folder_1 = os.path.join(args.output, "embedding_init_export")
    os.makedirs(temp_folder_1, exist_ok=True)

    fpath_1 = os.path.join(temp_folder_1, filename)
    torch.onnx.export(
        model.model.combined_embed,
        args=dummy_inputs,
        f=fpath_1,
        export_params=True,
        input_names=["input_ids", "image_features"],
        output_names=["inputs_embeds"],
        dynamic_axes=dynamic_axes,
        opset_version=14,
        do_constant_folding=True,
    )

    onnx.checker.check_model(fpath_1)
    onnx.shape_inference.infer_shapes_path(fpath_1)
    onnx_model = onnx.load_model(fpath_1, load_external_data=True)

    fpath_2 = os.path.join(args.output, filename)
    onnx.save_model(
        onnx_model,
        fpath_2,
        save_as_external_data=True,
        all_tensors_to_one_file=True,
        location=f"{filename}.data",
        size_threshold=0,
        convert_attribute=False,
    )
    shutil.rmtree(temp_folder_1)


def build_text(args):
    # Create ONNX model
    model_name = None
    precision = "int4"
    extra_options = {
        "exclude_embeds": "true",
        "filename": "phi-3.5-v-instruct-text.onnx",
    }
    if args.precision == torch.float32: extra_options["int4_accuracy_level"] = 4
    create_model(model_name, args.input, args.output, precision, args.execution_provider, args.cache_dir, **extra_options)


def get_args():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "-i",
        "--input",
        required=True,
        help="Path to folder on disk containing the Hugging Face config, model, tokenizer, etc.",
    )

    parser.add_argument(
        "-o",
        "--output",
        required=True,
        help="Path to folder to store ONNX model and additional files (e.g. GenAI config, external data files, etc.)",
    )

    parser.add_argument(
        "-p",
        "--precision",
        required=True,
        choices=["fp16", "fp32"],
        help="Precision to export PyTorch components with",
    )

    parser.add_argument(
        "-e",
        "--execution_provider",
        required=True,
        choices=["cpu", "cuda", "dml"],
        help="Execution provider for Phi-3.5 vision components",
    )

    parser.add_argument(
        "-c",
        "--cache_dir",
        required=False,
        default=os.path.join('.', 'cache_dir'),
        help="Cache directory for Hugging Face files and temporary ONNX external data files",
    )

    args = parser.parse_args()
    args.precision = torch.float16 if args.precision == "fp16" else torch.float32
    return args

if __name__ == "__main__":
    user_prompt = '<|user|>\n'
    assistant_prompt = '<|assistant|>\n'
    prompt_suffix = "<|end|>\n"

    args = get_args()
    config = AutoConfig.from_pretrained(args.input, trust_remote_code=True)
    processor = AutoProcessor.from_pretrained(args.input, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(args.input, trust_remote_code=True, torch_dtype=args.precision).to(args.execution_provider.replace("dml", "cuda"))

    # Build model components
    build_vision(args)
    build_embedding(args)
    build_text(args)