|
import torch |
|
import torch.nn.functional as F |
|
import torch.nn as nn |
|
|
|
class SatCLIPLoss(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
local_loss=False, |
|
cache_labels=False, |
|
rank=0, |
|
world_size=1, |
|
): |
|
super().__init__() |
|
self.local_loss = local_loss |
|
self.cache_labels = cache_labels |
|
self.rank = rank |
|
self.world_size = world_size |
|
|
|
|
|
self.prev_num_logits = 0 |
|
self.labels = {} |
|
|
|
def get_ground_truth(self, device, num_logits) -> torch.Tensor: |
|
|
|
if self.prev_num_logits != num_logits or device not in self.labels: |
|
labels = torch.arange(num_logits, device=device, dtype=torch.long) |
|
if self.world_size > 1 and self.local_loss: |
|
labels = labels + num_logits * self.rank |
|
if self.cache_labels: |
|
self.labels[device] = labels |
|
self.prev_num_logits = num_logits |
|
else: |
|
labels = self.labels[device] |
|
return labels |
|
|
|
def forward(self, logits_per_image, logits_per_coord, output_dict=False): |
|
device = logits_per_image.device |
|
|
|
labels = self.get_ground_truth(device, logits_per_image.shape[0]) |
|
|
|
total_loss = ( |
|
F.cross_entropy(logits_per_image, labels) + |
|
F.cross_entropy(logits_per_coord, labels) |
|
) / 2 |
|
|
|
return {"contrastive_loss": total_loss} if output_dict else total_loss |
|
|