File size: 5,107 Bytes
650c805
 
 
7addd34
ced6a2a
7addd34
 
ced6a2a
7addd34
 
 
ced6a2a
 
 
 
 
 
 
 
 
7addd34
 
 
 
ced6a2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
650c805
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ef117e
650c805
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ef117e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7addd34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# utils/lora_details.py

import gradio as gr
from utils.constants import LORA_DETAILS, MODELS, LORAS

def get_lora_models():
    return [(item["image"], item["title"]) for item in LORAS]

def upd_prompt_notes_by_index(lora_index):
    """
    Updates the prompt notes label based on the selected LoRA model.
    """
    try:
        if LORAS[lora_index]:
            notes = LORAS[lora_index].get('notes', None)
            if notes is None:
                trigger_word = LORAS[lora_index].get('trigger_word', "")
                trigger_position = LORAS[lora_index].get('trigger_position', "")
                notes = f"{trigger_position} '{trigger_word}' in prompt"
    except IndexError:
        notes = (
            "Enter prompt description of your image. \n"
            "Using models without LoRA may take 30 minutes."
        )
    return gr.update(value=notes)

def get_trigger_words_by_index(lora_index):
    """
    Retrieves the trigger words from LORAS for the specified index.

    Args:
        lora_index (int): The index of the selected LoRA model.

    Returns:
        str: The trigger words associated with the model, or an empty string if not found.
    """
    try:
        trigger_words = LORAS[lora_index].get('trigger_word', "")
    except IndexError:
        trigger_words = ""
    return trigger_words

def upd_prompt_notes(model_textbox_value):
    """
    Updates the prompt_notes_label with the notes from LORA_DETAILS.

    Args:
        model_textbox_value (str): The name of the LoRA model.

    Returns:
        gr.update: Updated Gradio label component with the notes.
    """
    notes = ""
    if model_textbox_value in LORA_DETAILS:
        lora_detail_list = LORA_DETAILS[model_textbox_value]
        for item in lora_detail_list:
            if 'notes' in item:
                notes = item['notes']
                break
    else:
        notes = "Enter Prompt description of your image, \nusing models without LoRa may take a 30 minutes."
    return gr.update(value=notes)

def get_trigger_words(model_textbox_value):
    """
    Retrieves the trigger words from constants.LORA_DETAILS for the specified model.

    Args:
        model_textbox_value (str): The name of the LoRA model.

    Returns:
        str: The trigger words associated with the model, or a default message if not found.
    """
    trigger_words = ""
    if model_textbox_value in LORA_DETAILS:
        lora_detail_list = LORA_DETAILS[model_textbox_value]
        for item in lora_detail_list:
            if 'trigger_words' in item:
                trigger_words = item['trigger_words']
                break
    else:
        trigger_words = ""
    return trigger_words

def upd_trigger_words(model_textbox_value):
    """
    Updates the trigger_words_label with the trigger words from LORA_DETAILS.

    Args:
        model_textbox_value (str): The name of the LoRA model.

    Returns:
        gr.update: Updated Gradio label component with the trigger words.
    """
    trigger_words = get_trigger_words(model_textbox_value)
    return gr.update(value=trigger_words)

def approximate_token_count(prompt):
    """
    Approximates the number of tokens in a prompt based on word count.

    Parameters:
        prompt (str): The text prompt.

    Returns:
        int: The approximate number of tokens.
    """
    words = prompt.split()
    # Average tokens per word (can vary based on language and model)
    tokens_per_word = 1.3
    return int(len(words) * tokens_per_word)

def split_prompt_by_tokens(prompt, token_number):
    words = prompt.split()
    # Average tokens per word (can vary based on language and model)
    tokens_per_word = 1.3
    return ' '.join(words[:int(tokens_per_word * token_number)]), ' '.join(words[int(tokens_per_word * token_number):])

# Split prompt precisely by token count
import tiktoken

def split_prompt_precisely(prompt, max_tokens=77, model="gpt-3.5-turbo"):
    try:
        encoding = tiktoken.encoding_for_model(model)
    except KeyError:
        encoding = tiktoken.get_encoding("cl100k_base")
       
    tokens = encoding.encode(prompt)
       
    if len(tokens) <= max_tokens:
        return prompt, ""
       
    # Find the split point
    split_point = max_tokens
    split_tokens = tokens[:split_point]
    remaining_tokens = tokens[split_point:]
       
    split_prompt = encoding.decode(split_tokens)
    remaining_prompt = encoding.decode(remaining_tokens)
       
    return split_prompt, remaining_prompt

def is_lora_loaded(pipe, adapter_name):
    """
    Check if a LoRA weight with the given adapter name is already loaded in the pipeline.

    Args:
        pipe (FluxPipeline): The pipeline to check.
        adapter_name (str): The adapter name of the LoRA weight.

    Returns:
        bool: True if the LoRA weight is loaded, False otherwise.
    """
    adapter_list = pipe.get_list_adapters()
    for component_adapters in adapter_list.values():
        if adapter_name in component_adapters:
            return True

    if adapter_name in getattr(pipe, "peft_config", {}):
        return True
    return False