import gradio as gr
import numpy as np
import random
import torch
import spaces
from PIL import Image
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from tqdm import tqdm
import gc
from qwenimage.pipeline_qwen_image_edit import QwenImageEditPipeline
from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
LORA_CONFIG = {
"None": {
"repo_id": None,
"filename": None,
"type": "edit",
"method": "none",
"prompt_template": "{prompt}",
"description": "Use the base Qwen-Image-Edit model without any LoRA.",
},
"InStyle (Style Transfer)": {
"repo_id": "peteromallet/Qwen-Image-Edit-InStyle",
"filename": "InStyle-0.5.safetensors",
"type": "style",
"method": "manual_fuse",
"prompt_template": "Make an image in this style of {prompt}",
"description": "Transfers the style from a reference image to a new image described by the prompt.",
},
"InScene (In-Scene Editing)": {
"repo_id": "flymy-ai/qwen-image-edit-inscene-lora",
"filename": "flymy_qwen_image_edit_inscene_lora.safetensors",
"type": "edit",
"method": "standard",
"prompt_template": "{prompt}",
"description": "Improves in-scene editing, object positioning, and camera perspective changes.",
},
"Face Segmentation": {
"repo_id": "TsienDragon/qwen-image-edit-lora-face-segmentation",
"filename": "pytorch_lora_weights.safetensors",
"type": "edit",
"method": "standard",
"prompt_template": "change the face to face segmentation mask",
"description": "Transforms a facial image into a precise segmentation mask.",
},
"Object Remover": {
"repo_id": "valiantcat/Qwen-Image-Edit-Remover-General-LoRA",
"filename": "qwen-edit-remover.safetensors",
"type": "edit",
"method": "standard",
"prompt_template": "Remove {prompt}",
"description": "Removes objects from an image while maintaining background consistency.",
},
}
print("Initializing model...")
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = QwenImageEditPipeline.from_pretrained(
"Qwen/Qwen-Image-Edit",
torch_dtype=dtype
).to(device)
pipe.transformer.__class__ = QwenImageTransformer2DModel
pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
original_transformer_state_dict = pipe.transformer.state_dict()
print("Base model loaded and ready.")
def fuse_lora_manual(transformer, lora_state_dict, alpha=1.0):
key_mapping = {}
for key in lora_state_dict.keys():
base_key = key.replace('diffusion_model.', '').rsplit('.lora_', 1)[0]
if base_key not in key_mapping:
key_mapping[base_key] = {}
if 'lora_A' in key:
key_mapping[base_key]['down'] = lora_state_dict[key]
elif 'lora_B' in key:
key_mapping[base_key]['up'] = lora_state_dict[key]
for name, module in tqdm(transformer.named_modules(), desc="Fusing layers"):
if name in key_mapping and isinstance(module, torch.nn.Linear):
lora_weights = key_mapping[name]
if 'down' in lora_weights and 'up' in lora_weights:
device = module.weight.device
dtype = module.weight.dtype
lora_down = lora_weights['down'].to(device, dtype=dtype)
lora_up = lora_weights['up'].to(device, dtype=dtype)
merged_delta = lora_up @ lora_down
module.weight.data += alpha * merged_delta
return transformer
def load_and_fuse_lora(lora_name):
"""Carrega uma LoRA, funde-a ao modelo e retorna o pipeline modificado."""
config = LORA_CONFIG[lora_name]
print("Resetting transformer to original state...")
pipe.transformer.load_state_dict(original_transformer_state_dict)
if config["method"] == "none":
print("No LoRA selected. Using base model.")
return
print(f"Loading LoRA: {lora_name}")
lora_path = hf_hub_download(repo_id=config["repo_id"], filename=config["filename"])
if config["method"] == "standard":
print("Using standard loading method...")
pipe.load_lora_weights(lora_path)
print("Fusing LoRA into the model...")
pipe.fuse_lora()
elif config["method"] == "manual_fuse":
print("Using manual fusion method...")
lora_state_dict = load_file(lora_path)
pipe.transformer = fuse_lora_manual(pipe.transformer, lora_state_dict)
gc.collect()
torch.cuda.empty_cache()
print(f"LoRA '{lora_name}' is now active.")
@spaces.GPU(duration=60)
def infer(
lora_name,
input_image,
style_image,
prompt,
seed,
randomize_seed,
true_guidance_scale,
num_inference_steps,
progress=gr.Progress(track_tqdm=True),
):
if not lora_name:
raise gr.Error("Please select a LoRA model.")
config = LORA_CONFIG[lora_name]
if config["type"] == "style":
if style_image is None:
raise gr.Error("Style Transfer LoRA requires a Style Reference Image.")
image_for_pipeline = style_image
else: # 'edit'
if input_image is None:
raise gr.Error("This LoRA requires an Input Image.")
image_for_pipeline = input_image
if not prompt and config["prompt_template"] != "change the face to face segmentation mask":
raise gr.Error("A text prompt is required for this LoRA.")
load_and_fuse_lora(lora_name)
final_prompt = config["prompt_template"].format(prompt=prompt)
if randomize_seed:
seed = random.randint(0, np.iinfo(np.int32).max)
generator = torch.Generator(device=device).manual_seed(int(seed))
print("--- Running Inference ---")
print(f"LoRA: {lora_name}")
print(f"Prompt: {final_prompt}")
print(f"Seed: {seed}, Steps: {num_inference_steps}, CFG: {true_guidance_scale}")
with torch.inference_mode():
result_image = pipe(
image=image_for_pipeline,
prompt=final_prompt,
negative_prompt=" ",
num_inference_steps=int(num_inference_steps),
generator=generator,
true_cfg_scale=true_guidance_scale,
).images[0]
pipe.unfuse_lora()
gc.collect()
torch.cuda.empty_cache()
return result_image, seed
def on_lora_change(lora_name):
config = LORA_CONFIG[lora_name]
is_style_lora = config["type"] == "style"
return {
lora_description: gr.Markdown(visible=True, value=f"**Description:** {config['description']}"),
input_image_box: gr.Image(visible=not is_style_lora),
style_image_box: gr.Image(visible=is_style_lora),
prompt_box: gr.Textbox(visible=(config["prompt_template"] != "change the face to face segmentation mask"))
}
with gr.Blocks(css="#col-container { margin: 0 auto; max-width: 1024px; }") as demo:
with gr.Column(elem_id="col-container"):
gr.HTML('')
gr.Markdown("