Spaces:
Runtime error
Runtime error
Commit
·
73ac5f9
1
Parent(s):
4b9047d
Big code cleanup
Browse files
app.py
CHANGED
|
@@ -139,20 +139,19 @@ def parse_lset_prompt(lset_prompt):
|
|
| 139 |
resolved_prompt = resolved_prompt.replace(f"{{{key}}}", highlighted_value)
|
| 140 |
return resolved_prompt
|
| 141 |
|
| 142 |
-
def handle_lora_selection_change(preset_name: str,
|
| 143 |
"""
|
| 144 |
-
|
| 145 |
-
|
| 146 |
"""
|
| 147 |
-
#
|
| 148 |
-
lora_slider_update = gr.update(visible=False, interactive=False)
|
| 149 |
-
|
| 150 |
-
# If "None" is selected, the displayed prompt is just the base prompt.
|
| 151 |
if not preset_name or preset_name == "None":
|
| 152 |
gr.Info("LoRA cleared.")
|
| 153 |
-
|
|
|
|
| 154 |
|
| 155 |
try:
|
|
|
|
| 156 |
lset_filename = f"{preset_name}.lset"
|
| 157 |
lset_path = hf_hub_download(
|
| 158 |
repo_id=DYNAMIC_LORA_REPO_ID,
|
|
@@ -166,38 +165,26 @@ def handle_lora_selection_change(preset_name: str, base_prompt: str):
|
|
| 166 |
lset_data = json.loads(lset_content)
|
| 167 |
lset_prompt_raw = lset_data.get("prompt")
|
| 168 |
except json.JSONDecodeError:
|
| 169 |
-
print(f"Info: '{lset_filename}' is not JSON. Treating as plain text prompt.")
|
| 170 |
lset_prompt_raw = lset_content
|
| 171 |
|
| 172 |
if lset_prompt_raw:
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
new_prompt = f"{
|
| 177 |
-
gr.Info(f"✅ Appended triggers from '{preset_name}'.
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
new_prompt_update = gr.update(value=new_prompt)
|
| 181 |
-
lora_slider_update = gr.update(visible=True, interactive=True)
|
| 182 |
-
return new_prompt_update, lora_slider_update
|
| 183 |
else:
|
|
|
|
| 184 |
gr.Info(f"ℹ️ No prompt found in '{preset_name}.lset'. Prompt unchanged.")
|
| 185 |
-
|
| 186 |
-
return gr.update(value=base_prompt), lora_slider_update
|
| 187 |
|
| 188 |
except Exception as e:
|
| 189 |
print(f"Info: Could not process .lset for '{preset_name}'. Reason: {e}")
|
| 190 |
gr.Warning(f"⚠️ Could not load triggers for '{preset_name}'.")
|
| 191 |
-
# On error,
|
| 192 |
-
return gr.update(value=
|
| 193 |
-
|
| 194 |
-
def set_base_prompt(current_prompt_text):
|
| 195 |
-
"""
|
| 196 |
-
Called when the user clicks the 'Set as Base Prompt' button.
|
| 197 |
-
This updates the base prompt state and resets the LoRA selection.
|
| 198 |
-
"""
|
| 199 |
-
gr.Info("New base prompt set. You can now select a LoRA to add triggers.")
|
| 200 |
-
return current_prompt_text, "None", gr.update(visible=False, interactive=False)
|
| 201 |
|
| 202 |
|
| 203 |
def _manage_lora_state(pipe, selected_lora: str, lora_weight: float) -> bool:
|
|
@@ -208,16 +195,13 @@ def _manage_lora_state(pipe, selected_lora: str, lora_weight: float) -> bool:
|
|
| 208 |
bool: True if a dynamic LoRA was loaded, False otherwise.
|
| 209 |
"""
|
| 210 |
# Pre-emptive cleanup of any previously loaded dynamic adapter.
|
| 211 |
-
# This is more robust than relying only on the `finally` block of the previous run.
|
| 212 |
try:
|
| 213 |
-
# --- FIX: Use delete_adapters to remove a specific adapter by name ---
|
| 214 |
pipe.delete_adapters([DYNAMIC_LORA_ADAPTER_NAME])
|
| 215 |
print("🧼 Pre-emptively unloaded previous dynamic LoRA.")
|
| 216 |
except Exception:
|
| 217 |
pass # No dynamic lora was present, which is a clean state.
|
| 218 |
|
| 219 |
if not selected_lora or selected_lora == "None":
|
| 220 |
-
# This run uses no dynamic LoRA. Ensure only the base LoRA is active.
|
| 221 |
pipe.set_adapters([FUSIONX_ADAPTER_NAME], adapter_weights=[FUSIONX_LORA_WEIGHT])
|
| 222 |
print("ℹ️ No dynamic LoRA selected. Using base LoRA only.")
|
| 223 |
return False
|
|
@@ -226,7 +210,6 @@ def _manage_lora_state(pipe, selected_lora: str, lora_weight: float) -> bool:
|
|
| 226 |
print(f"🚀 Processing preset: {selected_lora} with weight {lora_weight}")
|
| 227 |
lora_filename = None
|
| 228 |
try:
|
| 229 |
-
# Attempt to get the real LoRA filename from the .lset file
|
| 230 |
lset_filename = f"{selected_lora}.lset"
|
| 231 |
lset_path = hf_hub_download(
|
| 232 |
repo_id=DYNAMIC_LORA_REPO_ID,
|
|
@@ -268,27 +251,18 @@ def load_pipelines():
|
|
| 268 |
|
| 269 |
print("\n🚀 Loading T2V pipeline with base LoRA...")
|
| 270 |
try:
|
| 271 |
-
# To avoid potential cache duplication and storage issues, we load the
|
| 272 |
-
# pipeline directly, then replace the VAE with a float32 version for stability.
|
| 273 |
t2v_pipe = DiffusionPipeline.from_pretrained(
|
| 274 |
T2V_BASE_MODEL_ID,
|
| 275 |
-
torch_dtype=torch.bfloat16,
|
| 276 |
)
|
| 277 |
print("✅ Base pipeline loaded. Overriding VAE with float32 version...")
|
| 278 |
-
|
| 279 |
-
# The VAE often works better in float32. We reload it and replace it in the pipeline.
|
| 280 |
-
# Using the specific AutoencoderKLWan class is more robust than the generic AutoModel.
|
| 281 |
vae_fp32 = AutoencoderKLWan.from_pretrained(T2V_BASE_MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
|
| 282 |
t2v_pipe.vae = vae_fp32
|
| 283 |
-
|
| 284 |
t2v_pipe.to("cuda")
|
| 285 |
print("✅ Pipeline configured. Loading and activating base FusionX LoRA...")
|
| 286 |
-
|
| 287 |
-
# Load and set the base LoRA that is always active
|
| 288 |
t2v_pipe.load_lora_weights(FUSIONX_LORA_REPO, weight_name=FUSIONX_LORA_FILE, adapter_name=FUSIONX_ADAPTER_NAME)
|
| 289 |
t2v_pipe.set_adapters([FUSIONX_ADAPTER_NAME], adapter_weights=[FUSIONX_LORA_WEIGHT])
|
| 290 |
print("✅ T2V pipeline with base LoRA is ready.")
|
| 291 |
-
|
| 292 |
except Exception as e:
|
| 293 |
print(f"❌ CRITICAL ERROR: Failed to load T2V pipeline. T2V will be disabled. Reason: {e}")
|
| 294 |
traceback.print_exc()
|
|
@@ -296,8 +270,6 @@ def load_pipelines():
|
|
| 296 |
|
| 297 |
print("\n🤖 Loading LLM for Prompt Enhancement...")
|
| 298 |
try:
|
| 299 |
-
# In a ZeroGPU environment, we must load models on the CPU at startup.
|
| 300 |
-
# The model will be moved to the GPU inside the decorated function.
|
| 301 |
enhancer_pipe = pipeline("text-generation", model=ENHANCER_MODEL_ID, torch_dtype=torch.bfloat16, device="cpu")
|
| 302 |
print("✅ LLM Prompt Enhancer loaded successfully (on CPU).")
|
| 303 |
except Exception as e:
|
|
@@ -313,17 +285,11 @@ def load_pipelines():
|
|
| 313 |
def enhance_prompt_with_llm(prompt: str, enhancer_pipeline):
|
| 314 |
"""
|
| 315 |
Uses the loaded LLM to enhance a given prompt.
|
| 316 |
-
This function manually handles tensor placement to avoid device mismatches in a ZeroGPU environment.
|
| 317 |
"""
|
| 318 |
if enhancer_pipeline is None:
|
| 319 |
-
print("LLM enhancer not available, returning original prompt.")
|
| 320 |
gr.Warning("LLM enhancer is not available.")
|
| 321 |
-
return prompt
|
| 322 |
|
| 323 |
-
# In a Hugging Face ZeroGPU Space, the GPU is provisioned on-demand for functions
|
| 324 |
-
# decorated with @spaces.GPU and de-provisioned afterward. Therefore, the model,
|
| 325 |
-
# which is loaded on the CPU at startup, must be moved to the GPU for every call.
|
| 326 |
-
# The "Moving enhancer model to CUDA..." message is expected and correct for this setup.
|
| 327 |
if enhancer_pipeline.model.device.type != 'cuda':
|
| 328 |
print("Moving enhancer model to CUDA for on-demand GPU execution...")
|
| 329 |
enhancer_pipeline.model.to("cuda")
|
|
@@ -332,64 +298,31 @@ def enhance_prompt_with_llm(prompt: str, enhancer_pipeline):
|
|
| 332 |
print(f"Enhancing prompt: '{prompt}'")
|
| 333 |
|
| 334 |
try:
|
| 335 |
-
# 1. Get the tokenizer from the pipeline.
|
| 336 |
tokenizer = enhancer_pipeline.tokenizer
|
| 337 |
-
|
| 338 |
-
# FIX: Set pad_token to eos_token if not set. This is a common requirement for
|
| 339 |
-
# models like Qwen2 and helps prevent warnings about attention masks.
|
| 340 |
if tokenizer.pad_token is None:
|
| 341 |
tokenizer.pad_token = tokenizer.eos_token
|
| 342 |
-
|
| 343 |
-
# 2. Apply the chat template and tokenize. This returns a dictionary containing
|
| 344 |
-
# 'input_ids' and 'attention_mask' as PyTorch tensors.
|
| 345 |
-
tokenized_inputs = tokenizer.apply_chat_template(
|
| 346 |
-
messages,
|
| 347 |
-
tokenize=True,
|
| 348 |
-
add_generation_prompt=True,
|
| 349 |
-
return_tensors="pt"
|
| 350 |
-
)
|
| 351 |
|
| 352 |
-
# 3. FIX: Move each tensor in the dictionary to the CUDA device.
|
| 353 |
-
# 3. FIX: The tokenizer might return a single tensor instead of a dictionary.
|
| 354 |
-
# We handle both cases to make the code more robust.
|
| 355 |
if isinstance(tokenized_inputs, torch.Tensor):
|
| 356 |
-
# If we get a single tensor, assume it's input_ids
|
| 357 |
inputs_on_cuda = {"input_ids": tokenized_inputs.to("cuda")}
|
| 358 |
-
# Manually create the attention mask as it's good practice for generate()
|
| 359 |
inputs_on_cuda["attention_mask"] = torch.ones_like(inputs_on_cuda["input_ids"])
|
| 360 |
else:
|
| 361 |
-
# If we get a dictionary, move all its tensors to cuda
|
| 362 |
inputs_on_cuda = {k: v.to("cuda") for k, v in tokenized_inputs.items()}
|
| 363 |
|
| 364 |
-
|
| 365 |
-
# both `input_ids` and `attention_mask`. This resolves the warning.
|
| 366 |
-
generated_ids = enhancer_pipeline.model.generate(
|
| 367 |
-
**inputs_on_cuda,
|
| 368 |
-
max_new_tokens=256,
|
| 369 |
-
do_sample=True,
|
| 370 |
-
temperature=0.7,
|
| 371 |
-
top_p=0.95
|
| 372 |
-
)
|
| 373 |
-
|
| 374 |
-
# 5. The output from generate() includes the input tokens. We need to decode only the newly generated part.
|
| 375 |
input_token_length = inputs_on_cuda['input_ids'].shape[1]
|
| 376 |
newly_generated_ids = generated_ids[:, input_token_length:]
|
| 377 |
-
|
| 378 |
-
# 6. Decode the new tokens back into a string.
|
| 379 |
final_answer = tokenizer.decode(newly_generated_ids[0], skip_special_tokens=True)
|
| 380 |
|
| 381 |
print(f"Enhanced prompt: '{final_answer.strip()}'")
|
| 382 |
-
#
|
| 383 |
return final_answer.strip(), "None", gr.update(visible=False, interactive=False)
|
| 384 |
except Exception as e:
|
| 385 |
print(f"❌ Error during prompt enhancement: {e}")
|
| 386 |
-
# Adding full traceback for better debugging in the console
|
| 387 |
traceback.print_exc()
|
| 388 |
gr.Warning(f"An error occurred during prompt enhancement. See console for details.")
|
| 389 |
-
return prompt, gr.update(), gr.update()
|
| 390 |
finally:
|
| 391 |
-
# Explicitly empty the CUDA cache to help release GPU memory.
|
| 392 |
-
# This can help resolve intermittent issues where the GPU remains active.
|
| 393 |
print("🧹 Clearing CUDA cache after prompt enhancement...")
|
| 394 |
torch.cuda.empty_cache()
|
| 395 |
|
|
@@ -408,11 +341,12 @@ def generate_t2v_video(
|
|
| 408 |
if not prompt:
|
| 409 |
raise gr.Error("Please enter a prompt for Text-to-Video generation.")
|
| 410 |
|
|
|
|
|
|
|
|
|
|
| 411 |
target_h = max(MOD_VALUE, (height // MOD_VALUE) * MOD_VALUE)
|
| 412 |
target_w = max(MOD_VALUE, (width // MOD_VALUE) * MOD_VALUE)
|
| 413 |
|
| 414 |
-
# Calculate the initial number of frames based on duration and model constraints.
|
| 415 |
-
# The model requires (num_frames - 1) to be divisible by 4.
|
| 416 |
requested_frames = int(round(duration_seconds * T2V_FIXED_FPS))
|
| 417 |
frames_minus_one = requested_frames - 1
|
| 418 |
valid_frames_minus_one = round(frames_minus_one / 4.0) * 4
|
|
@@ -425,28 +359,23 @@ def generate_t2v_video(
|
|
| 425 |
lora_loaded = False
|
| 426 |
|
| 427 |
try:
|
| 428 |
-
|
| 429 |
-
lora_loaded = _manage_lora_state(
|
| 430 |
-
pipe=t2v_pipe,
|
| 431 |
-
selected_lora=selected_lora,
|
| 432 |
-
lora_weight=lora_weight
|
| 433 |
-
)
|
| 434 |
|
| 435 |
print("\n--- Starting T2V Generation ---")
|
| 436 |
-
print(f"Prompt: {
|
| 437 |
print(f"Resolution: {target_w}x{target_h}, Frames: {num_frames}, Seed: {current_seed}")
|
| 438 |
print(f"Steps: {steps}, Guidance: 1.0 (fixed for FusionX)")
|
| 439 |
print("---------------------------------")
|
| 440 |
|
| 441 |
with torch.inference_mode():
|
| 442 |
output_frames_list = t2v_pipe(
|
| 443 |
-
prompt=
|
| 444 |
height=target_h, width=target_w, num_frames=num_frames,
|
| 445 |
guidance_scale=1.0, num_inference_steps=int(steps),
|
| 446 |
generator=torch.Generator(device="cuda").manual_seed(current_seed)
|
| 447 |
).frames[0]
|
| 448 |
|
| 449 |
-
sanitized_prompt = sanitize_prompt_for_filename(
|
| 450 |
filename = f"t2v_{sanitized_prompt}_{current_seed}.mp4"
|
| 451 |
temp_dir = tempfile.mkdtemp()
|
| 452 |
video_path = os.path.join(temp_dir, filename)
|
|
@@ -457,28 +386,20 @@ def generate_t2v_video(
|
|
| 457 |
return video_path, current_seed, gr.File(value=video_path, visible=True, label=download_label)
|
| 458 |
|
| 459 |
except Exception as e:
|
| 460 |
-
# Broad exception to catch any error during generation and ensure cleanup still happens
|
| 461 |
print(f"❌ An error occurred during video generation: {e}")
|
| 462 |
traceback.print_exc()
|
| 463 |
raise gr.Error("Video generation failed. Please check the logs for details.")
|
| 464 |
|
| 465 |
finally:
|
| 466 |
-
# --- CLEANUP ---
|
| 467 |
-
# This block ensures the dynamic LoRA is removed after every run,
|
| 468 |
-
# resetting the pipeline to a clean state for the next user.
|
| 469 |
if lora_loaded:
|
| 470 |
print(f"🧼 Cleaning up dynamic LoRA: {selected_lora}")
|
| 471 |
try:
|
| 472 |
-
# --- FIX: Use delete_adapters to correctly remove the adapter ---
|
| 473 |
t2v_pipe.delete_adapters([DYNAMIC_LORA_ADAPTER_NAME])
|
| 474 |
-
# IMPORTANT: Reset adapters back to just the base LoRA for the next run.
|
| 475 |
t2v_pipe.set_adapters([FUSIONX_ADAPTER_NAME], adapter_weights=[FUSIONX_LORA_WEIGHT])
|
| 476 |
print("✅ Cleanup complete. Pipeline reset to base LoRA state.")
|
| 477 |
except Exception as e:
|
| 478 |
print(f"⚠️ Error during LoRA cleanup: {e}. State may be inconsistent.")
|
| 479 |
|
| 480 |
-
# Explicitly empty the CUDA cache to help release GPU memory.
|
| 481 |
-
# This can help resolve intermittent issues where the GPU remains active.
|
| 482 |
print("🧹 Clearing CUDA cache after video generation...")
|
| 483 |
torch.cuda.empty_cache()
|
| 484 |
|
|
@@ -488,11 +409,8 @@ def generate_t2v_video(
|
|
| 488 |
def build_ui(t2v_pipe, enhancer_pipe, available_loras):
|
| 489 |
"""Creates and configures the Gradio UI."""
|
| 490 |
with gr.Blocks(theme=gr.themes.Soft(), css=".main-container { max-width: 1080px; margin: auto; }") as demo:
|
| 491 |
-
# --- Add a state component to reliably store the user's base prompt ---
|
| 492 |
-
base_prompt_state = gr.State(value=DEFAULT_PROMPT_T2V)
|
| 493 |
-
|
| 494 |
gr.Markdown("# ✨ Wan 2.1 Text-to-Video Suite with Dynamic LoRAs")
|
| 495 |
-
gr.Markdown("Generate videos from text
|
| 496 |
|
| 497 |
with gr.Tabs():
|
| 498 |
with gr.TabItem("✍️ Text-to-Video", id="t2v_tab", interactive=t2v_pipe is not None):
|
|
@@ -505,22 +423,17 @@ def build_ui(t2v_pipe, enhancer_pipe, available_loras):
|
|
| 505 |
label="✏️ Prompt", value=DEFAULT_PROMPT_T2V, lines=4,
|
| 506 |
placeholder="e.g., A cinematic drone shot flying over a futuristic city at night..."
|
| 507 |
)
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
)
|
| 513 |
-
# --- FIX: Add a button to explicitly set the base prompt ---
|
| 514 |
-
set_base_btn = gr.Button(
|
| 515 |
-
"📌 Set as Base Prompt"
|
| 516 |
-
)
|
| 517 |
|
| 518 |
with gr.Group():
|
| 519 |
t2v_lora_preset = gr.Dropdown(
|
| 520 |
label="🎨 Dynamic Style LoRA (Optional)",
|
| 521 |
choices=available_loras,
|
| 522 |
value="None",
|
| 523 |
-
info="
|
| 524 |
)
|
| 525 |
t2v_lora_weight = gr.Slider(
|
| 526 |
label="💪 LoRA Weight", minimum=0.0, maximum=2.0, step=0.05, value=0.8,
|
|
@@ -549,40 +462,24 @@ def build_ui(t2v_pipe, enhancer_pipe, available_loras):
|
|
| 549 |
t2v_download = gr.File(label="📥 Download Video", visible=False)
|
| 550 |
|
| 551 |
if t2v_pipe is not None:
|
| 552 |
-
# Create a partial function that has the enhancer_pipe "baked in".
|
| 553 |
enhance_fn = partial(enhance_prompt_with_llm, enhancer_pipeline=enhancer_pipe)
|
| 554 |
|
| 555 |
-
#
|
| 556 |
-
|
| 557 |
-
# 1. When user clicks the "Set as Base" button after a manual edit.
|
| 558 |
-
set_base_btn.click(
|
| 559 |
-
fn=set_base_prompt,
|
| 560 |
-
inputs=[t2v_prompt],
|
| 561 |
-
outputs=[base_prompt_state, t2v_lora_preset, t2v_lora_weight],
|
| 562 |
-
queue=False
|
| 563 |
-
)
|
| 564 |
-
|
| 565 |
-
# 2. When the user enhances the prompt with the LLM. This also creates a new base prompt.
|
| 566 |
t2v_enhance_btn.click(
|
| 567 |
fn=enhance_fn,
|
| 568 |
inputs=[t2v_prompt],
|
| 569 |
-
# The enhance function now also resets the LoRA dropdown and slider
|
| 570 |
outputs=[t2v_prompt, t2v_lora_preset, t2v_lora_weight]
|
| 571 |
-
).then(
|
| 572 |
-
fn=lambda p: p, # A simple function to pass the new prompt through
|
| 573 |
-
inputs=[t2v_prompt],
|
| 574 |
-
outputs=[base_prompt_state] # Update the base prompt state with the enhanced version
|
| 575 |
)
|
| 576 |
|
| 577 |
-
#
|
| 578 |
t2v_lora_preset.change(
|
| 579 |
fn=handle_lora_selection_change,
|
| 580 |
-
#
|
| 581 |
-
inputs=[t2v_lora_preset,
|
| 582 |
outputs=[t2v_prompt, t2v_lora_weight]
|
| 583 |
)
|
| 584 |
|
| 585 |
-
#
|
| 586 |
t2v_generate_btn.click(
|
| 587 |
fn=generate_t2v_video,
|
| 588 |
inputs=[
|
|
@@ -599,7 +496,6 @@ def build_ui(t2v_pipe, enhancer_pipe, available_loras):
|
|
| 599 |
if __name__ == "__main__":
|
| 600 |
t2v_pipe, enhancer_pipe = load_pipelines()
|
| 601 |
|
| 602 |
-
# Fetch LoRAs only if the main pipeline loaded successfully
|
| 603 |
available_loras = []
|
| 604 |
if t2v_pipe:
|
| 605 |
available_loras = get_available_presets(DYNAMIC_LORA_REPO_ID, DYNAMIC_LORA_SUBFOLDER)
|
|
|
|
| 139 |
resolved_prompt = resolved_prompt.replace(f"{{{key}}}", highlighted_value)
|
| 140 |
return resolved_prompt
|
| 141 |
|
| 142 |
+
def handle_lora_selection_change(preset_name: str, current_prompt: str):
|
| 143 |
"""
|
| 144 |
+
Appends the selected LoRA's trigger words to the current prompt text
|
| 145 |
+
and controls the visibility of the weight slider.
|
| 146 |
"""
|
| 147 |
+
# If "None" is selected, do nothing to the prompt and hide the slider.
|
|
|
|
|
|
|
|
|
|
| 148 |
if not preset_name or preset_name == "None":
|
| 149 |
gr.Info("LoRA cleared.")
|
| 150 |
+
# Return the prompt unchanged, and hide the slider.
|
| 151 |
+
return gr.update(value=current_prompt), gr.update(visible=False, interactive=False)
|
| 152 |
|
| 153 |
try:
|
| 154 |
+
# Fetch the trigger words from the LoRA's .lset file.
|
| 155 |
lset_filename = f"{preset_name}.lset"
|
| 156 |
lset_path = hf_hub_download(
|
| 157 |
repo_id=DYNAMIC_LORA_REPO_ID,
|
|
|
|
| 165 |
lset_data = json.loads(lset_content)
|
| 166 |
lset_prompt_raw = lset_data.get("prompt")
|
| 167 |
except json.JSONDecodeError:
|
|
|
|
| 168 |
lset_prompt_raw = lset_content
|
| 169 |
|
| 170 |
if lset_prompt_raw:
|
| 171 |
+
# Append the new trigger words to the current prompt.
|
| 172 |
+
trigger_words = parse_lset_prompt(lset_prompt_raw)
|
| 173 |
+
separator = ", " if current_prompt and not current_prompt.endswith((",", " ")) else ""
|
| 174 |
+
new_prompt = f"{current_prompt}{separator}{trigger_words}".strip()
|
| 175 |
+
gr.Info(f"✅ Appended triggers from '{preset_name}'. You can now edit them.")
|
| 176 |
+
# Return the updated prompt and show the slider.
|
| 177 |
+
return gr.update(value=new_prompt), gr.update(visible=True, interactive=True)
|
|
|
|
|
|
|
|
|
|
| 178 |
else:
|
| 179 |
+
# If the .lset file has no prompt, do nothing.
|
| 180 |
gr.Info(f"ℹ️ No prompt found in '{preset_name}.lset'. Prompt unchanged.")
|
| 181 |
+
return gr.update(value=current_prompt), gr.update(visible=True, interactive=True)
|
|
|
|
| 182 |
|
| 183 |
except Exception as e:
|
| 184 |
print(f"Info: Could not process .lset for '{preset_name}'. Reason: {e}")
|
| 185 |
gr.Warning(f"⚠️ Could not load triggers for '{preset_name}'.")
|
| 186 |
+
# On error, return the prompt unchanged but still show the slider.
|
| 187 |
+
return gr.update(value=current_prompt), gr.update(visible=True, interactive=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
|
| 189 |
|
| 190 |
def _manage_lora_state(pipe, selected_lora: str, lora_weight: float) -> bool:
|
|
|
|
| 195 |
bool: True if a dynamic LoRA was loaded, False otherwise.
|
| 196 |
"""
|
| 197 |
# Pre-emptive cleanup of any previously loaded dynamic adapter.
|
|
|
|
| 198 |
try:
|
|
|
|
| 199 |
pipe.delete_adapters([DYNAMIC_LORA_ADAPTER_NAME])
|
| 200 |
print("🧼 Pre-emptively unloaded previous dynamic LoRA.")
|
| 201 |
except Exception:
|
| 202 |
pass # No dynamic lora was present, which is a clean state.
|
| 203 |
|
| 204 |
if not selected_lora or selected_lora == "None":
|
|
|
|
| 205 |
pipe.set_adapters([FUSIONX_ADAPTER_NAME], adapter_weights=[FUSIONX_LORA_WEIGHT])
|
| 206 |
print("ℹ️ No dynamic LoRA selected. Using base LoRA only.")
|
| 207 |
return False
|
|
|
|
| 210 |
print(f"🚀 Processing preset: {selected_lora} with weight {lora_weight}")
|
| 211 |
lora_filename = None
|
| 212 |
try:
|
|
|
|
| 213 |
lset_filename = f"{selected_lora}.lset"
|
| 214 |
lset_path = hf_hub_download(
|
| 215 |
repo_id=DYNAMIC_LORA_REPO_ID,
|
|
|
|
| 251 |
|
| 252 |
print("\n🚀 Loading T2V pipeline with base LoRA...")
|
| 253 |
try:
|
|
|
|
|
|
|
| 254 |
t2v_pipe = DiffusionPipeline.from_pretrained(
|
| 255 |
T2V_BASE_MODEL_ID,
|
| 256 |
+
torch_dtype=torch.bfloat16,
|
| 257 |
)
|
| 258 |
print("✅ Base pipeline loaded. Overriding VAE with float32 version...")
|
|
|
|
|
|
|
|
|
|
| 259 |
vae_fp32 = AutoencoderKLWan.from_pretrained(T2V_BASE_MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
|
| 260 |
t2v_pipe.vae = vae_fp32
|
|
|
|
| 261 |
t2v_pipe.to("cuda")
|
| 262 |
print("✅ Pipeline configured. Loading and activating base FusionX LoRA...")
|
|
|
|
|
|
|
| 263 |
t2v_pipe.load_lora_weights(FUSIONX_LORA_REPO, weight_name=FUSIONX_LORA_FILE, adapter_name=FUSIONX_ADAPTER_NAME)
|
| 264 |
t2v_pipe.set_adapters([FUSIONX_ADAPTER_NAME], adapter_weights=[FUSIONX_LORA_WEIGHT])
|
| 265 |
print("✅ T2V pipeline with base LoRA is ready.")
|
|
|
|
| 266 |
except Exception as e:
|
| 267 |
print(f"❌ CRITICAL ERROR: Failed to load T2V pipeline. T2V will be disabled. Reason: {e}")
|
| 268 |
traceback.print_exc()
|
|
|
|
| 270 |
|
| 271 |
print("\n🤖 Loading LLM for Prompt Enhancement...")
|
| 272 |
try:
|
|
|
|
|
|
|
| 273 |
enhancer_pipe = pipeline("text-generation", model=ENHANCER_MODEL_ID, torch_dtype=torch.bfloat16, device="cpu")
|
| 274 |
print("✅ LLM Prompt Enhancer loaded successfully (on CPU).")
|
| 275 |
except Exception as e:
|
|
|
|
| 285 |
def enhance_prompt_with_llm(prompt: str, enhancer_pipeline):
|
| 286 |
"""
|
| 287 |
Uses the loaded LLM to enhance a given prompt.
|
|
|
|
| 288 |
"""
|
| 289 |
if enhancer_pipeline is None:
|
|
|
|
| 290 |
gr.Warning("LLM enhancer is not available.")
|
| 291 |
+
return prompt, gr.update(), gr.update()
|
| 292 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
if enhancer_pipeline.model.device.type != 'cuda':
|
| 294 |
print("Moving enhancer model to CUDA for on-demand GPU execution...")
|
| 295 |
enhancer_pipeline.model.to("cuda")
|
|
|
|
| 298 |
print(f"Enhancing prompt: '{prompt}'")
|
| 299 |
|
| 300 |
try:
|
|
|
|
| 301 |
tokenizer = enhancer_pipeline.tokenizer
|
|
|
|
|
|
|
|
|
|
| 302 |
if tokenizer.pad_token is None:
|
| 303 |
tokenizer.pad_token = tokenizer.eos_token
|
| 304 |
+
tokenized_inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
|
|
|
|
|
|
|
|
|
|
| 306 |
if isinstance(tokenized_inputs, torch.Tensor):
|
|
|
|
| 307 |
inputs_on_cuda = {"input_ids": tokenized_inputs.to("cuda")}
|
|
|
|
| 308 |
inputs_on_cuda["attention_mask"] = torch.ones_like(inputs_on_cuda["input_ids"])
|
| 309 |
else:
|
|
|
|
| 310 |
inputs_on_cuda = {k: v.to("cuda") for k, v in tokenized_inputs.items()}
|
| 311 |
|
| 312 |
+
generated_ids = enhancer_pipeline.model.generate(**inputs_on_cuda, max_new_tokens=256, do_sample=True, temperature=0.7, top_p=0.95)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
input_token_length = inputs_on_cuda['input_ids'].shape[1]
|
| 314 |
newly_generated_ids = generated_ids[:, input_token_length:]
|
|
|
|
|
|
|
| 315 |
final_answer = tokenizer.decode(newly_generated_ids[0], skip_special_tokens=True)
|
| 316 |
|
| 317 |
print(f"Enhanced prompt: '{final_answer.strip()}'")
|
| 318 |
+
# The enhanced prompt overwrites the textbox. The LoRA selection is reset.
|
| 319 |
return final_answer.strip(), "None", gr.update(visible=False, interactive=False)
|
| 320 |
except Exception as e:
|
| 321 |
print(f"❌ Error during prompt enhancement: {e}")
|
|
|
|
| 322 |
traceback.print_exc()
|
| 323 |
gr.Warning(f"An error occurred during prompt enhancement. See console for details.")
|
| 324 |
+
return prompt, gr.update(), gr.update()
|
| 325 |
finally:
|
|
|
|
|
|
|
| 326 |
print("🧹 Clearing CUDA cache after prompt enhancement...")
|
| 327 |
torch.cuda.empty_cache()
|
| 328 |
|
|
|
|
| 341 |
if not prompt:
|
| 342 |
raise gr.Error("Please enter a prompt for Text-to-Video generation.")
|
| 343 |
|
| 344 |
+
# --- The prompt from the textbox is now the final prompt. No more combining. ---
|
| 345 |
+
final_prompt = prompt
|
| 346 |
+
|
| 347 |
target_h = max(MOD_VALUE, (height // MOD_VALUE) * MOD_VALUE)
|
| 348 |
target_w = max(MOD_VALUE, (width // MOD_VALUE) * MOD_VALUE)
|
| 349 |
|
|
|
|
|
|
|
| 350 |
requested_frames = int(round(duration_seconds * T2V_FIXED_FPS))
|
| 351 |
frames_minus_one = requested_frames - 1
|
| 352 |
valid_frames_minus_one = round(frames_minus_one / 4.0) * 4
|
|
|
|
| 359 |
lora_loaded = False
|
| 360 |
|
| 361 |
try:
|
| 362 |
+
lora_loaded = _manage_lora_state(pipe=t2v_pipe, selected_lora=selected_lora, lora_weight=lora_weight)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
|
| 364 |
print("\n--- Starting T2V Generation ---")
|
| 365 |
+
print(f"Final Prompt: {final_prompt}")
|
| 366 |
print(f"Resolution: {target_w}x{target_h}, Frames: {num_frames}, Seed: {current_seed}")
|
| 367 |
print(f"Steps: {steps}, Guidance: 1.0 (fixed for FusionX)")
|
| 368 |
print("---------------------------------")
|
| 369 |
|
| 370 |
with torch.inference_mode():
|
| 371 |
output_frames_list = t2v_pipe(
|
| 372 |
+
prompt=final_prompt, negative_prompt=negative_prompt,
|
| 373 |
height=target_h, width=target_w, num_frames=num_frames,
|
| 374 |
guidance_scale=1.0, num_inference_steps=int(steps),
|
| 375 |
generator=torch.Generator(device="cuda").manual_seed(current_seed)
|
| 376 |
).frames[0]
|
| 377 |
|
| 378 |
+
sanitized_prompt = sanitize_prompt_for_filename(final_prompt)
|
| 379 |
filename = f"t2v_{sanitized_prompt}_{current_seed}.mp4"
|
| 380 |
temp_dir = tempfile.mkdtemp()
|
| 381 |
video_path = os.path.join(temp_dir, filename)
|
|
|
|
| 386 |
return video_path, current_seed, gr.File(value=video_path, visible=True, label=download_label)
|
| 387 |
|
| 388 |
except Exception as e:
|
|
|
|
| 389 |
print(f"❌ An error occurred during video generation: {e}")
|
| 390 |
traceback.print_exc()
|
| 391 |
raise gr.Error("Video generation failed. Please check the logs for details.")
|
| 392 |
|
| 393 |
finally:
|
|
|
|
|
|
|
|
|
|
| 394 |
if lora_loaded:
|
| 395 |
print(f"🧼 Cleaning up dynamic LoRA: {selected_lora}")
|
| 396 |
try:
|
|
|
|
| 397 |
t2v_pipe.delete_adapters([DYNAMIC_LORA_ADAPTER_NAME])
|
|
|
|
| 398 |
t2v_pipe.set_adapters([FUSIONX_ADAPTER_NAME], adapter_weights=[FUSIONX_LORA_WEIGHT])
|
| 399 |
print("✅ Cleanup complete. Pipeline reset to base LoRA state.")
|
| 400 |
except Exception as e:
|
| 401 |
print(f"⚠️ Error during LoRA cleanup: {e}. State may be inconsistent.")
|
| 402 |
|
|
|
|
|
|
|
| 403 |
print("🧹 Clearing CUDA cache after video generation...")
|
| 404 |
torch.cuda.empty_cache()
|
| 405 |
|
|
|
|
| 409 |
def build_ui(t2v_pipe, enhancer_pipe, available_loras):
|
| 410 |
"""Creates and configures the Gradio UI."""
|
| 411 |
with gr.Blocks(theme=gr.themes.Soft(), css=".main-container { max-width: 1080px; margin: auto; }") as demo:
|
|
|
|
|
|
|
|
|
|
| 412 |
gr.Markdown("# ✨ Wan 2.1 Text-to-Video Suite with Dynamic LoRAs")
|
| 413 |
+
gr.Markdown("Generate videos from text. Edit the prompt below. Selecting a LoRA will append its triggers to your prompt.")
|
| 414 |
|
| 415 |
with gr.Tabs():
|
| 416 |
with gr.TabItem("✍️ Text-to-Video", id="t2v_tab", interactive=t2v_pipe is not None):
|
|
|
|
| 423 |
label="✏️ Prompt", value=DEFAULT_PROMPT_T2V, lines=4,
|
| 424 |
placeholder="e.g., A cinematic drone shot flying over a futuristic city at night..."
|
| 425 |
)
|
| 426 |
+
t2v_enhance_btn = gr.Button(
|
| 427 |
+
"🤖 Enhance Prompt with AI",
|
| 428 |
+
interactive=enhancer_pipe is not None
|
| 429 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 430 |
|
| 431 |
with gr.Group():
|
| 432 |
t2v_lora_preset = gr.Dropdown(
|
| 433 |
label="🎨 Dynamic Style LoRA (Optional)",
|
| 434 |
choices=available_loras,
|
| 435 |
value="None",
|
| 436 |
+
info="Appends style triggers to the prompt text above."
|
| 437 |
)
|
| 438 |
t2v_lora_weight = gr.Slider(
|
| 439 |
label="💪 LoRA Weight", minimum=0.0, maximum=2.0, step=0.05, value=0.8,
|
|
|
|
| 462 |
t2v_download = gr.File(label="📥 Download Video", visible=False)
|
| 463 |
|
| 464 |
if t2v_pipe is not None:
|
|
|
|
| 465 |
enhance_fn = partial(enhance_prompt_with_llm, enhancer_pipeline=enhancer_pipe)
|
| 466 |
|
| 467 |
+
# 1. When the user enhances the prompt with the LLM.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 468 |
t2v_enhance_btn.click(
|
| 469 |
fn=enhance_fn,
|
| 470 |
inputs=[t2v_prompt],
|
|
|
|
| 471 |
outputs=[t2v_prompt, t2v_lora_preset, t2v_lora_weight]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 472 |
)
|
| 473 |
|
| 474 |
+
# 2. When the user selects a LoRA from the dropdown.
|
| 475 |
t2v_lora_preset.change(
|
| 476 |
fn=handle_lora_selection_change,
|
| 477 |
+
# Pass the current prompt text in, get the new text back out.
|
| 478 |
+
inputs=[t2v_lora_preset, t2v_prompt],
|
| 479 |
outputs=[t2v_prompt, t2v_lora_weight]
|
| 480 |
)
|
| 481 |
|
| 482 |
+
# 3. When the user clicks the final generate button.
|
| 483 |
t2v_generate_btn.click(
|
| 484 |
fn=generate_t2v_video,
|
| 485 |
inputs=[
|
|
|
|
| 496 |
if __name__ == "__main__":
|
| 497 |
t2v_pipe, enhancer_pipe = load_pipelines()
|
| 498 |
|
|
|
|
| 499 |
available_loras = []
|
| 500 |
if t2v_pipe:
|
| 501 |
available_loras = get_available_presets(DYNAMIC_LORA_REPO_ID, DYNAMIC_LORA_SUBFOLDER)
|