Spaces:
Build error
Build error
from typing import List, OrderedDict, Tuple | |
import warnings | |
import numpy as np | |
import pandas as pd | |
import cv2 | |
import os | |
from torch.nn.modules.conv import Conv2d | |
from torch.utils.data.dataset import ConcatDataset | |
from tqdm import tqdm | |
import argparse | |
from torch.utils.data import Dataset,DataLoader | |
import torch | |
import torch.nn as nn | |
from torchvision import models | |
import detection.transforms as transforms | |
import torchvision.transforms as T | |
import detection.utils as utils | |
import torch.nn.functional as F | |
import shutil | |
import json | |
from detection.engine import train_one_epoch, evaluate | |
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor | |
import torch.multiprocessing | |
import copy | |
from torchvision.ops import MultiScaleRoIAlign | |
from torchvision.models.detection.roi_heads import RoIHeads | |
# First we will create the FRCNN model | |
def get_FRCNN_model(num_classes=1): | |
model = models.detection.fasterrcnn_resnet50_fpn(pretrained=True,trainable_backbone_layers=3,min_size=1800,max_size=3600,image_std=(1.0,1.0,1.0),box_score_thresh=0.001) | |
# get number of input features for the classifier | |
in_features = model.roi_heads.box_predictor.cls_score.in_features | |
# replace the pre-trained head with a new one | |
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes+1) | |
return model | |
# Some utility heads for Bilateral Model | |
class RoIpool(nn.Module): | |
def __init__(self,pool): | |
super().__init__() | |
self.box_roi_pool1 = copy.deepcopy(pool) | |
self.box_roi_pool2 = copy.deepcopy(pool) | |
def forward(self,features,proposals,image_shapes): | |
x = self.box_roi_pool1(features[0],proposals,image_shapes) | |
y = self.box_roi_pool2(features[1],proposals,image_shapes) | |
z = torch.cat((x,y),dim=1) | |
return z | |
class TwoMLPHead(nn.Module): | |
""" | |
Standard heads for FPN-based models | |
Args: | |
in_channels (int): number of input channels | |
representation_size (int): size of the intermediate representation | |
""" | |
def __init__(self, in_channels=None, representation_size=None): | |
super().__init__() | |
self.fc6 = nn.Linear(in_channels, representation_size) | |
self.fc7 = nn.Linear(representation_size, representation_size) | |
def forward(self, x): | |
x = x.flatten(start_dim=1) | |
x = F.relu(self.fc6(x)) | |
x = F.relu(self.fc7(x)) | |
return x | |
# Next the bilateral model | |
class Bilateral_model(nn.Module): | |
def __init__(self,frcnn_model): | |
super().__init__() | |
self.frcnn = frcnn_model | |
self.transform = copy.deepcopy(frcnn_model.transform) | |
self.backbone1 = copy.deepcopy(frcnn_model.backbone) | |
self.backbone2 = copy.deepcopy(frcnn_model.backbone) | |
self.rpn = copy.deepcopy(frcnn_model.rpn) | |
for param in self.rpn.parameters(): | |
param.requires_grad = False | |
for param in self.backbone1.parameters(): | |
param.requires_grad = False | |
for param in self.backbone2.parameters(): | |
param.requires_grad = False | |
box_roi_pool = RoIpool(frcnn_model.roi_heads.box_roi_pool) | |
box_head = TwoMLPHead(512*7*7,1024) | |
box_predictor = copy.deepcopy(frcnn_model.roi_heads.box_predictor) | |
box_score_thresh=0.001 | |
box_nms_thresh=0.5 | |
box_detections_per_img=100 | |
box_fg_iou_thresh=0.5 | |
box_bg_iou_thresh=0.5 | |
box_batch_size_per_image=512 | |
box_positive_fraction=0.25 | |
bbox_reg_weights=None | |
self.roi_heads = RoIHeads( | |
# Box | |
box_roi_pool, | |
box_head, | |
box_predictor, | |
box_fg_iou_thresh, | |
box_bg_iou_thresh, | |
box_batch_size_per_image, | |
box_positive_fraction, | |
bbox_reg_weights, | |
box_score_thresh, | |
box_nms_thresh, | |
box_detections_per_img, | |
) | |
def eager_outputs(self, losses, detections): | |
if self.training: | |
return losses | |
return detections | |
def forward(self, images, targets=None): | |
""" | |
Args: | |
images (list[Tensor(tuples)]): images to be processed | |
targets (list[Dict[str, Tensor]]): ground-truth boxes present in the image (optional) | |
Returns: | |
result (list[BoxList] or dict[Tensor]): the output from the model. | |
During training, it returns a dict[Tensor] which contains the losses. | |
During testing, it returns list[BoxList] contains additional fields | |
like `scores`, `labels` and `mask` (for Mask R-CNN models). | |
""" | |
if self.training and targets is None: | |
raise ValueError("In training mode, targets should be passed") | |
if self.training: | |
assert targets is not None | |
for target in targets: | |
boxes = target["boxes"] | |
if isinstance(boxes, torch.Tensor): | |
if len(boxes.shape) != 2 or boxes.shape[-1] != 4: | |
raise ValueError(f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.") | |
else: | |
raise ValueError(f"Expected target boxes to be of type Tensor, got {type(boxes)}.") | |
original_image_sizes: List[Tuple[int, int]] = [] | |
for img in images: | |
val = img[0].shape[-2:] | |
assert len(val) == 2 | |
original_image_sizes.append((val[0], val[1])) | |
images1 = [img[0] for img in images] | |
images2 = [img[1] for img in images] | |
targets2 = copy.deepcopy(targets) | |
#print(images1.shape) | |
#print(images2.shape) | |
images1, targets = self.transform(images1, targets) | |
images2, targets2 = self.transform(images2, targets2) | |
# Check for degenerate boxes | |
# TODO: Move this to a function | |
if targets is not None: | |
for target_idx, target in enumerate(targets): | |
boxes = target["boxes"] | |
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] | |
if degenerate_boxes.any(): | |
# print the first degenerate box | |
bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0] | |
degen_bb: List[float] = boxes[bb_idx].tolist() | |
raise ValueError( | |
"All bounding boxes should have positive height and width." | |
f" Found invalid box {degen_bb} for target at index {target_idx}." | |
) | |
features1 = self.backbone1(images1.tensors) | |
features2 = self.backbone2(images2.tensors) | |
#print(self.backbone1.out_channels) | |
if isinstance(features1, torch.Tensor): | |
features1 = OrderedDict([("0", features1)]) | |
if isinstance(features2, torch.Tensor): | |
features2 = OrderedDict([("0", features2)]) | |
proposals, proposal_losses = self.rpn(images1, features1, targets) | |
features = {0:features1,1:features2} | |
detections, detector_losses = self.roi_heads(features, proposals, images1.image_sizes, targets) | |
detections = self.transform.postprocess(detections, images1.image_sizes, original_image_sizes) # type: ignore[operator] | |
losses = {} | |
losses.update(detector_losses) | |
losses.update(proposal_losses) | |
if torch.jit.is_scripting(): | |
if not self._has_warned: | |
warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting") | |
self._has_warned = True | |
return losses, detections | |
else: | |
return self.eager_outputs(losses, detections) | |