import os import time import json import pprint import random import numpy as np from easydict import EasyDict as EDict from tqdm import tqdm, trange from collections import OrderedDict import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from baselines.excl.config import BaseOptions from baselines.excl.model import EXCL from baselines.excl.start_end_dataset import \ ExCLDataset, start_end_collate, prepare_batch_inputs from baselines.excl.inference import eval_epoch, start_inference from utils.basic_utils import AverageMeter from utils.model_utils import count_parameters import logging logger = logging.getLogger(__name__) logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) def set_seed(seed, use_cuda=True): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if use_cuda: torch.cuda.manual_seed_all(seed) def train_epoch(model, train_loader, optimizer, opt, epoch_i, training=True): logger.info("use train_epoch func for training: {}".format(training)) model.train(mode=training) # init meters dataloading_time = AverageMeter() prepare_inputs_time = AverageMeter() model_forward_time = AverageMeter() model_backward_time = AverageMeter() loss_meters = OrderedDict(loss_st_ed=AverageMeter()) num_training_examples = len(train_loader) timer_dataloading = time.time() for batch_idx, batch in tqdm(enumerate(train_loader), desc="Training Iteration", total=num_training_examples): global_step = epoch_i * num_training_examples + batch_idx dataloading_time.update(time.time() - timer_dataloading) # continue timer_start = time.time() model_inputs = prepare_batch_inputs(batch[1], opt.device, non_blocking=opt.pin_memory) prepare_inputs_time.update(time.time() - timer_start) # logger.info("model_inputs {}" # .format({k: (type(k), v.shape if isinstance(v, torch.Tensor) else v) # for k, v in model_inputs.items()})) # logger.info("model_inputs \n{}".format({k: (type(v), v.shape, v.dtype) for k, v in model_inputs.items()})) timer_start = time.time() loss, loss_dict, _, _ = model(**model_inputs) model_forward_time.update(time.time() - timer_start) timer_start = time.time() if training: optimizer.zero_grad() loss.backward() if opt.grad_clip != -1: nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) optimizer.step() model_backward_time.update(time.time() - timer_start) opt.writer.add_scalar("Train/LR", float(optimizer.param_groups[0]["lr"]), global_step) for k, v in loss_dict.items(): opt.writer.add_scalar("Train/{}".format(k), v, global_step) for k, v in loss_dict.items(): loss_meters[k].update(float(v)) timer_dataloading = time.time() if opt.debug and batch_idx == 3: break if training: to_write = opt.train_log_txt_formatter.format( time_str=time.strftime("%Y_%m_%d_%H_%M_%S"), epoch=epoch_i, loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in loss_meters.items()])) with open(opt.train_log_filepath, "a") as f: f.write(to_write) print("Epoch time stats:") print("dataloading_time: max {dataloading_time.max} " "min {dataloading_time.min} avg {dataloading_time.avg}\n" "prepare_inputs_time: max {prepare_inputs_time.max} " "min {prepare_inputs_time.min} avg {prepare_inputs_time.avg}\n" "model_forward_time: max {model_forward_time.max} " "min {model_forward_time.min} avg {model_forward_time.avg}\n" "model_backward_time: max {model_backward_time.max} " "min {model_backward_time.min} avg {model_backward_time.avg}\n" "".format(dataloading_time=dataloading_time, prepare_inputs_time=prepare_inputs_time, model_forward_time=model_forward_time, model_backward_time=model_backward_time)) else: for k, v in loss_meters.items(): opt.writer.add_scalar("Eval_Loss/{}".format(k), v.avg, epoch_i) def rm_key_from_odict(odict_obj, rm_suffix): """remove key entry from the OrderedDict""" return OrderedDict([(k, v) for k, v in odict_obj.items() if rm_suffix not in k]) def train(model, train_dataset, val_dataset, opt): # Prepare optimizer if opt.device.type == "cuda": logger.info("CUDA enabled.") model.to(opt.device) if len(opt.device_ids) > 1: logger.info("Use multi GPU", opt.device_ids) model = torch.nn.DataParallel(model, device_ids=opt.device_ids) # use multi GPU train_loader = DataLoader(train_dataset, collate_fn=start_end_collate, batch_size=opt.bsz, num_workers=opt.num_workers, shuffle=True, pin_memory=opt.pin_memory) # Prepare optimizer optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=opt.lr) prev_best_score = 0. es_cnt = 0 start_epoch = -1 if opt.eval_untrained else 0 eval_tasks_at_training = opt.eval_tasks_at_training # VR is computed along with VCMR save_submission_filename = \ "latest_{}_{}_predictions_{}.json".format(opt.dset_name, opt.eval_split_name, "_".join(eval_tasks_at_training)) for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"): if epoch_i > -1: with torch.autograd.detect_anomaly(): train_epoch(model, train_loader, optimizer, opt, epoch_i, training=True) global_step = (epoch_i + 1) * len(train_loader) if opt.eval_path is not None: with torch.no_grad(): metrics_no_nms, metrics_nms, latest_file_paths = \ eval_epoch(model, val_dataset, opt, save_submission_filename, tasks=eval_tasks_at_training, max_after_nms=100) to_write = opt.eval_log_txt_formatter.format( time_str=time.strftime("%Y_%m_%d_%H_%M_%S"), epoch=epoch_i, eval_metrics_str=json.dumps(metrics_no_nms)) with open(opt.eval_log_filepath, "a") as f: f.write(to_write) logger.info("metrics_no_nms {}".format(pprint.pformat(rm_key_from_odict(metrics_no_nms, rm_suffix="by_type"), indent=4))) logger.info("metrics_nms {}".format(pprint.pformat(metrics_nms, indent=4))) # metrics = metrics_nms if metrics_nms is not None else metrics_no_nms metrics = metrics_no_nms # early stop/ log / save model for task_type in ["SVMR", "VCMR"]: if task_type in metrics: task_metrics = metrics[task_type] for iou_thd in [0.5, 0.7]: opt.writer.add_scalars("Eval/{}-{}".format(task_type, iou_thd), {k: v for k, v in task_metrics.items() if str(iou_thd) in k}, global_step) task_type = "VR" if task_type in metrics: task_metrics = metrics[task_type] opt.writer.add_scalars("Eval/{}".format(task_type), {k: v for k, v in task_metrics.items()}, global_step) # use the most strict metric available stop_metric_names = ["r1"] if opt.stop_task == "VR" else ["0.5-r1", "0.7-r1"] stop_score = sum([metrics[opt.stop_task][e] for e in stop_metric_names]) if stop_score > prev_best_score: es_cnt = 0 prev_best_score = stop_score checkpoint = { "model": model.state_dict(), "model_cfg": model.config, "epoch": epoch_i} torch.save(checkpoint, opt.ckpt_filepath) best_file_paths = [e.replace("latest", "best") for e in latest_file_paths] for src, tgt in zip(latest_file_paths, best_file_paths): os.renames(src, tgt) logger.info("The checkpoint file has been updated.") else: es_cnt += 1 if opt.max_es_cnt != -1 and es_cnt > opt.max_es_cnt: # early stop with open(opt.train_log_filepath, "a") as f: f.write("Early Stop at epoch {}".format(epoch_i)) logger.info("Early stop at {} with {} {}" .format(epoch_i, " ".join([opt.stop_task] + stop_metric_names), prev_best_score)) break else: checkpoint = { "model": model.state_dict(), "model_cfg": model.config, "epoch": epoch_i} torch.save(checkpoint, opt.ckpt_filepath) if opt.debug: break opt.writer.close() def start_training(): logger.info("Setup config, data and model...") opt = BaseOptions().parse() set_seed(opt.seed) if opt.debug: # keep the model run deterministically # 'cudnn.benchmark = True' enabled auto finding the best algorithm for a specific input/net config. # Enable this only when input size is fixed. cudnn.benchmark = False cudnn.deterministic = True opt.writer = SummaryWriter(opt.tensorboard_log_dir) opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n" opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Metrics] {eval_metrics_str}\n" train_dataset = ExCLDataset( dset_name=opt.dset_name, data_path=opt.train_path, desc_bert_path_or_handler=opt.desc_bert_path, sub_bert_path_or_handler=opt.sub_bert_path, max_desc_len=opt.max_desc_l, max_ctx_len=opt.max_ctx_l, vid_feat_path_or_handler=opt.vid_feat_path, clip_length=opt.clip_length, ctx_mode=opt.ctx_mode, h5driver=opt.h5driver, data_ratio=opt.data_ratio, normalize_vfeat=not opt.no_norm_vfeat, normalize_tfeat=not opt.no_norm_tfeat, ) if opt.eval_path is not None: eval_dataset = ExCLDataset( dset_name=opt.dset_name, data_path=opt.eval_path, desc_bert_path_or_handler=train_dataset.desc_bert_h5, sub_bert_path_or_handler=train_dataset.sub_bert_h5 if "sub" in opt.ctx_mode else None, max_desc_len=opt.max_desc_l, max_ctx_len=opt.max_ctx_l, vid_feat_path_or_handler=train_dataset.vid_feat_h5 if "video" in opt.ctx_mode else None, clip_length=opt.clip_length, ctx_mode=opt.ctx_mode, h5driver=opt.h5driver, data_ratio=opt.data_ratio, normalize_vfeat=not opt.no_norm_vfeat, normalize_tfeat=not opt.no_norm_tfeat, corpus_path=opt.corpus_path, eval_split_name=opt.eval_split_name ) else: eval_dataset = None model_config = EDict( visual_input_size=opt.vid_feat_size, sub_input_size=opt.sub_feat_size, # for both desc and subtitles query_input_size=opt.q_feat_size, # for both desc and subtitles hidden_size=opt.hidden_size, drop=opt.drop, ctx_mode=opt.ctx_mode, # video, sub or video_sub initializer_range=opt.initializer_range ) logger.info("model_config {}".format(model_config)) model = EXCL(model_config) count_parameters(model) logger.info("Start Training...") train(model, train_dataset, eval_dataset, opt) return opt.results_dir, opt.eval_split_name, opt.eval_path, opt.debug if __name__ == '__main__': model_dir, eval_split_name, eval_path, debug = start_training() if not debug: model_dir = model_dir.split(os.sep)[-1] tasks = ["SVMR"] input_args = ["--model_dir", model_dir, "--eval_split_name", eval_split_name, "--eval_path", eval_path, "--tasks"] + tasks import sys sys.argv[1:] = input_args logger.info("\n\n\nFINISHED TRAINING!!!") logger.info("Evaluating model in {}".format(model_dir)) logger.info("Input args {}".format(sys.argv[1:])) start_inference()