|
import torch |
|
from torch import nn |
|
from torch.utils.data import Dataset, DataLoader |
|
from torchvision.models import resnet50 |
|
from torchvision import transforms |
|
from PIL import Image |
|
import matplotlib.pyplot as plt |
|
from transformers import BertTokenizer, BertModel |
|
import os |
|
import json |
|
import numpy as np |
|
from collections import defaultdict |
|
import random |
|
from tqdm.notebook import tqdm |
|
from torchvision import models |
|
from torch.nn.utils.rnn import pad_sequence |
|
import matplotlib.patches as patches |
|
|
|
import math |
|
import time |
|
import os |
|
from PIL import Image |
|
import requests |
|
import nltk |
|
|
|
import os |
|
import cv2 |
|
import colorsys |
|
from numpy import asarray |
|
import math |
|
|
|
|
|
from transformers import GPT2LMHeadModel, GPT2Config |
|
|
|
from transformers import BertTokenizer |
|
|
|
|
|
from scipy.optimize import linear_sum_assignment |
|
|
|
|
|
|
|
|
|
class CocoDataset(Dataset): |
|
def __init__(self, root_dir, annotation_file, instance_file, max_objects=40, transform=None): |
|
self.root_dir = root_dir |
|
self.transform = transform |
|
self.max_objects = max_objects |
|
self.img_cache = dict() |
|
|
|
|
|
with open(instance_file, 'r') as file: |
|
data = json.load(file) |
|
instances = data['annotations'] |
|
categories = data['categories'] |
|
|
|
with open(annotation_file, 'r') as file: |
|
annotations = json.load(file)['annotations'] |
|
|
|
self.image_captions = defaultdict(list) |
|
for annotation in annotations: |
|
img_id = annotation['image_id'] |
|
self.image_captions[img_id].append(annotation['caption']) |
|
|
|
self.image_instances = defaultdict(list) |
|
self.category_id_to_name = {category['id']: category['name'] for category in categories} |
|
|
|
for instance in instances: |
|
img_id = instance['image_id'] |
|
bbox = instance['bbox'] |
|
category_id = instance['category_id'] |
|
self.image_instances[img_id].append((bbox, category_id)) |
|
|
|
self.img_ids = list(self.image_captions.keys()) |
|
|
|
def __len__(self): |
|
return len(self.img_ids) |
|
|
|
def __getitem__(self, index): |
|
img_id = self.img_ids[index] |
|
img_path = os.path.join(self.root_dir, f'{str(img_id).zfill(12)}.jpg') |
|
|
|
|
|
|
|
if img_id in self.img_cache: |
|
img = self.img_cache[img_id] |
|
else: |
|
img = Image.open(img_path).convert("RGB") |
|
self.img_cache[img_id] = img |
|
|
|
|
|
captions = self.image_captions[img_id] |
|
caption = random.choice(captions) |
|
|
|
annotations = self.image_instances[img_id] |
|
bboxes = [] |
|
labels = [] |
|
for obbox, label_id in annotations: |
|
bbox = torch.tensor(obbox) |
|
bbox[0] = bbox[0] / img.width + (bbox[2] / img.width)/2 |
|
bbox[1] = bbox[1] / img.height + (bbox[3] / img.height)/2 |
|
bbox[2] = bbox[2] / img.width |
|
bbox[3] = bbox[3] / img.height |
|
label = self.category_id_to_name[label_id] |
|
bboxes.append(bbox) |
|
labels.append(label) |
|
|
|
bboxes.append(torch.tensor([0.5, 0.5, 1, 1])) |
|
labels.append(caption) |
|
|
|
total_boxes = len(bboxes) |
|
|
|
if total_boxes < 40: |
|
for _ in range(40-total_boxes): |
|
bboxes.append(torch.tensor([0, 0, 0 ,0])) |
|
labels.append("na") |
|
else: |
|
bboxes = bboxes[:40] |
|
labels = labels[:40] |
|
|
|
if self.transform: |
|
img = self.transform(img) |
|
|
|
return img, bboxes, labels |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((256, 256)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
|
|
def custom_collate(batch): |
|
images, boxes_list, labels_list = zip(*batch) |
|
|
|
|
|
stacked_images = torch.stack(images) |
|
|
|
|
|
stacked_boxes = [torch.stack([box.clone().detach() for box in boxes]) for boxes in boxes_list] |
|
|
|
|
|
|
|
|
|
|
|
return stacked_images, stacked_boxes, labels_list |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_fn(data_loader, model, criterion, optimizer, device, scheduler, epoch): |
|
model.train() |
|
criterion.train() |
|
summary_loss = AverageMeter() |
|
|
|
tk0 = tqdm(data_loader, total=len(data_loader)-1) |
|
|
|
for step, (images, bboxes, captions) in enumerate(tk0): |
|
|
|
try: |
|
flattened_captions = [caption for sublist in captions for caption in sublist] |
|
captions = tokenizer(flattened_captions, padding=True, return_tensors="pt", truncation=True) |
|
captions = captions["input_ids"] |
|
input_ids = captions.reshape(batch_size, num_queries, -1).to(device) |
|
min_length = 2 |
|
except RuntimeError as e: |
|
print("Reshape failed:", e) |
|
continue |
|
|
|
''' |
|
min_length = 2 |
|
if input_ids.size(-1) < min_length: |
|
padding_needed = min_length - input_ids.size(-1) |
|
input_ids = F.pad(input_ids, (0, padding_needed), 'constant', PAD_TOKEN) |
|
|
|
targets = build_targets(bboxes, input_ids[:, :, 1:]) |
|
targets = [{k: v.to(device) for k, v in t.items()} for t in targets] |
|
|
|
images = list(image.to(device) for image in images) |
|
|
|
|
|
output = model(images,input_ids[:, :,:-1]) |
|
''' |
|
|
|
min_length = 2 |
|
if input_ids.size(-1) < min_length: |
|
padding_needed = min_length - input_ids.size(-1) |
|
input_ids = F.pad(input_ids, (0, padding_needed), 'constant', PAD_TOKEN) |
|
|
|
|
|
|
|
|
|
targets = build_targets(bboxes, input_ids[:, :, 1:]) |
|
|
|
|
|
|
|
images = list(image.to(device) for image in images) |
|
|
|
targets = [{k: v.to(device) for k, v in t.items()} for t in targets] |
|
|
|
|
|
output = model(images,input_ids[:,:,:-1]) |
|
|
|
loss_dict = criterion(output, targets) |
|
weight_dict = criterion.weight_dict |
|
|
|
losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) |
|
|
|
optimizer.zero_grad() |
|
losses.backward() |
|
optimizer.step() |
|
|
|
if scheduler is not None: |
|
scheduler.step() |
|
|
|
|
|
loss_dict = {k: v.detach() for k, v in loss_dict.items()} |
|
|
|
del images, bboxes, captions, output, targets, loss_dict |
|
torch.cuda.empty_cache() |
|
|
|
summary_loss.update(losses.item(),BATCH_SIZE) |
|
tk0.set_postfix(loss=summary_loss.avg) |
|
|
|
|
|
return summary_loss |
|
class HungarianMatcher(nn.Module): |
|
"""This class computes an assignment between the targets and the predictions of the network |
|
|
|
For efficiency reasons, the targets don't include the no_object. Because of this, in general, |
|
there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, |
|
while the others are un-matched (and thus treated as non-objects). |
|
""" |
|
|
|
def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1): |
|
"""Creates the matcher |
|
|
|
Params: |
|
cost_class: This is the relative weight of the classification error in the matching cost |
|
cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost |
|
cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost |
|
""" |
|
super().__init__() |
|
self.cost_class = cost_class |
|
self.cost_bbox = cost_bbox |
|
self.cost_giou = cost_giou |
|
assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0" |
|
|
|
@torch.no_grad() |
|
def forward(self, outputs, targets): |
|
bs, num_queries = outputs["pred_logits"].shape[:2] |
|
|
|
|
|
|
|
|
|
out_prob = outputs["pred_logits"].flatten(0,2 ).softmax(-1) |
|
|
|
out_bbox = outputs["pred_boxes"].flatten(0, 1) |
|
|
|
|
|
tgt_ids = torch.cat([v["labels"] for v in targets]) |
|
tgt_bbox = torch.cat([v["boxes"] for v in targets]) |
|
|
|
|
|
|
|
|
|
|
|
cost_class = -out_prob[:, tgt_ids] |
|
|
|
|
|
cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) |
|
|
|
|
|
cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)) |
|
|
|
|
|
C = self.cost_bbox * cost_bbox + self.cost_class * cost_class.mean() + self.cost_giou * cost_giou |
|
|
|
C = C.view(bs, num_queries, -1).cpu() |
|
|
|
sizes = [len(v["boxes"]) for v in targets] |
|
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] |
|
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] |
|
|
|
|
|
|
|
def build_matcher(args): |
|
return HungarianMatcher(cost_class=args.set_cost_class, cost_bbox=args.set_cost_bbox, cost_giou=args.set_cost_giou) |
|
|
|
|
|
|
|
class SetCriterion(nn.Module): |
|
""" This class computes the loss for DETR. |
|
The process happens in two steps: |
|
1) we compute hungarian assignment between ground truth boxes and the outputs of the model |
|
2) we supervise each pair of matched ground-truth / prediction (supervise class and box) |
|
""" |
|
def __init__(self, vocab_size, matcher, weight_dict, eos_coef, losses,pad_token): |
|
""" Create the criterion. |
|
Parameters: |
|
vocab_size : es number of object categories, omitting the special no-object category |
|
matcher: module able to compute a matching between targets and proposals |
|
weight_dict: dict containing as key the names of the losses and as values their relative weight. |
|
eos_coef: relative classification weight applied to the no-object category |
|
losses: list of all the losses to be applied. See get_loss for list of available losses. |
|
""" |
|
super().__init__() |
|
self.vocab_size = vocab_size |
|
self.matcher = matcher |
|
self.weight_dict = weight_dict |
|
self.eos_coef = eos_coef |
|
self.losses = losses |
|
self.pad_token=pad_token |
|
empty_weight = torch.ones(self.vocab_size) |
|
|
|
self.register_buffer('empty_weight', empty_weight) |
|
self.criterion = nn.CrossEntropyLoss(ignore_index=pad_token) |
|
|
|
|
|
def loss_labels(self, outputs, targets, indices, num_boxes, log=False): |
|
|
|
"""Classification loss (NLL) for sequences |
|
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes, seq_length] |
|
""" |
|
assert 'pred_logits' in outputs |
|
src_logits = outputs['pred_logits'] |
|
batch_size, num_boxes , sequence_length, _ = src_logits.size() |
|
|
|
|
|
batch_idx, src_idx = self._get_src_permutation_idx(indices) |
|
|
|
target_classes = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) |
|
|
|
|
|
assert (target_classes >= 0).all() and (target_classes < self.vocab_size).all(), "Invalid token index in target!" |
|
|
|
|
|
loss_ce = self.criterion(src_logits.reshape(batch_size * num_boxes * sequence_length, -1), target_classes.reshape(-1)) |
|
|
|
|
|
|
|
losses = {'loss_ce': loss_ce} |
|
|
|
return losses |
|
|
|
|
|
|
|
|
|
''' |
|
criterion = nn.CrossEntropyLoss(ignore_index=self.PAD_TOKEN) |
|
loss_ce = criterion(src_logits, target_classes_for_loss) |
|
losses = {'loss_ce': loss_ce} |
|
''' |
|
|
|
@torch.no_grad() |
|
def loss_cardinality(self, outputs, targets, indices, num_boxes): |
|
""" Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes |
|
This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients |
|
""" |
|
pred_logits = outputs['pred_logits'] |
|
device = pred_logits.device |
|
tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device) |
|
|
|
card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1) |
|
|
|
card_pred = card_pred.sum(dim=1) |
|
|
|
card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) |
|
losses = {'cardinality_error': card_err} |
|
return losses |
|
|
|
def loss_boxes(self, outputs, targets, indices, num_boxes): |
|
"""Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss |
|
targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] |
|
The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. |
|
""" |
|
assert 'pred_boxes' in outputs |
|
idx = self._get_src_permutation_idx(indices) |
|
|
|
src_boxes = outputs['pred_boxes'][idx] |
|
target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) |
|
|
|
loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') |
|
|
|
losses = {} |
|
losses['loss_bbox'] = loss_bbox.sum() / num_boxes |
|
|
|
loss_giou = 1 - torch.diag(generalized_box_iou( |
|
box_cxcywh_to_xyxy(src_boxes), |
|
box_cxcywh_to_xyxy(target_boxes))) |
|
losses['loss_giou'] = loss_giou.sum() / num_boxes |
|
return losses |
|
|
|
def loss_masks(self, outputs, targets, indices, num_boxes): |
|
"""Compute the losses related to the masks: the focal loss and the dice loss. |
|
targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] |
|
""" |
|
assert "pred_masks" in outputs |
|
|
|
src_idx = self._get_src_permutation_idx(indices) |
|
tgt_idx = self._get_tgt_permutation_idx(indices) |
|
src_masks = outputs["pred_masks"] |
|
src_masks = src_masks[src_idx] |
|
masks = [t["masks"] for t in targets] |
|
|
|
target_masks, valid = nested_tensor_from_tensor_list(masks).decompose() |
|
target_masks = target_masks.to(src_masks) |
|
target_masks = target_masks[tgt_idx] |
|
|
|
|
|
src_masks = interpolate(src_masks[:, None], size=target_masks.shape[-2:], |
|
mode="bilinear", align_corners=False) |
|
src_masks = src_masks[:, 0].flatten(1) |
|
|
|
target_masks = target_masks.flatten(1) |
|
target_masks = target_masks.view(src_masks.shape) |
|
losses = { |
|
"loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes), |
|
"loss_dice": dice_loss(src_masks, target_masks, num_boxes), |
|
} |
|
return losses |
|
|
|
def _get_src_permutation_idx(self, indices): |
|
|
|
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) |
|
src_idx = torch.cat([src for (src, _) in indices]) |
|
return batch_idx, src_idx |
|
|
|
def _get_tgt_permutation_idx(self, indices): |
|
|
|
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) |
|
tgt_idx = torch.cat([tgt for (_, tgt) in indices]) |
|
return batch_idx, tgt_idx |
|
|
|
def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): |
|
loss_map = { |
|
'labels': self.loss_labels, |
|
'cardinality': self.loss_cardinality, |
|
'boxes': self.loss_boxes, |
|
'masks': self.loss_masks |
|
} |
|
assert loss in loss_map, f'do you really want to compute {loss} loss?' |
|
return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) |
|
|
|
def forward(self, outputs, targets): |
|
""" This performs the loss computation. |
|
Parameters: |
|
outputs: dict of tensors, see the output specification of the model for the format |
|
targets: list of dicts, such that len(targets) == batch_size. |
|
The expected keys in each dict depends on the losses applied, see each loss' doc |
|
""" |
|
|
|
outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'} |
|
|
|
|
|
indices = self.matcher(outputs_without_aux, targets) |
|
|
|
|
|
|
|
|
|
|
|
num_boxes = sum(len(t["labels"]) for t in targets) |
|
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) |
|
if is_dist_avail_and_initialized(): |
|
torch.distributed.all_reduce(num_boxes) |
|
num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() |
|
|
|
|
|
losses = {} |
|
for loss in self.losses: |
|
losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes)) |
|
''' |
|
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer. |
|
if 'aux_outputs' in outputs: |
|
for i, aux_outputs in enumerate(outputs['aux_outputs']): |
|
indices = self.matcher(aux_outputs, targets) |
|
for loss in self.losses: |
|
if loss == 'masks': |
|
# Intermediate masks losses are too costly to compute, we ignore them. |
|
continue |
|
kwargs = {} |
|
if loss == 'labels': |
|
# Logging is enabled only for the last layer |
|
kwargs = {'log': False} |
|
l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs) |
|
l_dict = {k + f'_{i}': v for k, v in l_dict.items()} |
|
losses.update(l_dict) |
|
''' |
|
return losses |
|
|
|
def eval_fn(data_loader, model,criterion, device): |
|
model.eval() |
|
criterion.eval() |
|
summary_loss = AverageMeter() |
|
|
|
with torch.no_grad(): |
|
|
|
|
|
|
|
|
|
|
|
tk0 = tqdm(data_loader, total=len(data_loader)-1) |
|
for step, (images, bboxes, captions) in enumerate(tk0): |
|
|
|
try: |
|
flattened_captions = [caption for sublist in captions for caption in sublist] |
|
captions = tokenizer(flattened_captions, padding=True, return_tensors="pt", truncation=True) |
|
captions = captions["input_ids"] |
|
input_ids = captions.reshape(batch_size, num_queries, -1).to(device) |
|
min_length = 2 |
|
except RuntimeError as e: |
|
print("Reshape failed:", e) |
|
continue |
|
|
|
if input_ids.size(-1) < min_length: |
|
padding_needed = min_length - input_ids.size(-1) |
|
input_ids = F.pad(input_ids, (0, padding_needed), 'constant', PAD_TOKEN) |
|
|
|
|
|
|
|
|
|
targets = build_targets(bboxes, input_ids[:, :, 1:]) |
|
|
|
|
|
|
|
images = list(image.to(device) for image in images) |
|
|
|
targets = [{k: v.to(device) for k, v in t.items()} for t in targets] |
|
|
|
|
|
output = model(images,input_ids[:,:,:-1]) |
|
|
|
|
|
loss_dict = criterion(output, targets) |
|
weight_dict = criterion.weight_dict |
|
|
|
losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) |
|
|
|
|
|
summary_loss.update(losses.item(),BATCH_SIZE) |
|
|
|
|
|
|
|
|
|
loss_dict = {k: v.detach() for k, v in loss_dict.items()} |
|
|
|
del images, bboxes, captions, output, targets, loss_dict |
|
torch.cuda.empty_cache() |
|
|
|
tk0.set_postfix(loss=summary_loss.avg) |
|
|
|
|
|
return summary_loss |
|
|
|
def build_targets(bboxes, captions): |
|
targets = [] |
|
for i, (bbox, caption) in enumerate(zip(bboxes, captions)): |
|
target = { |
|
"boxes": bbox, |
|
"labels": caption, |
|
} |
|
targets.append(target) |
|
return targets |
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
train_dataset = CocoDataset(root_dir="../data/coco91/train2017", |
|
annotation_file="../data/coco91/annotations/captions_train2017.json", |
|
instance_file="../data/coco91/annotations/instances_train2017.json", |
|
transform=transform) |
|
val_dataset = CocoDataset(root_dir="../data/coco91/val2017", annotation_file="../data/coco91/annotations/captions_val2017.json", |
|
instance_file="../data/coco91/annotations/instances_val2017.json", |
|
transform=transform) |
|
|
|
|
|
batch_size=4 |
|
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate) |
|
val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=custom_collate) |
|
|
|
|
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
|
|
|
|
|
|
PAD_TOKEN = tokenizer.pad_token_id |
|
|
|
|
|
|
|
|
|
PAD_SOS = tokenizer.cls_token_id |
|
|
|
|
|
vocab_size = tokenizer.vocab_size |
|
|
|
print(f"Pad token: {PAD_TOKEN}") |
|
print(f"Start of Sequence token: {PAD_SOS}, ID: {PAD_SOS}") |
|
print(f"Vocab size: {vocab_size}") |
|
|
|
matcher = HungarianMatcher() |
|
|
|
weight_dict = weight_dict = {'loss_ce': 1, 'loss_bbox': 1 , 'loss_giou': 1} |
|
|
|
losses = ['labels', 'boxes', 'cardinality'] |
|
|
|
criterion = nn.CrossEntropyLoss(ignore_index=PAD_TOKEN) |
|
|
|
model = LLMEyaCapModel(num_queries=NUM_QUERIES,vocab_size=vocab_size) |
|
model = model.to(device) |
|
|
|
criterion = SetCriterion(vocab_size, matcher=matcher, weight_dict=weight_dict, eos_coef = NULL_CLASS_COEF, losses=losses) |
|
criterion = criterion.to(device) |
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=LR) |
|
|
|
best_loss = 10**5 |
|
|
|
LR = 2e-6 |
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=LR) |
|
EPOCHS=1 |
|
num_queries=NUM_QUERIES |
|
batch_size=4 |
|
|
|
for epoch in range(EPOCHS): |
|
time_start = time.time() |
|
train_loss = train_fn(train_loader, model,criterion, optimizer,device,scheduler=None,epoch=epoch) |
|
valid_loss = eval_fn(val_loader, model,criterion, device) |
|
|
|
elapsed = time.time() - time_start |
|
chk_name = f'LLMEyeCap_01_e{epoch}.bin' |
|
torch.save(model.state_dict(), chk_name) |
|
print(f"[Epoch {epoch+1:2d} / {EPOCHS:2d}] Train loss: {train_loss.avg:.3f}. Val loss: {valid_loss.avg:.3f} --> {chk_name} [{elapsed/60:.0f} mins]") |
|
|
|
if valid_loss.avg < best_loss: |
|
best_loss = valid_loss.avg |
|
print(f'Best model found in epoch {epoch+1}........Saving Model') |
|
torch.save(model.state_dict(), 'LLMEyeCap_01_model.bin') |
|
|