|
|
|
|
|
|
|
|
|
from typing import Any, Dict, List
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
|
|
from detectron2.config import CfgNode
|
|
from detectron2.structures import Instances
|
|
|
|
from densepose.data.meshes.catalog import MeshCatalog
|
|
from densepose.modeling.cse.utils import normalize_embeddings, squared_euclidean_distance_matrix
|
|
|
|
from .embed_utils import PackedCseAnnotations
|
|
from .utils import BilinearInterpolationHelper
|
|
|
|
|
|
class EmbeddingLoss:
|
|
"""
|
|
Computes losses for estimated embeddings given annotated vertices.
|
|
Instances in a minibatch that correspond to the same mesh are grouped
|
|
together. For each group, loss is computed as cross-entropy for
|
|
unnormalized scores given ground truth mesh vertex ids.
|
|
Scores are based on squared distances between estimated vertex embeddings
|
|
and mesh vertex embeddings.
|
|
"""
|
|
|
|
def __init__(self, cfg: CfgNode):
|
|
"""
|
|
Initialize embedding loss from config
|
|
"""
|
|
self.embdist_gauss_sigma = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_DIST_GAUSS_SIGMA
|
|
|
|
def __call__(
|
|
self,
|
|
proposals_with_gt: List[Instances],
|
|
densepose_predictor_outputs: Any,
|
|
packed_annotations: PackedCseAnnotations,
|
|
interpolator: BilinearInterpolationHelper,
|
|
embedder: nn.Module,
|
|
) -> Dict[int, torch.Tensor]:
|
|
"""
|
|
Produces losses for estimated embeddings given annotated vertices.
|
|
Embeddings for all the vertices of a mesh are computed by the embedder.
|
|
Embeddings for observed pixels are estimated by a predictor.
|
|
Losses are computed as cross-entropy for squared distances between
|
|
observed vertex embeddings and all mesh vertex embeddings given
|
|
ground truth vertex IDs.
|
|
|
|
Args:
|
|
proposals_with_gt (list of Instances): detections with associated
|
|
ground truth data; each item corresponds to instances detected
|
|
on 1 image; the number of items corresponds to the number of
|
|
images in a batch
|
|
densepose_predictor_outputs: an object of a dataclass that contains predictor
|
|
outputs with estimated values; assumed to have the following attributes:
|
|
* embedding - embedding estimates, tensor of shape [N, D, S, S], where
|
|
N = number of instances (= sum N_i, where N_i is the number of
|
|
instances on image i)
|
|
D = embedding space dimensionality (MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE)
|
|
S = output size (width and height)
|
|
packed_annotations (PackedCseAnnotations): contains various data useful
|
|
for loss computation, each data is packed into a single tensor
|
|
interpolator (BilinearInterpolationHelper): bilinear interpolation helper
|
|
embedder (nn.Module): module that computes vertex embeddings for different meshes
|
|
Return:
|
|
dict(int -> tensor): losses for different mesh IDs
|
|
"""
|
|
losses = {}
|
|
for mesh_id_tensor in packed_annotations.vertex_mesh_ids_gt.unique():
|
|
mesh_id = mesh_id_tensor.item()
|
|
mesh_name = MeshCatalog.get_mesh_name(mesh_id)
|
|
|
|
|
|
j_valid = interpolator.j_valid * (
|
|
packed_annotations.vertex_mesh_ids_gt == mesh_id
|
|
)
|
|
if not torch.any(j_valid):
|
|
continue
|
|
|
|
|
|
vertex_embeddings_i = normalize_embeddings(
|
|
interpolator.extract_at_points(
|
|
densepose_predictor_outputs.embedding,
|
|
slice_fine_segm=slice(None),
|
|
w_ylo_xlo=interpolator.w_ylo_xlo[:, None],
|
|
w_ylo_xhi=interpolator.w_ylo_xhi[:, None],
|
|
w_yhi_xlo=interpolator.w_yhi_xlo[:, None],
|
|
w_yhi_xhi=interpolator.w_yhi_xhi[:, None],
|
|
)[j_valid, :]
|
|
)
|
|
|
|
|
|
vertex_indices_i = packed_annotations.vertex_ids_gt[j_valid]
|
|
|
|
|
|
mesh_vertex_embeddings = embedder(mesh_name)
|
|
|
|
|
|
scores = squared_euclidean_distance_matrix(
|
|
vertex_embeddings_i, mesh_vertex_embeddings
|
|
) / (-self.embdist_gauss_sigma)
|
|
losses[mesh_name] = F.cross_entropy(scores, vertex_indices_i, ignore_index=-1)
|
|
|
|
for mesh_name in embedder.mesh_names:
|
|
if mesh_name not in losses:
|
|
losses[mesh_name] = self.fake_value(
|
|
densepose_predictor_outputs, embedder, mesh_name
|
|
)
|
|
return losses
|
|
|
|
def fake_values(self, densepose_predictor_outputs: Any, embedder: nn.Module):
|
|
losses = {}
|
|
for mesh_name in embedder.mesh_names:
|
|
losses[mesh_name] = self.fake_value(densepose_predictor_outputs, embedder, mesh_name)
|
|
return losses
|
|
|
|
def fake_value(self, densepose_predictor_outputs: Any, embedder: nn.Module, mesh_name: str):
|
|
return densepose_predictor_outputs.embedding.sum() * 0 + embedder(mesh_name).sum() * 0
|
|
|