|
import os |
|
import time |
|
import json |
|
import pprint |
|
import random |
|
import numpy as np |
|
from easydict import EasyDict as EDict |
|
from tqdm import tqdm, trange |
|
from collections import OrderedDict |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.backends.cudnn as cudnn |
|
from torch.utils.data import DataLoader |
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
from baselines.excl.config import BaseOptions |
|
from baselines.excl.model import EXCL |
|
from baselines.excl.start_end_dataset import \ |
|
ExCLDataset, start_end_collate, prepare_batch_inputs |
|
from baselines.excl.inference import eval_epoch, start_inference |
|
from utils.basic_utils import AverageMeter |
|
from utils.model_utils import count_parameters |
|
|
|
|
|
import logging |
|
logger = logging.getLogger(__name__) |
|
logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s", |
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
level=logging.INFO) |
|
|
|
|
|
def set_seed(seed, use_cuda=True): |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
if use_cuda: |
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
def train_epoch(model, train_loader, optimizer, opt, epoch_i, training=True): |
|
logger.info("use train_epoch func for training: {}".format(training)) |
|
model.train(mode=training) |
|
|
|
|
|
dataloading_time = AverageMeter() |
|
prepare_inputs_time = AverageMeter() |
|
model_forward_time = AverageMeter() |
|
model_backward_time = AverageMeter() |
|
loss_meters = OrderedDict(loss_st_ed=AverageMeter()) |
|
|
|
num_training_examples = len(train_loader) |
|
timer_dataloading = time.time() |
|
for batch_idx, batch in tqdm(enumerate(train_loader), |
|
desc="Training Iteration", |
|
total=num_training_examples): |
|
global_step = epoch_i * num_training_examples + batch_idx |
|
dataloading_time.update(time.time() - timer_dataloading) |
|
|
|
|
|
timer_start = time.time() |
|
model_inputs = prepare_batch_inputs(batch[1], opt.device, non_blocking=opt.pin_memory) |
|
prepare_inputs_time.update(time.time() - timer_start) |
|
|
|
|
|
|
|
|
|
timer_start = time.time() |
|
loss, loss_dict, _, _ = model(**model_inputs) |
|
model_forward_time.update(time.time() - timer_start) |
|
timer_start = time.time() |
|
if training: |
|
optimizer.zero_grad() |
|
loss.backward() |
|
if opt.grad_clip != -1: |
|
nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) |
|
optimizer.step() |
|
model_backward_time.update(time.time() - timer_start) |
|
|
|
opt.writer.add_scalar("Train/LR", float(optimizer.param_groups[0]["lr"]), global_step) |
|
for k, v in loss_dict.items(): |
|
opt.writer.add_scalar("Train/{}".format(k), v, global_step) |
|
|
|
for k, v in loss_dict.items(): |
|
loss_meters[k].update(float(v)) |
|
|
|
timer_dataloading = time.time() |
|
if opt.debug and batch_idx == 3: |
|
break |
|
|
|
if training: |
|
to_write = opt.train_log_txt_formatter.format( |
|
time_str=time.strftime("%Y_%m_%d_%H_%M_%S"), |
|
epoch=epoch_i, |
|
loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in loss_meters.items()])) |
|
with open(opt.train_log_filepath, "a") as f: |
|
f.write(to_write) |
|
print("Epoch time stats:") |
|
print("dataloading_time: max {dataloading_time.max} " |
|
"min {dataloading_time.min} avg {dataloading_time.avg}\n" |
|
"prepare_inputs_time: max {prepare_inputs_time.max} " |
|
"min {prepare_inputs_time.min} avg {prepare_inputs_time.avg}\n" |
|
"model_forward_time: max {model_forward_time.max} " |
|
"min {model_forward_time.min} avg {model_forward_time.avg}\n" |
|
"model_backward_time: max {model_backward_time.max} " |
|
"min {model_backward_time.min} avg {model_backward_time.avg}\n" |
|
"".format(dataloading_time=dataloading_time, prepare_inputs_time=prepare_inputs_time, |
|
model_forward_time=model_forward_time, model_backward_time=model_backward_time)) |
|
else: |
|
for k, v in loss_meters.items(): |
|
opt.writer.add_scalar("Eval_Loss/{}".format(k), v.avg, epoch_i) |
|
|
|
|
|
def rm_key_from_odict(odict_obj, rm_suffix): |
|
"""remove key entry from the OrderedDict""" |
|
return OrderedDict([(k, v) for k, v in odict_obj.items() if rm_suffix not in k]) |
|
|
|
|
|
def train(model, train_dataset, val_dataset, opt): |
|
|
|
if opt.device.type == "cuda": |
|
logger.info("CUDA enabled.") |
|
model.to(opt.device) |
|
if len(opt.device_ids) > 1: |
|
logger.info("Use multi GPU", opt.device_ids) |
|
model = torch.nn.DataParallel(model, device_ids=opt.device_ids) |
|
|
|
train_loader = DataLoader(train_dataset, |
|
collate_fn=start_end_collate, |
|
batch_size=opt.bsz, |
|
num_workers=opt.num_workers, |
|
shuffle=True, |
|
pin_memory=opt.pin_memory) |
|
|
|
|
|
optimizer = torch.optim.Adam( |
|
filter(lambda p: p.requires_grad, model.parameters()), |
|
lr=opt.lr) |
|
|
|
prev_best_score = 0. |
|
es_cnt = 0 |
|
start_epoch = -1 if opt.eval_untrained else 0 |
|
eval_tasks_at_training = opt.eval_tasks_at_training |
|
save_submission_filename = \ |
|
"latest_{}_{}_predictions_{}.json".format(opt.dset_name, opt.eval_split_name, "_".join(eval_tasks_at_training)) |
|
for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"): |
|
if epoch_i > -1: |
|
with torch.autograd.detect_anomaly(): |
|
train_epoch(model, train_loader, optimizer, opt, epoch_i, training=True) |
|
global_step = (epoch_i + 1) * len(train_loader) |
|
if opt.eval_path is not None: |
|
with torch.no_grad(): |
|
metrics_no_nms, metrics_nms, latest_file_paths = \ |
|
eval_epoch(model, val_dataset, opt, save_submission_filename, |
|
tasks=eval_tasks_at_training, max_after_nms=100) |
|
to_write = opt.eval_log_txt_formatter.format( |
|
time_str=time.strftime("%Y_%m_%d_%H_%M_%S"), |
|
epoch=epoch_i, |
|
eval_metrics_str=json.dumps(metrics_no_nms)) |
|
with open(opt.eval_log_filepath, "a") as f: |
|
f.write(to_write) |
|
logger.info("metrics_no_nms {}".format(pprint.pformat(rm_key_from_odict(metrics_no_nms, rm_suffix="by_type"), indent=4))) |
|
logger.info("metrics_nms {}".format(pprint.pformat(metrics_nms, indent=4))) |
|
|
|
|
|
metrics = metrics_no_nms |
|
|
|
for task_type in ["SVMR", "VCMR"]: |
|
if task_type in metrics: |
|
task_metrics = metrics[task_type] |
|
for iou_thd in [0.5, 0.7]: |
|
opt.writer.add_scalars("Eval/{}-{}".format(task_type, iou_thd), |
|
{k: v for k, v in task_metrics.items() if str(iou_thd) in k}, |
|
global_step) |
|
|
|
task_type = "VR" |
|
if task_type in metrics: |
|
task_metrics = metrics[task_type] |
|
opt.writer.add_scalars("Eval/{}".format(task_type), |
|
{k: v for k, v in task_metrics.items()}, |
|
global_step) |
|
|
|
|
|
stop_metric_names = ["r1"] if opt.stop_task == "VR" else ["0.5-r1", "0.7-r1"] |
|
stop_score = sum([metrics[opt.stop_task][e] for e in stop_metric_names]) |
|
|
|
if stop_score > prev_best_score: |
|
es_cnt = 0 |
|
prev_best_score = stop_score |
|
|
|
checkpoint = { |
|
"model": model.state_dict(), |
|
"model_cfg": model.config, |
|
"epoch": epoch_i} |
|
torch.save(checkpoint, opt.ckpt_filepath) |
|
|
|
best_file_paths = [e.replace("latest", "best") for e in latest_file_paths] |
|
for src, tgt in zip(latest_file_paths, best_file_paths): |
|
os.renames(src, tgt) |
|
logger.info("The checkpoint file has been updated.") |
|
else: |
|
es_cnt += 1 |
|
if opt.max_es_cnt != -1 and es_cnt > opt.max_es_cnt: |
|
with open(opt.train_log_filepath, "a") as f: |
|
f.write("Early Stop at epoch {}".format(epoch_i)) |
|
logger.info("Early stop at {} with {} {}" |
|
.format(epoch_i, " ".join([opt.stop_task] + stop_metric_names), prev_best_score)) |
|
break |
|
else: |
|
checkpoint = { |
|
"model": model.state_dict(), |
|
"model_cfg": model.config, |
|
"epoch": epoch_i} |
|
torch.save(checkpoint, opt.ckpt_filepath) |
|
|
|
if opt.debug: |
|
break |
|
|
|
opt.writer.close() |
|
|
|
|
|
def start_training(): |
|
logger.info("Setup config, data and model...") |
|
opt = BaseOptions().parse() |
|
set_seed(opt.seed) |
|
if opt.debug: |
|
|
|
|
|
cudnn.benchmark = False |
|
cudnn.deterministic = True |
|
|
|
opt.writer = SummaryWriter(opt.tensorboard_log_dir) |
|
opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n" |
|
opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Metrics] {eval_metrics_str}\n" |
|
|
|
train_dataset = ExCLDataset( |
|
dset_name=opt.dset_name, |
|
data_path=opt.train_path, |
|
desc_bert_path_or_handler=opt.desc_bert_path, |
|
sub_bert_path_or_handler=opt.sub_bert_path, |
|
max_desc_len=opt.max_desc_l, |
|
max_ctx_len=opt.max_ctx_l, |
|
vid_feat_path_or_handler=opt.vid_feat_path, |
|
clip_length=opt.clip_length, |
|
ctx_mode=opt.ctx_mode, |
|
h5driver=opt.h5driver, |
|
data_ratio=opt.data_ratio, |
|
normalize_vfeat=not opt.no_norm_vfeat, |
|
normalize_tfeat=not opt.no_norm_tfeat, |
|
) |
|
|
|
if opt.eval_path is not None: |
|
eval_dataset = ExCLDataset( |
|
dset_name=opt.dset_name, |
|
data_path=opt.eval_path, |
|
desc_bert_path_or_handler=train_dataset.desc_bert_h5, |
|
sub_bert_path_or_handler=train_dataset.sub_bert_h5 if "sub" in opt.ctx_mode else None, |
|
max_desc_len=opt.max_desc_l, |
|
max_ctx_len=opt.max_ctx_l, |
|
vid_feat_path_or_handler=train_dataset.vid_feat_h5 if "video" in opt.ctx_mode else None, |
|
clip_length=opt.clip_length, |
|
ctx_mode=opt.ctx_mode, |
|
h5driver=opt.h5driver, |
|
data_ratio=opt.data_ratio, |
|
normalize_vfeat=not opt.no_norm_vfeat, |
|
normalize_tfeat=not opt.no_norm_tfeat, |
|
corpus_path=opt.corpus_path, |
|
eval_split_name=opt.eval_split_name |
|
) |
|
else: |
|
eval_dataset = None |
|
|
|
model_config = EDict( |
|
visual_input_size=opt.vid_feat_size, |
|
sub_input_size=opt.sub_feat_size, |
|
query_input_size=opt.q_feat_size, |
|
hidden_size=opt.hidden_size, |
|
drop=opt.drop, |
|
ctx_mode=opt.ctx_mode, |
|
initializer_range=opt.initializer_range |
|
) |
|
logger.info("model_config {}".format(model_config)) |
|
model = EXCL(model_config) |
|
count_parameters(model) |
|
logger.info("Start Training...") |
|
train(model, train_dataset, eval_dataset, opt) |
|
return opt.results_dir, opt.eval_split_name, opt.eval_path, opt.debug |
|
|
|
|
|
if __name__ == '__main__': |
|
model_dir, eval_split_name, eval_path, debug = start_training() |
|
if not debug: |
|
model_dir = model_dir.split(os.sep)[-1] |
|
tasks = ["SVMR"] |
|
input_args = ["--model_dir", model_dir, |
|
"--eval_split_name", eval_split_name, |
|
"--eval_path", eval_path, |
|
"--tasks"] + tasks |
|
|
|
import sys |
|
sys.argv[1:] = input_args |
|
logger.info("\n\n\nFINISHED TRAINING!!!") |
|
logger.info("Evaluating model in {}".format(model_dir)) |
|
logger.info("Input args {}".format(sys.argv[1:])) |
|
start_inference() |
|
|