Spaces:
Running
on
Zero
Running
on
Zero
# utils/lora_details.py | |
import gradio as gr | |
from utils.constants import LORA_DETAILS, MODELS, LORAS | |
def get_lora_models(): | |
return [(item["image"], item["title"]) for item in LORAS] | |
def upd_prompt_notes_by_index(lora_index): | |
""" | |
Updates the prompt notes label based on the selected LoRA model. | |
""" | |
try: | |
if LORAS[lora_index]: | |
notes = LORAS[lora_index].get('notes', None) | |
if notes is None: | |
trigger_word = LORAS[lora_index].get('trigger_word', "") | |
trigger_position = LORAS[lora_index].get('trigger_position', "") | |
notes = f"{trigger_position} '{trigger_word}' in prompt" | |
except IndexError: | |
notes = ( | |
"Enter prompt description of your image. \n" | |
"Using models without LoRA may take 30 minutes." | |
) | |
return gr.update(value=notes) | |
def get_trigger_words_by_index(lora_index): | |
""" | |
Retrieves the trigger words from LORAS for the specified index. | |
Args: | |
lora_index (int): The index of the selected LoRA model. | |
Returns: | |
str: The trigger words associated with the model, or an empty string if not found. | |
""" | |
try: | |
trigger_words = LORAS[lora_index].get('trigger_word', "") | |
except IndexError: | |
trigger_words = "" | |
return trigger_words | |
def upd_prompt_notes(model_textbox_value): | |
""" | |
Updates the prompt_notes_label with the notes from LORA_DETAILS. | |
Args: | |
model_textbox_value (str): The name of the LoRA model. | |
Returns: | |
gr.update: Updated Gradio label component with the notes. | |
""" | |
notes = "" | |
if model_textbox_value in LORA_DETAILS: | |
lora_detail_list = LORA_DETAILS[model_textbox_value] | |
for item in lora_detail_list: | |
if 'notes' in item: | |
notes = item['notes'] | |
break | |
else: | |
notes = "Enter Prompt description of your image, \nusing models without LoRa may take a 30 minutes." | |
return gr.update(value=notes) | |
def get_trigger_words(model_textbox_value): | |
""" | |
Retrieves the trigger words from constants.LORA_DETAILS for the specified model. | |
Args: | |
model_textbox_value (str): The name of the LoRA model. | |
Returns: | |
str: The trigger words associated with the model, or a default message if not found. | |
""" | |
trigger_words = "" | |
if model_textbox_value in LORA_DETAILS: | |
lora_detail_list = LORA_DETAILS[model_textbox_value] | |
for item in lora_detail_list: | |
if 'trigger_words' in item: | |
trigger_words = item['trigger_words'] | |
break | |
else: | |
trigger_words = "" | |
return trigger_words | |
def upd_trigger_words(model_textbox_value): | |
""" | |
Updates the trigger_words_label with the trigger words from LORA_DETAILS. | |
Args: | |
model_textbox_value (str): The name of the LoRA model. | |
Returns: | |
gr.update: Updated Gradio label component with the trigger words. | |
""" | |
trigger_words = get_trigger_words(model_textbox_value) | |
return gr.update(value=trigger_words) | |
def approximate_token_count(prompt): | |
""" | |
Approximates the number of tokens in a prompt based on word count. | |
Parameters: | |
prompt (str): The text prompt. | |
Returns: | |
int: The approximate number of tokens. | |
""" | |
words = prompt.split() | |
# Average tokens per word (can vary based on language and model) | |
tokens_per_word = 1.3 | |
return int(len(words) * tokens_per_word) | |
def split_prompt_by_tokens(prompt, token_number): | |
words = prompt.split() | |
# Average tokens per word (can vary based on language and model) | |
tokens_per_word = 1.3 | |
return ' '.join(words[:int(tokens_per_word * token_number)]), ' '.join(words[int(tokens_per_word * token_number):]) | |
# Split prompt precisely by token count | |
import tiktoken | |
def split_prompt_precisely(prompt, max_tokens=77, model="gpt-3.5-turbo"): | |
try: | |
encoding = tiktoken.encoding_for_model(model) | |
except KeyError: | |
encoding = tiktoken.get_encoding("cl100k_base") | |
tokens = encoding.encode(prompt) | |
if len(tokens) <= max_tokens: | |
return prompt, "" | |
# Find the split point | |
split_point = max_tokens | |
split_tokens = tokens[:split_point] | |
remaining_tokens = tokens[split_point:] | |
split_prompt = encoding.decode(split_tokens) | |
remaining_prompt = encoding.decode(remaining_tokens) | |
return split_prompt, remaining_prompt | |
def is_lora_loaded(pipe, adapter_name): | |
""" | |
Check if a LoRA weight with the given adapter name is already loaded in the pipeline. | |
Args: | |
pipe (FluxPipeline): The pipeline to check. | |
adapter_name (str): The adapter name of the LoRA weight. | |
Returns: | |
bool: True if the LoRA weight is loaded, False otherwise. | |
""" | |
adapter_list = pipe.get_list_adapters() | |
for component_adapters in adapter_list.values(): | |
if adapter_name in component_adapters: | |
return True | |
if adapter_name in getattr(pipe, "peft_config", {}): | |
return True | |
return False | |