Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -35,9 +35,6 @@ from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
|
|
35 |
from diffusers import ShapEImg2ImgPipeline, ShapEPipeline
|
36 |
from diffusers.utils import export_to_ply
|
37 |
|
38 |
-
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
|
39 |
-
from modelscope import snapshot_download
|
40 |
-
|
41 |
# Global constants and helper functions
|
42 |
|
43 |
MAX_SEED = np.iinfo(np.int32).max
|
@@ -340,49 +337,53 @@ def save_image(img: Image.Image) -> str:
|
|
340 |
img.save(unique_name)
|
341 |
return unique_name
|
342 |
|
343 |
-
|
344 |
-
|
345 |
-
def get_wan_pipe():
|
346 |
-
global wan_pipe
|
347 |
-
if wan_pipe is None:
|
348 |
-
snapshot_download("Wan-AI/Wan2.1-T2V-1.3B", local_dir="models/Wan-AI/Wan2.1-T2V-1.3B")
|
349 |
-
model_manager = ModelManager(device="cpu")
|
350 |
-
model_manager.load_models(
|
351 |
-
[
|
352 |
-
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
|
353 |
-
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
|
354 |
-
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
355 |
-
],
|
356 |
-
torch_dtype=torch.bfloat16,
|
357 |
-
)
|
358 |
-
wan_pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
359 |
-
wan_pipe.enable_vram_management(num_persistent_param_in_dit=None)
|
360 |
-
return wan_pipe
|
361 |
-
|
362 |
-
@spaces.GPU(duration=120, enable_queue=True)
|
363 |
-
def generate_video_fn(
|
364 |
prompt: str,
|
365 |
negative_prompt: str = "",
|
|
|
366 |
seed: int = 1,
|
367 |
-
|
|
|
|
|
|
|
368 |
randomize_seed: bool = False,
|
|
|
|
|
|
|
369 |
):
|
370 |
-
"""
|
371 |
-
Generate a video from text using the Wan pipeline.
|
372 |
-
Returns a tuple of (video_file_path, used_seed).
|
373 |
-
"""
|
374 |
seed = int(randomize_seed_fn(seed, randomize_seed))
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
386 |
|
387 |
# Text-to-3D Generation using the ShapE Pipeline
|
388 |
|
@@ -423,7 +424,7 @@ def detect_objects(image: np.ndarray):
|
|
423 |
|
424 |
return Image.fromarray(annotated_image)
|
425 |
|
426 |
-
# Chat Generation Function with support for @tts, @image, @3d, @web, @rAgent,
|
427 |
|
428 |
@spaces.GPU
|
429 |
def generate(
|
@@ -443,7 +444,6 @@ def generate(
|
|
443 |
- "@web": triggers a web search or webpage visit.
|
444 |
- "@rAgent": initiates a reasoning chain using Llama mode OpenAI.
|
445 |
- "@yolo": triggers object detection using YOLO.
|
446 |
-
- "@wan": triggers video generation using the Wan pipeline.
|
447 |
"""
|
448 |
text = input_dict["text"]
|
449 |
files = input_dict.get("files", [])
|
@@ -539,42 +539,6 @@ def generate(
|
|
539 |
yield gr.Image(result_img)
|
540 |
return
|
541 |
|
542 |
-
# --- Wan Video Generation branch ---
|
543 |
-
if text.strip().lower().startswith("@wan"):
|
544 |
-
prompt = text[len("@wan"):].strip()
|
545 |
-
yield "🎞️ Generating video..."
|
546 |
-
# If a video file is attached, perform video-to-video generation.
|
547 |
-
if files and len(files) > 0:
|
548 |
-
try:
|
549 |
-
input_video_path = files[0]
|
550 |
-
video_data = VideoData(input_video_path, height=480, width=832)
|
551 |
-
except Exception as e:
|
552 |
-
yield f"Error loading video: {str(e)}"
|
553 |
-
return
|
554 |
-
pipe = get_wan_pipe()
|
555 |
-
video = pipe(
|
556 |
-
prompt=prompt,
|
557 |
-
negative_prompt="",
|
558 |
-
input_video=video_data,
|
559 |
-
denoising_strength=0.7,
|
560 |
-
num_inference_steps=50,
|
561 |
-
seed=randomize_seed_fn(1, True),
|
562 |
-
tiled=True
|
563 |
-
)
|
564 |
-
video_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
|
565 |
-
save_video(video, video_path, fps=24, quality=5)
|
566 |
-
yield gr.Video(video_path)
|
567 |
-
else:
|
568 |
-
video_path, used_seed = generate_video_fn(
|
569 |
-
prompt=prompt,
|
570 |
-
negative_prompt="",
|
571 |
-
seed=1,
|
572 |
-
num_inference_steps=50,
|
573 |
-
randomize_seed=True,
|
574 |
-
)
|
575 |
-
yield gr.Video(video_path)
|
576 |
-
return
|
577 |
-
|
578 |
# --- Text and TTS branch ---
|
579 |
tts_prefix = "@tts"
|
580 |
is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
|
@@ -665,20 +629,19 @@ demo = gr.ChatInterface(
|
|
665 |
examples=[
|
666 |
["@tts2 What causes rainbows to form?"],
|
667 |
["@image Chocolate dripping from a donut"],
|
668 |
-
["@wan A documentary-style shot of a lively puppy running"],
|
669 |
["@3d A birthday cupcake with cherry"],
|
670 |
[{"text": "Summarize the letter", "files": ["examples/1.png"]}],
|
671 |
[{"text": "@yolo", "files": ["examples/yolo.jpeg"]}],
|
672 |
["@rAgent Explain how a binary search algorithm works."],
|
673 |
["@web Is Grok-3 Beats DeepSeek-R1 at Reasoning ?"],
|
674 |
-
["@tts1 Explain Tower of Hanoi"]
|
675 |
],
|
676 |
cache_examples=False,
|
677 |
type="messages",
|
678 |
description=DESCRIPTION,
|
679 |
css=css,
|
680 |
fill_height=True,
|
681 |
-
textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"
|
682 |
stop_btn="Stop Generation",
|
683 |
multimodal=True,
|
684 |
)
|
|
|
35 |
from diffusers import ShapEImg2ImgPipeline, ShapEPipeline
|
36 |
from diffusers.utils import export_to_ply
|
37 |
|
|
|
|
|
|
|
38 |
# Global constants and helper functions
|
39 |
|
40 |
MAX_SEED = np.iinfo(np.int32).max
|
|
|
337 |
img.save(unique_name)
|
338 |
return unique_name
|
339 |
|
340 |
+
@spaces.GPU(duration=60, enable_queue=True)
|
341 |
+
def generate_image_fn(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
342 |
prompt: str,
|
343 |
negative_prompt: str = "",
|
344 |
+
use_negative_prompt: bool = False,
|
345 |
seed: int = 1,
|
346 |
+
width: int = 1024,
|
347 |
+
height: int = 1024,
|
348 |
+
guidance_scale: float = 3,
|
349 |
+
num_inference_steps: int = 25,
|
350 |
randomize_seed: bool = False,
|
351 |
+
use_resolution_binning: bool = True,
|
352 |
+
num_images: int = 1,
|
353 |
+
progress=gr.Progress(track_tqdm=True),
|
354 |
):
|
355 |
+
"""Generate images using the SDXL pipeline."""
|
|
|
|
|
|
|
356 |
seed = int(randomize_seed_fn(seed, randomize_seed))
|
357 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
358 |
+
|
359 |
+
options = {
|
360 |
+
"prompt": [prompt] * num_images,
|
361 |
+
"negative_prompt": [negative_prompt] * num_images if use_negative_prompt else None,
|
362 |
+
"width": width,
|
363 |
+
"height": height,
|
364 |
+
"guidance_scale": guidance_scale,
|
365 |
+
"num_inference_steps": num_inference_steps,
|
366 |
+
"generator": generator,
|
367 |
+
"output_type": "pil",
|
368 |
+
}
|
369 |
+
if use_resolution_binning:
|
370 |
+
options["use_resolution_binning"] = True
|
371 |
+
|
372 |
+
images = []
|
373 |
+
# Process in batches
|
374 |
+
for i in range(0, num_images, BATCH_SIZE):
|
375 |
+
batch_options = options.copy()
|
376 |
+
batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
|
377 |
+
if "negative_prompt" in batch_options and batch_options["negative_prompt"] is not None:
|
378 |
+
batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
|
379 |
+
if device.type == "cuda":
|
380 |
+
with torch.autocast("cuda", dtype=torch.float16):
|
381 |
+
outputs = sd_pipe(**batch_options)
|
382 |
+
else:
|
383 |
+
outputs = sd_pipe(**batch_options)
|
384 |
+
images.extend(outputs.images)
|
385 |
+
image_paths = [save_image(img) for img in images]
|
386 |
+
return image_paths, seed
|
387 |
|
388 |
# Text-to-3D Generation using the ShapE Pipeline
|
389 |
|
|
|
424 |
|
425 |
return Image.fromarray(annotated_image)
|
426 |
|
427 |
+
# Chat Generation Function with support for @tts, @image, @3d, @web, @rAgent, and @yolo commands
|
428 |
|
429 |
@spaces.GPU
|
430 |
def generate(
|
|
|
444 |
- "@web": triggers a web search or webpage visit.
|
445 |
- "@rAgent": initiates a reasoning chain using Llama mode OpenAI.
|
446 |
- "@yolo": triggers object detection using YOLO.
|
|
|
447 |
"""
|
448 |
text = input_dict["text"]
|
449 |
files = input_dict.get("files", [])
|
|
|
539 |
yield gr.Image(result_img)
|
540 |
return
|
541 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
542 |
# --- Text and TTS branch ---
|
543 |
tts_prefix = "@tts"
|
544 |
is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
|
|
|
629 |
examples=[
|
630 |
["@tts2 What causes rainbows to form?"],
|
631 |
["@image Chocolate dripping from a donut"],
|
|
|
632 |
["@3d A birthday cupcake with cherry"],
|
633 |
[{"text": "Summarize the letter", "files": ["examples/1.png"]}],
|
634 |
[{"text": "@yolo", "files": ["examples/yolo.jpeg"]}],
|
635 |
["@rAgent Explain how a binary search algorithm works."],
|
636 |
["@web Is Grok-3 Beats DeepSeek-R1 at Reasoning ?"],
|
637 |
+
["@tts1 Explain Tower of Hanoi"],
|
638 |
],
|
639 |
cache_examples=False,
|
640 |
type="messages",
|
641 |
description=DESCRIPTION,
|
642 |
css=css,
|
643 |
fill_height=True,
|
644 |
+
textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple", placeholder="@tts1-♀, @tts2-♂, @image-image gen, @3d-3d mesh gen, @rAgent-coding, @web-websearch, @yolo-object detection, default-{text gen}{image-text-text}"),
|
645 |
stop_btn="Stop Generation",
|
646 |
multimodal=True,
|
647 |
)
|