from huggingface_hub import InferenceClient
import os
from dotenv import load_dotenv
import random
import json
from openai import OpenAI

load_dotenv()
API_TOKEN = os.getenv('HF_TOKEN')

def format_prompt(message, history):
  prompt = "<s>"
  for user_prompt, bot_response in history:
    prompt += f"[INST] {user_prompt} [/INST]"
    prompt += f" {bot_response}</s> "
  prompt += f"[INST] {message} [/INST]"
  return prompt

def format_prompt_openai(system_prompt, message, history):
  messages = []
  if system_prompt != '':
    messages.append({"role": "system", "content": system_prompt})
  for user_prompt, bot_response in history:
    messages.append({"role": "user", "content": user_prompt})
    messages.append({"role": "assistant", "content": bot_response})
  messages.append({"role": "user", "content": message})  
  return messages 

def chat_huggingface(prompt, history, chat_client, temperature, max_new_tokens, top_p, repetition_penalty): 
  client = InferenceClient(
      chat_client,
      token=API_TOKEN
  )
  temperature = float(temperature)
  if temperature < 1e-2:
      temperature = 1e-2
  top_p = float(top_p)

  generate_kwargs = dict(
      temperature=temperature,
      max_new_tokens=max_new_tokens,
      top_p=top_p,
      repetition_penalty=repetition_penalty,
      do_sample=True,
      seed=random.randint(0, 10**7),
  )
  formatted_prompt = format_prompt(prompt, history)
  print('***************************************************')
  print(formatted_prompt)
  print('***************************************************')
  stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
  return stream   

def chat_openai(prompt, history, chat_client, temperature, max_new_tokens, top_p, repetition_penalty, client_openai): 
  try: 
    prompt = prompt.replace('\n', '')
    json_data = json.loads(prompt)
    user_prompt = json_data["messages"][1]["content"]
    system_prompt = json_data["input"]["content"]
    system_style = json_data["input"]["style"]
    instructions = json_data["messages"][0]["content"]
    if instructions != '':
        system_prompt += '\n' + instructions
    if system_style != '':
        system_prompt += '\n' + system_style
  except: 
    user_prompt = prompt
    system_prompt = ''
  messages = format_prompt_openai(system_prompt, user_prompt, history)
  print('***************************************************')
  print(messages)
  print('***************************************************')
  stream = client_openai.chat.completions.create(
      model=chat_client,
      stream=True,
      messages=messages, 
      temperature=temperature,
      max_tokens=max_new_tokens,
  )  
  return stream

def chat(prompt, history, chat_client,temperature=0.9, max_new_tokens=1024, top_p=0.95, repetition_penalty=1.0, client_openai = None):
  if chat_client[:3] == 'gpt': 
    return chat_openai(prompt, history, chat_client, temperature, max_new_tokens, top_p, repetition_penalty, client_openai)   
  else: 
    return chat_huggingface(prompt, history, chat_client, temperature, max_new_tokens, top_p, repetition_penalty)