Surn commited on
Commit
650c805
·
1 Parent(s): 061d802

Working Version with negative prompts and dynamic trigger words

Browse files
utils/ai_generator.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils/ai_generator.py
2
+
3
+ import os
4
+ import time # Added for implementing delays
5
+ import torch
6
+ import random
7
+ from utils.ai_generator_diffusers_flux import generate_ai_image_local
8
+ from pathlib import Path
9
+ from huggingface_hub import InferenceClient
10
+ import requests
11
+ import io
12
+ from PIL import Image
13
+ from tempfile import NamedTemporaryFile
14
+ import utils.constants as constants
15
+
16
+ def generate_image_from_text(text, model_name="flax-community/dalle-mini", image_width=768, image_height=512):
17
+ # Initialize the InferenceClient
18
+ client = InferenceClient()
19
+ # Generate the image from the text
20
+ response = client(text, model_name)
21
+ # Get the image data
22
+ image_data = response.content
23
+ # Load the image from the data
24
+ image = Image.open(io.BytesIO(image_data))
25
+ # Resize the image
26
+ image = image.resize((image_width, image_height))
27
+ return image
28
+
29
+ def generate_ai_image(
30
+ map_option,
31
+ prompt_textbox_value,
32
+ neg_prompt_textbox_value,
33
+ model,
34
+ lora_weights=None,
35
+ *args,
36
+ **kwargs
37
+ ):
38
+ seed = random.randint(1, 99999)
39
+ if torch.cuda.is_available():
40
+ print("Local GPU available. Generating image locally.")
41
+ return generate_ai_image_local(
42
+ map_option,
43
+ prompt_textbox_value,
44
+ neg_prompt_textbox_value,
45
+ model,
46
+ lora_weights=lora_weights,
47
+ seed=seed
48
+ )
49
+ else:
50
+ print("No local GPU available. Sending request to Hugging Face API.")
51
+ return generate_ai_image_remote(
52
+ map_option,
53
+ prompt_textbox_value,
54
+ neg_prompt_textbox_value,
55
+ model
56
+ )
57
+
58
+ def generate_ai_image_remote(map_option, prompt_textbox_value, neg_prompt_textbox_value, model, height=512, width=896, num_inference_steps=50, guidance_scale=3.5, seed=777):
59
+ max_retries = 3
60
+ retry_delay = 4 # Initial delay in seconds
61
+
62
+ try:
63
+ if map_option != "Prompt":
64
+ prompt = constants.PROMPTS[map_option]
65
+ # Convert the negative prompt string to a list
66
+ negative_prompt_str = constants.NEGATIVE_PROMPTS.get(map_option, "")
67
+ negative_prompt = [p.strip() for p in negative_prompt_str.split(',') if p.strip()]
68
+ else:
69
+ prompt = prompt_textbox_value
70
+ # Convert the negative prompt string to a list
71
+ negative_prompt = [p.strip() for p in neg_prompt_textbox_value.split(',') if p.strip()] if neg_prompt_textbox_value else []
72
+
73
+ print("Remotely Generating image with the following parameters:")
74
+ print(f"Prompt: {prompt}")
75
+ print(f"Negative Prompt: {negative_prompt}")
76
+ print(f"Height: {height}")
77
+ print(f"Width: {width}")
78
+ print(f"Number of Inference Steps: {num_inference_steps}")
79
+ print(f"Guidance Scale: {guidance_scale}")
80
+ print(f"Seed: {seed}")
81
+
82
+ for attempt in range(1, max_retries + 1):
83
+ try:
84
+ if os.getenv("IS_SHARED_SPACE") == "True":
85
+ client = InferenceClient(
86
+ model,
87
+ token=constants.HF_API_TOKEN
88
+ )
89
+ image = client.text_to_image(
90
+ inputs=prompt,
91
+ parameters={
92
+ "guidance_scale": guidance_scale,
93
+ "num_inference_steps": num_inference_steps,
94
+ "width": width,
95
+ "height": height,
96
+ "max_sequence_length":512,
97
+ # Optional: Add 'scheduler' and 'seed' if needed
98
+ "seed": seed
99
+ }
100
+ )
101
+ else:
102
+ API_URL = f"https://api-inference.huggingface.co/models/{model}"
103
+ headers = {
104
+ "Authorization": f"Bearer {constants.HF_API_TOKEN}",
105
+ "Content-Type": "application/json"
106
+ }
107
+ payload = {
108
+ "inputs": prompt,
109
+ "parameters": {
110
+ "guidance_scale": guidance_scale,
111
+ "num_inference_steps": num_inference_steps,
112
+ "width": width,
113
+ "height": height,
114
+ "max_sequence_length":512,
115
+ # Optional: Add 'scheduler' and 'seed' if needed
116
+ "seed": seed
117
+ }
118
+ }
119
+
120
+ print(f"Attempt {attempt}: Sending POST request to Hugging Face API...")
121
+ response = requests.post(API_URL, headers=headers, json=payload, timeout=300) # Increased timeout to 30 seconds
122
+ if response.status_code == 200:
123
+ image_bytes = response.content
124
+ image = Image.open(io.BytesIO(image_bytes))
125
+ break # Exit the retry loop on success
126
+ elif response.status_code == 400:
127
+ # Handle 400 Bad Request specifically
128
+ print(f"Bad Request (400): {response.text}")
129
+ print("Check your request parameters and payload format.")
130
+ return None # Do not retry on 400 errors
131
+ elif response.status_code in [429, 504]:
132
+ print(f"Received status code {response.status_code}. Retrying in {retry_delay} seconds...")
133
+ if attempt < max_retries:
134
+ time.sleep(retry_delay)
135
+ retry_delay *= 2 # Exponential backoff
136
+ else:
137
+ response.raise_for_status() # Raise exception after max retries
138
+ else:
139
+ print(f"Received unexpected status code {response.status_code}: {response.text}")
140
+ response.raise_for_status()
141
+ except (requests.exceptions.ReadTimeout, requests.exceptions.ConnectTimeout) as timeout_error:
142
+ print(f"Timeout occurred: {timeout_error}. Retrying in {retry_delay} seconds...")
143
+ if attempt < max_retries:
144
+ time.sleep(retry_delay)
145
+ retry_delay *= 2 # Exponential backoff
146
+ else:
147
+ raise # Re-raise the exception after max retries
148
+ except requests.exceptions.RequestException as req_error:
149
+ print(f"Request exception: {req_error}. Retrying in {retry_delay} seconds...")
150
+ if attempt < max_retries:
151
+ time.sleep(retry_delay)
152
+ retry_delay *= 2 # Exponential backoff
153
+ else:
154
+ raise # Re-raise the exception after max retries
155
+
156
+ else:
157
+ # If all retries failed
158
+ print("Max retries exceeded. Failed to generate image.")
159
+ return None
160
+
161
+ with NamedTemporaryFile(delete=False, suffix=".png") as tmp:
162
+ image.save(tmp.name, format="PNG")
163
+ constants.temp_files.append(tmp.name)
164
+ print(f"Image saved to {tmp.name}")
165
+ return tmp.name
166
+
167
+ except Exception as e:
168
+ print(f"Error generating AI image: {e}")
169
+ return None
utils/ai_generator_diffusers_flux.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils/ai_generator_diffusers_flux.py
2
+ import os
3
+ import torch
4
+ import accelerate
5
+ import transformers
6
+ import safetensors
7
+ import xformers
8
+ from diffusers import FluxPipeline
9
+ from diffusers.utils import load_image
10
+ # from huggingface_hub import hf_hub_download
11
+ from PIL import Image
12
+ from tempfile import NamedTemporaryFile
13
+ from src.condition import Condition
14
+ import utils.constants as constants
15
+ from utils.image_utils import (
16
+ crop_and_resize_image,
17
+ )
18
+ from utils.version_info import (
19
+ versions_html,
20
+ get_torch_info,
21
+ get_diffusers_version,
22
+ get_transformers_version,
23
+ get_xformers_version
24
+ )
25
+ from utils.lora_details import get_trigger_words
26
+ from utils.color_utils import detect_color_format
27
+ # import utils.misc as misc
28
+ from pathlib import Path
29
+ import warnings
30
+ warnings.filterwarnings("ignore", message=".*Torch was not compiled with flash attention.*")
31
+ #print(torch.__version__) # Ensure it's 2.0 or newer
32
+ #print(torch.cuda.is_available()) # Ensure CUDA is available
33
+
34
+ def generate_image_from_text(
35
+ text,
36
+ model_name="black-forest-labs/FLUX.1-dev",
37
+ lora_weights=None,
38
+ conditioned_image=None,
39
+ image_width=1344,
40
+ image_height=848,
41
+ guidance_scale=3.5,
42
+ num_inference_steps=50,
43
+ seed=0,
44
+ additional_parameters=None
45
+ ):
46
+ device = "cuda" if torch.cuda.is_available() else "cpu"
47
+ print(f"device:{device}\nmodel_name:{model_name}\n")
48
+ pipe = FluxPipeline.from_pretrained(
49
+ model_name,
50
+ torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32
51
+ ).to(device)
52
+ pipe = pipe.to(device)
53
+ pipe.enable_model_cpu_offload()
54
+ # Load and apply LoRA weights
55
+ if lora_weights:
56
+ for lora_weight in lora_weights:
57
+ lora_configs = constants.LORA_DETAILS.get(lora_weight, [])
58
+ if lora_configs:
59
+ for config in lora_configs:
60
+ weight_name = config.get("weight_name")
61
+ adapter_name = config.get("adapter_name")
62
+ pipe.load_lora_weights(
63
+ lora_weight,
64
+ weight_name=weight_name,
65
+ adapter_name=adapter_name,
66
+ use_auth_token=constants.HF_API_TOKEN
67
+ )
68
+ else:
69
+ pipe.load_lora_weights(lora_weight, use_auth_token=constants.HF_API_TOKEN)
70
+ generator = torch.Generator(device=device).manual_seed(seed)
71
+ conditions = []
72
+ if conditioned_image is not None:
73
+ conditioned_image = crop_and_resize_image(conditioned_image, 1024, 1024)
74
+ condition = Condition("subject", conditioned_image)
75
+ conditions.append(condition)
76
+ generate_params = {
77
+ "prompt": text,
78
+ "height": image_height,
79
+ "width": image_width,
80
+ "guidance_scale": guidance_scale,
81
+ "num_inference_steps": num_inference_steps,
82
+ "generator": generator,
83
+ "conditions": conditions if conditions else None
84
+ }
85
+ if additional_parameters:
86
+ generate_params.update(additional_parameters)
87
+ generate_params = {k: v for k, v in generate_params.items() if v is not None}
88
+ result = pipe(**generate_params)
89
+ image = result.images[0]
90
+ return image
91
+
92
+ def generate_image_lowmem(
93
+ text,
94
+ neg_prompt=None,
95
+ model_name="black-forest-labs/FLUX.1-dev",
96
+ lora_weights=None,
97
+ conditioned_image=None,
98
+ image_width=1344,
99
+ image_height=848,
100
+ guidance_scale=3.5,
101
+ num_inference_steps=50,
102
+ seed=0,
103
+ true_cfg_scale=1.0,
104
+ additional_parameters=None
105
+ ):
106
+ device = "cuda" if torch.cuda.is_available() else "cpu"
107
+ print(f"device:{device}\nmodel_name:{model_name}\n")
108
+ print(f"\n {get_torch_info()}\n")
109
+ # Disable gradient calculations
110
+ with torch.no_grad():
111
+ # Initialize the pipeline inside the context manager
112
+ pipe = FluxPipeline.from_pretrained(
113
+ model_name,
114
+ torch_dtype=torch.bfloat16 if device == "cuda" else torch.bfloat32
115
+ ).to(device)
116
+ # Optionally, don't use CPU offload if not necessary
117
+ pipe.enable_model_cpu_offload()
118
+ # alternative version that may be more efficient
119
+ # pipe.enable_sequential_cpu_offload()
120
+ flash_attention_enabled = torch.backends.cuda.flash_sdp_enabled()
121
+ if flash_attention_enabled == False:
122
+ #Enable xFormers memory-efficient attention (optional)
123
+ pipe.enable_xformers_memory_efficient_attention()
124
+ print("\nEnabled xFormers memory-efficient attention.\n")
125
+ else:
126
+ pipe.attn_implementation="flash_attention_2"
127
+ print("\nEnabled flash_attention_2.\n")
128
+ pipe.enable_vae_tiling()
129
+ # Load LoRA weights
130
+ if lora_weights:
131
+ for lora_weight in lora_weights:
132
+ lora_configs = constants.LORA_DETAILS.get(lora_weight, [])
133
+ if lora_configs:
134
+ for config in lora_configs:
135
+ # Load LoRA weights with optional weight_name and adapter_name
136
+ weight_name = config.get("weight_name")
137
+ adapter_name = config.get("adapter_name")
138
+ if weight_name and adapter_name:
139
+ pipe.load_lora_weights(
140
+ lora_weight,
141
+ weight_name=weight_name,
142
+ adapter_name=adapter_name,
143
+ use_auth_token=constants.HF_API_TOKEN
144
+ )
145
+ else:
146
+ pipe.load_lora_weights(
147
+ lora_weight,
148
+ use_auth_token=constants.HF_API_TOKEN
149
+ )
150
+
151
+ # Apply 'pipe' configurations if present
152
+ if 'pipe' in config:
153
+ pipe_config = config['pipe']
154
+ for method_name, params in pipe_config.items():
155
+ method = getattr(pipe, method_name, None)
156
+ if method:
157
+ print(f"Applying pipe method: {method_name} with params: {params}")
158
+ method(**params)
159
+ else:
160
+ print(f"Method {method_name} not found in pipe.")
161
+ else:
162
+ pipe.load_lora_weights(lora_weight, use_auth_token=constants.HF_API_TOKEN)
163
+ generator = torch.Generator(device=device).manual_seed(seed)
164
+ conditions = []
165
+ if conditioned_image is not None:
166
+ conditioned_image = crop_and_resize_image(conditioned_image, 1024, 1024)
167
+ condition = Condition("subject", conditioned_image)
168
+ conditions.append(condition)
169
+ if neg_prompt!=None:
170
+ true_cfg_scale=1.1
171
+ generate_params = {
172
+ "prompt": text,
173
+ "negative_prompt": neg_prompt,
174
+ "true_cfg_scale": true_cfg_scale,
175
+ "height": image_height,
176
+ "width": image_width,
177
+ "guidance_scale": guidance_scale,
178
+ "num_inference_steps": num_inference_steps,
179
+ "generator": generator,
180
+ "conditions": conditions if conditions else None
181
+ }
182
+ if additional_parameters:
183
+ generate_params.update(additional_parameters)
184
+ generate_params = {k: v for k, v in generate_params.items() if v is not None}
185
+ # Generate the image
186
+ result = pipe(**generate_params)
187
+ image = result.images[0]
188
+ # Clean up
189
+ del result
190
+ del conditions
191
+ del generator
192
+ # Delete the pipeline and clear cache
193
+ del pipe
194
+ torch.cuda.empty_cache()
195
+ print(torch.cuda.memory_summary(device=None, abbreviated=False))
196
+ return image
197
+
198
+ def generate_ai_image_local (
199
+ map_option,
200
+ prompt_textbox_value,
201
+ neg_prompt_textbox_value,
202
+ model="black-forest-labs/FLUX.1-dev",
203
+ lora_weights=None,
204
+ conditioned_image=None,
205
+ height=512,
206
+ width=896,
207
+ num_inference_steps=50,
208
+ guidance_scale=3.5,
209
+ seed=777
210
+ ):
211
+ try:
212
+ if map_option != "Prompt":
213
+ prompt = constants.PROMPTS[map_option]
214
+ negative_prompt = constants.NEGATIVE_PROMPTS.get(map_option, "")
215
+ else:
216
+ prompt = prompt_textbox_value
217
+ negative_prompt = neg_prompt_textbox_value or ""
218
+ #full_prompt = f"{prompt} {negative_prompt}"
219
+ additional_parameters = {}
220
+ if lora_weights:
221
+ for lora_weight in lora_weights:
222
+ lora_configs = constants.LORA_DETAILS.get(lora_weight, [])
223
+ for config in lora_configs:
224
+ if 'parameters' in config:
225
+ additional_parameters.update(config['parameters'])
226
+ elif 'trigger_words' in config:
227
+ trigger_words = get_trigger_words(lora_weight)
228
+ prompt = f"{trigger_words} {prompt}"
229
+ for key, value in additional_parameters.items():
230
+ if key in ['height', 'width', 'num_inference_steps', 'max_sequence_length']:
231
+ additional_parameters[key] = int(value)
232
+ elif key in ['guidance_scale','true_cfg_scale']:
233
+ additional_parameters[key] = float(value)
234
+ height = additional_parameters.get('height', height)
235
+ width = additional_parameters.get('width', width)
236
+ num_inference_steps = additional_parameters.get('num_inference_steps', num_inference_steps)
237
+ guidance_scale = additional_parameters.get('guidance_scale', guidance_scale)
238
+ print("Generating image with the following parameters:")
239
+ print(f"Model: {model}")
240
+ print(f"LoRA Weights: {lora_weights}")
241
+ print(f"Prompt: {prompt}")
242
+ print(f"Neg Prompt: {negative_prompt}")
243
+ print(f"Height: {height}")
244
+ print(f"Width: {width}")
245
+ print(f"Number of Inference Steps: {num_inference_steps}")
246
+ print(f"Guidance Scale: {guidance_scale}")
247
+ print(f"Seed: {seed}")
248
+ print(f"Additional Parameters: {additional_parameters}")
249
+ image = generate_image_lowmem(
250
+ text=prompt,
251
+ model_name=model,
252
+ neg_prompt=negative_prompt,
253
+ lora_weights=lora_weights,
254
+ conditioned_image=conditioned_image,
255
+ image_width=width,
256
+ image_height=height,
257
+ guidance_scale=guidance_scale,
258
+ num_inference_steps=num_inference_steps,
259
+ seed=seed,
260
+ additional_parameters=additional_parameters
261
+ )
262
+ with NamedTemporaryFile(delete=False, suffix=".png") as tmp:
263
+ image.save(tmp.name, format="PNG")
264
+ constants.temp_files.append(tmp.name)
265
+ print(f"Image saved to {tmp.name}")
266
+ return tmp.name
267
+ except Exception as e:
268
+ print(f"Error generating AI image: {e}")
269
+ return None
utils/constants.py CHANGED
@@ -15,6 +15,8 @@ os.environ['XFORMERS_FORCE_DISABLE_TRITON']= '1'
15
  os.environ["HF_TOKEN"] = ""
16
  HF_API_TOKEN = os.getenv("HF_TOKEN")
17
  default_lut_example_img = "./LUT/daisy.jpg"
 
 
18
 
19
  PROMPTS = {
20
  "Map1": "eight_color (tabletop_map built from small hexagon pieces) as ((empty black on all sides), barren alien_world_map), with light_blue_is_rivers and brown_is_mountains and red_is_volcano and [white_is_snow at the top and bottom of map] as (four_color background: light_blue, green, tan, brown), horizontal_gradient is (brown to tan to green to light_blue to blue) and vertical_gradient is (white to blue to (green, tan and red) to blue to white), (middle is dark, no_reflections, no_shadows), ((partial hexes on edges and sides are black))",
 
15
  os.environ["HF_TOKEN"] = ""
16
  HF_API_TOKEN = os.getenv("HF_TOKEN")
17
  default_lut_example_img = "./LUT/daisy.jpg"
18
+ os.environ["HF_TOKEN"] = """
19
+ HF_API_TOKEN = os.getenv("HF_TOKEN")
20
 
21
  PROMPTS = {
22
  "Map1": "eight_color (tabletop_map built from small hexagon pieces) as ((empty black on all sides), barren alien_world_map), with light_blue_is_rivers and brown_is_mountains and red_is_volcano and [white_is_snow at the top and bottom of map] as (four_color background: light_blue, green, tan, brown), horizontal_gradient is (brown to tan to green to light_blue to blue) and vertical_gradient is (white to blue to (green, tan and red) to blue to white), (middle is dark, no_reflections, no_shadows), ((partial hexes on edges and sides are black))",
utils/lora_details.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils/lora_details.py
2
+
3
+ import gradio as gr
4
+ from utils.constants import LORA_DETAILS
5
+
6
+ def upd_prompt_notes(model_textbox_value):
7
+ """
8
+ Updates the prompt_notes_label with the notes from LORA_DETAILS.
9
+
10
+ Args:
11
+ model_textbox_value (str): The name of the LoRA model.
12
+
13
+ Returns:
14
+ gr.update: Updated Gradio label component with the notes.
15
+ """
16
+ notes = ""
17
+ if model_textbox_value in LORA_DETAILS:
18
+ lora_detail_list = LORA_DETAILS[model_textbox_value]
19
+ for item in lora_detail_list:
20
+ if 'notes' in item:
21
+ notes = item['notes']
22
+ break
23
+ else:
24
+ notes = "Enter Prompt description of your image"
25
+ return gr.update(value=notes)
26
+
27
+ def get_trigger_words(model_textbox_value):
28
+ """
29
+ Retrieves the trigger words from constants.LORA_DETAILS for the specified model.
30
+
31
+ Args:
32
+ model_textbox_value (str): The name of the LoRA model.
33
+
34
+ Returns:
35
+ str: The trigger words associated with the model, or a default message if not found.
36
+ """
37
+ trigger_words = ""
38
+ if model_textbox_value in LORA_DETAILS:
39
+ lora_detail_list = LORA_DETAILS[model_textbox_value]
40
+ for item in lora_detail_list:
41
+ if 'trigger_words' in item:
42
+ trigger_words = item['trigger_words']
43
+ break
44
+ else:
45
+ trigger_words = ""
46
+ return trigger_words
47
+
48
+ def upd_trigger_words(model_textbox_value):
49
+ """
50
+ Updates the trigger_words_label with the trigger words from LORA_DETAILS.
51
+
52
+ Args:
53
+ model_textbox_value (str): The name of the LoRA model.
54
+
55
+ Returns:
56
+ gr.update: Updated Gradio label component with the trigger words.
57
+ """
58
+ trigger_words = get_trigger_words(model_textbox_value)
59
+ return gr.update(value=trigger_words)