# utils/lora_details.py

import gradio as gr
from utils.constants import LORA_DETAILS, MODELS, LORAS

def get_lora_models():
    return [(item["image"], item["title"]) for item in LORAS]

def upd_prompt_notes_by_index(lora_index):
    """
    Updates the prompt notes label based on the selected LoRA model.
    """
    try:
        if LORAS[lora_index]:
            notes = LORAS[lora_index].get('notes', None)
            if notes is None:
                trigger_word = LORAS[lora_index].get('trigger_word', "")
                trigger_position = LORAS[lora_index].get('trigger_position', "")
                notes = f"{trigger_position} '{trigger_word}' in prompt"
    except IndexError:
        notes = (
            "Enter prompt description of your image. \n"
            "Using models without LoRA may take 30 minutes."
        )
    return gr.update(value=notes)

def get_trigger_words_by_index(lora_index):
    """
    Retrieves the trigger words from LORAS for the specified index.

    Args:
        lora_index (int): The index of the selected LoRA model.

    Returns:
        str: The trigger words associated with the model, or an empty string if not found.
    """
    try:
        trigger_words = LORAS[lora_index].get('trigger_word', "")
    except IndexError:
        trigger_words = ""
    return trigger_words

def upd_prompt_notes(model_textbox_value):
    """
    Updates the prompt_notes_label with the notes from LORA_DETAILS.

    Args:
        model_textbox_value (str): The name of the LoRA model.

    Returns:
        gr.update: Updated Gradio label component with the notes.
    """
    notes = ""
    if model_textbox_value in LORA_DETAILS:
        lora_detail_list = LORA_DETAILS[model_textbox_value]
        for item in lora_detail_list:
            if 'notes' in item:
                notes = item['notes']
                break
    else:
        notes = "Enter Prompt description of your image, \nusing models without LoRa may take a 30 minutes."
    return gr.update(value=notes)

def get_trigger_words(model_textbox_value):
    """
    Retrieves the trigger words from constants.LORA_DETAILS for the specified model.

    Args:
        model_textbox_value (str): The name of the LoRA model.

    Returns:
        str: The trigger words associated with the model, or a default message if not found.
    """
    trigger_words = ""
    if model_textbox_value in LORA_DETAILS:
        lora_detail_list = LORA_DETAILS[model_textbox_value]
        for item in lora_detail_list:
            if 'trigger_words' in item:
                trigger_words = item['trigger_words']
                break
    else:
        trigger_words = ""
    return trigger_words

def upd_trigger_words(model_textbox_value):
    """
    Updates the trigger_words_label with the trigger words from LORA_DETAILS.

    Args:
        model_textbox_value (str): The name of the LoRA model.

    Returns:
        gr.update: Updated Gradio label component with the trigger words.
    """
    trigger_words = get_trigger_words(model_textbox_value)
    return gr.update(value=trigger_words)

def approximate_token_count(prompt):
    """
    Approximates the number of tokens in a prompt based on word count.

    Parameters:
        prompt (str): The text prompt.

    Returns:
        int: The approximate number of tokens.
    """
    words = prompt.split()
    # Average tokens per word (can vary based on language and model)
    tokens_per_word = 1.35
    return int(len(words) * tokens_per_word)

def split_prompt_by_tokens(prompt, token_number):
    words = prompt.split()
    # Average tokens per word (can vary based on language and model)
    tokens_per_word = 1.3
    return ' '.join(words[:int(tokens_per_word * token_number)]), ' '.join(words[int(tokens_per_word * token_number):])

# Split prompt precisely by token count
import tiktoken

def split_prompt_precisely(prompt, max_tokens=77, model="gpt-3.5-turbo"):
    try:
        encoding = tiktoken.encoding_for_model(model)
    except KeyError:
        encoding = tiktoken.get_encoding("cl100k_base")
       
    tokens = encoding.encode(prompt)
       
    if len(tokens) <= max_tokens:
        return prompt, ""
       
    # Find the split point
    split_point = max_tokens - 1
    split_tokens = tokens[:split_point]
    remaining_tokens = tokens[split_point:]
       
    split_prompt = encoding.decode(split_tokens)
    remaining_prompt = encoding.decode(remaining_tokens)
       
    return split_prompt, remaining_prompt

def is_lora_loaded(pipe, adapter_name):
    """
    Check if a LoRA weight with the given adapter name is already loaded in the pipeline.

    Args:
        pipe (FluxPipeline): The pipeline to check.
        adapter_name (str): The adapter name of the LoRA weight.

    Returns:
        bool: True if the LoRA weight is loaded, False otherwise.
    """
    adapter_list = pipe.get_list_adapters()
    for component_adapters in adapter_list.values():
        if adapter_name in component_adapters:
            return True

    if adapter_name in getattr(pipe, "peft_config", {}):
        return True
    return False