|
|
|
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch.nn import functional as F
|
|
|
|
from densepose.data.meshes.catalog import MeshCatalog
|
|
from densepose.structures.mesh import load_mesh_symmetry
|
|
from densepose.structures.transform_data import DensePoseTransformData
|
|
|
|
|
|
class DensePoseDataRelative:
|
|
"""
|
|
Dense pose relative annotations that can be applied to any bounding box:
|
|
x - normalized X coordinates [0, 255] of annotated points
|
|
y - normalized Y coordinates [0, 255] of annotated points
|
|
i - body part labels 0,...,24 for annotated points
|
|
u - body part U coordinates [0, 1] for annotated points
|
|
v - body part V coordinates [0, 1] for annotated points
|
|
segm - 256x256 segmentation mask with values 0,...,14
|
|
To obtain absolute x and y data wrt some bounding box one needs to first
|
|
divide the data by 256, multiply by the respective bounding box size
|
|
and add bounding box offset:
|
|
x_img = x0 + x_norm * w / 256.0
|
|
y_img = y0 + y_norm * h / 256.0
|
|
Segmentation masks are typically sampled to get image-based masks.
|
|
"""
|
|
|
|
|
|
X_KEY = "dp_x"
|
|
|
|
Y_KEY = "dp_y"
|
|
|
|
U_KEY = "dp_U"
|
|
|
|
V_KEY = "dp_V"
|
|
|
|
I_KEY = "dp_I"
|
|
|
|
S_KEY = "dp_masks"
|
|
|
|
VERTEX_IDS_KEY = "dp_vertex"
|
|
|
|
MESH_NAME_KEY = "ref_model"
|
|
|
|
N_BODY_PARTS = 14
|
|
|
|
N_PART_LABELS = 24
|
|
MASK_SIZE = 256
|
|
|
|
def __init__(self, annotation, cleanup=False):
|
|
self.x = torch.as_tensor(annotation[DensePoseDataRelative.X_KEY])
|
|
self.y = torch.as_tensor(annotation[DensePoseDataRelative.Y_KEY])
|
|
if (
|
|
DensePoseDataRelative.I_KEY in annotation
|
|
and DensePoseDataRelative.U_KEY in annotation
|
|
and DensePoseDataRelative.V_KEY in annotation
|
|
):
|
|
self.i = torch.as_tensor(annotation[DensePoseDataRelative.I_KEY])
|
|
self.u = torch.as_tensor(annotation[DensePoseDataRelative.U_KEY])
|
|
self.v = torch.as_tensor(annotation[DensePoseDataRelative.V_KEY])
|
|
if (
|
|
DensePoseDataRelative.VERTEX_IDS_KEY in annotation
|
|
and DensePoseDataRelative.MESH_NAME_KEY in annotation
|
|
):
|
|
self.vertex_ids = torch.as_tensor(
|
|
annotation[DensePoseDataRelative.VERTEX_IDS_KEY], dtype=torch.long
|
|
)
|
|
self.mesh_id = MeshCatalog.get_mesh_id(annotation[DensePoseDataRelative.MESH_NAME_KEY])
|
|
if DensePoseDataRelative.S_KEY in annotation:
|
|
self.segm = DensePoseDataRelative.extract_segmentation_mask(annotation)
|
|
self.device = torch.device("cpu")
|
|
if cleanup:
|
|
DensePoseDataRelative.cleanup_annotation(annotation)
|
|
|
|
def to(self, device):
|
|
if self.device == device:
|
|
return self
|
|
new_data = DensePoseDataRelative.__new__(DensePoseDataRelative)
|
|
new_data.x = self.x.to(device)
|
|
new_data.y = self.y.to(device)
|
|
for attr in ["i", "u", "v", "vertex_ids", "segm"]:
|
|
if hasattr(self, attr):
|
|
setattr(new_data, attr, getattr(self, attr).to(device))
|
|
if hasattr(self, "mesh_id"):
|
|
new_data.mesh_id = self.mesh_id
|
|
new_data.device = device
|
|
return new_data
|
|
|
|
@staticmethod
|
|
def extract_segmentation_mask(annotation):
|
|
import pycocotools.mask as mask_utils
|
|
|
|
|
|
|
|
|
|
poly_specs = annotation[DensePoseDataRelative.S_KEY]
|
|
if isinstance(poly_specs, torch.Tensor):
|
|
|
|
return poly_specs
|
|
segm = torch.zeros((DensePoseDataRelative.MASK_SIZE,) * 2, dtype=torch.float32)
|
|
if isinstance(poly_specs, dict):
|
|
if poly_specs:
|
|
mask = mask_utils.decode(poly_specs)
|
|
segm[mask > 0] = 1
|
|
else:
|
|
for i in range(len(poly_specs)):
|
|
poly_i = poly_specs[i]
|
|
if poly_i:
|
|
mask_i = mask_utils.decode(poly_i)
|
|
segm[mask_i > 0] = i + 1
|
|
return segm
|
|
|
|
@staticmethod
|
|
def validate_annotation(annotation):
|
|
for key in [
|
|
DensePoseDataRelative.X_KEY,
|
|
DensePoseDataRelative.Y_KEY,
|
|
]:
|
|
if key not in annotation:
|
|
return False, "no {key} data in the annotation".format(key=key)
|
|
valid_for_iuv_setting = all(
|
|
key in annotation
|
|
for key in [
|
|
DensePoseDataRelative.I_KEY,
|
|
DensePoseDataRelative.U_KEY,
|
|
DensePoseDataRelative.V_KEY,
|
|
]
|
|
)
|
|
valid_for_cse_setting = all(
|
|
key in annotation
|
|
for key in [
|
|
DensePoseDataRelative.VERTEX_IDS_KEY,
|
|
DensePoseDataRelative.MESH_NAME_KEY,
|
|
]
|
|
)
|
|
if not valid_for_iuv_setting and not valid_for_cse_setting:
|
|
return (
|
|
False,
|
|
"expected either {} (IUV setting) or {} (CSE setting) annotations".format(
|
|
", ".join(
|
|
[
|
|
DensePoseDataRelative.I_KEY,
|
|
DensePoseDataRelative.U_KEY,
|
|
DensePoseDataRelative.V_KEY,
|
|
]
|
|
),
|
|
", ".join(
|
|
[
|
|
DensePoseDataRelative.VERTEX_IDS_KEY,
|
|
DensePoseDataRelative.MESH_NAME_KEY,
|
|
]
|
|
),
|
|
),
|
|
)
|
|
return True, None
|
|
|
|
@staticmethod
|
|
def cleanup_annotation(annotation):
|
|
for key in [
|
|
DensePoseDataRelative.X_KEY,
|
|
DensePoseDataRelative.Y_KEY,
|
|
DensePoseDataRelative.I_KEY,
|
|
DensePoseDataRelative.U_KEY,
|
|
DensePoseDataRelative.V_KEY,
|
|
DensePoseDataRelative.S_KEY,
|
|
DensePoseDataRelative.VERTEX_IDS_KEY,
|
|
DensePoseDataRelative.MESH_NAME_KEY,
|
|
]:
|
|
if key in annotation:
|
|
del annotation[key]
|
|
|
|
def apply_transform(self, transforms, densepose_transform_data):
|
|
self._transform_pts(transforms, densepose_transform_data)
|
|
if hasattr(self, "segm"):
|
|
self._transform_segm(transforms, densepose_transform_data)
|
|
|
|
def _transform_pts(self, transforms, dp_transform_data):
|
|
import detectron2.data.transforms as T
|
|
|
|
|
|
do_hflip = sum(isinstance(t, T.HFlipTransform) for t in transforms.transforms) % 2 == 1
|
|
if do_hflip:
|
|
self.x = self.MASK_SIZE - self.x
|
|
if hasattr(self, "i"):
|
|
self._flip_iuv_semantics(dp_transform_data)
|
|
if hasattr(self, "vertex_ids"):
|
|
self._flip_vertices()
|
|
|
|
for t in transforms.transforms:
|
|
if isinstance(t, T.RotationTransform):
|
|
xy_scale = np.array((t.w, t.h)) / DensePoseDataRelative.MASK_SIZE
|
|
xy = t.apply_coords(np.stack((self.x, self.y), axis=1) * xy_scale)
|
|
self.x, self.y = torch.tensor(xy / xy_scale, dtype=self.x.dtype).T
|
|
|
|
def _flip_iuv_semantics(self, dp_transform_data: DensePoseTransformData) -> None:
|
|
i_old = self.i.clone()
|
|
uv_symmetries = dp_transform_data.uv_symmetries
|
|
pt_label_symmetries = dp_transform_data.point_label_symmetries
|
|
for i in range(self.N_PART_LABELS):
|
|
if i + 1 in i_old:
|
|
annot_indices_i = i_old == i + 1
|
|
if pt_label_symmetries[i + 1] != i + 1:
|
|
self.i[annot_indices_i] = pt_label_symmetries[i + 1]
|
|
u_loc = (self.u[annot_indices_i] * 255).long()
|
|
v_loc = (self.v[annot_indices_i] * 255).long()
|
|
self.u[annot_indices_i] = uv_symmetries["U_transforms"][i][v_loc, u_loc].to(
|
|
device=self.u.device
|
|
)
|
|
self.v[annot_indices_i] = uv_symmetries["V_transforms"][i][v_loc, u_loc].to(
|
|
device=self.v.device
|
|
)
|
|
|
|
def _flip_vertices(self):
|
|
mesh_info = MeshCatalog[MeshCatalog.get_mesh_name(self.mesh_id)]
|
|
mesh_symmetry = (
|
|
load_mesh_symmetry(mesh_info.symmetry) if mesh_info.symmetry is not None else None
|
|
)
|
|
self.vertex_ids = mesh_symmetry["vertex_transforms"][self.vertex_ids]
|
|
|
|
def _transform_segm(self, transforms, dp_transform_data):
|
|
import detectron2.data.transforms as T
|
|
|
|
|
|
do_hflip = sum(isinstance(t, T.HFlipTransform) for t in transforms.transforms) % 2 == 1
|
|
if do_hflip:
|
|
self.segm = torch.flip(self.segm, [1])
|
|
self._flip_segm_semantics(dp_transform_data)
|
|
|
|
for t in transforms.transforms:
|
|
if isinstance(t, T.RotationTransform):
|
|
self._transform_segm_rotation(t)
|
|
|
|
def _flip_segm_semantics(self, dp_transform_data):
|
|
old_segm = self.segm.clone()
|
|
mask_label_symmetries = dp_transform_data.mask_label_symmetries
|
|
for i in range(self.N_BODY_PARTS):
|
|
if mask_label_symmetries[i + 1] != i + 1:
|
|
self.segm[old_segm == i + 1] = mask_label_symmetries[i + 1]
|
|
|
|
def _transform_segm_rotation(self, rotation):
|
|
self.segm = F.interpolate(self.segm[None, None, :], (rotation.h, rotation.w)).numpy()
|
|
self.segm = torch.tensor(rotation.apply_segmentation(self.segm[0, 0]))[None, None, :]
|
|
self.segm = F.interpolate(self.segm, [DensePoseDataRelative.MASK_SIZE] * 2)[0, 0]
|
|
|