|
""" |
|
Train and eval functions used in main.py |
|
Modified from DETR (https://github.com/facebookresearch/detr) |
|
""" |
|
import math |
|
from models import postprocessors |
|
import os |
|
import sys |
|
from typing import Iterable |
|
|
|
import torch |
|
import torch.distributed as dist |
|
|
|
import util.misc as utils |
|
from datasets.coco_eval import CocoEvaluator |
|
from datasets.refexp_eval import RefExpEvaluator |
|
|
|
from pycocotools.coco import COCO |
|
from pycocotools.cocoeval import COCOeval |
|
from datasets.a2d_eval import calculate_precision_at_k_and_iou_metrics, calculate_bbox_precision_at_k_and_iou_metrics |
|
|
|
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, |
|
data_loader: Iterable, optimizer: torch.optim.Optimizer, |
|
device: torch.device, epoch: int, max_norm: float = 0): |
|
model.train() |
|
criterion.train() |
|
metric_logger = utils.MetricLogger(delimiter=" ") |
|
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) |
|
header = 'Epoch: [{}]'.format(epoch) |
|
print_freq = 10 |
|
for samples, targets in metric_logger.log_every(data_loader, print_freq, header): |
|
samples = samples.to(device) |
|
captions = [t["caption"] for t in targets] |
|
targets = utils.targets_to(targets, device) |
|
|
|
outputs = model(samples, captions, targets) |
|
loss_dict = criterion(outputs, targets) |
|
|
|
weight_dict = criterion.weight_dict |
|
losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) |
|
|
|
|
|
loss_dict_reduced = utils.reduce_dict(loss_dict) |
|
loss_dict_reduced_unscaled = {f'{k}_unscaled': v |
|
for k, v in loss_dict_reduced.items()} |
|
loss_dict_reduced_scaled = {k: v * weight_dict[k] |
|
for k, v in loss_dict_reduced.items() if k in weight_dict} |
|
losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) |
|
|
|
loss_value = losses_reduced_scaled.item() |
|
|
|
if not math.isfinite(loss_value): |
|
print("Loss is {}, stopping training".format(loss_value)) |
|
print(loss_dict_reduced) |
|
sys.exit(1) |
|
optimizer.zero_grad() |
|
losses.backward() |
|
if max_norm > 0: |
|
grad_total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) |
|
else: |
|
grad_total_norm = utils.get_total_grad_norm(model.parameters(), max_norm) |
|
optimizer.step() |
|
|
|
metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled) |
|
metric_logger.update(lr=optimizer.param_groups[0]["lr"]) |
|
metric_logger.update(grad_norm=grad_total_norm) |
|
|
|
|
|
metric_logger.synchronize_between_processes() |
|
print("Averaged stats:", metric_logger) |
|
return {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
|
|
|
|
|
@torch.no_grad() |
|
def evaluate(model, criterion, postprocessors, data_loader, evaluator_list, device, args): |
|
model.eval() |
|
criterion.eval() |
|
|
|
metric_logger = utils.MetricLogger(delimiter=" ") |
|
header = 'Test:' |
|
|
|
predictions = [] |
|
for samples, targets in metric_logger.log_every(data_loader, 10, header): |
|
dataset_name = targets[0]["dataset_name"] |
|
samples = samples.to(device) |
|
captions = [t["caption"] for t in targets] |
|
targets = utils.targets_to(targets, device) |
|
|
|
outputs = model(samples, captions, targets) |
|
loss_dict = criterion(outputs, targets) |
|
weight_dict = criterion.weight_dict |
|
|
|
|
|
loss_dict_reduced = utils.reduce_dict(loss_dict) |
|
loss_dict_reduced_scaled = {k: v * weight_dict[k] |
|
for k, v in loss_dict_reduced.items() if k in weight_dict} |
|
loss_dict_reduced_unscaled = {f'{k}_unscaled': v |
|
for k, v in loss_dict_reduced.items()} |
|
metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()), |
|
**loss_dict_reduced_scaled, |
|
**loss_dict_reduced_unscaled) |
|
|
|
orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) |
|
results = postprocessors['bbox'](outputs, orig_target_sizes) |
|
if 'segm' in postprocessors.keys(): |
|
target_sizes = torch.stack([t["size"] for t in targets], dim=0) |
|
results = postprocessors['segm'](results, outputs, orig_target_sizes, target_sizes) |
|
res = {target['image_id'].item(): output for target, output in zip(targets, results)} |
|
|
|
for evaluator in evaluator_list: |
|
evaluator.update(res) |
|
|
|
|
|
for p, target in zip(results, targets): |
|
for s, b, m in zip(p['scores'], p['boxes'], p['rle_masks']): |
|
predictions.append({'image_id': target['image_id'].item(), |
|
'category_id': 1, |
|
'bbox': b.tolist(), |
|
'segmentation': m, |
|
'score': s.item()}) |
|
|
|
|
|
|
|
metric_logger.synchronize_between_processes() |
|
print("Averaged stats:", metric_logger) |
|
for evaluator in evaluator_list: |
|
evaluator.synchronize_between_processes() |
|
|
|
|
|
refexp_res = None |
|
for evaluator in evaluator_list: |
|
if isinstance(evaluator, CocoEvaluator): |
|
evaluator.accumulate() |
|
evaluator.summarize() |
|
elif isinstance(evaluator, RefExpEvaluator): |
|
refexp_res = evaluator.summarize() |
|
|
|
stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
|
|
|
|
|
for evaluator in evaluator_list: |
|
if isinstance(evaluator, CocoEvaluator): |
|
if "bbox" in postprocessors.keys(): |
|
stats["coco_eval_bbox"] = evaluator.coco_eval["bbox"].stats.tolist() |
|
if "segm" in postprocessors.keys(): |
|
stats["coco_eval_masks"] = evaluator.coco_eval["segm"].stats.tolist() |
|
if refexp_res is not None: |
|
stats.update(refexp_res) |
|
|
|
|
|
|
|
gathered_pred_lists = utils.all_gather(predictions) |
|
predictions = [p for p_list in gathered_pred_lists for p in p_list] |
|
|
|
eval_metrics = {} |
|
if utils.is_main_process(): |
|
if dataset_name == 'refcoco': |
|
coco_gt = COCO(os.path.join(args.coco_path, 'refcoco/instances_refcoco_val.json')) |
|
elif dataset_name == 'refcoco+': |
|
coco_gt = COCO(os.path.join(args.coco_path, 'refcoco+/instances_refcoco+_val.json')) |
|
elif dataset_name == 'refcocog': |
|
coco_gt = COCO(os.path.join(args.coco_path, 'refcocog/instances_refcocog_val.json')) |
|
else: |
|
raise NotImplementedError |
|
coco_pred = coco_gt.loadRes(predictions) |
|
coco_eval = COCOeval(coco_gt, coco_pred, iouType='segm') |
|
coco_eval.params.useCats = 0 |
|
coco_eval.evaluate() |
|
coco_eval.accumulate() |
|
coco_eval.summarize() |
|
|
|
|
|
|
|
|
|
|
|
precision_at_k, overall_iou, mean_iou = calculate_bbox_precision_at_k_and_iou_metrics(coco_gt, coco_pred) |
|
eval_metrics.update({f'bbox P@{k}': m for k, m in zip([0.5, 0.6, 0.7, 0.8, 0.9], precision_at_k)}) |
|
eval_metrics.update({'bbox overall_iou': overall_iou, 'bbox mean_iou': mean_iou}) |
|
|
|
precision_at_k, overall_iou, mean_iou = calculate_precision_at_k_and_iou_metrics(coco_gt, coco_pred) |
|
eval_metrics.update({f'segm P@{k}': m for k, m in zip([0.5, 0.6, 0.7, 0.8, 0.9], precision_at_k)}) |
|
eval_metrics.update({'segm overall_iou': overall_iou, 'segm mean_iou': mean_iou}) |
|
print(eval_metrics) |
|
stats.update(eval_metrics) |
|
|
|
return stats |
|
|
|
|
|
@torch.no_grad() |
|
def evaluate_a2d(model, data_loader, postprocessor, device, args): |
|
model.eval() |
|
predictions = [] |
|
metric_logger = utils.MetricLogger(delimiter=" ") |
|
header = 'Test:' |
|
|
|
for samples, targets in metric_logger.log_every(data_loader, 10, header): |
|
image_ids = [t['image_id'] for t in targets] |
|
|
|
samples = samples.to(device) |
|
captions = [t["caption"] for t in targets] |
|
targets = utils.targets_to(targets, device) |
|
|
|
outputs = model(samples, captions, targets) |
|
|
|
orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) |
|
target_sizes = torch.stack([t["size"] for t in targets], dim=0) |
|
processed_outputs = postprocessor(outputs, orig_target_sizes, target_sizes) |
|
|
|
for p, image_id in zip(processed_outputs, image_ids): |
|
for s, m in zip(p['scores'], p['rle_masks']): |
|
predictions.append({'image_id': image_id, |
|
'category_id': 1, |
|
'segmentation': m, |
|
'score': s.item()}) |
|
|
|
|
|
gathered_pred_lists = utils.all_gather(predictions) |
|
predictions = [p for p_list in gathered_pred_lists for p in p_list] |
|
|
|
eval_metrics = {} |
|
if utils.is_main_process(): |
|
if args.dataset_file == 'a2d': |
|
coco_gt = COCO(os.path.join(args.a2d_path, 'a2d_sentences_test_annotations_in_coco_format.json')) |
|
elif args.dataset_file == 'jhmdb': |
|
coco_gt = COCO(os.path.join(args.jhmdb_path, 'jhmdb_sentences_gt_annotations_in_coco_format.json')) |
|
else: |
|
raise NotImplementedError |
|
coco_pred = coco_gt.loadRes(predictions) |
|
coco_eval = COCOeval(coco_gt, coco_pred, iouType='segm') |
|
coco_eval.params.useCats = 0 |
|
coco_eval.evaluate() |
|
coco_eval.accumulate() |
|
coco_eval.summarize() |
|
ap_labels = ['mAP 0.5:0.95', 'AP 0.5', 'AP 0.75', 'AP 0.5:0.95 S', 'AP 0.5:0.95 M', 'AP 0.5:0.95 L'] |
|
ap_metrics = coco_eval.stats[:6] |
|
eval_metrics = {l: m for l, m in zip(ap_labels, ap_metrics)} |
|
|
|
precision_at_k, overall_iou, mean_iou = calculate_precision_at_k_and_iou_metrics(coco_gt, coco_pred) |
|
eval_metrics.update({f'P@{k}': m for k, m in zip([0.5, 0.6, 0.7, 0.8, 0.9], precision_at_k)}) |
|
eval_metrics.update({'overall_iou': overall_iou, 'mean_iou': mean_iou}) |
|
print(eval_metrics) |
|
|
|
|
|
dist.barrier() |
|
return eval_metrics |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|