nftnik commited on
Commit
02c8fdb
·
verified ·
1 Parent(s): f8708de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -24
app.py CHANGED
@@ -3,13 +3,18 @@ import sys
3
  import random
4
  import torch
5
  from pathlib import Path
6
- from PIL import Image
7
  import gradio as gr
8
  from huggingface_hub import hf_hub_download
9
  import spaces
10
  from typing import Union, Sequence, Mapping, Any
11
  import logging
12
 
 
 
 
 
 
13
  # Configurar logging para debug
14
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
15
  logger = logging.getLogger(__name__)
@@ -18,11 +23,6 @@ logger = logging.getLogger(__name__)
18
  current_dir = os.path.dirname(os.path.abspath(__file__))
19
  sys.path.append(current_dir)
20
 
21
- # 2. Imports do ComfyUI
22
- import folder_paths
23
- from nodes import NODE_CLASS_MAPPINGS, init_extra_nodes
24
- from comfy import model_management
25
-
26
  # 3. Configuração de Diretórios
27
  BASE_DIR = os.path.dirname(os.path.realpath(__file__))
28
  output_dir = os.path.join(BASE_DIR, "output")
@@ -109,8 +109,6 @@ except Exception as e:
109
  # 8. Inicialização dos Modelos
110
  logger.info("Inicializando modelos...")
111
  try:
112
- # Use torch.no_grad() em vez de torch.inference_mode()
113
- # para evitar o erro de version counter.
114
  with torch.no_grad():
115
  # CLIP
116
  logger.info("Carregando CLIP...")
@@ -155,7 +153,7 @@ try:
155
  unetloader = NODE_CLASS_MAPPINGS["UNETLoader"]()
156
  UNET_MODEL = unetloader.load_unet(
157
  unet_name="flux1-dev.safetensors",
158
- weight_dtype="fp8_e4m3fn" # Ajuste a seu hardware, se necessário
159
  )
160
  if UNET_MODEL is None:
161
  raise ValueError("Falha ao carregar UNET model")
@@ -175,14 +173,13 @@ except Exception as e:
175
  @spaces.GPU
176
  def generate_image(
177
  prompt, input_image, lora_weight, guidance, downsampling_factor,
178
- weight, seed, width, height, batch_size, steps,
179
  progress=gr.Progress(track_tqdm=True)
180
  ):
181
  try:
182
- # Aqui também: no_grad() para evitar cálculo de gradientes
183
  with torch.no_grad():
184
  logger.info(f"Iniciando geração com prompt: {prompt}")
185
-
186
  # Codificar texto
187
  cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
188
  encoded_text = cliptextencode.encode(
@@ -217,7 +214,7 @@ def generate_image(
217
  image=loaded_image[0]
218
  )
219
 
220
- # Criar latente vazio
221
  emptylatentimage = NODE_CLASS_MAPPINGS["EmptyLatentImage"]()
222
  empty_latent = emptylatentimage.generate(
223
  width=width,
@@ -249,16 +246,37 @@ def generate_image(
249
  vae=VAE_MODEL[0]
250
  )
251
 
252
- # Salvar imagem
253
- temp_filename = f"Flux_{random.randint(0, 99999)}.png"
254
- temp_path = os.path.join(output_dir, temp_filename)
255
- try:
256
- Image.fromarray((decoded[0] * 255).astype("uint8")).save(temp_path)
257
- logger.info(f"Imagem salva em: {temp_path}")
258
- return temp_path
259
- except Exception as e:
260
- logger.error(f"Erro ao salvar imagem: {str(e)}")
261
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
 
263
  except Exception as e:
264
  logger.error(f"Erro ao gerar imagem: {str(e)}")
@@ -361,5 +379,4 @@ with gr.Blocks() as app:
361
  )
362
 
363
  if __name__ == "__main__":
364
- # Ajuste caso queira compartilhar publicamente, exemplo: app.launch(server_name="0.0.0.0", share=True)
365
  app.launch()
 
3
  import random
4
  import torch
5
  from pathlib import Path
6
+ import numpy as np
7
  import gradio as gr
8
  from huggingface_hub import hf_hub_download
9
  import spaces
10
  from typing import Union, Sequence, Mapping, Any
11
  import logging
12
 
13
+ # Adicione se ainda não tiver
14
+ from nodes import NODE_CLASS_MAPPINGS, init_extra_nodes, SaveImage # <-- Node SaveImage
15
+ from comfy import model_management
16
+ import folder_paths
17
+
18
  # Configurar logging para debug
19
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
20
  logger = logging.getLogger(__name__)
 
23
  current_dir = os.path.dirname(os.path.abspath(__file__))
24
  sys.path.append(current_dir)
25
 
 
 
 
 
 
26
  # 3. Configuração de Diretórios
27
  BASE_DIR = os.path.dirname(os.path.realpath(__file__))
28
  output_dir = os.path.join(BASE_DIR, "output")
 
109
  # 8. Inicialização dos Modelos
110
  logger.info("Inicializando modelos...")
111
  try:
 
 
112
  with torch.no_grad():
113
  # CLIP
114
  logger.info("Carregando CLIP...")
 
153
  unetloader = NODE_CLASS_MAPPINGS["UNETLoader"]()
154
  UNET_MODEL = unetloader.load_unet(
155
  unet_name="flux1-dev.safetensors",
156
+ weight_dtype="fp8_e4m3fn" # ajuste se preciso
157
  )
158
  if UNET_MODEL is None:
159
  raise ValueError("Falha ao carregar UNET model")
 
173
  @spaces.GPU
174
  def generate_image(
175
  prompt, input_image, lora_weight, guidance, downsampling_factor,
176
+ weight, seed, width, height, batch_size, steps,
177
  progress=gr.Progress(track_tqdm=True)
178
  ):
179
  try:
 
180
  with torch.no_grad():
181
  logger.info(f"Iniciando geração com prompt: {prompt}")
182
+
183
  # Codificar texto
184
  cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
185
  encoded_text = cliptextencode.encode(
 
214
  image=loaded_image[0]
215
  )
216
 
217
+ # Empty Latent
218
  emptylatentimage = NODE_CLASS_MAPPINGS["EmptyLatentImage"]()
219
  empty_latent = emptylatentimage.generate(
220
  width=width,
 
246
  vae=VAE_MODEL[0]
247
  )
248
 
249
+ # ======================== SALVAR IMAGEM USANDO O NODE SaveImage ======================
250
+ logger.info("Salvando imagem via node SaveImage...")
251
+
252
+ # 1. Pegue a saída do decode (tensor)
253
+ decoded_tensor = decoded[0] # se 'decoded' for um dict/tuple, ajuste conforme preciso
254
+
255
+ # 2. Instancia o SaveImage
256
+ saveimage_node = NODE_CLASS_MAPPINGS["SaveImage"]()
257
+
258
+ # 3. Usa o método save_images
259
+ # 'filename_prefix' é o prefixo do arquivo de saída
260
+ result_dict = saveimage_node.save_images(
261
+ filename_prefix="FluxRedux", # ou algo dinâmico se preferir
262
+ images=decoded_tensor
263
+ )
264
+
265
+ # 4. Normalmente, o node 'save_images' retorna um dicionário contendo:
266
+ # {
267
+ # 'ui': {
268
+ # 'images': [
269
+ # {'filename': 'FluxRedux_12345.png', 'subfolder': ''},
270
+ # ...
271
+ # ]
272
+ # },
273
+ # ...
274
+ # }
275
+ # Assim, para pegar o nome do arquivo salvo:
276
+ saved_path = os.path.join(output_dir, result_dict["ui"]["images"][0]["filename"])
277
+
278
+ logger.info(f"Imagem salva em: {saved_path}")
279
+ return saved_path
280
 
281
  except Exception as e:
282
  logger.error(f"Erro ao gerar imagem: {str(e)}")
 
379
  )
380
 
381
  if __name__ == "__main__":
 
382
  app.launch()