# A100 Zero GPU
import spaces

# flash attention
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)

# Phantom Package
import torch
from PIL import Image
from utils.utils import *
from model.load_model import load_model

# Gradio Package
import time
import gradio as gr
from threading import Thread
from accelerate import Accelerator
from transformers import TextIteratorStreamer
from torchvision.transforms.functional import pil_to_tensor

# accel
accel = Accelerator()

# loading model
model_1_8, tokenizer_1_8 = load_model(size='1.8b')

# loading model
model_3_8, tokenizer_3_8 = load_model(size='3.8b')

# loading model
model_7, tokenizer_7 = load_model(size='7b')

def threading_function(inputs, streamer, device, model, tokenizer, temperature, new_max_token, top_p):

    # propagation
    _inputs = model.eval_process(inputs=inputs,
                                 data='demo',
                                 tokenizer=tokenizer,
                                 device=device)
    generation_kwargs = _inputs
    generation_kwargs.update({'streamer': streamer})
    generation_kwargs.update({'do_sample': True})
    generation_kwargs.update({'max_new_tokens': new_max_token})
    generation_kwargs.update({'top_p': top_p})
    generation_kwargs.update({'temperature': temperature})
    return model.generate(**generation_kwargs)

@spaces.GPU
def bot_streaming(message, history, link, temperature, new_max_token, top_p):

    # model selection
    if "1.8B" in link:
        model = model_1_8
        tokenizer = tokenizer_1_8
    elif "3.8B" in link:
        model = model_3_8
        tokenizer = tokenizer_3_8
    elif "7B" in link:
        model = model_7
        tokenizer = tokenizer_7
    
    # X -> bfloat16 conversion 
    for param in model.parameters():
        if 'float32' in str(param.dtype).lower() or 'float16' in str(param.dtype).lower():
            param.data = param.data.to(torch.bfloat16)

    # cpu -> gpu
    for param in model.parameters():
        if not param.is_cuda:
            param.data = param.to(accel.device)

    try:
        # prompt type -> input prompt
        if len(message['files']) == 1:
            # Image Load
            image = pil_to_tensor(Image.open(message['files'][0]).convert("RGB"))
            inputs = [{'image': image.to(accel.device), 'question': message['text']}]
        elif len(message['files']) > 1:
            raise Exception("No way!")
        else:
            inputs = [{'question': message['text']}]

        # Text Generation
        with torch.inference_mode():
            # kwargs
            streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)

            # Threading generation
            thread = Thread(target=threading_function, kwargs=dict(inputs=inputs,
                                                                streamer=streamer,
                                                                model=model,
                                                                tokenizer=tokenizer,
                                                                device=accel.device,
                                                                temperature=temperature,
                                                                new_max_token=new_max_token,
                                                                top_p=top_p))
            thread.start()

            # generated text
            generated_text = ""
            for new_text in streamer:
                generated_text += new_text
            generated_text

        # Text decoding
        response = output_filtering(generated_text, model)

    except:
        response = "There may be unsupported format: ex) pdf, video, sound. Only supported is a single image in this version."

    # private log print
    text = message['text']
    files = message['files']
    print('-----------------------------')
    print(f'Link: {link}')
    print(f'Text: {text}')
    print(f'MM Files: {files}')
    print(f'Response: {response}')
    print('-----------------------------\n')


    buffer = ""
    for character in response:
        buffer += character
        time.sleep(0.012)
        yield buffer

demo = gr.ChatInterface(fn=bot_streaming,
                        additional_inputs = [gr.Radio(["1.8B", "3.8B", "7B"], label="Size", info="Select one model size", value="7B"), gr.Slider(0, 1, 0.9, label="temperature"), gr.Slider(1, 1024, 128, label="new_max_token"), gr.Slider(0, 1, 0.95, label="top_p")],
                        additional_inputs_accordion="Generation Hyperparameters",
                        theme=gr.themes.Soft(),
                        title="Phantom",
                        description="Phantom is super efficient 0.5B, 1.8B, 3.8B, and 7B size Large Language and Vision Models built on new propagation strategy. "
                                    "Its inference speed highly depends on assinging non-scheduled GPU. (Therefore, once all GPUs are busy, then inference may be taken in infinity) "
                                    "Note that, we don't support history-based conversation referring to previous dialogue",
                        stop_btn="Stop Generation", multimodal=True)
demo.launch()