import gradio as gr 
import os
from typing import List
import logging
import urllib.request
from utils import model_name_mapping, urial_template, openai_base_request
from constant import js_code_label, HEADER_MD
from openai import OpenAI
import datetime
# add logging info to console 
logging.basicConfig(level=logging.INFO)

URIAL_VERSION = "inst_1k_v4.help"
URIAL_URL = f"https://raw.githubusercontent.com/Re-Align/URIAL/main/urial_prompts/{URIAL_VERSION}.txt"
urial_prompt = urllib.request.urlopen(URIAL_URL).read().decode('utf-8')
urial_prompt = urial_prompt.replace("```", '"""') # new version of URIAL uses """ instead of ```
STOP_STRS = ['"""', '# Query:', '# Answer:']

addr_limit_counter = {}
LAST_UPDATE_TIME = datetime.datetime.now() 


def respond(
    message,
    history: list[tuple[str, str]],
    max_tokens,
    temperature,
    top_p,
    rp,
    model_name,
    api_key,
    request:gr.Request
):  
    global STOP_STRS, urial_prompt, LAST_UPDATE_TIME, addr_limit_counter
    rp = 1.0
    prompt = urial_template(urial_prompt, history, message)
    
    # _model_name = "meta-llama/Llama-3-8b-hf"
    _model_name = model_name_mapping(model_name)

    if api_key and len(api_key) == 64:
        api_key = api_key
    else:
        api_key = None

    # headers = request.headers
    # if already 24 hours passed, reset the counter
    if datetime.datetime.now() - LAST_UPDATE_TIME > datetime.timedelta(days=1):
        addr_limit_counter = {}
        LAST_UPDATE_TIME = datetime.datetime.now()
    host_addr = request.client.host
    if host_addr not in addr_limit_counter:
        addr_limit_counter[host_addr] = 0
    if addr_limit_counter[host_addr] > 100:
        return "You have reached the limit of 100 requests for today. Please use your own API key."

    infer_request = openai_base_request(prompt=prompt, model=_model_name, 
                                   temperature=temperature, 
                                   max_tokens=max_tokens, 
                                   top_p=top_p, 
                                   repetition_penalty=rp,
                                   stop=STOP_STRS, api_key=api_key)  
    addr_limit_counter[host_addr] += 1
    logging.info(f"Requesting chat completion from OpenAI API with model {_model_name}")
    logging.info(f"addr_limit_counter: {addr_limit_counter}; Last update time: {LAST_UPDATE_TIME};")

    response = ""
    for msg in infer_request:
        # print(msg.choices[0].delta.keys())
        if hasattr(msg.choices[0], "delta"):
            token = msg.choices[0].delta["content"]
        else:
            token = msg.choices[0].text
        should_stop = False
        for _stop in STOP_STRS:
            if _stop in response + token:
                should_stop = True
                break
        if should_stop:
            break
        response += token
        if response.endswith('\n"'):
            response = response[:-1]
        elif response.endswith('\n""'):
            response = response[:-2]
        yield response
 
with gr.Blocks(gr.themes.Soft(), js=js_code_label) as demo:
    with gr.Row():
        with gr.Column():
            gr.Markdown(HEADER_MD)
            model_name = gr.Radio(["Llama-3.1-405B-FP8", "Llama-3-70B", "Llama-3-8B", 
                                   "Mistral-7B-v0.1", 
                                   "Mixtral-8x22B", "Qwen1.5-72B", "Yi-34B", "Llama-2-7B", "Llama-2-70B", "OLMO"]
                                  , value="Llama-3.1-405B-FP8", label="Base LLM name")
        with gr.Column():
            api_key = gr.Textbox(label="🔑 APIKey", placeholder="Enter your Together/Hyperbolic API Key. Leave it blank to use our key with limited usage.", type="password", elem_id="api_key", visible=False)
            # with gr.Column():
            with gr.Accordion("⚙️ Parameters for Base LLM", open=True):
                with gr.Row():
                    max_tokens = gr.Textbox(value=256, label="Max tokens")
                    temperature = gr.Textbox(value=0.5, label="Temperature")
                    top_p = gr.Textbox(value=0.9, label="Top-p")
                    rp = gr.Textbox(value=1.1, label="Repetition penalty")
    # with gr.Row():            
    chat = gr.ChatInterface(
        respond,
        additional_inputs=[max_tokens, temperature, top_p, rp, model_name, api_key],
        # additional_inputs_accordion="⚙️ Parameters",
        # fill_height=True, 
    )
    chat.chatbot.label="Chat with Base LLMs via URIAL"
    chat.chatbot.height = 550
    chat.chatbot.show_copy_button = True  

if __name__ == "__main__": 
    demo.launch(show_api=False)