Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn.functional as F | |
from detectors.multi_attention_detector import AttentionPooling | |
from .abstract_loss_func import AbstractLossClass | |
from metrics.registry import LOSSFUNC | |
class RegionIndependentLoss(AbstractLossClass): | |
def __init__(self, M, N, alpha, alpha_decay, decay_batch, inter_margin, intra_margin): | |
super().__init__() | |
feature_centers = torch.zeros(M, N) | |
self.register_buffer("feature_centers", | |
feature_centers.cuda() if torch.cuda.is_available() else feature_centers) | |
self.alpha = alpha | |
self.alpha_decay = alpha_decay | |
self.decay_batch = decay_batch | |
self.batch_cnt = 0 | |
self.inter_margin = inter_margin | |
intra_margin = torch.Tensor(intra_margin) | |
self.register_buffer("intra_margin", intra_margin.cuda() if torch.cuda.is_available() else intra_margin) | |
self.atp = AttentionPooling() | |
def forward(self, feature_maps_d, attention_maps, labels): | |
B, N, H, W = feature_maps_d.size() | |
B, M, AH, AW = attention_maps.size() | |
if AH != H or AW != W: | |
attention_maps = F.interpolate(attention_maps, (H, W), mode='bilinear', align_corners=True) | |
feature_matrix = self.atp(feature_maps_d, attention_maps) | |
# Calculate new feature centers. P.s., I don't know why to use no_grad() and detach() for so many times. | |
feature_centers = self.feature_centers.detach() | |
new_feature_centers = feature_centers + self.alpha * torch.mean(feature_matrix - feature_centers, dim=0) | |
new_feature_centers = new_feature_centers.detach() | |
with torch.no_grad(): | |
self.feature_centers = new_feature_centers | |
# Calculate intra-class loss | |
intra_margins = torch.gather(self.intra_margin.repeat(B, 1), dim=1, index=labels.unsqueeze(1)) | |
intra_class_loss = torch.mean(F.relu(torch.norm(feature_matrix - new_feature_centers, dim=-1) - intra_margins)) | |
# Calculate inter-class loss | |
inter_class_loss = 0 | |
for i in range(M): | |
for j in range(i + 1, M): | |
inter_class_loss += F.relu( | |
self.inter_margin - torch.dist(new_feature_centers[i], new_feature_centers[j]), inplace=False) | |
inter_class_loss = inter_class_loss / M / self.alpha | |
# Count batch, this is used to simulate epoch, since alpha cannot be modified based on epoch due to code | |
# structure. self.alpha should be modified every N batch. | |
self.batch_cnt += 1 | |
if self.batch_cnt % self.decay_batch == 0: | |
self.alpha *= self.alpha_decay | |
return inter_class_loss + intra_class_loss | |