|
|
|
|
|
|
|
|
|
import logging
|
|
import os
|
|
from collections import OrderedDict
|
|
from typing import List, Optional, Union
|
|
import torch
|
|
from torch import nn
|
|
|
|
from detectron2.checkpoint import DetectionCheckpointer
|
|
from detectron2.config import CfgNode
|
|
from detectron2.engine import DefaultTrainer
|
|
from detectron2.evaluation import (
|
|
DatasetEvaluator,
|
|
DatasetEvaluators,
|
|
inference_on_dataset,
|
|
print_csv_format,
|
|
)
|
|
from detectron2.solver.build import get_default_optimizer_params, maybe_add_gradient_clipping
|
|
from detectron2.utils import comm
|
|
from detectron2.utils.events import EventWriter, get_event_storage
|
|
|
|
from densepose import DensePoseDatasetMapperTTA, DensePoseGeneralizedRCNNWithTTA, load_from_cfg
|
|
from densepose.data import (
|
|
DatasetMapper,
|
|
build_combined_loader,
|
|
build_detection_test_loader,
|
|
build_detection_train_loader,
|
|
build_inference_based_loaders,
|
|
has_inference_based_loaders,
|
|
)
|
|
from densepose.evaluation.d2_evaluator_adapter import Detectron2COCOEvaluatorAdapter
|
|
from densepose.evaluation.evaluator import DensePoseCOCOEvaluator, build_densepose_evaluator_storage
|
|
from densepose.modeling.cse import Embedder
|
|
|
|
|
|
class SampleCountingLoader:
|
|
def __init__(self, loader):
|
|
self.loader = loader
|
|
|
|
def __iter__(self):
|
|
it = iter(self.loader)
|
|
storage = get_event_storage()
|
|
while True:
|
|
try:
|
|
batch = next(it)
|
|
num_inst_per_dataset = {}
|
|
for data in batch:
|
|
dataset_name = data["dataset"]
|
|
if dataset_name not in num_inst_per_dataset:
|
|
num_inst_per_dataset[dataset_name] = 0
|
|
num_inst = len(data["instances"])
|
|
num_inst_per_dataset[dataset_name] += num_inst
|
|
for dataset_name in num_inst_per_dataset:
|
|
storage.put_scalar(f"batch/{dataset_name}", num_inst_per_dataset[dataset_name])
|
|
yield batch
|
|
except StopIteration:
|
|
break
|
|
|
|
|
|
class SampleCountMetricPrinter(EventWriter):
|
|
def __init__(self):
|
|
self.logger = logging.getLogger(__name__)
|
|
|
|
def write(self):
|
|
storage = get_event_storage()
|
|
batch_stats_strs = []
|
|
for key, buf in storage.histories().items():
|
|
if key.startswith("batch/"):
|
|
batch_stats_strs.append(f"{key} {buf.avg(20)}")
|
|
self.logger.info(", ".join(batch_stats_strs))
|
|
|
|
|
|
class Trainer(DefaultTrainer):
|
|
@classmethod
|
|
def extract_embedder_from_model(cls, model: nn.Module) -> Optional[Embedder]:
|
|
if isinstance(model, nn.parallel.DistributedDataParallel):
|
|
model = model.module
|
|
if hasattr(model, "roi_heads") and hasattr(model.roi_heads, "embedder"):
|
|
return model.roi_heads.embedder
|
|
return None
|
|
|
|
|
|
|
|
@classmethod
|
|
def test(
|
|
cls,
|
|
cfg: CfgNode,
|
|
model: nn.Module,
|
|
evaluators: Optional[Union[DatasetEvaluator, List[DatasetEvaluator]]] = None,
|
|
):
|
|
"""
|
|
Args:
|
|
cfg (CfgNode):
|
|
model (nn.Module):
|
|
evaluators (DatasetEvaluator, list[DatasetEvaluator] or None): if None, will call
|
|
:meth:`build_evaluator`. Otherwise, must have the same length as
|
|
``cfg.DATASETS.TEST``.
|
|
|
|
Returns:
|
|
dict: a dict of result metrics
|
|
"""
|
|
logger = logging.getLogger(__name__)
|
|
if isinstance(evaluators, DatasetEvaluator):
|
|
evaluators = [evaluators]
|
|
if evaluators is not None:
|
|
assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format(
|
|
len(cfg.DATASETS.TEST), len(evaluators)
|
|
)
|
|
|
|
results = OrderedDict()
|
|
for idx, dataset_name in enumerate(cfg.DATASETS.TEST):
|
|
data_loader = cls.build_test_loader(cfg, dataset_name)
|
|
|
|
|
|
if evaluators is not None:
|
|
evaluator = evaluators[idx]
|
|
else:
|
|
try:
|
|
embedder = cls.extract_embedder_from_model(model)
|
|
evaluator = cls.build_evaluator(cfg, dataset_name, embedder=embedder)
|
|
except NotImplementedError:
|
|
logger.warn(
|
|
"No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
|
|
"or implement its `build_evaluator` method."
|
|
)
|
|
results[dataset_name] = {}
|
|
continue
|
|
if cfg.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE or comm.is_main_process():
|
|
results_i = inference_on_dataset(model, data_loader, evaluator)
|
|
else:
|
|
results_i = {}
|
|
results[dataset_name] = results_i
|
|
if comm.is_main_process():
|
|
assert isinstance(
|
|
results_i, dict
|
|
), "Evaluator must return a dict on the main process. Got {} instead.".format(
|
|
results_i
|
|
)
|
|
logger.info("Evaluation results for {} in csv format:".format(dataset_name))
|
|
print_csv_format(results_i)
|
|
|
|
if len(results) == 1:
|
|
results = list(results.values())[0]
|
|
return results
|
|
|
|
@classmethod
|
|
def build_evaluator(
|
|
cls,
|
|
cfg: CfgNode,
|
|
dataset_name: str,
|
|
output_folder: Optional[str] = None,
|
|
embedder: Optional[Embedder] = None,
|
|
) -> DatasetEvaluators:
|
|
if output_folder is None:
|
|
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
|
|
evaluators = []
|
|
distributed = cfg.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
evaluators.append(
|
|
Detectron2COCOEvaluatorAdapter(
|
|
dataset_name, output_dir=output_folder, distributed=distributed
|
|
)
|
|
)
|
|
if cfg.MODEL.DENSEPOSE_ON:
|
|
storage = build_densepose_evaluator_storage(cfg, output_folder)
|
|
evaluators.append(
|
|
DensePoseCOCOEvaluator(
|
|
dataset_name,
|
|
distributed,
|
|
output_folder,
|
|
evaluator_type=cfg.DENSEPOSE_EVALUATION.TYPE,
|
|
min_iou_threshold=cfg.DENSEPOSE_EVALUATION.MIN_IOU_THRESHOLD,
|
|
storage=storage,
|
|
embedder=embedder,
|
|
should_evaluate_mesh_alignment=cfg.DENSEPOSE_EVALUATION.EVALUATE_MESH_ALIGNMENT,
|
|
mesh_alignment_mesh_names=cfg.DENSEPOSE_EVALUATION.MESH_ALIGNMENT_MESH_NAMES,
|
|
)
|
|
)
|
|
return DatasetEvaluators(evaluators)
|
|
|
|
@classmethod
|
|
def build_optimizer(cls, cfg: CfgNode, model: nn.Module):
|
|
params = get_default_optimizer_params(
|
|
model,
|
|
base_lr=cfg.SOLVER.BASE_LR,
|
|
weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,
|
|
bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,
|
|
weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,
|
|
overrides={
|
|
"features": {
|
|
"lr": cfg.SOLVER.BASE_LR * cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.FEATURES_LR_FACTOR,
|
|
},
|
|
"embeddings": {
|
|
"lr": cfg.SOLVER.BASE_LR * cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_LR_FACTOR,
|
|
},
|
|
},
|
|
)
|
|
optimizer = torch.optim.SGD(
|
|
params,
|
|
cfg.SOLVER.BASE_LR,
|
|
momentum=cfg.SOLVER.MOMENTUM,
|
|
nesterov=cfg.SOLVER.NESTEROV,
|
|
weight_decay=cfg.SOLVER.WEIGHT_DECAY,
|
|
)
|
|
|
|
return maybe_add_gradient_clipping(cfg, optimizer)
|
|
|
|
@classmethod
|
|
def build_test_loader(cls, cfg: CfgNode, dataset_name):
|
|
return build_detection_test_loader(cfg, dataset_name, mapper=DatasetMapper(cfg, False))
|
|
|
|
@classmethod
|
|
def build_train_loader(cls, cfg: CfgNode):
|
|
data_loader = build_detection_train_loader(cfg, mapper=DatasetMapper(cfg, True))
|
|
if not has_inference_based_loaders(cfg):
|
|
return data_loader
|
|
model = cls.build_model(cfg)
|
|
model.to(cfg.BOOTSTRAP_MODEL.DEVICE)
|
|
DetectionCheckpointer(model).resume_or_load(cfg.BOOTSTRAP_MODEL.WEIGHTS, resume=False)
|
|
inference_based_loaders, ratios = build_inference_based_loaders(cfg, model)
|
|
loaders = [data_loader] + inference_based_loaders
|
|
ratios = [1.0] + ratios
|
|
combined_data_loader = build_combined_loader(cfg, loaders, ratios)
|
|
sample_counting_loader = SampleCountingLoader(combined_data_loader)
|
|
return sample_counting_loader
|
|
|
|
def build_writers(self):
|
|
writers = super().build_writers()
|
|
writers.append(SampleCountMetricPrinter())
|
|
return writers
|
|
|
|
@classmethod
|
|
def test_with_TTA(cls, cfg: CfgNode, model):
|
|
logger = logging.getLogger("detectron2.trainer")
|
|
|
|
|
|
logger.info("Running inference with test-time augmentation ...")
|
|
transform_data = load_from_cfg(cfg)
|
|
model = DensePoseGeneralizedRCNNWithTTA(
|
|
cfg, model, transform_data, DensePoseDatasetMapperTTA(cfg)
|
|
)
|
|
evaluators = [
|
|
cls.build_evaluator(
|
|
cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA")
|
|
)
|
|
for name in cfg.DATASETS.TEST
|
|
]
|
|
res = cls.test(cfg, model, evaluators)
|
|
res = OrderedDict({k + "_TTA": v for k, v in res.items()})
|
|
return res
|
|
|