Spaces:
Running
on
Zero
Running
on
Zero
| # -*- coding: utf-8 -*- | |
| # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is | |
| # holder of all proprietary rights on this computer program. | |
| # You can only use this computer program if you have closed | |
| # a license agreement with MPG or you get the right to use the computer | |
| # program from someone who is authorized to grant you that right. | |
| # Any use of the computer program without a valid license is prohibited and | |
| # liable to prosecution. | |
| # | |
| # Copyright©2019 Max-Planck-Gesellschaft zur Förderung | |
| # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute | |
| # for Intelligent Systems. All rights reserved. | |
| # | |
| # Contact: [email protected] | |
| import torch | |
| from torch import nn | |
| import trimesh | |
| import math | |
| from typing import NewType | |
| from pytorch3d.structures import Meshes | |
| from pytorch3d.renderer.mesh import rasterize_meshes | |
| Tensor = NewType('Tensor', torch.Tensor) | |
| def solid_angles(points: Tensor, | |
| triangles: Tensor, | |
| thresh: float = 1e-8) -> Tensor: | |
| ''' Compute solid angle between the input points and triangles | |
| Follows the method described in: | |
| The Solid Angle of a Plane Triangle | |
| A. VAN OOSTEROM AND J. STRACKEE | |
| IEEE TRANSACTIONS ON BIOMEDICAL ENGINEERING, | |
| VOL. BME-30, NO. 2, FEBRUARY 1983 | |
| Parameters | |
| ----------- | |
| points: BxQx3 | |
| Tensor of input query points | |
| triangles: BxFx3x3 | |
| Target triangles | |
| thresh: float | |
| float threshold | |
| Returns | |
| ------- | |
| solid_angles: BxQxF | |
| A tensor containing the solid angle between all query points | |
| and input triangles | |
| ''' | |
| # Center the triangles on the query points. Size should be BxQxFx3x3 | |
| centered_tris = triangles[:, None] - points[:, :, None, None] | |
| # BxQxFx3 | |
| norms = torch.norm(centered_tris, dim=-1) | |
| # Should be BxQxFx3 | |
| cross_prod = torch.cross(centered_tris[:, :, :, 1], | |
| centered_tris[:, :, :, 2], | |
| dim=-1) | |
| # Should be BxQxF | |
| numerator = (centered_tris[:, :, :, 0] * cross_prod).sum(dim=-1) | |
| del cross_prod | |
| dot01 = (centered_tris[:, :, :, 0] * centered_tris[:, :, :, 1]).sum(dim=-1) | |
| dot12 = (centered_tris[:, :, :, 1] * centered_tris[:, :, :, 2]).sum(dim=-1) | |
| dot02 = (centered_tris[:, :, :, 0] * centered_tris[:, :, :, 2]).sum(dim=-1) | |
| del centered_tris | |
| denominator = (norms.prod(dim=-1) + dot01 * norms[:, :, :, 2] + | |
| dot02 * norms[:, :, :, 1] + dot12 * norms[:, :, :, 0]) | |
| del dot01, dot12, dot02, norms | |
| # Should be BxQ | |
| solid_angle = torch.atan2(numerator, denominator) | |
| del numerator, denominator | |
| torch.cuda.empty_cache() | |
| return 2 * solid_angle | |
| def winding_numbers(points: Tensor, | |
| triangles: Tensor, | |
| thresh: float = 1e-8) -> Tensor: | |
| ''' Uses winding_numbers to compute inside/outside | |
| Robust inside-outside segmentation using generalized winding numbers | |
| Alec Jacobson, | |
| Ladislav Kavan, | |
| Olga Sorkine-Hornung | |
| Fast Winding Numbers for Soups and Clouds SIGGRAPH 2018 | |
| Gavin Barill | |
| NEIL G. Dickson | |
| Ryan Schmidt | |
| David I.W. Levin | |
| and Alec Jacobson | |
| Parameters | |
| ----------- | |
| points: BxQx3 | |
| Tensor of input query points | |
| triangles: BxFx3x3 | |
| Target triangles | |
| thresh: float | |
| float threshold | |
| Returns | |
| ------- | |
| winding_numbers: BxQ | |
| A tensor containing the Generalized winding numbers | |
| ''' | |
| # The generalized winding number is the sum of solid angles of the point | |
| # with respect to all triangles. | |
| return 1 / (4 * math.pi) * solid_angles(points, triangles, | |
| thresh=thresh).sum(dim=-1) | |
| def batch_contains(verts, faces, points): | |
| B = verts.shape[0] | |
| N = points.shape[1] | |
| verts = verts.detach().cpu() | |
| faces = faces.detach().cpu() | |
| points = points.detach().cpu() | |
| contains = torch.zeros(B, N) | |
| for i in range(B): | |
| contains[i] = torch.as_tensor( | |
| trimesh.Trimesh(verts[i], faces[i]).contains(points[i])) | |
| return 2.0 * (contains - 0.5) | |
| def dict2obj(d): | |
| # if isinstance(d, list): | |
| # d = [dict2obj(x) for x in d] | |
| if not isinstance(d, dict): | |
| return d | |
| class C(object): | |
| pass | |
| o = C() | |
| for k in d: | |
| o.__dict__[k] = dict2obj(d[k]) | |
| return o | |
| def face_vertices(vertices, faces): | |
| """ | |
| :param vertices: [batch size, number of vertices, 3] | |
| :param faces: [batch size, number of faces, 3] | |
| :return: [batch size, number of faces, 3, 3] | |
| """ | |
| bs, nv = vertices.shape[:2] | |
| bs, nf = faces.shape[:2] | |
| device = vertices.device | |
| faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) * | |
| nv)[:, None, None] | |
| vertices = vertices.reshape((bs * nv, vertices.shape[-1])) | |
| return vertices[faces.long()] | |
| class Pytorch3dRasterizer(nn.Module): | |
| """ Borrowed from https://github.com/facebookresearch/pytorch3d | |
| Notice: | |
| x,y,z are in image space, normalized | |
| can only render squared image now | |
| """ | |
| def __init__(self, image_size=224): | |
| """ | |
| use fixed raster_settings for rendering faces | |
| """ | |
| super().__init__() | |
| raster_settings = { | |
| 'image_size': image_size, | |
| 'blur_radius': 0.0, | |
| 'faces_per_pixel': 1, | |
| 'bin_size': None, | |
| 'max_faces_per_bin': None, | |
| 'perspective_correct': True, | |
| 'cull_backfaces': True, | |
| } | |
| raster_settings = dict2obj(raster_settings) | |
| self.raster_settings = raster_settings | |
| def forward(self, vertices, faces, attributes=None): | |
| fixed_vertices = vertices.clone() | |
| fixed_vertices[..., :2] = -fixed_vertices[..., :2] | |
| meshes_screen = Meshes(verts=fixed_vertices.float(), | |
| faces=faces.long()) | |
| raster_settings = self.raster_settings | |
| pix_to_face, zbuf, bary_coords, dists = rasterize_meshes( | |
| meshes_screen, | |
| image_size=raster_settings.image_size, | |
| blur_radius=raster_settings.blur_radius, | |
| faces_per_pixel=raster_settings.faces_per_pixel, | |
| bin_size=raster_settings.bin_size, | |
| max_faces_per_bin=raster_settings.max_faces_per_bin, | |
| perspective_correct=raster_settings.perspective_correct, | |
| ) | |
| vismask = (pix_to_face > -1).float() | |
| D = attributes.shape[-1] | |
| attributes = attributes.clone() | |
| attributes = attributes.view(attributes.shape[0] * attributes.shape[1], | |
| 3, attributes.shape[-1]) | |
| N, H, W, K, _ = bary_coords.shape | |
| mask = pix_to_face == -1 | |
| pix_to_face = pix_to_face.clone() | |
| pix_to_face[mask] = 0 | |
| idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D) | |
| pixel_face_vals = attributes.gather(0, idx).view(N, H, W, K, 3, D) | |
| pixel_vals = (bary_coords[..., None] * pixel_face_vals).sum(dim=-2) | |
| pixel_vals[mask] = 0 # Replace masked values in output. | |
| pixel_vals = pixel_vals[:, :, :, 0].permute(0, 3, 1, 2) | |
| pixel_vals = torch.cat( | |
| [pixel_vals, vismask[:, :, :, 0][:, None, :, :]], dim=1) | |
| return pixel_vals | |