Spaces:
Paused
Paused
| import torch | |
| import numpy as np | |
| import json | |
| import imageio | |
| from PIL import Image | |
| from torchvision.transforms import v2 | |
| from einops import rearrange | |
| import torchvision | |
| import logging | |
| from config import TEST_DATA_DIR | |
| from camera_utils import Camera, parse_matrix, get_relative_pose | |
| logger = logging.getLogger(__name__) | |
| class VideoProcessor: | |
| def __init__(self, pipe): | |
| self.pipe = pipe | |
| self.default_height = 480 | |
| self.default_width = 832 | |
| def crop_and_resize(self, image, height, width): | |
| """Crop and resize image to match target dimensions""" | |
| width_img, height_img = image.size | |
| scale = max(width / width_img, height / height_img) | |
| image = torchvision.transforms.functional.resize( | |
| image, | |
| (round(height_img*scale), round(width_img*scale)), | |
| interpolation=torchvision.transforms.InterpolationMode.BILINEAR | |
| ) | |
| return image | |
| def load_video_frames(self, video_path, num_frames=81, height=480, width=832): | |
| """Load and process video frames""" | |
| reader = imageio.get_reader(video_path) | |
| frames = [] | |
| # Create frame processor with specified dimensions | |
| frame_process = v2.Compose([ | |
| v2.CenterCrop(size=(height, width)), | |
| v2.Resize(size=(height, width), antialias=True), | |
| v2.ToTensor(), | |
| v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), | |
| ]) | |
| for i in range(num_frames): | |
| try: | |
| frame = reader.get_data(i) | |
| frame = Image.fromarray(frame) | |
| frame = self.crop_and_resize(frame, height, width) | |
| frame = frame_process(frame) | |
| frames.append(frame) | |
| except: | |
| # If we run out of frames, repeat the last one | |
| if frames: | |
| frames.append(frames[-1]) | |
| else: | |
| raise ValueError("Video is too short!") | |
| reader.close() | |
| frames = torch.stack(frames, dim=0) | |
| frames = rearrange(frames, "T C H W -> C T H W") | |
| video_tensor = frames.unsqueeze(0) # Add batch dimension | |
| return video_tensor | |
| def load_camera_trajectory(self, cam_type, num_frames=81): | |
| """Load camera trajectory for the selected type""" | |
| tgt_camera_path = "./camera_trajectories/camera_extrinsics.json" | |
| with open(tgt_camera_path, 'r') as file: | |
| cam_data = json.load(file) | |
| # Get camera trajectory for selected type | |
| cam_idx = list(range(num_frames))[::4] # Sample every 4 frames | |
| traj = [parse_matrix(cam_data[f"frame{idx}"][f"cam{int(cam_type):02d}"]) for idx in cam_idx] | |
| traj = np.stack(traj).transpose(0, 2, 1) | |
| c2ws = [] | |
| for c2w in traj: | |
| c2w = c2w[:, [1, 2, 0, 3]] | |
| c2w[:3, 1] *= -1. | |
| c2w[:3, 3] /= 100 | |
| c2ws.append(c2w) | |
| tgt_cam_params = [Camera(cam_param) for cam_param in c2ws] | |
| relative_poses = [] | |
| for i in range(len(tgt_cam_params)): | |
| relative_pose = get_relative_pose([tgt_cam_params[0], tgt_cam_params[i]]) | |
| relative_poses.append(torch.as_tensor(relative_pose)[:,:3,:][1]) | |
| pose_embedding = torch.stack(relative_poses, dim=0) # 21x3x4 | |
| pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') | |
| camera_tensor = pose_embedding.to(torch.bfloat16).unsqueeze(0) # Add batch dimension | |
| return camera_tensor | |
| def process_video(self, video_path, text_prompt, cam_type, num_frames=81, height=480, width=832, seed=0, num_inference_steps=50, cfg_scale=5.0): | |
| """Process video through ReCamMaster model""" | |
| # Load video frames | |
| video_tensor = self.load_video_frames(video_path, num_frames, height, width) | |
| # Load camera trajectory | |
| camera_tensor = self.load_camera_trajectory(cam_type, num_frames) | |
| # Generate video with ReCamMaster | |
| video = self.pipe( | |
| prompt=[text_prompt], | |
| negative_prompt=["worst quality, low quality, blurry, jittery, distorted"], | |
| source_video=video_tensor, | |
| target_camera=camera_tensor, | |
| height=height, | |
| width=width, | |
| num_frames=num_frames, | |
| cfg_scale=cfg_scale, | |
| num_inference_steps=num_inference_steps, | |
| seed=seed, | |
| tiled=True | |
| ) | |
| return video |