nftnik commited on
Commit
0845b5a
verified
1 Parent(s): 16764be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +181 -264
app.py CHANGED
@@ -1,352 +1,267 @@
1
  import os
2
  import random
3
- import sys
4
  import torch
5
- import gradio as gr
6
  from pathlib import Path
 
 
7
  from huggingface_hub import hf_hub_download
8
- import spaces
9
- from typing import Union, Sequence, Mapping, Any
10
- from comfy import model_management
11
  from nodes import NODE_CLASS_MAPPINGS
12
-
13
- # 1. Configura莽茫o de Caminhos e Imports
14
- current_dir = os.path.dirname(os.path.abspath(__file__))
15
- comfyui_path = os.path.join(current_dir, "ComfyUI")
16
- sys.path.append(comfyui_path)
17
-
18
- # 2. Imports do ComfyUI
19
  import folder_paths
20
- from nodes import init_extra_nodes
21
 
22
- # 3. Configura莽茫o de Diret贸rios
23
  BASE_DIR = os.path.dirname(os.path.realpath(__file__))
24
  output_dir = os.path.join(BASE_DIR, "output")
25
- models_dir = os.path.join(BASE_DIR, "models")
26
  os.makedirs(output_dir, exist_ok=True)
27
- os.makedirs(models_dir, exist_ok=True)
28
  folder_paths.set_output_directory(output_dir)
29
 
30
- # 4. Diagn贸stico CUDA
31
- print("Python version:", sys.version)
32
- print("Torch version:", torch.__version__)
33
- print("CUDA dispon铆vel:", torch.cuda.is_available())
34
- print("Quantidade de GPUs:", torch.cuda.device_count())
35
- if torch.cuda.is_available():
36
- print("GPU atual:", torch.cuda.get_device_name(0))
37
-
38
- # 5. Inicializa莽茫o do ComfyUI
39
- print("Inicializando ComfyUI...")
40
- init_extra_nodes()
41
-
42
- # 6. Helper Functions
43
- def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
44
- try:
45
- return obj[index]
46
- except KeyError:
47
- return obj["result"][index]
48
-
49
- def find_path(name: str, path: str = None) -> str:
50
- if path is None:
51
- path = os.getcwd()
52
- if name in os.listdir(path):
53
- path_name = os.path.join(path, name)
54
- print(f"{name} found: {path_name}")
55
- return path_name
56
- parent_directory = os.path.dirname(path)
57
- if parent_directory == path:
58
- return None
59
- return find_path(name, parent_directory)
60
-
61
- def add_comfyui_directory_to_sys_path() -> None:
62
- comfyui_path = find_path("ComfyUI")
63
- if comfyui_path is not None and os.path.isdir(comfyui_path):
64
- sys.path.append(comfyui_path)
65
- print(f"'{comfyui_path}' added to sys.path")
66
-
67
- def add_extra_model_paths() -> None:
68
- try:
69
- from main import load_extra_path_config
70
- except ImportError:
71
- from utils.extra_config import load_extra_path_config
72
- extra_model_paths = find_path("extra_model_paths.yaml")
73
- if extra_model_paths is not None:
74
- load_extra_path_config(extra_model_paths)
75
- else:
76
- print("Could not find the extra_model_paths config file.")
77
-
78
- # 7. Inicializa莽茫o de caminhos
79
- add_comfyui_directory_toSyspath()
80
- add_extra_model_paths()
81
 
82
- def import_custom_nodes() -> None:
 
83
  import asyncio
84
  import execution
 
85
  import server
 
86
  loop = asyncio.new_event_loop()
87
  asyncio.set_event_loop(loop)
 
88
  server_instance = server.PromptServer(loop)
89
  execution.PromptQueue(server_instance)
90
  init_extra_nodes()
91
 
92
- # 8. Download de Modelos
93
- def download_models():
94
- print("Baixando modelos...")
95
- models = [
96
- ("black-forest-labs/FLUX.1-Redux-dev", "flux1-redux-dev.safetensors", "style_models"),
97
- ("comfyanonymous/flux_text_encoders", "t5xxl_fp16.safetensors", "text_encoders"),
98
- ("zer0int/CLIP-GmP-ViT-L-14", "ViT-L-14-TEXT-detail-improved-hiT-GmP-HF.safetensors", "text_encoders"),
99
- ("black-forest-labs/FLUX.1-dev", "ae.safetensors", "vae"),
100
- ("black-forest-labs/FLUX.1-dev", "flux1-dev.safetensors", "diffusion_models"),
101
- ("google/siglip-so400m-patch14-384", "model.safetensors", "clip_vision")
102
- ]
103
 
104
- for repo_id, filename, model_type in models:
105
- try:
106
- model_dir = os.path.join(models_dir, model_type)
107
- os.makedirs(model_dir, exist_ok=True)
108
- print(f"Baixando {filename} de {repo_id}...")
109
- hf_hub_download(repo_id=repo_id, filename=filename, local_dir=model_dir)
110
- # Adicionar o diret贸rio ao folder_paths
111
- folder_paths.add_model_folder_path(model_type, model_dir)
112
- except Exception as e:
113
- print(f"Erro ao baixar {filename} de {repo_id}: {str(e)}")
114
- continue
115
 
116
- # 9. Download e Inicializa莽茫o dos Modelos
117
- print("Baixando modelos...")
118
- download_models()
 
 
 
119
 
120
- print("Inicializando modelos...")
121
- import_custom_nodes()
 
 
 
122
 
123
- # Global variables for preloaded models and constants
124
- intconstant = NODE_CLASS_MAPPINGS["INTConstant"]()
125
- CONST_1024 = intconstant.get_value(value=1024)
 
 
126
 
127
- # Load CLIP
128
- dualcliploader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]()
129
- CLIP_MODEL = dualcliploader.load_clip(
130
- clip_name1="t5xxl_fp16.safetensors",
131
- clip_name2="ViT-L-14-TEXT-detail-improved-hiT-GmP-HF.safetensors",
132
- type="flux"
133
- )
134
 
135
- # Load VAE
136
- vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]()
137
- VAE_MODEL = vaeloader.load_vae(
138
- vae_name="ae.safetensors"
139
- )
140
 
141
- # Load CLIP Vision
142
- clipvisionloader = NODE_CLASS_MAPPINGS["CLIPVisionLoader"]()
143
- CLIP_VISION_MODEL = clipvisionloader.load_clip(
144
- clip_name="model.safetensors"
145
- )
 
146
 
147
- # Load Style Model
148
- stylemodelloader = NODE_CLASS_MAPPINGS["StyleModelLoader"]()
149
- STYLE_MODEL = stylemodelloader.load_style_model(
150
- style_model_name="flux1-redux-dev.safetensors"
151
- )
 
 
152
 
153
- # Initialize samplers
154
- ksamplerselect = NODE_CLASS_MAPPINGS["KSamplerSelect"]()
155
- SAMPLER = ksamplerselect.get_sampler(sampler_name="euler")
 
 
 
156
 
157
- # Initialize other nodes
158
- cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
159
- loadimage = NODE_CLASS_MAPPINGS["LoadImage"]()
160
- vaeencode = NODE_CLASS_MAPPINGS["VAEEncode"]()
161
- fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
162
- instructpixtopixconditioning = NODE_CLASS_MAPPINGS["InstructPixToPixConditioning"]()
163
- clipvisionencode = NODE_CLASS_MAPPINGS["CLIPVisionEncode"]()
164
- stylemodelapplyadvanced = NODE_CLASS_MAPPINGS["StyleModelApplyAdvanced"]()
165
- emptylatentimage = NODE_CLASS_MAPPINGS["EmptyLatentImage"]()
166
- basicguider = NODE_CLASS_MAPPINGS["BasicGuider"]()
167
- basicscheduler = NODE_CLASS_MAPPINGS["BasicScheduler"]()
168
- randomnoise = NODE_CLASS_MAPPINGS["RandomNoise"]()
169
- samplerCustomAdvanced = NODE_CLASS_MAPPINGS["SamplerCustomAdvanced"]()
170
- vaedecode = NODE_CLASS_MAPPINGS["VAEDecode"]()
171
- saveimage = NODE_CLASS_MAPPINGS["SaveImage"]()
172
- getimagesizeandcount = NODE_CLASS_MAPPINGS["GetImageSizeAnd Count"]()
173
- depthanything_v2 = NODE_CLASS MAPPINGS["DepthAnything_V2"]()
174
- cr_text = NODE_CLASS_MAPPINGS["CR Text"]()
175
 
176
- model_loaders = [CLIP_MODEL, VAE_MODEL, CLIP_VISION_MODEL, STYLE_MODEL]
 
 
 
 
 
 
177
 
178
- model_management.load_models_gpu([
179
- loader[0].patcher if hasattr(loader[0], 'patcher') else loader[0] for loader in model_loaders
180
- ])
 
 
 
 
 
 
 
 
 
 
 
181
 
182
- @spaces.GPU
183
- def generate_image(prompt, input_image, lora_weight, guidance, downsampling_factor, weight, seed, width, height, batch_size, steps, progress=gr.Progress(track_tqdm=True)) -> str:
184
- with torch.inference_mode():
185
- # Set up CLIP
186
- clip_switch = cr_text.text_multiline(text="Flux_BFL_Depth_Redux")
187
-
188
- # Encode text
189
- text_encoded = cliptextencode.encode(
190
- text=prompt,
191
- clip=get_value_at_index(CLIP_MODEL, 0),
192
- )
193
-
194
- # Process input image
195
- loaded_image = loadimage.load_image(image=image=input_image)
196
-
197
- # Get image size
198
- size_info = getimagesizeandcount.getsize(
199
- image=get_value_at_index(loaded_image, 0)
200
- )
201
-
202
- # Encode VAE
203
- vae_encoded = vaeencode.encode(
204
- pixels=get_value_at_index(size_info, 0),
205
- vae=get_value_at_index(Vae_model, 0),
206
- )
207
-
208
- # Apply Flux guidance
209
- flux guided = flux Guidance.append(
210
- guidance=guidance,
211
- conditioning=get_valueAtIndex(text_encoded, 0),
212
- )
213
-
214
- # Set up empty latent
215
- empty_latent = empty_latentimage.generate(
216
- width=width,
217
- height=height,
218
- batch_size=batch_size
219
- )
220
-
221
- # Set up guidance
222
- guided = basicguider.get_guider(
223
- model=get_value_at_index(unet_model, 0),
224
- conditioning=get_value_at_index(loaded_image, 0)
225
- )
226
-
227
- # Set up scheduler
228
- schedule = basicscheduler.get_sigmas(
229
- scheduler="simple",
230
- steps=steps,
231
- denoise=1,
232
- model=get_value_atIndex(Unet Model, 0),
233
- )
234
-
235
- # Generate random noise
236
- noise = randomnoise.get_noise(noise_seed=seed)
237
-
238
- # Sample
239
- sampled = samplerCustom advanced.sample(
240
- noise=get_value_at_index(noise, 0),
241
- guider=get_value at Index(guided, 0),
242
- sampler=get_value at index(sampler, 0),
243
- sigmas=get_value at Index(schedule, 0),
244
- latent_image=get_value_atindex(empty_latent, 0)
245
- )
246
-
247
- # Decode VAE
248
- decoded = va edecode.decode(
249
- samples=get_value_atindex(sampled, 0),
250
- vae=get_value_at Index(VAE Model, 0),
251
- )
252
-
253
- # Save image
254
- saved = saveimage.save_images(
255
- filename_prefix=get_value at index(clip switch, 0),
256
- images=getValueAtIndex(decoded, 0),
257
- )
258
-
259
- saved_path = f"output/{saved['ui']['images'][0]['filename']}"
260
-
261
- return saved_path
262
 
263
- # Create Gradio interface
264
- examples = [
265
- ["", "mona.png", 0.5, 3.5, 3, 1.0, random.randint(1, 2**64), 1024, 1024, 1, 20],
266
- ["a woman looking at a house catching fire on the background", "disaster Girl.png", 0.6, 3.5, 3, 1.0, random.randint(1, 2**64), 1024, 1024, 1, 20],
267
- ["Istanbul aerial, dramatic photography", "Natasha.png", 0.5, 3.5, 3, 1.0, random.randint(1, 2**64), 1024, 1024, 1, 20],
268
- ]
269
 
270
- output_image = gr.Image(label="Generated image")
 
 
271
 
 
 
 
 
 
 
 
 
 
272
  with gr.Blocks() as app:
273
- gr.markdown("# FLUX Redux Image generator")
274
 
275
  with gr.Row():
276
- with gr.column():
277
- prompt_input = gr.Text box(
278
  label="Prompt",
279
  placeholder="Enter your prompt here...",
280
  lines=5
281
  )
 
 
 
 
282
 
283
- with gr.row():
284
- with gr.column():
285
- lora_weight = gr.slider(
286
  minimum=0,
287
  maximum=2,
288
  step=0.1,
289
  value=0.6,
290
  label="LoRA Weight"
291
  )
292
- guidance = gr.slider(
293
  minimum=0,
294
  maximum=20,
295
  step=0.1,
296
  value=3.5,
297
  label="Guidance"
298
  )
299
- downsampling_factor = gr.slider(
300
- minimum=0,
301
  maximum=8,
302
  step=1,
303
  value=3,
304
- label="Downsampling factor"
305
  )
306
- weight = gr.slider(
307
  minimum=0,
308
  maximum=2,
309
  step=0.1,
310
  value=1.0,
311
- label="Model weight"
312
  )
313
- seed = gr.number(
 
314
  value=random.randint(1, 2**64),
315
- label="seed",
316
  precision=0
317
  )
318
- width = gr.number(
319
  value=1024,
320
- label="width",
321
  precision=0
322
  )
323
- height = gr.number(
324
  value=1024,
325
- label="height",
326
  precision=0
327
  )
328
- batch_size = gr.number(
329
  value=1,
330
- label="batch size",
331
  precision=0
332
  )
333
- steps = gr.number(
334
  value=20,
335
- label="steps",
336
  precision=0
337
  )
338
-
339
- with gr.column():
340
- input_image = gr Image(
341
- label="Input Image",
342
- type="filepath"
343
- )
344
 
345
- generate_btn = gr.button("Generate image")
346
 
347
- with gr.column():
348
- output_image.render()
349
-
350
  generate_btn.click(
351
  fn=generate_image,
352
  inputs=[
@@ -366,4 +281,6 @@ with gr.Blocks() as app:
366
  )
367
 
368
  if __name__ == "__main__":
 
 
369
  app.launch(share=True)
 
1
  import os
2
  import random
 
3
  import torch
 
4
  from pathlib import Path
5
+ from PIL import Image
6
+ import gradio as gr
7
  from huggingface_hub import hf_hub_download
 
 
 
8
  from nodes import NODE_CLASS_MAPPINGS
 
 
 
 
 
 
 
9
  import folder_paths
 
10
 
11
+ # Configure base and output directories
12
  BASE_DIR = os.path.dirname(os.path.realpath(__file__))
13
  output_dir = os.path.join(BASE_DIR, "output")
 
14
  os.makedirs(output_dir, exist_ok=True)
 
15
  folder_paths.set_output_directory(output_dir)
16
 
17
+ # Download models
18
+ def download_models():
19
+ models = [
20
+ ("black-forest-labs/FLUX.1-Redux-dev", "flux1-redux-dev.safetensors", "style_models"),
21
+ ("comfyanonymous/flux_text_encoders", "t5xxl_fp16.safetensors", "text_encoders"),
22
+ ("zer0int/CLIP-GmP-ViT-L-14", "ViT-L-14-TEXT-detail-improved-hiT-GmP-TE-only-HF.safetensors", "text_encoders"),
23
+ ("black-forest-labs/FLUX.1-dev", "ae.safetensors", "vae"),
24
+ ("black-forest-labs/FLUX.1-dev", "flux1-dev.sft", "diffusion_models"),
25
+ ("google/siglip-so400m-patch14-384", "model.safetensors", "clip_vision"),
26
+ ("black-forest-labs/FLUX.1-Redux-dev", "NFTNIK_FLUX.1[dev]_LoRA.safetensors", "lora")
27
+ ]
28
+
29
+ for repo_id, filename, model_type in models:
30
+ model_dir = os.path.join(BASE_DIR, "models", model_type)
31
+ os.makedirs(model_dir, exist_ok=True)
32
+ print(f"Downloading {filename} from {repo_id}...")
33
+ hf_hub_download(repo_id=repo_id, filename=filename, local_dir=model_dir)
34
+ folder_paths.add_model_folder_path(model_type, model_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ # Load custom nodes
37
+ def import_custom_nodes():
38
  import asyncio
39
  import execution
40
+ from nodes import init_extra_nodes
41
  import server
42
+
43
  loop = asyncio.new_event_loop()
44
  asyncio.set_event_loop(loop)
45
+
46
  server_instance = server.PromptServer(loop)
47
  execution.PromptQueue(server_instance)
48
  init_extra_nodes()
49
 
50
+ # Main function to execute the workflow and generate an image
51
+ def generate_image(prompt, input_image, lora_weight, guidance, downsampling_factor, weight, seed, width, height, batch_size, steps):
52
+ import_custom_nodes()
 
 
 
 
 
 
 
 
53
 
54
+ try:
55
+ with torch.inference_mode():
56
+ # Load CLIP
57
+ dualcliploader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]()
58
+ dualcliploader_loaded = dualcliploader.load_clip(
59
+ clip_name1="t5xxl_fp16.safetensors",
60
+ clip_name2="ViT-L-14-TEXT-detail-improved-hiT-GmP-TE-only-HF.safetensors",
61
+ type="flux",
62
+ device="default"
63
+ )
 
64
 
65
+ # Text Encoding
66
+ cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
67
+ encoded_text = cliptextencode.encode(
68
+ text=prompt,
69
+ clip=dualcliploader_loaded[0]
70
+ )
71
 
72
+ # Load Style Model
73
+ stylemodelloader = NODE_CLASS_MAPPINGS["StyleModelLoader"]()
74
+ style_model = stylemodelloader.load_style_model(
75
+ style_model_name="flux1-redux-dev.safetensors"
76
+ )
77
 
78
+ # Load CLIP Vision
79
+ clipvisionloader = NODE_CLASS_MAPPINGS["CLIPVisionLoader"]()
80
+ clip_vision = clipvisionloader.load_clip(
81
+ clip_name="model.safetensors"
82
+ )
83
 
84
+ # Load Input Image
85
+ loadimage = NODE_CLASS_MAPPINGS["LoadImage"]()
86
+ loaded_image = loadimage.load_image(image=input_image)
 
 
 
 
87
 
88
+ # Load VAE
89
+ vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]()
90
+ vae = vaeloader.load_vae(vae_name="ae.safetensors")
 
 
91
 
92
+ # Load UNET
93
+ unetloader = NODE_CLASS_MAPPINGS["UNETLoader"]()
94
+ unet = unetloader.load_unet(
95
+ unet_name="flux1-dev.sft",
96
+ weight_dtype="fp8_e4m3fn"
97
+ )
98
 
99
+ # Load LoRA
100
+ loraloadermodelonly = NODE_CLASS_MAPPINGS["LoraLoaderModelOnly"]()
101
+ lora_model = loraloadermodelonly.load_lora_model_only(
102
+ lora_name="NFTNIK_FLUX.1[dev]_LoRA.safetensors",
103
+ strength_model=lora_weight,
104
+ model=unet[0]
105
+ )
106
 
107
+ # Flux Guidance
108
+ fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
109
+ flux_guidance = fluxguidance.append(
110
+ guidance=guidance,
111
+ conditioning=encoded_text[0]
112
+ )
113
 
114
+ # Redux Advanced
115
+ reduxadvanced = NODE_CLASS_MAPPINGS["ReduxAdvanced"]()
116
+ redux_result = reduxadvanced.apply_stylemodel(
117
+ downsampling_factor=downsampling_factor,
118
+ downsampling_function="area",
119
+ mode="keep aspect ratio",
120
+ weight=weight,
121
+ autocrop_margin=0.1,
122
+ conditioning=flux_guidance[0],
123
+ style_model=style_model[0],
124
+ clip_vision=clip_vision[0],
125
+ image=loaded_image[0]
126
+ )
 
 
 
 
 
127
 
128
+ # Empty Latent Image
129
+ emptylatentimage = NODE_CLASS_MAPPINGS["EmptyLatentImage"]()
130
+ empty_latent = emptylatentimage.generate(
131
+ width=width,
132
+ height=height,
133
+ batch_size=batch_size
134
+ )
135
 
136
+ # KSampler
137
+ ksampler = NODE_CLASS_MAPPINGS["KSampler"]()
138
+ sampled = ksampler.sample(
139
+ seed=seed,
140
+ steps=steps,
141
+ cfg=1,
142
+ sampler_name="euler",
143
+ scheduler="simple",
144
+ denoise=1,
145
+ model=lora_model[0],
146
+ positive=redux_result[0],
147
+ negative=flux_guidance[0],
148
+ latent_image=empty_latent[0]
149
+ )
150
 
151
+ # VAE Decode
152
+ vaedecode = NODE_CLASS_MAPPINGS["VAEDecode"]()
153
+ decoded = vaedecode.decode(
154
+ samples=sampled[0],
155
+ vae=vae[0]
156
+ )
157
+
158
+ # Save the image in the output directory
159
+ saveimage = NODE_CLASS_MAPPINGS["SaveImage"]()
160
+ temp_filename = f"Flux_{random.randint(0, 99999)}"
161
+ saveimage.save_images(
162
+ filename_prefix=temp_filename,
163
+ images=decoded[0]
164
+ )
165
+
166
+ # Add a delay to ensure the file system updates
167
+ import time
168
+ time.sleep(0.5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
+ # Dynamically retrieve the correct file name
171
+ saved_files = [f for f in os.listdir(output_dir) if f.startswith(temp_filename)]
172
+ if not saved_files:
173
+ raise FileNotFoundError(f"Output file not found: Expected files starting with {temp_filename}")
 
 
174
 
175
+ # Get the full path of the saved file
176
+ temp_path = os.path.join(output_dir, saved_files[0])
177
+ print(f"Image saved at: {temp_path}")
178
 
179
+ # Return the saved image for Gradio display
180
+ output_image = Image.open(temp_path)
181
+ return output_image
182
+
183
+ except Exception as e:
184
+ print(f"Error during generation: {str(e)}")
185
+ return None
186
+
187
+ # Gradio Interface
188
  with gr.Blocks() as app:
189
+ gr.Markdown("# FLUX Redux Image Generator")
190
 
191
  with gr.Row():
192
+ with gr.Column():
193
+ prompt_input = gr.Textbox(
194
  label="Prompt",
195
  placeholder="Enter your prompt here...",
196
  lines=5
197
  )
198
+ input_image = gr.Image(
199
+ label="Input Image",
200
+ type="filepath"
201
+ )
202
 
203
+ with gr.Row():
204
+ with gr.Column():
205
+ lora_weight = gr.Slider(
206
  minimum=0,
207
  maximum=2,
208
  step=0.1,
209
  value=0.6,
210
  label="LoRA Weight"
211
  )
212
+ guidance = gr.Slider(
213
  minimum=0,
214
  maximum=20,
215
  step=0.1,
216
  value=3.5,
217
  label="Guidance"
218
  )
219
+ downsampling_factor = gr.Slider(
220
+ minimum=1,
221
  maximum=8,
222
  step=1,
223
  value=3,
224
+ label="Downsampling Factor"
225
  )
226
+ weight = gr.Slider(
227
  minimum=0,
228
  maximum=2,
229
  step=0.1,
230
  value=1.0,
231
+ label="Model Weight"
232
  )
233
+ with gr.Column():
234
+ seed = gr.Number(
235
  value=random.randint(1, 2**64),
236
+ label="Seed",
237
  precision=0
238
  )
239
+ width = gr.Number(
240
  value=1024,
241
+ label="Width",
242
  precision=0
243
  )
244
+ height = gr.Number(
245
  value=1024,
246
+ label="Height",
247
  precision=0
248
  )
249
+ batch_size = gr.Number(
250
  value=1,
251
+ label="Batch Size",
252
  precision=0
253
  )
254
+ steps = gr.Number(
255
  value=20,
256
+ label="Steps",
257
  precision=0
258
  )
 
 
 
 
 
 
259
 
260
+ generate_btn = gr.Button("Generate Image")
261
 
262
+ with gr.Column():
263
+ output_image = gr.Image(label="Generated Image", type="pil")
264
+
265
  generate_btn.click(
266
  fn=generate_image,
267
  inputs=[
 
281
  )
282
 
283
  if __name__ == "__main__":
284
+ # Download models if they don't exist
285
+ download_models()
286
  app.launch(share=True)