|
|
|
|
|
|
|
|
|
import pickle
|
|
from functools import lru_cache
|
|
from typing import Dict, Optional, Tuple
|
|
import torch
|
|
|
|
from detectron2.utils.file_io import PathManager
|
|
|
|
from densepose.data.meshes.catalog import MeshCatalog, MeshInfo
|
|
|
|
|
|
def _maybe_copy_to_device(
|
|
attribute: Optional[torch.Tensor], device: torch.device
|
|
) -> Optional[torch.Tensor]:
|
|
if attribute is None:
|
|
return None
|
|
return attribute.to(device)
|
|
|
|
|
|
class Mesh:
|
|
def __init__(
|
|
self,
|
|
vertices: Optional[torch.Tensor] = None,
|
|
faces: Optional[torch.Tensor] = None,
|
|
geodists: Optional[torch.Tensor] = None,
|
|
symmetry: Optional[Dict[str, torch.Tensor]] = None,
|
|
texcoords: Optional[torch.Tensor] = None,
|
|
mesh_info: Optional[MeshInfo] = None,
|
|
device: Optional[torch.device] = None,
|
|
):
|
|
"""
|
|
Args:
|
|
vertices (tensor [N, 3] of float32): vertex coordinates in 3D
|
|
faces (tensor [M, 3] of long): triangular face represented as 3
|
|
vertex indices
|
|
geodists (tensor [N, N] of float32): geodesic distances from
|
|
vertex `i` to vertex `j` (optional, default: None)
|
|
symmetry (dict: str -> tensor): various mesh symmetry data:
|
|
- "vertex_transforms": vertex mapping under horizontal flip,
|
|
tensor of size [N] of type long; vertex `i` is mapped to
|
|
vertex `tensor[i]` (optional, default: None)
|
|
texcoords (tensor [N, 2] of float32): texture coordinates, i.e. global
|
|
and normalized mesh UVs (optional, default: None)
|
|
mesh_info (MeshInfo type): necessary to load the attributes on-the-go,
|
|
can be used instead of passing all the variables one by one
|
|
device (torch.device): device of the Mesh. If not provided, will use
|
|
the device of the vertices
|
|
"""
|
|
self._vertices = vertices
|
|
self._faces = faces
|
|
self._geodists = geodists
|
|
self._symmetry = symmetry
|
|
self._texcoords = texcoords
|
|
self.mesh_info = mesh_info
|
|
self.device = device
|
|
|
|
assert self._vertices is not None or self.mesh_info is not None
|
|
|
|
all_fields = [self._vertices, self._faces, self._geodists, self._texcoords]
|
|
|
|
if self.device is None:
|
|
for field in all_fields:
|
|
if field is not None:
|
|
self.device = field.device
|
|
break
|
|
if self.device is None and symmetry is not None:
|
|
for key in symmetry:
|
|
self.device = symmetry[key].device
|
|
break
|
|
self.device = torch.device("cpu") if self.device is None else self.device
|
|
|
|
assert all([var.device == self.device for var in all_fields if var is not None])
|
|
if symmetry:
|
|
assert all(symmetry[key].device == self.device for key in symmetry)
|
|
if texcoords and vertices:
|
|
assert len(vertices) == len(texcoords)
|
|
|
|
def to(self, device: torch.device):
|
|
device_symmetry = self._symmetry
|
|
if device_symmetry:
|
|
device_symmetry = {key: value.to(device) for key, value in device_symmetry.items()}
|
|
return Mesh(
|
|
_maybe_copy_to_device(self._vertices, device),
|
|
_maybe_copy_to_device(self._faces, device),
|
|
_maybe_copy_to_device(self._geodists, device),
|
|
device_symmetry,
|
|
_maybe_copy_to_device(self._texcoords, device),
|
|
self.mesh_info,
|
|
device,
|
|
)
|
|
|
|
@property
|
|
def vertices(self):
|
|
if self._vertices is None and self.mesh_info is not None:
|
|
self._vertices = load_mesh_data(self.mesh_info.data, "vertices", self.device)
|
|
return self._vertices
|
|
|
|
@property
|
|
def faces(self):
|
|
if self._faces is None and self.mesh_info is not None:
|
|
self._faces = load_mesh_data(self.mesh_info.data, "faces", self.device)
|
|
return self._faces
|
|
|
|
@property
|
|
def geodists(self):
|
|
if self._geodists is None and self.mesh_info is not None:
|
|
self._geodists = load_mesh_auxiliary_data(self.mesh_info.geodists, self.device)
|
|
return self._geodists
|
|
|
|
@property
|
|
def symmetry(self):
|
|
if self._symmetry is None and self.mesh_info is not None:
|
|
self._symmetry = load_mesh_symmetry(self.mesh_info.symmetry, self.device)
|
|
return self._symmetry
|
|
|
|
@property
|
|
def texcoords(self):
|
|
if self._texcoords is None and self.mesh_info is not None:
|
|
self._texcoords = load_mesh_auxiliary_data(self.mesh_info.texcoords, self.device)
|
|
return self._texcoords
|
|
|
|
def get_geodists(self):
|
|
if self.geodists is None:
|
|
self.geodists = self._compute_geodists()
|
|
return self.geodists
|
|
|
|
def _compute_geodists(self):
|
|
|
|
geodists = None
|
|
return geodists
|
|
|
|
|
|
def load_mesh_data(
|
|
mesh_fpath: str, field: str, device: Optional[torch.device] = None
|
|
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
with PathManager.open(mesh_fpath, "rb") as hFile:
|
|
|
|
|
|
return torch.as_tensor(pickle.load(hFile)[field], dtype=torch.float).to(device)
|
|
return None
|
|
|
|
|
|
def load_mesh_auxiliary_data(
|
|
fpath: str, device: Optional[torch.device] = None
|
|
) -> Optional[torch.Tensor]:
|
|
fpath_local = PathManager.get_local_path(fpath)
|
|
with PathManager.open(fpath_local, "rb") as hFile:
|
|
return torch.as_tensor(pickle.load(hFile), dtype=torch.float).to(device)
|
|
return None
|
|
|
|
|
|
@lru_cache()
|
|
def load_mesh_symmetry(
|
|
symmetry_fpath: str, device: Optional[torch.device] = None
|
|
) -> Optional[Dict[str, torch.Tensor]]:
|
|
with PathManager.open(symmetry_fpath, "rb") as hFile:
|
|
symmetry_loaded = pickle.load(hFile)
|
|
symmetry = {
|
|
"vertex_transforms": torch.as_tensor(
|
|
symmetry_loaded["vertex_transforms"], dtype=torch.long
|
|
).to(device),
|
|
}
|
|
return symmetry
|
|
return None
|
|
|
|
|
|
@lru_cache()
|
|
def create_mesh(mesh_name: str, device: Optional[torch.device] = None) -> Mesh:
|
|
return Mesh(mesh_info=MeshCatalog[mesh_name], device=device)
|
|
|