MoGe / utils3d /torch /rasterization.py
Ruicheng's picture
first commit
ec0c8fa
from typing import *
import torch
import nvdiffrast.torch as dr
from . import utils, transforms, mesh
from ._helpers import batched
__all__ = [
'RastContext',
'rasterize_triangle_faces',
'warp_image_by_depth',
'warp_image_by_forward_flow',
]
class RastContext:
"""
Create a rasterization context. Nothing but a wrapper of nvdiffrast.torch.RasterizeCudaContext or nvdiffrast.torch.RasterizeGLContext.
"""
def __init__(self, nvd_ctx: Union[dr.RasterizeCudaContext, dr.RasterizeGLContext] = None, *, backend: Literal['cuda', 'gl'] = 'gl', device: Union[str, torch.device] = None):
import nvdiffrast.torch as dr
if nvd_ctx is not None:
self.nvd_ctx = nvd_ctx
return
if backend == 'gl':
self.nvd_ctx = dr.RasterizeGLContext(device=device)
elif backend == 'cuda':
self.nvd_ctx = dr.RasterizeCudaContext(device=device)
else:
raise ValueError(f'Unknown backend: {backend}')
def rasterize_triangle_faces(
ctx: RastContext,
vertices: torch.Tensor,
faces: torch.Tensor,
attr: torch.Tensor,
width: int,
height: int,
model: torch.Tensor = None,
view: torch.Tensor = None,
projection: torch.Tensor = None,
antialiasing: Union[bool, List[int]] = True,
diff_attrs: Union[None, List[int]] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""
Rasterize a mesh with vertex attributes.
Args:
ctx (GLContext): rasterizer context
vertices (np.ndarray): (B, N, 2 or 3 or 4)
faces (torch.Tensor): (T, 3)
attr (torch.Tensor): (B, N, C)
width (int): width of the output image
height (int): height of the output image
model (torch.Tensor, optional): ([B,] 4, 4) model matrix. Defaults to None (identity).
view (torch.Tensor, optional): ([B,] 4, 4) view matrix. Defaults to None (identity).
projection (torch.Tensor, optional): ([B,] 4, 4) projection matrix. Defaults to None (identity).
antialiasing (Union[bool, List[int]], optional): whether to perform antialiasing. Defaults to True. If a list of indices is provided, only those channels will be antialiased.
diff_attrs (Union[None, List[int]], optional): indices of attributes to compute screen-space derivatives. Defaults to None.
Returns:
image: (torch.Tensor): (B, C, H, W)
depth: (torch.Tensor): (B, H, W) screen space depth, ranging from 0 (near) to 1. (far)
NOTE: Empty pixels will have depth 1., i.e. far plane.
"""
assert vertices.ndim == 3
assert faces.ndim == 2
if vertices.shape[-1] == 2:
vertices = torch.cat([vertices, torch.zeros_like(vertices[..., :1]), torch.ones_like(vertices[..., :1])], dim=-1)
elif vertices.shape[-1] == 3:
vertices = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1)
elif vertices.shape[-1] == 4:
pass
else:
raise ValueError(f'Wrong shape of vertices: {vertices.shape}')
mvp = projection if projection is not None else torch.eye(4).to(vertices)
if view is not None:
mvp = mvp @ view
if model is not None:
mvp = mvp @ model
pos_clip = vertices @ mvp.transpose(-1, -2)
faces = faces.contiguous()
attr = attr.contiguous()
rast_out, rast_db = dr.rasterize(ctx.nvd_ctx, pos_clip, faces, resolution=[height, width], grad_db=True)
image, image_dr = dr.interpolate(attr, rast_out, faces, rast_db, diff_attrs=diff_attrs)
if antialiasing == True:
image = dr.antialias(image, rast_out, pos_clip, faces)
elif isinstance(antialiasing, list):
aa_image = dr.antialias(image[..., antialiasing], rast_out, pos_clip, faces)
image[..., antialiasing] = aa_image
image = image.flip(1).permute(0, 3, 1, 2)
depth = rast_out[..., 2].flip(1)
depth = (depth * 0.5 + 0.5) * (depth > 0).float() + (depth == 0).float()
if diff_attrs is not None:
image_dr = image_dr.flip(1).permute(0, 3, 1, 2)
return image, depth, image_dr
return image, depth
def texture(
ctx: RastContext,
uv: torch.Tensor,
uv_da: torch.Tensor,
texture: torch.Tensor,
) -> torch.Tensor:
dr.texture(ctx.nvd_ctx, uv, texture)
def warp_image_by_depth(
ctx: RastContext,
depth: torch.FloatTensor,
image: torch.FloatTensor = None,
mask: torch.BoolTensor = None,
width: int = None,
height: int = None,
*,
extrinsics_src: torch.FloatTensor = None,
extrinsics_tgt: torch.FloatTensor = None,
intrinsics_src: torch.FloatTensor = None,
intrinsics_tgt: torch.FloatTensor = None,
near: float = 0.1,
far: float = 100.0,
antialiasing: bool = True,
backslash: bool = False,
padding: int = 0,
return_uv: bool = False,
return_dr: bool = False,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.BoolTensor, Optional[torch.FloatTensor], Optional[torch.FloatTensor]]:
"""
Warp image by depth.
NOTE: if batch size is 1, image mesh will be triangulated aware of the depth, yielding less distorted results.
Otherwise, image mesh will be triangulated simply for batch rendering.
Args:
ctx (Union[dr.RasterizeCudaContext, dr.RasterizeGLContext]): rasterization context
depth (torch.Tensor): (B, H, W) linear depth
image (torch.Tensor): (B, C, H, W). None to use image space uv. Defaults to None.
width (int, optional): width of the output image. None to use the same as depth. Defaults to None.
height (int, optional): height of the output image. Defaults the same as depth..
extrinsics_src (torch.Tensor, optional): (B, 4, 4) extrinsics matrix for source. None to use identity. Defaults to None.
extrinsics_tgt (torch.Tensor, optional): (B, 4, 4) extrinsics matrix for target. None to use identity. Defaults to None.
intrinsics_src (torch.Tensor, optional): (B, 3, 3) intrinsics matrix for source. None to use the same as target. Defaults to None.
intrinsics_tgt (torch.Tensor, optional): (B, 3, 3) intrinsics matrix for target. None to use the same as source. Defaults to None.
near (float, optional): near plane. Defaults to 0.1.
far (float, optional): far plane. Defaults to 100.0.
antialiasing (bool, optional): whether to perform antialiasing. Defaults to True.
backslash (bool, optional): whether to use backslash triangulation. Defaults to False.
padding (int, optional): padding of the image. Defaults to 0.
return_uv (bool, optional): whether to return the uv. Defaults to False.
return_dr (bool, optional): whether to return the image-space derivatives of uv. Defaults to False.
Returns:
image: (torch.FloatTensor): (B, C, H, W) rendered image
depth: (torch.FloatTensor): (B, H, W) linear depth, ranging from 0 to inf
mask: (torch.BoolTensor): (B, H, W) mask of valid pixels
uv: (torch.FloatTensor): (B, 2, H, W) image-space uv
dr: (torch.FloatTensor): (B, 4, H, W) image-space derivatives of uv
"""
assert depth.ndim == 3
batch_size = depth.shape[0]
if width is None:
width = depth.shape[-1]
if height is None:
height = depth.shape[-2]
if image is not None:
assert image.shape[-2:] == depth.shape[-2:], f'Shape of image {image.shape} does not match shape of depth {depth.shape}'
if extrinsics_src is None:
extrinsics_src = torch.eye(4).to(depth)
if extrinsics_tgt is None:
extrinsics_tgt = torch.eye(4).to(depth)
if intrinsics_src is None:
intrinsics_src = intrinsics_tgt
if intrinsics_tgt is None:
intrinsics_tgt = intrinsics_src
assert all(x is not None for x in [extrinsics_src, extrinsics_tgt, intrinsics_src, intrinsics_tgt]), "Make sure you have provided all the necessary camera parameters."
view_tgt = transforms.extrinsics_to_view(extrinsics_tgt)
perspective_tgt = transforms.intrinsics_to_perspective(intrinsics_tgt, near=near, far=far)
if padding > 0:
uv, faces = utils.image_mesh(width=width+2, height=height+2)
uv = (uv - 1 / (width + 2)) * ((width + 2) / width)
uv_ = uv.clone().reshape(height+2, width+2, 2)
uv_[0, :, 1] -= padding / height
uv_[-1, :, 1] += padding / height
uv_[:, 0, 0] -= padding / width
uv_[:, -1, 0] += padding / width
uv_ = uv_.reshape(-1, 2)
depth = torch.nn.functional.pad(depth, [1, 1, 1, 1], mode='replicate')
if image is not None:
image = torch.nn.functional.pad(image, [1, 1, 1, 1], mode='replicate')
uv, uv_, faces = uv.to(depth.device), uv_.to(depth.device), faces.to(depth.device)
pts = transforms.unproject_cv(
uv_,
depth.flatten(-2, -1),
extrinsics_src,
intrinsics_src,
)
else:
uv, faces = utils.image_mesh(width=depth.shape[-1], height=depth.shape[-2])
if mask is not None:
depth = torch.where(mask, depth, torch.tensor(far, dtype=depth.dtype, device=depth.device))
uv, faces = uv.to(depth.device), faces.to(depth.device)
pts = transforms.unproject_cv(
uv,
depth.flatten(-2, -1),
extrinsics_src,
intrinsics_src,
)
# triangulate
if batch_size == 1:
faces = mesh.triangulate(faces, vertices=pts[0])
else:
faces = mesh.triangulate(faces, backslash=backslash)
# rasterize attributes
diff_attrs = None
if image is not None:
attr = image.permute(0, 2, 3, 1).flatten(1, 2)
if return_dr or return_uv:
if return_dr:
diff_attrs = [image.shape[1], image.shape[1]+1]
if return_uv and antialiasing:
antialiasing = list(range(image.shape[1]))
attr = torch.cat([attr, uv.expand(batch_size, -1, -1)], dim=-1)
else:
attr = uv.expand(batch_size, -1, -1)
if antialiasing:
print("\033[93mWarning: you are performing antialiasing on uv. This may cause artifacts.\033[0m")
if return_uv:
return_uv = False
print("\033[93mWarning: image is None, return_uv is ignored.\033[0m")
if return_dr:
diff_attrs = [0, 1]
if mask is not None:
attr = torch.cat([attr, mask.float().flatten(1, 2).unsqueeze(-1)], dim=-1)
rast = rasterize_triangle_faces(
ctx,
pts,
faces,
attr,
width,
height,
view=view_tgt,
perspective=perspective_tgt,
antialiasing=antialiasing,
diff_attrs=diff_attrs,
)
if return_dr:
output_image, screen_depth, output_dr = rast
else:
output_image, screen_depth = rast
output_mask = screen_depth < 1.0
if mask is not None:
output_image, rast_mask = output_image[..., :-1, :, :], output_image[..., -1, :, :]
output_mask &= (rast_mask > 0.9999).reshape(-1, height, width)
if (return_dr or return_uv) and image is not None:
output_image, output_uv = output_image[..., :-2, :, :], output_image[..., -2:, :, :]
output_depth = transforms.depth_buffer_to_linear(screen_depth, near=near, far=far) * output_mask
output_image = output_image * output_mask.unsqueeze(1)
outs = [output_image, output_depth, output_mask]
if return_uv:
outs.append(output_uv)
if return_dr:
outs.append(output_dr)
return tuple(outs)
def warp_image_by_forward_flow(
ctx: RastContext,
image: torch.FloatTensor,
flow: torch.FloatTensor,
depth: torch.FloatTensor = None,
*,
antialiasing: bool = True,
backslash: bool = False,
) -> Tuple[torch.FloatTensor, torch.BoolTensor]:
"""
Warp image by forward flow.
NOTE: if batch size is 1, image mesh will be triangulated aware of the depth, yielding less distorted results.
Otherwise, image mesh will be triangulated simply for batch rendering.
Args:
ctx (Union[dr.RasterizeCudaContext, dr.RasterizeGLContext]): rasterization context
image (torch.Tensor): (B, C, H, W) image
flow (torch.Tensor): (B, 2, H, W) forward flow
depth (torch.Tensor, optional): (B, H, W) linear depth. If None, will use the same for all pixels. Defaults to None.
antialiasing (bool, optional): whether to perform antialiasing. Defaults to True.
backslash (bool, optional): whether to use backslash triangulation. Defaults to False.
Returns:
image: (torch.FloatTensor): (B, C, H, W) rendered image
mask: (torch.BoolTensor): (B, H, W) mask of valid pixels
"""
assert image.ndim == 4, f'Wrong shape of image: {image.shape}'
batch_size, _, height, width = image.shape
if depth is None:
depth = torch.ones_like(flow[:, 0])
extrinsics = torch.eye(4).to(image)
fov = torch.deg2rad(torch.tensor([45.0], device=image.device))
intrinsics = transforms.intrinsics_from_fov(fov, width, height, normalize=True)[0]
view = transforms.extrinsics_to_view(extrinsics)
perspective = transforms.intrinsics_to_perspective(intrinsics, near=0.1, far=100)
uv, faces = utils.image_mesh(width=width, height=height)
uv, faces = uv.to(image.device), faces.to(image.device)
uv = uv + flow.permute(0, 2, 3, 1).flatten(1, 2)
pts = transforms.unproject_cv(
uv,
depth.flatten(-2, -1),
extrinsics,
intrinsics,
)
# triangulate
if batch_size == 1:
faces = mesh.triangulate(faces, vertices=pts[0])
else:
faces = mesh.triangulate(faces, backslash=backslash)
# rasterize attributes
attr = image.permute(0, 2, 3, 1).flatten(1, 2)
rast = rasterize_triangle_faces(
ctx,
pts,
faces,
attr,
width,
height,
view=view,
perspective=perspective,
antialiasing=antialiasing,
)
output_image, screen_depth = rast
output_mask = screen_depth < 1.0
output_image = output_image * output_mask.unsqueeze(1)
outs = [output_image, output_mask]
return tuple(outs)