|
|
|
|
|
|
|
|
|
from typing import Any, List
|
|
from torch import nn
|
|
|
|
from detectron2.config import CfgNode
|
|
from detectron2.structures import Instances
|
|
|
|
from .cycle_pix2shape import PixToShapeCycleLoss
|
|
from .cycle_shape2shape import ShapeToShapeCycleLoss
|
|
from .embed import EmbeddingLoss
|
|
from .embed_utils import CseAnnotationsAccumulator
|
|
from .mask_or_segm import MaskOrSegmentationLoss
|
|
from .registry import DENSEPOSE_LOSS_REGISTRY
|
|
from .soft_embed import SoftEmbeddingLoss
|
|
from .utils import BilinearInterpolationHelper, LossDict, extract_packed_annotations_from_matches
|
|
|
|
|
|
@DENSEPOSE_LOSS_REGISTRY.register()
|
|
class DensePoseCseLoss:
|
|
""" """
|
|
|
|
_EMBED_LOSS_REGISTRY = {
|
|
EmbeddingLoss.__name__: EmbeddingLoss,
|
|
SoftEmbeddingLoss.__name__: SoftEmbeddingLoss,
|
|
}
|
|
|
|
def __init__(self, cfg: CfgNode):
|
|
"""
|
|
Initialize CSE loss from configuration options
|
|
|
|
Args:
|
|
cfg (CfgNode): configuration options
|
|
"""
|
|
self.w_segm = cfg.MODEL.ROI_DENSEPOSE_HEAD.INDEX_WEIGHTS
|
|
self.w_embed = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_LOSS_WEIGHT
|
|
self.segm_loss = MaskOrSegmentationLoss(cfg)
|
|
self.embed_loss = DensePoseCseLoss.create_embed_loss(cfg)
|
|
self.do_shape2shape = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.ENABLED
|
|
if self.do_shape2shape:
|
|
self.w_shape2shape = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.WEIGHT
|
|
self.shape2shape_loss = ShapeToShapeCycleLoss(cfg)
|
|
self.do_pix2shape = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.ENABLED
|
|
if self.do_pix2shape:
|
|
self.w_pix2shape = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.WEIGHT
|
|
self.pix2shape_loss = PixToShapeCycleLoss(cfg)
|
|
|
|
@classmethod
|
|
def create_embed_loss(cls, cfg: CfgNode):
|
|
|
|
|
|
return cls._EMBED_LOSS_REGISTRY[cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_LOSS_NAME](cfg)
|
|
|
|
def __call__(
|
|
self,
|
|
proposals_with_gt: List[Instances],
|
|
densepose_predictor_outputs: Any,
|
|
embedder: nn.Module,
|
|
) -> LossDict:
|
|
if not len(proposals_with_gt):
|
|
return self.produce_fake_losses(densepose_predictor_outputs, embedder)
|
|
accumulator = CseAnnotationsAccumulator()
|
|
packed_annotations = extract_packed_annotations_from_matches(proposals_with_gt, accumulator)
|
|
if packed_annotations is None:
|
|
return self.produce_fake_losses(densepose_predictor_outputs, embedder)
|
|
h, w = densepose_predictor_outputs.embedding.shape[2:]
|
|
interpolator = BilinearInterpolationHelper.from_matches(
|
|
packed_annotations,
|
|
(h, w),
|
|
)
|
|
meshid_to_embed_losses = self.embed_loss(
|
|
proposals_with_gt,
|
|
densepose_predictor_outputs,
|
|
packed_annotations,
|
|
interpolator,
|
|
embedder,
|
|
)
|
|
embed_loss_dict = {
|
|
f"loss_densepose_E{meshid}": self.w_embed * meshid_to_embed_losses[meshid]
|
|
for meshid in meshid_to_embed_losses
|
|
}
|
|
all_loss_dict = {
|
|
"loss_densepose_S": self.w_segm
|
|
* self.segm_loss(proposals_with_gt, densepose_predictor_outputs, packed_annotations),
|
|
**embed_loss_dict,
|
|
}
|
|
if self.do_shape2shape:
|
|
all_loss_dict["loss_shape2shape"] = self.w_shape2shape * self.shape2shape_loss(embedder)
|
|
if self.do_pix2shape:
|
|
all_loss_dict["loss_pix2shape"] = self.w_pix2shape * self.pix2shape_loss(
|
|
proposals_with_gt, densepose_predictor_outputs, packed_annotations, embedder
|
|
)
|
|
return all_loss_dict
|
|
|
|
def produce_fake_losses(
|
|
self, densepose_predictor_outputs: Any, embedder: nn.Module
|
|
) -> LossDict:
|
|
meshname_to_embed_losses = self.embed_loss.fake_values(
|
|
densepose_predictor_outputs, embedder=embedder
|
|
)
|
|
embed_loss_dict = {
|
|
f"loss_densepose_E{mesh_name}": meshname_to_embed_losses[mesh_name]
|
|
for mesh_name in meshname_to_embed_losses
|
|
}
|
|
all_loss_dict = {
|
|
"loss_densepose_S": self.segm_loss.fake_value(densepose_predictor_outputs),
|
|
**embed_loss_dict,
|
|
}
|
|
if self.do_shape2shape:
|
|
all_loss_dict["loss_shape2shape"] = self.shape2shape_loss.fake_value(embedder)
|
|
if self.do_pix2shape:
|
|
all_loss_dict["loss_pix2shape"] = self.pix2shape_loss.fake_value(
|
|
densepose_predictor_outputs, embedder
|
|
)
|
|
return all_loss_dict
|
|
|