|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import glob |
|
import time |
|
import threading |
|
import argparse |
|
from typing import List, Optional |
|
import copy |
|
|
|
import numpy as np |
|
import torch |
|
from tqdm.auto import tqdm |
|
import viser |
|
import viser.transforms as viser_tf |
|
import cv2 |
|
import requests |
|
try: |
|
import onnxruntime |
|
except ImportError: |
|
print("onnxruntime not found. Sky segmentation may not work.") |
|
|
|
from vggt.models.vggt import VGGT |
|
from vggt.utils.load_fn import load_and_preprocess_images |
|
from vggt.utils.geometry import closed_form_inverse_se3, unproject_depth_map_to_point_map |
|
from vggt.utils.pose_enc import pose_encoding_to_extri_intri |
|
|
|
def viser_wrapper( |
|
pred_dict: dict, |
|
port: int = 8080, |
|
init_conf_threshold: float = 50.0, |
|
use_point_map: bool = False, |
|
background_mode: bool = False, |
|
mask_sky: bool = False, |
|
image_folder: str = None, |
|
): |
|
""" |
|
Visualize predicted 3D points and camera poses with viser. |
|
|
|
Args: |
|
pred_dict (dict): |
|
{ |
|
"images": (S, 3, H, W) - Input images, |
|
"world_points": (S, H, W, 3), |
|
"world_points_conf": (S, H, W), |
|
"depth": (S, H, W, 1), |
|
"depth_conf": (S, H, W), |
|
"extrinsic": (S, 3, 4), |
|
"intrinsic": (S, 3, 3), |
|
} |
|
port (int): Port number for the viser server. |
|
init_conf_threshold (float): Initial percentage of low-confidence points to filter out. |
|
use_point_map (bool): Whether to visualize world_points or use depth-based points. |
|
background_mode (bool): Whether to run the server in background thread. |
|
mask_sky (bool): Whether to apply sky segmentation to filter out sky points. |
|
image_folder (str): Path to the folder containing input images. |
|
""" |
|
print(f"Starting viser server on port {port}") |
|
|
|
server = viser.ViserServer(host="0.0.0.0", port=port) |
|
server.gui.configure_theme(titlebar_content=None, control_layout="collapsible") |
|
|
|
|
|
images = pred_dict["images"] |
|
world_points_map = pred_dict["world_points"] |
|
conf_map = pred_dict["world_points_conf"] |
|
|
|
depth_map = pred_dict["depth"] |
|
depth_conf = pred_dict["depth_conf"] |
|
|
|
extrinsics_cam = pred_dict["extrinsic"] |
|
intrinsics_cam = pred_dict["intrinsic"] |
|
|
|
|
|
if not use_point_map: |
|
world_points = unproject_depth_map_to_point_map(depth_map, extrinsics_cam, intrinsics_cam) |
|
conf = depth_conf |
|
else: |
|
world_points = world_points_map |
|
conf = conf_map |
|
|
|
|
|
if mask_sky and image_folder is not None: |
|
conf = apply_sky_segmentation(conf, image_folder) |
|
|
|
|
|
|
|
colors = images.transpose(0, 2, 3, 1) |
|
S, H, W, _ = world_points.shape |
|
|
|
|
|
points = world_points.reshape(-1, 3) |
|
colors_flat = (colors.reshape(-1, 3) * 255).astype(np.uint8) |
|
conf = conf.reshape(-1) |
|
|
|
cam_to_world_mat = closed_form_inverse_se3(extrinsics_cam) |
|
|
|
cam_to_world = cam_to_world_mat[:, :3, :] |
|
|
|
|
|
scene_center = np.mean(points, axis=0) |
|
points_centered = points - scene_center |
|
cam_to_world[..., -1] -= scene_center |
|
|
|
|
|
frame_indices = np.repeat(np.arange(S), H * W) |
|
|
|
|
|
gui_show_frames = server.gui.add_checkbox( |
|
"Show Cameras", |
|
initial_value=True, |
|
) |
|
|
|
|
|
gui_points_conf = server.gui.add_slider( |
|
"Confidence Percent", |
|
min=0, |
|
max=100, |
|
step=0.1, |
|
initial_value=init_conf_threshold, |
|
) |
|
|
|
gui_frame_selector = server.gui.add_dropdown( |
|
"Show Points from Frames", |
|
options=["All"] + [str(i) for i in range(S)], |
|
initial_value="All", |
|
) |
|
|
|
|
|
|
|
init_threshold_val = np.percentile(conf, init_conf_threshold) |
|
init_conf_mask = conf > init_threshold_val |
|
point_cloud = server.scene.add_point_cloud( |
|
name="viser_pcd", |
|
points=points_centered[init_conf_mask], |
|
colors=colors_flat[init_conf_mask], |
|
|
|
point_size=0.001, |
|
point_shape="circle", |
|
) |
|
|
|
|
|
frames: List[viser.FrameHandle] = [] |
|
frustums: List[viser.CameraFrustumHandle] = [] |
|
|
|
def visualize_frames(extrinsics: np.ndarray, images_: np.ndarray) -> None: |
|
""" |
|
Add camera frames and frustums to the scene. |
|
extrinsics: (S, 3, 4) |
|
images_: (S, 3, H, W) |
|
""" |
|
|
|
for f in frames: |
|
f.remove() |
|
frames.clear() |
|
for fr in frustums: |
|
fr.remove() |
|
frustums.clear() |
|
|
|
|
|
def attach_callback(frustum: viser.CameraFrustumHandle, frame: viser.FrameHandle) -> None: |
|
@frustum.on_click |
|
def _(_) -> None: |
|
for client in server.get_clients().values(): |
|
client.camera.wxyz = frame.wxyz |
|
client.camera.position = frame.position |
|
|
|
img_ids = range(S) |
|
for img_id in tqdm(img_ids): |
|
cam2world_3x4 = extrinsics[img_id] |
|
T_world_camera = viser_tf.SE3.from_matrix(cam2world_3x4) |
|
|
|
|
|
frame_axis = server.scene.add_frame( |
|
f"frame_{img_id}", |
|
wxyz=T_world_camera.rotation().wxyz, |
|
position=T_world_camera.translation(), |
|
axes_length=0.05, |
|
axes_radius=0.002, |
|
origin_radius=0.002, |
|
) |
|
frames.append(frame_axis) |
|
|
|
|
|
img = images_[img_id] |
|
img = (img.transpose(1, 2, 0) * 255).astype(np.uint8) |
|
h, w = img.shape[:2] |
|
|
|
|
|
|
|
|
|
|
|
fy = 1.1 * h |
|
fov = 2 * np.arctan2(h / 2, fy) |
|
|
|
|
|
frustum_cam = server.scene.add_camera_frustum( |
|
f"frame_{img_id}/frustum", |
|
fov=fov, |
|
aspect=w / h, |
|
scale=0.05, |
|
image=img, |
|
line_width=1.0, |
|
) |
|
frustums.append(frustum_cam) |
|
attach_callback(frustum_cam, frame_axis) |
|
|
|
def update_point_cloud() -> None: |
|
"""Update the point cloud based on current GUI selections.""" |
|
|
|
current_percentage = gui_points_conf.value |
|
threshold_val = np.percentile(conf, current_percentage) |
|
conf_mask = conf > threshold_val |
|
|
|
if gui_frame_selector.value == "All": |
|
frame_mask = np.ones_like(conf_mask, dtype=bool) |
|
else: |
|
selected_idx = int(gui_frame_selector.value) |
|
frame_mask = frame_indices == selected_idx |
|
|
|
combined_mask = conf_mask & frame_mask |
|
point_cloud.points = points_centered[combined_mask] |
|
point_cloud.colors = colors_flat[combined_mask] |
|
|
|
@gui_points_conf.on_update |
|
def _(_) -> None: |
|
update_point_cloud() |
|
|
|
@gui_frame_selector.on_update |
|
def _(_) -> None: |
|
update_point_cloud() |
|
|
|
@gui_show_frames.on_update |
|
def _(_) -> None: |
|
"""Toggle visibility of camera frames and frustums.""" |
|
for f in frames: |
|
f.visible = gui_show_frames.value |
|
for fr in frustums: |
|
fr.visible = gui_show_frames.value |
|
|
|
|
|
visualize_frames(cam_to_world, images) |
|
|
|
print("Starting viser server...") |
|
|
|
if background_mode: |
|
|
|
def server_loop(): |
|
while True: |
|
time.sleep(0.001) |
|
|
|
thread = threading.Thread(target=server_loop, daemon=True) |
|
thread.start() |
|
else: |
|
while True: |
|
time.sleep(0.01) |
|
|
|
return server |
|
|
|
|
|
|
|
|
|
def download_file_from_url(url, filename): |
|
"""Downloads a file from a Hugging Face model repo, handling redirects.""" |
|
try: |
|
|
|
response = requests.get(url, allow_redirects=False) |
|
response.raise_for_status() |
|
|
|
if response.status_code == 302: |
|
redirect_url = response.headers["Location"] |
|
response = requests.get(redirect_url, stream=True) |
|
response.raise_for_status() |
|
else: |
|
print(f"Unexpected status code: {response.status_code}") |
|
return |
|
|
|
with open(filename, "wb") as f: |
|
for chunk in response.iter_content(chunk_size=8192): |
|
f.write(chunk) |
|
print(f"Downloaded {filename} successfully.") |
|
|
|
except requests.exceptions.RequestException as e: |
|
print(f"Error downloading file: {e}") |
|
|
|
|
|
|
|
def apply_sky_segmentation(conf: np.ndarray, image_folder: str) -> np.ndarray: |
|
""" |
|
Apply sky segmentation to confidence scores. |
|
|
|
Args: |
|
conf (np.ndarray): Confidence scores with shape (S, H, W) |
|
image_folder (str): Path to the folder containing input images |
|
|
|
Returns: |
|
np.ndarray: Updated confidence scores with sky regions masked out |
|
""" |
|
S, H, W = conf.shape |
|
sky_masks_dir = image_folder.rstrip('/') + "_sky_masks" |
|
os.makedirs(sky_masks_dir, exist_ok=True) |
|
|
|
|
|
if not os.path.exists("skyseg.onnx"): |
|
print("Downloading skyseg.onnx...") |
|
download_file_from_url( |
|
"https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx", "skyseg.onnx" |
|
) |
|
|
|
skyseg_session = onnxruntime.InferenceSession("skyseg.onnx") |
|
image_files = sorted(glob.glob(os.path.join(image_folder, "*"))) |
|
sky_mask_list = [] |
|
|
|
print("Generating sky masks...") |
|
for i, image_path in enumerate(tqdm(image_files[:S])): |
|
image_name = os.path.basename(image_path) |
|
mask_filepath = os.path.join(sky_masks_dir, image_name) |
|
|
|
if os.path.exists(mask_filepath): |
|
sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE) |
|
else: |
|
sky_mask = segment_sky(image_path, skyseg_session, mask_filepath) |
|
|
|
|
|
if sky_mask.shape[0] != H or sky_mask.shape[1] != W: |
|
sky_mask = cv2.resize(sky_mask, (W, H)) |
|
|
|
sky_mask_list.append(sky_mask) |
|
|
|
|
|
|
|
sky_mask_array = np.array(sky_mask_list) |
|
|
|
sky_mask_binary = (sky_mask_array > 0.01).astype(np.float32) |
|
conf = conf * sky_mask_binary |
|
|
|
print("Sky segmentation applied successfully") |
|
return conf |
|
|
|
|
|
|
|
def segment_sky(image_path, onnx_session, mask_filename=None): |
|
""" |
|
Segments sky from an image using an ONNX model. |
|
|
|
Args: |
|
image_path: Path to input image |
|
onnx_session: ONNX runtime session with loaded model |
|
mask_filename: Path to save the output mask |
|
|
|
Returns: |
|
np.ndarray: Binary mask where 255 indicates non-sky regions |
|
""" |
|
assert mask_filename is not None |
|
image = cv2.imread(image_path) |
|
|
|
result_map = run_skyseg(onnx_session, [320, 320], image) |
|
|
|
result_map_original = cv2.resize(result_map, (image.shape[1], image.shape[0])) |
|
|
|
output_mask = np.zeros_like(result_map_original) |
|
output_mask[result_map_original < 1] = 1 |
|
output_mask = output_mask.astype(np.uint8) * 255 |
|
os.makedirs(os.path.dirname(mask_filename), exist_ok=True) |
|
cv2.imwrite(mask_filename, output_mask) |
|
return output_mask |
|
|
|
|
|
def run_skyseg(onnx_session, input_size, image): |
|
""" |
|
Runs sky segmentation inference using ONNX model. |
|
|
|
Args: |
|
onnx_session: ONNX runtime session |
|
input_size: Target size for model input (width, height) |
|
image: Input image in BGR format |
|
|
|
Returns: |
|
np.ndarray: Segmentation mask |
|
""" |
|
|
|
temp_image = copy.deepcopy(image) |
|
resize_image = cv2.resize(temp_image, dsize=(input_size[0], input_size[1])) |
|
x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB) |
|
x = np.array(x, dtype=np.float32) |
|
mean = [0.485, 0.456, 0.406] |
|
std = [0.229, 0.224, 0.225] |
|
x = (x / 255 - mean) / std |
|
x = x.transpose(2, 0, 1) |
|
x = x.reshape(-1, 3, input_size[0], input_size[1]).astype("float32") |
|
|
|
|
|
input_name = onnx_session.get_inputs()[0].name |
|
output_name = onnx_session.get_outputs()[0].name |
|
onnx_result = onnx_session.run([output_name], {input_name: x}) |
|
|
|
|
|
onnx_result = np.array(onnx_result).squeeze() |
|
min_value = np.min(onnx_result) |
|
max_value = np.max(onnx_result) |
|
onnx_result = (onnx_result - min_value) / (max_value - min_value) |
|
onnx_result *= 255 |
|
onnx_result = onnx_result.astype("uint8") |
|
|
|
return onnx_result |
|
|
|
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="VGGT demo with viser for 3D visualization") |
|
parser.add_argument( |
|
"--image_folder", type=str, default="examples/kitchen/images/", help="Path to folder containing images" |
|
) |
|
parser.add_argument("--use_point_map", action="store_true", help="Use point map instead of depth-based points") |
|
parser.add_argument("--background_mode", action="store_true", help="Run the viser server in background mode") |
|
parser.add_argument("--port", type=int, default=8080, help="Port number for the viser server") |
|
parser.add_argument( |
|
"--conf_threshold", type=float, default=25.0, help="Initial percentage of low-confidence points to filter out" |
|
) |
|
parser.add_argument("--mask_sky", action="store_true", help="Apply sky segmentation to filter out sky points") |
|
|
|
|
|
def main(): |
|
""" |
|
Main function for the VGGT demo with viser for 3D visualization. |
|
|
|
This function: |
|
1. Loads the VGGT model |
|
2. Processes input images from the specified folder |
|
3. Runs inference to generate 3D points and camera poses |
|
4. Optionally applies sky segmentation to filter out sky points |
|
5. Visualizes the results using viser |
|
|
|
Command-line arguments: |
|
--image_folder: Path to folder containing input images |
|
--use_point_map: Use point map instead of depth-based points |
|
--background_mode: Run the viser server in background mode |
|
--port: Port number for the viser server |
|
--conf_threshold: Initial percentage of low-confidence points to filter out |
|
--mask_sky: Apply sky segmentation to filter out sky points |
|
""" |
|
args = parser.parse_args() |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
print(f"Using device: {device}") |
|
|
|
print("Initializing and loading VGGT model...") |
|
model = VGGT() |
|
_URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt" |
|
model.load_state_dict(torch.hub.load_state_dict_from_url(_URL)) |
|
|
|
model.eval() |
|
model = model.to(device) |
|
|
|
|
|
print(f"Loading images from {args.image_folder}...") |
|
image_names = glob.glob(os.path.join(args.image_folder, "*")) |
|
print(f"Found {len(image_names)} images") |
|
|
|
images = load_and_preprocess_images(image_names).to(device) |
|
print(f"Preprocessed images shape: {images.shape}") |
|
|
|
print("Running inference...") |
|
with torch.no_grad(): |
|
with torch.cuda.amp.autocast(dtype=torch.bfloat16): |
|
predictions = model(images) |
|
|
|
print("Converting pose encoding to extrinsic and intrinsic matrices...") |
|
extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:]) |
|
predictions["extrinsic"] = extrinsic |
|
predictions["intrinsic"] = intrinsic |
|
|
|
print("Processing model outputs...") |
|
for key in predictions.keys(): |
|
if isinstance(predictions[key], torch.Tensor): |
|
predictions[key] = predictions[key].cpu().numpy().squeeze(0) |
|
|
|
if args.use_point_map: |
|
print("Visualizing 3D points from point map") |
|
else: |
|
print("Visualizing 3D points by unprojecting depth map by cameras") |
|
|
|
if args.mask_sky: |
|
print("Sky segmentation enabled - will filter out sky points") |
|
|
|
print("Starting viser visualization...") |
|
|
|
viser_server = viser_wrapper( |
|
predictions, |
|
port=args.port, |
|
init_conf_threshold=args.conf_threshold, |
|
use_point_map=args.use_point_map, |
|
background_mode=args.background_mode, |
|
mask_sky=args.mask_sky, |
|
image_folder=args.image_folder, |
|
) |
|
print("Visualization complete") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|