Capx
/

WhereAmAt / loss.py
Alyosha11's picture
Upload 8 files
5e83696 verified
import torch
import torch.nn.functional as F
import torch.nn as nn
class SatCLIPLoss(nn.Module):
def __init__(
self,
local_loss=False,
cache_labels=False,
rank=0,
world_size=1,
):
super().__init__()
self.local_loss = local_loss
self.cache_labels = cache_labels
self.rank = rank
self.world_size = world_size
# cache state
self.prev_num_logits = 0
self.labels = {}
def get_ground_truth(self, device, num_logits) -> torch.Tensor:
# calculated ground-truth and cache if enabled
if self.prev_num_logits != num_logits or device not in self.labels:
labels = torch.arange(num_logits, device=device, dtype=torch.long)
if self.world_size > 1 and self.local_loss:
labels = labels + num_logits * self.rank
if self.cache_labels:
self.labels[device] = labels
self.prev_num_logits = num_logits
else:
labels = self.labels[device]
return labels
def forward(self, logits_per_image, logits_per_coord, output_dict=False):
device = logits_per_image.device
labels = self.get_ground_truth(device, logits_per_image.shape[0])
total_loss = (
F.cross_entropy(logits_per_image, labels) +
F.cross_entropy(logits_per_coord, labels)
) / 2
return {"contrastive_loss": total_loss} if output_dict else total_loss