Feat2GS / utils /feat_utils.py
faneggg's picture
init
123719b
raw
history blame
33.5 kB
import os
import torch
import torchvision.transforms as tvf
import torch.nn.functional as F
import numpy as np
from dust3r.utils.device import to_numpy
from dust3r.inference import inference
from dust3r.model import AsymmetricCroCo3DStereo
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
from utils.dust3r_utils import compute_global_alignment
from mast3r.model import AsymmetricMASt3R
from mast3r.cloud_opt.sparse_ga import sparse_global_alignment
from mast3r.cloud_opt.tsdf_optimizer import TSDFPostProcess
from hydra.utils import instantiate
from omegaconf import OmegaConf
class TorchPCA(object):
def __init__(self, n_components):
self.n_components = n_components
def fit(self, X):
self.mean_ = X.mean(dim=0)
unbiased = X - self.mean_.unsqueeze(0)
U, S, V = torch.pca_lowrank(unbiased, q=self.n_components, center=False, niter=50)
self.components_ = V.T
self.singular_values_ = S
return self
def transform(self, X):
t0 = X - self.mean_.unsqueeze(0)
projected = t0 @ self.components_.T
return projected
def pca(stacked_feat, dim):
flattened_feats = []
for feat in stacked_feat:
H, W, C = feat.shape
feat = feat.reshape(H * W, C).detach()
flattened_feats.append(feat)
x = torch.cat(flattened_feats, dim=0)
fit_pca = TorchPCA(n_components=dim).fit(x)
projected_feats = []
for feat in stacked_feat:
H, W, C = feat.shape
feat = feat.reshape(H * W, C).detach()
x_red = fit_pca.transform(feat)
projected_feats.append(x_red.reshape(H, W, dim))
projected_feats = torch.stack(projected_feats)
return projected_feats
def upsampler(feature, upsampled_height, upsampled_width, max_chunk=None):
"""
Upsample the feature tensor to the specified height and width.
Args:
- feature (torch.Tensor): The input tensor with size [B, H, W, C].
- upsampled_height (int): The target height after upsampling.
- upsampled_width (int): The target width after upsampling.
Returns:
- upsampled_feature (torch.Tensor): The upsampled tensor with size [B, upsampled_height, upsampled_width, C].
"""
# Permute the tensor to [B, C, H, W] for interpolation
feature = feature.permute(0, 3, 1, 2)
# Perform the upsampling
if max_chunk:
upsampled_chunks = []
for i in range(0, len(feature), max_chunk):
chunk = feature[i:i+max_chunk]
upsampled_chunk = F.interpolate(chunk, size=(upsampled_height, upsampled_width), mode='bilinear', align_corners=False)
upsampled_chunks.append(upsampled_chunk)
upsampled_feature = torch.cat(upsampled_chunks, dim=0)
else:
upsampled_feature = F.interpolate(feature, size=(upsampled_height, upsampled_width), mode='bilinear', align_corners=False)
# Permute back to [B, H, W, C]
upsampled_feature = upsampled_feature.permute(0, 2, 3, 1)
return upsampled_feature
def visualizer(features, images, save_dir, dim=9, feat_type=None, file_name=None):
"""
Visualize features and corresponding images, and save the result.
Args:
features (torch.Tensor): Feature tensor with shape [B, H, W, C].
images (list): List of dictionaries containing images with keys 'img'. Each image tensor has shape [1, 3, H, W]
and values in the range [-1, 1].
save_dir (str): Directory to save the resulting visualization.
feat_type (list): List of feature types.
file_name (str): Name of the file to save.
"""
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
import torchvision.utils as vutils
assert features.dim() == 4, "Input tensor must have 4 dimensions (B, H, W, C)"
B, H, W, C = features.size()
features = features[..., dim-9:]
# Normalize the 3-dimensional feature to range [0, 1]
features_min = features.min(dim=0, keepdim=True).values.min(dim=1, keepdim=True).values.min(dim=2, keepdim=True).values
features_max = features.max(dim=0, keepdim=True).values.max(dim=1, keepdim=True).values.max(dim=2, keepdim=True).values
features = (features - features_min) / (features_max - features_min)
##### Save individual feature maps
# # Create subdirectory for feature visualizations
# feat_dir = os.path.join(save_dir, 'feature_maps')
# if feat_type:
# feat_dir = os.path.join(feat_dir, '-'.join(feat_type))
# os.makedirs(feat_dir, exist_ok=True)
# for i in range(B):
# # Extract and save the feature map (channels 3-6)
# feat_map = features[i, :, :, 3:6].permute(2, 0, 1) # [3, H, W]
# save_path = os.path.join(feat_dir, f'{i}_feat.png')
# vutils.save_image(feat_map, save_path, normalize=False)
# return feat_dir
##### Save feature maps in a single image
# Set the size of the plot
fig, axes = plt.subplots(B, 4, figsize=(W*4*0.01, H*B*0.01))
for i in range(B):
# Get the original image
image_tensor = images[i]['img']
assert image_tensor.dim() == 4 and image_tensor.size(0) == 1 and image_tensor.size(1) == 3, "Image tensor must have shape [1, 3, H, W]"
image = image_tensor.squeeze(0).permute(1, 2, 0).numpy() # Convert to (H, W, 3)
# Scale image values from [-1, 1] to [0, 1]
image = (image + 1) / 2
ax = axes[i, 0] if B > 1 else axes[0]
ax.imshow(image)
ax.axis('off')
# Visualize each 3-dimensional feature
for j in range(3):
ax = axes[i, j+1] if B > 1 else axes[j+1]
if j * 3 < min(C, dim): # Check if the feature channels are available
feature_to_plot = features[i, :, :, j*3:(j+1)*3].cpu().numpy()
ax.imshow(feature_to_plot)
else: # Plot white image if features are not available
ax.imshow(torch.ones(H, W, 3).numpy())
ax.axis('off')
# Reduce margins and spaces between images
plt.subplots_adjust(wspace=0.005, hspace=0.005, left=0.01, right=0.99, top=0.99, bottom=0.01)
# Save the entire plot
if file_name is None:
file_name = f'feat_dim{dim-9}-{dim}'
if feat_type:
feat_type_str = '-'.join(feat_type)
file_name = file_name + f'_{feat_type_str}'
save_path = os.path.join(save_dir, file_name + '.png')
plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
plt.close()
return save_path
#### Open it if you visualize feature maps in Feat2GS's teaser
# import matplotlib.colors as mcolors
# from PIL import Image
# morandi_colors = [
# '#8AA2A9', '#C98474', '#F2D0A9', '#8D9F87', '#A7A7A7', '#D98E73', '#B24C33', '#5E7460', '#4A6B8A', '#B2CBC2',
# '#BBC990', '#6B859E', '#B45342', '#4E0000', '#3D0000', '#2C0000', '#1B0000', '#0A0000', '#DCAC99', '#6F936B',
# '#EBA062', '#FED273', '#9A8EB4', '#706052', '#E9E5E5', '#C4D8D2', '#F2CBBD', '#F6F9F1', '#C5CABC', '#A3968B',
# '#5C6974', '#BE7B6E', '#C67752', '#C18830', '#8C956C', '#CAC691', '#819992', '#4D797F', '#95AEB2', '#B6C4CF',
# '#84291C', '#B9551F', '#A96400', '#374B6C', '#C8B493', '#677D5D', '#9882A2', '#2D5F53', '#D2A0AC', '#658D9A',
# '#9A7265', '#EFE1D2', '#DDD8D1', '#D2C6BC', '#E3C9BC', '#B8AB9F', '#D8BEA4', '#E0D4C5', '#B8B8B6', '#D0CAC3',
# '#9AA8B5', '#BBC9B9', '#E3E8D8', '#ADB3A4', '#C5C9BB', '#A3968B', '#C2A995', '#EDE1D1', '#EDE8E1', '#EDEBE1',
# '#CFCFCC', '#AABAC6', '#DCDEE0', '#EAE5E7', '#B7AB9F', '#F7EFE3', '#DED8CF', '#ABCA99', '#C5CD8F', '#959491',
# '#FFE481', '#C18E99', '#B07C86', '#9F6A73', '#8E5860', '#DEAD44', '#CD9B31', '#BC891E', '#AB770B', '#9A6500',
# '#778144', '#666F31', '#555D1E', '#444B0B', '#333900', '#67587B', '#564668', '#684563', '#573350', '#684550',
# '#57333D', '#46212A', '#350F17', '#240004',
# ]
# def rgb_to_hsv(rgb):
# rgb = rgb.clamp(0, 1)
# cmax, cmax_idx = rgb.max(dim=-1)
# cmin = rgb.min(dim=-1).values
# diff = cmax - cmin
# h = torch.zeros_like(cmax)
# h[cmax_idx == 0] = (((rgb[..., 1] - rgb[..., 2]) / diff) % 6)[cmax_idx == 0]
# h[cmax_idx == 1] = (((rgb[..., 2] - rgb[..., 0]) / diff) + 2)[cmax_idx == 1]
# h[cmax_idx == 2] = (((rgb[..., 0] - rgb[..., 1]) / diff) + 4)[cmax_idx == 2]
# h[diff == 0] = 0 # If cmax == cmin
# h = h / 6
# s = torch.zeros_like(cmax)
# s[cmax != 0] = (diff / cmax)[cmax != 0]
# v = cmax
# return torch.stack([h, s, v], dim=-1)
# def hsv_to_rgb(hsv):
# h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2]
# c = v * s
# x = c * (1 - torch.abs((h * 6) % 2 - 1))
# m = v - c
# rgb = torch.zeros_like(hsv)
# mask = (h < 1/6)
# rgb[mask] = torch.stack([c[mask], x[mask], torch.zeros_like(x[mask])], dim=-1)
# mask = (1/6 <= h) & (h < 2/6)
# rgb[mask] = torch.stack([x[mask], c[mask], torch.zeros_like(x[mask])], dim=-1)
# mask = (2/6 <= h) & (h < 3/6)
# rgb[mask] = torch.stack([torch.zeros_like(x[mask]), c[mask], x[mask]], dim=-1)
# mask = (3/6 <= h) & (h < 4/6)
# rgb[mask] = torch.stack([torch.zeros_like(x[mask]), x[mask], c[mask]], dim=-1)
# mask = (4/6 <= h) & (h < 5/6)
# rgb[mask] = torch.stack([x[mask], torch.zeros_like(x[mask]), c[mask]], dim=-1)
# mask = (5/6 <= h)
# rgb[mask] = torch.stack([c[mask], torch.zeros_like(x[mask]), x[mask]], dim=-1)
# return rgb + m.unsqueeze(-1)
# def interpolate_colors(colors, n_colors):
# # Convert colors to RGB tensor
# rgb_colors = torch.tensor([mcolors.to_rgb(color) for color in colors])
# # Convert RGB to HSV
# hsv_colors = rgb_to_hsv(rgb_colors)
# # Sort by hue
# sorted_indices = torch.argsort(hsv_colors[:, 0])
# sorted_hsv_colors = hsv_colors[sorted_indices]
# # Create interpolation indices
# indices = torch.linspace(0, len(sorted_hsv_colors) - 1, n_colors)
# # Perform interpolation
# interpolated_hsv = torch.stack([
# torch.lerp(sorted_hsv_colors[int(i)],
# sorted_hsv_colors[min(int(i) + 1, len(sorted_hsv_colors) - 1)],
# i - int(i))
# for i in indices
# ])
# # Convert interpolated result back to RGB
# interpolated_rgb = hsv_to_rgb(interpolated_hsv)
# return interpolated_rgb
# def project_to_morandi(features, morandi_colors):
# features_flat = features.reshape(-1, 3)
# distances = torch.cdist(features_flat, morandi_colors)
# # Get the indices of the closest colors
# closest_color_indices = torch.argmin(distances, dim=1)
# # Use the closest Morandi colors directly
# features_morandi = morandi_colors[closest_color_indices]
# features_morandi = features_morandi.reshape(features.shape)
# return features_morandi
# def smooth_color_transform(features, morandi_colors, smoothness=0.1):
# features_flat = features.reshape(-1, 3)
# distances = torch.cdist(features_flat, morandi_colors)
# # Calculate weights
# weights = torch.exp(-distances / smoothness)
# weights = weights / weights.sum(dim=1, keepdim=True)
# # Weighted average
# features_morandi = torch.matmul(weights, morandi_colors)
# features_morandi = features_morandi.reshape(features.shape)
# return features_morandi
# def histogram_matching(source, template):
# """
# Match the histogram of the source tensor to that of the template tensor.
# :param source: Source tensor with shape [B, H, W, 3]
# :param template: Template tensor with shape [N, 3], where N is the number of colors
# :return: Source tensor after histogram matching
# """
# def match_cumulative_cdf(source, template):
# src_values, src_indices, src_counts = torch.unique(source, return_inverse=True, return_counts=True)
# tmpl_values, tmpl_counts = torch.unique(template, return_counts=True)
# src_quantiles = torch.cumsum(src_counts.float(), 0) / source.numel()
# tmpl_quantiles = torch.cumsum(tmpl_counts.float(), 0) / template.numel()
# idx = torch.searchsorted(tmpl_quantiles, src_quantiles)
# idx = torch.clamp(idx, 1, len(tmpl_quantiles)-1)
# slope = (tmpl_values[idx] - tmpl_values[idx-1]) / (tmpl_quantiles[idx] - tmpl_quantiles[idx-1])
# interp_a_values = torch.lerp(tmpl_values[idx-1], tmpl_values[idx],
# (src_quantiles - tmpl_quantiles[idx-1]) * slope)
# return interp_a_values[src_indices].reshape(source.shape)
# matched = torch.stack([match_cumulative_cdf(source[..., i], template[:, i]) for i in range(3)], dim=-1)
# return matched
# def process_features(features):
# device = features.device
# n_colors = 1024
# morandi_colors_tensor = interpolate_colors(morandi_colors, n_colors).to(device)
# # morandi_colors_tensor = torch.tensor([mcolors.to_rgb(color) for color in morandi_colors]).to(device)
# # features_morandi = project_to_morandi(features, morandi_colors_tensor)
# # features_morandi = histogram_matching(features, morandi_colors_tensor)
# features_morandi = smooth_color_transform(features, morandi_colors_tensor, smoothness=0.05)
# return features_morandi.cpu().numpy()
# def visualizer(features, images, save_dir, dim=9, feat_type=None, file_name=None):
# import matplotlib
# matplotlib.use('Agg')
# import matplotlib.pyplot as plt
# import numpy as np
# import os
# assert features.dim() == 4, "Input tensor must have 4 dimensions (B, H, W, C)"
# B, H, W, C = features.size()
# # Ensure features have at least 3 channels for RGB visualization
# assert C >= 3, "Features must have at least 3 channels for RGB visualization"
# features = features[..., :3]
# # Normalize features to [0, 1] range
# features_min = features.min(dim=0, keepdim=True).values.min(dim=1, keepdim=True).values.min(dim=2, keepdim=True).values
# features_max = features.max(dim=0, keepdim=True).values.max(dim=1, keepdim=True).values.max(dim=2, keepdim=True).values
# features = (features - features_min) / (features_max - features_min)
# features_processed = process_features(features)
# # Create the directory structure
# vis_dir = os.path.join(save_dir, 'vis')
# if feat_type:
# feat_type_str = '-'.join(feat_type)
# vis_dir = os.path.join(vis_dir, feat_type_str)
# os.makedirs(vis_dir, exist_ok=True)
# # Save individual images for each feature map
# for i in range(B):
# if file_name is None:
# file_name = 'feat_morandi'
# save_path = os.path.join(vis_dir, f'{file_name}_{i}.png')
# # Convert to uint8 and save directly
# img = Image.fromarray((features_processed[i] * 255).astype(np.uint8))
# img.save(save_path)
# print(f"Feature maps have been saved in the directory: {vis_dir}")
# return vis_dir
def mv_visualizer(features, images, save_dir, dim=9, feat_type=None, file_name=None):
"""
Visualize features and corresponding images, and save the result. (For MASt3R decoder or head features)
"""
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
import os
B, H, W, _ = features.size()
features = features[..., dim-9:]
# Normalize the 3-dimensional feature to range [0, 1]
features_min = features.min(dim=0, keepdim=True).values.min(dim=1, keepdim=True).values.min(dim=2, keepdim=True).values
features_max = features.max(dim=0, keepdim=True).values.max(dim=1, keepdim=True).values.max(dim=2, keepdim=True).values
features = (features - features_min) / (features_max - features_min)
rows = (B + 1) // 2 # Adjust rows for odd B
fig, axes = plt.subplots(rows, 8, figsize=(W*8*0.01, H*rows*0.01))
for i in range(B//2):
# Odd row: image and features
image = (images[2*i]['img'].squeeze(0).permute(1, 2, 0).numpy() + 1) / 2
axes[i, 0].imshow(image)
axes[i, 0].axis('off')
for j in range(3):
axes[i, j+1].imshow(features[2*i, :, :, j*3:(j+1)*3].cpu().numpy())
axes[i, j+1].axis('off')
# Even row: image and features
if 2*i + 1 < B:
image = (images[2*i + 1]['img'].squeeze(0).permute(1, 2, 0).numpy() + 1) / 2
axes[i, 4].imshow(image)
axes[i, 4].axis('off')
for j in range(3):
axes[i, j+5].imshow(features[2*i + 1, :, :, j*3:(j+1)*3].cpu().numpy())
axes[i, j+5].axis('off')
# Handle last row if B is odd
if B % 2 != 0:
image = (images[-1]['img'].squeeze(0).permute(1, 2, 0).numpy() + 1) / 2
axes[-1, 0].imshow(image)
axes[-1, 0].axis('off')
for j in range(3):
axes[-1, j+1].imshow(features[-1, :, :, j*3:(j+1)*3].cpu().numpy())
axes[-1, j+1].axis('off')
# Hide unused columns in last row
for j in range(4, 8):
axes[-1, j].axis('off')
plt.subplots_adjust(wspace=0.005, hspace=0.005, left=0.01, right=0.99, top=0.99, bottom=0.01)
# Save the plot
if file_name is None:
file_name = f'feat_dim{dim-9}-{dim}'
if feat_type:
feat_type_str = '-'.join(feat_type)
file_name = file_name + f'_{feat_type_str}'
save_path = os.path.join(save_dir, file_name + '.png')
plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
plt.close()
return save_path
def adjust_norm(image: torch.Tensor) -> torch.Tensor:
inv_normalize = tvf.Normalize(
mean=[-1, -1, -1],
std=[1/0.5, 1/0.5, 1/0.5]
)
correct_normalize = tvf.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
image = inv_normalize(image)
image = correct_normalize(image)
return image
def adjust_midas_norm(image: torch.Tensor) -> torch.Tensor:
inv_normalize = tvf.Normalize(
mean=[-1, -1, -1],
std=[1/0.5, 1/0.5, 1/0.5]
)
correct_normalize = tvf.Normalize(
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5]
)
image = inv_normalize(image)
image = correct_normalize(image)
return image
def adjust_clip_norm(image: torch.Tensor) -> torch.Tensor:
inv_normalize = tvf.Normalize(
mean=[-1, -1, -1],
std=[1/0.5, 1/0.5, 1/0.5]
)
correct_normalize = tvf.Normalize(
mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711]
)
image = inv_normalize(image)
image = correct_normalize(image)
return image
class UnNormalize(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, image):
image2 = torch.clone(image)
if len(image2.shape) == 4:
image2 = image2.permute(1, 0, 2, 3)
for t, m, s in zip(image2, self.mean, self.std):
t.mul_(s).add_(m)
return image2.permute(1, 0, 2, 3)
norm = tvf.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
unnorm = UnNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
midas_norm = tvf.Normalize([0.5] * 3, [0.5] * 3)
midas_unnorm = UnNormalize([0.5] * 3, [0.5] * 3)
def generate_iuv(B, H, W):
i_coords = torch.arange(B).view(B, 1, 1, 1).expand(B, H, W, 1).float() / (B - 1)
u_coords = torch.linspace(0, 1, W).view(1, 1, W, 1).expand(B, H, W, 1)
v_coords = torch.linspace(0, 1, H).view(1, H, 1, 1).expand(B, H, W, 1)
iuv_coords = torch.cat([i_coords, u_coords, v_coords], dim=-1)
return iuv_coords
class FeatureExtractor:
"""
Extracts and processes features from images using VFMs for per point(per pixel).
Supports multiple VFM features, dimensionality reduction, feature upsampling, and visualization.
Parameters:
images (list): List of image info.
method (str): Pointmap Init method, choose in ["dust3r", "mast3r"].
device (str): 'cuda'.
feat_type (list): VFM, choose in ["dust3r", "mast3r", "dift", "dino_b16", "dinov2_b14", "radio", "clip_b16", "mae_b16", "midas_l16", "sam_base", "iuvrgb"].
feat_dim (int): PCA dimensions.
img_base_path (str): Training view data directory path.
model_path (str): Model path, './submodules/mast3r/checkpoints/'.
vis_feat (bool): Visualize and save feature maps.
vis_key (str): Feature type to visualize(only for mast3r), choose in ["decfeat", "desc"].
focal_avg (bool): Use averaging focal.
"""
def __init__(self, images, args, method):
self.images = images
self.method = method
self.device = args.device
self.feat_type = args.feat_type
self.feat_dim = args.feat_dim
self.img_base_path = args.img_base_path
# self.use_featup = args.use_featup
self.model_path = args.model_path
self.vis_feat = args.vis_feat
self.vis_key = args.vis_key
self.focal_avg = args.focal_avg
def get_dust3r_feat(self, **kw):
model_path = os.path.join(self.model_path, "DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth")
model = AsymmetricCroCo3DStereo.from_pretrained(model_path).to(self.device)
output = inference(kw['pairs'], model, self.device, batch_size=1)
scene = global_aligner(output, device=self.device, mode=GlobalAlignerMode.PointCloudOptimizer)
if self.vis_key:
assert self.vis_key == 'decfeat', f"Expected vis_key to be 'decfeat', but got {self.vis_key}"
self.vis_decfeat(kw['pairs'], output=output)
# del model, output
# torch.cuda.empty_cache()
return scene.stacked_feat
def get_mast3r_feat(self, **kw):
model_path = os.path.join(self.model_path, "MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth")
model = AsymmetricMASt3R.from_pretrained(model_path).to(self.device)
cache_dir = os.path.join(self.img_base_path, "cache")
if os.path.exists(cache_dir):
os.system(f'rm -rf {cache_dir}')
scene = sparse_global_alignment(kw['train_img_list'], kw['pairs'], cache_dir,
model, lr1=0.07, niter1=500, lr2=0.014, niter2=200, device=self.device,
opt_depth='depth' in 'refine', shared_intrinsics=self.focal_avg,
matching_conf_thr=5.)
if self.vis_key:
assert self.vis_key in ['decfeat', 'desc'], f"Expected vis_key to be 'decfeat' or 'desc', but got {self.vis_key}"
self.vis_decfeat(kw['pairs'], model=model)
# del model
# torch.cuda.empty_cache()
return scene.stacked_feat
def get_feat(self, feat_type):
"""
Get features using Probe3D.
"""
cfg = OmegaConf.load(f"configs/backbone/{feat_type}.yaml")
model = instantiate(cfg.model, output="dense", return_multilayer=False)
model = model.to(self.device)
if 'midas' in feat_type:
image_norm = adjust_midas_norm(torch.cat([i['img'] for i in self.images])).to(self.device)
# elif 'clip' in self.feat_type:
# image_norm = adjust_clip_norm(torch.cat([i['img'] for i in self.images])).to(self.device)
else:
image_norm = adjust_norm(torch.cat([i['img'] for i in self.images])).to(self.device)
with torch.no_grad():
feats = model(image_norm).permute(0, 2, 3, 1)
# del model
# torch.cuda.empty_cache()
return feats
# def get_feat(self, feat_type):
# """
# Get features using FeatUp.
# """
# original_feat_type = feat_type
# use_norm = False if 'maskclip' in feat_type else True
# if 'featup' in original_feat_type:
# feat_type = feat_type.split('_featup')[0]
# # feat_upsampler = torch.hub.load("mhamilton723/FeatUp", feat_type, use_norm=use_norm).to(device)
# feat_upsampler = torch.hub.load("/home/chenyue/.cache/torch/hub/mhamilton723_FeatUp_main/", feat_type, use_norm=use_norm, source='local').to(self.device) ## offline
# image_norm = adjust_norm(torch.cat([i['img'] for i in self.images])).to(self.device)
# image_norm = F.interpolate(image_norm, size=(224, 224), mode='bilinear', align_corners=False)
# if 'featup' in original_feat_type:
# feats = feat_upsampler(image_norm).permute(0, 2, 3, 1)
# else:
# feats = feat_upsampler.model(image_norm).permute(0, 2, 3, 1)
# return feats
def get_iuvrgb(self):
rgb = torch.cat([i['img'] for i in self.images]).permute(0, 2, 3, 1).to(self.device)
feats = torch.cat([generate_iuv(*rgb.shape[:-1]).to(self.device), rgb], dim=-1)
return feats
def get_iuv(self):
rgb = torch.cat([i['img'] for i in self.images]).permute(0, 2, 3, 1).to(self.device)
feats = generate_iuv(*rgb.shape[:-1]).to(self.device)
return feats
def preprocess(self, feature, feat_dim, vis_feat=False, is_upsample=True):
"""
Preprocess features by applying PCA, upsampling, and optionally visualizing.
"""
if feat_dim:
feature = pca(feature, feat_dim)
# else:
# feature_min = feature.min(dim=0, keepdim=True).values.min(dim=1, keepdim=True).values
# feature_max = feature.max(dim=0, keepdim=True).values.max(dim=1, keepdim=True).values
# feature = (feature - feature_min) / (feature_max - feature_min + 1e-6)
# feature = feature - feature.mean(dim=[0,1,2], keepdim=True)
torch.cuda.empty_cache()
if (feature[0].shape[0:-1] != self.images[0]['true_shape'][0]).all() and is_upsample:
feature = upsampler(feature, *[s for s in self.images[0]['true_shape'][0]])
print(f"Feature map size >>> height: {feature[0].shape[0]}, width: {feature[0].shape[1]}, channels: {feature[0].shape[2]}")
if vis_feat:
save_path = visualizer(feature, self.images, self.img_base_path, feat_type=self.feat_type)
print(f"The encoder feature visualization has been saved at >>>>> {save_path}")
return feature
def vis_decfeat(self, pairs, **kw):
"""
Visualize decoder or head(only for mast3r) features.
"""
if 'output' in kw:
output = kw['output']
else:
output = inference(pairs, kw['model'], self.device, batch_size=1, verbose=False)
decfeat1 = output['pred1'][self.vis_key].detach()
decfeat2 = output['pred2'][self.vis_key].detach()
# decfeat1 = pca(decfeat1, 9)
# decfeat2 = pca(decfeat2, 9)
decfeat = torch.stack([decfeat1, decfeat2], dim=1).view(-1, *decfeat1.shape[1:])
decfeat = torch.cat(torch.chunk(decfeat,2)[::-1], dim=0)
decfeat = pca(decfeat, 9)
if (decfeat.shape[1:-1] != self.images[0]['true_shape'][0]).all():
decfeat = upsampler(decfeat, *[s for s in self.images[0]['true_shape'][0]])
pair_images = [im for p in pairs[3:] + pairs[:3] for im in p]
save_path = mv_visualizer(decfeat, pair_images, self.img_base_path,
feat_type=self.feat_type, file_name=f'{self.vis_key}_pcaall_dim0-9')
print(f"The decoder feature visualization has been saved at >>>>> {save_path}")
def forward(self, **kw):
feat_dim = self.feat_dim
vis_feat = self.vis_feat and len(self.feat_type) == 1
is_upsample = len(self.feat_type) == 1
all_feats = {}
for feat_type in self.feat_type:
if feat_type == self.method:
feats = kw['scene'].stacked_feat
elif feat_type in ['dust3r', 'mast3r']:
feats = getattr(self, f"get_{feat_type}_feat")(**kw)
elif feat_type in ['iuv', 'iuvrgb']:
feats = getattr(self, f"get_{feat_type}")()
feat_dim = None
else:
feats = self.get_feat(feat_type)
# feats = to_numpy(self.preprocess(feats))
all_feats[feat_type] = self.preprocess(feats.detach().clone(), feat_dim, vis_feat, is_upsample)
if len(self.feat_type) > 1:
all_feats = {k: (v - v.min()) / (v.max() - v.min()) for k, v in all_feats.items()}
target_size = tuple(s // 16 for s in self.images[0]['true_shape'][0][:2])
tmp_feats = []
kickoff = []
for k, v in all_feats.items():
if k in ['iuv', 'iuvrgb']:
# self.feat_dim -= v.shape[-1]
kickoff.append(v)
else:
if v.shape[1:3] != target_size:
v = F.interpolate(v.permute(0, 3, 1, 2), size=target_size,
mode='bilinear', align_corners=False).permute(0, 2, 3, 1)
tmp_feats.append(v)
feats = self.preprocess(torch.cat(tmp_feats, dim=-1), self.feat_dim, self.vis_feat and not kickoff)
if kickoff:
feats = torch.cat([feats] + kickoff, dim=-1)
feats = self.preprocess(feats, self.feat_dim, self.vis_feat, is_upsample=False)
else:
feats = all_feats[self.feat_type[0]]
torch.cuda.empty_cache()
return to_numpy(feats)
def __call__(self, **kw):
return self.forward(**kw)
class InitMethod:
"""
Initialize pointmap and camera param via DUSt3R or MASt3R.
"""
def __init__(self, args):
self.method = args.method
self.n_views = args.n_views
self.device = args.device
self.img_base_path = args.img_base_path
self.focal_avg = args.focal_avg
self.tsdf_thresh = args.tsdf_thresh
self.min_conf_thr = args.min_conf_thr
if self.method == 'dust3r':
self.model_path = os.path.join(args.model_path, "DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth")
else:
self.model_path = os.path.join(args.model_path, "MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth")
def get_dust3r(self):
return AsymmetricCroCo3DStereo.from_pretrained(self.model_path).to(self.device)
def get_mast3r(self):
return AsymmetricMASt3R.from_pretrained(self.model_path).to(self.device)
def infer_dust3r(self, **kw):
output = inference(kw['pairs'], kw['model'], self.device, batch_size=1)
scene = global_aligner(output, device=self.device, mode=GlobalAlignerMode.PointCloudOptimizer)
loss = compute_global_alignment(scene=scene, init="mst", niter=300, schedule='linear', lr=0.01,
focal_avg=self.focal_avg, known_focal=kw.get('known_focal', None))
scene = scene.clean_pointcloud()
return scene
def infer_mast3r(self, **kw):
cache_dir = os.path.join(self.img_base_path, "cache")
if os.path.exists(cache_dir):
os.system(f'rm -rf {cache_dir}')
scene = sparse_global_alignment(kw['train_img_list'], kw['pairs'], cache_dir,
kw['model'], lr1=0.07, niter1=500, lr2=0.014, niter2=200, device=self.device,
opt_depth='depth' in 'refine', shared_intrinsics=self.focal_avg,
matching_conf_thr=5.)
return scene
def get_dust3r_info(self, scene):
imgs = to_numpy(scene.imgs)
focals = scene.get_focals()
poses = to_numpy(scene.get_im_poses())
pts3d = to_numpy(scene.get_pts3d())
# pts3d = to_numpy(scene.get_planes3d())
scene.min_conf_thr = float(scene.conf_trf(torch.tensor(1.0)))
confidence_masks = to_numpy(scene.get_masks())
intrinsics = to_numpy(scene.get_intrinsics())
return imgs, focals, poses, intrinsics, pts3d, confidence_masks
def get_mast3r_info(self, scene):
imgs = to_numpy(scene.imgs)
focals = scene.get_focals()[:,None]
poses = to_numpy(scene.get_im_poses())
intrinsics = to_numpy(scene.intrinsics)
tsdf = TSDFPostProcess(scene, TSDF_thresh=self.tsdf_thresh)
pts3d, _, confs = to_numpy(tsdf.get_dense_pts3d(clean_depth=True))
pts3d = [arr.reshape((*imgs[0].shape[:2], 3)) for arr in pts3d]
confidence_masks = np.array(to_numpy([c > self.min_conf_thr for c in confs]))
return imgs, focals, poses, intrinsics, pts3d, confidence_masks
def get_dust3r_depth(self, scene):
return to_numpy(scene.get_depthmaps())
def get_mast3r_depth(self, scene):
imgs = to_numpy(scene.imgs)
tsdf = TSDFPostProcess(scene, TSDF_thresh=self.tsdf_thresh)
_, depthmaps, _ = to_numpy(tsdf.get_dense_pts3d(clean_depth=True))
depthmaps = [arr.reshape((*imgs[0].shape[:2], 3)) for arr in depthmaps]
return depthmaps
def get_model(self):
return getattr(self, f"get_{self.method}")()
def infer(self, **kw):
return getattr(self, f"infer_{self.method}")(**kw)
def get_info(self, scene):
return getattr(self, f"get_{self.method}_info")(scene)
def get_depth(self, scene):
return getattr(self, f"get_{self.method}_depth")(scene)