import os |
import glob |
import time |
import threading |
import argparse |
from typing import List, Optional |
import copy |
import numpy as np |
import torch |
from tqdm.auto import tqdm |
import viser |
import viser.transforms as viser_tf |
import cv2 |
import requests |
try: |
import onnxruntime |
except ImportError: |
print("onnxruntime not found. Sky segmentation may not work.") |
from vggt.models.vggt import VGGT |
from vggt.utils.load_fn import load_and_preprocess_images |
from vggt.utils.geometry import closed_form_inverse_se3, unproject_depth_map_to_point_map |
from vggt.utils.pose_enc import pose_encoding_to_extri_intri |
def viser_wrapper( |
pred_dict: dict, |
port: int = 8080, |
init_conf_threshold: float = 50.0, |
use_point_map: bool = False, |
background_mode: bool = False, |
mask_sky: bool = False, |
image_folder: str = None, |
): |
""" |
Visualize predicted 3D points and camera poses with viser. |
Args: |
pred_dict (dict): |
{ |
"images": (S, 3, H, W) - Input images, |
"world_points": (S, H, W, 3), |
"world_points_conf": (S, H, W), |
"depth": (S, H, W, 1), |
"depth_conf": (S, H, W), |
"extrinsic": (S, 3, 4), |
"intrinsic": (S, 3, 3), |
} |
port (int): Port number for the viser server. |
init_conf_threshold (float): Initial percentage of low-confidence points to filter out. |
use_point_map (bool): Whether to visualize world_points or use depth-based points. |
background_mode (bool): Whether to run the server in background thread. |
mask_sky (bool): Whether to apply sky segmentation to filter out sky points. |
image_folder (str): Path to the folder containing input images. |
""" |
print(f"Starting viser server on port {port}") |
server = viser.ViserServer(host="", port=port) |
server.gui.configure_theme(titlebar_content=None, control_layout="collapsible") |
images = pred_dict["images"] |
world_points_map = pred_dict["world_points"] |
conf_map = pred_dict["world_points_conf"] |
depth_map = pred_dict["depth"] |
depth_conf = pred_dict["depth_conf"] |
extrinsics_cam = pred_dict["extrinsic"] |
intrinsics_cam = pred_dict["intrinsic"] |
if not use_point_map: |
world_points = unproject_depth_map_to_point_map(depth_map, extrinsics_cam, intrinsics_cam) |
conf = depth_conf |
else: |
world_points = world_points_map |
conf = conf_map |
if mask_sky and image_folder is not None: |
conf = apply_sky_segmentation(conf, image_folder) |
colors = images.transpose(0, 2, 3, 1) |
S, H, W, _ = world_points.shape |
points = world_points.reshape(-1, 3) |
colors_flat = (colors.reshape(-1, 3) * 255).astype(np.uint8) |
conf = conf.reshape(-1) |
cam_to_world_mat = closed_form_inverse_se3(extrinsics_cam) |
cam_to_world = cam_to_world_mat[:, :3, :] |
scene_center = np.mean(points, axis=0) |
points_centered = points - scene_center |
cam_to_world[..., -1] -= scene_center |
frame_indices = np.repeat(np.arange(S), H * W) |
gui_show_frames = server.gui.add_checkbox( |
"Show Cameras", |
initial_value=True, |
) |
gui_points_conf = server.gui.add_slider( |
"Confidence Percent", |
min=0, |
max=100, |
step=0.1, |
initial_value=init_conf_threshold, |
) |
gui_frame_selector = server.gui.add_dropdown( |
"Show Points from Frames", |
options=["All"] + [str(i) for i in range(S)], |
initial_value="All", |
) |
init_threshold_val = np.percentile(conf, init_conf_threshold) |
init_conf_mask = conf > init_threshold_val |
point_cloud = server.scene.add_point_cloud( |
name="viser_pcd", |
points=points_centered[init_conf_mask], |
colors=colors_flat[init_conf_mask], |
point_size=0.001, |
point_shape="circle", |
) |
frames: List[viser.FrameHandle] = [] |
frustums: List[viser.CameraFrustumHandle] = [] |
def visualize_frames(extrinsics: np.ndarray, images_: np.ndarray) -> None: |
""" |
Add camera frames and frustums to the scene. |
extrinsics: (S, 3, 4) |
images_: (S, 3, H, W) |
""" |
for f in frames: |
f.remove() |
frames.clear() |
for fr in frustums: |
fr.remove() |
frustums.clear() |
def attach_callback(frustum: viser.CameraFrustumHandle, frame: viser.FrameHandle) -> None: |
@frustum.on_click |
def _(_) -> None: |
for client in server.get_clients().values(): |
client.camera.wxyz = frame.wxyz |
client.camera.position = frame.position |
img_ids = range(S) |
for img_id in tqdm(img_ids): |
cam2world_3x4 = extrinsics[img_id] |
T_world_camera = viser_tf.SE3.from_matrix(cam2world_3x4) |
frame_axis = server.scene.add_frame( |
f"frame_{img_id}", |
wxyz=T_world_camera.rotation().wxyz, |
position=T_world_camera.translation(), |
axes_length=0.05, |
axes_radius=0.002, |
origin_radius=0.002, |
) |
frames.append(frame_axis) |
img = images_[img_id] |
img = (img.transpose(1, 2, 0) * 255).astype(np.uint8) |
h, w = img.shape[:2] |
fy = 1.1 * h |
fov = 2 * np.arctan2(h / 2, fy) |
frustum_cam = server.scene.add_camera_frustum( |
f"frame_{img_id}/frustum", |
fov=fov, |
aspect=w / h, |
scale=0.05, |
image=img, |
line_width=1.0, |
) |
frustums.append(frustum_cam) |
attach_callback(frustum_cam, frame_axis) |
def update_point_cloud() -> None: |
"""Update the point cloud based on current GUI selections.""" |
current_percentage = gui_points_conf.value |
threshold_val = np.percentile(conf, current_percentage) |
conf_mask = conf > threshold_val |
if gui_frame_selector.value == "All": |
frame_mask = np.ones_like(conf_mask, dtype=bool) |
else: |
selected_idx = int(gui_frame_selector.value) |
frame_mask = frame_indices == selected_idx |
combined_mask = conf_mask & frame_mask |
point_cloud.points = points_centered[combined_mask] |
point_cloud.colors = colors_flat[combined_mask] |
@gui_points_conf.on_update |
def _(_) -> None: |
update_point_cloud() |
@gui_frame_selector.on_update |
def _(_) -> None: |
update_point_cloud() |
@gui_show_frames.on_update |
def _(_) -> None: |
"""Toggle visibility of camera frames and frustums.""" |
for f in frames: |
f.visible = gui_show_frames.value |
for fr in frustums: |
fr.visible = gui_show_frames.value |
visualize_frames(cam_to_world, images) |
print("Starting viser server...") |
if background_mode: |
def server_loop(): |
while True: |
time.sleep(0.001) |
thread = threading.Thread(target=server_loop, daemon=True) |
thread.start() |
else: |
while True: |
time.sleep(0.01) |
return server |
def download_file_from_url(url, filename): |
"""Downloads a file from a Hugging Face model repo, handling redirects.""" |
try: |
response = requests.get(url, allow_redirects=False) |
response.raise_for_status() |
if response.status_code == 302: |
redirect_url = response.headers["Location"] |
response = requests.get(redirect_url, stream=True) |
response.raise_for_status() |
else: |
print(f"Unexpected status code: {response.status_code}") |
return |
with open(filename, "wb") as f: |
for chunk in response.iter_content(chunk_size=8192): |
f.write(chunk) |
print(f"Downloaded {filename} successfully.") |
except requests.exceptions.RequestException as e: |
print(f"Error downloading file: {e}") |
def apply_sky_segmentation(conf: np.ndarray, image_folder: str) -> np.ndarray: |
""" |
Apply sky segmentation to confidence scores. |
Args: |
conf (np.ndarray): Confidence scores with shape (S, H, W) |
image_folder (str): Path to the folder containing input images |
Returns: |
np.ndarray: Updated confidence scores with sky regions masked out |
""" |
S, H, W = conf.shape |
sky_masks_dir = image_folder.rstrip('/') + "_sky_masks" |
os.makedirs(sky_masks_dir, exist_ok=True) |
if not os.path.exists("skyseg.onnx"): |
print("Downloading skyseg.onnx...") |
download_file_from_url( |
"https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx", "skyseg.onnx" |
) |
skyseg_session = onnxruntime.InferenceSession("skyseg.onnx") |
image_files = sorted(glob.glob(os.path.join(image_folder, "*"))) |
sky_mask_list = [] |
print("Generating sky masks...") |
for i, image_path in enumerate(tqdm(image_files[:S])): |
image_name = os.path.basename(image_path) |
mask_filepath = os.path.join(sky_masks_dir, image_name) |
if os.path.exists(mask_filepath): |
sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE) |
else: |
sky_mask = segment_sky(image_path, skyseg_session, mask_filepath) |
if sky_mask.shape[0] != H or sky_mask.shape[1] != W: |
sky_mask = cv2.resize(sky_mask, (W, H)) |
sky_mask_list.append(sky_mask) |
sky_mask_array = np.array(sky_mask_list) |
sky_mask_binary = (sky_mask_array > 0.01).astype(np.float32) |
conf = conf * sky_mask_binary |
print("Sky segmentation applied successfully") |
return conf |
def segment_sky(image_path, onnx_session, mask_filename=None): |
""" |
Segments sky from an image using an ONNX model. |
Args: |
image_path: Path to input image |
onnx_session: ONNX runtime session with loaded model |
mask_filename: Path to save the output mask |
Returns: |
np.ndarray: Binary mask where 255 indicates non-sky regions |
""" |
assert mask_filename is not None |
image = cv2.imread(image_path) |
result_map = run_skyseg(onnx_session, [320, 320], image) |
result_map_original = cv2.resize(result_map, (image.shape[1], image.shape[0])) |
output_mask = np.zeros_like(result_map_original) |
output_mask[result_map_original < 1] = 1 |
output_mask = output_mask.astype(np.uint8) * 255 |
os.makedirs(os.path.dirname(mask_filename), exist_ok=True) |
cv2.imwrite(mask_filename, output_mask) |
return output_mask |
def run_skyseg(onnx_session, input_size, image): |
""" |
Runs sky segmentation inference using ONNX model. |
Args: |
onnx_session: ONNX runtime session |
input_size: Target size for model input (width, height) |
image: Input image in BGR format |
Returns: |
np.ndarray: Segmentation mask |
""" |
temp_image = copy.deepcopy(image) |
resize_image = cv2.resize(temp_image, dsize=(input_size[0], input_size[1])) |
x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB) |
x = np.array(x, dtype=np.float32) |
mean = [0.485, 0.456, 0.406] |
std = [0.229, 0.224, 0.225] |
x = (x / 255 - mean) / std |
x = x.transpose(2, 0, 1) |
x = x.reshape(-1, 3, input_size[0], input_size[1]).astype("float32") |
input_name = onnx_session.get_inputs()[0].name |
output_name = onnx_session.get_outputs()[0].name |
onnx_result = onnx_session.run([output_name], {input_name: x}) |
onnx_result = np.array(onnx_result).squeeze() |
min_value = np.min(onnx_result) |
max_value = np.max(onnx_result) |
onnx_result = (onnx_result - min_value) / (max_value - min_value) |
onnx_result *= 255 |
onnx_result = onnx_result.astype("uint8") |
return onnx_result |
parser = argparse.ArgumentParser(description="VGGT demo with viser for 3D visualization") |
parser.add_argument( |
"--image_folder", type=str, default="examples/kitchen/images/", help="Path to folder containing images" |
) |
parser.add_argument("--use_point_map", action="store_true", help="Use point map instead of depth-based points") |
parser.add_argument("--background_mode", action="store_true", help="Run the viser server in background mode") |
parser.add_argument("--port", type=int, default=8080, help="Port number for the viser server") |
parser.add_argument( |
"--conf_threshold", type=float, default=25.0, help="Initial percentage of low-confidence points to filter out" |
) |
parser.add_argument("--mask_sky", action="store_true", help="Apply sky segmentation to filter out sky points") |
def main(): |
""" |
Main function for the VGGT demo with viser for 3D visualization. |
This function: |
1. Loads the VGGT model |
2. Processes input images from the specified folder |
3. Runs inference to generate 3D points and camera poses |
4. Optionally applies sky segmentation to filter out sky points |
5. Visualizes the results using viser |
Command-line arguments: |
--image_folder: Path to folder containing input images |
--use_point_map: Use point map instead of depth-based points |
--background_mode: Run the viser server in background mode |
--port: Port number for the viser server |
--conf_threshold: Initial percentage of low-confidence points to filter out |
--mask_sky: Apply sky segmentation to filter out sky points |
""" |
args = parser.parse_args() |
device = "cuda" if torch.cuda.is_available() else "cpu" |
print(f"Using device: {device}") |
print("Initializing and loading VGGT model...") |
model = VGGT() |
_URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt" |
model.load_state_dict(torch.hub.load_state_dict_from_url(_URL)) |
model.eval() |
model = model.to(device) |
print(f"Loading images from {args.image_folder}...") |
image_names = glob.glob(os.path.join(args.image_folder, "*")) |
print(f"Found {len(image_names)} images") |
images = load_and_preprocess_images(image_names).to(device) |
print(f"Preprocessed images shape: {images.shape}") |
print("Running inference...") |
with torch.no_grad(): |
with torch.cuda.amp.autocast(dtype=torch.bfloat16): |
predictions = model(images) |
print("Converting pose encoding to extrinsic and intrinsic matrices...") |
extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:]) |
predictions["extrinsic"] = extrinsic |
predictions["intrinsic"] = intrinsic |
print("Processing model outputs...") |
for key in predictions.keys(): |
if isinstance(predictions[key], torch.Tensor): |
predictions[key] = predictions[key].cpu().numpy().squeeze(0) |
if args.use_point_map: |
print("Visualizing 3D points from point map") |
else: |
print("Visualizing 3D points by unprojecting depth map by cameras") |
if args.mask_sky: |
print("Sky segmentation enabled - will filter out sky points") |
print("Starting viser visualization...") |
viser_server = viser_wrapper( |
predictions, |
port=args.port, |
init_conf_threshold=args.conf_threshold, |
use_point_map=args.use_point_map, |
background_mode=args.background_mode, |
mask_sky=args.mask_sky, |
image_folder=args.image_folder, |
) |
print("Visualization complete") |
if __name__ == "__main__": |
main() |