import os
from threading import Thread
from typing import Iterator

import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))


# DESCRIPTION = ""
# if not torch.cuda.is_available():
#     DESCRIPTION = "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"


if torch.cuda.is_available():
    device = torch.device("cuda")
    print('There are %d GPU(s) available.' % torch.cuda.device_count())
    print('We will use the GPU:', torch.cuda.get_device_name(0))
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")



tokenizer = AutoTokenizer.from_pretrained("Back-up/T5-pretrain")
model = AutoModelForSeq2SeqLM.from_pretrained("Back-up/T5-large-QA")
model.to(device)


@spaces.GPU
def generate(
    message: str,
    chat_history: list[tuple[str, str]],
    system_prompt: str,
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    tokenized_text = tokenizer.encode(message, return_tensors="pt").to(model.device)

    model.eval()
    summary_ids = model.generate(
                        tokenized_text,
                        max_length=1024,
                        min_length=8,
                        num_beams=5,
                        repetition_penalty=2.5,
                        length_penalty=1.0,
                        early_stopping=True
                    )
    output = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    yield output

chat_interface = gr.ChatInterface(
    fn=generate,
    additional_inputs=[
        gr.Textbox(label="System prompt", lines=6),
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.1,
            maximum=4.0,
            step=0.1,
            value=0.6,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.9,
        ),
        gr.Slider(
            label="Top-k",
            minimum=1,
            maximum=1000,
            step=1,
            value=50,
        ),
        gr.Slider(
            label="Repetition penalty",
            minimum=1.0,
            maximum=2.0,
            step=0.05,
            value=1.2,
        ),
    ],
    stop_btn=None,
    examples=[
        ["Trường đại học Nông Lâm thành phố Hồ Chí Minh nằm ở đâu?"],
        ["Mục tiêu chiến lược của trường đại học Nông Lâm thành phố Hồ Chí Minh là gì?"],
        ["Sinh viên được khen thưởng cá nhân và tập thể khi nào?"],
        ["Điều kiện cơ bản để được hỗ trợ vay tiền sinh viên là gì?"],
        ["Trường Đại học Nông Lâm đã trải qua bao nhiêu năm hoạt động tính đến năm 2023?"],
        ["Những hành vi nào của sinh viên bị coi là vi phạm quy định của Nhà trường?"],
        ["Địa chỉ của Phân hiệu Trường Đại học Nông Lâm tại Ninh Thuận?"],
        ["Làm thế nào khi sinh viên không hài lòng với việc giải quyết thắc mắc của Trưởng Bộ môn?"],
        ["Làm thế để yêu cầu phúc khảo bài thi?"],
        ["Nghĩa vụ của sinh viên là gì?"],
        ["Viết cho tôi một chương trình tính số nguyên tố bằng python."]
    ],
)

with gr.Blocks(css="style.css") as demo:
    chat_interface.render()


if __name__ == "__main__":
    demo.queue(max_size=20).launch()