nftnik commited on
Commit
6aa4d81
·
verified ·
1 Parent(s): e5e4614

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +159 -269
app.py CHANGED
@@ -1,269 +1,159 @@
1
- import os
2
- import random
3
- import torch
4
- from pathlib import Path
5
- from PIL import Image
6
- import gradio as gr
7
- from nodes import NODE_CLASS_MAPPINGS
8
- import folder_paths
9
-
10
- # Configure base and output directories
11
- BASE_DIR = os.path.dirname(os.path.realpath(__file__))
12
- output_dir = os.path.join(BASE_DIR, "output")
13
- os.makedirs(output_dir, exist_ok=True)
14
- folder_paths.set_output_directory(output_dir)
15
-
16
- def import_custom_nodes():
17
- """Loads custom nodes required for the workflow."""
18
- import asyncio
19
- import execution
20
- from nodes import init_extra_nodes
21
- import server
22
-
23
- loop = asyncio.new_event_loop()
24
- asyncio.set_event_loop(loop)
25
-
26
- server_instance = server.PromptServer(loop)
27
- execution.PromptQueue(server_instance)
28
- init_extra_nodes()
29
-
30
- def generate_image(prompt, input_image, lora_weight, guidance, downsampling_factor, weight, seed, width, height, batch_size, steps):
31
- """
32
- Main function to execute the workflow and generate an image.
33
- """
34
- import_custom_nodes()
35
-
36
- try:
37
- with torch.inference_mode():
38
- # Load CLIP
39
- dualcliploader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]()
40
- dualcliploader_loaded = dualcliploader.load_clip(
41
- clip_name1="t5xxl_fp16.safetensors",
42
- clip_name2="ViT-L-14-TEXT-detail-improved-hiT-GmP-TE-only-HF.safetensors",
43
- type="flux",
44
- device="default"
45
- )
46
-
47
- # Text Encoding
48
- cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
49
- encoded_text = cliptextencode.encode(
50
- text=prompt,
51
- clip=dualcliploader_loaded[0]
52
- )
53
-
54
- # Load Style Model
55
- stylemodelloader = NODE_CLASS_MAPPINGS["StyleModelLoader"]()
56
- style_model = stylemodelloader.load_style_model(
57
- style_model_name="flux1-redux-dev.safetensors"
58
- )
59
-
60
- # Load CLIP Vision
61
- clipvisionloader = NODE_CLASS_MAPPINGS["CLIPVisionLoader"]()
62
- clip_vision = clipvisionloader.load_clip(
63
- clip_name="sigclip_vision_patch14_384.safetensors"
64
- )
65
-
66
- # Load Input Image
67
- loadimage = NODE_CLASS_MAPPINGS["LoadImage"]()
68
- loaded_image = loadimage.load_image(image=input_image)
69
-
70
- # Load VAE
71
- vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]()
72
- vae = vaeloader.load_vae(vae_name="ae.safetensors")
73
-
74
- # Load UNET
75
- unetloader = NODE_CLASS_MAPPINGS["UNETLoader"]()
76
- unet = unetloader.load_unet(
77
- unet_name="flux1-dev.sft",
78
- weight_dtype="fp8_e4m3fn"
79
- )
80
-
81
- # Load LoRA
82
- loraloadermodelonly = NODE_CLASS_MAPPINGS["LoraLoaderModelOnly"]()
83
- lora_model = loraloadermodelonly.load_lora_model_only(
84
- lora_name="NFTNIK_FLUX.1[dev]_LoRA.safetensors",
85
- strength_model=lora_weight,
86
- model=unet[0]
87
- )
88
-
89
- # Flux Guidance
90
- fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
91
- flux_guidance = fluxguidance.append(
92
- guidance=guidance,
93
- conditioning=encoded_text[0]
94
- )
95
-
96
- # Redux Advanced
97
- reduxadvanced = NODE_CLASS_MAPPINGS["ReduxAdvanced"]()
98
- redux_result = reduxadvanced.apply_stylemodel(
99
- downsampling_factor=downsampling_factor,
100
- downsampling_function="area",
101
- mode="keep aspect ratio",
102
- weight=weight,
103
- autocrop_margin=0.1,
104
- conditioning=flux_guidance[0],
105
- style_model=style_model[0],
106
- clip_vision=clip_vision[0],
107
- image=loaded_image[0]
108
- )
109
-
110
- # Empty Latent Image
111
- emptylatentimage = NODE_CLASS_MAPPINGS["EmptyLatentImage"]()
112
- empty_latent = emptylatentimage.generate(
113
- width=width,
114
- height=height,
115
- batch_size=batch_size
116
- )
117
-
118
- # KSampler
119
- ksampler = NODE_CLASS_MAPPINGS["KSampler"]()
120
- sampled = ksampler.sample(
121
- seed=seed,
122
- steps=steps,
123
- cfg=1,
124
- sampler_name="euler",
125
- scheduler="simple",
126
- denoise=1,
127
- model=lora_model[0],
128
- positive=redux_result[0],
129
- negative=flux_guidance[0],
130
- latent_image=empty_latent[0]
131
- )
132
-
133
- # VAE Decode
134
- vaedecode = NODE_CLASS_MAPPINGS["VAEDecode"]()
135
- decoded = vaedecode.decode(
136
- samples=sampled[0],
137
- vae=vae[0]
138
- )
139
-
140
- # Save the image in the output directory
141
- saveimage = NODE_CLASS_MAPPINGS["SaveImage"]()
142
- temp_filename = f"Flux_{random.randint(0, 99999)}"
143
- saveimage.save_images(
144
- filename_prefix=temp_filename,
145
- images=decoded[0]
146
- )
147
-
148
- # Add a delay to ensure the file system updates
149
- import time
150
- time.sleep(0.5)
151
-
152
- # Dynamically retrieve the correct file name
153
- saved_files = [f for f in os.listdir(output_dir) if f.startswith(temp_filename)]
154
- if not saved_files:
155
- raise FileNotFoundError(f"Output file not found: Expected files starting with {temp_filename}")
156
-
157
- # Get the full path of the saved file
158
- temp_path = os.path.join(output_dir, saved_files[0])
159
- print(f"Image saved at: {temp_path}")
160
-
161
- # Return the saved image for Gradio display
162
- output_image = Image.open(temp_path)
163
- return output_image
164
-
165
- except Exception as e:
166
- print(f"Error during generation: {str(e)}")
167
- return None
168
-
169
- # Gradio Interface
170
- with gr.Blocks() as app:
171
- gr.Markdown("# FLUX Redux Image Generator")
172
-
173
- with gr.Row():
174
- with gr.Column():
175
- prompt_input = gr.Textbox(
176
- label="Prompt",
177
- placeholder="Enter your prompt here...",
178
- lines=5
179
- )
180
- input_image = gr.Image(
181
- label="Input Image",
182
- type="filepath"
183
- )
184
-
185
- with gr.Row():
186
- with gr.Column():
187
- lora_weight = gr.Slider(
188
- minimum=0,
189
- maximum=2,
190
- step=0.1,
191
- value=0.6,
192
- label="LoRA Weight"
193
- )
194
- guidance = gr.Slider(
195
- minimum=0,
196
- maximum=20,
197
- step=0.1,
198
- value=3.5,
199
- label="Guidance"
200
- )
201
- downsampling_factor = gr.Slider(
202
- minimum=1,
203
- maximum=8,
204
- step=1,
205
- value=3,
206
- label="Downsampling Factor"
207
- )
208
- weight = gr.Slider(
209
- minimum=0,
210
- maximum=2,
211
- step=0.1,
212
- value=1.0,
213
- label="Model Weight"
214
- )
215
- with gr.Column():
216
- seed = gr.Number(
217
- value=random.randint(1, 2**64),
218
- label="Seed",
219
- precision=0
220
- )
221
- width = gr.Number(
222
- value=1024,
223
- label="Width",
224
- precision=0
225
- )
226
- height = gr.Number(
227
- value=1024,
228
- label="Height",
229
- precision=0
230
- )
231
- batch_size = gr.Number(
232
- value=1,
233
- label="Batch Size",
234
- precision=0
235
- )
236
- steps = gr.Number(
237
- value=20,
238
- label="Steps",
239
- precision=0
240
- )
241
-
242
- generate_btn = gr.Button("Generate Image")
243
-
244
- with gr.Column():
245
- output_image = gr.Image(label="Generated Image", type="pil")
246
-
247
- generate_btn.click(
248
- fn=generate_image,
249
- inputs=[
250
- prompt_input,
251
- input_image,
252
- lora_weight,
253
- guidance,
254
- downsampling_factor,
255
- weight,
256
- seed,
257
- width,
258
- height,
259
- batch_size,
260
- steps
261
- ],
262
- outputs=[output_image]
263
- )
264
-
265
- if __name__ == "__main__":
266
- app.launch()
267
-
268
-
269
- #python app.py
 
1
+ import os
2
+ import random
3
+ import torch
4
+ from pathlib import Path
5
+ from PIL import Image
6
+ import gradio as gr
7
+ from huggingface_hub import hf_hub_download
8
+ from nodes import NODE_CLASS_MAPPINGS
9
+ import folder_paths
10
+
11
+ # Diretório base e de saída
12
+ BASE_DIR = os.path.dirname(os.path.realpath(__file__))
13
+ output_dir = os.path.join(BASE_DIR, "output")
14
+ os.makedirs(output_dir, exist_ok=True)
15
+ folder_paths.set_output_directory(output_dir)
16
+
17
+ # Baixar os modelos necessários
18
+
19
+ hf_hub_download(repo_id="black-forest-labs/FLUX.1-Redux-dev",
20
+ filename="flux1-redux-dev.safetensors",
21
+ local_dir="models/style_models")
22
+
23
+ hf_hub_download(repo_id="comfyanonymous/flux_text_encoders",
24
+ filename="t5xxl_fp16.safetensors",
25
+ local_dir="models/text_encoders")
26
+
27
+ hf_hub_download(repo_id="zer0int/CLIP-GmP-ViT-L-14",
28
+ filename="ViT-L-14-TEXT-detail-improved-hiT-GmP-HF.safetensors",
29
+ local_dir="models/text_encoders")
30
+
31
+ hf_hub_download(repo_id="black-forest-labs/FLUX.1-dev",
32
+ filename="ae.safetensors",
33
+ local_dir="models/vae")
34
+
35
+ hf_hub_download(repo_id="black-forest-labs/FLUX.1-dev",
36
+ filename="flux1-dev.safetensors.safetensors",
37
+ local_dir="models/diffusion_models")
38
+
39
+ hf_hub_download(repo_id="google/siglip-so400m-patch14-384",
40
+ filename="model.safetensors",
41
+ local_dir="models/clip_vision")
42
+
43
+ hf_hub_download(repo_id="nftnik/NFTNIK-FLUX.1-dev-LoRA",
44
+ filename="NFTNIK_FLUX.1[dev]_LoRA.safetensors",
45
+ local_dir="models/lora")
46
+
47
+ # Função para importar nodes personalizados
48
+ def import_custom_nodes():
49
+ """Carregar nodes customizados."""
50
+ import asyncio
51
+ import execution
52
+ from nodes import init_extra_nodes
53
+ import server
54
+
55
+ loop = asyncio.new_event_loop()
56
+ asyncio.set_event_loop(loop)
57
+
58
+ server_instance = server.PromptServer(loop)
59
+ execution.PromptQueue(server_instance)
60
+ init_extra_nodes()
61
+
62
+ # Função principal de geração
63
+ def generate_image(prompt, input_image, lora_weight, guidance, downsampling_factor, weight, seed, width, height, batch_size, steps):
64
+ import_custom_nodes()
65
+
66
+ try:
67
+ with torch.inference_mode():
68
+ # Carregar CLIP
69
+ dualcliploader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]()
70
+ dualcliploader_loaded = dualcliploader.load_clip(
71
+ clip_name1="models/text_encoders/t5xxl_fp16.safetensors",
72
+ clip_name2="models/clip_vision/ViT-L-14-TEXT-detail-improved-hiT-GmP-HF.safetensors",
73
+ type="flux"
74
+ )
75
+
76
+ # Codificar texto
77
+ cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
78
+ encoded_text = cliptextencode.encode(
79
+ text=prompt,
80
+ clip=dualcliploader_loaded[0]
81
+ )
82
+
83
+ # Carregar modelos de estilo e LoRA
84
+ stylemodelloader = NODE_CLASS_MAPPINGS["StyleModelLoader"]()
85
+ style_model = stylemodelloader.load_style_model(
86
+ style_model_name="models/style_models/flux1-redux-dev.safetensors"
87
+ )
88
+ loraloadermodelonly = NODE_CLASS_MAPPINGS["LoraLoaderModelOnly"]()
89
+ lora_model = loraloadermodelonly.load_lora_model_only(
90
+ lora_name="models/lora/NFTNIK_FLUX.1[dev]_LoRA.safetensors",
91
+ strength_model=lora_weight,
92
+ model=style_model[0]
93
+ )
94
+
95
+ # Processar imagem de entrada
96
+ loadimage = NODE_CLASS_MAPPINGS["LoadImage"]()
97
+ loaded_image = loadimage.load_image(image=input_image)
98
+
99
+ # Configurações adicionais e saída
100
+ vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]()
101
+ vae = vaeloader.load_vae(vae_name="models/vae/ae.safetensors")
102
+
103
+ # Decodificar e salvar
104
+ vaedecode = NODE_CLASS_MAPPINGS["VAEDecode"]()
105
+ decoded = vaedecode.decode(
106
+ samples=lora_model[0],
107
+ vae=vae[0]
108
+ )
109
+
110
+ temp_filename = f"Flux_{random.randint(0, 99999)}.png"
111
+ temp_path = os.path.join(output_dir, temp_filename)
112
+ Image.fromarray((decoded[0] * 255).astype("uint8")).save(temp_path)
113
+
114
+ return temp_path
115
+ except Exception as e:
116
+ print(f"Erro ao gerar imagem: {str(e)}")
117
+ return None
118
+
119
+ # Interface Gradio
120
+ with gr.Blocks() as app:
121
+ gr.Markdown("# Gerador de Imagens FLUX Redux")
122
+ with gr.Row():
123
+ with gr.Column():
124
+ prompt_input = gr.Textbox(label="Prompt", placeholder="Digite seu prompt aqui...", lines=5)
125
+ input_image = gr.Image(label="Imagem de Entrada", type="filepath")
126
+ lora_weight = gr.Slider(minimum=0, maximum=2, step=0.1, value=0.6, label="Peso LoRA")
127
+ guidance = gr.Slider(minimum=0, maximum=20, step=0.1, value=3.5, label="Orientação")
128
+ downsampling_factor = gr.Slider(minimum=1, maximum=8, step=1, value=3, label="Fator de Redução")
129
+ weight = gr.Slider(minimum=0, maximum=2, step=0.1, value=1.0, label="Peso do Modelo")
130
+ seed = gr.Number(value=random.randint(1, 2**64), label="Seed", precision=0)
131
+ width = gr.Number(value=1024, label="Largura", precision=0)
132
+ height = gr.Number(value=1024, label="Altura", precision=0)
133
+ batch_size = gr.Number(value=1, label="Tamanho do Lote", precision=0)
134
+ steps = gr.Number(value=20, label="Etapas", precision=0)
135
+ generate_btn = gr.Button("Gerar Imagem")
136
+
137
+ with gr.Column():
138
+ output_image = gr.Image(label="Imagem Gerada", type="filepath")
139
+
140
+ generate_btn.click(
141
+ fn=generate_image,
142
+ inputs=[
143
+ prompt_input,
144
+ input_image,
145
+ lora_weight,
146
+ guidance,
147
+ downsampling_factor,
148
+ weight,
149
+ seed,
150
+ width,
151
+ height,
152
+ batch_size,
153
+ steps
154
+ ],
155
+ outputs=[output_image]
156
+ )
157
+
158
+ if __name__ == "__main__":
159
+ app.launch()