CuriousDolphin's picture
add detr panoptic resnet101
28210ca
raw
history blame
9.05 kB
from functools import cache
import io
import itertools
import torch
import torchvision.transforms as T
import os
import numpy as np
import seaborn as sns
from torch import nn
from torchvision.models import resnet50
from panopticapi.utils import id2rgb, rgb2id
from supervision import Detections, BoxAnnotator, MaskAnnotator
from PIL import Image
torch.set_grad_enabled(False)
# https://colab.research.google.com/github/facebookresearch/detr/blob/colab/notebooks/detr_demo.ipynb#scrollTo=cfCcEYjg7y46
DETR_DEMO_WEIGHTS_URI = "https://dl.fbaipublicfiles.com/detr/detr_demo-da2a99e9.pth"
TORCH_HOME = os.path.abspath(os.curdir) + "/data/cache"
os.environ["TORCH_HOME"] = TORCH_HOME
print("Torch home:", TORCH_HOME)
# standard PyTorch mean-std input image normalization
def normalize_img(image):
transform = T.Compose(
[
T.Resize(800),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
return transform(image).unsqueeze(0)
# for output bounding box post-processing
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=1)
def rescale_bboxes(out_bbox, size):
img_w, img_h = size
b = box_cxcywh_to_xyxy(out_bbox)
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
return b
class DETRdemo(nn.Module):
"""
Demo DETR implementation.
Demo implementation of DETR in minimal number of lines, with the
following differences wrt DETR in the paper:
* learned positional encoding (instead of sine)
* positional encoding is passed at input (instead of attention)
* fc bbox predictor (instead of MLP)
The model achieves ~40 AP on COCO val5k and runs at ~28 FPS on Tesla V100.
Only batch size 1 supported.
"""
def __init__(
self,
num_classes,
hidden_dim=256,
nheads=8,
num_encoder_layers=6,
num_decoder_layers=6,
):
super().__init__()
# create ResNet-50 backbone
self.backbone = resnet50()
del self.backbone.fc
# create conversion layer
self.conv = nn.Conv2d(2048, hidden_dim, 1)
# create a default PyTorch transformer
self.transformer = nn.Transformer(
hidden_dim, nheads, num_encoder_layers, num_decoder_layers
)
# prediction heads, one extra class for predicting non-empty slots
# note that in baseline DETR linear_bbox layer is 3-layer MLP
self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
self.linear_bbox = nn.Linear(hidden_dim, 4)
# output positional encodings (object queries)
self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))
# spatial positional encodings
# note that in baseline DETR we use sine positional encodings
self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
def forward(self, inputs):
# propagate inputs through ResNet-50 up to avg-pool layer
x = self.backbone.conv1(inputs)
x = self.backbone.bn1(x)
x = self.backbone.relu(x)
x = self.backbone.maxpool(x)
x = self.backbone.layer1(x)
x = self.backbone.layer2(x)
x = self.backbone.layer3(x)
x = self.backbone.layer4(x)
# convert from 2048 to 256 feature planes for the transformer
h = self.conv(x)
# construct positional encodings
H, W = h.shape[-2:]
pos = (
torch.cat(
[
self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
],
dim=-1,
)
.flatten(0, 1)
.unsqueeze(1)
)
# propagate through the transformer
h = self.transformer(
pos + 0.1 * h.flatten(2).permute(2, 0, 1), self.query_pos.unsqueeze(1)
).transpose(0, 1)
# finally project transformer outputs to class labels and bounding boxes
return {
"pred_logits": self.linear_class(h),
"pred_boxes": self.linear_bbox(h).sigmoid(),
}
class SimpleDetr:
@cache
def __init__(self):
self.model = DETRdemo(num_classes=91)
state_dict = torch.hub.load_state_dict_from_url(
url=DETR_DEMO_WEIGHTS_URI,
map_location="cpu",
check_hash=True,
)
self.model.load_state_dict(state_dict)
self.model.eval()
self.box_annotator: BoxAnnotator = BoxAnnotator()
def detect(self, image, conf):
# mean-std normalize the input image (batch-size: 1)
img = normalize_img(image)
# demo model only support by default images with aspect ratio between 0.5 and 2
# if you want to use images with an aspect ratio outside this range
# rescale your image so that the maximum size is at most 1333 for best results
assert (
img.shape[-2] <= 1600 and img.shape[-1] <= 1600
), "demo model only supports images up to 1600 pixels on each side"
# propagate through the model
outputs = self.model(img)
# keep only predictions with 0.7+ confidence
scores = outputs["pred_logits"].softmax(-1)[0, :, :-1]
keep = scores.max(-1).values > conf
# convert boxes from [0; 1] to image scales
bboxes_scaled = rescale_bboxes(outputs["pred_boxes"][0, keep], image.size)
probas = scores[keep]
class_id = []
confidence = []
for prob in probas:
cls_id = prob.argmax()
c = prob[cls_id]
class_id.append(int(cls_id))
confidence.append(float(c))
print(class_id, confidence)
detections = Detections(
xyxy=bboxes_scaled.cpu().detach().numpy(),
class_id=np.array(class_id),
confidence=np.array(confidence),
)
annotated = self.box_annotator.annotate(
scene=np.array(image),
skip_label=False,
detections=detections,
labels=[
f"{CLASSES[cls_id]} {conf:.2f}"
for cls_id, conf in zip(detections.class_id, detections.confidence)
],
)
return annotated
class PanopticDetrResenet101:
@cache
def __init__(self):
self.model, self.postprocessor = torch.hub.load(
"facebookresearch/detr",
"detr_resnet101_panoptic",
pretrained=True,
return_postprocessor=True,
num_classes=250,
)
self.model.eval()
def detect(self, image, conf):
# mean-std normalize the input image (batch-size: 1)
img = normalize_img(image)
outputs = self.model(img)
result = self.postprocessor(
outputs, torch.as_tensor(img.shape[-2:]).unsqueeze(0)
)[0]
print(result.keys())
palette = itertools.cycle(sns.color_palette())
# The segmentation is stored in a special-format png
panoptic_seg = Image.open(io.BytesIO(result["png_string"]))
panoptic_seg = np.array(panoptic_seg, dtype=np.uint8).copy()
# We retrieve the ids corresponding to each mask
panoptic_seg_id = rgb2id(panoptic_seg)
# Finally we color each mask individually
panoptic_seg[:, :, :] = 0
for id in range(panoptic_seg_id.max() + 1):
panoptic_seg[panoptic_seg_id == id] = np.asarray(next(palette)) * 255
return panoptic_seg
# COCO classes
CLASSES = [
"N/A",
"person",
"bicycle",
"car",
"motorcycle",
"airplane",
"bus",
"train",
"truck",
"boat",
"traffic light",
"fire hydrant",
"N/A",
"stop sign",
"parking meter",
"bench",
"bird",
"cat",
"dog",
"horse",
"sheep",
"cow",
"elephant",
"bear",
"zebra",
"giraffe",
"N/A",
"backpack",
"umbrella",
"N/A",
"N/A",
"handbag",
"tie",
"suitcase",
"frisbee",
"skis",
"snowboard",
"sports ball",
"kite",
"baseball bat",
"baseball glove",
"skateboard",
"surfboard",
"tennis racket",
"bottle",
"N/A",
"wine glass",
"cup",
"fork",
"knife",
"spoon",
"bowl",
"banana",
"apple",
"sandwich",
"orange",
"broccoli",
"carrot",
"hot dog",
"pizza",
"donut",
"cake",
"chair",
"couch",
"potted plant",
"bed",
"N/A",
"dining table",
"N/A",
"N/A",
"toilet",
"N/A",
"tv",
"laptop",
"mouse",
"remote",
"keyboard",
"cell phone",
"microwave",
"oven",
"toaster",
"sink",
"refrigerator",
"N/A",
"book",
"clock",
"vase",
"scissors",
"teddy bear",
"hair drier",
"toothbrush",
]