import torch import torch.nn.functional as F import torch.nn as nn class SatCLIPLoss(nn.Module): def __init__( self, local_loss=False, cache_labels=False, rank=0, world_size=1, ): super().__init__() self.local_loss = local_loss self.cache_labels = cache_labels self.rank = rank self.world_size = world_size # cache state self.prev_num_logits = 0 self.labels = {} def get_ground_truth(self, device, num_logits) -> torch.Tensor: # calculated ground-truth and cache if enabled if self.prev_num_logits != num_logits or device not in self.labels: labels = torch.arange(num_logits, device=device, dtype=torch.long) if self.world_size > 1 and self.local_loss: labels = labels + num_logits * self.rank if self.cache_labels: self.labels[device] = labels self.prev_num_logits = num_logits else: labels = self.labels[device] return labels def forward(self, logits_per_image, logits_per_coord, output_dict=False): device = logits_per_image.device labels = self.get_ground_truth(device, logits_per_image.shape[0]) total_loss = ( F.cross_entropy(logits_per_image, labels) + F.cross_entropy(logits_per_coord, labels) ) / 2 return {"contrastive_loss": total_loss} if output_dict else total_loss