import os import torch import torchvision.transforms as tvf import torch.nn.functional as F import numpy as np from dust3r.utils.device import to_numpy from dust3r.inference import inference from dust3r.model import AsymmetricCroCo3DStereo from dust3r.cloud_opt import global_aligner, GlobalAlignerMode from utils.dust3r_utils import compute_global_alignment from mast3r.model import AsymmetricMASt3R from mast3r.cloud_opt.sparse_ga import sparse_global_alignment from mast3r.cloud_opt.tsdf_optimizer import TSDFPostProcess from hydra.utils import instantiate from omegaconf import OmegaConf class TorchPCA(object): def __init__(self, n_components): self.n_components = n_components def fit(self, X): self.mean_ = X.mean(dim=0) unbiased = X - self.mean_.unsqueeze(0) U, S, V = torch.pca_lowrank(unbiased, q=self.n_components, center=False, niter=50) self.components_ = V.T self.singular_values_ = S return self def transform(self, X): t0 = X - self.mean_.unsqueeze(0) projected = t0 @ self.components_.T return projected def pca(stacked_feat, dim): flattened_feats = [] for feat in stacked_feat: H, W, C = feat.shape feat = feat.reshape(H * W, C).detach() flattened_feats.append(feat) x = torch.cat(flattened_feats, dim=0) fit_pca = TorchPCA(n_components=dim).fit(x) projected_feats = [] for feat in stacked_feat: H, W, C = feat.shape feat = feat.reshape(H * W, C).detach() x_red = fit_pca.transform(feat) projected_feats.append(x_red.reshape(H, W, dim)) projected_feats = torch.stack(projected_feats) return projected_feats def upsampler(feature, upsampled_height, upsampled_width, max_chunk=None): """ Upsample the feature tensor to the specified height and width. Args: - feature (torch.Tensor): The input tensor with size [B, H, W, C]. - upsampled_height (int): The target height after upsampling. - upsampled_width (int): The target width after upsampling. Returns: - upsampled_feature (torch.Tensor): The upsampled tensor with size [B, upsampled_height, upsampled_width, C]. """ # Permute the tensor to [B, C, H, W] for interpolation feature = feature.permute(0, 3, 1, 2) # Perform the upsampling if max_chunk: upsampled_chunks = [] for i in range(0, len(feature), max_chunk): chunk = feature[i:i+max_chunk] upsampled_chunk = F.interpolate(chunk, size=(upsampled_height, upsampled_width), mode='bilinear', align_corners=False) upsampled_chunks.append(upsampled_chunk) upsampled_feature = torch.cat(upsampled_chunks, dim=0) else: upsampled_feature = F.interpolate(feature, size=(upsampled_height, upsampled_width), mode='bilinear', align_corners=False) # Permute back to [B, H, W, C] upsampled_feature = upsampled_feature.permute(0, 2, 3, 1) return upsampled_feature def visualizer(features, images, save_dir, dim=9, feat_type=None, file_name=None): """ Visualize features and corresponding images, and save the result. Args: features (torch.Tensor): Feature tensor with shape [B, H, W, C]. images (list): List of dictionaries containing images with keys 'img'. Each image tensor has shape [1, 3, H, W] and values in the range [-1, 1]. save_dir (str): Directory to save the resulting visualization. feat_type (list): List of feature types. file_name (str): Name of the file to save. """ import matplotlib matplotlib.use('Agg') from matplotlib import pyplot as plt import torchvision.utils as vutils assert features.dim() == 4, "Input tensor must have 4 dimensions (B, H, W, C)" B, H, W, C = features.size() features = features[..., dim-9:] # Normalize the 3-dimensional feature to range [0, 1] features_min = features.min(dim=0, keepdim=True).values.min(dim=1, keepdim=True).values.min(dim=2, keepdim=True).values features_max = features.max(dim=0, keepdim=True).values.max(dim=1, keepdim=True).values.max(dim=2, keepdim=True).values features = (features - features_min) / (features_max - features_min) ##### Save individual feature maps # # Create subdirectory for feature visualizations # feat_dir = os.path.join(save_dir, 'feature_maps') # if feat_type: # feat_dir = os.path.join(feat_dir, '-'.join(feat_type)) # os.makedirs(feat_dir, exist_ok=True) # for i in range(B): # # Extract and save the feature map (channels 3-6) # feat_map = features[i, :, :, 3:6].permute(2, 0, 1) # [3, H, W] # save_path = os.path.join(feat_dir, f'{i}_feat.png') # vutils.save_image(feat_map, save_path, normalize=False) # return feat_dir ##### Save feature maps in a single image # Set the size of the plot fig, axes = plt.subplots(B, 4, figsize=(W*4*0.01, H*B*0.01)) for i in range(B): # Get the original image image_tensor = images[i]['img'] assert image_tensor.dim() == 4 and image_tensor.size(0) == 1 and image_tensor.size(1) == 3, "Image tensor must have shape [1, 3, H, W]" image = image_tensor.squeeze(0).permute(1, 2, 0).numpy() # Convert to (H, W, 3) # Scale image values from [-1, 1] to [0, 1] image = (image + 1) / 2 ax = axes[i, 0] if B > 1 else axes[0] ax.imshow(image) ax.axis('off') # Visualize each 3-dimensional feature for j in range(3): ax = axes[i, j+1] if B > 1 else axes[j+1] if j * 3 < min(C, dim): # Check if the feature channels are available feature_to_plot = features[i, :, :, j*3:(j+1)*3].cpu().numpy() ax.imshow(feature_to_plot) else: # Plot white image if features are not available ax.imshow(torch.ones(H, W, 3).numpy()) ax.axis('off') # Reduce margins and spaces between images plt.subplots_adjust(wspace=0.005, hspace=0.005, left=0.01, right=0.99, top=0.99, bottom=0.01) # Save the entire plot if file_name is None: file_name = f'feat_dim{dim-9}-{dim}' if feat_type: feat_type_str = '-'.join(feat_type) file_name = file_name + f'_{feat_type_str}' save_path = os.path.join(save_dir, file_name + '.png') plt.savefig(save_path, bbox_inches='tight', pad_inches=0) plt.close() return save_path #### Open it if you visualize feature maps in Feat2GS's teaser # import matplotlib.colors as mcolors # from PIL import Image # morandi_colors = [ # '#8AA2A9', '#C98474', '#F2D0A9', '#8D9F87', '#A7A7A7', '#D98E73', '#B24C33', '#5E7460', '#4A6B8A', '#B2CBC2', # '#BBC990', '#6B859E', '#B45342', '#4E0000', '#3D0000', '#2C0000', '#1B0000', '#0A0000', '#DCAC99', '#6F936B', # '#EBA062', '#FED273', '#9A8EB4', '#706052', '#E9E5E5', '#C4D8D2', '#F2CBBD', '#F6F9F1', '#C5CABC', '#A3968B', # '#5C6974', '#BE7B6E', '#C67752', '#C18830', '#8C956C', '#CAC691', '#819992', '#4D797F', '#95AEB2', '#B6C4CF', # '#84291C', '#B9551F', '#A96400', '#374B6C', '#C8B493', '#677D5D', '#9882A2', '#2D5F53', '#D2A0AC', '#658D9A', # '#9A7265', '#EFE1D2', '#DDD8D1', '#D2C6BC', '#E3C9BC', '#B8AB9F', '#D8BEA4', '#E0D4C5', '#B8B8B6', '#D0CAC3', # '#9AA8B5', '#BBC9B9', '#E3E8D8', '#ADB3A4', '#C5C9BB', '#A3968B', '#C2A995', '#EDE1D1', '#EDE8E1', '#EDEBE1', # '#CFCFCC', '#AABAC6', '#DCDEE0', '#EAE5E7', '#B7AB9F', '#F7EFE3', '#DED8CF', '#ABCA99', '#C5CD8F', '#959491', # '#FFE481', '#C18E99', '#B07C86', '#9F6A73', '#8E5860', '#DEAD44', '#CD9B31', '#BC891E', '#AB770B', '#9A6500', # '#778144', '#666F31', '#555D1E', '#444B0B', '#333900', '#67587B', '#564668', '#684563', '#573350', '#684550', # '#57333D', '#46212A', '#350F17', '#240004', # ] # def rgb_to_hsv(rgb): # rgb = rgb.clamp(0, 1) # cmax, cmax_idx = rgb.max(dim=-1) # cmin = rgb.min(dim=-1).values # diff = cmax - cmin # h = torch.zeros_like(cmax) # h[cmax_idx == 0] = (((rgb[..., 1] - rgb[..., 2]) / diff) % 6)[cmax_idx == 0] # h[cmax_idx == 1] = (((rgb[..., 2] - rgb[..., 0]) / diff) + 2)[cmax_idx == 1] # h[cmax_idx == 2] = (((rgb[..., 0] - rgb[..., 1]) / diff) + 4)[cmax_idx == 2] # h[diff == 0] = 0 # If cmax == cmin # h = h / 6 # s = torch.zeros_like(cmax) # s[cmax != 0] = (diff / cmax)[cmax != 0] # v = cmax # return torch.stack([h, s, v], dim=-1) # def hsv_to_rgb(hsv): # h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2] # c = v * s # x = c * (1 - torch.abs((h * 6) % 2 - 1)) # m = v - c # rgb = torch.zeros_like(hsv) # mask = (h < 1/6) # rgb[mask] = torch.stack([c[mask], x[mask], torch.zeros_like(x[mask])], dim=-1) # mask = (1/6 <= h) & (h < 2/6) # rgb[mask] = torch.stack([x[mask], c[mask], torch.zeros_like(x[mask])], dim=-1) # mask = (2/6 <= h) & (h < 3/6) # rgb[mask] = torch.stack([torch.zeros_like(x[mask]), c[mask], x[mask]], dim=-1) # mask = (3/6 <= h) & (h < 4/6) # rgb[mask] = torch.stack([torch.zeros_like(x[mask]), x[mask], c[mask]], dim=-1) # mask = (4/6 <= h) & (h < 5/6) # rgb[mask] = torch.stack([x[mask], torch.zeros_like(x[mask]), c[mask]], dim=-1) # mask = (5/6 <= h) # rgb[mask] = torch.stack([c[mask], torch.zeros_like(x[mask]), x[mask]], dim=-1) # return rgb + m.unsqueeze(-1) # def interpolate_colors(colors, n_colors): # # Convert colors to RGB tensor # rgb_colors = torch.tensor([mcolors.to_rgb(color) for color in colors]) # # Convert RGB to HSV # hsv_colors = rgb_to_hsv(rgb_colors) # # Sort by hue # sorted_indices = torch.argsort(hsv_colors[:, 0]) # sorted_hsv_colors = hsv_colors[sorted_indices] # # Create interpolation indices # indices = torch.linspace(0, len(sorted_hsv_colors) - 1, n_colors) # # Perform interpolation # interpolated_hsv = torch.stack([ # torch.lerp(sorted_hsv_colors[int(i)], # sorted_hsv_colors[min(int(i) + 1, len(sorted_hsv_colors) - 1)], # i - int(i)) # for i in indices # ]) # # Convert interpolated result back to RGB # interpolated_rgb = hsv_to_rgb(interpolated_hsv) # return interpolated_rgb # def project_to_morandi(features, morandi_colors): # features_flat = features.reshape(-1, 3) # distances = torch.cdist(features_flat, morandi_colors) # # Get the indices of the closest colors # closest_color_indices = torch.argmin(distances, dim=1) # # Use the closest Morandi colors directly # features_morandi = morandi_colors[closest_color_indices] # features_morandi = features_morandi.reshape(features.shape) # return features_morandi # def smooth_color_transform(features, morandi_colors, smoothness=0.1): # features_flat = features.reshape(-1, 3) # distances = torch.cdist(features_flat, morandi_colors) # # Calculate weights # weights = torch.exp(-distances / smoothness) # weights = weights / weights.sum(dim=1, keepdim=True) # # Weighted average # features_morandi = torch.matmul(weights, morandi_colors) # features_morandi = features_morandi.reshape(features.shape) # return features_morandi # def histogram_matching(source, template): # """ # Match the histogram of the source tensor to that of the template tensor. # :param source: Source tensor with shape [B, H, W, 3] # :param template: Template tensor with shape [N, 3], where N is the number of colors # :return: Source tensor after histogram matching # """ # def match_cumulative_cdf(source, template): # src_values, src_indices, src_counts = torch.unique(source, return_inverse=True, return_counts=True) # tmpl_values, tmpl_counts = torch.unique(template, return_counts=True) # src_quantiles = torch.cumsum(src_counts.float(), 0) / source.numel() # tmpl_quantiles = torch.cumsum(tmpl_counts.float(), 0) / template.numel() # idx = torch.searchsorted(tmpl_quantiles, src_quantiles) # idx = torch.clamp(idx, 1, len(tmpl_quantiles)-1) # slope = (tmpl_values[idx] - tmpl_values[idx-1]) / (tmpl_quantiles[idx] - tmpl_quantiles[idx-1]) # interp_a_values = torch.lerp(tmpl_values[idx-1], tmpl_values[idx], # (src_quantiles - tmpl_quantiles[idx-1]) * slope) # return interp_a_values[src_indices].reshape(source.shape) # matched = torch.stack([match_cumulative_cdf(source[..., i], template[:, i]) for i in range(3)], dim=-1) # return matched # def process_features(features): # device = features.device # n_colors = 1024 # morandi_colors_tensor = interpolate_colors(morandi_colors, n_colors).to(device) # # morandi_colors_tensor = torch.tensor([mcolors.to_rgb(color) for color in morandi_colors]).to(device) # # features_morandi = project_to_morandi(features, morandi_colors_tensor) # # features_morandi = histogram_matching(features, morandi_colors_tensor) # features_morandi = smooth_color_transform(features, morandi_colors_tensor, smoothness=0.05) # return features_morandi.cpu().numpy() # def visualizer(features, images, save_dir, dim=9, feat_type=None, file_name=None): # import matplotlib # matplotlib.use('Agg') # import matplotlib.pyplot as plt # import numpy as np # import os # assert features.dim() == 4, "Input tensor must have 4 dimensions (B, H, W, C)" # B, H, W, C = features.size() # # Ensure features have at least 3 channels for RGB visualization # assert C >= 3, "Features must have at least 3 channels for RGB visualization" # features = features[..., :3] # # Normalize features to [0, 1] range # features_min = features.min(dim=0, keepdim=True).values.min(dim=1, keepdim=True).values.min(dim=2, keepdim=True).values # features_max = features.max(dim=0, keepdim=True).values.max(dim=1, keepdim=True).values.max(dim=2, keepdim=True).values # features = (features - features_min) / (features_max - features_min) # features_processed = process_features(features) # # Create the directory structure # vis_dir = os.path.join(save_dir, 'vis') # if feat_type: # feat_type_str = '-'.join(feat_type) # vis_dir = os.path.join(vis_dir, feat_type_str) # os.makedirs(vis_dir, exist_ok=True) # # Save individual images for each feature map # for i in range(B): # if file_name is None: # file_name = 'feat_morandi' # save_path = os.path.join(vis_dir, f'{file_name}_{i}.png') # # Convert to uint8 and save directly # img = Image.fromarray((features_processed[i] * 255).astype(np.uint8)) # img.save(save_path) # print(f"Feature maps have been saved in the directory: {vis_dir}") # return vis_dir def mv_visualizer(features, images, save_dir, dim=9, feat_type=None, file_name=None): """ Visualize features and corresponding images, and save the result. (For MASt3R decoder or head features) """ import matplotlib matplotlib.use('Agg') from matplotlib import pyplot as plt import os B, H, W, _ = features.size() features = features[..., dim-9:] # Normalize the 3-dimensional feature to range [0, 1] features_min = features.min(dim=0, keepdim=True).values.min(dim=1, keepdim=True).values.min(dim=2, keepdim=True).values features_max = features.max(dim=0, keepdim=True).values.max(dim=1, keepdim=True).values.max(dim=2, keepdim=True).values features = (features - features_min) / (features_max - features_min) rows = (B + 1) // 2 # Adjust rows for odd B fig, axes = plt.subplots(rows, 8, figsize=(W*8*0.01, H*rows*0.01)) for i in range(B//2): # Odd row: image and features image = (images[2*i]['img'].squeeze(0).permute(1, 2, 0).numpy() + 1) / 2 axes[i, 0].imshow(image) axes[i, 0].axis('off') for j in range(3): axes[i, j+1].imshow(features[2*i, :, :, j*3:(j+1)*3].cpu().numpy()) axes[i, j+1].axis('off') # Even row: image and features if 2*i + 1 < B: image = (images[2*i + 1]['img'].squeeze(0).permute(1, 2, 0).numpy() + 1) / 2 axes[i, 4].imshow(image) axes[i, 4].axis('off') for j in range(3): axes[i, j+5].imshow(features[2*i + 1, :, :, j*3:(j+1)*3].cpu().numpy()) axes[i, j+5].axis('off') # Handle last row if B is odd if B % 2 != 0: image = (images[-1]['img'].squeeze(0).permute(1, 2, 0).numpy() + 1) / 2 axes[-1, 0].imshow(image) axes[-1, 0].axis('off') for j in range(3): axes[-1, j+1].imshow(features[-1, :, :, j*3:(j+1)*3].cpu().numpy()) axes[-1, j+1].axis('off') # Hide unused columns in last row for j in range(4, 8): axes[-1, j].axis('off') plt.subplots_adjust(wspace=0.005, hspace=0.005, left=0.01, right=0.99, top=0.99, bottom=0.01) # Save the plot if file_name is None: file_name = f'feat_dim{dim-9}-{dim}' if feat_type: feat_type_str = '-'.join(feat_type) file_name = file_name + f'_{feat_type_str}' save_path = os.path.join(save_dir, file_name + '.png') plt.savefig(save_path, bbox_inches='tight', pad_inches=0) plt.close() return save_path def adjust_norm(image: torch.Tensor) -> torch.Tensor: inv_normalize = tvf.Normalize( mean=[-1, -1, -1], std=[1/0.5, 1/0.5, 1/0.5] ) correct_normalize = tvf.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) image = inv_normalize(image) image = correct_normalize(image) return image def adjust_midas_norm(image: torch.Tensor) -> torch.Tensor: inv_normalize = tvf.Normalize( mean=[-1, -1, -1], std=[1/0.5, 1/0.5, 1/0.5] ) correct_normalize = tvf.Normalize( mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5] ) image = inv_normalize(image) image = correct_normalize(image) return image def adjust_clip_norm(image: torch.Tensor) -> torch.Tensor: inv_normalize = tvf.Normalize( mean=[-1, -1, -1], std=[1/0.5, 1/0.5, 1/0.5] ) correct_normalize = tvf.Normalize( mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711] ) image = inv_normalize(image) image = correct_normalize(image) return image class UnNormalize(object): def __init__(self, mean, std): self.mean = mean self.std = std def __call__(self, image): image2 = torch.clone(image) if len(image2.shape) == 4: image2 = image2.permute(1, 0, 2, 3) for t, m, s in zip(image2, self.mean, self.std): t.mul_(s).add_(m) return image2.permute(1, 0, 2, 3) norm = tvf.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) unnorm = UnNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) midas_norm = tvf.Normalize([0.5] * 3, [0.5] * 3) midas_unnorm = UnNormalize([0.5] * 3, [0.5] * 3) def generate_iuv(B, H, W): i_coords = torch.arange(B).view(B, 1, 1, 1).expand(B, H, W, 1).float() / (B - 1) u_coords = torch.linspace(0, 1, W).view(1, 1, W, 1).expand(B, H, W, 1) v_coords = torch.linspace(0, 1, H).view(1, H, 1, 1).expand(B, H, W, 1) iuv_coords = torch.cat([i_coords, u_coords, v_coords], dim=-1) return iuv_coords class FeatureExtractor: """ Extracts and processes features from images using VFMs for per point(per pixel). Supports multiple VFM features, dimensionality reduction, feature upsampling, and visualization. Parameters: images (list): List of image info. method (str): Pointmap Init method, choose in ["dust3r", "mast3r"]. device (str): 'cuda'. feat_type (list): VFM, choose in ["dust3r", "mast3r", "dift", "dino_b16", "dinov2_b14", "radio", "clip_b16", "mae_b16", "midas_l16", "sam_base", "iuvrgb"]. feat_dim (int): PCA dimensions. img_base_path (str): Training view data directory path. model_path (str): Model path, './submodules/mast3r/checkpoints/'. vis_feat (bool): Visualize and save feature maps. vis_key (str): Feature type to visualize(only for mast3r), choose in ["decfeat", "desc"]. focal_avg (bool): Use averaging focal. """ def __init__(self, images, args, method): self.images = images self.method = method self.device = args.device self.feat_type = args.feat_type self.feat_dim = args.feat_dim self.img_base_path = args.img_base_path # self.use_featup = args.use_featup self.model_path = args.model_path self.vis_feat = args.vis_feat self.vis_key = args.vis_key self.focal_avg = args.focal_avg def get_dust3r_feat(self, **kw): model_path = os.path.join(self.model_path, "DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth") model = AsymmetricCroCo3DStereo.from_pretrained(model_path).to(self.device) output = inference(kw['pairs'], model, self.device, batch_size=1) scene = global_aligner(output, device=self.device, mode=GlobalAlignerMode.PointCloudOptimizer) if self.vis_key: assert self.vis_key == 'decfeat', f"Expected vis_key to be 'decfeat', but got {self.vis_key}" self.vis_decfeat(kw['pairs'], output=output) # del model, output # torch.cuda.empty_cache() return scene.stacked_feat def get_mast3r_feat(self, **kw): model_path = os.path.join(self.model_path, "MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth") model = AsymmetricMASt3R.from_pretrained(model_path).to(self.device) cache_dir = os.path.join(self.img_base_path, "cache") if os.path.exists(cache_dir): os.system(f'rm -rf {cache_dir}') scene = sparse_global_alignment(kw['train_img_list'], kw['pairs'], cache_dir, model, lr1=0.07, niter1=500, lr2=0.014, niter2=200, device=self.device, opt_depth='depth' in 'refine', shared_intrinsics=self.focal_avg, matching_conf_thr=5.) if self.vis_key: assert self.vis_key in ['decfeat', 'desc'], f"Expected vis_key to be 'decfeat' or 'desc', but got {self.vis_key}" self.vis_decfeat(kw['pairs'], model=model) # del model # torch.cuda.empty_cache() return scene.stacked_feat def get_feat(self, feat_type): """ Get features using Probe3D. """ cfg = OmegaConf.load(f"configs/backbone/{feat_type}.yaml") model = instantiate(cfg.model, output="dense", return_multilayer=False) model = model.to(self.device) if 'midas' in feat_type: image_norm = adjust_midas_norm(torch.cat([i['img'] for i in self.images])).to(self.device) # elif 'clip' in self.feat_type: # image_norm = adjust_clip_norm(torch.cat([i['img'] for i in self.images])).to(self.device) else: image_norm = adjust_norm(torch.cat([i['img'] for i in self.images])).to(self.device) with torch.no_grad(): feats = model(image_norm).permute(0, 2, 3, 1) # del model # torch.cuda.empty_cache() return feats # def get_feat(self, feat_type): # """ # Get features using FeatUp. # """ # original_feat_type = feat_type # use_norm = False if 'maskclip' in feat_type else True # if 'featup' in original_feat_type: # feat_type = feat_type.split('_featup')[0] # # feat_upsampler = torch.hub.load("mhamilton723/FeatUp", feat_type, use_norm=use_norm).to(device) # feat_upsampler = torch.hub.load("/home/chenyue/.cache/torch/hub/mhamilton723_FeatUp_main/", feat_type, use_norm=use_norm, source='local').to(self.device) ## offline # image_norm = adjust_norm(torch.cat([i['img'] for i in self.images])).to(self.device) # image_norm = F.interpolate(image_norm, size=(224, 224), mode='bilinear', align_corners=False) # if 'featup' in original_feat_type: # feats = feat_upsampler(image_norm).permute(0, 2, 3, 1) # else: # feats = feat_upsampler.model(image_norm).permute(0, 2, 3, 1) # return feats def get_iuvrgb(self): rgb = torch.cat([i['img'] for i in self.images]).permute(0, 2, 3, 1).to(self.device) feats = torch.cat([generate_iuv(*rgb.shape[:-1]).to(self.device), rgb], dim=-1) return feats def get_iuv(self): rgb = torch.cat([i['img'] for i in self.images]).permute(0, 2, 3, 1).to(self.device) feats = generate_iuv(*rgb.shape[:-1]).to(self.device) return feats def preprocess(self, feature, feat_dim, vis_feat=False, is_upsample=True): """ Preprocess features by applying PCA, upsampling, and optionally visualizing. """ if feat_dim: feature = pca(feature, feat_dim) # else: # feature_min = feature.min(dim=0, keepdim=True).values.min(dim=1, keepdim=True).values # feature_max = feature.max(dim=0, keepdim=True).values.max(dim=1, keepdim=True).values # feature = (feature - feature_min) / (feature_max - feature_min + 1e-6) # feature = feature - feature.mean(dim=[0,1,2], keepdim=True) torch.cuda.empty_cache() if (feature[0].shape[0:-1] != self.images[0]['true_shape'][0]).all() and is_upsample: feature = upsampler(feature, *[s for s in self.images[0]['true_shape'][0]]) print(f"Feature map size >>> height: {feature[0].shape[0]}, width: {feature[0].shape[1]}, channels: {feature[0].shape[2]}") if vis_feat: save_path = visualizer(feature, self.images, self.img_base_path, feat_type=self.feat_type) print(f"The encoder feature visualization has been saved at >>>>> {save_path}") return feature def vis_decfeat(self, pairs, **kw): """ Visualize decoder or head(only for mast3r) features. """ if 'output' in kw: output = kw['output'] else: output = inference(pairs, kw['model'], self.device, batch_size=1, verbose=False) decfeat1 = output['pred1'][self.vis_key].detach() decfeat2 = output['pred2'][self.vis_key].detach() # decfeat1 = pca(decfeat1, 9) # decfeat2 = pca(decfeat2, 9) decfeat = torch.stack([decfeat1, decfeat2], dim=1).view(-1, *decfeat1.shape[1:]) decfeat = torch.cat(torch.chunk(decfeat,2)[::-1], dim=0) decfeat = pca(decfeat, 9) if (decfeat.shape[1:-1] != self.images[0]['true_shape'][0]).all(): decfeat = upsampler(decfeat, *[s for s in self.images[0]['true_shape'][0]]) pair_images = [im for p in pairs[3:] + pairs[:3] for im in p] save_path = mv_visualizer(decfeat, pair_images, self.img_base_path, feat_type=self.feat_type, file_name=f'{self.vis_key}_pcaall_dim0-9') print(f"The decoder feature visualization has been saved at >>>>> {save_path}") def forward(self, **kw): feat_dim = self.feat_dim vis_feat = self.vis_feat and len(self.feat_type) == 1 is_upsample = len(self.feat_type) == 1 all_feats = {} for feat_type in self.feat_type: if feat_type == self.method: feats = kw['scene'].stacked_feat elif feat_type in ['dust3r', 'mast3r']: feats = getattr(self, f"get_{feat_type}_feat")(**kw) elif feat_type in ['iuv', 'iuvrgb']: feats = getattr(self, f"get_{feat_type}")() feat_dim = None else: feats = self.get_feat(feat_type) # feats = to_numpy(self.preprocess(feats)) all_feats[feat_type] = self.preprocess(feats.detach().clone(), feat_dim, vis_feat, is_upsample) if len(self.feat_type) > 1: all_feats = {k: (v - v.min()) / (v.max() - v.min()) for k, v in all_feats.items()} target_size = tuple(s // 16 for s in self.images[0]['true_shape'][0][:2]) tmp_feats = [] kickoff = [] for k, v in all_feats.items(): if k in ['iuv', 'iuvrgb']: # self.feat_dim -= v.shape[-1] kickoff.append(v) else: if v.shape[1:3] != target_size: v = F.interpolate(v.permute(0, 3, 1, 2), size=target_size, mode='bilinear', align_corners=False).permute(0, 2, 3, 1) tmp_feats.append(v) feats = self.preprocess(torch.cat(tmp_feats, dim=-1), self.feat_dim, self.vis_feat and not kickoff) if kickoff: feats = torch.cat([feats] + kickoff, dim=-1) feats = self.preprocess(feats, self.feat_dim, self.vis_feat, is_upsample=False) else: feats = all_feats[self.feat_type[0]] torch.cuda.empty_cache() return to_numpy(feats) def __call__(self, **kw): return self.forward(**kw) class InitMethod: """ Initialize pointmap and camera param via DUSt3R or MASt3R. """ def __init__(self, args): self.method = args.method self.n_views = args.n_views self.device = args.device self.img_base_path = args.img_base_path self.focal_avg = args.focal_avg self.tsdf_thresh = args.tsdf_thresh self.min_conf_thr = args.min_conf_thr if self.method == 'dust3r': self.model_path = os.path.join(args.model_path, "DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth") else: self.model_path = os.path.join(args.model_path, "MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth") def get_dust3r(self): return AsymmetricCroCo3DStereo.from_pretrained(self.model_path).to(self.device) def get_mast3r(self): return AsymmetricMASt3R.from_pretrained(self.model_path).to(self.device) def infer_dust3r(self, **kw): output = inference(kw['pairs'], kw['model'], self.device, batch_size=1) scene = global_aligner(output, device=self.device, mode=GlobalAlignerMode.PointCloudOptimizer) loss = compute_global_alignment(scene=scene, init="mst", niter=300, schedule='linear', lr=0.01, focal_avg=self.focal_avg, known_focal=kw.get('known_focal', None)) scene = scene.clean_pointcloud() return scene def infer_mast3r(self, **kw): cache_dir = os.path.join(self.img_base_path, "cache") if os.path.exists(cache_dir): os.system(f'rm -rf {cache_dir}') scene = sparse_global_alignment(kw['train_img_list'], kw['pairs'], cache_dir, kw['model'], lr1=0.07, niter1=500, lr2=0.014, niter2=200, device=self.device, opt_depth='depth' in 'refine', shared_intrinsics=self.focal_avg, matching_conf_thr=5.) return scene def get_dust3r_info(self, scene): imgs = to_numpy(scene.imgs) focals = scene.get_focals() poses = to_numpy(scene.get_im_poses()) pts3d = to_numpy(scene.get_pts3d()) # pts3d = to_numpy(scene.get_planes3d()) scene.min_conf_thr = float(scene.conf_trf(torch.tensor(1.0))) confidence_masks = to_numpy(scene.get_masks()) intrinsics = to_numpy(scene.get_intrinsics()) return imgs, focals, poses, intrinsics, pts3d, confidence_masks def get_mast3r_info(self, scene): imgs = to_numpy(scene.imgs) focals = scene.get_focals()[:,None] poses = to_numpy(scene.get_im_poses()) intrinsics = to_numpy(scene.intrinsics) tsdf = TSDFPostProcess(scene, TSDF_thresh=self.tsdf_thresh) pts3d, _, confs = to_numpy(tsdf.get_dense_pts3d(clean_depth=True)) pts3d = [arr.reshape((*imgs[0].shape[:2], 3)) for arr in pts3d] confidence_masks = np.array(to_numpy([c > self.min_conf_thr for c in confs])) return imgs, focals, poses, intrinsics, pts3d, confidence_masks def get_dust3r_depth(self, scene): return to_numpy(scene.get_depthmaps()) def get_mast3r_depth(self, scene): imgs = to_numpy(scene.imgs) tsdf = TSDFPostProcess(scene, TSDF_thresh=self.tsdf_thresh) _, depthmaps, _ = to_numpy(tsdf.get_dense_pts3d(clean_depth=True)) depthmaps = [arr.reshape((*imgs[0].shape[:2], 3)) for arr in depthmaps] return depthmaps def get_model(self): return getattr(self, f"get_{self.method}")() def infer(self, **kw): return getattr(self, f"infer_{self.method}")(**kw) def get_info(self, scene): return getattr(self, f"get_{self.method}_info")(scene) def get_depth(self, scene): return getattr(self, f"get_{self.method}_depth")(scene)