import copy
import json
import os
import os.path as osp
import queue
import secrets
import threading
import time
from datetime import datetime
from glob import glob
from pathlib import Path
from typing import Literal
import gradio as gr
import httpx
import imageio.v3 as iio
import numpy as np
import torch
import torch.nn.functional as F
import tyro
import viser
import viser.transforms as vt
from einops import rearrange
from gradio import networking
from gradio.context import LocalContext
from gradio.tunneling import CERTIFICATE_PATH, Tunnel
from seva.eval import (
IS_TORCH_NIGHTLY,
chunk_input_and_test,
create_transforms_simple,
infer_prior_stats,
run_one_scene,
transform_img_and_K,
)
from seva.geometry import (
DEFAULT_FOV_RAD,
get_default_intrinsics,
get_preset_pose_fov,
normalize_scene,
)
from seva.gui import define_gui
from seva.model import SGMWrapper
from seva.modules.autoencoder import AutoEncoder
from seva.modules.conditioner import CLIPConditioner
from seva.modules.preprocessor import Dust3rPipeline
from seva.sampling import DDPMDiscretization, DiscreteDenoiser
from seva.utils import load_model
device = "cuda:0"
# Constants.
WORK_DIR = "work_dirs/demo_gr"
MAX_SESSIONS = 1
ADVANCE_EXAMPLE_MAP = [
(
"assets/advance/blue-car.jpg",
["assets/advance/blue-car.jpg"],
),
(
"assets/advance/garden-4_0.jpg",
[
"assets/advance/garden-4_0.jpg",
"assets/advance/garden-4_1.jpg",
"assets/advance/garden-4_2.jpg",
"assets/advance/garden-4_3.jpg",
],
),
(
"assets/advance/vgg-lab-4_0.png",
[
"assets/advance/vgg-lab-4_0.png",
"assets/advance/vgg-lab-4_1.png",
"assets/advance/vgg-lab-4_2.png",
"assets/advance/vgg-lab-4_3.png",
],
),
(
"assets/advance/telebooth-2_0.jpg",
[
"assets/advance/telebooth-2_0.jpg",
"assets/advance/telebooth-2_1.jpg",
],
),
(
"assets/advance/backyard-7_0.jpg",
[
"assets/advance/backyard-7_0.jpg",
"assets/advance/backyard-7_1.jpg",
"assets/advance/backyard-7_2.jpg",
"assets/advance/backyard-7_3.jpg",
"assets/advance/backyard-7_4.jpg",
"assets/advance/backyard-7_5.jpg",
"assets/advance/backyard-7_6.jpg",
],
),
]
if IS_TORCH_NIGHTLY:
COMPILE = True
os.environ["TORCHINDUCTOR_AUTOGRAD_CACHE"] = "1"
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
else:
COMPILE = False
# Shared global variables across sessions.
DUST3R = Dust3rPipeline(device=device) # type: ignore
MODEL = SGMWrapper(load_model(device="cpu", verbose=True).eval()).to(device)
AE = AutoEncoder(chunk_size=1).to(device)
CONDITIONER = CLIPConditioner().to(device)
DISCRETIZATION = DDPMDiscretization()
DENOISER = DiscreteDenoiser(discretization=DISCRETIZATION, num_idx=1000, device=device)
VERSION_DICT = {
"H": 576,
"W": 576,
"T": 21,
"C": 4,
"f": 8,
"options": {},
}
SERVERS = {}
ABORT_EVENTS = {}
if COMPILE:
MODEL = torch.compile(MODEL)
CONDITIONER = torch.compile(CONDITIONER)
AE = torch.compile(AE)
class SevaRenderer(object):
def __init__(self, server: viser.ViserServer):
self.server = server
self.gui_state = None
def preprocess(
self, input_img_path_or_tuples: list[tuple[str, None]] | str
) -> tuple[dict, dict, dict]:
# Simply hardcode these such that aspect ratio is always kept and
# shorter side is resized to 576. This is only to make GUI option fewer
# though, changing it still works.
shorter: int = 576
# Has to be 64 multiple for the network.
shorter = round(shorter / 64) * 64
if isinstance(input_img_path_or_tuples, str):
# Assume `Basic` demo mode: just hardcode the camera parameters and ignore points.
input_imgs = torch.as_tensor(
iio.imread(input_img_path_or_tuples) / 255.0, dtype=torch.float32
)[None, ..., :3]
input_imgs = transform_img_and_K(
input_imgs.permute(0, 3, 1, 2),
shorter,
K=None,
size_stride=64,
)[0].permute(0, 2, 3, 1)
input_Ks = get_default_intrinsics(
aspect_ratio=input_imgs.shape[2] / input_imgs.shape[1]
)
input_c2ws = torch.eye(4)[None]
# Simulate a small time interval such that gradio can update
# propgress properly.
time.sleep(0.1)
return (
{
"input_imgs": input_imgs,
"input_Ks": input_Ks,
"input_c2ws": input_c2ws,
"input_wh": (input_imgs.shape[2], input_imgs.shape[1]),
"points": [np.zeros((0, 3))],
"point_colors": [np.zeros((0, 3))],
"scene_scale": 1.0,
},
gr.update(visible=False),
gr.update(),
)
else:
# Assume `Advance` demo mode: use dust3r to extract camera parameters and points.
img_paths = [p for (p, _) in input_img_path_or_tuples]
(
input_imgs,
input_Ks,
input_c2ws,
points,
point_colors,
) = DUST3R.infer_cameras_and_points(img_paths)
num_inputs = len(img_paths)
if num_inputs == 1:
input_imgs, input_Ks, input_c2ws, points, point_colors = (
input_imgs[:1],
input_Ks[:1],
input_c2ws[:1],
points[:1],
point_colors[:1],
)
input_imgs = [img[..., :3] for img in input_imgs]
# Normalize the scene.
point_chunks = [p.shape[0] for p in points]
point_indices = np.cumsum(point_chunks)[:-1]
input_c2ws, points, _ = normalize_scene( # type: ignore
input_c2ws,
np.concatenate(points, 0),
camera_center_method="poses",
)
points = np.split(points, point_indices, 0)
# Scale camera and points for viewport visualization.
scene_scale = np.median(
np.ptp(np.concatenate([input_c2ws[:, :3, 3], *points], 0), -1)
)
input_c2ws[:, :3, 3] /= scene_scale
points = [point / scene_scale for point in points]
input_imgs = [
torch.as_tensor(img / 255.0, dtype=torch.float32) for img in input_imgs
]
input_Ks = torch.as_tensor(input_Ks)
input_c2ws = torch.as_tensor(input_c2ws)
new_input_imgs, new_input_Ks = [], []
for img, K in zip(input_imgs, input_Ks):
img = rearrange(img, "h w c -> 1 c h w")
# If you don't want to keep aspect ratio and want to always center crop, use this:
# img, K = transform_img_and_K(img, (shorter, shorter), K=K[None])
img, K = transform_img_and_K(img, shorter, K=K[None], size_stride=64)
assert isinstance(K, torch.Tensor)
K = K / K.new_tensor([img.shape[-1], img.shape[-2], 1])[:, None]
new_input_imgs.append(img)
new_input_Ks.append(K)
input_imgs = torch.cat(new_input_imgs, 0)
input_imgs = rearrange(input_imgs, "b c h w -> b h w c")[..., :3]
input_Ks = torch.cat(new_input_Ks, 0)
return (
{
"input_imgs": input_imgs,
"input_Ks": input_Ks,
"input_c2ws": input_c2ws,
"input_wh": (input_imgs.shape[2], input_imgs.shape[1]),
"points": points,
"point_colors": point_colors,
"scene_scale": scene_scale,
},
gr.update(visible=False),
gr.update()
if num_inputs <= 10
else gr.update(choices=["interp"], value="interp"),
)
def visualize_scene(self, preprocessed: dict):
server = self.server
server.scene.reset()
server.gui.reset()
set_bkgd_color(server)
(
input_imgs,
input_Ks,
input_c2ws,
input_wh,
points,
point_colors,
scene_scale,
) = (
preprocessed["input_imgs"],
preprocessed["input_Ks"],
preprocessed["input_c2ws"],
preprocessed["input_wh"],
preprocessed["points"],
preprocessed["point_colors"],
preprocessed["scene_scale"],
)
W, H = input_wh
server.scene.set_up_direction(-input_c2ws[..., :3, 1].mean(0).numpy())
# Use first image as default fov.
assert input_imgs[0].shape[:2] == (H, W)
if H > W:
init_fov = 2 * np.arctan(1 / (2 * input_Ks[0, 0, 0].item()))
else:
init_fov = 2 * np.arctan(1 / (2 * input_Ks[0, 1, 1].item()))
init_fov_deg = float(init_fov / np.pi * 180.0)
frustum_nodes, pcd_nodes = [], []
for i in range(len(input_imgs)):
K = input_Ks[i]
frustum = server.scene.add_camera_frustum(
f"/scene_assets/cameras/{i}",
fov=2 * np.arctan(1 / (2 * K[1, 1].item())),
aspect=W / H,
scale=0.1 * scene_scale,
image=(input_imgs[i].numpy() * 255.0).astype(np.uint8),
wxyz=vt.SO3.from_matrix(input_c2ws[i, :3, :3].numpy()).wxyz,
position=input_c2ws[i, :3, 3].numpy(),
)
def get_handler(frustum):
def handler(event: viser.GuiEvent) -> None:
assert event.client_id is not None
client = server.get_clients()[event.client_id]
with client.atomic():
client.camera.position = frustum.position
client.camera.wxyz = frustum.wxyz
# Set look_at as the projected origin onto the
# frustum's forward direction.
look_direction = vt.SO3(frustum.wxyz).as_matrix()[:, 2]
position_origin = -frustum.position
client.camera.look_at = (
frustum.position
+ np.dot(look_direction, position_origin)
/ np.linalg.norm(position_origin)
* look_direction
)
return handler
frustum.on_click(get_handler(frustum)) # type: ignore
frustum_nodes.append(frustum)
pcd = server.scene.add_point_cloud(
f"/scene_assets/points/{i}",
points[i],
point_colors[i],
point_size=0.01 * scene_scale,
point_shape="circle",
)
pcd_nodes.append(pcd)
with server.gui.add_folder("Scene scale", expand_by_default=False, order=200):
camera_scale_slider = server.gui.add_slider(
"Log camera scale", initial_value=0.0, min=-2.0, max=2.0, step=0.1
)
@camera_scale_slider.on_update
def _(_) -> None:
for i in range(len(frustum_nodes)):
frustum_nodes[i].scale = (
0.1 * scene_scale * 10**camera_scale_slider.value
)
point_scale_slider = server.gui.add_slider(
"Log point scale", initial_value=0.0, min=-2.0, max=2.0, step=0.1
)
@point_scale_slider.on_update
def _(_) -> None:
for i in range(len(pcd_nodes)):
pcd_nodes[i].point_size = (
0.01 * scene_scale * 10**point_scale_slider.value
)
self.gui_state = define_gui(
server,
init_fov=init_fov_deg,
img_wh=input_wh,
scene_scale=scene_scale,
)
def get_target_c2ws_and_Ks_from_gui(self, preprocessed: dict):
input_wh = preprocessed["input_wh"]
W, H = input_wh
gui_state = self.gui_state
assert gui_state is not None and gui_state.camera_traj_list is not None
target_c2ws, target_Ks = [], []
for item in gui_state.camera_traj_list:
target_c2ws.append(item["w2c"])
assert item["img_wh"] == input_wh
K = np.array(item["K"]).reshape(3, 3) / np.array([W, H, 1])[:, None]
target_Ks.append(K)
target_c2ws = torch.as_tensor(
np.linalg.inv(np.array(target_c2ws).reshape(-1, 4, 4))
)
target_Ks = torch.as_tensor(np.array(target_Ks).reshape(-1, 3, 3))
return target_c2ws, target_Ks
def get_target_c2ws_and_Ks_from_preset(
self,
preprocessed: dict,
preset_traj: Literal[
"orbit",
"spiral",
"lemniscate",
"zoom-in",
"zoom-out",
"dolly zoom-in",
"dolly zoom-out",
"move-forward",
"move-backward",
"move-up",
"move-down",
"move-left",
"move-right",
],
num_frames: int,
zoom_factor: float | None,
):
img_wh = preprocessed["input_wh"]
start_c2w = preprocessed["input_c2ws"][0]
start_w2c = torch.linalg.inv(start_c2w)
look_at = torch.tensor([0, 0, 10])
start_fov = DEFAULT_FOV_RAD
target_c2ws, target_fovs = get_preset_pose_fov(
preset_traj,
num_frames,
start_w2c,
look_at,
-start_c2w[:3, 1],
start_fov,
spiral_radii=[1.0, 1.0, 0.5],
zoom_factor=zoom_factor,
)
target_c2ws = torch.as_tensor(target_c2ws)
target_fovs = torch.as_tensor(target_fovs)
target_Ks = get_default_intrinsics(
target_fovs, # type: ignore
aspect_ratio=img_wh[0] / img_wh[1],
)
return target_c2ws, target_Ks
def export_output_data(self, preprocessed: dict, output_dir: str):
input_imgs, input_Ks, input_c2ws, input_wh = (
preprocessed["input_imgs"],
preprocessed["input_Ks"],
preprocessed["input_c2ws"],
preprocessed["input_wh"],
)
target_c2ws, target_Ks = self.get_target_c2ws_and_Ks_from_gui(preprocessed)
num_inputs = len(input_imgs)
num_targets = len(target_c2ws)
input_imgs = (input_imgs.cpu().numpy() * 255.0).astype(np.uint8)
input_c2ws = input_c2ws.cpu().numpy()
input_Ks = input_Ks.cpu().numpy()
target_c2ws = target_c2ws.cpu().numpy()
target_Ks = target_Ks.cpu().numpy()
img_whs = np.array(input_wh)[None].repeat(len(input_imgs) + len(target_Ks), 0)
os.makedirs(output_dir, exist_ok=True)
img_paths = []
for i, img in enumerate(input_imgs):
iio.imwrite(img_path := osp.join(output_dir, f"{i:03d}.png"), img)
img_paths.append(img_path)
for i in range(num_targets):
iio.imwrite(
img_path := osp.join(output_dir, f"{i + num_inputs:03d}.png"),
np.zeros((input_wh[1], input_wh[0], 3), dtype=np.uint8),
)
img_paths.append(img_path)
# Convert from OpenCV to OpenGL camera format.
all_c2ws = np.concatenate([input_c2ws, target_c2ws])
all_Ks = np.concatenate([input_Ks, target_Ks])
all_c2ws = all_c2ws @ np.diag([1, -1, -1, 1])
create_transforms_simple(output_dir, img_paths, img_whs, all_c2ws, all_Ks)
split_dict = {
"train_ids": list(range(num_inputs)),
"test_ids": list(range(num_inputs, num_inputs + num_targets)),
}
with open(
osp.join(output_dir, f"train_test_split_{num_inputs}.json"), "w"
) as f:
json.dump(split_dict, f, indent=4)
gr.Info(f"Output data saved to {output_dir}", duration=1)
def render(
self,
preprocessed: dict,
session_hash: str,
seed: int,
chunk_strategy: str,
cfg: float,
preset_traj: Literal[
"orbit",
"spiral",
"lemniscate",
"zoom-in",
"zoom-out",
"dolly zoom-in",
"dolly zoom-out",
"move-forward",
"move-backward",
"move-up",
"move-down",
"move-left",
"move-right",
]
| None,
num_frames: int | None,
zoom_factor: float | None,
camera_scale: float,
):
render_name = datetime.now().strftime("%Y%m%d_%H%M%S")
render_dir = osp.join(WORK_DIR, render_name)
input_imgs, input_Ks, input_c2ws, (W, H) = (
preprocessed["input_imgs"],
preprocessed["input_Ks"],
preprocessed["input_c2ws"],
preprocessed["input_wh"],
)
num_inputs = len(input_imgs)
if preset_traj is None:
target_c2ws, target_Ks = self.get_target_c2ws_and_Ks_from_gui(preprocessed)
else:
assert num_frames is not None
assert num_inputs == 1
input_c2ws = torch.eye(4)[None].to(dtype=input_c2ws.dtype)
target_c2ws, target_Ks = self.get_target_c2ws_and_Ks_from_preset(
preprocessed, preset_traj, num_frames, zoom_factor
)
all_c2ws = torch.cat([input_c2ws, target_c2ws], 0)
all_Ks = (
torch.cat([input_Ks, target_Ks], 0)
* input_Ks.new_tensor([W, H, 1])[:, None]
)
num_targets = len(target_c2ws)
input_indices = list(range(num_inputs))
target_indices = np.arange(num_inputs, num_inputs + num_targets).tolist()
# Get anchor cameras.
T = VERSION_DICT["T"]
version_dict = copy.deepcopy(VERSION_DICT)
num_anchors = infer_prior_stats(
T,
num_inputs,
num_total_frames=num_targets,
version_dict=version_dict,
)
# infer_prior_stats modifies T in-place.
T = version_dict["T"]
assert isinstance(num_anchors, int)
anchor_indices = np.linspace(
num_inputs,
num_inputs + num_targets - 1,
num_anchors,
).tolist()
anchor_c2ws = all_c2ws[[round(ind) for ind in anchor_indices]]
anchor_Ks = all_Ks[[round(ind) for ind in anchor_indices]]
# Create image conditioning.
all_imgs_np = (
F.pad(input_imgs, (0, 0, 0, 0, 0, 0, 0, num_targets), value=0.0).numpy()
* 255.0
).astype(np.uint8)
image_cond = {
"img": all_imgs_np,
"input_indices": input_indices,
"prior_indices": anchor_indices,
}
# Create camera conditioning (K is unnormalized).
camera_cond = {
"c2w": all_c2ws,
"K": all_Ks,
"input_indices": list(range(num_inputs + num_targets)),
}
# Run rendering.
num_steps = 50
options_ori = VERSION_DICT["options"]
options = copy.deepcopy(options_ori)
options["chunk_strategy"] = chunk_strategy
options["video_save_fps"] = 30.0
options["beta_linear_start"] = 5e-6
options["log_snr_shift"] = 2.4
options["guider_types"] = [1, 2]
options["cfg"] = [
float(cfg),
3.0 if num_inputs >= 9 else 2.0,
] # We define semi-dense-view regime to have 9 input views.
options["camera_scale"] = camera_scale
options["num_steps"] = num_steps
options["cfg_min"] = 1.2
options["encoding_t"] = 1
options["decoding_t"] = 1
assert session_hash in ABORT_EVENTS
abort_event = ABORT_EVENTS[session_hash]
abort_event.clear()
options["abort_event"] = abort_event
task = "img2trajvid"
# Get number of first pass chunks.
T_first_pass = T[0] if isinstance(T, (list, tuple)) else T
chunk_strategy_first_pass = options.get(
"chunk_strategy_first_pass", "gt-nearest"
)
num_chunks_0 = len(
chunk_input_and_test(
T_first_pass,
input_c2ws,
anchor_c2ws,
input_indices,
image_cond["prior_indices"],
options={**options, "sampler_verbose": False},
task=task,
chunk_strategy=chunk_strategy_first_pass,
gt_input_inds=list(range(input_c2ws.shape[0])),
)[1]
)
# Get number of second pass chunks.
anchor_argsort = np.argsort(input_indices + anchor_indices).tolist()
anchor_indices = np.array(input_indices + anchor_indices)[
anchor_argsort
].tolist()
gt_input_inds = [anchor_argsort.index(i) for i in range(input_c2ws.shape[0])]
anchor_c2ws_second_pass = torch.cat([input_c2ws, anchor_c2ws], dim=0)[
anchor_argsort
]
T_second_pass = T[1] if isinstance(T, (list, tuple)) else T
chunk_strategy = options.get("chunk_strategy", "nearest")
num_chunks_1 = len(
chunk_input_and_test(
T_second_pass,
anchor_c2ws_second_pass,
target_c2ws,
anchor_indices,
target_indices,
options={**options, "sampler_verbose": False},
task=task,
chunk_strategy=chunk_strategy,
gt_input_inds=gt_input_inds,
)[1]
)
second_pass_pbar = gr.Progress().tqdm(
iterable=None,
desc="Second pass sampling",
total=num_chunks_1 * num_steps,
)
first_pass_pbar = gr.Progress().tqdm(
iterable=None,
desc="First pass sampling",
total=num_chunks_0 * num_steps,
)
video_path_generator = run_one_scene(
task=task,
version_dict={
"H": H,
"W": W,
"T": T,
"C": VERSION_DICT["C"],
"f": VERSION_DICT["f"],
"options": options,
},
model=MODEL,
ae=AE,
conditioner=CONDITIONER,
denoiser=DENOISER,
image_cond=image_cond,
camera_cond=camera_cond,
save_path=render_dir,
use_traj_prior=True,
traj_prior_c2ws=anchor_c2ws,
traj_prior_Ks=anchor_Ks,
seed=seed,
gradio=True,
first_pass_pbar=first_pass_pbar,
second_pass_pbar=second_pass_pbar,
abort_event=abort_event,
)
output_queue = queue.Queue()
blocks = LocalContext.blocks.get()
event_id = LocalContext.event_id.get()
def worker():
# gradio doesn't support threading with progress intentionally, so
# we need to hack this.
LocalContext.blocks.set(blocks)
LocalContext.event_id.set(event_id)
for i, video_path in enumerate(video_path_generator):
if i == 0:
output_queue.put(
(
video_path,
gr.update(),
gr.update(),
gr.update(),
)
)
elif i == 1:
output_queue.put(
(
video_path,
gr.update(visible=True),
gr.update(visible=False),
gr.update(visible=False),
)
)
else:
gr.Error("More than two passes during rendering.")
thread = threading.Thread(target=worker, daemon=True)
thread.start()
while thread.is_alive() or not output_queue.empty():
if abort_event.is_set():
thread.join()
abort_event.clear()
yield (
gr.update(),
gr.update(visible=True),
gr.update(visible=False),
gr.update(visible=False),
)
time.sleep(0.1)
while not output_queue.empty():
yield output_queue.get()
# This is basically a copy of the original `networking.setup_tunnel` function,
# but it also returns the tunnel object for proper cleanup.
def setup_tunnel(
local_host: str, local_port: int, share_token: str, share_server_address: str | None
) -> tuple[str, Tunnel]:
share_server_address = (
networking.GRADIO_SHARE_SERVER_ADDRESS
if share_server_address is None
else share_server_address
)
if share_server_address is None:
try:
response = httpx.get(networking.GRADIO_API_SERVER, timeout=30)
payload = response.json()[0]
remote_host, remote_port = payload["host"], int(payload["port"])
certificate = payload["root_ca"]
Path(CERTIFICATE_PATH).parent.mkdir(parents=True, exist_ok=True)
with open(CERTIFICATE_PATH, "w") as f:
f.write(certificate)
except Exception as e:
raise RuntimeError(
"Could not get share link from Gradio API Server."
) from e
else:
remote_host, remote_port = share_server_address.split(":")
remote_port = int(remote_port)
tunnel = Tunnel(remote_host, remote_port, local_host, local_port, share_token)
address = tunnel.start_tunnel()
return address, tunnel
def set_bkgd_color(server: viser.ViserServer | viser.ClientHandle):
server.scene.set_background_image(np.array([[[39, 39, 42]]], dtype=np.uint8))
def start_server_and_abort_event(request: gr.Request):
server = viser.ViserServer()
@server.on_client_connect
def _(client: viser.ClientHandle):
# Force dark mode that blends well with gradio's dark theme.
client.gui.configure_theme(
dark_mode=True,
show_share_button=False,
control_layout="collapsible",
)
set_bkgd_color(client)
print(f"Starting server {server.get_port()}")
server_url, tunnel = setup_tunnel(
local_host=server.get_host(),
local_port=server.get_port(),
share_token=secrets.token_urlsafe(32),
share_server_address=None,
)
SERVERS[request.session_hash] = (server, tunnel)
if server_url is None:
raise gr.Error(
"Failed to get a viewport URL. Please check your network connection."
)
# Give it enough time to start.
time.sleep(1)
ABORT_EVENTS[request.session_hash] = threading.Event()
return (
SevaRenderer(server),
gr.HTML(
f'',
container=True,
),
request.session_hash,
)
def stop_server_and_abort_event(request: gr.Request):
if request.session_hash in SERVERS:
print(f"Stopping server {request.session_hash}")
server, tunnel = SERVERS.pop(request.session_hash)
server.stop()
tunnel.kill()
if request.session_hash in ABORT_EVENTS:
print(f"Setting abort event {request.session_hash}")
ABORT_EVENTS[request.session_hash].set()
# Give it enough time to abort jobs.
time.sleep(5)
ABORT_EVENTS.pop(request.session_hash)
def set_abort_event(request: gr.Request):
if request.session_hash in ABORT_EVENTS:
print(f"Setting abort event {request.session_hash}")
ABORT_EVENTS[request.session_hash].set()
def get_advance_examples(selection: gr.SelectData):
index = selection.index
return (
gr.Gallery(ADVANCE_EXAMPLE_MAP[index][1], visible=True),
gr.update(visible=True),
gr.update(visible=True),
gr.Gallery(visible=False),
)
def get_preamble():
gr.Markdown("""
# Stable Virtual Camera
Welcome to the demo of Stable Virtual Camera (Seva)! Given any number of input views and their cameras, this demo will allow you to generate novel views of a scene at any target camera of interest.
We provide two ways to use our demo (selected by the tab below, documented [here](https://github.com/Stability-AI/stable-virtual-camera/blob/main/docs/GR_USAGE.md)):
1. **[Basic](https://github.com/user-attachments/assets/4d965fa6-d8eb-452c-b773-6e09c88ca705)**: Given a single image, you can generate a video following one of our preset camera trajectories.
2. **[Advanced](https://github.com/user-attachments/assets/dcec1be0-bd10-441e-879c-d1c2b63091ba)**: Given any number of input images, you can generate a video following any camera trajectory of your choice by our key-frame-based interface.
> This is a research preview and comes with a few [limitations](https://stable-virtual-camera.github.io/#limitations):
> - Limited quality in certain subjects due to training data, including humans, animals, and dynamic textures.
> - Limited quality in some highly ambiguous scenes and camera trajectories, including extreme views and collision into objects.
""")
# Make sure that gradio uses dark theme.
_APP_JS = """
function refresh() {
const url = new URL(window.location);
if (url.searchParams.get('__theme') !== 'dark') {
url.searchParams.set('__theme', 'dark');
}
}
"""
def main(server_port: int | None = None, share: bool = True):
with gr.Blocks(js=_APP_JS) as app:
renderer = gr.State()
session_hash = gr.State()
_ = get_preamble()
with gr.Tabs():
with gr.Tab("Basic"):
render_btn = gr.Button("Render video", interactive=False, render=False)
with gr.Row():
with gr.Column():
with gr.Group():
# Initially disable the Preprocess Images button until an image is selected.
preprocess_btn = gr.Button("Preprocess images", interactive=False)
preprocess_progress = gr.Textbox(
label="",
visible=False,
interactive=False,
)
with gr.Group():
input_imgs = gr.Image(
type="filepath",
label="Input",
height=200,
)
_ = gr.Examples(
examples=sorted(glob("assets/basic/*")),
inputs=[input_imgs],
label="Example",
)
chunk_strategy = gr.Dropdown(
["interp", "interp-gt"],
label="Chunk strategy",
render=False,
)
preprocessed = gr.State()
# Enable the Preprocess Images button only if an image is selected.
input_imgs.change(
lambda img: gr.update(interactive=bool(img)),
inputs=input_imgs,
outputs=preprocess_btn,
)
preprocess_btn.click(
lambda r, *args: [
*r.preprocess(*args),
gr.update(interactive=True),
],
inputs=[renderer, input_imgs],
outputs=[
preprocessed,
preprocess_progress,
chunk_strategy,
render_btn,
],
show_progress_on=[preprocess_progress],
concurrency_limit=1,
concurrency_id="gpu_queue",
)
preprocess_btn.click(
lambda: gr.update(visible=True),
outputs=[preprocess_progress],
)
with gr.Row():
preset_traj = gr.Dropdown(
choices=[
"orbit",
"spiral",
"lemniscate",
"zoom-in",
"zoom-out",
"dolly zoom-in",
"dolly zoom-out",
"move-forward",
"move-backward",
"move-up",
"move-down",
"move-left",
"move-right",
],
label="Preset trajectory",
value="orbit",
)
num_frames = gr.Slider(30, 150, 80, label="#Frames")
zoom_factor = gr.Slider(
step=0.01, label="Zoom factor", visible=False
)
with gr.Row():
seed = gr.Number(value=23, label="Random seed")
chunk_strategy.render()
cfg = gr.Slider(1.0, 7.0, value=4.0, label="CFG value")
with gr.Row():
camera_scale = gr.Slider(
0.1,
15.0,
value=2.0,
label="Camera scale",
)
def default_cfg_preset_traj(traj):
# These are just some hand-tuned values that we
# found work the best.
if traj in ["zoom-out", "move-down"]:
value = 5.0
elif traj in [
"orbit",
"dolly zoom-out",
"move-backward",
"move-up",
"move-left",
"move-right",
]:
value = 4.0
else:
value = 3.0
return value
preset_traj.change(
default_cfg_preset_traj,
inputs=[preset_traj],
outputs=[cfg],
)
preset_traj.change(
lambda traj: gr.update(
value=(
10.0 if "dolly" in traj or "pan" in traj else 2.0
)
),
inputs=[preset_traj],
outputs=[camera_scale],
)
def zoom_factor_preset_traj(traj):
visible = traj in [
"zoom-in",
"zoom-out",
"dolly zoom-in",
"dolly zoom-out",
]
is_zoomin = traj.endswith("zoom-in")
if is_zoomin:
minimum = 0.1
maximum = 0.5
value = 0.28
else:
minimum = 1.2
maximum = 3
value = 1.5
return gr.update(
visible=visible,
minimum=minimum,
maximum=maximum,
value=value,
)
preset_traj.change(
zoom_factor_preset_traj,
inputs=[preset_traj],
outputs=[zoom_factor],
)
with gr.Column():
with gr.Group():
abort_btn = gr.Button("Abort rendering", visible=False)
render_btn.render()
render_progress = gr.Textbox(
label="", visible=False, interactive=False
)
output_video = gr.Video(
label="Output", interactive=False, autoplay=True, loop=True
)
render_btn.click(
lambda r, *args: (yield from r.render(*args)),
inputs=[
renderer,
preprocessed,
session_hash,
seed,
chunk_strategy,
cfg,
preset_traj,
num_frames,
zoom_factor,
camera_scale,
],
outputs=[
output_video,
render_btn,
abort_btn,
render_progress,
],
show_progress_on=[render_progress],
concurrency_id="gpu_queue",
)
render_btn.click(
lambda: [
gr.update(visible=False),
gr.update(visible=True),
gr.update(visible=True),
],
outputs=[render_btn, abort_btn, render_progress],
)
abort_btn.click(set_abort_event)
with gr.Tab("Advanced"):
render_btn = gr.Button("Render video", interactive=False, render=False)
viewport = gr.HTML(container=True, render=False)
gr.Timer(0.1).tick(
lambda renderer: gr.update(
interactive=renderer is not None
and renderer.gui_state is not None
and renderer.gui_state.camera_traj_list is not None
),
inputs=[renderer],
outputs=[render_btn],
)
with gr.Row():
viewport.render()
with gr.Row():
with gr.Column():
with gr.Group():
# Initially disable the Preprocess Images button until images are selected.
preprocess_btn = gr.Button("Preprocess images", interactive=False)
preprocess_progress = gr.Textbox(
label="",
visible=False,
interactive=False,
)
with gr.Group():
input_imgs = gr.Gallery(
interactive=True,
label="Input",
columns=4,
height=200,
)
# Define example images (gradio doesn't support variable length
# examples so we need to hack it).
example_imgs = gr.Gallery(
[e[0] for e in ADVANCE_EXAMPLE_MAP],
allow_preview=False,
preview=False,
label="Example",
columns=20,
rows=1,
height=115,
)
example_imgs_expander = gr.Gallery(
visible=False,
interactive=False,
label="Example",
preview=True,
columns=20,
rows=1,
)
chunk_strategy = gr.Dropdown(
["interp-gt", "interp"],
label="Chunk strategy",
value="interp-gt",
render=False,
)
with gr.Row():
example_imgs_backer = gr.Button(
"Go back", visible=False
)
example_imgs_confirmer = gr.Button(
"Confirm", visible=False
)
example_imgs.select(
get_advance_examples,
outputs=[
example_imgs_expander,
example_imgs_confirmer,
example_imgs_backer,
example_imgs,
],
)
example_imgs_confirmer.click(
lambda x: (
x,
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=True),
gr.update(interactive=bool(x))
),
inputs=[example_imgs_expander],
outputs=[
input_imgs,
example_imgs_expander,
example_imgs_confirmer,
example_imgs_backer,
example_imgs,
preprocess_btn
],
)
example_imgs_backer.click(
lambda: (
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=True),
),
outputs=[
example_imgs_expander,
example_imgs_confirmer,
example_imgs_backer,
example_imgs,
],
)
preprocessed = gr.State()
preprocess_btn.click(
lambda r, *args: r.preprocess(*args),
inputs=[renderer, input_imgs],
outputs=[
preprocessed,
preprocess_progress,
chunk_strategy,
],
show_progress_on=[preprocess_progress],
concurrency_id="gpu_queue",
)
preprocess_btn.click(
lambda: gr.update(visible=True),
outputs=[preprocess_progress],
)
preprocessed.change(
lambda r, *args: r.visualize_scene(*args),
inputs=[renderer, preprocessed],
)
with gr.Row():
seed = gr.Number(value=23, label="Random seed")
chunk_strategy.render()
cfg = gr.Slider(1.0, 7.0, value=3.0, label="CFG value")
with gr.Row():
camera_scale = gr.Slider(
0.1,
15.0,
value=2.0,
label="Camera scale (useful for single-view input)",
)
with gr.Group():
output_data_dir = gr.Textbox(label="Output data directory")
output_data_btn = gr.Button("Export output data")
output_data_btn.click(
lambda r, *args: r.export_output_data(*args),
inputs=[renderer, preprocessed, output_data_dir],
)
with gr.Column():
with gr.Group():
abort_btn = gr.Button("Abort rendering", visible=False)
render_btn.render()
render_progress = gr.Textbox(
label="", visible=False, interactive=False
)
output_video = gr.Video(
label="Output", interactive=False, autoplay=True, loop=True
)
render_btn.click(
lambda r, *args: (yield from r.render(*args)),
inputs=[
renderer,
preprocessed,
session_hash,
seed,
chunk_strategy,
cfg,
gr.State(),
gr.State(),
gr.State(),
camera_scale,
],
outputs=[
output_video,
render_btn,
abort_btn,
render_progress,
],
show_progress_on=[render_progress],
concurrency_id="gpu_queue",
)
render_btn.click(
lambda: [
gr.update(visible=False),
gr.update(visible=True),
gr.update(visible=True),
],
outputs=[render_btn, abort_btn, render_progress],
)
abort_btn.click(set_abort_event)
# Register the session initialization and cleanup functions.
app.load(
start_server_and_abort_event,
outputs=[renderer, viewport, session_hash],
)
app.unload(stop_server_and_abort_event)
app.queue(max_size=5).launch(
share=share,
server_port=server_port,
show_error=True,
allowed_paths=[WORK_DIR],
# Badget rendering will be broken otherwise.
ssr_mode=False,
)
if __name__ == "__main__":
tyro.cli(main)