็™ฝ้นญๅ…ˆ็”Ÿ
init
abd2a81
raw
history blame
21 kB
import os
import torch_lydorn.torchvision
from tqdm import tqdm
import torch
import torch.distributed
import warnings
with warnings.catch_warnings():
warnings.simplefilter("ignore")
from torch.utils.tensorboard import SummaryWriter
# from pytorch_memlab import profile, profile_every
from . import measures, plot_utils
from . import local_utils
from lydorn_utils import run_utils
from lydorn_utils import python_utils
from lydorn_utils import math_utils
try:
from apex import amp
APEX_AVAILABLE = True
except ModuleNotFoundError:
APEX_AVAILABLE = False
def humanbytes(B):
'Return the given bytes as a human friendly KB, MB, GB, or TB string'
B = float(B)
KB = float(1024)
MB = float(KB ** 2) # 1,048,576
GB = float(KB ** 3) # 1,073,741,824
TB = float(KB ** 4) # 1,099,511,627,776
if B < KB:
return '{0} {1}'.format(B, 'Bytes' if 0 == B > 1 else 'Byte')
elif KB <= B < MB:
return '{0:.2f} KB'.format(B / KB)
elif MB <= B < GB:
return '{0:.2f} MB'.format(B / MB)
elif GB <= B < TB:
return '{0:.2f} GB'.format(B / GB)
elif TB <= B:
return '{0:.2f} TB'.format(B / TB)
class Trainer:
def __init__(self, rank, gpu, config, model, optimizer, loss_func,
run_dirpath, init_checkpoints_dirpath=None, lr_scheduler=None):
self.rank = rank
self.gpu = gpu
self.config = config
self.model = model
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.loss_func = loss_func
self.init_checkpoints_dirpath = init_checkpoints_dirpath
logs_dirpath = run_utils.setup_run_subdir(run_dirpath, config["optim_params"]["logs_dirname"])
self.checkpoints_dirpath = run_utils.setup_run_subdir(run_dirpath, config["optim_params"]["checkpoints_dirname"])
if self.rank == 0:
self.logs_dirpath = logs_dirpath
train_logs_dirpath = os.path.join(self.logs_dirpath, "train")
val_logs_dirpath = os.path.join(self.logs_dirpath, "val")
self.train_writer = SummaryWriter(train_logs_dirpath)
self.val_writer = SummaryWriter(val_logs_dirpath)
else:
self.logs_dirpath = self.train_writer = self.val_writer = None
def log_weights(self, module, module_name, step):
weight_list = module.parameters()
for i, weight in enumerate(weight_list):
if len(weight.shape) == 4:
weight_type = "4d"
elif len(weight.shape) == 1:
weight_type = "1d"
elif len(weight.shape) == 2:
weight_type = "2d"
else:
weight_type = ""
self.train_writer.add_histogram('{}/{}/{}/hist'.format(module_name, i, weight_type), weight, step)
# self.writer.add_scalar('{}/{}/mean'.format(module_name, i), mean, step)
# self.writer.add_scalar('{}/{}/max'.format(module_name, i), maxi, step)
# def log_pr_curve(self, name, pred, batch, iter_step):
# num_thresholds = 100
# thresholds = torch.linspace(0, 2 * self.config["max_disp_global"] + self.config["max_disp_poly"], steps=num_thresholds)
# dists = measures.pos_dists(pred, batch).cpu()
# tiled_dists = dists.repeat(num_thresholds, 1)
# tiled_thresholds = thresholds.repeat(dists.shape[0], 1).t()
# true_positives = tiled_dists < tiled_thresholds
# true_positive_counts = torch.sum(true_positives, dim=1)
# recall = true_positive_counts.float() / true_positives.shape[1]
#
# precision = 1 - thresholds / (2 * self.config["max_disp_global"] + self.config["max_disp_poly"])
#
# false_positive_counts = true_positives.shape[1] - true_positive_counts
# true_negative_counts = torch.zeros(num_thresholds)
# false_negative_counts = torch.zeros(num_thresholds)
# self.writer.add_pr_curve_raw(name, true_positive_counts,
# false_positive_counts,
# true_negative_counts,
# false_negative_counts,
# precision,
# recall,
# global_step=iter_step,
# num_thresholds=num_thresholds)
def sync_outputs(self, loss, individual_metrics_dict):
# Reduce to rank 0:
torch.distributed.reduce(loss, dst=0)
for key in individual_metrics_dict.keys():
torch.distributed.reduce(individual_metrics_dict[key], dst=0)
# Average on rank 0:
if self.rank == 0:
loss /= self.config["world_size"]
for key in individual_metrics_dict.keys():
individual_metrics_dict[key] /= self.config["world_size"]
# from pytorch_memlab import profile
# @profile
def loss_batch(self, batch, opt=None, epoch=None):
# print("Forward pass:")
# t0 = time.time()
pred, batch = self.model(batch)
# print(f"{time.time() - t0}s")
# print("Loss computation:")
# t0 = time.time()
loss, individual_metrics_dict, extra_dict = self.loss_func(pred, batch, epoch=epoch)
# print(f"{time.time() - t0}s")
# Compute IoUs at different thresholds
if "seg" in pred:
y_pred = pred["seg"][:, 0, ...]
y_true = batch["gt_polygons_image"][:, 0, ...]
iou_thresholds = [0.1, 0.25, 0.5, 0.75, 0.9]
for iou_threshold in iou_thresholds:
iou = measures.iou(y_pred.reshape(y_pred.shape[0], -1), y_true.reshape(y_true.shape[0], -1), threshold=iou_threshold)
mean_iou = torch.mean(iou)
individual_metrics_dict[f"IoU_{iou_threshold}"] = mean_iou
# print("Backward pass:")
# t0 = time.time()
if opt is not None:
# Detect if loss is nan
# contains_nan = bool(torch.sum(torch.isnan(loss)).item())
# if contains_nan:
# raise ValueError("NaN values detected, aborting...")
if self.config["use_amp"] and APEX_AVAILABLE:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
# torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
# all_grads = []
# for param in self.model.parameters():
# # print("shape: {}".format(param.shape))
# if param.grad is not None:
# all_grads.append(param.grad.view(-1))
# all_grads = torch.cat(all_grads)
# all_grads_abs = torch.abs(all_grads)
opt.step()
opt.zero_grad()
# print(f"{time.time() - t0}s")
# Synchronize losses/accuracies to GPU 0 so that they can be logged
self.sync_outputs(loss, individual_metrics_dict)
for key in individual_metrics_dict:
individual_metrics_dict[key] = individual_metrics_dict[key].item()
# Log IoU if exists
log_iou = None
iou_name = f"IoU_{0.5}" # Progress bars will show that IoU and it will be saved in checkpoints
if iou_name in individual_metrics_dict:
log_iou = individual_metrics_dict[iou_name]
return pred, batch, loss.item(), individual_metrics_dict, extra_dict, log_iou, batch["image"].shape[0]
def run_epoch(self, split_name, dl, epoch, log_steps=None, opt=None, iter_step=None):
assert split_name in ["train", "val"]
if split_name == "train":
writer = self.train_writer
elif split_name == "val":
writer = self.val_writer
assert iter_step is not None
else:
writer = None
running_loss_meter = math_utils.AverageMeter("running_loss")
running_losses_meter_dict = {loss_func.name: math_utils.AverageMeter(loss_func.name) for loss_func in
self.loss_func.loss_funcs}
total_running_loss_meter = math_utils.AverageMeter("total_running_loss")
running_iou_meter = math_utils.AverageMeter("running_iou")
total_running_iou_meter = math_utils.AverageMeter("total_running_iou")
# batch_index_offset = 0
epoch_iterator = dl
if self.gpu == 0:
epoch_iterator = tqdm(epoch_iterator, desc="{}: ".format(split_name), leave=False)
for i, batch in enumerate(epoch_iterator):
# Send batch to device
batch = local_utils.batch_to_cuda(batch)
# with torch.autograd.detect_anomaly(): # TODO: comment when not debugging
pred, batch, total_loss, metrics_dict, loss_extra_dict, log_iou, nums = self.loss_batch(batch, opt=opt, epoch=epoch)
# with torch.autograd.profiler.profile(use_cuda=True) as prof:
# loss, nums = self.loss_batch(batch, opt=opt)
# print(prof.key_averages().table(sort_by="cuda_time_total"))
running_loss_meter.update(total_loss, nums)
for name, loss in metrics_dict.items():
if name not in running_losses_meter_dict: # Init
running_losses_meter_dict[name] = math_utils.AverageMeter(name)
running_losses_meter_dict[name].update(loss, nums)
total_running_loss_meter.update(total_loss, nums)
if log_iou is not None:
running_iou_meter.update(log_iou, nums)
total_running_iou_meter.update(log_iou, nums)
# Log values
# batch_index = i + batch_index_offset
if split_name == "train":
iter_step = epoch * len(epoch_iterator) + i
if split_name == "train" and (iter_step % log_steps == 0) or \
split_name == "val" and i == (len(epoch_iterator) - 1):
# if iter_step % log_steps == 0:
if self.gpu == 0:
epoch_iterator.set_postfix(loss="{:.4f}".format(running_loss_meter.get_avg()),
iou="{:.4f}".format(running_iou_meter.get_avg()))
# Logs
if self.rank == 0:
writer.add_scalar("Metrics/Loss", running_loss_meter.get_avg(), iter_step)
for key, meter in running_losses_meter_dict.items():
writer.add_scalar(f"Metrics/{key}", meter.get_avg(), iter_step)
image_display = torch_lydorn.torchvision.transforms.functional.batch_denormalize(batch["image"],
batch[
"image_mean"],
batch["image_std"])
# # Save image overlaid with gt_seg to tensorboard:
# image_gt_seg_display = plot_utils.get_tensorboard_image_seg_display(image_display, batch["gt_polygons_image"])
# writer.add_images('gt_seg', image_gt_seg_display, iter_step)
# Save image overlaid with seg to tensorboard:
if "seg" in pred:
crossfield = pred["crossfield"] if "crossfield" in pred else None
image_seg_display = plot_utils.get_tensorboard_image_seg_display(image_display, pred["seg"], crossfield=crossfield)
writer.add_images('seg', image_seg_display, iter_step)
# self.log_pr_curve("PR curve/{}".format(name), pred, batch, iter_step)
# self.log_weights(self.model.module.backbone, "backbone", iter_step)
# if hasattr(self.model.module, "seg_module"):
# self.log_weights(self.model.module.seg_module, "seg_module", iter_step)
# if hasattr(self.model.module, "crossfield_module"):
# self.log_weights(self.model.module.crossfield_module, "crossfield_module", iter_step)
# self.writer.flush()
# im = batch["image"][0]
# self.writer.add_image('image', im)
running_loss_meter.reset()
for key, meter in running_losses_meter_dict.items():
meter.reset()
running_iou_meter.reset()
return total_running_loss_meter.get_avg(), total_running_iou_meter.get_avg(), iter_step
def compute_loss_norms(self, dl, total_batches):
self.loss_func.reset_norm()
t = None
if self.gpu == 0:
t = tqdm(total=total_batches, desc="Init loss norms", leave=True) # Initialise
batch_i = 0
while batch_i < total_batches:
for batch in dl:
# Update loss norms
batch = local_utils.batch_to_cuda(batch)
pred, batch = self.model(batch)
self.loss_func.update_norm(pred, batch, batch["image"].shape[0])
if t is not None:
t.update(1)
batch_i += 1
if not batch_i < total_batches:
break
# Now sync loss norms across GPUs:
self.loss_func.sync(self.config["world_size"])
def fit(self, train_dl, val_dl=None, init_dl=None):
# Try loading previous model
checkpoint = self.load_checkpoint(self.checkpoints_dirpath) # Try last checkpoint
if checkpoint is None and self.init_checkpoints_dirpath is not None:
# Try with init_checkpoints_dirpath:
checkpoint = self.load_checkpoint(self.init_checkpoints_dirpath)
checkpoint["epoch"] = 0 # Re-start from 0
if checkpoint is None:
checkpoint = {
"epoch": 0,
}
if init_dl is not None:
# --- Compute norms of losses on several epochs:
self.model.train() # Important for batchnorm and dropout, even in computing loss norms
with torch.no_grad():
loss_norm_batches_min = self.config["loss_params"]["multiloss"]["normalization_params"]["min_samples"] // (2 * self.config["optim_params"]["batch_size"]) + 1
loss_norm_batches_max = self.config["loss_params"]["multiloss"]["normalization_params"]["max_samples"] // (2 * self.config["optim_params"]["batch_size"]) + 1
loss_norm_batches = max(loss_norm_batches_min, min(loss_norm_batches_max, len(init_dl)))
self.compute_loss_norms(init_dl, loss_norm_batches)
if self.gpu == 0:
# Prints loss norms:
print(self.loss_func)
start_epoch = checkpoint["epoch"] # Start at next epoch
fit_iterator = range(start_epoch, self.config["optim_params"]["max_epoch"])
if self.gpu == 0:
fit_iterator = tqdm(fit_iterator, desc="Fitting: ", initial=start_epoch,
total=self.config["optim_params"]["max_epoch"])
train_loss = None
val_loss = None
train_iou = None
epoch = None
for epoch in fit_iterator:
self.model.train()
train_loss, train_iou, iter_step = self.run_epoch("train", train_dl, epoch, self.config["optim_params"]["log_steps"],
opt=self.optimizer)
if val_dl is not None:
self.model.eval()
with torch.no_grad():
val_loss, val_iou, _ = self.run_epoch("val", val_dl, epoch, self.config["optim_params"]["log_steps"], iter_step=iter_step)
else:
val_loss = None
val_iou = None
if val_loss is not None:
self.lr_scheduler.step()
else:
self.lr_scheduler.step()
if self.gpu == 0:
postfix_args = {"t_loss": "{:.4f}".format(train_loss), "t_iou": "{:.4f}".format(train_iou)}
if val_loss is not None:
postfix_args["v_loss"] = "{:.4f}".format(val_loss)
if val_loss is not None:
postfix_args["v_iou"] = "{:.4f}".format(val_iou)
fit_iterator.set_postfix(**postfix_args)
if self.rank == 0:
if (epoch + 1) % self.config["optim_params"]["checkpoint_epoch"] == 0:
self.save_last_checkpoint(epoch + 1, train_loss, val_loss, train_iou,
val_iou) # Save the last completed epoch, hence the "+1"
self.delete_old_checkpoint(epoch + 1)
if val_loss is not None:
self.save_best_val_checkpoint(epoch + 1, train_loss, val_loss, train_iou, val_iou)
if self.rank == 0 and epoch is not None:
self.save_last_checkpoint(epoch + 1, train_loss, val_loss, train_iou,
val_iou) # Save the last completed epoch, hence the "+1"
def load_checkpoint(self, checkpoints_dirpath):
"""
Loads last checkpoint in checkpoints_dirpath
:param checkpoints_dirpath:
:return:
"""
try:
filepaths = python_utils.get_filepaths(checkpoints_dirpath, endswith_str=".tar",
startswith_str="checkpoint.")
if len(filepaths) == 0:
return None
filepaths = sorted(filepaths)
filepath = filepaths[-1] # Last checkpoint
checkpoint = torch.load(filepath, map_location="cuda:{}".format(
self.gpu)) # map_location is used to load on current device
self.model.module.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict'])
self.loss_func.load_state_dict(checkpoint['loss_func_state_dict'])
epoch = checkpoint['epoch']
return {
"epoch": epoch,
}
except NotADirectoryError:
return None
def save_checkpoint(self, filepath, epoch, train_loss, val_loss, train_acc, val_acc):
torch.save({
'epoch': epoch,
'model_state_dict': self.model.module.state_dict(), # model is a DistributedDataParallel module
'optimizer_state_dict': self.optimizer.state_dict(),
'lr_scheduler_state_dict': self.lr_scheduler.state_dict(),
'loss_func_state_dict': self.loss_func.state_dict(),
'train_loss': train_loss,
'val_loss': val_loss,
'train_acc': train_acc,
'val_acc': val_acc,
}, filepath)
def save_last_checkpoint(self, epoch, train_loss, val_loss, train_acc, val_acc):
filename_format = "checkpoint.epoch_{:06d}.tar"
filepath = os.path.join(self.checkpoints_dirpath, filename_format.format(epoch))
self.save_checkpoint(filepath, epoch, train_loss, val_loss, train_acc, val_acc)
def delete_old_checkpoint(self, current_epoch):
filename_format = "checkpoint.epoch_{:06d}.tar"
to_delete_epoch = current_epoch - self.config["optim_params"]["checkpoints_to_keep"] * self.config["optim_params"]["checkpoint_epoch"]
filepath = os.path.join(self.checkpoints_dirpath, filename_format.format(to_delete_epoch))
if os.path.exists(filepath):
os.remove(filepath)
def save_best_val_checkpoint(self, epoch, train_loss, val_loss, train_acc, val_acc):
filepath = os.path.join(self.checkpoints_dirpath, "checkpoint.best_val.epoch_{:06d}.tar".format(epoch))
# Search for a prev best val checkpoint:
prev_filepaths = python_utils.get_filepaths(self.checkpoints_dirpath, startswith_str="checkpoint.best_val.",
endswith_str=".tar")
if len(prev_filepaths):
prev_filepaths = sorted(prev_filepaths)
prev_filepath = prev_filepaths[-1] # Last best val checkpoint filepath in case there is more than one
prev_best_val_checkpoint = torch.load(prev_filepath)
prev_best_loss = prev_best_val_checkpoint["val_loss"]
if val_loss < prev_best_loss:
self.save_checkpoint(filepath, epoch, train_loss, val_loss, train_acc, val_acc)
# Delete prev best val
[os.remove(prev_filepath) for prev_filepath in prev_filepaths]
else:
self.save_checkpoint(filepath, epoch, train_loss, val_loss, train_acc, val_acc)