Spaces:
Running
on
L40S
Running
on
L40S
import colorsys | |
import dataclasses | |
import threading | |
import time | |
from pathlib import Path | |
import numpy as np | |
import scipy | |
import splines | |
import splines.quaternion | |
import torch | |
import viser | |
import viser.transforms as vt | |
from seva.geometry import get_preset_pose_fov | |
class Keyframe(object): | |
position: np.ndarray | |
wxyz: np.ndarray | |
override_fov_enabled: bool | |
override_fov_rad: float | |
aspect: float | |
override_transition_enabled: bool | |
override_transition_sec: float | None | |
def from_camera(camera: viser.CameraHandle, aspect: float) -> "Keyframe": | |
return Keyframe( | |
camera.position, | |
camera.wxyz, | |
override_fov_enabled=False, | |
override_fov_rad=camera.fov, | |
aspect=aspect, | |
override_transition_enabled=False, | |
override_transition_sec=None, | |
) | |
def from_se3(se3: vt.SE3, fov: float, aspect: float) -> "Keyframe": | |
return Keyframe( | |
se3.translation(), | |
se3.rotation().wxyz, | |
override_fov_enabled=False, | |
override_fov_rad=fov, | |
aspect=aspect, | |
override_transition_enabled=False, | |
override_transition_sec=None, | |
) | |
class CameraTrajectory(object): | |
def __init__( | |
self, | |
server: viser.ViserServer, | |
duration_element: viser.GuiInputHandle[float], | |
scene_scale: float, | |
scene_node_prefix: str = "/", | |
): | |
self._server = server | |
self._keyframes: dict[int, tuple[Keyframe, viser.CameraFrustumHandle]] = {} | |
self._keyframe_counter: int = 0 | |
self._spline_nodes: list[viser.SceneNodeHandle] = [] | |
self._camera_edit_panel: viser.Gui3dContainerHandle | None = None | |
self._orientation_spline: splines.quaternion.KochanekBartels | None = None | |
self._position_spline: splines.KochanekBartels | None = None | |
self._fov_spline: splines.KochanekBartels | None = None | |
self._keyframes_visible: bool = True | |
self._duration_element = duration_element | |
self._scene_node_prefix = scene_node_prefix | |
self.scene_scale = scene_scale | |
# These parameters should be overridden externally. | |
self.loop: bool = False | |
self.framerate: float = 30.0 | |
self.tension: float = 0.0 # Tension / alpha term. | |
self.default_fov: float = 0.0 | |
self.default_transition_sec: float = 0.0 | |
self.show_spline: bool = True | |
def set_keyframes_visible(self, visible: bool) -> None: | |
self._keyframes_visible = visible | |
for keyframe in self._keyframes.values(): | |
keyframe[1].visible = visible | |
def add_camera(self, keyframe: Keyframe, keyframe_index: int | None = None) -> None: | |
"""Add a new camera, or replace an old one if `keyframe_index` is passed in.""" | |
server = self._server | |
# Add a keyframe if we aren't replacing an existing one. | |
if keyframe_index is None: | |
keyframe_index = self._keyframe_counter | |
self._keyframe_counter += 1 | |
print( | |
f"{keyframe.wxyz=} {keyframe.position=} {keyframe_index=} {keyframe.aspect=}" | |
) | |
frustum_handle = server.scene.add_camera_frustum( | |
str(Path(self._scene_node_prefix) / f"cameras/{keyframe_index}"), | |
fov=( | |
keyframe.override_fov_rad | |
if keyframe.override_fov_enabled | |
else self.default_fov | |
), | |
aspect=keyframe.aspect, | |
scale=0.1 * self.scene_scale, | |
color=(200, 10, 30), | |
wxyz=keyframe.wxyz, | |
position=keyframe.position, | |
visible=self._keyframes_visible, | |
) | |
self._server.scene.add_icosphere( | |
str(Path(self._scene_node_prefix) / f"cameras/{keyframe_index}/sphere"), | |
radius=0.03, | |
color=(200, 10, 30), | |
) | |
def _(_) -> None: | |
if self._camera_edit_panel is not None: | |
self._camera_edit_panel.remove() | |
self._camera_edit_panel = None | |
with server.scene.add_3d_gui_container( | |
"/camera_edit_panel", | |
position=keyframe.position, | |
) as camera_edit_panel: | |
self._camera_edit_panel = camera_edit_panel | |
override_fov = server.gui.add_checkbox( | |
"Override FOV", initial_value=keyframe.override_fov_enabled | |
) | |
override_fov_degrees = server.gui.add_slider( | |
"Override FOV (degrees)", | |
5.0, | |
175.0, | |
step=0.1, | |
initial_value=keyframe.override_fov_rad * 180.0 / np.pi, | |
disabled=not keyframe.override_fov_enabled, | |
) | |
delete_button = server.gui.add_button( | |
"Delete", color="red", icon=viser.Icon.TRASH | |
) | |
go_to_button = server.gui.add_button("Go to") | |
close_button = server.gui.add_button("Close") | |
def _(_) -> None: | |
keyframe.override_fov_enabled = override_fov.value | |
override_fov_degrees.disabled = not override_fov.value | |
self.add_camera(keyframe, keyframe_index) | |
def _(_) -> None: | |
keyframe.override_fov_rad = override_fov_degrees.value / 180.0 * np.pi | |
self.add_camera(keyframe, keyframe_index) | |
def _(event: viser.GuiEvent) -> None: | |
assert event.client is not None | |
with event.client.gui.add_modal("Confirm") as modal: | |
event.client.gui.add_markdown("Delete keyframe?") | |
confirm_button = event.client.gui.add_button( | |
"Yes", color="red", icon=viser.Icon.TRASH | |
) | |
exit_button = event.client.gui.add_button("Cancel") | |
def _(_) -> None: | |
assert camera_edit_panel is not None | |
keyframe_id = None | |
for i, keyframe_tuple in self._keyframes.items(): | |
if keyframe_tuple[1] is frustum_handle: | |
keyframe_id = i | |
break | |
assert keyframe_id is not None | |
self._keyframes.pop(keyframe_id) | |
frustum_handle.remove() | |
camera_edit_panel.remove() | |
self._camera_edit_panel = None | |
modal.close() | |
self.update_spline() | |
def _(_) -> None: | |
modal.close() | |
def _(event: viser.GuiEvent) -> None: | |
assert event.client is not None | |
client = event.client | |
T_world_current = vt.SE3.from_rotation_and_translation( | |
vt.SO3(client.camera.wxyz), client.camera.position | |
) | |
T_world_target = vt.SE3.from_rotation_and_translation( | |
vt.SO3(keyframe.wxyz), keyframe.position | |
) @ vt.SE3.from_translation(np.array([0.0, 0.0, -0.5])) | |
T_current_target = T_world_current.inverse() @ T_world_target | |
for j in range(10): | |
T_world_set = T_world_current @ vt.SE3.exp( | |
T_current_target.log() * j / 9.0 | |
) | |
# Important bit: we atomically set both the orientation and | |
# the position of the camera. | |
with client.atomic(): | |
client.camera.wxyz = T_world_set.rotation().wxyz | |
client.camera.position = T_world_set.translation() | |
time.sleep(1.0 / 30.0) | |
def _(_) -> None: | |
assert camera_edit_panel is not None | |
camera_edit_panel.remove() | |
self._camera_edit_panel = None | |
self._keyframes[keyframe_index] = (keyframe, frustum_handle) | |
def update_aspect(self, aspect: float) -> None: | |
for keyframe_index, frame in self._keyframes.items(): | |
frame = dataclasses.replace(frame[0], aspect=aspect) | |
self.add_camera(frame, keyframe_index=keyframe_index) | |
def get_aspect(self) -> float: | |
"""Get W/H aspect ratio, which is shared across all keyframes.""" | |
assert len(self._keyframes) > 0 | |
return next(iter(self._keyframes.values()))[0].aspect | |
def reset(self) -> None: | |
for frame in self._keyframes.values(): | |
print(f"removing {frame[1]}") | |
frame[1].remove() | |
self._keyframes.clear() | |
self.update_spline() | |
print("camera traj reset") | |
def spline_t_from_t_sec(self, time: np.ndarray) -> np.ndarray: | |
"""From a time value in seconds, compute a t value for our geometric | |
spline interpolation. An increment of 1 for the latter will move the | |
camera forward by one keyframe. | |
We use a PCHIP spline here to guarantee monotonicity. | |
""" | |
transition_times_cumsum = self.compute_transition_times_cumsum() | |
spline_indices = np.arange(transition_times_cumsum.shape[0]) | |
if self.loop: | |
# In the case of a loop, we pad the spline to match the start/end | |
# slopes. | |
interpolator = scipy.interpolate.PchipInterpolator( | |
x=np.concatenate( | |
[ | |
[-(transition_times_cumsum[-1] - transition_times_cumsum[-2])], | |
transition_times_cumsum, | |
transition_times_cumsum[-1:] + transition_times_cumsum[1:2], | |
], | |
axis=0, | |
), | |
y=np.concatenate( | |
[[-1], spline_indices, [spline_indices[-1] + 1]], # type: ignore | |
axis=0, | |
), | |
) | |
else: | |
interpolator = scipy.interpolate.PchipInterpolator( | |
x=transition_times_cumsum, y=spline_indices | |
) | |
# Clip to account for floating point error. | |
return np.clip(interpolator(time), 0, spline_indices[-1]) | |
def interpolate_pose_and_fov_rad( | |
self, normalized_t: float | |
) -> tuple[vt.SE3, float] | None: | |
if len(self._keyframes) < 2: | |
return None | |
self._fov_spline = splines.KochanekBartels( | |
[ | |
( | |
keyframe[0].override_fov_rad | |
if keyframe[0].override_fov_enabled | |
else self.default_fov | |
) | |
for keyframe in self._keyframes.values() | |
], | |
tcb=(self.tension, 0.0, 0.0), | |
endconditions="closed" if self.loop else "natural", | |
) | |
assert self._orientation_spline is not None | |
assert self._position_spline is not None | |
assert self._fov_spline is not None | |
max_t = self.compute_duration() | |
t = max_t * normalized_t | |
spline_t = float(self.spline_t_from_t_sec(np.array(t))) | |
quat = self._orientation_spline.evaluate(spline_t) | |
assert isinstance(quat, splines.quaternion.UnitQuaternion) | |
return ( | |
vt.SE3.from_rotation_and_translation( | |
vt.SO3(np.array([quat.scalar, *quat.vector])), | |
self._position_spline.evaluate(spline_t), | |
), | |
float(self._fov_spline.evaluate(spline_t)), | |
) | |
def update_spline(self) -> None: | |
num_frames = int(self.compute_duration() * self.framerate) | |
keyframes = list(self._keyframes.values()) | |
if num_frames <= 0 or not self.show_spline or len(keyframes) < 2: | |
for node in self._spline_nodes: | |
node.remove() | |
self._spline_nodes.clear() | |
return | |
transition_times_cumsum = self.compute_transition_times_cumsum() | |
self._orientation_spline = splines.quaternion.KochanekBartels( | |
[ | |
splines.quaternion.UnitQuaternion.from_unit_xyzw( | |
np.roll(keyframe[0].wxyz, shift=-1) | |
) | |
for keyframe in keyframes | |
], | |
tcb=(self.tension, 0.0, 0.0), | |
endconditions="closed" if self.loop else "natural", | |
) | |
self._position_spline = splines.KochanekBartels( | |
[keyframe[0].position for keyframe in keyframes], | |
tcb=(self.tension, 0.0, 0.0), | |
endconditions="closed" if self.loop else "natural", | |
) | |
# Update visualized spline. | |
points_array = self._position_spline.evaluate( | |
self.spline_t_from_t_sec( | |
np.linspace(0, transition_times_cumsum[-1], num_frames) | |
) | |
) | |
colors_array = np.array( | |
[ | |
colorsys.hls_to_rgb(h, 0.5, 1.0) | |
for h in np.linspace(0.0, 1.0, len(points_array)) | |
] | |
) | |
# Clear prior spline nodes. | |
for node in self._spline_nodes: | |
node.remove() | |
self._spline_nodes.clear() | |
self._spline_nodes.append( | |
self._server.scene.add_spline_catmull_rom( | |
str(Path(self._scene_node_prefix) / "camera_spline"), | |
positions=points_array, | |
color=(220, 220, 220), | |
closed=self.loop, | |
line_width=1.0, | |
segments=points_array.shape[0] + 1, | |
) | |
) | |
self._spline_nodes.append( | |
self._server.scene.add_point_cloud( | |
str(Path(self._scene_node_prefix) / "camera_spline/points"), | |
points=points_array, | |
colors=colors_array, | |
point_size=0.04, | |
) | |
) | |
def make_transition_handle(i: int) -> None: | |
assert self._position_spline is not None | |
transition_pos = self._position_spline.evaluate( | |
float( | |
self.spline_t_from_t_sec( | |
(transition_times_cumsum[i] + transition_times_cumsum[i + 1]) | |
/ 2.0, | |
) | |
) | |
) | |
transition_sphere = self._server.scene.add_icosphere( | |
str(Path(self._scene_node_prefix) / f"camera_spline/transition_{i}"), | |
radius=0.04, | |
color=(255, 0, 0), | |
position=transition_pos, | |
) | |
self._spline_nodes.append(transition_sphere) | |
def _(_) -> None: | |
server = self._server | |
if self._camera_edit_panel is not None: | |
self._camera_edit_panel.remove() | |
self._camera_edit_panel = None | |
keyframe_index = (i + 1) % len(self._keyframes) | |
keyframe = keyframes[keyframe_index][0] | |
with server.scene.add_3d_gui_container( | |
"/camera_edit_panel", | |
position=transition_pos, | |
) as camera_edit_panel: | |
self._camera_edit_panel = camera_edit_panel | |
override_transition_enabled = server.gui.add_checkbox( | |
"Override transition", | |
initial_value=keyframe.override_transition_enabled, | |
) | |
override_transition_sec = server.gui.add_number( | |
"Override transition (sec)", | |
initial_value=( | |
keyframe.override_transition_sec | |
if keyframe.override_transition_sec is not None | |
else self.default_transition_sec | |
), | |
min=0.001, | |
max=30.0, | |
step=0.001, | |
disabled=not override_transition_enabled.value, | |
) | |
close_button = server.gui.add_button("Close") | |
def _(_) -> None: | |
keyframe.override_transition_enabled = ( | |
override_transition_enabled.value | |
) | |
override_transition_sec.disabled = ( | |
not override_transition_enabled.value | |
) | |
self._duration_element.value = self.compute_duration() | |
def _(_) -> None: | |
keyframe.override_transition_sec = override_transition_sec.value | |
self._duration_element.value = self.compute_duration() | |
def _(_) -> None: | |
assert camera_edit_panel is not None | |
camera_edit_panel.remove() | |
self._camera_edit_panel = None | |
(num_transitions_plus_1,) = transition_times_cumsum.shape | |
for i in range(num_transitions_plus_1 - 1): | |
make_transition_handle(i) | |
def compute_duration(self) -> float: | |
"""Compute the total duration of the trajectory.""" | |
total = 0.0 | |
for i, (keyframe, frustum) in enumerate(self._keyframes.values()): | |
if i == 0 and not self.loop: | |
continue | |
del frustum | |
total += ( | |
keyframe.override_transition_sec | |
if keyframe.override_transition_enabled | |
and keyframe.override_transition_sec is not None | |
else self.default_transition_sec | |
) | |
return total | |
def compute_transition_times_cumsum(self) -> np.ndarray: | |
"""Compute the total duration of the trajectory.""" | |
total = 0.0 | |
out = [0.0] | |
for i, (keyframe, frustum) in enumerate(self._keyframes.values()): | |
if i == 0: | |
continue | |
del frustum | |
total += ( | |
keyframe.override_transition_sec | |
if keyframe.override_transition_enabled | |
and keyframe.override_transition_sec is not None | |
else self.default_transition_sec | |
) | |
out.append(total) | |
if self.loop: | |
keyframe = next(iter(self._keyframes.values()))[0] | |
total += ( | |
keyframe.override_transition_sec | |
if keyframe.override_transition_enabled | |
and keyframe.override_transition_sec is not None | |
else self.default_transition_sec | |
) | |
out.append(total) | |
return np.array(out) | |
class GuiState: | |
preview_render: bool | |
preview_fov: float | |
preview_aspect: float | |
camera_traj_list: list | None | |
active_input_index: int | |
def define_gui( | |
server: viser.ViserServer, | |
init_fov: float = 75.0, | |
img_wh: tuple[int, int] = (576, 576), | |
**kwargs, | |
) -> GuiState: | |
gui_state = GuiState( | |
preview_render=False, | |
preview_fov=0.0, | |
preview_aspect=1.0, | |
camera_traj_list=None, | |
active_input_index=0, | |
) | |
with server.gui.add_folder( | |
"Preset camera trajectories", order=99, expand_by_default=False | |
): | |
preset_traj_dropdown = server.gui.add_dropdown( | |
"Options", | |
[ | |
"orbit", | |
"spiral", | |
"lemniscate", | |
"zoom-out", | |
"dolly zoom-out", | |
], | |
initial_value="orbit", | |
hint="Select a preset camera trajectory.", | |
) | |
preset_duration_num = server.gui.add_number( | |
"Duration (sec)", | |
min=1.0, | |
max=60.0, | |
step=0.5, | |
initial_value=2.0, | |
) | |
preset_submit_button = server.gui.add_button( | |
"Submit", | |
icon=viser.Icon.PICK, | |
hint="Add a new keyframe at the current pose.", | |
) | |
def _(event: viser.GuiEvent) -> None: | |
camera_traj.reset() | |
gui_state.camera_traj_list = None | |
duration = preset_duration_num.value | |
fps = framerate_number.value | |
num_frames = int(duration * fps) | |
transition_sec = duration / num_frames | |
transition_sec_number.value = transition_sec | |
assert event.client_id is not None | |
transition_sec_number.disabled = True | |
loop_checkbox.disabled = True | |
add_keyframe_button.disabled = True | |
camera = server.get_clients()[event.client_id].camera | |
start_w2c = torch.linalg.inv( | |
torch.as_tensor( | |
vt.SE3.from_rotation_and_translation( | |
vt.SO3(camera.wxyz), camera.position | |
).as_matrix(), | |
dtype=torch.float32, | |
) | |
) | |
look_at = torch.as_tensor(camera.look_at, dtype=torch.float32) | |
up_direction = torch.as_tensor(camera.up_direction, dtype=torch.float32) | |
poses, fovs = get_preset_pose_fov( | |
option=preset_traj_dropdown.value, # type: ignore | |
num_frames=num_frames, | |
start_w2c=start_w2c, | |
look_at=look_at, | |
up_direction=up_direction, | |
fov=camera.fov, | |
) | |
assert poses is not None and fovs is not None | |
for pose, fov in zip(poses, fovs): | |
camera_traj.add_camera( | |
Keyframe.from_se3( | |
vt.SE3.from_matrix(pose), | |
fov=fov, | |
aspect=img_wh[0] / img_wh[1], | |
) | |
) | |
duration_number.value = camera_traj.compute_duration() | |
camera_traj.update_spline() | |
with server.gui.add_folder("Advanced", expand_by_default=False, order=100): | |
transition_sec_number = server.gui.add_number( | |
"Transition (sec)", | |
min=0.001, | |
max=30.0, | |
step=0.001, | |
initial_value=1.5, | |
hint="Time in seconds between each keyframe, which can also be overridden on a per-transition basis.", | |
) | |
framerate_number = server.gui.add_number( | |
"FPS", min=0.1, max=240.0, step=1e-2, initial_value=30.0 | |
) | |
framerate_buttons = server.gui.add_button_group("", ("24", "30", "60")) | |
duration_number = server.gui.add_number( | |
"Duration (sec)", | |
min=0.0, | |
max=1e8, | |
step=0.001, | |
initial_value=0.0, | |
disabled=True, | |
) | |
def _(_) -> None: | |
framerate_number.value = float(framerate_buttons.value) | |
fov_degree_slider = server.gui.add_slider( | |
"FOV", | |
initial_value=init_fov, | |
min=0.1, | |
max=175.0, | |
step=0.01, | |
hint="Field-of-view for rendering, which can also be overridden on a per-keyframe basis.", | |
) | |
def _(_) -> None: | |
fov_radians = fov_degree_slider.value / 180.0 * np.pi | |
for client in server.get_clients().values(): | |
client.camera.fov = fov_radians | |
camera_traj.default_fov = fov_radians | |
# Updating the aspect ratio will also re-render the camera frustums. | |
# Could rethink this. | |
camera_traj.update_aspect(img_wh[0] / img_wh[1]) | |
compute_and_update_preview_camera_state() | |
scene_node_prefix = "/render_assets" | |
base_scene_node = server.scene.add_frame(scene_node_prefix, show_axes=False) | |
add_keyframe_button = server.gui.add_button( | |
"Add keyframe", | |
icon=viser.Icon.PLUS, | |
hint="Add a new keyframe at the current pose.", | |
) | |
def _(event: viser.GuiEvent) -> None: | |
assert event.client_id is not None | |
camera = server.get_clients()[event.client_id].camera | |
pose = vt.SE3.from_rotation_and_translation( | |
vt.SO3(camera.wxyz), camera.position | |
) | |
print(f"client {event.client_id} at {camera.position} {camera.wxyz}") | |
print(f"camera pose {pose.as_matrix()}") | |
# Add this camera to the trajectory. | |
camera_traj.add_camera( | |
Keyframe.from_camera( | |
camera, | |
aspect=img_wh[0] / img_wh[1], | |
), | |
) | |
duration_number.value = camera_traj.compute_duration() | |
camera_traj.update_spline() | |
clear_keyframes_button = server.gui.add_button( | |
"Clear keyframes", | |
icon=viser.Icon.TRASH, | |
hint="Remove all keyframes from the render trajectory.", | |
) | |
def _(event: viser.GuiEvent) -> None: | |
assert event.client_id is not None | |
client = server.get_clients()[event.client_id] | |
with client.atomic(), client.gui.add_modal("Confirm") as modal: | |
client.gui.add_markdown("Clear all keyframes?") | |
confirm_button = client.gui.add_button( | |
"Yes", color="red", icon=viser.Icon.TRASH | |
) | |
exit_button = client.gui.add_button("Cancel") | |
def _(_) -> None: | |
camera_traj.reset() | |
modal.close() | |
duration_number.value = camera_traj.compute_duration() | |
add_keyframe_button.disabled = False | |
transition_sec_number.disabled = False | |
transition_sec_number.value = 1.5 | |
loop_checkbox.disabled = False | |
nonlocal gui_state | |
gui_state.camera_traj_list = None | |
def _(_) -> None: | |
modal.close() | |
play_button = server.gui.add_button("Play", icon=viser.Icon.PLAYER_PLAY) | |
pause_button = server.gui.add_button( | |
"Pause", icon=viser.Icon.PLAYER_PAUSE, visible=False | |
) | |
# Poll the play button to see if we should be playing endlessly. | |
def play() -> None: | |
while True: | |
while not play_button.visible: | |
max_frame = int(framerate_number.value * duration_number.value) | |
if max_frame > 0: | |
assert preview_frame_slider is not None | |
preview_frame_slider.value = ( | |
preview_frame_slider.value + 1 | |
) % max_frame | |
time.sleep(1.0 / framerate_number.value) | |
time.sleep(0.1) | |
threading.Thread(target=play).start() | |
# Play the camera trajectory when the play button is pressed. | |
def _(_) -> None: | |
play_button.visible = False | |
pause_button.visible = True | |
# Play the camera trajectory when the play button is pressed. | |
def _(_) -> None: | |
play_button.visible = True | |
pause_button.visible = False | |
preview_render_button = server.gui.add_button( | |
"Preview render", | |
hint="Show a preview of the render in the viewport.", | |
icon=viser.Icon.CAMERA_CHECK, | |
) | |
preview_render_stop_button = server.gui.add_button( | |
"Exit render preview", | |
color="red", | |
icon=viser.Icon.CAMERA_CANCEL, | |
visible=False, | |
) | |
def _(_) -> None: | |
gui_state.preview_render = True | |
preview_render_button.visible = False | |
preview_render_stop_button.visible = True | |
play_button.visible = False | |
pause_button.visible = True | |
preset_submit_button.disabled = True | |
maybe_pose_and_fov_rad = compute_and_update_preview_camera_state() | |
if maybe_pose_and_fov_rad is None: | |
remove_preview_camera() | |
return | |
pose, fov = maybe_pose_and_fov_rad | |
del fov | |
# Hide all render assets when we're previewing the render. | |
nonlocal base_scene_node | |
base_scene_node.visible = False | |
# Back up and then set camera poses. | |
for client in server.get_clients().values(): | |
camera_pose_backup_from_id[client.client_id] = ( | |
client.camera.position, | |
client.camera.look_at, | |
client.camera.up_direction, | |
) | |
with client.atomic(): | |
client.camera.wxyz = pose.rotation().wxyz | |
client.camera.position = pose.translation() | |
def stop_preview_render() -> None: | |
gui_state.preview_render = False | |
preview_render_button.visible = True | |
preview_render_stop_button.visible = False | |
play_button.visible = True | |
pause_button.visible = False | |
preset_submit_button.disabled = False | |
# Revert camera poses. | |
for client in server.get_clients().values(): | |
if client.client_id not in camera_pose_backup_from_id: | |
continue | |
cam_position, cam_look_at, cam_up = camera_pose_backup_from_id.pop( | |
client.client_id | |
) | |
with client.atomic(): | |
client.camera.position = cam_position | |
client.camera.look_at = cam_look_at | |
client.camera.up_direction = cam_up | |
client.flush() | |
# Un-hide render assets. | |
nonlocal base_scene_node | |
base_scene_node.visible = True | |
remove_preview_camera() | |
def _(_) -> None: | |
stop_preview_render() | |
def get_max_frame_index() -> int: | |
return max(1, int(framerate_number.value * duration_number.value) - 1) | |
def add_preview_frame_slider() -> viser.GuiInputHandle[int] | None: | |
"""Helper for creating the current frame # slider. This is removed and | |
re-added anytime the `max` value changes.""" | |
preview_frame_slider = server.gui.add_slider( | |
"Preview frame", | |
min=0, | |
max=get_max_frame_index(), | |
step=1, | |
initial_value=0, | |
order=set_traj_button.order + 0.01, | |
disabled=get_max_frame_index() == 1, | |
) | |
play_button.disabled = preview_frame_slider.disabled | |
preview_render_button.disabled = preview_frame_slider.disabled | |
set_traj_button.disabled = preview_frame_slider.disabled | |
def _(_) -> None: | |
nonlocal preview_camera_handle | |
maybe_pose_and_fov_rad = compute_and_update_preview_camera_state() | |
if maybe_pose_and_fov_rad is None: | |
return | |
pose, fov_rad = maybe_pose_and_fov_rad | |
preview_camera_handle = server.scene.add_camera_frustum( | |
str(Path(scene_node_prefix) / "preview_camera"), | |
fov=fov_rad, | |
aspect=img_wh[0] / img_wh[1], | |
scale=0.35, | |
wxyz=pose.rotation().wxyz, | |
position=pose.translation(), | |
color=(10, 200, 30), | |
) | |
if gui_state.preview_render: | |
for client in server.get_clients().values(): | |
with client.atomic(): | |
client.camera.wxyz = pose.rotation().wxyz | |
client.camera.position = pose.translation() | |
return preview_frame_slider | |
set_traj_button = server.gui.add_button( | |
"Set camera trajectory", | |
color="green", | |
icon=viser.Icon.CHECK, | |
hint="Save the camera trajectory for rendering.", | |
) | |
def _(event: viser.GuiEvent) -> None: | |
assert event.client is not None | |
num_frames = int(framerate_number.value * duration_number.value) | |
def get_intrinsics(W, H, fov_rad): | |
focal = 0.5 * H / np.tan(0.5 * fov_rad) | |
return np.array( | |
[[focal, 0.0, 0.5 * W], [0.0, focal, 0.5 * H], [0.0, 0.0, 1.0]] | |
) | |
camera_traj_list = [] | |
for i in range(num_frames): | |
maybe_pose_and_fov_rad = camera_traj.interpolate_pose_and_fov_rad( | |
i / num_frames | |
) | |
if maybe_pose_and_fov_rad is None: | |
return | |
pose, fov_rad = maybe_pose_and_fov_rad | |
H = img_wh[1] | |
W = img_wh[0] | |
K = get_intrinsics(W, H, fov_rad) | |
w2c = pose.inverse().as_matrix() | |
camera_traj_list.append( | |
{ | |
"w2c": w2c.flatten().tolist(), | |
"K": K.flatten().tolist(), | |
"img_wh": (W, H), | |
} | |
) | |
nonlocal gui_state | |
gui_state.camera_traj_list = camera_traj_list | |
print(f"Get camera_traj_list: {gui_state.camera_traj_list}") | |
stop_preview_render() | |
preview_frame_slider = add_preview_frame_slider() | |
loop_checkbox = server.gui.add_checkbox( | |
"Loop", False, hint="Add a segment between the first and last keyframes." | |
) | |
def _(_) -> None: | |
camera_traj.loop = loop_checkbox.value | |
duration_number.value = camera_traj.compute_duration() | |
def _(_) -> None: | |
camera_traj.default_transition_sec = transition_sec_number.value | |
duration_number.value = camera_traj.compute_duration() | |
preview_camera_handle: viser.SceneNodeHandle | None = None | |
def remove_preview_camera() -> None: | |
nonlocal preview_camera_handle | |
if preview_camera_handle is not None: | |
preview_camera_handle.remove() | |
preview_camera_handle = None | |
def compute_and_update_preview_camera_state() -> tuple[vt.SE3, float] | None: | |
"""Update the render tab state with the current preview camera pose. | |
Returns current camera pose + FOV if available.""" | |
if preview_frame_slider is None: | |
return None | |
maybe_pose_and_fov_rad = camera_traj.interpolate_pose_and_fov_rad( | |
preview_frame_slider.value / get_max_frame_index() | |
) | |
if maybe_pose_and_fov_rad is None: | |
remove_preview_camera() | |
return None | |
pose, fov_rad = maybe_pose_and_fov_rad | |
gui_state.preview_fov = fov_rad | |
gui_state.preview_aspect = camera_traj.get_aspect() | |
return pose, fov_rad | |
# We back up the camera poses before and after we start previewing renders. | |
camera_pose_backup_from_id: dict[int, tuple] = {} | |
# Update the # of frames. | |
def _(_) -> None: | |
remove_preview_camera() # Will be re-added when slider is updated. | |
nonlocal preview_frame_slider | |
old = preview_frame_slider | |
assert old is not None | |
preview_frame_slider = add_preview_frame_slider() | |
if preview_frame_slider is not None: | |
old.remove() | |
else: | |
preview_frame_slider = old | |
camera_traj.framerate = framerate_number.value | |
camera_traj.update_spline() | |
camera_traj = CameraTrajectory( | |
server, | |
duration_number, | |
scene_node_prefix=scene_node_prefix, | |
**kwargs, | |
) | |
camera_traj.default_fov = fov_degree_slider.value / 180.0 * np.pi | |
camera_traj.default_transition_sec = transition_sec_number.value | |
return gui_state | |