Spaces:
Paused
Paused
import gradio as gr | |
import subprocess | |
import os | |
import tempfile | |
import shutil | |
from pathlib import Path | |
import torch | |
import logging | |
# Configure logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# Constants | |
DEFAULT_CONFIG_PATH = "configs/inference.yaml" | |
DEFAULT_INPUT_FILE = "examples/infer_samples.txt" | |
OUTPUT_DIR = Path("demo_out/gradio_outputs") | |
OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
def generate_avatar_video( | |
reference_image, | |
audio_file, | |
text_prompt, | |
seed=42, | |
num_steps=50, | |
guidance_scale=4.5, | |
audio_scale=None, | |
overlap_frames=13, | |
fps=25, | |
silence_duration=0.3, | |
resolution="720p", | |
progress=gr.Progress() | |
): | |
"""Generate an avatar video using OmniAvatar | |
Args: | |
reference_image: Path to reference avatar image | |
audio_file: Path to audio file for lip sync | |
text_prompt: Text description of the video to generate | |
seed: Random seed for generation | |
num_steps: Number of inference steps | |
guidance_scale: Classifier-free guidance scale | |
audio_scale: Audio guidance scale (uses guidance_scale if None) | |
overlap_frames: Number of overlapping frames between chunks | |
fps: Frames per second | |
silence_duration: Duration of silence to add before/after audio | |
resolution: Output resolution ("480p" or "720p") | |
progress: Gradio progress callback | |
Returns: | |
str: Path to generated video file | |
""" | |
try: | |
progress(0.1, desc="Preparing inputs") | |
# Create temporary directory for this generation | |
with tempfile.TemporaryDirectory() as temp_dir: | |
temp_path = Path(temp_dir) | |
# Copy input files to temp directory | |
temp_image = temp_path / "input_image.jpeg" | |
temp_audio = temp_path / "input_audio.mp3" | |
shutil.copy(reference_image, temp_image) | |
shutil.copy(audio_file, temp_audio) | |
# Create input file for inference script | |
input_file = temp_path / "input.txt" | |
# Format: prompt@@image_path@@audio_path | |
with open(input_file, 'w') as f: | |
f.write(f"{text_prompt}@@{temp_image}@@{temp_audio}\n") | |
progress(0.2, desc="Configuring generation parameters") | |
# Determine max_hw based on resolution | |
max_hw = 720 if resolution == "480p" else 1280 | |
# Build command to run inference script | |
cmd = [ | |
"torchrun", | |
"--nproc_per_node=1", | |
"scripts/inference.py", | |
"--config", DEFAULT_CONFIG_PATH, | |
"--input_file", str(input_file), | |
"-hp", f"seed={seed},num_steps={num_steps},guidance_scale={guidance_scale}," | |
f"overlap_frame={overlap_frames},fps={fps},silence_duration_s={silence_duration}," | |
f"max_hw={max_hw},use_audio=True,i2v=True" | |
] | |
# Add audio scale if specified | |
if audio_scale is not None: | |
cmd[-1] += f",audio_scale={audio_scale}" | |
progress(0.3, desc="Running OmniAvatar generation") | |
logger.info(f"Running command: {' '.join(cmd)}") | |
# Run the inference script | |
env = os.environ.copy() | |
env['CUDA_VISIBLE_DEVICES'] = '0' # Use first GPU | |
process = subprocess.Popen( | |
cmd, | |
stdout=subprocess.PIPE, | |
stderr=subprocess.PIPE, | |
text=True, | |
env=env | |
) | |
# Monitor progress (simplified - in reality you'd parse the output) | |
stdout_lines = [] | |
stderr_lines = [] | |
while True: | |
output = process.stdout.readline() | |
if output: | |
stdout_lines.append(output.strip()) | |
logger.info(output.strip()) | |
# Update progress based on output | |
if "Starting video generation" in output: | |
progress(0.5, desc="Generating video frames") | |
elif "[1/" in output: # First chunk | |
progress(0.6, desc="Processing video chunks") | |
elif "Saving video" in output: | |
progress(0.9, desc="Finalizing video") | |
if process.poll() is not None: | |
break | |
# Get any remaining output | |
remaining_stdout, remaining_stderr = process.communicate() | |
if remaining_stdout: | |
stdout_lines.extend(remaining_stdout.strip().split('\n')) | |
if remaining_stderr: | |
stderr_lines.extend(remaining_stderr.strip().split('\n')) | |
if process.returncode != 0: | |
error_msg = '\n'.join(stderr_lines) | |
logger.error(f"Inference failed with return code {process.returncode}") | |
logger.error(f"Error output: {error_msg}") | |
raise gr.Error(f"Video generation failed: {error_msg}") | |
progress(0.95, desc="Retrieving generated video") | |
# Find the generated video file | |
# The inference script saves to demo_out/{exp_name}/res_{input_file_name}_... | |
# We need to find the most recent video file | |
generated_videos = list(Path("demo_out").rglob("result_000.mp4")) | |
if not generated_videos: | |
raise gr.Error("No video file was generated") | |
# Get the most recent video | |
latest_video = max(generated_videos, key=lambda p: p.stat().st_mtime) | |
# Copy to output directory with unique name | |
output_filename = f"avatar_video_{os.getpid()}_{torch.randint(1000, 9999, (1,)).item()}.mp4" | |
output_path = OUTPUT_DIR / output_filename | |
shutil.copy(latest_video, output_path) | |
progress(1.0, desc="Generation complete") | |
logger.info(f"Video saved to: {output_path}") | |
return str(output_path) | |
except Exception as e: | |
logger.error(f"Error generating video: {str(e)}") | |
raise gr.Error(f"Error generating video: {str(e)}") | |
# Create the Gradio interface | |
with gr.Blocks(title="OmniAvatar - Lipsynced Avatar Video Generation") as app: | |
gr.Markdown(""" | |
# π OmniAvatar - Lipsynced Avatar Video Generation | |
Generate videos with lipsynced avatars using a reference image and audio file. | |
Based on Wan2.1 with OmniAvatar enhancements for audio-driven avatar animation. | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
# Input components | |
reference_image = gr.Image( | |
label="Reference Avatar Image", | |
type="filepath", | |
elem_id="reference_image" | |
) | |
audio_file = gr.Audio( | |
label="Speech Audio File", | |
type="filepath", | |
elem_id="audio_file" | |
) | |
text_prompt = gr.Textbox( | |
label="Video Description", | |
placeholder="Describe the video scene and actions...", | |
lines=3, | |
value="A person speaking naturally with subtle facial expressions" | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
with gr.Row(): | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=2147483647, | |
step=1, | |
value=42 | |
) | |
resolution = gr.Radio( | |
label="Resolution", | |
choices=["480p", "720p"], | |
value="720p" | |
) | |
with gr.Row(): | |
num_steps = gr.Slider( | |
label="Inference Steps", | |
minimum=10, | |
maximum=100, | |
step=5, | |
value=50 | |
) | |
guidance_scale = gr.Slider( | |
label="Guidance Scale", | |
minimum=1.0, | |
maximum=10.0, | |
step=0.5, | |
value=4.5 | |
) | |
with gr.Row(): | |
audio_scale = gr.Slider( | |
label="Audio Scale (leave 0 to use guidance scale)", | |
minimum=0.0, | |
maximum=10.0, | |
step=0.5, | |
value=0.0 | |
) | |
overlap_frames = gr.Slider( | |
label="Overlap Frames", | |
minimum=1, | |
maximum=25, | |
step=4, | |
value=13, | |
info="Must be 1 + 4*n" | |
) | |
with gr.Row(): | |
fps = gr.Slider( | |
label="FPS", | |
minimum=10, | |
maximum=30, | |
step=1, | |
value=25 | |
) | |
silence_duration = gr.Slider( | |
label="Silence Duration (s)", | |
minimum=0.0, | |
maximum=2.0, | |
step=0.1, | |
value=0.3 | |
) | |
generate_btn = gr.Button( | |
"π¬ Generate Avatar Video", | |
variant="primary" | |
) | |
with gr.Column(scale=1): | |
# Output component | |
output_video = gr.Video( | |
label="Generated Avatar Video", | |
elem_id="output_video" | |
) | |
# Examples | |
gr.Examples( | |
examples=[ | |
[ | |
"examples/images/0000.jpeg", | |
"examples/audios/0000.MP3", | |
"A professional woman giving a presentation with confident gestures" | |
], | |
], | |
inputs=[reference_image, audio_file, text_prompt], | |
label="Example Inputs" | |
) | |
# Connect the generate button | |
generate_btn.click( | |
fn=generate_avatar_video, | |
inputs=[ | |
reference_image, | |
audio_file, | |
text_prompt, | |
seed, | |
num_steps, | |
guidance_scale, | |
audio_scale, | |
overlap_frames, | |
fps, | |
silence_duration, | |
resolution | |
], | |
outputs=output_video | |
) | |
gr.Markdown(""" | |
## π Notes | |
- The reference image should be a clear frontal view of the person | |
- Audio should be clear speech without background music | |
- Generation may take several minutes depending on video length | |
- For best results, use high-quality input images and audio | |
""") | |
# Launch the app | |
if __name__ == "__main__": | |
app.launch(share=True) |