import openai
import gradio as gr
import os
from dotenv import load_dotenv
import tiktoken

from typing import Any, Dict, Generator, List

from huggingface_hub import InferenceClient
from transformers import AutoTokenizer

load_dotenv()

OPENAI_KEY = os.getenv("OPENAI_API_KEY")
HF_TOKEN = os.getenv("HF_TOKEN")
HF_MODEL = os.getenv("HF_MODEL")
if not HF_MODEL:
    raise ValueError("HF_MODEL environment variable is not set")
TOKENIZER = AutoTokenizer.from_pretrained(HF_MODEL)

HF_CLIENT = InferenceClient(
    os.getenv("HF_MODEL"),
    token=HF_TOKEN
)
OAI_CLIENT = openai.Client(api_key=OPENAI_KEY)

HF_GENERATE_KWARGS = {
    'temperature': max(float(os.getenv("TEMPERATURE", 0.9)), 1e-2),
    'max_new_tokens': int(os.getenv("MAX_NEW_TOKENS", 256)),
    'top_p': float(os.getenv("TOP_P", 0.6)),
    'repetition_penalty': float(os.getenv("REP_PENALTY", 1.2)),
    'do_sample': bool(os.getenv("DO_SAMPLE", True))
}

OAI_GENERATE_KWARGS = {
    'temperature': max(float(os.getenv("TEMPERATURE", 0.9)), 1e-2),
    'max_tokens': int(os.getenv("MAX_NEW_TOKENS", 256)),
    'top_p': float(os.getenv("TOP_P", 0.6)),
    'frequency_penalty': max(-2, min(float(os.getenv("FREQ_PENALTY", 0)), 2))
}


def format_prompt(message: str, api_kind: str):
    """
    Formats the given message using a chat template.

    Args:
        message (str): The user message to be formatted.
        api_kind (str): LLM API provider.
    Returns:
        str: Formatted message after applying the chat template.
    """

    # Create a list of message dictionaries with role and content
    messages: List[Dict[str, str]] = [{'role': 'user', 'content': message}]

    if api_kind == "openai":
        return messages
    elif api_kind == "hf":
        return TOKENIZER.apply_chat_template(messages, tokenize=False)
    elif api_kind:
        raise ValueError("API is not supported")


def generate_hf(prompt: str, history: str) -> Generator[str, None, str]:
    """
    Generate a sequence of tokens based on a given prompt and history using Mistral client.

    Args:
        prompt (str): The prompt for the text generation.
        history (str): Context or history for the text generation.
    Returns:
        Generator[str, None, str]: A generator yielding chunks of generated text.
                                   Returns a final string if an error occurs.
    """

    formatted_prompt = format_prompt(prompt, "hf")
    formatted_prompt = formatted_prompt.encode("utf-8").decode("utf-8")
    # print(f'-------------START Formatted prompt: {formatted_prompt}')
    # print('-------------END Formatted prompt')
    try:
        stream = HF_CLIENT.text_generation(
            formatted_prompt,
            **HF_GENERATE_KWARGS,
            stream=True,
            details=True,
            return_full_text=False
        )
        output = ""
        for response in stream:
            output += response.token.text
            yield output

    except Exception as e:
        if "Too Many Requests" in str(e):
            raise gr.Error(f"Too many requests: {str(e)}")
        elif "Authorization header is invalid" in str(e):
            raise gr.Error("Authentication error: HF token was either not provided or incorrect")
        else:
            raise gr.Error(f"Unhandled Exception: {str(e)}")


def generate_openai(prompt: str, history: str) -> Generator[str, None, str]:
    """
    Generate a sequence of tokens based on a given prompt and history using Mistral client.

    Args:
        prompt (str): The initial prompt for the text generation.
        history (str): Context or history for the text generation.
    Returns:
        Generator[str, None, str]: A generator yielding chunks of generated text.
                                   Returns a final string if an error occurs.
    """
    formatted_prompt = format_prompt(prompt, "openai")

    OPENAI_MODEL = os.getenv("OPENAI_MODEL")

    if not OPENAI_MODEL:
        raise ValueError("OPENAI_MODEL environment variable is not set")

    try:
        stream = OAI_CLIENT.chat.completions.create(
            model=OPENAI_MODEL,
            messages=formatted_prompt,
            **OAI_GENERATE_KWARGS,
            stream=True
        )
        output = ""
        for chunk in stream:
            if chunk.choices[0].delta.content:
                output += chunk.choices[0].delta.content
                yield output

    except Exception as e:
        if "Too Many Requests" in str(e):
            raise gr.Error("ERROR: Too many requests on OpenAI client")
        elif "You didn't provide an API key" in str(e):
            raise gr.Error("Authentication error: OpenAI key was either not provided or incorrect")
        else:
            raise gr.Error(f"Unhandled Exception: {str(e)}")

def get_max_length(texts: list[str]) -> int:
    encoding = tiktoken.get_encoding("cl100k_base")
    max_len = 0
    for text in texts:
        max_len = max(max_len, len(encoding.encode(text)))
    return max_len