|
import rembg |
|
import cv2 |
|
import numpy as np |
|
import glm |
|
import torch |
|
from tqdm import tqdm |
|
import torchvision |
|
import torchvision.transforms.v2 as T |
|
from models.lrm.utils import render_utils |
|
import os |
|
|
|
import torch |
|
import numpy as np |
|
import scipy |
|
import cv2 |
|
from rembg import remove |
|
|
|
|
|
def load_mipmap(env_path): |
|
diffuse_path = os.path.join(env_path, "diffuse.pth") |
|
diffuse = torch.load(diffuse_path, map_location=torch.device('cpu')) |
|
|
|
specular = [] |
|
for i in range(6): |
|
specular_path = os.path.join(env_path, f"specular_{i}.pth") |
|
specular_tensor = torch.load(specular_path, map_location=torch.device('cpu')) |
|
specular.append(specular_tensor) |
|
return [specular, diffuse] |
|
|
|
def get_background(img_tensor): |
|
""" |
|
Args: |
|
img_tensor: 输入图像张量,形状为 (B, 3, H, W),数值范围为 [0, 1] 或 [0, 255]。 |
|
Returns: |
|
mask_tensor: 输出掩码张量,形状为 (B, 1, H, W),二值化。 |
|
""" |
|
B, C, H, W = img_tensor.shape |
|
assert C == 3, "Input tensor must have 3 channels (RGB)." |
|
|
|
|
|
img_numpy = (img_tensor.permute(0, 2, 3, 1) * 255).byte().cpu().numpy() |
|
|
|
masks = [] |
|
for i in range(B): |
|
|
|
mask = remove(img_numpy[i], only_mask=True) |
|
|
|
|
|
mask_binary = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1] |
|
|
|
|
|
masks.append(mask_binary[..., None]) |
|
|
|
|
|
masks = np.stack(masks, axis=0) |
|
|
|
|
|
mask_tensor = torch.from_numpy(masks).permute(0, 3, 1, 2).float() / 255.0 |
|
|
|
return mask_tensor |
|
|
|
def get_render_cameras_video(batch_size=1, M=120, radius=4.0, elevation=20.0, is_flexicubes=False, fov=50): |
|
""" |
|
Get the rendering camera parameters. |
|
""" |
|
train_res = [512, 512] |
|
cam_near_far = [0.1, 1000.0] |
|
fovy = np.deg2rad(fov) |
|
proj_mtx = render_utils.perspective(fovy, train_res[1] / train_res[0], cam_near_far[0], cam_near_far[1]) |
|
all_mv = [] |
|
all_mvp = [] |
|
all_campos = [] |
|
if isinstance(elevation, tuple): |
|
elevation_0 = np.deg2rad(elevation[0]) |
|
elevation_1 = np.deg2rad(elevation[1]) |
|
for i in range(M//2): |
|
azimuth = 2 * np.pi * i / (M // 2) |
|
z = radius * np.cos(azimuth) * np.sin(elevation_0) |
|
x = radius * np.sin(azimuth) * np.sin(elevation_0) |
|
y = radius * np.cos(elevation_0) |
|
|
|
eye = glm.vec3(x, y, z) |
|
at = glm.vec3(0.0, 0.0, 0.0) |
|
up = glm.vec3(0.0, 1.0, 0.0) |
|
view_matrix = glm.lookAt(eye, at, up) |
|
mv = torch.from_numpy(np.array(view_matrix)) |
|
mvp = proj_mtx @ (mv) |
|
campos = torch.linalg.inv(mv)[:3, 3] |
|
all_mv.append(mv[None, ...].cuda()) |
|
all_mvp.append(mvp[None, ...].cuda()) |
|
all_campos.append(campos[None, ...].cuda()) |
|
for i in range(M//2): |
|
azimuth = 2 * np.pi * i / (M // 2) |
|
z = radius * np.cos(azimuth) * np.sin(elevation_1) |
|
x = radius * np.sin(azimuth) * np.sin(elevation_1) |
|
y = radius * np.cos(elevation_1) |
|
|
|
eye = glm.vec3(x, y, z) |
|
at = glm.vec3(0.0, 0.0, 0.0) |
|
up = glm.vec3(0.0, 1.0, 0.0) |
|
view_matrix = glm.lookAt(eye, at, up) |
|
mv = torch.from_numpy(np.array(view_matrix)) |
|
mvp = proj_mtx @ (mv) |
|
campos = torch.linalg.inv(mv)[:3, 3] |
|
all_mv.append(mv[None, ...].cuda()) |
|
all_mvp.append(mvp[None, ...].cuda()) |
|
all_campos.append(campos[None, ...].cuda()) |
|
else: |
|
|
|
for i in range(M): |
|
azimuth = 2 * np.pi * i / M |
|
z = radius * np.cos(azimuth) * np.sin(elevation) |
|
x = radius * np.sin(azimuth) * np.sin(elevation) |
|
y = radius * np.cos(elevation) |
|
|
|
eye = glm.vec3(x, y, z) |
|
at = glm.vec3(0.0, 0.0, 0.0) |
|
up = glm.vec3(0.0, 1.0, 0.0) |
|
view_matrix = glm.lookAt(eye, at, up) |
|
mv = torch.from_numpy(np.array(view_matrix)) |
|
mvp = proj_mtx @ (mv) |
|
campos = torch.linalg.inv(mv)[:3, 3] |
|
all_mv.append(mv[None, ...].cuda()) |
|
all_mvp.append(mvp[None, ...].cuda()) |
|
all_campos.append(campos[None, ...].cuda()) |
|
all_mv = torch.stack(all_mv, dim=0).unsqueeze(0).squeeze(2) |
|
all_mvp = torch.stack(all_mvp, dim=0).unsqueeze(0).squeeze(2) |
|
all_campos = torch.stack(all_campos, dim=0).unsqueeze(0).squeeze(2) |
|
return all_mv, all_mvp, all_campos |
|
|
|
def get_render_cameras_frames(batch_size=1, radius=4.0, azimuths=0, elevations=20.0, fov=30): |
|
""" |
|
Get the rendering camera parameters. |
|
""" |
|
train_res = [512, 512] |
|
cam_near_far = [0.1, 1000.0] |
|
fovy = np.deg2rad(fov) |
|
proj_mtx = render_utils.perspective(fovy, train_res[1] / train_res[0], cam_near_far[0], cam_near_far[1]) |
|
all_mv = [] |
|
all_mvp = [] |
|
all_campos = [] |
|
elevations = 90 - elevations |
|
if isinstance(elevations, np.ndarray) or isinstance(elevations, torch.Tensor): |
|
if isinstance(elevations, torch.Tensor): |
|
elevations = elevations.cpu().numpy() |
|
if isinstance(azimuths, torch.Tensor): |
|
azimuths = azimuths.cpu().numpy() |
|
azimuths = np.deg2rad(azimuths) |
|
elevations = np.deg2rad(elevations) |
|
for azi, ele in zip(azimuths, elevations): |
|
z = radius * np.cos(azi) * np.sin(ele) |
|
x = radius * np.sin(azi) * np.sin(ele) |
|
y = radius * np.cos(ele) |
|
|
|
eye = glm.vec3(x, y, z) |
|
at = glm.vec3(0.0, 0.0, 0.0) |
|
up = glm.vec3(0.0, 1.0, 0.0) |
|
view_matrix = glm.lookAt(eye, at, up) |
|
mv = torch.from_numpy(np.array(view_matrix)) |
|
mvp = proj_mtx @ (mv) |
|
campos = torch.linalg.inv(mv)[:3, 3] |
|
all_mv.append(mv[None, ...].cuda()) |
|
all_mvp.append(mvp[None, ...].cuda()) |
|
all_campos.append(campos[None, ...].cuda()) |
|
|
|
else: |
|
z = radius * np.cos(azimuths) * np.sin(elevations) |
|
x = radius * np.sin(azimuths) * np.sin(elevations) |
|
y = radius * np.cos(elevations) |
|
|
|
eye = glm.vec3(x, y, z) |
|
at = glm.vec3(0.0, 0.0, 0.0) |
|
up = glm.vec3(0.0, 1.0, 0.0) |
|
view_matrix = glm.lookAt(eye, at, up) |
|
mv = torch.from_numpy(np.array(view_matrix)) |
|
mvp = proj_mtx @ (mv) |
|
campos = torch.linalg.inv(mv)[:3, 3] |
|
all_mv.append(mv[None, ...].cuda()) |
|
all_mvp.append(mvp[None, ...].cuda()) |
|
all_campos.append(campos[None, ...].cuda()) |
|
|
|
|
|
identity_azimuths = np.array([0]) |
|
identity_elevations = np.array([90]) |
|
z = radius * np.cos(identity_azimuths) * np.sin(identity_elevations) |
|
x = radius * np.sin(identity_azimuths) * np.sin(identity_elevations) |
|
y = radius * np.cos(identity_elevations) |
|
eye = glm.vec3(x, y, z) |
|
at = glm.vec3(0.0, 0.0, 0.0) |
|
up = glm.vec3(0.0, 1.0, 0.0) |
|
view_matrix = glm.lookAt(eye, at, up) |
|
identity_mv = torch.from_numpy(np.array(view_matrix)) |
|
|
|
all_mv = torch.stack(all_mv, dim=0).unsqueeze(0).squeeze(2) |
|
all_mvp = torch.stack(all_mvp, dim=0).unsqueeze(0).squeeze(2) |
|
all_campos = torch.stack(all_campos, dim=0).unsqueeze(0).squeeze(2) |
|
return all_mv, all_mvp, all_campos, identity_mv |
|
|
|
|
|
def worldNormal2camNormal(rot_w2c, normal_map_world): |
|
H,W,_ = normal_map_world.shape |
|
|
|
normal_map_world = normal_map_world[...,:3] |
|
|
|
normal_map_flat = normal_map_world.view(-1, 3) |
|
normal_map_camera_flat = torch.matmul(normal_map_flat.float(), rot_w2c.T.float()) |
|
|
|
|
|
normal_map_camera = normal_map_camera_flat.view(normal_map_world.shape) |
|
|
|
return normal_map_camera |
|
|
|
|
|
def trans_normal(normal, RT_w2c, RT_w2c_target): |
|
|
|
|
|
|
|
|
|
relative_RT = torch.matmul(RT_w2c_target[:3,:3], torch.linalg.inv(RT_w2c[:3,:3])) |
|
normal_target_cam = worldNormal2camNormal(relative_RT[:3,:3], normal) |
|
|
|
return normal_target_cam |
|
|
|
def render_frames(model, planes, render_cameras, camera_pos, env, materials, render_size=512, chunk_size=1, |
|
is_flexicubes=False, render_mv=None, local_normal=False, identity_mv=None): |
|
""" |
|
Render frames from triplanes. |
|
""" |
|
frames = [] |
|
albedos = [] |
|
pbr_spec_lights = [] |
|
pbr_diffuse_lights = [] |
|
normals = [] |
|
alphas = [] |
|
for i in tqdm(range(0, render_cameras.shape[1])): |
|
out = model.forward_geometry( |
|
planes, |
|
render_cameras[:, i:i+chunk_size], |
|
camera_pos[:, i:i+chunk_size], |
|
[[env]*chunk_size], |
|
[[materials]*chunk_size], |
|
render_size=render_size, |
|
) |
|
frame = out['pbr_img'] |
|
albedo = out['albedo'] |
|
pbr_spec_light = out['pbr_spec_light'] |
|
pbr_diffuse_light = out['pbr_diffuse_light'] |
|
normal = out['normal'] |
|
alpha = out['mask'] |
|
|
|
if local_normal: |
|
|
|
target_w2c = render_mv[0,i,:3,:3] |
|
identity_w2c = identity_mv[:3,:3] |
|
|
|
|
|
normal = trans_normal(normal.squeeze(0), identity_w2c.cuda(), target_w2c.cuda()) |
|
normal = normal / torch.norm(normal, dim=-1, keepdim=True) |
|
|
|
background_normal = torch.tensor([1,1,1], dtype=torch.float32, device=normal.device) |
|
normal = normal.unsqueeze(0) |
|
normal[...,0] *= -1 |
|
|
|
normal = normal * alpha.squeeze(0).permute(0,2,3,1) + background_normal * (1-alpha.squeeze(0).permute(0,2,3,1)) |
|
frames.append(frame) |
|
albedos.append(albedo) |
|
pbr_spec_lights.append(pbr_spec_light) |
|
pbr_diffuse_lights.append(pbr_diffuse_light) |
|
normals.append(normal) |
|
alphas.append(alpha) |
|
|
|
frames = torch.cat(frames, dim=1)[0] |
|
alphas = torch.cat(alphas, dim=1)[0] |
|
albedos = torch.cat(albedos, dim=1)[0] |
|
pbr_spec_lights = torch.cat(pbr_spec_lights, dim=1)[0] |
|
pbr_diffuse_lights = torch.cat(pbr_diffuse_lights, dim=1)[0] |
|
normals = torch.cat(normals, dim=0).permute(0,3,1,2)[:,:3] |
|
return frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas |
|
|
|
|
|
def mask_fix(mask, erode_dilate=0, smooth=0, remove_isolated_pixels=0, blur=0, fill_holes=0): |
|
masks = [] |
|
for m in mask: |
|
|
|
if erode_dilate != 0: |
|
if erode_dilate < 0: |
|
m = torch.from_numpy(scipy.ndimage.grey_erosion(m.cpu().numpy(), size=(-erode_dilate, -erode_dilate))) |
|
else: |
|
m = torch.from_numpy(scipy.ndimage.grey_dilation(m.cpu().numpy(), size=(erode_dilate, erode_dilate))) |
|
|
|
|
|
if fill_holes > 0: |
|
|
|
m = torch.from_numpy(scipy.ndimage.grey_closing(m.cpu().numpy(), size=(fill_holes, fill_holes))) |
|
|
|
|
|
if remove_isolated_pixels > 0: |
|
m = torch.from_numpy(scipy.ndimage.grey_opening(m.cpu().numpy(), size=(remove_isolated_pixels, remove_isolated_pixels))) |
|
|
|
|
|
if smooth > 0: |
|
if smooth % 2 == 0: |
|
smooth += 1 |
|
m = T.functional.gaussian_blur((m > 0.5).unsqueeze(0), smooth).squeeze(0) |
|
|
|
|
|
if blur > 0: |
|
if blur % 2 == 0: |
|
blur += 1 |
|
m = T.functional.gaussian_blur(m.float().unsqueeze(0), blur).squeeze(0) |
|
|
|
masks.append(m.float()) |
|
|
|
masks = torch.stack(masks, dim=0).float() |
|
|
|
return masks |
|
|
|
|
|
class NormalTransfer: |
|
def __init__(self): |
|
self.identity_w2c = torch.tensor([ |
|
[0.0, 0.0, 1.0, 0.0], |
|
[ 0.0, 1.0, 0.0, 0.0], |
|
[-1.0, 0.0, 0.0, 4.5]]).float() |
|
|
|
def look_at(self,camera_position, target_position, up_vector=np.array([0, 0, 1])): |
|
forward = camera_position - target_position |
|
forward = forward / np.linalg.norm(forward) |
|
|
|
right = np.cross(up_vector, forward) |
|
right = right / np.linalg.norm(right) |
|
|
|
up = np.cross(forward, right) |
|
|
|
rotation_matrix = np.array([right, up, forward]).T |
|
|
|
translation_matrix = np.eye(4) |
|
translation_matrix[:3, 3] = -camera_position |
|
|
|
rotation_homogeneous = np.eye(4) |
|
rotation_homogeneous[:3, :3] = rotation_matrix |
|
|
|
w2c = rotation_homogeneous @ translation_matrix |
|
return w2c |
|
|
|
def generate_target_pose(self, azimuths_deg, elevations_deg, radius=4.5): |
|
if isinstance(azimuths_deg, torch.Tensor): |
|
azimuths_deg = azimuths_deg.cpu().numpy() |
|
if isinstance(elevations_deg, torch.Tensor): |
|
elevations_deg = elevations_deg.cpu().numpy() |
|
azimuths = np.deg2rad(azimuths_deg) |
|
elevations = np.deg2rad(elevations_deg) |
|
|
|
x = radius * np.cos(azimuths) * np.cos(elevations) |
|
y = radius * np.sin(azimuths) * np.cos(elevations) |
|
z = radius * np.sin(elevations) |
|
camera_positions = np.stack([x, y, z], axis=-1) |
|
|
|
target_position = np.array([0, 0, 0]) |
|
|
|
|
|
w2c_matrices = [self.look_at(cam_pos, target_position) for cam_pos in camera_positions] |
|
w2c_matrices = np.stack(w2c_matrices, axis=0) |
|
return w2c_matrices |
|
|
|
def convert_to_blender(self, pose): |
|
|
|
w2c_opengl = pose |
|
w2c_opengl[[1, 2], :] = w2c_opengl[[2, 1], :] |
|
|
|
|
|
w2c_opengl[1] *= -1 |
|
R = w2c_opengl[:3, :3] |
|
t = w2c_opengl[:3, 3] |
|
|
|
cam_rec = np.asarray([[1, 0, 0], [0, -1, 0], [0, 0, -1]], np.float32) |
|
R = R.T |
|
t = -R @ t |
|
R_world2cv = cam_rec @ R |
|
t_world2cv = cam_rec @ t |
|
|
|
RT = np.concatenate([R_world2cv,t_world2cv[:,None]],1) |
|
return RT |
|
|
|
def worldNormal2camNormal(self, rot_w2c, normal_map_world): |
|
H,W,_ = normal_map_world.shape |
|
|
|
normal_map_world = normal_map_world[...,:3] |
|
|
|
normal_map_flat = normal_map_world.contiguous().view(-1, 3) |
|
|
|
normal_map_camera_flat = torch.matmul(normal_map_flat.float(), rot_w2c.T.float()) |
|
|
|
|
|
normal_map_camera = normal_map_camera_flat.view(normal_map_world.shape) |
|
|
|
return normal_map_camera |
|
|
|
def trans_normal(self, normal, RT_w2c, RT_w2c_target): |
|
""" |
|
:param normal: (H,W,3), torch tensor, range [-1,1] |
|
:param RT_w2c: (4,4), torch tensor, world to camera |
|
:param RT_w2c_target: (4,4), torch tensor, world to camera |
|
:return: normal_target_cam: (H,W,3), torch tensor, range [-1,1] |
|
""" |
|
relative_RT = torch.matmul(RT_w2c_target[:3,:3], torch.linalg.inv(RT_w2c[:3,:3])) |
|
normal_target_cam = self.worldNormal2camNormal(relative_RT[:3,:3], normal) |
|
|
|
return normal_target_cam |
|
|
|
def trans_local_2_global(self, normal_local, azimuths_deg, elevations_deg, radius=4.5, for_lotus=True): |
|
""" |
|
:param normal_local: (B,H,W,3), torch tensor, range [-1,1] |
|
:param azimuths_deg: (B,), numpy array, range [0,360] |
|
:param elevations_deg: (B,), numpy array, range [-90,90] |
|
:param radius: float, default 4.5 |
|
:return: global_normal: (B,H,W,3), torch tensor, range [-1,1] |
|
|
|
""" |
|
|
|
|
|
|
|
assert normal_local.shape[0] == azimuths_deg.shape[0] == elevations_deg.shape[0] |
|
identity_w2c = self.identity_w2c |
|
|
|
|
|
target_w2c = self.generate_target_pose(azimuths_deg, elevations_deg, radius) |
|
target_w2c = torch.from_numpy(np.stack([self.convert_to_blender(w2c) for w2c in target_w2c])).float() |
|
global_normal = [] |
|
|
|
|
|
for i in range(normal_local.shape[0]): |
|
normal_local_i = normal_local[i] |
|
normal_zero123 = self.trans_normal(normal_local_i, target_w2c[i], identity_w2c) |
|
global_normal.append(normal_zero123) |
|
|
|
global_normal = torch.stack(global_normal, dim=0) |
|
if for_lotus: |
|
global_normal[...,0] *= -1 |
|
global_normal = global_normal / torch.norm(global_normal, dim=-1, keepdim=True) |
|
return global_normal |
|
|
|
def trans_global_2_local(self, normal_local, azimuths_deg, elevations_deg, radius=4.5): |
|
""" |
|
:param normal_global: (B,H,W,3), torch tensor, range [-1,1] |
|
:param azimuths_deg: (B,), numpy array, range [0,360] |
|
:param elevations_deg: (B,), numpy array, range [-90,90] |
|
:param radius: float, default 4.5 |
|
:return: local_normal: (B,H,W,3), torch tensor, range [-1,1] |
|
|
|
""" |
|
print(f"normal_local.shape:{normal_local.shape}") |
|
print(f"azimuths_deg.shape:{azimuths_deg.shape}") |
|
print(f"elevations_deg.shape:{elevations_deg.shape}") |
|
assert normal_local.shape[0] == azimuths_deg.shape[0] == elevations_deg.shape[0] |
|
identity_w2c = self.identity_w2c |
|
|
|
|
|
target_w2c = self.generate_target_pose(azimuths_deg, elevations_deg, radius) |
|
target_w2c = torch.from_numpy(np.stack([w2c for w2c in target_w2c])).float() |
|
local_normal = [] |
|
|
|
|
|
for i in range(normal_local.shape[0]): |
|
normal_local_i = normal_local[i] |
|
normal = self.trans_normal(normal_local_i, identity_w2c, target_w2c[i]) |
|
local_normal.append(normal) |
|
|
|
local_normal = torch.stack(local_normal, dim=0) |
|
|
|
local_normal = local_normal / torch.norm(local_normal, dim=-1, keepdim=True) |
|
return local_normal |