""" Train and eval functions used in main.py Modified from DETR (https://github.com/facebookresearch/detr) """ import math from models import postprocessors import os import sys from typing import Iterable import torch import torch.distributed as dist import util.misc as utils from datasets.coco_eval import CocoEvaluator from datasets.refexp_eval import RefExpEvaluator from pycocotools.coco import COCO from pycocotools.cocoeval import COCOeval from datasets.a2d_eval import calculate_precision_at_k_and_iou_metrics, calculate_bbox_precision_at_k_and_iou_metrics def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer, device: torch.device, epoch: int, max_norm: float = 0): model.train() criterion.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) header = 'Epoch: [{}]'.format(epoch) print_freq = 10 for samples, targets in metric_logger.log_every(data_loader, print_freq, header): samples = samples.to(device) captions = [t["caption"] for t in targets] targets = utils.targets_to(targets, device) outputs = model(samples, captions, targets) loss_dict = criterion(outputs, targets) weight_dict = criterion.weight_dict losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) # reduce losses over all GPUs for logging purposes loss_dict_reduced = utils.reduce_dict(loss_dict) loss_dict_reduced_unscaled = {f'{k}_unscaled': v for k, v in loss_dict_reduced.items()} loss_dict_reduced_scaled = {k: v * weight_dict[k] for k, v in loss_dict_reduced.items() if k in weight_dict} losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) loss_value = losses_reduced_scaled.item() if not math.isfinite(loss_value): print("Loss is {}, stopping training".format(loss_value)) print(loss_dict_reduced) sys.exit(1) optimizer.zero_grad() losses.backward() if max_norm > 0: grad_total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) else: grad_total_norm = utils.get_total_grad_norm(model.parameters(), max_norm) optimizer.step() metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled) metric_logger.update(lr=optimizer.param_groups[0]["lr"]) metric_logger.update(grad_norm=grad_total_norm) # gather the stats from all processes metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) return {k: meter.global_avg for k, meter in metric_logger.meters.items()} @torch.no_grad() def evaluate(model, criterion, postprocessors, data_loader, evaluator_list, device, args): model.eval() criterion.eval() metric_logger = utils.MetricLogger(delimiter=" ") header = 'Test:' predictions = [] for samples, targets in metric_logger.log_every(data_loader, 10, header): dataset_name = targets[0]["dataset_name"] samples = samples.to(device) captions = [t["caption"] for t in targets] targets = utils.targets_to(targets, device) outputs = model(samples, captions, targets) loss_dict = criterion(outputs, targets) weight_dict = criterion.weight_dict # reduce losses over all GPUs for logging purposes loss_dict_reduced = utils.reduce_dict(loss_dict) loss_dict_reduced_scaled = {k: v * weight_dict[k] for k, v in loss_dict_reduced.items() if k in weight_dict} loss_dict_reduced_unscaled = {f'{k}_unscaled': v for k, v in loss_dict_reduced.items()} metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()), **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled) orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) results = postprocessors['bbox'](outputs, orig_target_sizes) if 'segm' in postprocessors.keys(): target_sizes = torch.stack([t["size"] for t in targets], dim=0) results = postprocessors['segm'](results, outputs, orig_target_sizes, target_sizes) res = {target['image_id'].item(): output for target, output in zip(targets, results)} for evaluator in evaluator_list: evaluator.update(res) # REC & RES predictions for p, target in zip(results, targets): for s, b, m in zip(p['scores'], p['boxes'], p['rle_masks']): predictions.append({'image_id': target['image_id'].item(), 'category_id': 1, # dummy label, as categories are not predicted in ref-vos 'bbox': b.tolist(), 'segmentation': m, 'score': s.item()}) # gather the stats from all processes metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) for evaluator in evaluator_list: evaluator.synchronize_between_processes() # accumulate predictions from all images refexp_res = None for evaluator in evaluator_list: if isinstance(evaluator, CocoEvaluator): evaluator.accumulate() evaluator.summarize() elif isinstance(evaluator, RefExpEvaluator): refexp_res = evaluator.summarize() stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} # update stats for evaluator in evaluator_list: if isinstance(evaluator, CocoEvaluator): if "bbox" in postprocessors.keys(): stats["coco_eval_bbox"] = evaluator.coco_eval["bbox"].stats.tolist() if "segm" in postprocessors.keys(): stats["coco_eval_masks"] = evaluator.coco_eval["segm"].stats.tolist() if refexp_res is not None: stats.update(refexp_res) # evaluate RES # gather and merge predictions from all gpus gathered_pred_lists = utils.all_gather(predictions) predictions = [p for p_list in gathered_pred_lists for p in p_list] eval_metrics = {} if utils.is_main_process(): if dataset_name == 'refcoco': coco_gt = COCO(os.path.join(args.coco_path, 'refcoco/instances_refcoco_val.json')) elif dataset_name == 'refcoco+': coco_gt = COCO(os.path.join(args.coco_path, 'refcoco+/instances_refcoco+_val.json')) elif dataset_name == 'refcocog': coco_gt = COCO(os.path.join(args.coco_path, 'refcocog/instances_refcocog_val.json')) else: raise NotImplementedError coco_pred = coco_gt.loadRes(predictions) coco_eval = COCOeval(coco_gt, coco_pred, iouType='segm') coco_eval.params.useCats = 0 # ignore categories as they are not predicted in ref-vos task coco_eval.evaluate() coco_eval.accumulate() coco_eval.summarize() # ap_labels = ['mAP 0.5:0.95', 'AP 0.5', 'AP 0.75', 'AP 0.5:0.95 S', 'AP 0.5:0.95 M', 'AP 0.5:0.95 L'] # ap_metrics = coco_eval.stats[:6] # eval_metrics = {l: m for l, m in zip(ap_labels, ap_metrics)} # Precision and IOU # bbox precision_at_k, overall_iou, mean_iou = calculate_bbox_precision_at_k_and_iou_metrics(coco_gt, coco_pred) eval_metrics.update({f'bbox P@{k}': m for k, m in zip([0.5, 0.6, 0.7, 0.8, 0.9], precision_at_k)}) eval_metrics.update({'bbox overall_iou': overall_iou, 'bbox mean_iou': mean_iou}) # mask precision_at_k, overall_iou, mean_iou = calculate_precision_at_k_and_iou_metrics(coco_gt, coco_pred) eval_metrics.update({f'segm P@{k}': m for k, m in zip([0.5, 0.6, 0.7, 0.8, 0.9], precision_at_k)}) eval_metrics.update({'segm overall_iou': overall_iou, 'segm mean_iou': mean_iou}) print(eval_metrics) stats.update(eval_metrics) return stats @torch.no_grad() def evaluate_a2d(model, data_loader, postprocessor, device, args): model.eval() predictions = [] metric_logger = utils.MetricLogger(delimiter=" ") header = 'Test:' for samples, targets in metric_logger.log_every(data_loader, 10, header): image_ids = [t['image_id'] for t in targets] samples = samples.to(device) captions = [t["caption"] for t in targets] targets = utils.targets_to(targets, device) outputs = model(samples, captions, targets) orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) target_sizes = torch.stack([t["size"] for t in targets], dim=0) processed_outputs = postprocessor(outputs, orig_target_sizes, target_sizes) for p, image_id in zip(processed_outputs, image_ids): for s, m in zip(p['scores'], p['rle_masks']): predictions.append({'image_id': image_id, 'category_id': 1, # dummy label, as categories are not predicted in ref-vos 'segmentation': m, 'score': s.item()}) # gather and merge predictions from all gpus gathered_pred_lists = utils.all_gather(predictions) predictions = [p for p_list in gathered_pred_lists for p in p_list] # evaluation eval_metrics = {} if utils.is_main_process(): if args.dataset_file == 'a2d': coco_gt = COCO(os.path.join(args.a2d_path, 'a2d_sentences_test_annotations_in_coco_format.json')) elif args.dataset_file == 'jhmdb': coco_gt = COCO(os.path.join(args.jhmdb_path, 'jhmdb_sentences_gt_annotations_in_coco_format.json')) else: raise NotImplementedError coco_pred = coco_gt.loadRes(predictions) coco_eval = COCOeval(coco_gt, coco_pred, iouType='segm') coco_eval.params.useCats = 0 # ignore categories as they are not predicted in ref-vos task coco_eval.evaluate() coco_eval.accumulate() coco_eval.summarize() ap_labels = ['mAP 0.5:0.95', 'AP 0.5', 'AP 0.75', 'AP 0.5:0.95 S', 'AP 0.5:0.95 M', 'AP 0.5:0.95 L'] ap_metrics = coco_eval.stats[:6] eval_metrics = {l: m for l, m in zip(ap_labels, ap_metrics)} # Precision and IOU precision_at_k, overall_iou, mean_iou = calculate_precision_at_k_and_iou_metrics(coco_gt, coco_pred) eval_metrics.update({f'P@{k}': m for k, m in zip([0.5, 0.6, 0.7, 0.8, 0.9], precision_at_k)}) eval_metrics.update({'overall_iou': overall_iou, 'mean_iou': mean_iou}) print(eval_metrics) # sync all processes before starting a new epoch or exiting dist.barrier() return eval_metrics