import os
from threading import Thread
from typing import Iterator

import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

DESCRIPTION = """\
# Prot2Text Demo

A demo to generate a protein's funtion with its amino acid sequence and its structure using [Prot2Text Base v1.1](https://huggingface.co/habdine/Prot2Text-Base-v1-1). To test this model, only enter below, the AlphaFoldDB ID of the protein.
"""

MAX_MAX_NEW_TOKENS = 256
DEFAULT_MAX_NEW_TOKENS = 100


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


tokenizer = AutoTokenizer.from_pretrained('habdine/Prot2Text-Base-v1-1', 
                                            trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained('habdine/Prot2Text-Base-v1-1',
                                            trust_remote_code=True).to(device)
model.eval()


@spaces.GPU(duration=90)
def generate(
    message: str,
    chat_history: list[dict],
    max_new_tokens: int = 1024,
    do_sample: bool = False,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:


    streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        protein_pdbID=message,
        tokenizer=tokenizer,
        device=device,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=do_sample,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty,
    )
    t = Thread(target=model.generate_protein_description, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)


chat_interface = gr.ChatInterface(
    fn=generate,
    additional_inputs=[
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Checkbox(label="Do Sample"),
        gr.Slider(
            label="Temperature",
            minimum=0.1,
            maximum=4.0,
            step=0.1,
            value=0.6,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.9,
        ),
        gr.Slider(
            label="Top-k",
            minimum=1,
            maximum=1000,
            step=1,
            value=50,
        ),
        gr.Slider(
            label="Repetition penalty",
            minimum=1.0,
            maximum=2.0,
            step=0.05,
            value=1.0,
        ),
    ],
    stop_btn=None,
    examples=[
        ['P0A0V1'],
        ["Q10MK9"],
        ["Q6K5W5"],
        ["Q65WY8"]
        
    ],
    cache_examples=False,
    type="messages",
)

with gr.Blocks(css_paths="style.css", fill_height=True) as demo:
    gr.Markdown(DESCRIPTION)
    gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
    chat_interface.render()

if __name__ == "__main__":
    demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=7860)