|
|
|
|
|
|
|
|
|
import random
|
|
from typing import Optional, Tuple
|
|
import torch
|
|
from torch.nn import functional as F
|
|
|
|
from detectron2.config import CfgNode
|
|
from detectron2.structures import Instances
|
|
|
|
from densepose.converters.base import IntTupleBox
|
|
|
|
from .densepose_cse_base import DensePoseCSEBaseSampler
|
|
|
|
|
|
class DensePoseCSEConfidenceBasedSampler(DensePoseCSEBaseSampler):
|
|
"""
|
|
Samples DensePose data from DensePose predictions.
|
|
Samples for each class are drawn using confidence value estimates.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
cfg: CfgNode,
|
|
use_gt_categories: bool,
|
|
embedder: torch.nn.Module,
|
|
confidence_channel: str,
|
|
count_per_class: int = 8,
|
|
search_count_multiplier: Optional[float] = None,
|
|
search_proportion: Optional[float] = None,
|
|
):
|
|
"""
|
|
Constructor
|
|
|
|
Args:
|
|
cfg (CfgNode): the config of the model
|
|
embedder (torch.nn.Module): necessary to compute mesh vertex embeddings
|
|
confidence_channel (str): confidence channel to use for sampling;
|
|
possible values:
|
|
"coarse_segm_confidence": confidences for coarse segmentation
|
|
(default: "coarse_segm_confidence")
|
|
count_per_class (int): the sampler produces at most `count_per_class`
|
|
samples for each category (default: 8)
|
|
search_count_multiplier (float or None): if not None, the total number
|
|
of the most confident estimates of a given class to consider is
|
|
defined as `min(search_count_multiplier * count_per_class, N)`,
|
|
where `N` is the total number of estimates of the class; cannot be
|
|
specified together with `search_proportion` (default: None)
|
|
search_proportion (float or None): if not None, the total number of the
|
|
of the most confident estimates of a given class to consider is
|
|
defined as `min(max(search_proportion * N, count_per_class), N)`,
|
|
where `N` is the total number of estimates of the class; cannot be
|
|
specified together with `search_count_multiplier` (default: None)
|
|
"""
|
|
super().__init__(cfg, use_gt_categories, embedder, count_per_class)
|
|
self.confidence_channel = confidence_channel
|
|
self.search_count_multiplier = search_count_multiplier
|
|
self.search_proportion = search_proportion
|
|
assert (search_count_multiplier is None) or (search_proportion is None), (
|
|
f"Cannot specify both search_count_multiplier (={search_count_multiplier})"
|
|
f"and search_proportion (={search_proportion})"
|
|
)
|
|
|
|
def _produce_index_sample(self, values: torch.Tensor, count: int):
|
|
"""
|
|
Produce a sample of indices to select data based on confidences
|
|
|
|
Args:
|
|
values (torch.Tensor): a tensor of length k that contains confidences
|
|
k: number of points labeled with part_id
|
|
count (int): number of samples to produce, should be positive and <= k
|
|
|
|
Return:
|
|
list(int): indices of values (along axis 1) selected as a sample
|
|
"""
|
|
k = values.shape[1]
|
|
if k == count:
|
|
index_sample = list(range(k))
|
|
else:
|
|
|
|
|
|
|
|
_, sorted_confidence_indices = torch.sort(values[0])
|
|
if self.search_count_multiplier is not None:
|
|
search_count = min(int(count * self.search_count_multiplier), k)
|
|
elif self.search_proportion is not None:
|
|
search_count = min(max(int(k * self.search_proportion), count), k)
|
|
else:
|
|
search_count = min(count, k)
|
|
sample_from_top = random.sample(range(search_count), count)
|
|
index_sample = sorted_confidence_indices[-search_count:][sample_from_top]
|
|
return index_sample
|
|
|
|
def _produce_mask_and_results(
|
|
self, instance: Instances, bbox_xywh: IntTupleBox
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Method to get labels and DensePose results from an instance
|
|
|
|
Args:
|
|
instance (Instances): an instance of
|
|
`DensePoseEmbeddingPredictorOutputWithConfidences`
|
|
bbox_xywh (IntTupleBox): the corresponding bounding box
|
|
|
|
Return:
|
|
mask (torch.Tensor): shape [H, W], DensePose segmentation mask
|
|
embeddings (Tuple[torch.Tensor]): a tensor of shape [D, H, W]
|
|
DensePose CSE Embeddings
|
|
other_values: a tensor of shape [1, H, W], DensePose CSE confidence
|
|
"""
|
|
_, _, w, h = bbox_xywh
|
|
densepose_output = instance.pred_densepose
|
|
mask, embeddings, _ = super()._produce_mask_and_results(instance, bbox_xywh)
|
|
other_values = F.interpolate(
|
|
getattr(densepose_output, self.confidence_channel),
|
|
size=(h, w),
|
|
mode="bilinear",
|
|
)[0].cpu()
|
|
return mask, embeddings, other_values
|
|
|