import sys
import time
import warnings
from pathlib import Path
from typing import Optional

import lightning as L
import torch

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from generate import generate
from lit_llama import Tokenizer
from lit_llama.adapter import LLaMA
from lit_llama.utils import EmptyInitOnDevice, lazy_load, llama_model_lookup
from scripts.prepare_alpaca import generate_prompt

# 配置hugface环境
from huggingface_hub import hf_hub_download
import gradio as gr
import os
import glob
import json


# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
torch.set_float32_matmul_precision("high")
# quantize: Optional[str] = "llm.int8",

def model_load(
    adapter_path: Path = Path("out/adapter/alpaca/lit-llama-adapter-finetuned_15k.pth"),
    pretrained_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"),
    quantize: Optional[str] = None,
):

    fabric = L.Fabric(devices=1)
    dtype = torch.bfloat16 if fabric.device.type == "cuda" and torch.cuda.is_bf16_supported() else torch.float32

    with lazy_load(pretrained_path) as pretrained_checkpoint, lazy_load(adapter_path) as adapter_checkpoint:
        name = llama_model_lookup(pretrained_checkpoint)

        with EmptyInitOnDevice(
                device=fabric.device, dtype=dtype, quantization_mode=quantize
        ):
            model = LLaMA.from_name(name)

        # 1. Load the pretrained weights
        model.load_state_dict(pretrained_checkpoint, strict=False)
        # 2. Load the fine-tuned adapter weights
        model.load_state_dict(adapter_checkpoint, strict=False)

    model.eval()
    model = fabric.setup_module(model)

    return model


def instruct_generate(
    img_path: str = " ",
    prompt: str = "What food do lamas eat?",
    input: str = "",
    max_new_tokens: int = 1024,
    temperature: float = 0.8,
    top_k: int = 200,
) -> None:
    """Generates a response based on a given instruction and an optional input.
    This script will only work with checkpoints from the instruction-tuned LLaMA-Adapter model.
    See `finetune_adapter.py`.
    Args:
        prompt: The prompt/instruction (Alpaca style).
        adapter_path: Path to the checkpoint with trained adapter weights, which are the output of
            `finetune_adapter.py`.
        input: Optional input (Alpaca style).
        pretrained_path: The path to the checkpoint with pretrained LLaMA weights.
        tokenizer_path: The tokenizer path to load.
        quantize: Whether to quantize the model and using which method:
            ``"llm.int8"``: LLM.int8() mode,
            ``"gptq.int4"``: GPTQ 4-bit mode.
        max_new_tokens: The number of generation steps to take.
        top_k: The number of top most probable tokens to consider in the sampling process.
        temperature: A value controlling the randomness of the sampling process. Higher values result in more random
    """
    if input in input_value_2_real.keys():
        input = input_value_2_real[input]
    if "..." in input:
        input = input.replace("...", "")
    sample = {"instruction": prompt, "input": input}
    prompt = generate_prompt(sample)
    encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)
    # prompt_length = encoded.size(0)

    y = generate(
        model,
        idx=encoded,
        max_seq_length=max_new_tokens,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_k=top_k,
        eos_id=tokenizer.eos_id
    )

    # y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id)

    output = tokenizer.decode(y)
    output = output.split("### Response:")[1].strip()
    print(output)
    return output

# 配置具体参数
# pretrained_path = hf_hub_download(
#     repo_id="Gary3410/pretrain_lit_llama", filename="lit-llama.pth")
# tokenizer_path = hf_hub_download(
#     repo_id="Gary3410/pretrain_lit_llama", filename="tokenizer.model")
# adapter_path = hf_hub_download(
#     repo_id="Gary3410/pretrain_lit_llama", filename="lit-llama-adapter-finetuned_15k.pth")
adapter_path = "lit-llama-adapter-finetuned_15k.pth"
tokenizer_path = "tokenizer.model"
pretrained_path = "lit-llama.pth"
example_path = "example.json"
# 1024如果不够, 调整为512
max_seq_len = 1024
max_batch_size = 1

model = model_load(adapter_path, pretrained_path)
tokenizer = Tokenizer(tokenizer_path)
with open(example_path, 'r') as f:
    content = f.read()
    example_dict = json.loads(content)
input_value_2_real = {}
for scene_id, scene_dict in example_dict.items():
    input_value_2_real[scene_dict["input_display"]] = scene_dict["input"]

def create_instruct_demo():
    with gr.Blocks() as instruct_demo:
        with gr.Row():
            with gr.Column():
                scene_img = gr.Image(label='Scene', type='filepath', shape=(1024, 320), height=320, width=1024, interactive=False)

                object_list = gr.Textbox(
                    lines=5, label="Object List", placeholder="Please click one from the examples below", interactive=False)

                instruction = gr.Textbox(
                    lines=2, label="Instruction", placeholder="Please input the instruction. E.g.Please turn on the lamp")
                max_len = gr.Slider(minimum=512, maximum=1024,
                                    value=1024, label="Max length")
                with gr.Accordion(label='Advanced options', open=False):
                    temp = gr.Slider(minimum=0, maximum=1,
                                     value=0.8, label="Temperature")
                    top_k = gr.Slider(minimum=100, maximum=300,
                                      value=200, label="Top k")

                run_botton = gr.Button("Run")

            with gr.Column():
                outputs = gr.Textbox(lines=20, label="Output")

        inputs = [scene_img, instruction, object_list, max_len, temp, top_k]
        # inputs = [scene_img, instruction, object_list]

        # 接下来设定具体的example格式
        examples_img_list = glob.glob("caption_demo/*.png")
        examples = []
        for example_img_one in examples_img_list:
            scene_name = os.path.basename(example_img_one).split(".")[0]
            example_object_list = example_dict[scene_name]["input"]
            example_instruction = example_dict[scene_name]["instruction"]
            example_one = [example_img_one, example_instruction, example_object_list]
            examples.append(example_one)

        gr.Examples(
            examples=examples,
            inputs=inputs,
            outputs=outputs,
            fn=instruct_generate,
            cache_examples=os.getenv('SYSTEM') == 'spaces'
        )
        # inputs = inputs + [max_len, temp, top_k]
        run_botton.click(fn=instruct_generate, inputs=inputs, outputs=outputs)
    return instruct_demo


# Please refer to our [arXiv paper](https://arxiv.org/abs/2303.16199) and [github](https://github.com/ZrrSkywalker/LLaMA-Adapter) for more details.
description = """
# TaPA
The official demo for **Embodied Task Planning with Large Language Models**.
"""

with gr.Blocks(css='style.css') as demo:
    gr.Markdown(description)
    with gr.TabItem("Instruction-Following"):
        create_instruct_demo()

demo.queue(api_open=True, concurrency_count=1).launch()