|
|
|
|
|
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Any, Callable, Dict, List, Optional
|
|
|
|
from detectron2.structures import Instances
|
|
|
|
ModelOutput = Dict[str, Any]
|
|
SampledData = Dict[str, Any]
|
|
|
|
|
|
@dataclass
|
|
class _Sampler:
|
|
"""
|
|
Sampler registry entry that contains:
|
|
- src (str): source field to sample from (deleted after sampling)
|
|
- dst (Optional[str]): destination field to sample to, if not None
|
|
- func (Optional[Callable: Any -> Any]): function that performs sampling,
|
|
if None, reference copy is performed
|
|
"""
|
|
|
|
src: str
|
|
dst: Optional[str]
|
|
func: Optional[Callable[[Any], Any]]
|
|
|
|
|
|
class PredictionToGroundTruthSampler:
|
|
"""
|
|
Sampler implementation that converts predictions to GT using registered
|
|
samplers for different fields of `Instances`.
|
|
"""
|
|
|
|
def __init__(self, dataset_name: str = ""):
|
|
self.dataset_name = dataset_name
|
|
self._samplers = {}
|
|
self.register_sampler("pred_boxes", "gt_boxes", None)
|
|
self.register_sampler("pred_classes", "gt_classes", None)
|
|
|
|
self.register_sampler("scores")
|
|
|
|
def __call__(self, model_output: List[ModelOutput]) -> List[SampledData]:
|
|
"""
|
|
Transform model output into ground truth data through sampling
|
|
|
|
Args:
|
|
model_output (Dict[str, Any]): model output
|
|
Returns:
|
|
Dict[str, Any]: sampled data
|
|
"""
|
|
for model_output_i in model_output:
|
|
instances: Instances = model_output_i["instances"]
|
|
|
|
for _, sampler in self._samplers.items():
|
|
if not instances.has(sampler.src) or sampler.dst is None:
|
|
continue
|
|
if sampler.func is None:
|
|
instances.set(sampler.dst, instances.get(sampler.src))
|
|
else:
|
|
instances.set(sampler.dst, sampler.func(instances))
|
|
|
|
for _, sampler in self._samplers.items():
|
|
if sampler.src != sampler.dst and instances.has(sampler.src):
|
|
instances.remove(sampler.src)
|
|
model_output_i["dataset"] = self.dataset_name
|
|
return model_output
|
|
|
|
def register_sampler(
|
|
self,
|
|
prediction_attr: str,
|
|
gt_attr: Optional[str] = None,
|
|
func: Optional[Callable[[Any], Any]] = None,
|
|
):
|
|
"""
|
|
Register sampler for a field
|
|
|
|
Args:
|
|
prediction_attr (str): field to replace with a sampled value
|
|
gt_attr (Optional[str]): field to store the sampled value to, if not None
|
|
func (Optional[Callable: Any -> Any]): sampler function
|
|
"""
|
|
self._samplers[(prediction_attr, gt_attr)] = _Sampler(
|
|
src=prediction_attr, dst=gt_attr, func=func
|
|
)
|
|
|
|
def remove_sampler(
|
|
self,
|
|
prediction_attr: str,
|
|
gt_attr: Optional[str] = None,
|
|
):
|
|
"""
|
|
Remove sampler for a field
|
|
|
|
Args:
|
|
prediction_attr (str): field to replace with a sampled value
|
|
gt_attr (Optional[str]): field to store the sampled value to, if not None
|
|
"""
|
|
assert (prediction_attr, gt_attr) in self._samplers
|
|
del self._samplers[(prediction_attr, gt_attr)]
|
|
|