diff --git a/app.py b/app.py index 07938cf89dfdb31378296518585571ddcb025c4b..ef98df51d32956d9242e8cac3837980fc45731e2 100644 --- a/app.py +++ b/app.py @@ -4,68 +4,178 @@ import torch import numpy as np import gradio as gr +import trimesh +import sys +import os +sys.path.append('vggsfm_code/') +import shutil -def parse_video(video_file): - vs = cv2.VideoCapture(video_file) - - frames = [] - while True: - (gotit, frame) = vs.read() - if frame is not None: - frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - frames.append(frame) - if not gotit: - break +from vggsfm_code.hf_demo import demo_fn +from omegaconf import DictConfig, OmegaConf +from viz_utils.viz_fn import add_camera - return np.stack(frames) +# +from scipy.spatial.transform import Rotation +import PIL +import spaces @spaces.GPU -def cotracker_demo( +def vggsfm_demo( + input_image, input_video, - grid_size: int = 10, - tracks_leave_trace: bool = False, + query_frame_num, + max_query_pts + # grid_size: int = 10, ): - load_video = parse_video(input_video) - load_video = torch.from_numpy(load_video).permute(0, 3, 1, 2)[None].float() + cfg_file = "vggsfm_code/cfgs/demo.yaml" + cfg = OmegaConf.load(cfg_file) + + max_input_image = 20 + + target_dir = f"input_images" + if os.path.exists(target_dir): + shutil.rmtree(target_dir) + + os.makedirs(target_dir) + target_dir_images = target_dir + "/images" + os.makedirs(target_dir_images) + + if input_image is not None: + if len(input_image)<3: + return None, "Please input at least three frames" + + input_image = sorted(input_image) + input_image = input_image[:max_input_image] + + # Copy files to the new directory + for file_name in input_image: + shutil.copy(file_name, target_dir_images) + elif input_video is not None: + vs = cv2.VideoCapture(input_video) + + fps = vs.get(cv2.CAP_PROP_FPS) + + frame_rate = 1 + frame_interval = int(fps * frame_rate) + + video_frame_num = 0 + count = 0 + + while video_frame_num<=max_input_image: + (gotit, frame) = vs.read() + count +=1 + + if count % frame_interval == 0: + cv2.imwrite(target_dir_images+"/"+f"{video_frame_num:06}.png", frame) + video_frame_num+=1 + if not gotit: + break + if video_frame_num<3: + return None, "Please input at least three frames" + else: + return None, "Input format incorrect" + + cfg.query_frame_num = query_frame_num + cfg.max_query_pts = max_query_pts + print(f"Files have been copied to {target_dir_images}") + cfg.SCENE_DIR = target_dir + + predictions = demo_fn(cfg) + + glbfile = vggsfm_predictions_to_glb(predictions) + + + print(input_image) + print(input_video) + return glbfile, "Success" - import time - def current_milli_time(): - return round(time.time() * 1000) - filename = str(current_milli_time()) + +def vggsfm_predictions_to_glb(predictions): + # learned from https://github.com/naver/dust3r/blob/main/dust3r/viz.py + points3D = predictions["points3D"].cpu().numpy() + points3D_rgb = predictions["points3D_rgb"].cpu().numpy() + points3D_rgb = (points3D_rgb*255).astype(np.uint8) - return os.path.join( - os.path.dirname(__file__), "results", f"{filename}.mp4" - ) + extrinsics_opencv = predictions["extrinsics_opencv"].cpu().numpy() + intrinsics_opencv = predictions["intrinsics_opencv"].cpu().numpy() + raw_image_paths = predictions["raw_image_paths"] + images = predictions["images"].permute(0,2,3,1).cpu().numpy() + images = (images*255).astype(np.uint8) + + glbscene = trimesh.Scene() + point_cloud = trimesh.PointCloud(points3D, colors=points3D_rgb) + glbscene.add_geometry(point_cloud) + + + camera_edge_colors = [(255, 0, 0), (0, 0, 255), (0, 255, 0), (255, 0, 255), (255, 204, 0), (0, 204, 204), + (128, 255, 255), (255, 128, 255), (255, 255, 128), (0, 0, 0), (128, 128, 128)] + + frame_num = len(extrinsics_opencv) + extrinsics_opencv_4x4 = np.zeros((frame_num, 4, 4)) + extrinsics_opencv_4x4[:, :3, :4] = extrinsics_opencv + extrinsics_opencv_4x4[:, 3, 3] = 1 + for idx in range(frame_num): + cam_from_world = extrinsics_opencv_4x4[idx] + cam_to_world = np.linalg.inv(cam_from_world) + cur_cam_color = camera_edge_colors[idx % len(camera_edge_colors)] + cur_focal = intrinsics_opencv[idx, 0, 0] + # cur_image_path = raw_image_paths[idx] + # cur_image = np.array(PIL.Image.open(cur_image_path)) + # add_camera(glbscene, cam_to_world, cur_cam_color, image=None, imsize=cur_image.shape[1::-1], + # focal=None,screen_width=0.3) + add_camera(glbscene, cam_to_world, cur_cam_color, image=None, imsize=(1024,1024), + focal=None,screen_width=0.35) -app = gr.Interface( - title="🎨 CoTracker: It is Better to Track Together", - description="<div style='text-align: left;'> \ - <p>Welcome to <a href='http://co-tracker.github.io' target='_blank'>CoTracker</a>! This space demonstrates point (pixel) tracking in videos. \ - Points are sampled on a regular grid and are tracked jointly. </p> \ - <p> To get started, simply upload your <b>.mp4</b> video in landscape orientation or click on one of the example videos to load them. The shorter the video, the faster the processing. We recommend submitting short videos of length <b>2-7 seconds</b>.</p> \ - <ul style='display: inline-block; text-align: left;'> \ - <li>The total number of grid points is the square of <b>Grid Size</b>.</li> \ - <li>Check <b>Visualize Track Traces</b> to visualize traces of all the tracked points. </li> \ - </ul> \ - <p style='text-align: left'>For more details, check out our <a href='https://github.com/facebookresearch/co-tracker' target='_blank'>GitHub Repo</a> ⭐</p> \ - </div>", - fn=cotracker_demo, - inputs=[ - gr.Video(type="file", label="Input video", interactive=True), - gr.Slider(minimum=10, maximum=80, step=1, value=10, label="Number of tracks"), - ], - outputs=gr.Video(label="Video with predicted tracks"), - cache_examples=True, - allow_flagging=False, -) -app.queue(max_size=20, concurrency_count=1).launch(debug=True) + opengl_mat = np.array([[1, 0, 0, 0], + [0, -1, 0, 0], + [0, 0, -1, 0], + [0, 0, 0, 1]]) + + rot = np.eye(4) + rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix() + glbscene.apply_transform(np.linalg.inv(np.linalg.inv(extrinsics_opencv_4x4[0]) @ opengl_mat @ rot)) + + glbfile = "glbscene.glb" + glbscene.export(file_obj=glbfile) + return glbfile + + + + + +if True: + demo = gr.Interface( + title="🎨 VGGSfM: Visual Geometry Grounded Deep Structure From Motion", + description="<div style='text-align: left;'> \ + <p>Welcome to <a href='https://github.com/facebookresearch/vggsfm' target='_blank'>VGGSfM</a>!", + fn=vggsfm_demo, + inputs=[ + gr.File(file_count="multiple", label="Input Images", interactive=True), + gr.Video(label="Input video", interactive=True), + gr.Slider(minimum=1, maximum=10, step=1, value=5, label="Number of query images"), + gr.Slider(minimum=512, maximum=4096, step=1, value=1024, label="Number of query points"), + ], + outputs=[gr.Model3D(label="Reconstruction"), gr.Textbox(label="Log")], + cache_examples=True, + allow_flagging=False, + ) + demo.queue(max_size=20, concurrency_count=1).launch(debug=True) + + # demo.launch(debug=True, share=True) +else: + import glob + files = glob.glob(f'vggsfm_code/examples/cake/images/*', recursive=True) + vggsfm_demo(files, None, None) + + +# demo.queue(max_size=20, concurrency_count=1).launch(debug=True, share=True) diff --git a/debug_demo.py b/debug_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..f01818a7a72a1a541a5c7afd16577437fb6787b4 --- /dev/null +++ b/debug_demo.py @@ -0,0 +1,31 @@ +import gradio as gr + +def greet(name, intensity): + return "Hello, " + name + "!" * int(intensity) + +demo = gr.Interface( + fn=greet, + inputs=["text", "slider"], + outputs=["text"], +) + +demo.launch(share=True) + + +import sys +import os + +sys.path.append('vggsfm_code/') + +from vggsfm_code.hf_demo import demo_fn +from omegaconf import DictConfig, OmegaConf + +cfg_file = "vggsfm_code/cfgs/demo.yaml" +cfg = OmegaConf.load(cfg_file) +cfg.SCENE_DIR = "vggsfm_code/examples/cake" + +import pdb;pdb.set_trace() + +demo_fn(cfg) + + diff --git a/glbscene.glb b/glbscene.glb new file mode 100644 index 0000000000000000000000000000000000000000..fe844b11af0e1488fb3553020084e3c1e8294cf8 Binary files /dev/null and b/glbscene.glb differ diff --git a/requirements.txt b/requirements.txt index aee576be6d2837fb01156c7f352bf0ade0eb3dc5..49e0b91cda3c22a7ac0b016a76e512ed3f917291 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,4 @@ git+https://github.com/cvg/LightGlue.git#egg=LightGlue numpy==1.26.3 pycolmap==0.6.1 https://huggingface.co/facebook/VGGSfM/resolve/main/poselib-2.0.2-cp310-cp310-linux_x86_64.whl - +trimesh diff --git a/vggsfm/.gitignore b/vggsfm_code/.gitignore similarity index 100% rename from vggsfm/.gitignore rename to vggsfm_code/.gitignore diff --git a/vggsfm/CHANGELOG.txt b/vggsfm_code/CHANGELOG.txt similarity index 100% rename from vggsfm/CHANGELOG.txt rename to vggsfm_code/CHANGELOG.txt diff --git a/vggsfm/CODE_OF_CONDUCT.md b/vggsfm_code/CODE_OF_CONDUCT.md similarity index 100% rename from vggsfm/CODE_OF_CONDUCT.md rename to vggsfm_code/CODE_OF_CONDUCT.md diff --git a/vggsfm/CONTRIBUTING.md b/vggsfm_code/CONTRIBUTING.md similarity index 100% rename from vggsfm/CONTRIBUTING.md rename to vggsfm_code/CONTRIBUTING.md diff --git a/vggsfm/LICENSE.txt b/vggsfm_code/LICENSE.txt similarity index 100% rename from vggsfm/LICENSE.txt rename to vggsfm_code/LICENSE.txt diff --git a/vggsfm/README.md b/vggsfm_code/README.md similarity index 100% rename from vggsfm/README.md rename to vggsfm_code/README.md diff --git a/vggsfm/assets/ui.png b/vggsfm_code/assets/ui.png similarity index 100% rename from vggsfm/assets/ui.png rename to vggsfm_code/assets/ui.png diff --git a/vggsfm/cfgs/demo.yaml b/vggsfm_code/cfgs/demo.yaml similarity index 100% rename from vggsfm/cfgs/demo.yaml rename to vggsfm_code/cfgs/demo.yaml diff --git a/vggsfm/demo.py b/vggsfm_code/demo.py similarity index 100% rename from vggsfm/demo.py rename to vggsfm_code/demo.py diff --git a/vggsfm/examples/apple/images/frame000007.jpg b/vggsfm_code/examples/apple/images/frame000007.jpg similarity index 100% rename from vggsfm/examples/apple/images/frame000007.jpg rename to vggsfm_code/examples/apple/images/frame000007.jpg diff --git a/vggsfm/examples/apple/images/frame000012.jpg b/vggsfm_code/examples/apple/images/frame000012.jpg similarity index 100% rename from vggsfm/examples/apple/images/frame000012.jpg rename to vggsfm_code/examples/apple/images/frame000012.jpg diff --git a/vggsfm/examples/apple/images/frame000017.jpg b/vggsfm_code/examples/apple/images/frame000017.jpg similarity index 100% rename from vggsfm/examples/apple/images/frame000017.jpg rename to vggsfm_code/examples/apple/images/frame000017.jpg diff --git a/vggsfm/examples/apple/images/frame000019.jpg b/vggsfm_code/examples/apple/images/frame000019.jpg similarity index 100% rename from vggsfm/examples/apple/images/frame000019.jpg rename to vggsfm_code/examples/apple/images/frame000019.jpg diff --git a/vggsfm/examples/apple/images/frame000024.jpg b/vggsfm_code/examples/apple/images/frame000024.jpg similarity index 100% rename from vggsfm/examples/apple/images/frame000024.jpg rename to vggsfm_code/examples/apple/images/frame000024.jpg diff --git a/vggsfm/examples/apple/images/frame000025.jpg b/vggsfm_code/examples/apple/images/frame000025.jpg similarity index 100% rename from vggsfm/examples/apple/images/frame000025.jpg rename to vggsfm_code/examples/apple/images/frame000025.jpg diff --git a/vggsfm/examples/apple/images/frame000043.jpg b/vggsfm_code/examples/apple/images/frame000043.jpg similarity index 100% rename from vggsfm/examples/apple/images/frame000043.jpg rename to vggsfm_code/examples/apple/images/frame000043.jpg diff --git a/vggsfm/examples/apple/images/frame000052.jpg b/vggsfm_code/examples/apple/images/frame000052.jpg similarity index 100% rename from vggsfm/examples/apple/images/frame000052.jpg rename to vggsfm_code/examples/apple/images/frame000052.jpg diff --git a/vggsfm/examples/apple/images/frame000070.jpg b/vggsfm_code/examples/apple/images/frame000070.jpg similarity index 100% rename from vggsfm/examples/apple/images/frame000070.jpg rename to vggsfm_code/examples/apple/images/frame000070.jpg diff --git a/vggsfm/examples/apple/images/frame000077.jpg b/vggsfm_code/examples/apple/images/frame000077.jpg similarity index 100% rename from vggsfm/examples/apple/images/frame000077.jpg rename to vggsfm_code/examples/apple/images/frame000077.jpg diff --git a/vggsfm/examples/apple/images/frame000085.jpg b/vggsfm_code/examples/apple/images/frame000085.jpg similarity index 100% rename from vggsfm/examples/apple/images/frame000085.jpg rename to vggsfm_code/examples/apple/images/frame000085.jpg diff --git a/vggsfm/examples/apple/images/frame000096.jpg b/vggsfm_code/examples/apple/images/frame000096.jpg similarity index 100% rename from vggsfm/examples/apple/images/frame000096.jpg rename to vggsfm_code/examples/apple/images/frame000096.jpg diff --git a/vggsfm/examples/apple/images/frame000128.jpg b/vggsfm_code/examples/apple/images/frame000128.jpg similarity index 100% rename from vggsfm/examples/apple/images/frame000128.jpg rename to vggsfm_code/examples/apple/images/frame000128.jpg diff --git a/vggsfm/examples/apple/images/frame000145.jpg b/vggsfm_code/examples/apple/images/frame000145.jpg similarity index 100% rename from vggsfm/examples/apple/images/frame000145.jpg rename to vggsfm_code/examples/apple/images/frame000145.jpg diff --git a/vggsfm/examples/apple/images/frame000160.jpg b/vggsfm_code/examples/apple/images/frame000160.jpg similarity index 100% rename from vggsfm/examples/apple/images/frame000160.jpg rename to vggsfm_code/examples/apple/images/frame000160.jpg diff --git a/vggsfm/examples/apple/images/frame000162.jpg b/vggsfm_code/examples/apple/images/frame000162.jpg similarity index 100% rename from vggsfm/examples/apple/images/frame000162.jpg rename to vggsfm_code/examples/apple/images/frame000162.jpg diff --git a/vggsfm/examples/apple/images/frame000168.jpg b/vggsfm_code/examples/apple/images/frame000168.jpg similarity index 100% rename from vggsfm/examples/apple/images/frame000168.jpg rename to vggsfm_code/examples/apple/images/frame000168.jpg diff --git a/vggsfm/examples/apple/images/frame000172.jpg b/vggsfm_code/examples/apple/images/frame000172.jpg similarity index 100% rename from vggsfm/examples/apple/images/frame000172.jpg rename to vggsfm_code/examples/apple/images/frame000172.jpg diff --git a/vggsfm/examples/apple/images/frame000191.jpg b/vggsfm_code/examples/apple/images/frame000191.jpg similarity index 100% rename from vggsfm/examples/apple/images/frame000191.jpg rename to vggsfm_code/examples/apple/images/frame000191.jpg diff --git a/vggsfm/examples/apple/images/frame000200.jpg b/vggsfm_code/examples/apple/images/frame000200.jpg similarity index 100% rename from vggsfm/examples/apple/images/frame000200.jpg rename to vggsfm_code/examples/apple/images/frame000200.jpg diff --git a/vggsfm/examples/british_museum/images/29057984_287139632.jpg b/vggsfm_code/examples/british_museum/images/29057984_287139632.jpg similarity index 100% rename from vggsfm/examples/british_museum/images/29057984_287139632.jpg rename to vggsfm_code/examples/british_museum/images/29057984_287139632.jpg diff --git a/vggsfm/examples/british_museum/images/32630292_7166579210.jpg b/vggsfm_code/examples/british_museum/images/32630292_7166579210.jpg similarity index 100% rename from vggsfm/examples/british_museum/images/32630292_7166579210.jpg rename to vggsfm_code/examples/british_museum/images/32630292_7166579210.jpg diff --git a/vggsfm/examples/british_museum/images/45839934_4117745134.jpg b/vggsfm_code/examples/british_museum/images/45839934_4117745134.jpg similarity index 100% rename from vggsfm/examples/british_museum/images/45839934_4117745134.jpg rename to vggsfm_code/examples/british_museum/images/45839934_4117745134.jpg diff --git a/vggsfm/examples/british_museum/images/51004432_567773767.jpg b/vggsfm_code/examples/british_museum/images/51004432_567773767.jpg similarity index 100% rename from vggsfm/examples/british_museum/images/51004432_567773767.jpg rename to vggsfm_code/examples/british_museum/images/51004432_567773767.jpg diff --git a/vggsfm/examples/british_museum/images/62620282_3728576515.jpg b/vggsfm_code/examples/british_museum/images/62620282_3728576515.jpg similarity index 100% rename from vggsfm/examples/british_museum/images/62620282_3728576515.jpg rename to vggsfm_code/examples/british_museum/images/62620282_3728576515.jpg diff --git a/vggsfm/examples/british_museum/images/71931631_7212707886.jpg b/vggsfm_code/examples/british_museum/images/71931631_7212707886.jpg similarity index 100% rename from vggsfm/examples/british_museum/images/71931631_7212707886.jpg rename to vggsfm_code/examples/british_museum/images/71931631_7212707886.jpg diff --git a/vggsfm/examples/british_museum/images/78600497_407639599.jpg b/vggsfm_code/examples/british_museum/images/78600497_407639599.jpg similarity index 100% rename from vggsfm/examples/british_museum/images/78600497_407639599.jpg rename to vggsfm_code/examples/british_museum/images/78600497_407639599.jpg diff --git a/vggsfm/examples/british_museum/images/80340357_5029510336.jpg b/vggsfm_code/examples/british_museum/images/80340357_5029510336.jpg similarity index 100% rename from vggsfm/examples/british_museum/images/80340357_5029510336.jpg rename to vggsfm_code/examples/british_museum/images/80340357_5029510336.jpg diff --git a/vggsfm/examples/british_museum/images/81272348_2712949069.jpg b/vggsfm_code/examples/british_museum/images/81272348_2712949069.jpg similarity index 100% rename from vggsfm/examples/british_museum/images/81272348_2712949069.jpg rename to vggsfm_code/examples/british_museum/images/81272348_2712949069.jpg diff --git a/vggsfm/examples/british_museum/images/93266801_2335569192.jpg b/vggsfm_code/examples/british_museum/images/93266801_2335569192.jpg similarity index 100% rename from vggsfm/examples/british_museum/images/93266801_2335569192.jpg rename to vggsfm_code/examples/british_museum/images/93266801_2335569192.jpg diff --git a/vggsfm/examples/cake/images/frame000020.jpg b/vggsfm_code/examples/cake/images/frame000020.jpg similarity index 100% rename from vggsfm/examples/cake/images/frame000020.jpg rename to vggsfm_code/examples/cake/images/frame000020.jpg diff --git a/vggsfm/examples/cake/images/frame000069.jpg b/vggsfm_code/examples/cake/images/frame000069.jpg similarity index 100% rename from vggsfm/examples/cake/images/frame000069.jpg rename to vggsfm_code/examples/cake/images/frame000069.jpg diff --git a/vggsfm/examples/cake/images/frame000096.jpg b/vggsfm_code/examples/cake/images/frame000096.jpg similarity index 100% rename from vggsfm/examples/cake/images/frame000096.jpg rename to vggsfm_code/examples/cake/images/frame000096.jpg diff --git a/vggsfm/examples/cake/images/frame000112.jpg b/vggsfm_code/examples/cake/images/frame000112.jpg similarity index 100% rename from vggsfm/examples/cake/images/frame000112.jpg rename to vggsfm_code/examples/cake/images/frame000112.jpg diff --git a/vggsfm/examples/cake/images/frame000146.jpg b/vggsfm_code/examples/cake/images/frame000146.jpg similarity index 100% rename from vggsfm/examples/cake/images/frame000146.jpg rename to vggsfm_code/examples/cake/images/frame000146.jpg diff --git a/vggsfm/examples/cake/images/frame000149.jpg b/vggsfm_code/examples/cake/images/frame000149.jpg similarity index 100% rename from vggsfm/examples/cake/images/frame000149.jpg rename to vggsfm_code/examples/cake/images/frame000149.jpg diff --git a/vggsfm/examples/cake/images/frame000166.jpg b/vggsfm_code/examples/cake/images/frame000166.jpg similarity index 100% rename from vggsfm/examples/cake/images/frame000166.jpg rename to vggsfm_code/examples/cake/images/frame000166.jpg diff --git a/vggsfm/examples/cake/images/frame000169.jpg b/vggsfm_code/examples/cake/images/frame000169.jpg similarity index 100% rename from vggsfm/examples/cake/images/frame000169.jpg rename to vggsfm_code/examples/cake/images/frame000169.jpg diff --git a/vggsfm_code/hf_demo.py b/vggsfm_code/hf_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..a36d9cec308a0083c6afcb11b34bf3b159e868ee --- /dev/null +++ b/vggsfm_code/hf_demo.py @@ -0,0 +1,457 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import time +import random +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from torch.cuda.amp import autocast +import hydra + +from omegaconf import DictConfig, OmegaConf +from hydra.utils import instantiate + +from lightglue import LightGlue, SuperPoint, SIFT, ALIKED + +import pycolmap + +from visdom import Visdom + + +from vggsfm.datasets.demo_loader import DemoLoader +from vggsfm.two_view_geo.estimate_preliminary import estimate_preliminary_cameras + +try: + import poselib + from vggsfm.two_view_geo.estimate_preliminary import estimate_preliminary_cameras_poselib + + print("Poselib is available") +except: + print("Poselib is not installed. Please disable use_poselib") + +from vggsfm.utils.utils import ( + set_seed_and_print, + farthest_point_sampling, + calculate_index_mappings, + switch_tensor_order, +) + + +def demo_fn(cfg): + OmegaConf.set_struct(cfg, False) + + # Print configuration + print("Model Config:", OmegaConf.to_yaml(cfg)) + + torch.backends.cudnn.enabled = False + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = True + + # Set seed + seed_all_random_engines(cfg.seed) + + # Model instantiation + model = instantiate(cfg.MODEL, _recursive_=False, cfg=cfg) + + device = "cuda" if torch.cuda.is_available() else "cpu" + + model = model.to(device) + + # Prepare test dataset + test_dataset = DemoLoader( + SCENE_DIR=cfg.SCENE_DIR, img_size=cfg.img_size, normalize_cameras=False, load_gt=cfg.load_gt, cfg=cfg + ) + + # if cfg.resume_ckpt: + _VGGSFM_URL = "https://huggingface.co/facebook/VGGSfM/resolve/main/vggsfm_v2_0_0.bin" + + # Reload model + checkpoint = torch.hub.load_state_dict_from_url(_VGGSFM_URL) + model.load_state_dict(checkpoint, strict=True) + print(f"Successfully resumed from {_VGGSFM_URL}") + + + sequence_list = test_dataset.sequence_list + + for seq_name in sequence_list: + print("*" * 50 + f" Testing on Scene {seq_name} " + "*" * 50) + + # Load the data + batch, image_paths = test_dataset.get_data(sequence_name=seq_name, return_path=True) + + # Send to GPU + images = batch["image"].to(device) + crop_params = batch["crop_params"].to(device) + + + # Unsqueeze to have batch size = 1 + images = images.unsqueeze(0) + crop_params = crop_params.unsqueeze(0) + + batch_size = len(images) + + with torch.no_grad(): + # Run the model + assert cfg.mixed_precision in ("None", "bf16", "fp16") + if cfg.mixed_precision == "None": + dtype = torch.float32 + elif cfg.mixed_precision == "bf16": + dtype = torch.bfloat16 + elif cfg.mixed_precision == "fp16": + dtype = torch.float16 + else: + raise NotImplementedError(f"dtype {cfg.mixed_precision} is not supported now") + + predictions = run_one_scene( + model, + images, + crop_params=crop_params, + query_frame_num=cfg.query_frame_num, + image_paths=image_paths, + dtype=dtype, + cfg=cfg, + ) + + pred_cameras_PT3D = predictions["pred_cameras_PT3D"] + + return predictions + + +def run_one_scene(model, images, crop_params=None, query_frame_num=3, image_paths=None, dtype=None, cfg=None): + """ + images have been normalized to the range [0, 1] instead of [0, 255] + """ + batch_num, frame_num, image_dim, height, width = images.shape + device = images.device + reshaped_image = images.reshape(batch_num * frame_num, image_dim, height, width) + + predictions = {} + extra_dict = {} + + camera_predictor = model.camera_predictor + track_predictor = model.track_predictor + triangulator = model.triangulator + + # Find the query frames + # First use DINO to find the most common frame among all the input frames + # i.e., the one has highest (average) cosine similarity to all others + # Then use farthest_point_sampling to find the next ones + # The number of query frames is determined by query_frame_num + + with autocast(dtype=dtype): + query_frame_indexes = find_query_frame_indexes(reshaped_image, camera_predictor, frame_num) + + raw_image_paths = image_paths + image_paths = [os.path.basename(imgpath) for imgpath in image_paths] + + if cfg.center_order: + # The code below switchs the first frame (frame 0) to the most common frame + center_frame_index = query_frame_indexes[0] + center_order = calculate_index_mappings(center_frame_index, frame_num, device=device) + + images, crop_params = switch_tensor_order([images, crop_params], center_order, dim=1) + reshaped_image = switch_tensor_order([reshaped_image], center_order, dim=0)[0] + + image_paths = [image_paths[i] for i in center_order.cpu().numpy().tolist()] + + # Also update query_frame_indexes: + query_frame_indexes = [center_frame_index if x == 0 else x for x in query_frame_indexes] + query_frame_indexes[0] = 0 + + # only pick query_frame_num + query_frame_indexes = query_frame_indexes[:query_frame_num] + + # Prepare image feature maps for tracker + fmaps_for_tracker = track_predictor.process_images_to_fmaps(images) + + # Predict tracks + with autocast(dtype=dtype): + pred_track, pred_vis, pred_score = predict_tracks( + cfg.query_method, + cfg.max_query_pts, + track_predictor, + images, + fmaps_for_tracker, + query_frame_indexes, + frame_num, + device, + cfg, + ) + + if cfg.comple_nonvis: + pred_track, pred_vis, pred_score = comple_nonvis_frames( + track_predictor, + images, + fmaps_for_tracker, + frame_num, + device, + pred_track, + pred_vis, + pred_score, + 200, + cfg=cfg, + ) + + torch.cuda.empty_cache() + + # If necessary, force all the predictions at the padding areas as non-visible + if crop_params is not None: + boundaries = crop_params[:, :, -4:-2].abs().to(device) + boundaries = torch.cat([boundaries, reshaped_image.shape[-1] - boundaries], dim=-1) + hvis = torch.logical_and( + pred_track[..., 1] >= boundaries[:, :, 1:2], pred_track[..., 1] <= boundaries[:, :, 3:4] + ) + wvis = torch.logical_and( + pred_track[..., 0] >= boundaries[:, :, 0:1], pred_track[..., 0] <= boundaries[:, :, 2:3] + ) + force_vis = torch.logical_and(hvis, wvis) + pred_vis = pred_vis * force_vis.float() + + # TODO: plot 2D matches + if cfg.use_poselib: + estimate_preliminary_cameras_fn = estimate_preliminary_cameras_poselib + else: + estimate_preliminary_cameras_fn = estimate_preliminary_cameras + + # Estimate preliminary_cameras by recovering fundamental/essential/homography matrix from 2D matches + # By default, we use fundamental matrix estimation with 7p/8p+LORANSAC + # All the operations are batched and differentiable (if necessary) + # except when you enable use_poselib to save GPU memory + _, preliminary_dict = estimate_preliminary_cameras_fn( + pred_track, + pred_vis, + width, + height, + tracks_score=pred_score, + max_error=cfg.fmat_thres, + loopresidual=True, + # max_ransac_iters=cfg.max_ransac_iters, + ) + + pose_predictions = camera_predictor(reshaped_image, batch_size=batch_num) + + pred_cameras = pose_predictions["pred_cameras"] + + # Conduct Triangulation and Bundle Adjustment + ( + BA_cameras_PT3D, + extrinsics_opencv, + intrinsics_opencv, + points3D, + points3D_rgb, + reconstruction, + valid_frame_mask, + ) = triangulator( + pred_cameras, + pred_track, + pred_vis, + images, + preliminary_dict, + image_paths=image_paths, + crop_params=crop_params, + pred_score=pred_score, + fmat_thres=cfg.fmat_thres, + BA_iters=cfg.BA_iters, + max_reproj_error = cfg.max_reproj_error, + init_max_reproj_error=cfg.init_max_reproj_error, + cfg=cfg, + ) + + if cfg.center_order: + # NOTE we changed the image order previously, now we need to switch it back + BA_cameras_PT3D = BA_cameras_PT3D[center_order] + extrinsics_opencv = extrinsics_opencv[center_order] + intrinsics_opencv = intrinsics_opencv[center_order] + + if cfg.filter_invalid_frame: + raw_image_paths = np.array(raw_image_paths)[valid_frame_mask.cpu().numpy().tolist()].tolist() + images = images[0][valid_frame_mask] + + predictions["pred_cameras_PT3D"] = BA_cameras_PT3D + predictions["extrinsics_opencv"] = extrinsics_opencv + predictions["intrinsics_opencv"] = intrinsics_opencv + predictions["points3D"] = points3D + predictions["points3D_rgb"] = points3D_rgb + predictions["reconstruction"] = reconstruction + predictions["images"] = images + predictions["raw_image_paths"] = raw_image_paths + return predictions + + +def predict_tracks( + query_method, + max_query_pts, + track_predictor, + images, + fmaps_for_tracker, + query_frame_indexes, + frame_num, + device, + cfg=None, +): + pred_track_list = [] + pred_vis_list = [] + pred_score_list = [] + + for query_index in query_frame_indexes: + print(f"Predicting tracks with query_index = {query_index}") + + # Find query_points at the query frame + query_points = get_query_points(images[:, query_index], query_method, max_query_pts) + + # Switch so that query_index frame stays at the first frame + # This largely simplifies the code structure of tracker + new_order = calculate_index_mappings(query_index, frame_num, device=device) + images_feed, fmaps_feed = switch_tensor_order([images, fmaps_for_tracker], new_order) + + # Feed into track predictor + fine_pred_track, _, pred_vis, pred_score = track_predictor(images_feed, query_points, fmaps=fmaps_feed) + + # Switch back the predictions + fine_pred_track, pred_vis, pred_score = switch_tensor_order([fine_pred_track, pred_vis, pred_score], new_order) + + # Append predictions for different queries + pred_track_list.append(fine_pred_track) + pred_vis_list.append(pred_vis) + pred_score_list.append(pred_score) + + pred_track = torch.cat(pred_track_list, dim=2) + pred_vis = torch.cat(pred_vis_list, dim=2) + pred_score = torch.cat(pred_score_list, dim=2) + + return pred_track, pred_vis, pred_score + + +def comple_nonvis_frames( + track_predictor, + images, + fmaps_for_tracker, + frame_num, + device, + pred_track, + pred_vis, + pred_score, + min_vis=500, + cfg=None, +): + # if a frame has too few visible inlier, use it as a query + non_vis_frames = torch.nonzero((pred_vis.squeeze(0) > 0.05).sum(-1) < min_vis).squeeze(-1).tolist() + last_query = -1 + while len(non_vis_frames) > 0: + print("Processing non visible frames") + print(non_vis_frames) + if non_vis_frames[0] == last_query: + print("The non vis frame still does not has enough 2D matches") + pred_track_comple, pred_vis_comple, pred_score_comple = predict_tracks( + "sp+sift+aliked", + cfg.max_query_pts // 2, + track_predictor, + images, + fmaps_for_tracker, + non_vis_frames, + frame_num, + device, + cfg, + ) + # concat predictions + pred_track = torch.cat([pred_track, pred_track_comple], dim=2) + pred_vis = torch.cat([pred_vis, pred_vis_comple], dim=2) + pred_score = torch.cat([pred_score, pred_score_comple], dim=2) + break + + non_vis_query_list = [non_vis_frames[0]] + last_query = non_vis_frames[0] + pred_track_comple, pred_vis_comple, pred_score_comple = predict_tracks( + cfg.query_method, + cfg.max_query_pts, + track_predictor, + images, + fmaps_for_tracker, + non_vis_query_list, + frame_num, + device, + cfg, + ) + + # concat predictions + pred_track = torch.cat([pred_track, pred_track_comple], dim=2) + pred_vis = torch.cat([pred_vis, pred_vis_comple], dim=2) + pred_score = torch.cat([pred_score, pred_score_comple], dim=2) + non_vis_frames = torch.nonzero((pred_vis.squeeze(0) > 0.05).sum(-1) < min_vis).squeeze(-1).tolist() + return pred_track, pred_vis, pred_score + + +def find_query_frame_indexes(reshaped_image, camera_predictor, query_frame_num, image_size=336): + # Downsample image to image_size x image_size + # because we found it is unnecessary to use high resolution + rgbs = F.interpolate(reshaped_image, (image_size, image_size), mode="bilinear", align_corners=True) + rgbs = camera_predictor._resnet_normalize_image(rgbs) + + # Get the image features (patch level) + frame_feat = camera_predictor.backbone(rgbs, is_training=True) + frame_feat = frame_feat["x_norm_patchtokens"] + frame_feat_norm = F.normalize(frame_feat, p=2, dim=1) + + # Compute the similiarty matrix + frame_feat_norm = frame_feat_norm.permute(1, 0, 2) + similarity_matrix = torch.bmm(frame_feat_norm, frame_feat_norm.transpose(-1, -2)) + similarity_matrix = similarity_matrix.mean(dim=0) + distance_matrix = 100 - similarity_matrix.clone() + + # Ignore self-pairing + similarity_matrix.fill_diagonal_(-100) + + similarity_sum = similarity_matrix.sum(dim=1) + + # Find the most common frame + most_common_frame_index = torch.argmax(similarity_sum).item() + + # Conduct FPS sampling + # Starting from the most_common_frame_index, + # try to find the farthest frame, + # then the farthest to the last found frame + # (frames are not allowed to be found twice) + fps_idx = farthest_point_sampling(distance_matrix, query_frame_num, most_common_frame_index) + + return fps_idx + + +def get_query_points(query_image, query_method, max_query_num=4096, det_thres=0.005): + # Run superpoint and sift on the target frame + # Feel free to modify for your own + + methods = query_method.split("+") + pred_points = [] + + for method in methods: + if "sp" in method: + extractor = SuperPoint(max_num_keypoints=max_query_num, detection_threshold=det_thres).cuda().eval() + elif "sift" in method: + extractor = SIFT(max_num_keypoints=max_query_num).cuda().eval() + elif "aliked" in method: + extractor = ALIKED(max_num_keypoints=max_query_num, detection_threshold=det_thres).cuda().eval() + else: + raise NotImplementedError(f"query method {method} is not supprted now") + + query_points = extractor.extract(query_image)["keypoints"] + pred_points.append(query_points) + + query_points = torch.cat(pred_points, dim=1) + + if query_points.shape[1] > max_query_num: + random_point_indices = torch.randperm(query_points.shape[1])[:max_query_num] + query_points = query_points[:, random_point_indices, :] + + return query_points + + +def seed_all_random_engines(seed: int) -> None: + np.random.seed(seed) + torch.manual_seed(seed) + random.seed(seed) diff --git a/vggsfm/install.sh b/vggsfm_code/install.sh similarity index 100% rename from vggsfm/install.sh rename to vggsfm_code/install.sh diff --git a/vggsfm/minipytorch3d/__init__.py b/vggsfm_code/minipytorch3d/__init__.py similarity index 100% rename from vggsfm/minipytorch3d/__init__.py rename to vggsfm_code/minipytorch3d/__init__.py diff --git a/vggsfm/minipytorch3d/cameras.py b/vggsfm_code/minipytorch3d/cameras.py similarity index 100% rename from vggsfm/minipytorch3d/cameras.py rename to vggsfm_code/minipytorch3d/cameras.py diff --git a/vggsfm/minipytorch3d/device_utils.py b/vggsfm_code/minipytorch3d/device_utils.py similarity index 100% rename from vggsfm/minipytorch3d/device_utils.py rename to vggsfm_code/minipytorch3d/device_utils.py diff --git a/vggsfm/minipytorch3d/harmonic_embedding.py b/vggsfm_code/minipytorch3d/harmonic_embedding.py similarity index 100% rename from vggsfm/minipytorch3d/harmonic_embedding.py rename to vggsfm_code/minipytorch3d/harmonic_embedding.py diff --git a/vggsfm/minipytorch3d/renderer_utils.py b/vggsfm_code/minipytorch3d/renderer_utils.py similarity index 100% rename from vggsfm/minipytorch3d/renderer_utils.py rename to vggsfm_code/minipytorch3d/renderer_utils.py diff --git a/vggsfm/minipytorch3d/rotation_conversions.py b/vggsfm_code/minipytorch3d/rotation_conversions.py similarity index 100% rename from vggsfm/minipytorch3d/rotation_conversions.py rename to vggsfm_code/minipytorch3d/rotation_conversions.py diff --git a/vggsfm/minipytorch3d/transform3d.py b/vggsfm_code/minipytorch3d/transform3d.py similarity index 100% rename from vggsfm/minipytorch3d/transform3d.py rename to vggsfm_code/minipytorch3d/transform3d.py diff --git a/vggsfm/vggsfm/datasets/camera_transform.py b/vggsfm_code/vggsfm/datasets/camera_transform.py similarity index 100% rename from vggsfm/vggsfm/datasets/camera_transform.py rename to vggsfm_code/vggsfm/datasets/camera_transform.py diff --git a/vggsfm/vggsfm/datasets/demo_loader.py b/vggsfm_code/vggsfm/datasets/demo_loader.py similarity index 100% rename from vggsfm/vggsfm/datasets/demo_loader.py rename to vggsfm_code/vggsfm/datasets/demo_loader.py diff --git a/vggsfm/vggsfm/datasets/imc.py b/vggsfm_code/vggsfm/datasets/imc.py similarity index 100% rename from vggsfm/vggsfm/datasets/imc.py rename to vggsfm_code/vggsfm/datasets/imc.py diff --git a/vggsfm/vggsfm/datasets/imc_helper.py b/vggsfm_code/vggsfm/datasets/imc_helper.py similarity index 100% rename from vggsfm/vggsfm/datasets/imc_helper.py rename to vggsfm_code/vggsfm/datasets/imc_helper.py diff --git a/vggsfm/vggsfm/models/__init__.py b/vggsfm_code/vggsfm/models/__init__.py similarity index 100% rename from vggsfm/vggsfm/models/__init__.py rename to vggsfm_code/vggsfm/models/__init__.py diff --git a/vggsfm/vggsfm/models/camera_predictor.py b/vggsfm_code/vggsfm/models/camera_predictor.py similarity index 100% rename from vggsfm/vggsfm/models/camera_predictor.py rename to vggsfm_code/vggsfm/models/camera_predictor.py diff --git a/vggsfm/vggsfm/models/modules.py b/vggsfm_code/vggsfm/models/modules.py similarity index 100% rename from vggsfm/vggsfm/models/modules.py rename to vggsfm_code/vggsfm/models/modules.py diff --git a/vggsfm/vggsfm/models/track_modules/__init__.py b/vggsfm_code/vggsfm/models/track_modules/__init__.py similarity index 100% rename from vggsfm/vggsfm/models/track_modules/__init__.py rename to vggsfm_code/vggsfm/models/track_modules/__init__.py diff --git a/vggsfm/vggsfm/models/track_modules/base_track_predictor.py b/vggsfm_code/vggsfm/models/track_modules/base_track_predictor.py similarity index 100% rename from vggsfm/vggsfm/models/track_modules/base_track_predictor.py rename to vggsfm_code/vggsfm/models/track_modules/base_track_predictor.py diff --git a/vggsfm/vggsfm/models/track_modules/blocks.py b/vggsfm_code/vggsfm/models/track_modules/blocks.py similarity index 100% rename from vggsfm/vggsfm/models/track_modules/blocks.py rename to vggsfm_code/vggsfm/models/track_modules/blocks.py diff --git a/vggsfm/vggsfm/models/track_modules/refine_track.py b/vggsfm_code/vggsfm/models/track_modules/refine_track.py similarity index 100% rename from vggsfm/vggsfm/models/track_modules/refine_track.py rename to vggsfm_code/vggsfm/models/track_modules/refine_track.py diff --git a/vggsfm/vggsfm/models/track_predictor.py b/vggsfm_code/vggsfm/models/track_predictor.py similarity index 100% rename from vggsfm/vggsfm/models/track_predictor.py rename to vggsfm_code/vggsfm/models/track_predictor.py diff --git a/vggsfm/vggsfm/models/triangulator.py b/vggsfm_code/vggsfm/models/triangulator.py similarity index 100% rename from vggsfm/vggsfm/models/triangulator.py rename to vggsfm_code/vggsfm/models/triangulator.py diff --git a/vggsfm/vggsfm/models/utils.py b/vggsfm_code/vggsfm/models/utils.py similarity index 100% rename from vggsfm/vggsfm/models/utils.py rename to vggsfm_code/vggsfm/models/utils.py diff --git a/vggsfm/vggsfm/models/vggsfm.py b/vggsfm_code/vggsfm/models/vggsfm.py similarity index 100% rename from vggsfm/vggsfm/models/vggsfm.py rename to vggsfm_code/vggsfm/models/vggsfm.py diff --git a/vggsfm/vggsfm/two_view_geo/essential.py b/vggsfm_code/vggsfm/two_view_geo/essential.py similarity index 100% rename from vggsfm/vggsfm/two_view_geo/essential.py rename to vggsfm_code/vggsfm/two_view_geo/essential.py diff --git a/vggsfm/vggsfm/two_view_geo/estimate_preliminary.py b/vggsfm_code/vggsfm/two_view_geo/estimate_preliminary.py similarity index 100% rename from vggsfm/vggsfm/two_view_geo/estimate_preliminary.py rename to vggsfm_code/vggsfm/two_view_geo/estimate_preliminary.py diff --git a/vggsfm/vggsfm/two_view_geo/fundamental.py b/vggsfm_code/vggsfm/two_view_geo/fundamental.py similarity index 100% rename from vggsfm/vggsfm/two_view_geo/fundamental.py rename to vggsfm_code/vggsfm/two_view_geo/fundamental.py diff --git a/vggsfm/vggsfm/two_view_geo/homography.py b/vggsfm_code/vggsfm/two_view_geo/homography.py similarity index 100% rename from vggsfm/vggsfm/two_view_geo/homography.py rename to vggsfm_code/vggsfm/two_view_geo/homography.py diff --git a/vggsfm/vggsfm/two_view_geo/perspective_n_points.py b/vggsfm_code/vggsfm/two_view_geo/perspective_n_points.py similarity index 100% rename from vggsfm/vggsfm/two_view_geo/perspective_n_points.py rename to vggsfm_code/vggsfm/two_view_geo/perspective_n_points.py diff --git a/vggsfm/vggsfm/two_view_geo/pnp.py b/vggsfm_code/vggsfm/two_view_geo/pnp.py similarity index 100% rename from vggsfm/vggsfm/two_view_geo/pnp.py rename to vggsfm_code/vggsfm/two_view_geo/pnp.py diff --git a/vggsfm/vggsfm/two_view_geo/utils.py b/vggsfm_code/vggsfm/two_view_geo/utils.py similarity index 100% rename from vggsfm/vggsfm/two_view_geo/utils.py rename to vggsfm_code/vggsfm/two_view_geo/utils.py diff --git a/vggsfm/vggsfm/utils/metric.py b/vggsfm_code/vggsfm/utils/metric.py similarity index 100% rename from vggsfm/vggsfm/utils/metric.py rename to vggsfm_code/vggsfm/utils/metric.py diff --git a/vggsfm/vggsfm/utils/tensor_to_pycolmap.py b/vggsfm_code/vggsfm/utils/tensor_to_pycolmap.py similarity index 100% rename from vggsfm/vggsfm/utils/tensor_to_pycolmap.py rename to vggsfm_code/vggsfm/utils/tensor_to_pycolmap.py diff --git a/vggsfm/vggsfm/utils/triangulation.py b/vggsfm_code/vggsfm/utils/triangulation.py similarity index 100% rename from vggsfm/vggsfm/utils/triangulation.py rename to vggsfm_code/vggsfm/utils/triangulation.py diff --git a/vggsfm/vggsfm/utils/triangulation_helpers.py b/vggsfm_code/vggsfm/utils/triangulation_helpers.py similarity index 100% rename from vggsfm/vggsfm/utils/triangulation_helpers.py rename to vggsfm_code/vggsfm/utils/triangulation_helpers.py diff --git a/vggsfm/vggsfm/utils/utils.py b/vggsfm_code/vggsfm/utils/utils.py similarity index 100% rename from vggsfm/vggsfm/utils/utils.py rename to vggsfm_code/vggsfm/utils/utils.py diff --git a/viz_utils/__pycache__/viz_fn.cpython-310.pyc b/viz_utils/__pycache__/viz_fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..275454792daea52b84a1b1cee8e2637219649c88 Binary files /dev/null and b/viz_utils/__pycache__/viz_fn.cpython-310.pyc differ diff --git a/viz_utils/viz_fn.py b/viz_utils/viz_fn.py new file mode 100644 index 0000000000000000000000000000000000000000..5143a6987cdbdfa459c5b3414c5487cacb4a17b5 --- /dev/null +++ b/viz_utils/viz_fn.py @@ -0,0 +1,148 @@ +import os +import cv2 +import torch +import numpy as np +import gradio as gr + +import trimesh +import sys +import os + +# sys.path.append('vggsfm_code/') +import shutil +from datetime import datetime + +# from vggsfm_code.hf_demo import demo_fn +# from omegaconf import DictConfig, OmegaConf +# from viz_utils.viz_fn import add_camera + +from scipy.spatial.transform import Rotation +import PIL + + +def add_camera(scene, pose_c2w, edge_color, image=None, + focal=None, imsize=None, + screen_width=0.03, marker=None): + # learned from https://github.com/naver/dust3r/blob/main/dust3r/viz.py + + opengl_mat = np.array([[1, 0, 0, 0], + [0, -1, 0, 0], + [0, 0, -1, 0], + [0, 0, 0, 1]]) + + if image is not None: + image = np.asarray(image) + H, W, THREE = image.shape + assert THREE == 3 + if image.dtype != np.uint8: + image = np.uint8(255*image) + elif imsize is not None: + W, H = imsize + elif focal is not None: + H = W = focal / 1.1 + else: + H = W = 1 + + + if isinstance(focal, np.ndarray): + focal = focal[0] + if not focal: + focal = min(H,W) * 1.1 # default value + + # create fake camera + height = max( screen_width/10, focal * screen_width / H ) + width = screen_width * 0.5**0.5 + rot45 = np.eye(4) + rot45[:3, :3] = Rotation.from_euler('z', np.deg2rad(45)).as_matrix() + rot45[2, 3] = -height # set the tip of the cone = optical center + aspect_ratio = np.eye(4) + aspect_ratio[0, 0] = W/H + transform = pose_c2w @ opengl_mat @ aspect_ratio @ rot45 + cam = trimesh.creation.cone(width, height, sections=4) # , transform=transform) + + # this is the image + if image is not None: + vertices = geotrf(transform, cam.vertices[[4, 5, 1, 3]]) + faces = np.array([[0, 1, 2], [0, 2, 3], [2, 1, 0], [3, 2, 0]]) + img = trimesh.Trimesh(vertices=vertices, faces=faces) + uv_coords = np.float32([[0, 0], [1, 0], [1, 1], [0, 1]]) + img.visual = trimesh.visual.TextureVisuals(uv_coords, image=PIL.Image.fromarray(image)) + scene.add_geometry(img) + + # this is the camera mesh + rot2 = np.eye(4) + rot2[:3, :3] = Rotation.from_euler('z', np.deg2rad(2)).as_matrix() + vertices = np.r_[cam.vertices, 0.95*cam.vertices, geotrf(rot2, cam.vertices)] + vertices = geotrf(transform, vertices) + faces = [] + for face in cam.faces: + if 0 in face: + continue + a, b, c = face + a2, b2, c2 = face + len(cam.vertices) + a3, b3, c3 = face + 2*len(cam.vertices) + + # add 3 pseudo-edges + faces.append((a, b, b2)) + faces.append((a, a2, c)) + faces.append((c2, b, c)) + + faces.append((a, b, b3)) + faces.append((a, a3, c)) + faces.append((c3, b, c)) + + # no culling + faces += [(c, b, a) for a, b, c in faces] + + cam = trimesh.Trimesh(vertices=vertices, faces=faces) + cam.visual.face_colors[:, :3] = edge_color + scene.add_geometry(cam) + + if marker == 'o': + marker = trimesh.creation.icosphere(3, radius=screen_width/4) + marker.vertices += pose_c2w[:3,3] + marker.visual.face_colors[:,:3] = edge_color + scene.add_geometry(marker) + +def geotrf(Trf, pts, ncol=None, norm=False): + # learned from https://github.com/naver/dust3r/blob/main/dust3r/ + + assert Trf.ndim >= 2 + pts = np.asarray(pts) + + # adapt shape if necessary + output_reshape = pts.shape[:-1] + ncol = ncol or pts.shape[-1] + + if Trf.ndim >= 3: + n = Trf.ndim - 2 + assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match' + Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1]) + + if pts.ndim > Trf.ndim: + # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d) + pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1]) + elif pts.ndim == 2: + # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d) + pts = pts[:, None, :] + + if pts.shape[-1] + 1 == Trf.shape[-1]: + Trf = Trf.swapaxes(-1, -2) # transpose Trf + pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :] + elif pts.shape[-1] == Trf.shape[-1]: + Trf = Trf.swapaxes(-1, -2) # transpose Trf + pts = pts @ Trf + else: + pts = Trf @ pts.T + if pts.ndim >= 2: + pts = pts.swapaxes(-1, -2) + + if norm: + pts = pts / pts[..., -1:] + if norm != 1: + pts *= norm + + res = pts[..., :ncol].reshape(*output_reshape, ncol) + return res + +