Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	| # Copyright 2019-present NAVER Corp. | |
| # CC BY-NC-SA 3.0 | |
| # Available only for non-commercial use | |
| import pdb | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from nets.sampler import * | |
| from nets.repeatability_loss import * | |
| from nets.reliability_loss import * | |
| class MultiLoss(nn.Module): | |
| """Combines several loss functions for convenience. | |
| *args: [loss weight (float), loss creator, ... ] | |
| Example: | |
| loss = MultiLoss( 1, MyFirstLoss(), 0.5, MySecondLoss() ) | |
| """ | |
| def __init__(self, *args, dbg=()): | |
| nn.Module.__init__(self) | |
| assert len(args) % 2 == 0, "args must be a list of (float, loss)" | |
| self.weights = [] | |
| self.losses = nn.ModuleList() | |
| for i in range(len(args) // 2): | |
| weight = float(args[2 * i + 0]) | |
| loss = args[2 * i + 1] | |
| assert isinstance(loss, nn.Module), "%s is not a loss!" % loss | |
| self.weights.append(weight) | |
| self.losses.append(loss) | |
| def forward(self, select=None, **variables): | |
| assert not select or all(1 <= n <= len(self.losses) for n in select) | |
| d = dict() | |
| cum_loss = 0 | |
| for num, (weight, loss_func) in enumerate(zip(self.weights, self.losses), 1): | |
| if select is not None and num not in select: | |
| continue | |
| l = loss_func(**{k: v for k, v in variables.items()}) | |
| if isinstance(l, tuple): | |
| assert len(l) == 2 and isinstance(l[1], dict) | |
| else: | |
| l = l, {loss_func.name: l} | |
| cum_loss = cum_loss + weight * l[0] | |
| for key, val in l[1].items(): | |
| d["loss_" + key] = float(val) | |
| d["loss"] = float(cum_loss) | |
| return cum_loss, d | |
 
			
