Liangrj5
init
ebf5d87
import os
import time
import json
import pprint
import random
import numpy as np
from collections import OrderedDict
from easydict import EasyDict as EDict
from tqdm import tqdm, trange
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from baselines.mixture_embedding_experts.config import BaseOptions
from baselines.mixture_embedding_experts.model import MEE
from baselines.mixture_embedding_experts.retrieval_dataset import \
RetrievalDataset, retrieval_collate, RetrievalEvalDataset, prepare_batch_inputs
from baselines.mixture_embedding_experts.inference import eval_epoch, start_inference
from utils.basic_utils import save_jsonl, save_json, AverageMeter
from utils.model_utils import count_parameters
import logging
logger = logging.getLogger(__name__)
logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.INFO)
def set_seed(seed, use_cuda=True):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if use_cuda:
torch.cuda.manual_seed_all(seed)
def train_epoch(model, train_loader, optimizer, opt, epoch_i):
model.train()
# init meters
dataloading_time = AverageMeter()
prepare_inputs_time = AverageMeter()
model_forward_time = AverageMeter()
model_backward_time = AverageMeter()
loss_meter = AverageMeter()
num_training_examples = len(train_loader)
timer_dataloading = time.time()
for batch_idx, batch in tqdm(enumerate(train_loader),
desc="Training Iteration",
total=num_training_examples):
dataloading_time.update(time.time() - timer_dataloading)
# continue
timer_start = time.time()
model_inputs = prepare_batch_inputs(batch[1], opt.device, non_blocking=opt.pin_memory)
prepare_inputs_time.update(time.time() - timer_start)
timer_start = time.time()
loss = model(**model_inputs)
model_forward_time.update(time.time() - timer_start)
timer_start = time.time()
optimizer.zero_grad()
loss.backward()
if opt.grad_clip != -1:
nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
optimizer.step()
model_backward_time.update(time.time() - timer_start)
global_step = epoch_i * num_training_examples + batch_idx
opt.writer.add_scalar("Train/LR", float(optimizer.param_groups[0]["lr"]), global_step)
opt.writer.add_scalar("Train/Loss", float(loss), global_step)
loss_meter.update(float(loss))
timer_dataloading = time.time()
if opt.debug and batch_idx == 3:
break
to_write = opt.train_log_txt_formatter.format(
time_str=time.strftime("%Y_%m_%d_%H_%M_%S"),
epoch=epoch_i,
loss_str=str(loss_meter.avg))
with open(opt.train_log_filepath, "a") as f:
f.write(to_write)
print("Epoch time stats:")
print("dataloading_time: max {dataloading_time.max} "
"min {dataloading_time.min} avg {dataloading_time.avg}\n"
"prepare_inputs_time: max {prepare_inputs_time.max} "
"min {prepare_inputs_time.min} avg {prepare_inputs_time.avg}\n"
"model_forward_time: max {model_forward_time.max} "
"min {model_forward_time.min} avg {model_forward_time.avg}\n"
"model_backward_time: max {model_backward_time.max} "
"min {model_backward_time.min} avg {model_backward_time.avg}\n"
"".format(dataloading_time=dataloading_time, prepare_inputs_time=prepare_inputs_time,
model_forward_time=model_forward_time, model_backward_time=model_backward_time))
def train(model, train_dataset, val_dataset, opt):
# Prepare optimizer
if opt.device.type == "cuda":
logger.info("CUDA enabled.")
model.to(opt.device)
if len(opt.device_ids) > 1:
logger.info("Use multi GPU", opt.device_ids)
model = torch.nn.DataParallel(model, device_ids=opt.device_ids) # use multi GPU
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()),
lr=opt.lr)
# reduce the lr by 0.1 every 30 epochs
scheduler = torch.optim.lr_scheduler.ExponentialLR(
optimizer,
gamma=0.95
)
train_loader = DataLoader(train_dataset,
collate_fn=retrieval_collate,
batch_size=opt.bsz,
num_workers=opt.num_workers,
shuffle=True,
pin_memory=opt.pin_memory)
prev_best_score = 0.
es_cnt = 0
start_epoch = -1 if opt.eval_untrained else 0
eval_tasks_at_training = ["VR"]
save_submission_filename = \
"latest_{}_{}_predictions_{}.json".format(opt.dset_name, opt.eval_split_name, "_".join(eval_tasks_at_training))
for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"):
if epoch_i > -1:
with torch.autograd.detect_anomaly():
train_epoch(model, train_loader, optimizer, opt, epoch_i)
global_step = (epoch_i + 1) * len(train_loader)
scheduler.step()
if opt.eval_path is not None:
with torch.no_grad():
metrics_no_nms, metrics_nms, latest_file_paths = \
eval_epoch(model, val_dataset, opt, save_submission_filename, tasks=eval_tasks_at_training)
logger.info("metrics_no_nms {}".format(
pprint.pformat(rm_key_from_odict(metrics_no_nms, rm_suffix="by_type"), indent=4)))
logger.info("metrics_nms \n{}".format(pprint.pformat(metrics_nms, indent=4)))
to_write = opt.eval_log_txt_formatter.format(
time_str=time.strftime("%Y_%m_%d_%H_%M_%S"),
epoch=epoch_i,
eval_metrics_str=json.dumps(metrics_no_nms))
with open(opt.eval_log_filepath, "a") as f:
f.write(to_write)
# metrics = metrics_nms if metrics_nms is not None else metrics_no_nms
metrics = metrics_no_nms
# early stop/ log / save model
for task_type, task_metrics in metrics.items():
for iou_thd in [0.5, 0.7]:
opt.writer.add_scalars("Eval/{}-{}".format(task_type, iou_thd),
{k: v for k, v in task_metrics.items() if str(iou_thd) in k},
global_step)
# use the most strict metric available
if metrics["VR"]["r1"] > prev_best_score:
es_cnt = 0
prev_best_score = metrics["VR"]["r1"]
checkpoint = {
"model": model.state_dict(),
"model_cfg": model.config,
"epoch": epoch_i}
torch.save(checkpoint, opt.ckpt_filepath)
best_file_paths = [e.replace("latest", "best") for e in latest_file_paths]
for src, tgt in zip(latest_file_paths, best_file_paths):
os.renames(src, tgt)
logger.info("The checkpoint file has been updated.")
else:
es_cnt += 1
if es_cnt > opt.max_es_cnt: # early stop
with open(opt.train_log_filepath, "a") as f:
f.write("Early Stop at epoch {}".format(epoch_i))
logger.info("Early stop at {} with VR r1 {}".format(epoch_i, prev_best_score))
break
else:
checkpoint = {
"model": model.state_dict(),
"model_cfg": model.config,
"epoch": epoch_i}
torch.save(checkpoint, opt.ckpt_filepath)
if opt.debug:
break
opt.writer.close()
def rm_key_from_odict(odict_obj, rm_suffix):
"""remove key entry from the OrderedDict"""
return OrderedDict([(k, v) for k, v in odict_obj.items() if rm_suffix not in k])
def start_training():
logger.info("Setup config, data and model...")
opt = BaseOptions().parse()
set_seed(opt.seed)
if opt.debug: # keep the model run deterministically
# 'cudnn.benchmark = True' enabled auto finding the best algorithm for a specific input/net config.
# Enable this only when input size is fixed.
cudnn.benchmark = False
cudnn.deterministic = True
opt.writer = SummaryWriter(opt.tensorboard_log_dir)
opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n"
opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Metrics] {eval_metrics_str}\n"
train_dataset = RetrievalDataset(
dset_name=opt.dset_name,
data_path=opt.train_path,
desc_bert_path_or_handler=opt.desc_bert_path,
sub_bert_path_or_handler=opt.sub_bert_path,
vid_feat_path_or_handler=opt.vid_feat_path,
max_desc_len=opt.max_desc_l,
max_ctx_len=opt.max_ctx_l,
ctx_mode=opt.ctx_mode,
h5driver=opt.h5driver,
data_ratio=opt.data_ratio,
normalize_vfeat=not opt.no_norm_vfeat,
normalize_tfeat=not opt.no_norm_tfeat,
)
if opt.eval_path is not None:
eval_dataset = RetrievalEvalDataset(
dset_name=opt.dset_name,
eval_split_name=opt.eval_split_name, # should only be val set
data_path=opt.eval_path,
desc_bert_path_or_handler=train_dataset.desc_bert_h5,
sub_bert_path_or_handler=train_dataset.sub_bert_h5 if "sub" in opt.ctx_mode else None,
max_desc_len=opt.max_desc_l,
max_ctx_len=opt.max_ctx_l,
corpus_path=opt.corpus_path,
vid_feat_path_or_handler=train_dataset.vid_feat_h5 if "video" in opt.ctx_mode else None,
ctx_mode=opt.ctx_mode,
data_mode="query",
h5driver=opt.h5driver,
data_ratio=opt.data_ratio,
normalize_vfeat=not opt.no_norm_vfeat,
normalize_tfeat=not opt.no_norm_tfeat,
)
else:
eval_dataset = None
model_config = EDict(
ctx_mode=opt.ctx_mode,
text_input_size=opt.sub_feat_size,
vid_input_size=opt.vid_feat_size, #
output_size=opt.output_size,
margin=opt.margin, # margin for ranking loss
)
logger.info("model_config {}".format(model_config))
model = MEE(model_config)
count_parameters(model)
logger.info("Start Training...")
train(model, train_dataset, eval_dataset, opt)
return opt.results_dir, opt.eval_split_name, opt.eval_path, opt.debug
if __name__ == '__main__':
model_dir, eval_split_name, eval_path, debug = start_training()
if not debug:
model_dir = model_dir.split(os.sep)[-1]
tasks = ["VR"]
input_args = ["--model_dir", model_dir,
"--eval_split_name", eval_split_name,
"--eval_path", eval_path,
"--tasks"] + tasks
import sys
sys.argv[1:] = input_args
logger.info("\n\n\nFINISHED TRAINING!!!")
logger.info("Evaluating model in {}".format(model_dir))
start_inference()