prithivMLmods commited on
Commit
a57530c
·
verified ·
1 Parent(s): ccc8b48

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -44
app.py CHANGED
@@ -35,6 +35,11 @@ from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
35
  from diffusers import ShapEImg2ImgPipeline, ShapEPipeline
36
  from diffusers.utils import export_to_ply
37
 
 
 
 
 
 
38
  # Global constants and helper functions
39
 
40
  MAX_SEED = np.iinfo(np.int32).max
@@ -337,53 +342,49 @@ def save_image(img: Image.Image) -> str:
337
  img.save(unique_name)
338
  return unique_name
339
 
340
- @spaces.GPU(duration=60, enable_queue=True)
341
- def generate_image_fn(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
  prompt: str,
343
  negative_prompt: str = "",
344
- use_negative_prompt: bool = False,
345
  seed: int = 1,
346
- width: int = 1024,
347
- height: int = 1024,
348
- guidance_scale: float = 3,
349
- num_inference_steps: int = 25,
350
  randomize_seed: bool = False,
351
- use_resolution_binning: bool = True,
352
- num_images: int = 1,
353
- progress=gr.Progress(track_tqdm=True),
354
  ):
355
- """Generate images using the SDXL pipeline."""
 
 
 
356
  seed = int(randomize_seed_fn(seed, randomize_seed))
357
- generator = torch.Generator(device=device).manual_seed(seed)
358
-
359
- options = {
360
- "prompt": [prompt] * num_images,
361
- "negative_prompt": [negative_prompt] * num_images if use_negative_prompt else None,
362
- "width": width,
363
- "height": height,
364
- "guidance_scale": guidance_scale,
365
- "num_inference_steps": num_inference_steps,
366
- "generator": generator,
367
- "output_type": "pil",
368
- }
369
- if use_resolution_binning:
370
- options["use_resolution_binning"] = True
371
-
372
- images = []
373
- # Process in batches
374
- for i in range(0, num_images, BATCH_SIZE):
375
- batch_options = options.copy()
376
- batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
377
- if "negative_prompt" in batch_options and batch_options["negative_prompt"] is not None:
378
- batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
379
- if device.type == "cuda":
380
- with torch.autocast("cuda", dtype=torch.float16):
381
- outputs = sd_pipe(**batch_options)
382
- else:
383
- outputs = sd_pipe(**batch_options)
384
- images.extend(outputs.images)
385
- image_paths = [save_image(img) for img in images]
386
- return image_paths, seed
387
 
388
  # Text-to-3D Generation using the ShapE Pipeline
389
 
@@ -424,7 +425,7 @@ def detect_objects(image: np.ndarray):
424
 
425
  return Image.fromarray(annotated_image)
426
 
427
- # Chat Generation Function with support for @tts, @image, @3d, @web, @rAgent, and @yolo commands
428
 
429
  @spaces.GPU
430
  def generate(
@@ -444,6 +445,7 @@ def generate(
444
  - "@web": triggers a web search or webpage visit.
445
  - "@rAgent": initiates a reasoning chain using Llama mode OpenAI.
446
  - "@yolo": triggers object detection using YOLO.
 
447
  """
448
  text = input_dict["text"]
449
  files = input_dict.get("files", [])
@@ -539,6 +541,42 @@ def generate(
539
  yield gr.Image(result_img)
540
  return
541
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
542
  # --- Text and TTS branch ---
543
  tts_prefix = "@tts"
544
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
@@ -629,19 +667,20 @@ demo = gr.ChatInterface(
629
  examples=[
630
  ["@tts2 What causes rainbows to form?"],
631
  ["@image Chocolate dripping from a donut"],
 
632
  ["@3d A birthday cupcake with cherry"],
633
  [{"text": "Summarize the letter", "files": ["examples/1.png"]}],
634
  [{"text": "@yolo", "files": ["examples/yolo.jpeg"]}],
635
  ["@rAgent Explain how a binary search algorithm works."],
636
  ["@web Is Grok-3 Beats DeepSeek-R1 at Reasoning ?"],
637
- ["@tts1 Explain Tower of Hanoi"],
638
  ],
639
  cache_examples=False,
640
  type="messages",
641
  description=DESCRIPTION,
642
  css=css,
643
  fill_height=True,
644
- textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple", placeholder="@tts1-♀, @tts2-♂, @image-image gen, @3d-3d mesh gen, @rAgent-coding, @web-websearch, @yolo-object detection, default-{text gen}{image-text-text}"),
645
  stop_btn="Stop Generation",
646
  multimodal=True,
647
  )
 
35
  from diffusers import ShapEImg2ImgPipeline, ShapEPipeline
36
  from diffusers.utils import export_to_ply
37
 
38
+ # NEW IMPORTS FOR TEXT-TO-VIDEO FEATURE
39
+ import torch # already imported above; included here for clarity
40
+ from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
41
+ from modelscope import snapshot_download
42
+
43
  # Global constants and helper functions
44
 
45
  MAX_SEED = np.iinfo(np.int32).max
 
342
  img.save(unique_name)
343
  return unique_name
344
 
345
+ # NEW: Global setup for Wan Video Pipeline
346
+ wan_pipe = None
347
+ def get_wan_pipe():
348
+ global wan_pipe
349
+ if wan_pipe is None:
350
+ snapshot_download("Wan-AI/Wan2.1-T2V-1.3B", local_dir="models/Wan-AI/Wan2.1-T2V-1.3B")
351
+ model_manager = ModelManager(device="cpu")
352
+ model_manager.load_models(
353
+ [
354
+ "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
355
+ "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
356
+ "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
357
+ ],
358
+ torch_dtype=torch.bfloat16,
359
+ )
360
+ wan_pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
361
+ wan_pipe.enable_vram_management(num_persistent_param_in_dit=None)
362
+ return wan_pipe
363
+
364
+ @spaces.GPU(duration=120, enable_queue=True)
365
+ def generate_video_fn(
366
  prompt: str,
367
  negative_prompt: str = "",
 
368
  seed: int = 1,
369
+ num_inference_steps: int = 50,
 
 
 
370
  randomize_seed: bool = False,
 
 
 
371
  ):
372
+ """
373
+ Generate a video from text using the Wan pipeline.
374
+ Returns a tuple of (video_file_path, used_seed).
375
+ """
376
  seed = int(randomize_seed_fn(seed, randomize_seed))
377
+ pipe = get_wan_pipe()
378
+ video = pipe(
379
+ prompt=prompt,
380
+ negative_prompt=negative_prompt,
381
+ num_inference_steps=num_inference_steps,
382
+ seed=seed,
383
+ tiled=True
384
+ )
385
+ video_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
386
+ save_video(video, video_path, fps=15, quality=5)
387
+ return video_path, seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
 
389
  # Text-to-3D Generation using the ShapE Pipeline
390
 
 
425
 
426
  return Image.fromarray(annotated_image)
427
 
428
+ # Chat Generation Function with support for @tts, @image, @3d, @web, @rAgent, @yolo, and now @wan commands
429
 
430
  @spaces.GPU
431
  def generate(
 
445
  - "@web": triggers a web search or webpage visit.
446
  - "@rAgent": initiates a reasoning chain using Llama mode OpenAI.
447
  - "@yolo": triggers object detection using YOLO.
448
+ - "@wan": triggers video generation using the Wan pipeline.
449
  """
450
  text = input_dict["text"]
451
  files = input_dict.get("files", [])
 
541
  yield gr.Image(result_img)
542
  return
543
 
544
+ # --- Wan Video Generation branch ---
545
+ if text.strip().lower().startswith("@wan"):
546
+ prompt = text[len("@wan"):].strip()
547
+ yield "🎞️ Generating video..."
548
+ # If a video file is attached, perform video-to-video generation.
549
+ if files and len(files) > 0:
550
+ try:
551
+ input_video_path = files[0]
552
+ video_data = VideoData(input_video_path, height=480, width=832)
553
+ except Exception as e:
554
+ yield f"Error loading video: {str(e)}"
555
+ return
556
+ pipe = get_wan_pipe()
557
+ video = pipe(
558
+ prompt=prompt,
559
+ negative_prompt="",
560
+ input_video=video_data,
561
+ denoising_strength=0.7,
562
+ num_inference_steps=50,
563
+ seed=randomize_seed_fn(1, True),
564
+ tiled=True
565
+ )
566
+ video_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
567
+ save_video(video, video_path, fps=24, quality=5)
568
+ yield gr.Video(video_path)
569
+ else:
570
+ video_path, used_seed = generate_video_fn(
571
+ prompt=prompt,
572
+ negative_prompt="",
573
+ seed=1,
574
+ num_inference_steps=50,
575
+ randomize_seed=True,
576
+ )
577
+ yield gr.Video(video_path)
578
+ return
579
+
580
  # --- Text and TTS branch ---
581
  tts_prefix = "@tts"
582
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
 
667
  examples=[
668
  ["@tts2 What causes rainbows to form?"],
669
  ["@image Chocolate dripping from a donut"],
670
+ ["@wan A documentary-style shot of a lively puppy running"],
671
  ["@3d A birthday cupcake with cherry"],
672
  [{"text": "Summarize the letter", "files": ["examples/1.png"]}],
673
  [{"text": "@yolo", "files": ["examples/yolo.jpeg"]}],
674
  ["@rAgent Explain how a binary search algorithm works."],
675
  ["@web Is Grok-3 Beats DeepSeek-R1 at Reasoning ?"],
676
+ ["@tts1 Explain Tower of Hanoi"]
677
  ],
678
  cache_examples=False,
679
  type="messages",
680
  description=DESCRIPTION,
681
  css=css,
682
  fill_height=True,
683
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple", placeholder="@tts1-♀, @tts2-♂, @image-image gen, @3d-3d mesh gen, @rAgent-coding, @web-websearch, @yolo-object detection, @wan-video gen, default-{text gen}{image-text-text}"),
684
  stop_btn="Stop Generation",
685
  multimodal=True,
686
  )