thankfulcarp commited on
Commit
73ac5f9
·
1 Parent(s): 4b9047d

Big code cleanup

Browse files
Files changed (1) hide show
  1. app.py +42 -146
app.py CHANGED
@@ -139,20 +139,19 @@ def parse_lset_prompt(lset_prompt):
139
  resolved_prompt = resolved_prompt.replace(f"{{{key}}}", highlighted_value)
140
  return resolved_prompt
141
 
142
- def handle_lora_selection_change(preset_name: str, base_prompt: str):
143
  """
144
- When a preset is selected, this function combines the base_prompt (from state)
145
- with the new LoRA's trigger words. This makes the update idempotent.
146
  """
147
- # Default state for the slider is hidden and non-interactive
148
- lora_slider_update = gr.update(visible=False, interactive=False)
149
-
150
- # If "None" is selected, the displayed prompt is just the base prompt.
151
  if not preset_name or preset_name == "None":
152
  gr.Info("LoRA cleared.")
153
- return gr.update(value=base_prompt), lora_slider_update
 
154
 
155
  try:
 
156
  lset_filename = f"{preset_name}.lset"
157
  lset_path = hf_hub_download(
158
  repo_id=DYNAMIC_LORA_REPO_ID,
@@ -166,38 +165,26 @@ def handle_lora_selection_change(preset_name: str, base_prompt: str):
166
  lset_data = json.loads(lset_content)
167
  lset_prompt_raw = lset_data.get("prompt")
168
  except json.JSONDecodeError:
169
- print(f"Info: '{lset_filename}' is not JSON. Treating as plain text prompt.")
170
  lset_prompt_raw = lset_content
171
 
172
  if lset_prompt_raw:
173
- resolved_prompt = parse_lset_prompt(lset_prompt_raw)
174
- separator = ", " if base_prompt and not base_prompt.endswith((",", " ")) else ""
175
- # The new prompt is always constructed from the base prompt and the new triggers.
176
- new_prompt = f"{base_prompt}{separator}{resolved_prompt}".strip()
177
- gr.Info(f"✅ Appended triggers from '{preset_name}'. Replace highlighted text like __this__.")
178
-
179
- # On success, update the prompt and make the slider visible and interactive
180
- new_prompt_update = gr.update(value=new_prompt)
181
- lora_slider_update = gr.update(visible=True, interactive=True)
182
- return new_prompt_update, lora_slider_update
183
  else:
 
184
  gr.Info(f"ℹ️ No prompt found in '{preset_name}.lset'. Prompt unchanged.")
185
- # If no triggers, the prompt is just the base prompt.
186
- return gr.update(value=base_prompt), lora_slider_update
187
 
188
  except Exception as e:
189
  print(f"Info: Could not process .lset for '{preset_name}'. Reason: {e}")
190
  gr.Warning(f"⚠️ Could not load triggers for '{preset_name}'.")
191
- # On error, revert to just the base prompt.
192
- return gr.update(value=base_prompt), lora_slider_update
193
-
194
- def set_base_prompt(current_prompt_text):
195
- """
196
- Called when the user clicks the 'Set as Base Prompt' button.
197
- This updates the base prompt state and resets the LoRA selection.
198
- """
199
- gr.Info("New base prompt set. You can now select a LoRA to add triggers.")
200
- return current_prompt_text, "None", gr.update(visible=False, interactive=False)
201
 
202
 
203
  def _manage_lora_state(pipe, selected_lora: str, lora_weight: float) -> bool:
@@ -208,16 +195,13 @@ def _manage_lora_state(pipe, selected_lora: str, lora_weight: float) -> bool:
208
  bool: True if a dynamic LoRA was loaded, False otherwise.
209
  """
210
  # Pre-emptive cleanup of any previously loaded dynamic adapter.
211
- # This is more robust than relying only on the `finally` block of the previous run.
212
  try:
213
- # --- FIX: Use delete_adapters to remove a specific adapter by name ---
214
  pipe.delete_adapters([DYNAMIC_LORA_ADAPTER_NAME])
215
  print("🧼 Pre-emptively unloaded previous dynamic LoRA.")
216
  except Exception:
217
  pass # No dynamic lora was present, which is a clean state.
218
 
219
  if not selected_lora or selected_lora == "None":
220
- # This run uses no dynamic LoRA. Ensure only the base LoRA is active.
221
  pipe.set_adapters([FUSIONX_ADAPTER_NAME], adapter_weights=[FUSIONX_LORA_WEIGHT])
222
  print("ℹ️ No dynamic LoRA selected. Using base LoRA only.")
223
  return False
@@ -226,7 +210,6 @@ def _manage_lora_state(pipe, selected_lora: str, lora_weight: float) -> bool:
226
  print(f"🚀 Processing preset: {selected_lora} with weight {lora_weight}")
227
  lora_filename = None
228
  try:
229
- # Attempt to get the real LoRA filename from the .lset file
230
  lset_filename = f"{selected_lora}.lset"
231
  lset_path = hf_hub_download(
232
  repo_id=DYNAMIC_LORA_REPO_ID,
@@ -268,27 +251,18 @@ def load_pipelines():
268
 
269
  print("\n🚀 Loading T2V pipeline with base LoRA...")
270
  try:
271
- # To avoid potential cache duplication and storage issues, we load the
272
- # pipeline directly, then replace the VAE with a float32 version for stability.
273
  t2v_pipe = DiffusionPipeline.from_pretrained(
274
  T2V_BASE_MODEL_ID,
275
- torch_dtype=torch.bfloat16, # Load in bfloat16 to save memory
276
  )
277
  print("✅ Base pipeline loaded. Overriding VAE with float32 version...")
278
-
279
- # The VAE often works better in float32. We reload it and replace it in the pipeline.
280
- # Using the specific AutoencoderKLWan class is more robust than the generic AutoModel.
281
  vae_fp32 = AutoencoderKLWan.from_pretrained(T2V_BASE_MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
282
  t2v_pipe.vae = vae_fp32
283
-
284
  t2v_pipe.to("cuda")
285
  print("✅ Pipeline configured. Loading and activating base FusionX LoRA...")
286
-
287
- # Load and set the base LoRA that is always active
288
  t2v_pipe.load_lora_weights(FUSIONX_LORA_REPO, weight_name=FUSIONX_LORA_FILE, adapter_name=FUSIONX_ADAPTER_NAME)
289
  t2v_pipe.set_adapters([FUSIONX_ADAPTER_NAME], adapter_weights=[FUSIONX_LORA_WEIGHT])
290
  print("✅ T2V pipeline with base LoRA is ready.")
291
-
292
  except Exception as e:
293
  print(f"❌ CRITICAL ERROR: Failed to load T2V pipeline. T2V will be disabled. Reason: {e}")
294
  traceback.print_exc()
@@ -296,8 +270,6 @@ def load_pipelines():
296
 
297
  print("\n🤖 Loading LLM for Prompt Enhancement...")
298
  try:
299
- # In a ZeroGPU environment, we must load models on the CPU at startup.
300
- # The model will be moved to the GPU inside the decorated function.
301
  enhancer_pipe = pipeline("text-generation", model=ENHANCER_MODEL_ID, torch_dtype=torch.bfloat16, device="cpu")
302
  print("✅ LLM Prompt Enhancer loaded successfully (on CPU).")
303
  except Exception as e:
@@ -313,17 +285,11 @@ def load_pipelines():
313
  def enhance_prompt_with_llm(prompt: str, enhancer_pipeline):
314
  """
315
  Uses the loaded LLM to enhance a given prompt.
316
- This function manually handles tensor placement to avoid device mismatches in a ZeroGPU environment.
317
  """
318
  if enhancer_pipeline is None:
319
- print("LLM enhancer not available, returning original prompt.")
320
  gr.Warning("LLM enhancer is not available.")
321
- return prompt
322
 
323
- # In a Hugging Face ZeroGPU Space, the GPU is provisioned on-demand for functions
324
- # decorated with @spaces.GPU and de-provisioned afterward. Therefore, the model,
325
- # which is loaded on the CPU at startup, must be moved to the GPU for every call.
326
- # The "Moving enhancer model to CUDA..." message is expected and correct for this setup.
327
  if enhancer_pipeline.model.device.type != 'cuda':
328
  print("Moving enhancer model to CUDA for on-demand GPU execution...")
329
  enhancer_pipeline.model.to("cuda")
@@ -332,64 +298,31 @@ def enhance_prompt_with_llm(prompt: str, enhancer_pipeline):
332
  print(f"Enhancing prompt: '{prompt}'")
333
 
334
  try:
335
- # 1. Get the tokenizer from the pipeline.
336
  tokenizer = enhancer_pipeline.tokenizer
337
-
338
- # FIX: Set pad_token to eos_token if not set. This is a common requirement for
339
- # models like Qwen2 and helps prevent warnings about attention masks.
340
  if tokenizer.pad_token is None:
341
  tokenizer.pad_token = tokenizer.eos_token
342
-
343
- # 2. Apply the chat template and tokenize. This returns a dictionary containing
344
- # 'input_ids' and 'attention_mask' as PyTorch tensors.
345
- tokenized_inputs = tokenizer.apply_chat_template(
346
- messages,
347
- tokenize=True,
348
- add_generation_prompt=True,
349
- return_tensors="pt"
350
- )
351
 
352
- # 3. FIX: Move each tensor in the dictionary to the CUDA device.
353
- # 3. FIX: The tokenizer might return a single tensor instead of a dictionary.
354
- # We handle both cases to make the code more robust.
355
  if isinstance(tokenized_inputs, torch.Tensor):
356
- # If we get a single tensor, assume it's input_ids
357
  inputs_on_cuda = {"input_ids": tokenized_inputs.to("cuda")}
358
- # Manually create the attention mask as it's good practice for generate()
359
  inputs_on_cuda["attention_mask"] = torch.ones_like(inputs_on_cuda["input_ids"])
360
  else:
361
- # If we get a dictionary, move all its tensors to cuda
362
  inputs_on_cuda = {k: v.to("cuda") for k, v in tokenized_inputs.items()}
363
 
364
- # 4. Use the model's generate() method, unpacking the dictionary to pass
365
- # both `input_ids` and `attention_mask`. This resolves the warning.
366
- generated_ids = enhancer_pipeline.model.generate(
367
- **inputs_on_cuda,
368
- max_new_tokens=256,
369
- do_sample=True,
370
- temperature=0.7,
371
- top_p=0.95
372
- )
373
-
374
- # 5. The output from generate() includes the input tokens. We need to decode only the newly generated part.
375
  input_token_length = inputs_on_cuda['input_ids'].shape[1]
376
  newly_generated_ids = generated_ids[:, input_token_length:]
377
-
378
- # 6. Decode the new tokens back into a string.
379
  final_answer = tokenizer.decode(newly_generated_ids[0], skip_special_tokens=True)
380
 
381
  print(f"Enhanced prompt: '{final_answer.strip()}'")
382
- # Return the enhanced prompt and also reset the LoRA dropdown
383
  return final_answer.strip(), "None", gr.update(visible=False, interactive=False)
384
  except Exception as e:
385
  print(f"❌ Error during prompt enhancement: {e}")
386
- # Adding full traceback for better debugging in the console
387
  traceback.print_exc()
388
  gr.Warning(f"An error occurred during prompt enhancement. See console for details.")
389
- return prompt, gr.update(), gr.update() # Return original prompt, don't change LoRA
390
  finally:
391
- # Explicitly empty the CUDA cache to help release GPU memory.
392
- # This can help resolve intermittent issues where the GPU remains active.
393
  print("🧹 Clearing CUDA cache after prompt enhancement...")
394
  torch.cuda.empty_cache()
395
 
@@ -408,11 +341,12 @@ def generate_t2v_video(
408
  if not prompt:
409
  raise gr.Error("Please enter a prompt for Text-to-Video generation.")
410
 
 
 
 
411
  target_h = max(MOD_VALUE, (height // MOD_VALUE) * MOD_VALUE)
412
  target_w = max(MOD_VALUE, (width // MOD_VALUE) * MOD_VALUE)
413
 
414
- # Calculate the initial number of frames based on duration and model constraints.
415
- # The model requires (num_frames - 1) to be divisible by 4.
416
  requested_frames = int(round(duration_seconds * T2V_FIXED_FPS))
417
  frames_minus_one = requested_frames - 1
418
  valid_frames_minus_one = round(frames_minus_one / 4.0) * 4
@@ -425,28 +359,23 @@ def generate_t2v_video(
425
  lora_loaded = False
426
 
427
  try:
428
- # Manage the LoRA state and get a flag indicating if a dynamic LoRA was loaded.
429
- lora_loaded = _manage_lora_state(
430
- pipe=t2v_pipe,
431
- selected_lora=selected_lora,
432
- lora_weight=lora_weight
433
- )
434
 
435
  print("\n--- Starting T2V Generation ---")
436
- print(f"Prompt: {prompt}")
437
  print(f"Resolution: {target_w}x{target_h}, Frames: {num_frames}, Seed: {current_seed}")
438
  print(f"Steps: {steps}, Guidance: 1.0 (fixed for FusionX)")
439
  print("---------------------------------")
440
 
441
  with torch.inference_mode():
442
  output_frames_list = t2v_pipe(
443
- prompt=prompt, negative_prompt=negative_prompt,
444
  height=target_h, width=target_w, num_frames=num_frames,
445
  guidance_scale=1.0, num_inference_steps=int(steps),
446
  generator=torch.Generator(device="cuda").manual_seed(current_seed)
447
  ).frames[0]
448
 
449
- sanitized_prompt = sanitize_prompt_for_filename(prompt)
450
  filename = f"t2v_{sanitized_prompt}_{current_seed}.mp4"
451
  temp_dir = tempfile.mkdtemp()
452
  video_path = os.path.join(temp_dir, filename)
@@ -457,28 +386,20 @@ def generate_t2v_video(
457
  return video_path, current_seed, gr.File(value=video_path, visible=True, label=download_label)
458
 
459
  except Exception as e:
460
- # Broad exception to catch any error during generation and ensure cleanup still happens
461
  print(f"❌ An error occurred during video generation: {e}")
462
  traceback.print_exc()
463
  raise gr.Error("Video generation failed. Please check the logs for details.")
464
 
465
  finally:
466
- # --- CLEANUP ---
467
- # This block ensures the dynamic LoRA is removed after every run,
468
- # resetting the pipeline to a clean state for the next user.
469
  if lora_loaded:
470
  print(f"🧼 Cleaning up dynamic LoRA: {selected_lora}")
471
  try:
472
- # --- FIX: Use delete_adapters to correctly remove the adapter ---
473
  t2v_pipe.delete_adapters([DYNAMIC_LORA_ADAPTER_NAME])
474
- # IMPORTANT: Reset adapters back to just the base LoRA for the next run.
475
  t2v_pipe.set_adapters([FUSIONX_ADAPTER_NAME], adapter_weights=[FUSIONX_LORA_WEIGHT])
476
  print("✅ Cleanup complete. Pipeline reset to base LoRA state.")
477
  except Exception as e:
478
  print(f"⚠️ Error during LoRA cleanup: {e}. State may be inconsistent.")
479
 
480
- # Explicitly empty the CUDA cache to help release GPU memory.
481
- # This can help resolve intermittent issues where the GPU remains active.
482
  print("🧹 Clearing CUDA cache after video generation...")
483
  torch.cuda.empty_cache()
484
 
@@ -488,11 +409,8 @@ def generate_t2v_video(
488
  def build_ui(t2v_pipe, enhancer_pipe, available_loras):
489
  """Creates and configures the Gradio UI."""
490
  with gr.Blocks(theme=gr.themes.Soft(), css=".main-container { max-width: 1080px; margin: auto; }") as demo:
491
- # --- Add a state component to reliably store the user's base prompt ---
492
- base_prompt_state = gr.State(value=DEFAULT_PROMPT_T2V)
493
-
494
  gr.Markdown("# ✨ Wan 2.1 Text-to-Video Suite with Dynamic LoRAs")
495
- gr.Markdown("Generate videos from text, enhanced by the base `FusionX` LoRA and your choice of dynamic style LoRA.")
496
 
497
  with gr.Tabs():
498
  with gr.TabItem("✍️ Text-to-Video", id="t2v_tab", interactive=t2v_pipe is not None):
@@ -505,22 +423,17 @@ def build_ui(t2v_pipe, enhancer_pipe, available_loras):
505
  label="✏️ Prompt", value=DEFAULT_PROMPT_T2V, lines=4,
506
  placeholder="e.g., A cinematic drone shot flying over a futuristic city at night..."
507
  )
508
- with gr.Row():
509
- t2v_enhance_btn = gr.Button(
510
- "🤖 Enhance Prompt with AI",
511
- interactive=enhancer_pipe is not None
512
- )
513
- # --- FIX: Add a button to explicitly set the base prompt ---
514
- set_base_btn = gr.Button(
515
- "📌 Set as Base Prompt"
516
- )
517
 
518
  with gr.Group():
519
  t2v_lora_preset = gr.Dropdown(
520
  label="🎨 Dynamic Style LoRA (Optional)",
521
  choices=available_loras,
522
  value="None",
523
- info="Adds a secondary style LoRA. Replaces previous LoRA triggers."
524
  )
525
  t2v_lora_weight = gr.Slider(
526
  label="💪 LoRA Weight", minimum=0.0, maximum=2.0, step=0.05, value=0.8,
@@ -549,40 +462,24 @@ def build_ui(t2v_pipe, enhancer_pipe, available_loras):
549
  t2v_download = gr.File(label="📥 Download Video", visible=False)
550
 
551
  if t2v_pipe is not None:
552
- # Create a partial function that has the enhancer_pipe "baked in".
553
  enhance_fn = partial(enhance_prompt_with_llm, enhancer_pipeline=enhancer_pipe)
554
 
555
- # --- Wire up the new state-based event handlers ---
556
-
557
- # 1. When user clicks the "Set as Base" button after a manual edit.
558
- set_base_btn.click(
559
- fn=set_base_prompt,
560
- inputs=[t2v_prompt],
561
- outputs=[base_prompt_state, t2v_lora_preset, t2v_lora_weight],
562
- queue=False
563
- )
564
-
565
- # 2. When the user enhances the prompt with the LLM. This also creates a new base prompt.
566
  t2v_enhance_btn.click(
567
  fn=enhance_fn,
568
  inputs=[t2v_prompt],
569
- # The enhance function now also resets the LoRA dropdown and slider
570
  outputs=[t2v_prompt, t2v_lora_preset, t2v_lora_weight]
571
- ).then(
572
- fn=lambda p: p, # A simple function to pass the new prompt through
573
- inputs=[t2v_prompt],
574
- outputs=[base_prompt_state] # Update the base prompt state with the enhanced version
575
  )
576
 
577
- # 3. When the user selects a LoRA from the dropdown.
578
  t2v_lora_preset.change(
579
  fn=handle_lora_selection_change,
580
- # The input is now the reliable base_prompt_state, not the textbox
581
- inputs=[t2v_lora_preset, base_prompt_state],
582
  outputs=[t2v_prompt, t2v_lora_weight]
583
  )
584
 
585
- # 4. When the user clicks the final generate button.
586
  t2v_generate_btn.click(
587
  fn=generate_t2v_video,
588
  inputs=[
@@ -599,7 +496,6 @@ def build_ui(t2v_pipe, enhancer_pipe, available_loras):
599
  if __name__ == "__main__":
600
  t2v_pipe, enhancer_pipe = load_pipelines()
601
 
602
- # Fetch LoRAs only if the main pipeline loaded successfully
603
  available_loras = []
604
  if t2v_pipe:
605
  available_loras = get_available_presets(DYNAMIC_LORA_REPO_ID, DYNAMIC_LORA_SUBFOLDER)
 
139
  resolved_prompt = resolved_prompt.replace(f"{{{key}}}", highlighted_value)
140
  return resolved_prompt
141
 
142
+ def handle_lora_selection_change(preset_name: str, current_prompt: str):
143
  """
144
+ Appends the selected LoRA's trigger words to the current prompt text
145
+ and controls the visibility of the weight slider.
146
  """
147
+ # If "None" is selected, do nothing to the prompt and hide the slider.
 
 
 
148
  if not preset_name or preset_name == "None":
149
  gr.Info("LoRA cleared.")
150
+ # Return the prompt unchanged, and hide the slider.
151
+ return gr.update(value=current_prompt), gr.update(visible=False, interactive=False)
152
 
153
  try:
154
+ # Fetch the trigger words from the LoRA's .lset file.
155
  lset_filename = f"{preset_name}.lset"
156
  lset_path = hf_hub_download(
157
  repo_id=DYNAMIC_LORA_REPO_ID,
 
165
  lset_data = json.loads(lset_content)
166
  lset_prompt_raw = lset_data.get("prompt")
167
  except json.JSONDecodeError:
 
168
  lset_prompt_raw = lset_content
169
 
170
  if lset_prompt_raw:
171
+ # Append the new trigger words to the current prompt.
172
+ trigger_words = parse_lset_prompt(lset_prompt_raw)
173
+ separator = ", " if current_prompt and not current_prompt.endswith((",", " ")) else ""
174
+ new_prompt = f"{current_prompt}{separator}{trigger_words}".strip()
175
+ gr.Info(f"✅ Appended triggers from '{preset_name}'. You can now edit them.")
176
+ # Return the updated prompt and show the slider.
177
+ return gr.update(value=new_prompt), gr.update(visible=True, interactive=True)
 
 
 
178
  else:
179
+ # If the .lset file has no prompt, do nothing.
180
  gr.Info(f"ℹ️ No prompt found in '{preset_name}.lset'. Prompt unchanged.")
181
+ return gr.update(value=current_prompt), gr.update(visible=True, interactive=True)
 
182
 
183
  except Exception as e:
184
  print(f"Info: Could not process .lset for '{preset_name}'. Reason: {e}")
185
  gr.Warning(f"⚠️ Could not load triggers for '{preset_name}'.")
186
+ # On error, return the prompt unchanged but still show the slider.
187
+ return gr.update(value=current_prompt), gr.update(visible=True, interactive=True)
 
 
 
 
 
 
 
 
188
 
189
 
190
  def _manage_lora_state(pipe, selected_lora: str, lora_weight: float) -> bool:
 
195
  bool: True if a dynamic LoRA was loaded, False otherwise.
196
  """
197
  # Pre-emptive cleanup of any previously loaded dynamic adapter.
 
198
  try:
 
199
  pipe.delete_adapters([DYNAMIC_LORA_ADAPTER_NAME])
200
  print("🧼 Pre-emptively unloaded previous dynamic LoRA.")
201
  except Exception:
202
  pass # No dynamic lora was present, which is a clean state.
203
 
204
  if not selected_lora or selected_lora == "None":
 
205
  pipe.set_adapters([FUSIONX_ADAPTER_NAME], adapter_weights=[FUSIONX_LORA_WEIGHT])
206
  print("ℹ️ No dynamic LoRA selected. Using base LoRA only.")
207
  return False
 
210
  print(f"🚀 Processing preset: {selected_lora} with weight {lora_weight}")
211
  lora_filename = None
212
  try:
 
213
  lset_filename = f"{selected_lora}.lset"
214
  lset_path = hf_hub_download(
215
  repo_id=DYNAMIC_LORA_REPO_ID,
 
251
 
252
  print("\n🚀 Loading T2V pipeline with base LoRA...")
253
  try:
 
 
254
  t2v_pipe = DiffusionPipeline.from_pretrained(
255
  T2V_BASE_MODEL_ID,
256
+ torch_dtype=torch.bfloat16,
257
  )
258
  print("✅ Base pipeline loaded. Overriding VAE with float32 version...")
 
 
 
259
  vae_fp32 = AutoencoderKLWan.from_pretrained(T2V_BASE_MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
260
  t2v_pipe.vae = vae_fp32
 
261
  t2v_pipe.to("cuda")
262
  print("✅ Pipeline configured. Loading and activating base FusionX LoRA...")
 
 
263
  t2v_pipe.load_lora_weights(FUSIONX_LORA_REPO, weight_name=FUSIONX_LORA_FILE, adapter_name=FUSIONX_ADAPTER_NAME)
264
  t2v_pipe.set_adapters([FUSIONX_ADAPTER_NAME], adapter_weights=[FUSIONX_LORA_WEIGHT])
265
  print("✅ T2V pipeline with base LoRA is ready.")
 
266
  except Exception as e:
267
  print(f"❌ CRITICAL ERROR: Failed to load T2V pipeline. T2V will be disabled. Reason: {e}")
268
  traceback.print_exc()
 
270
 
271
  print("\n🤖 Loading LLM for Prompt Enhancement...")
272
  try:
 
 
273
  enhancer_pipe = pipeline("text-generation", model=ENHANCER_MODEL_ID, torch_dtype=torch.bfloat16, device="cpu")
274
  print("✅ LLM Prompt Enhancer loaded successfully (on CPU).")
275
  except Exception as e:
 
285
  def enhance_prompt_with_llm(prompt: str, enhancer_pipeline):
286
  """
287
  Uses the loaded LLM to enhance a given prompt.
 
288
  """
289
  if enhancer_pipeline is None:
 
290
  gr.Warning("LLM enhancer is not available.")
291
+ return prompt, gr.update(), gr.update()
292
 
 
 
 
 
293
  if enhancer_pipeline.model.device.type != 'cuda':
294
  print("Moving enhancer model to CUDA for on-demand GPU execution...")
295
  enhancer_pipeline.model.to("cuda")
 
298
  print(f"Enhancing prompt: '{prompt}'")
299
 
300
  try:
 
301
  tokenizer = enhancer_pipeline.tokenizer
 
 
 
302
  if tokenizer.pad_token is None:
303
  tokenizer.pad_token = tokenizer.eos_token
304
+ tokenized_inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")
 
 
 
 
 
 
 
 
305
 
 
 
 
306
  if isinstance(tokenized_inputs, torch.Tensor):
 
307
  inputs_on_cuda = {"input_ids": tokenized_inputs.to("cuda")}
 
308
  inputs_on_cuda["attention_mask"] = torch.ones_like(inputs_on_cuda["input_ids"])
309
  else:
 
310
  inputs_on_cuda = {k: v.to("cuda") for k, v in tokenized_inputs.items()}
311
 
312
+ generated_ids = enhancer_pipeline.model.generate(**inputs_on_cuda, max_new_tokens=256, do_sample=True, temperature=0.7, top_p=0.95)
 
 
 
 
 
 
 
 
 
 
313
  input_token_length = inputs_on_cuda['input_ids'].shape[1]
314
  newly_generated_ids = generated_ids[:, input_token_length:]
 
 
315
  final_answer = tokenizer.decode(newly_generated_ids[0], skip_special_tokens=True)
316
 
317
  print(f"Enhanced prompt: '{final_answer.strip()}'")
318
+ # The enhanced prompt overwrites the textbox. The LoRA selection is reset.
319
  return final_answer.strip(), "None", gr.update(visible=False, interactive=False)
320
  except Exception as e:
321
  print(f"❌ Error during prompt enhancement: {e}")
 
322
  traceback.print_exc()
323
  gr.Warning(f"An error occurred during prompt enhancement. See console for details.")
324
+ return prompt, gr.update(), gr.update()
325
  finally:
 
 
326
  print("🧹 Clearing CUDA cache after prompt enhancement...")
327
  torch.cuda.empty_cache()
328
 
 
341
  if not prompt:
342
  raise gr.Error("Please enter a prompt for Text-to-Video generation.")
343
 
344
+ # --- The prompt from the textbox is now the final prompt. No more combining. ---
345
+ final_prompt = prompt
346
+
347
  target_h = max(MOD_VALUE, (height // MOD_VALUE) * MOD_VALUE)
348
  target_w = max(MOD_VALUE, (width // MOD_VALUE) * MOD_VALUE)
349
 
 
 
350
  requested_frames = int(round(duration_seconds * T2V_FIXED_FPS))
351
  frames_minus_one = requested_frames - 1
352
  valid_frames_minus_one = round(frames_minus_one / 4.0) * 4
 
359
  lora_loaded = False
360
 
361
  try:
362
+ lora_loaded = _manage_lora_state(pipe=t2v_pipe, selected_lora=selected_lora, lora_weight=lora_weight)
 
 
 
 
 
363
 
364
  print("\n--- Starting T2V Generation ---")
365
+ print(f"Final Prompt: {final_prompt}")
366
  print(f"Resolution: {target_w}x{target_h}, Frames: {num_frames}, Seed: {current_seed}")
367
  print(f"Steps: {steps}, Guidance: 1.0 (fixed for FusionX)")
368
  print("---------------------------------")
369
 
370
  with torch.inference_mode():
371
  output_frames_list = t2v_pipe(
372
+ prompt=final_prompt, negative_prompt=negative_prompt,
373
  height=target_h, width=target_w, num_frames=num_frames,
374
  guidance_scale=1.0, num_inference_steps=int(steps),
375
  generator=torch.Generator(device="cuda").manual_seed(current_seed)
376
  ).frames[0]
377
 
378
+ sanitized_prompt = sanitize_prompt_for_filename(final_prompt)
379
  filename = f"t2v_{sanitized_prompt}_{current_seed}.mp4"
380
  temp_dir = tempfile.mkdtemp()
381
  video_path = os.path.join(temp_dir, filename)
 
386
  return video_path, current_seed, gr.File(value=video_path, visible=True, label=download_label)
387
 
388
  except Exception as e:
 
389
  print(f"❌ An error occurred during video generation: {e}")
390
  traceback.print_exc()
391
  raise gr.Error("Video generation failed. Please check the logs for details.")
392
 
393
  finally:
 
 
 
394
  if lora_loaded:
395
  print(f"🧼 Cleaning up dynamic LoRA: {selected_lora}")
396
  try:
 
397
  t2v_pipe.delete_adapters([DYNAMIC_LORA_ADAPTER_NAME])
 
398
  t2v_pipe.set_adapters([FUSIONX_ADAPTER_NAME], adapter_weights=[FUSIONX_LORA_WEIGHT])
399
  print("✅ Cleanup complete. Pipeline reset to base LoRA state.")
400
  except Exception as e:
401
  print(f"⚠️ Error during LoRA cleanup: {e}. State may be inconsistent.")
402
 
 
 
403
  print("🧹 Clearing CUDA cache after video generation...")
404
  torch.cuda.empty_cache()
405
 
 
409
  def build_ui(t2v_pipe, enhancer_pipe, available_loras):
410
  """Creates and configures the Gradio UI."""
411
  with gr.Blocks(theme=gr.themes.Soft(), css=".main-container { max-width: 1080px; margin: auto; }") as demo:
 
 
 
412
  gr.Markdown("# ✨ Wan 2.1 Text-to-Video Suite with Dynamic LoRAs")
413
+ gr.Markdown("Generate videos from text. Edit the prompt below. Selecting a LoRA will append its triggers to your prompt.")
414
 
415
  with gr.Tabs():
416
  with gr.TabItem("✍️ Text-to-Video", id="t2v_tab", interactive=t2v_pipe is not None):
 
423
  label="✏️ Prompt", value=DEFAULT_PROMPT_T2V, lines=4,
424
  placeholder="e.g., A cinematic drone shot flying over a futuristic city at night..."
425
  )
426
+ t2v_enhance_btn = gr.Button(
427
+ "🤖 Enhance Prompt with AI",
428
+ interactive=enhancer_pipe is not None
429
+ )
 
 
 
 
 
430
 
431
  with gr.Group():
432
  t2v_lora_preset = gr.Dropdown(
433
  label="🎨 Dynamic Style LoRA (Optional)",
434
  choices=available_loras,
435
  value="None",
436
+ info="Appends style triggers to the prompt text above."
437
  )
438
  t2v_lora_weight = gr.Slider(
439
  label="💪 LoRA Weight", minimum=0.0, maximum=2.0, step=0.05, value=0.8,
 
462
  t2v_download = gr.File(label="📥 Download Video", visible=False)
463
 
464
  if t2v_pipe is not None:
 
465
  enhance_fn = partial(enhance_prompt_with_llm, enhancer_pipeline=enhancer_pipe)
466
 
467
+ # 1. When the user enhances the prompt with the LLM.
 
 
 
 
 
 
 
 
 
 
468
  t2v_enhance_btn.click(
469
  fn=enhance_fn,
470
  inputs=[t2v_prompt],
 
471
  outputs=[t2v_prompt, t2v_lora_preset, t2v_lora_weight]
 
 
 
 
472
  )
473
 
474
+ # 2. When the user selects a LoRA from the dropdown.
475
  t2v_lora_preset.change(
476
  fn=handle_lora_selection_change,
477
+ # Pass the current prompt text in, get the new text back out.
478
+ inputs=[t2v_lora_preset, t2v_prompt],
479
  outputs=[t2v_prompt, t2v_lora_weight]
480
  )
481
 
482
+ # 3. When the user clicks the final generate button.
483
  t2v_generate_btn.click(
484
  fn=generate_t2v_video,
485
  inputs=[
 
496
  if __name__ == "__main__":
497
  t2v_pipe, enhancer_pipe = load_pipelines()
498
 
 
499
  available_loras = []
500
  if t2v_pipe:
501
  available_loras = get_available_presets(DYNAMIC_LORA_REPO_ID, DYNAMIC_LORA_SUBFOLDER)