import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
from functools import partial


tokenizer = AutoTokenizer.from_pretrained("IlyaGusev/saiga_llama3_8b")
model = AutoModelForCausalLM.from_pretrained("IlyaGusev/saiga_llama3_8b", torch_dtype=torch.bfloat16)
model = model


def transform_history(history):
    transformed_history = []
    for qa_pair in history:
        transformed_history.append({"role": "user", "content": qa_pair[0]})
        transformed_history.append({"role": "assistant", "content": qa_pair[1]})
    return transformed_history


def predict(message, history):
    # print(history) [[вопрос1, ответ1], [вопрос2, ответ2]...]
    history = transform_history(history)
    history_transformer_format = history + [{"role": "user", "content": message}, 
                                            {"role": "assistant", "content": ""}]

    model_inputs = tokenizer.apply_chat_template(history_transformer_format, return_tensors="pt")
    streamer = TextIteratorStreamer(tokenizer, timeout=300, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        streamer=streamer,
        max_new_tokens=1024,
        do_sample=True,
        top_p=0.95,
        top_k=1000,
        temperature=1.0,
        num_beams=1,
    )
    generating_func = partial(model.generate, model_inputs)
    t = Thread(target=generating_func, kwargs=generate_kwargs)
    t.start()

    partial_message = ""
    for new_token in streamer:
        if  'assistant' not in new_token:
            partial_message += new_token
            yield partial_message


gr.ChatInterface(predict).launch(share=True)