import types import random import spaces import os import torch import numpy as np from diffusers import AutoencoderKLWan, UniPCMultistepScheduler from diffusers.utils import export_to_video from huggingface_hub import snapshot_download import gradio as gr import tempfile from huggingface_hub import hf_hub_download from src.pipeline_wan_nag import NAGWanPipeline from src.transformer_wan_nag import NagWanTransformer3DModel # Dummy constants (replace with actual model values) MOD_VALUE = 32 DEFAULT_DURATION_SECONDS = 4 DEFAULT_STEPS = 4 DEFAULT_SEED = 2025 DEFAULT_H_SLIDER_VALUE = 480 DEFAULT_W_SLIDER_VALUE = 832 NEW_FORMULA_MAX_AREA = 480.0 * 832.0 SLIDER_MIN_H, SLIDER_MAX_H = 128, 896 SLIDER_MIN_W, SLIDER_MAX_W = 128, 896 MAX_SEED = np.iinfo(np.int32).max FIXED_FPS = 16 MIN_FRAMES_MODEL = 8 MAX_FRAMES_MODEL = 81 DEFAULT_NAG_NEGATIVE_PROMPT = "Static, motionless, still, ugly, bad quality, worst quality, poorly drawn, low resolution, blurry, lack of details" MODEL_ID = "Wan-AI/Wan2.1-T2V-14B-Diffusers" vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32) pipe = NAGWanPipeline.from_pretrained(MODEL_ID, vae=vae, torch_dtype=torch.bfloat16) pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=5.0) pipe.to("cuda") # Patch transformer methods pipe.transformer.__class__.attn_processors = NagWanTransformer3DModel.attn_processors pipe.transformer.__class__.set_attn_processor = NagWanTransformer3DModel.set_attn_processor pipe.transformer.__class__.forward = NagWanTransformer3DModel.forward # --- Predefined LoRAs --- AVAILABLE_LORAS = [ { "label": "CausVid LoRA", "repo_id": "Kijai/WanVideo_comfy", "filename": "Wan21_CausVid_14B_T2V_lora_rank32.safetensors", "adapter_name": "causvid_lora", "default_weight": 0.95, "scale_blocks": ["blocks.0"], }, { "label": "Detail Enhancer V1", "repo_id": "vrgamedevgirl84/Wan14BT2VFusioniX", "filename": "OtherLoRa's/DetailEnhancerV1.safetensors", "adapter_name": "mps_lora", "default_weight": 0.7 } ] def load_loras_from_ui(selected_labels, weights, custom_repo, custom_file, custom_weight): lora_adapters = [] lora_weights = [] selected_configs = [] for i, label in enumerate(selected_labels): lora = next((l for l in AVAILABLE_LORAS if l["label"] == label), None) if lora: config = lora.copy() config["weight"] = weights[i] selected_configs.append(config) # if custom_repo and custom_file: # adapter_name = os.path.splitext(os.path.basename(custom_file))[0] # selected_configs.append({ # "repo_id": custom_repo, # "filename": custom_file, # "adapter_name": adapter_name, # "weight": float(custom_weight), # }) for config in selected_configs: snapshot_path = snapshot_download( repo_id=config["repo_id"], allow_patterns=[config["filename"]], repo_type="model" ) lora_path = os.path.join(snapshot_path, config["filename"]) pipe.load_lora_weights(lora_path, adapter_name=config["adapter_name"]) if config.get("scale_blocks"): for name, param in pipe.transformer.named_parameters(): if "lora_B" in name and any(b in name for b in config["scale_blocks"]): param.data *= 0.25 lora_adapters.append(config["adapter_name"]) lora_weights.append(config["weight"]) if lora_adapters: pipe.set_adapters(lora_adapters, adapter_weights=lora_weights) pipe.fuse_lora() print(f"✅ Fused LoRAs: {lora_adapters}") # def get_duration( # prompt, # nag_negative_prompt, nag_scale, # height, width, duration_seconds, # steps, # seed, randomize_seed, # compare, # ): # duration = int(duration_seconds) * int(steps) * 2.25 + 5 # if compare: # duration *= 2 # return duration @spaces.GPU(duration=200) def generate_video(prompt, nag_negative_prompt, nag_scale, height=DEFAULT_H_SLIDER_VALUE, width=DEFAULT_W_SLIDER_VALUE, duration_seconds=DEFAULT_DURATION_SECONDS, steps=DEFAULT_STEPS, seed=DEFAULT_SEED, randomize_seed=False, compare=True): target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE) target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE) num_frames = np.clip(int(round(int(duration_seconds) * FIXED_FPS) + 1), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL) current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed) with torch.inference_mode(): nag_output_frames_list = pipe( prompt=prompt, nag_negative_prompt=nag_negative_prompt, nag_scale=nag_scale, nag_tau=3.5, nag_alpha=0.5, height=target_h, width=target_w, num_frames=num_frames, guidance_scale=0., num_inference_steps=int(steps), generator=torch.Generator(device="cuda").manual_seed(current_seed) ).frames[0] with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile: nag_video_path = tmpfile.name export_to_video(nag_output_frames_list, nag_video_path, fps=FIXED_FPS) if compare: baseline_output_frames_list = pipe( prompt=prompt, nag_negative_prompt=nag_negative_prompt, height=target_h, width=target_w, num_frames=num_frames, guidance_scale=0., num_inference_steps=int(steps), generator=torch.Generator(device="cuda").manual_seed(current_seed) ).frames[0] with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile: baseline_video_path = tmpfile.name export_to_video(baseline_output_frames_list, baseline_video_path, fps=FIXED_FPS) else: baseline_video_path = None return nag_video_path, baseline_video_path, current_seed # --- Gradio UI --- with gr.Blocks() as demo: gr.Markdown("# Wan2.1-T2V-14B + NAG + LoRA Control") with gr.Row(): with gr.Column(): prompt = gr.Textbox(label="Prompt") nag_negative_prompt = gr.Textbox(label="Negative Prompt", value=DEFAULT_NAG_NEGATIVE_PROMPT) nag_scale = gr.Slider(1., 20., value=11., step=0.25, label="NAG Scale") compare = gr.Checkbox(label="Compare with baseline", value=True) with gr.Accordion("Advanced", open=False): steps_slider = gr.Slider(1, 8, value=DEFAULT_STEPS, label="Inference Steps") duration_seconds_input = gr.Slider(1, 5, value=DEFAULT_DURATION_SECONDS, label="Duration (seconds)") seed_input = gr.Slider(0, MAX_SEED, step=1, value=DEFAULT_SEED, label="Seed") randomize_seed_checkbox = gr.Checkbox(label="Randomize Seed", value=True) height_input = gr.Slider(SLIDER_MIN_H, SLIDER_MAX_H, step=MOD_VALUE, value=DEFAULT_H_SLIDER_VALUE, label="Height") width_input = gr.Slider(SLIDER_MIN_W, SLIDER_MAX_W, step=MOD_VALUE, value=DEFAULT_W_SLIDER_VALUE, label="Width") with gr.Accordion("LoRA Settings", open=False): lora_selector = gr.CheckboxGroup([l["label"] for l in AVAILABLE_LORAS], label="Select Predefined LoRAs") lora_sliders = [ gr.Slider(0.0, 1.5, value=l["default_weight"], label=f"{l['label']} Weight") for l in AVAILABLE_LORAS ] # custom_repo = gr.Textbox(label="Custom Repo ID (optional)", placeholder="e.g. my-user/my-repo") # custom_file = gr.Textbox(label="Custom Filename (optional)", placeholder="e.g. my_model.safetensors") # custom_weight = gr.Slider(0.0, 1.5, value=1.0, label="Custom Weight") generate_button = gr.Button("Generate Video") with gr.Column(): nag_video_output = gr.Video(label="Video with NAG") baseline_video_output = gr.Video(label="Baseline Video") def generate_wrapper(*args): selected_labels = args[-5] if not isinstance(selected_labels, list): selected_labels = [] # Ensure it's iterable even if empty or NaN lora_weights = args[-4:-4 + len(AVAILABLE_LORAS)] if selected_labels: load_loras_from_ui(selected_labels, lora_weights) return generate_video(*args[:-5]) inputs = [ prompt, nag_negative_prompt, nag_scale, height_input, width_input, duration_seconds_input, steps_slider, seed_input, randomize_seed_checkbox, compare, lora_selector # ✅ CheckboxGroup - must be BEFORE sliders ] + lora_sliders generate_button.click( fn=generate_wrapper, inputs=inputs, outputs=[nag_video_output, baseline_video_output, seed_input], ) if __name__ == "__main__": demo.queue().launch()