็™ฝ้นญๅ…ˆ็”Ÿ
init
abd2a81
import math
from functools import partial
import scipy.interpolate
import numpy as np
import torch
import torch.distributed
from torch.nn import functional as F
from . import measures
from . import frame_field_utils
import torch_lydorn.kornia
from lydorn_utils import math_utils, print_utils
# --- Base classes --- #
class Loss(torch.nn.Module):
def __init__(self, name):
"""
Attribute extra_info can be used in self.compute() to add intermediary results of loss computation for
visualization for example.
It is the second output of self.__call__()
:param name:
"""
super(Loss, self).__init__()
self.name = name
self.norm_meter = None
self.norm = torch.nn.parameter.Parameter(torch.Tensor(1), requires_grad=False)
self.reset_norm()
self.extra_info = {} #
def reset_norm(self):
self.norm_meter = math_utils.AverageMeter("{}_norm".format(self.name), init_val=1)
self.norm[0] = self.norm_meter.val
def update_norm(self, pred_batch, gt_batch, nums):
loss = self.compute(pred_batch, gt_batch)
self.norm_meter.update(loss, nums)
self.norm[0] = self.norm_meter.val
def sync(self, world_size):
"""
This method should be used to synchronize loss norms across GPUs when using distributed training
:return:
"""
torch.distributed.all_reduce(self.norm)
self.norm /= world_size
def compute(self, pred_batch, gt_batch):
raise NotImplementedError
def forward(self, pred_batch, gt_batch, normalize=True):
loss = self.compute(pred_batch, gt_batch)
if normalize:
assert 1e-9 < self.norm[0], "self.norm[0] <= 1e-9 -> this might lead to numerical instabilities."
loss = loss / self.norm[0]
extra_info = self.extra_info
self.extra_info = {} # Re-init extra_info
# contains_nan = bool(torch.sum(torch.isnan(loss)).item())
# assert not contains_nan, f"loss {str(self)} is Nan!"
return loss, extra_info
def __repr__(self):
return "{} (name={}, norm={:0.06})".format(self.__class__.__name__, self.name, self.norm[0])
class MultiLoss(torch.nn.Module):
def __init__(self, loss_funcs, weights, epoch_thresholds=None, pre_processes=None):
"""
@param loss_funcs:
@param weights:
@param pre_processes: List of functions to call with 2 arguments (which are updated): pred_batch, gt_batch to compute only one values used by several losses.
"""
super(MultiLoss, self).__init__()
assert len(loss_funcs) == len(weights), \
"Should have the same amount of loss_funcs ({}) and weights ({})".format(len(loss_funcs), len(weights))
self.loss_funcs = torch.nn.ModuleList(loss_funcs)
self.weights = []
for weight in weights:
if isinstance(weight, list):
# Weight is a list of coefs corresponding to epoch_thresholds, they will be interpolated in-between
self.weights.append(scipy.interpolate.interp1d(epoch_thresholds, weight, bounds_error=False, fill_value=(weight[0], weight[-1])))
elif isinstance(weight, float) or isinstance(weight, int):
self.weights.append(float(weight))
else:
raise TypeError(f"Type {type(weight)} not supported as a loss coef weight.")
self.pre_processes = pre_processes
for loss_func, weight in zip(self.loss_funcs, self.weights):
if weight == 0:
print_utils.print_info(f"INFO: loss '{loss_func.name}' has a weight of zero and thus won't affect grad update.")
def reset_norm(self):
for loss_func in self.loss_funcs:
loss_func.reset_norm()
def update_norm(self, pred_batch, gt_batch, nums):
if self.pre_processes is not None:
for pre_process in self.pre_processes:
pred_batch, gt_batch = pre_process(pred_batch, gt_batch)
for loss_func in self.loss_funcs:
loss_func.update_norm(pred_batch, gt_batch, nums)
def sync(self, world_size):
"""
This method should be used to synchronize loss norms across GPUs when using distributed training
:return:
"""
for loss_func in self.loss_funcs:
loss_func.sync(world_size)
def forward(self, pred_batch, gt_batch, normalize=True, epoch=None):
if self.pre_processes is not None:
for pre_process in self.pre_processes:
pred_batch, gt_batch = pre_process(pred_batch, gt_batch)
total_loss = 0
# total_weight = 0
individual_losses_dict = {}
extra_dict = {}
for loss_func_i, weight_i in zip(self.loss_funcs, self.weights):
loss_i, extra_dict_i = loss_func_i(pred_batch, gt_batch, normalize=normalize)
if isinstance(weight_i, scipy.interpolate.interpolate.interp1d) and epoch is not None:
current_weight = float(weight_i(epoch))
else:
current_weight = weight_i
total_loss += current_weight * loss_i
# total_weight += weight_i
individual_losses_dict[loss_func_i.name] = loss_i
extra_dict[loss_func_i.name] = extra_dict_i
# total_loss /= total_weight
return total_loss, individual_losses_dict, extra_dict
def __repr__(self):
ret = "\n\t".join([str(loss_func) for loss_func in self.loss_funcs])
return "{}:\n\t{}".format(self.__class__.__name__, ret)
# --- Build combined loss: --- #
def compute_seg_loss_weigths(pred_batch, gt_batch, config):
"""
Combines distances (from U-Net paper) with sizes (from https://github.com/neptune-ai/open-solution-mapping-challenge).
@param pred_batch:
@param gt_batch:
@return:
"""
device = gt_batch["distances"].device
use_dist = config["loss_params"]["seg_loss_params"]["use_dist"]
use_size = config["loss_params"]["seg_loss_params"]["use_size"]
w0 = config["loss_params"]["seg_loss_params"]["w0"]
sigma = config["loss_params"]["seg_loss_params"]["sigma"]
height = gt_batch["image"].shape[2]
width = gt_batch["image"].shape[3]
im_radius = math.sqrt(height * width) / 2
# --- Class imbalance weight (not forgetting background):
gt_polygons_mask = (0 < gt_batch["gt_polygons_image"]).float()
background_freq = 1 - torch.sum(gt_batch["class_freq"], dim=1)
pixel_class_freq = gt_polygons_mask * gt_batch["class_freq"][:, :, None, None] + \
(1 - gt_polygons_mask) * background_freq[:, None, None, None]
if pixel_class_freq.min() == 0:
print_utils.print_error("ERROR: pixel_class_freq has some zero values, can't divide by zero!")
raise ZeroDivisionError
freq_weights = 1 / pixel_class_freq
# print("freq_weights:", freq_weights.min().item(), freq_weights.max().item())
# Compute size weights
# print("sizes:", gt_batch["sizes"].min().item(), gt_batch["sizes"].max().item())
# print("distances:", gt_batch["distances"].min().item(), gt_batch["distances"].max().item())
# print("im_radius:", im_radius)
size_weights = None
if use_size:
if gt_batch["sizes"].min() == 0:
print_utils.print_error(("ERROR: sizes tensor has zero values, can't divide by zero!"))
raise ZeroDivisionError
size_weights = 1 + 1 / (im_radius * gt_batch["sizes"])
distance_weights = None
if use_dist:
# print("distances:", gt_batch["distances"].min().item(), gt_batch["distances"].max().item())
distance_weights = gt_batch["distances"] * (height + width) # Denormalize distances
distance_weights = w0 * torch.exp(-(distance_weights ** 2) / (sigma ** 2))
# print("sum(distances == 0):", torch.sum(gt_batch["distances"] == 0).item())
# print("distance_weights:", distance_weights.min().item(), distance_weights.max().item())
# print(distance_weights.shape, distance_weights.min().item(), distance_weights.max().item())
# print(size_weights.shape, size_weights.min().item(), size_weights.max().item())
# print(freq_weights.shape, freq_weights.min().item(), freq_weights.max().item())
gt_batch["seg_loss_weights"] = freq_weights
if use_dist:
gt_batch["seg_loss_weights"] += distance_weights
if use_size:
gt_batch["seg_loss_weights"] *= size_weights
# print(gt_batch["seg_loss_weights"].shape, gt_batch["seg_loss_weights"].min().item(), gt_batch["seg_loss_weights"].max().item())
# print("seg_loss_weights:", size_weights.min().item(), size_weights.max().item())
# print("freq_weights:", freq_weights.min().item(), freq_weights.max().item())
# print("size_weights:", size_weights.min().item(), size_weights.max().item())
# print("distance_weights:", distance_weights.min().item(), distance_weights.max().item())
# Display:
# display_seg_loss_weights = gt_batch["seg_loss_weights"][0].cpu().detach().numpy()
# display_distance_weights = distance_weights[0].cpu().detach().numpy()
# skimage.io.imsave("seg_loss_dist_weights.png", display_distance_weights[0])
# display_size_weights = size_weights[0].cpu().detach().numpy()
# skimage.io.imsave("seg_loss_size_weights.png", display_size_weights[0])
# display_freq_weights = freq_weights[0].cpu().detach().numpy()
# display_freq_weights = display_freq_weights - display_freq_weights.min()
# display_freq_weights /= display_freq_weights.max()
# skimage.io.imsave("seg_loss_freq_weights.png", np.moveaxis(display_freq_weights, 0, -1))
# for i in range(3):
# skimage.io.imsave(f"seg_loss_weights_{i}.png", display_seg_loss_weights[i])
# skimage.io.imsave(f"freq_weights_{i}.png", display_freq_weights[i])
return pred_batch, gt_batch
def compute_gt_field(pred_batch, gt_batch):
gt_crossfield_angle = gt_batch["gt_crossfield_angle"]
gt_field = torch.cat([torch.cos(gt_crossfield_angle),
torch.sin(gt_crossfield_angle)], dim=1)
gt_batch["gt_field"] = gt_field
return pred_batch, gt_batch
class ComputeSegGrads:
def __init__(self, device):
self.spatial_gradient = torch_lydorn.kornia.filters.SpatialGradient(mode="scharr", coord="ij", normalized=True, device=device)
def __call__(self, pred_batch, gt_batch):
seg = pred_batch["seg"] # (b, c, h, w)
seg_grads = 2 * self.spatial_gradient(seg) # (b, c, 2, h, w), Normalize (kornia normalizes to -0.5, 0.5 for input in [0, 1])
seg_grad_norm = seg_grads.norm(dim=2) # (b, c, h, w)
seg_grads_normed = seg_grads / (seg_grad_norm[:, :, None, ...] + 1e-6) # (b, c, 2, h, w)
pred_batch["seg_grads"] = seg_grads
pred_batch["seg_grad_norm"] = seg_grad_norm
pred_batch["seg_grads_normed"] = seg_grads_normed
return pred_batch, gt_batch
def build_combined_loss(config):
pre_processes = []
loss_funcs = []
weights = []
if config["compute_seg"]:
partial_compute_seg_loss_weigths = partial(compute_seg_loss_weigths, config=config)
pre_processes.append(partial_compute_seg_loss_weigths)
gt_channel_selector = [config["seg_params"]["compute_interior"], config["seg_params"]["compute_edge"], config["seg_params"]["compute_vertex"]]
loss_funcs.append(SegLoss(name="seg",
gt_channel_selector=gt_channel_selector,
bce_coef=config["loss_params"]["seg_loss_params"]["bce_coef"],
dice_coef=config["loss_params"]["seg_loss_params"]["dice_coef"]))
weights.append(config["loss_params"]["multiloss"]["coefs"]["seg"])
if config["compute_crossfield"]:
pre_processes.append(compute_gt_field)
loss_funcs.append(CrossfieldAlignLoss(name="crossfield_align"))
weights.append(config["loss_params"]["multiloss"]["coefs"]["crossfield_align"])
loss_funcs.append(CrossfieldAlign90Loss(name="crossfield_align90"))
weights.append(config["loss_params"]["multiloss"]["coefs"]["crossfield_align90"])
loss_funcs.append(CrossfieldSmoothLoss(name="crossfield_smooth"))
weights.append(config["loss_params"]["multiloss"]["coefs"]["crossfield_smooth"])
# --- Coupling losses:
if config["compute_seg"]:
need_seg_grads = False
pred_channel = -1
# Seg interior <-> Crossfield coupling:
if config["seg_params"]["compute_interior"] and config["compute_crossfield"]:
need_seg_grads = True
pred_channel += 1
loss_funcs.append(SegCrossfieldLoss(name="seg_interior_crossfield", pred_channel=pred_channel))
weights.append(config["loss_params"]["multiloss"]["coefs"]["seg_interior_crossfield"])
# Seg edge <-> Crossfield coupling:
if config["seg_params"]["compute_edge"] and config["compute_crossfield"]:
need_seg_grads = True
pred_channel += 1
loss_funcs.append(SegCrossfieldLoss(name="seg_edge_crossfield", pred_channel=pred_channel))
weights.append(config["loss_params"]["multiloss"]["coefs"]["seg_edge_crossfield"])
# Seg edge <-> seg interior coupling:
if config["seg_params"]["compute_interior"] and config["seg_params"]["compute_edge"]:
need_seg_grads = True
loss_funcs.append(SegEdgeInteriorLoss(name="seg_edge_interior"))
weights.append(config["loss_params"]["multiloss"]["coefs"]["seg_edge_interior"])
if need_seg_grads:
pre_processes.append(ComputeSegGrads(config["device"]))
combined_loss = MultiLoss(loss_funcs, weights, epoch_thresholds=config["loss_params"]["multiloss"]["coefs"]["epoch_thresholds"], pre_processes=pre_processes)
return combined_loss
# --- Specific losses --- #
class SegLoss(Loss):
def __init__(self, name, gt_channel_selector, bce_coef=0.5, dice_coef=0.5):
"""
:param name:
:param gt_channel_selector: used to select which channels gt_polygons_image to use to compare to predicted seg
(see docstring of method compute() for more details).
"""
super(SegLoss, self).__init__(name)
self.gt_channel_selector = gt_channel_selector
self.bce_coef = bce_coef
self.dice_coef = dice_coef
def compute(self, pred_batch, gt_batch):
"""
seg and gt_polygons_image do not necessarily have the same channel count.
gt_selector is used to select which channels of gt_polygons_image to use.
For example, if seg has C_pred=2 (interior and edge) and
gt_polygons_image has C_gt=3 (interior, edge and vertex), use gt_channel_selector=slice(0, 2)
@param pred_batch: key "seg" is shape (N, C_pred, H, W)
@param gt_batch: key "gt_polygons_image" is shape (N, C_gt, H, W)
@return:
"""
# print(self.name)
pred_seg = pred_batch["seg"]
gt_seg = gt_batch["gt_polygons_image"][:, self.gt_channel_selector, ...]
weights = gt_batch["seg_loss_weights"][:, self.gt_channel_selector, ...]
dice = measures.dice_loss(pred_seg, gt_seg)
mean_dice = torch.mean(dice)
mean_cross_entropy = F.binary_cross_entropy(pred_seg, gt_seg, weight=weights, reduction="mean")
# Display:
# dispaly_pred_seg = pred_seg[0, 0].cpu().detach().numpy()
# print(f'{self.name}_pred:', dispaly_pred_seg.shape, dispaly_pred_seg.min(), dispaly_pred_seg.max())
# skimage.io.imsave(f'{self.name}_pred.png', dispaly_pred_seg)
# dispaly_gt_seg = gt_seg[0].cpu().detach().numpy()
# skimage.io.imsave(f'{self.name}_gt.png', dispaly_gt_seg)
return self.bce_coef * mean_cross_entropy + self.dice_coef * mean_dice
class CrossfieldAlignLoss(Loss):
def __init__(self, name):
super(CrossfieldAlignLoss, self).__init__(name)
def compute(self, pred_batch, gt_batch):
c0 = pred_batch["crossfield"][:, :2]
c2 = pred_batch["crossfield"][:, 2:]
z = gt_batch["gt_field"]
gt_polygons_image = gt_batch["gt_polygons_image"]
assert 2 <= gt_polygons_image.shape[1], \
"gt_polygons_image should have at least 2 channels for interior and edges"
gt_edges = gt_polygons_image[:, 1, ...]
align_loss = frame_field_utils.framefield_align_error(c0, c2, z, complex_dim=1)
avg_align_loss = torch.mean(align_loss * gt_edges)
self.extra_info["gt_field"] = gt_batch["gt_field"]
return avg_align_loss
class CrossfieldAlign90Loss(Loss):
def __init__(self, name):
super(CrossfieldAlign90Loss, self).__init__(name)
def compute(self, pred_batch, gt_batch):
c0 = pred_batch["crossfield"][:, :2]
c2 = pred_batch["crossfield"][:, 2:]
z = gt_batch["gt_field"]
z_90deg = torch.cat((- z[:, 1:2, ...], z[:, 0:1, ...]), dim=1)
gt_polygons_image = gt_batch["gt_polygons_image"]
assert gt_polygons_image.shape[1] == 3, \
"gt_polygons_image should have 3 channels for interior, edges and vertices"
gt_edges = gt_polygons_image[:, 1, ...]
gt_vertices = gt_polygons_image[:, 2, ...]
gt_edges_minus_vertices = gt_edges - gt_vertices
gt_edges_minus_vertices = gt_edges_minus_vertices.clamp(0, 1)
align90_loss = frame_field_utils.framefield_align_error(c0, c2, z_90deg, complex_dim=1)
avg_align90_loss = torch.mean(align90_loss * gt_edges_minus_vertices)
return avg_align90_loss
class CrossfieldSmoothLoss(Loss):
def __init__(self, name):
super(CrossfieldSmoothLoss, self).__init__(name)
self.laplacian_penalty = frame_field_utils.LaplacianPenalty(channels=4)
def compute(self, pred_batch, gt_batch):
c0c2 = pred_batch["crossfield"]
gt_polygons_image = gt_batch["gt_polygons_image"]
gt_edges_inv = 1 - gt_polygons_image[:, 1, ...]
penalty = self.laplacian_penalty(c0c2)
avg_penalty = torch.mean(penalty * gt_edges_inv[:, None, ...])
return avg_penalty
class SegCrossfieldLoss(Loss):
def __init__(self, name, pred_channel):
super(SegCrossfieldLoss, self).__init__(name)
self.pred_channel = pred_channel
def compute(self, pred_batch, gt_batch):
# TODO: don't apply on corners: corner_map = gt_batch["gt_polygons_image"][:, 2, :, :]
# TODO: apply on all seg at once? Like seg is now?
c0 = pred_batch["crossfield"][:, :2]
c2 = pred_batch["crossfield"][:, 2:]
seg_slice_grads_normed = pred_batch["seg_grads_normed"][:, self.pred_channel, ...]
seg_slice_grad_norm = pred_batch["seg_grad_norm"][:, self.pred_channel, ...]
align_loss = frame_field_utils.framefield_align_error(c0, c2, seg_slice_grads_normed, complex_dim=1)
# normed_align_loss = align_loss * seg_slice_grad_norm
# avg_align_loss = torch.sum(normed_align_loss) / (torch.sum(seg_slice_grad_norm) + 1e-6)
avg_align_loss = torch.mean(align_loss * seg_slice_grad_norm.detach())
# (prev line) Don't back-propagate to seg_slice_grad_norm so that seg smoothness is not encouraged
# Save extra info for viz:
self.extra_info["seg_slice_grads"] = pred_batch["seg_grads"][:, self.pred_channel, ...]
return avg_align_loss
class SegEdgeInteriorLoss(Loss):
"""
Enforce seg edge to be equal to interior grad norm except inside buildings
"""
def __init__(self, name):
super(SegEdgeInteriorLoss, self).__init__(name)
def compute(self, pred_batch, batch):
seg_interior = pred_batch["seg"][:, 0, ...]
seg_edge = pred_batch["seg"][:, 1, ...]
seg_interior_grad_norm = pred_batch["seg_grad_norm"][:, 0, ...]
raw_loss = torch.abs(seg_edge - seg_interior_grad_norm)
# Apply the loss only on interior boundaries and outside of objects
outside_mask = (torch.cos(np.pi * seg_interior) + 1) / 2
boundary_mask = (1 - torch.cos(np.pi * seg_interior_grad_norm)) / 2
mask = torch.max(outside_mask, boundary_mask).float()
avg_loss = torch.mean(raw_loss * mask)
return avg_loss