ReCamMaster / app.py
jbilcke-hf's picture
jbilcke-hf HF Staff
Update app.py
2932acc verified
raw
history blame
6.07 kB
import gradio as gr
import torch
import os
import tempfile
import shutil
import imageio
import logging
from pathlib import Path
# Import from our modules
from model_loader import ModelLoader, MODELS_ROOT_DIR
from video_processor import VideoProcessor
from config import CAMERA_TRANSFORMATIONS, TEST_DATA_DIR
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global model loader instance
model_loader = ModelLoader()
video_processor = None
def init_video_processor():
"""Initialize video processor"""
global video_processor
if model_loader.is_loaded and video_processor is None:
video_processor = VideoProcessor(model_loader.pipe)
return video_processor is not None
def extract_frames_from_video(video_path, output_dir, max_frames=81):
"""Extract frames from video and ensure we have at least 81 frames"""
os.makedirs(output_dir, exist_ok=True)
reader = imageio.get_reader(video_path)
fps = reader.get_meta_data()['fps']
total_frames = reader.count_frames()
frames = []
for i, frame in enumerate(reader):
frames.append(frame)
reader.close()
# If we have fewer than required frames, repeat the last frame
if len(frames) < max_frames:
logger.info(f"Video has {len(frames)} frames, padding to {max_frames} frames")
last_frame = frames[-1]
while len(frames) < max_frames:
frames.append(last_frame)
# Save frames
for i, frame in enumerate(frames[:max_frames]):
frame_path = os.path.join(output_dir, f"frame_{i:04d}.png")
imageio.imwrite(frame_path, frame)
return len(frames[:max_frames]), fps
def generate_recammaster_video(
video_file,
text_prompt,
camera_type,
progress=gr.Progress()
):
"""Main function to generate video with ReCamMaster"""
if not model_loader.is_loaded:
return None, "Error: Models not loaded! Please load models first."
if not init_video_processor():
return None, "Error: Failed to initialize video processor."
if video_file is None:
return None, "Please upload a video file."
try:
# Create temporary directory for processing
with tempfile.TemporaryDirectory() as temp_dir:
progress(0.1, desc="Processing input video...")
# Copy uploaded video to temp directory
input_video_path = os.path.join(temp_dir, "input.mp4")
shutil.copy(video_file.name, input_video_path)
# Extract frames
progress(0.2, desc="Extracting video frames...")
num_frames, fps = extract_frames_from_video(input_video_path, os.path.join(temp_dir, "frames"))
logger.info(f"Extracted {num_frames} frames at {fps} fps")
# Process with ReCamMaster
progress(0.3, desc="Processing with ReCamMaster...")
output_video = video_processor.process_video(
input_video_path,
text_prompt,
camera_type
)
# Save output video
progress(0.9, desc="Saving output video...")
output_path = os.path.join(temp_dir, "output.mp4")
from diffsynth import save_video
save_video(output_video, output_path, fps=30, quality=5)
# Copy to persistent location
final_output_path = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False).name
shutil.copy(output_path, final_output_path)
progress(1.0, desc="Done!")
transformation_name = CAMERA_TRANSFORMATIONS.get(str(camera_type), "Unknown")
status_msg = f"Successfully generated video with '{transformation_name}' camera movement!"
return final_output_path, status_msg
except Exception as e:
logger.error(f"Error generating video: {str(e)}")
return None, f"Error: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="ReCamMaster Demo") as demo:
gr.Markdown(f"""
# 🎥 ReCamMaster
ReCamMaster allows you to re-capture videos with novel camera trajectories.
Upload a video and select a camera transformation to see the magic!
""")
with gr.Row():
with gr.Column():
# Video input section
with gr.Group():
gr.Markdown("### Step 1: Upload Video")
video_input = gr.Video(label="Input Video")
text_prompt = gr.Textbox(
label="Text Prompt (describe your video)",
placeholder="A person walking in the street",
value="A dynamic scene"
)
# Camera selection
with gr.Group():
gr.Markdown("### Step 2: Select Camera Movement")
camera_type = gr.Radio(
choices=[(v, k) for k, v in CAMERA_TRANSFORMATIONS.items()],
label="Camera Transformation",
value="1"
)
# Generate button
generate_btn = gr.Button("Generate Video", variant="primary")
with gr.Column():
# Output section
output_video = gr.Video(label="Output Video")
status_output = gr.Textbox(label="Generation Status", interactive=False)
# Example videos
gr.Markdown("### Example Videos")
gr.Examples(
examples=[
[f"{TEST_DATA_DIR}/videos/case0.mp4", "A person dancing", "1"],
[f"{TEST_DATA_DIR}/videos/case1.mp4", "A scenic view", "5"],
],
inputs=[video_input, text_prompt, camera_type],
)
# Event handlers
generate_btn.click(
fn=generate_recammaster_video,
inputs=[video_input, text_prompt, camera_type],
outputs=[output_video, status_output]
)
if __name__ == "__main__":
model_loader.load_models()
demo.launch(share=True)