import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, TextStreamer
from threading import Thread
import gradio as gr
from peft import PeftModel

model_name_or_path = "sarvamai/OpenHathi-7B-Hi-v0.1-Base"
peft_model_id = "shuvom/OpenHathi-7B-FT-v0.1_SI"
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, load_in_4bit=True, device_map="auto")

# tokenizer.chat_template = chat_template
tokenizer = AutoTokenizer.from_pretrained(peft_model_id)
# make embedding resizing configurable?
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)

model = PeftModel.from_pretrained(model, peft_model_id)

class ChatCompletion:
  def __init__(self, model, tokenizer, system_prompt=None):
    self.model = model
    self.tokenizer = tokenizer
    self.streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True)
    self.print_streamer = TextStreamer(self.tokenizer, skip_prompt=True)
    # set the model in inference mode
    self.model.eval()
    self.system_prompt = system_prompt

  def get_completion(self, prompt, system_prompt=None, message_history=None, max_new_tokens=512, temperature=0.0):
    if temperature < 1e-2:
      temperature = 1e-2
    messages = []
    if message_history is not None:
      messages.extend(message_history)
    elif system_prompt or self.system_prompt:
      system_prompt = system_prompt or self.system_prompt
      messages.append({"role": "system", "content":system_prompt})
    messages.append({"role": "user", "content": prompt})
    chat_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    inputs = self.tokenizer(chat_prompt, return_tensors="pt", add_special_tokens=False)
    # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
    generation_kwargs = dict(max_new_tokens=max_new_tokens,
                             temperature=temperature,
                             top_p=0.95,
                             do_sample=True,
                             eos_token_id=tokenizer.eos_token_id,
                             repetition_penalty=1.2
                             )
    generated_text = self.model.generate(**inputs, streamer=self.print_streamer, **generation_kwargs)
    return generated_text

  def get_chat_completion(self, message, history):
    messages = []
    if self.system_prompt:
      messages.append({"role": "system", "content":self.system_prompt})
    for user_message, assistant_message in history:
        messages.append({"role": "user", "content": user_message})
        messages.append({"role": "system", "content": assistant_message})
    messages.append({"role": "user", "content": message})
    chat_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    inputs = self.tokenizer(chat_prompt, return_tensors="pt")
    # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
    generation_kwargs = dict(inputs,
                             streamer=self.streamer,
                             max_new_tokens=2048,
                             temperature=0.2,
                             top_p=0.95,
                             eos_token_id=tokenizer.eos_token_id,
                             do_sample=True,
                             repetition_penalty=1.2,
                             )
    thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
    thread.start()
    generated_text = ""
    for new_text in self.streamer:
        generated_text += new_text.replace(self.tokenizer.eos_token, "")
        yield generated_text
    thread.join()
    return generated_text

  def get_completion_without_streaming(self, prompt, system_prompt=None, message_history=None, max_new_tokens=512, temperature=0.0):
    if temperature < 1e-2:
      temperature = 1e-2
    messages = []
    if message_history is not None:
      messages.extend(message_history)
    elif system_prompt or self.system_prompt:
      system_prompt = system_prompt or self.system_prompt
      messages.append({"role": "system", "content":system_prompt})
    messages.append({"role": "user", "content": prompt})
    chat_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    inputs = self.tokenizer(chat_prompt, return_tensors="pt", add_special_tokens=False)
    # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
    generation_kwargs = dict(max_new_tokens=max_new_tokens,
                             temperature=temperature,
                             top_p=0.95,
                             do_sample=True,
                             repetition_penalty=1.1)
    outputs = self.model.generate(**inputs, **generation_kwargs)
    generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated_text

text_generator = ChatCompletion(model, tokenizer, system_prompt="You are a native Hindi speaker who can converse at expert level in both Hindi and colloquial Hinglish.")

gr.ChatInterface(text_generator.get_chat_completion).queue().launch(debug=True)