Redux / app.py
nftnik's picture
Update app.py
aa3232a verified
import os
import sys
import random
import torch
from pathlib import Path
import numpy as np
import gradio as gr
from huggingface_hub import hf_hub_download
import spaces
from typing import Union, Sequence, Mapping, Any
import logging
from nodes import NODE_CLASS_MAPPINGS, init_extra_nodes, SaveImage # <-- Node SaveImage
from comfy import model_management
import folder_paths
# 1. Configurar logging para debug
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# 2. Configuração de Caminhos e Imports
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(current_dir)
# 3. Configuração de Diretórios
BASE_DIR = os.path.dirname(os.path.realpath(__file__))
output_dir = os.path.join(BASE_DIR, "output")
models_dir = os.path.join(BASE_DIR, "models")
os.makedirs(output_dir, exist_ok=True)
os.makedirs(models_dir, exist_ok=True)
folder_paths.set_output_directory(output_dir)
# 4. Configurar caminhos dos modelos e verificar estrutura
MODEL_FOLDERS = ["style_models", "text_encoders", "vae", "unet", "clip_vision"]
for model_folder in MODEL_FOLDERS:
folder_path = os.path.join(models_dir, model_folder)
os.makedirs(folder_path, exist_ok=True)
folder_paths.add_model_folder_path(model_folder, folder_path)
logger.info(f"Pasta de modelo configurada: {model_folder}")
# 5. Diagnóstico CUDA
logger.info(f"Python version: {sys.version}")
logger.info(f"Torch version: {torch.__version__}")
logger.info(f"CUDA disponível: {torch.cuda.is_available()}")
logger.info(f"Quantidade de GPUs: {torch.cuda.device_count()}")
if torch.cuda.is_available():
logger.info(f"GPU atual: {torch.cuda.get_device_name(0)}")
# 6. Inicialização do ComfyUI
logger.info("Inicializando ComfyUI...")
try:
init_extra_nodes()
except Exception as e:
logger.warning(f"Aviso na inicialização de nós extras: {str(e)}")
logger.info("Continuando mesmo com avisos nos nós extras...")
# 7. Helper Functions
def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
try:
return obj[index]
except KeyError:
return obj["result"][index]
def verify_file_exists(folder: str, filename: str) -> bool:
file_path = os.path.join(models_dir, folder, filename)
exists = os.path.exists(file_path)
if not exists:
logger.error(f"Arquivo não encontrado: {file_path}")
return exists
# 8. Download de Modelos
logger.info("Baixando modelos necessários...")
try:
hf_hub_download(
repo_id="black-forest-labs/FLUX.1-Redux-dev",
filename="flux1-redux-dev.safetensors",
local_dir=os.path.join(models_dir, "style_models")
)
hf_hub_download(
repo_id="comfyanonymous/flux_text_encoders",
filename="t5xxl_fp16.safetensors",
local_dir=os.path.join(models_dir, "text_encoders")
)
hf_hub_download(
repo_id="zer0int/CLIP-GmP-ViT-L-14",
filename="ViT-L-14-TEXT-detail-improved-hiT-GmP-TE-only-HF.safetensors",
local_dir=os.path.join(models_dir, "text_encoders")
)
hf_hub_download(
repo_id="black-forest-labs/FLUX.1-dev",
filename="ae.safetensors",
local_dir=os.path.join(models_dir, "vae")
)
hf_hub_download(
repo_id="black-forest-labs/FLUX.1-dev",
filename="flux1-dev.safetensors",
local_dir=os.path.join(models_dir, "unet")
)
hf_hub_download(
repo_id="Comfy-Org/sigclip_vision_384",
filename="sigclip_vision_patch14_384.safetensors",
local_dir=os.path.join(models_dir, "clip_vision")
)
except Exception as e:
logger.error(f"Erro ao baixar modelos: {str(e)}")
raise
# 9. Inicialização dos Modelos
logger.info("Inicializando modelos...")
try:
with torch.no_grad():
# CLIP
logger.info("Carregando CLIP...")
dualcliploader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]()
CLIP_MODEL = dualcliploader.load_clip(
clip_name1="t5xxl_fp16.safetensors",
clip_name2="ViT-L-14-TEXT-detail-improved-hiT-GmP-TE-only-HF.safetensors",
type="flux"
)
if CLIP_MODEL is None:
raise ValueError("Falha ao carregar CLIP model")
# CLIP Vision
logger.info("Carregando CLIP Vision...")
clipvisionloader = NODE_CLASS_MAPPINGS["CLIPVisionLoader"]()
CLIP_VISION = clipvisionloader.load_clip(
clip_name="sigclip_vision_patch14_384.safetensors"
)
if CLIP_VISION is None:
raise ValueError("Falha ao carregar CLIP Vision model")
# Style Model
logger.info("Carregando Style Model...")
stylemodelloader = NODE_CLASS_MAPPINGS["StyleModelLoader"]()
STYLE_MODEL = stylemodelloader.load_style_model(
style_model_name="flux1-redux-dev.safetensors"
)
if STYLE_MODEL is None:
raise ValueError("Falha ao carregar Style Model")
# VAE
logger.info("Carregando VAE...")
vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]()
VAE_MODEL = vaeloader.load_vae(
vae_name="ae.safetensors"
)
if VAE_MODEL is None:
raise ValueError("Falha ao carregar VAE model")
# UNET
logger.info("Carregando UNET...")
unetloader = NODE_CLASS_MAPPINGS["UNETLoader"]()
UNET_MODEL = unetloader.load_unet(
unet_name="flux1-dev.safetensors",
weight_dtype="fp8_e4m3fn" # ajuste se preciso
)
if UNET_MODEL is None:
raise ValueError("Falha ao carregar UNET model")
logger.info("Carregando modelos na GPU...")
model_loaders = [CLIP_MODEL, VAE_MODEL, CLIP_VISION, UNET_MODEL]
model_management.load_models_gpu([
loader[0].patcher if hasattr(loader[0], 'patcher') else loader[0]
for loader in model_loaders
])
logger.info("Modelos carregados com sucesso")
except Exception as e:
logger.error(f"Erro ao inicializar modelos: {str(e)}")
raise
# 10. Função de Geração
@spaces.GPU
def generate_image(
prompt, input_image, lora_weight, guidance, downsampling_factor,
weight, seed, width, height, batch_size, steps,
progress=gr.Progress(track_tqdm=True)
):
try:
with torch.no_grad():
logger.info(f"Iniciando geração com prompt: {prompt}")
# Codificar texto
cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
encoded_text = cliptextencode.encode(
text=prompt,
clip=CLIP_MODEL[0]
)
# Carregar e processar imagem
loadimage = NODE_CLASS_MAPPINGS["LoadImage"]()
loaded_image = loadimage.load_image(image=input_image)
if loaded_image is None:
raise ValueError("Erro ao carregar a imagem de entrada")
logger.info("Imagem carregada com sucesso")
# Flux Guidance
fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
flux_guidance = fluxguidance.append(
guidance=guidance,
conditioning=encoded_text[0]
)
# Redux Advanced
reduxadvanced = NODE_CLASS_MAPPINGS["ReduxAdvanced"]()
redux_result = reduxadvanced.apply_stylemodel(
downsampling_factor=downsampling_factor,
downsampling_function="area",
mode="keep aspect ratio",
weight=weight,
conditioning=flux_guidance[0],
style_model=STYLE_MODEL[0],
clip_vision=CLIP_VISION[0],
image=loaded_image[0]
)
# Empty Latent
emptylatentimage = NODE_CLASS_MAPPINGS["EmptyLatentImage"]()
empty_latent = emptylatentimage.generate(
width=width,
height=height,
batch_size=batch_size
)
# KSampler
logger.info("Iniciando sampling...")
ksampler = NODE_CLASS_MAPPINGS["KSampler"]()
sampled = ksampler.sample(
seed=seed,
steps=steps,
cfg=1,
sampler_name="euler",
scheduler="simple",
denoise=1,
model=UNET_MODEL[0],
positive=redux_result[0],
negative=flux_guidance[0],
latent_image=empty_latent[0]
)
# VAE Decode
logger.info("Decodificando imagem...")
vaedecode = NODE_CLASS_MAPPINGS["VAEDecode"]()
decoded = vaedecode.decode(
samples=sampled[0],
vae=VAE_MODEL[0]
)
# Salvar Imagem
logger.info("Salvando imagem via node SaveImage...")
decoded_tensor = decoded[0]
saveimage_node = NODE_CLASS_MAPPINGS["SaveImage"]()
result_dict = saveimage_node.save_images(
filename_prefix="Flux_",
images=decoded_tensor
)
saved_path = os.path.join(output_dir, result_dict["ui"]["images"][0]["filename"])
logger.info(f"Imagem salva em: {saved_path}")
return saved_path
except Exception as e:
logger.error(f"Erro ao gerar imagem: {str(e)}")
return None
# 10. Interface Gradio
with gr.Blocks() as app:
gr.Markdown("# FLUX Redux Image Generator")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(
label="Prompt",
placeholder="Enter your prompt here...",
lines=5
)
input_image = gr.Image(
label="Input Image",
type="filepath"
)
with gr.Row():
with gr.Column():
lora_weight = gr.Slider(
minimum=0,
maximum=2,
step=0.1,
value=0.6,
label="LoRA Weight"
)
guidance = gr.Slider(
minimum=0,
maximum=20,
step=0.1,
value=3.5,
label="Guidance"
)
downsampling_factor = gr.Slider(
minimum=1,
maximum=8,
step=1,
value=3,
label="Downsampling Factor"
)
weight = gr.Slider(
minimum=0,
maximum=2,
step=0.1,
value=1.0,
label="Model Weight"
)
with gr.Column():
seed = gr.Number(
value=random.randint(1, 2**64),
label="Seed",
precision=0
)
width = gr.Number(
value=1024,
label="Width",
precision=0
)
height = gr.Number(
value=1024,
label="Height",
precision=0
)
batch_size = gr.Number(
value=1,
label="Batch Size",
precision=0
)
steps = gr.Number(
value=20,
label="Steps",
precision=0
)
generate_btn = gr.Button("Generate Image")
with gr.Column():
output_image = gr.Image(label="Generated Image", type="filepath")
generate_btn.click(
fn=generate_image,
inputs=[
prompt_input,
input_image,
lora_weight,
guidance,
downsampling_factor,
weight,
seed,
width,
height,
batch_size,
steps
],
outputs=[output_image]
)
if __name__ == "__main__":
app.launch()