prithivMLmods commited on
Commit
7b24dc9
·
verified ·
1 Parent(s): 52ffd64

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -81
app.py CHANGED
@@ -35,9 +35,6 @@ from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
35
  from diffusers import ShapEImg2ImgPipeline, ShapEPipeline
36
  from diffusers.utils import export_to_ply
37
 
38
- from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
39
- from modelscope import snapshot_download
40
-
41
  # Global constants and helper functions
42
 
43
  MAX_SEED = np.iinfo(np.int32).max
@@ -340,49 +337,53 @@ def save_image(img: Image.Image) -> str:
340
  img.save(unique_name)
341
  return unique_name
342
 
343
- # NEW: Global setup for Wan Video Pipeline
344
- wan_pipe = None
345
- def get_wan_pipe():
346
- global wan_pipe
347
- if wan_pipe is None:
348
- snapshot_download("Wan-AI/Wan2.1-T2V-1.3B", local_dir="models/Wan-AI/Wan2.1-T2V-1.3B")
349
- model_manager = ModelManager(device="cpu")
350
- model_manager.load_models(
351
- [
352
- "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
353
- "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
354
- "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
355
- ],
356
- torch_dtype=torch.bfloat16,
357
- )
358
- wan_pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
359
- wan_pipe.enable_vram_management(num_persistent_param_in_dit=None)
360
- return wan_pipe
361
-
362
- @spaces.GPU(duration=120, enable_queue=True)
363
- def generate_video_fn(
364
  prompt: str,
365
  negative_prompt: str = "",
 
366
  seed: int = 1,
367
- num_inference_steps: int = 50,
 
 
 
368
  randomize_seed: bool = False,
 
 
 
369
  ):
370
- """
371
- Generate a video from text using the Wan pipeline.
372
- Returns a tuple of (video_file_path, used_seed).
373
- """
374
  seed = int(randomize_seed_fn(seed, randomize_seed))
375
- pipe = get_wan_pipe()
376
- video = pipe(
377
- prompt=prompt,
378
- negative_prompt=negative_prompt,
379
- num_inference_steps=num_inference_steps,
380
- seed=seed,
381
- tiled=True
382
- )
383
- video_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
384
- save_video(video, video_path, fps=15, quality=5)
385
- return video_path, seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
 
387
  # Text-to-3D Generation using the ShapE Pipeline
388
 
@@ -423,7 +424,7 @@ def detect_objects(image: np.ndarray):
423
 
424
  return Image.fromarray(annotated_image)
425
 
426
- # Chat Generation Function with support for @tts, @image, @3d, @web, @rAgent, @yolo, and now @wan commands
427
 
428
  @spaces.GPU
429
  def generate(
@@ -443,7 +444,6 @@ def generate(
443
  - "@web": triggers a web search or webpage visit.
444
  - "@rAgent": initiates a reasoning chain using Llama mode OpenAI.
445
  - "@yolo": triggers object detection using YOLO.
446
- - "@wan": triggers video generation using the Wan pipeline.
447
  """
448
  text = input_dict["text"]
449
  files = input_dict.get("files", [])
@@ -539,42 +539,6 @@ def generate(
539
  yield gr.Image(result_img)
540
  return
541
 
542
- # --- Wan Video Generation branch ---
543
- if text.strip().lower().startswith("@wan"):
544
- prompt = text[len("@wan"):].strip()
545
- yield "🎞️ Generating video..."
546
- # If a video file is attached, perform video-to-video generation.
547
- if files and len(files) > 0:
548
- try:
549
- input_video_path = files[0]
550
- video_data = VideoData(input_video_path, height=480, width=832)
551
- except Exception as e:
552
- yield f"Error loading video: {str(e)}"
553
- return
554
- pipe = get_wan_pipe()
555
- video = pipe(
556
- prompt=prompt,
557
- negative_prompt="",
558
- input_video=video_data,
559
- denoising_strength=0.7,
560
- num_inference_steps=50,
561
- seed=randomize_seed_fn(1, True),
562
- tiled=True
563
- )
564
- video_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
565
- save_video(video, video_path, fps=24, quality=5)
566
- yield gr.Video(video_path)
567
- else:
568
- video_path, used_seed = generate_video_fn(
569
- prompt=prompt,
570
- negative_prompt="",
571
- seed=1,
572
- num_inference_steps=50,
573
- randomize_seed=True,
574
- )
575
- yield gr.Video(video_path)
576
- return
577
-
578
  # --- Text and TTS branch ---
579
  tts_prefix = "@tts"
580
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
@@ -665,20 +629,19 @@ demo = gr.ChatInterface(
665
  examples=[
666
  ["@tts2 What causes rainbows to form?"],
667
  ["@image Chocolate dripping from a donut"],
668
- ["@wan A documentary-style shot of a lively puppy running"],
669
  ["@3d A birthday cupcake with cherry"],
670
  [{"text": "Summarize the letter", "files": ["examples/1.png"]}],
671
  [{"text": "@yolo", "files": ["examples/yolo.jpeg"]}],
672
  ["@rAgent Explain how a binary search algorithm works."],
673
  ["@web Is Grok-3 Beats DeepSeek-R1 at Reasoning ?"],
674
- ["@tts1 Explain Tower of Hanoi"]
675
  ],
676
  cache_examples=False,
677
  type="messages",
678
  description=DESCRIPTION,
679
  css=css,
680
  fill_height=True,
681
- textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple", placeholder="@tts1-♀, @tts2-♂, @image-image gen, @3d-3d mesh gen, @rAgent-coding, @web-websearch, @yolo-object detection, @wan-video gen, default-{text gen}{image-text-text}"),
682
  stop_btn="Stop Generation",
683
  multimodal=True,
684
  )
 
35
  from diffusers import ShapEImg2ImgPipeline, ShapEPipeline
36
  from diffusers.utils import export_to_ply
37
 
 
 
 
38
  # Global constants and helper functions
39
 
40
  MAX_SEED = np.iinfo(np.int32).max
 
337
  img.save(unique_name)
338
  return unique_name
339
 
340
+ @spaces.GPU(duration=60, enable_queue=True)
341
+ def generate_image_fn(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
  prompt: str,
343
  negative_prompt: str = "",
344
+ use_negative_prompt: bool = False,
345
  seed: int = 1,
346
+ width: int = 1024,
347
+ height: int = 1024,
348
+ guidance_scale: float = 3,
349
+ num_inference_steps: int = 25,
350
  randomize_seed: bool = False,
351
+ use_resolution_binning: bool = True,
352
+ num_images: int = 1,
353
+ progress=gr.Progress(track_tqdm=True),
354
  ):
355
+ """Generate images using the SDXL pipeline."""
 
 
 
356
  seed = int(randomize_seed_fn(seed, randomize_seed))
357
+ generator = torch.Generator(device=device).manual_seed(seed)
358
+
359
+ options = {
360
+ "prompt": [prompt] * num_images,
361
+ "negative_prompt": [negative_prompt] * num_images if use_negative_prompt else None,
362
+ "width": width,
363
+ "height": height,
364
+ "guidance_scale": guidance_scale,
365
+ "num_inference_steps": num_inference_steps,
366
+ "generator": generator,
367
+ "output_type": "pil",
368
+ }
369
+ if use_resolution_binning:
370
+ options["use_resolution_binning"] = True
371
+
372
+ images = []
373
+ # Process in batches
374
+ for i in range(0, num_images, BATCH_SIZE):
375
+ batch_options = options.copy()
376
+ batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
377
+ if "negative_prompt" in batch_options and batch_options["negative_prompt"] is not None:
378
+ batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
379
+ if device.type == "cuda":
380
+ with torch.autocast("cuda", dtype=torch.float16):
381
+ outputs = sd_pipe(**batch_options)
382
+ else:
383
+ outputs = sd_pipe(**batch_options)
384
+ images.extend(outputs.images)
385
+ image_paths = [save_image(img) for img in images]
386
+ return image_paths, seed
387
 
388
  # Text-to-3D Generation using the ShapE Pipeline
389
 
 
424
 
425
  return Image.fromarray(annotated_image)
426
 
427
+ # Chat Generation Function with support for @tts, @image, @3d, @web, @rAgent, and @yolo commands
428
 
429
  @spaces.GPU
430
  def generate(
 
444
  - "@web": triggers a web search or webpage visit.
445
  - "@rAgent": initiates a reasoning chain using Llama mode OpenAI.
446
  - "@yolo": triggers object detection using YOLO.
 
447
  """
448
  text = input_dict["text"]
449
  files = input_dict.get("files", [])
 
539
  yield gr.Image(result_img)
540
  return
541
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
542
  # --- Text and TTS branch ---
543
  tts_prefix = "@tts"
544
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
 
629
  examples=[
630
  ["@tts2 What causes rainbows to form?"],
631
  ["@image Chocolate dripping from a donut"],
 
632
  ["@3d A birthday cupcake with cherry"],
633
  [{"text": "Summarize the letter", "files": ["examples/1.png"]}],
634
  [{"text": "@yolo", "files": ["examples/yolo.jpeg"]}],
635
  ["@rAgent Explain how a binary search algorithm works."],
636
  ["@web Is Grok-3 Beats DeepSeek-R1 at Reasoning ?"],
637
+ ["@tts1 Explain Tower of Hanoi"],
638
  ],
639
  cache_examples=False,
640
  type="messages",
641
  description=DESCRIPTION,
642
  css=css,
643
  fill_height=True,
644
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple", placeholder="@tts1-♀, @tts2-♂, @image-image gen, @3d-3d mesh gen, @rAgent-coding, @web-websearch, @yolo-object detection, default-{text gen}{image-text-text}"),
645
  stop_btn="Stop Generation",
646
  multimodal=True,
647
  )