Spaces:
Runtime error
Runtime error
File size: 5,442 Bytes
6b2d891 9a30e62 28b287d 2e4ea99 adca53b 1094aba 6b2d891 76c1707 fa10257 a64c8d1 958d113 2e4ea99 d01979a 76c1707 2e4ea99 b8316ce 5f79397 d54d7ce 76c1707 2e4ea99 76c1707 2e4ea99 11bf99e 2e4ea99 11bf99e 2e4ea99 11bf99e 2e4ea99 76c1707 2e4ea99 76c1707 1094aba 146be9b 2e4ea99 76c1707 6f792de 76c1707 28b287d 2e4ea99 76c1707 50861f1 76c1707 50861f1 76c1707 50861f1 76c1707 50861f1 76c1707 50861f1 76c1707 50861f1 76c1707 50861f1 76c1707 50861f1 76c1707 2e4ea99 76c1707 8e8bbb8 76c1707 00fbc7c 2f0508c 00fbc7c 76c1707 7d640ab 2f0508c 76c1707 17cf5c8 20328b0 17cf5c8 20328b0 17cf5c8 76c1707 17cf5c8 2e4ea99 462d840 2e4ea99 76c1707 50861f1 2e4ea99 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import os
import numpy as np
import gradio as gr
import torch
from torchvision import models, transforms
from PIL import Image
# -- install detectron2 from source ------------------------------------------------------------------------------
os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
os.system('pip install pyyaml==5.1')
import detectron2
from detectron2.utils.logger import setup_logger
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog
import cv2
setup_logger()
# -- load rcnn model ---------------------------------------------------------------------------------------------
cfg = get_cfg()
# load model config
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model
# set model weights
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
cfg.MODEL.DEVICE= 'cpu' # move to cpu
predictor = DefaultPredictor(cfg) # create model
# -- load design modernity model for classification --------------------------------------------------------------
DesignModernityModel = torch.load("DesignModernityModel.pt")
DesignModernityModel.eval() # set state of the model to inference
# Set class labels
LABELS = ['2000-2003', '2006-2008', '2009-2011', '2012-2014', '2015-2018']
n_labels = len(LABELS)
# define maéan and std dev for normalization
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
# define image transformation steps
carTransforms = transforms.Compose([transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean=MEAN, std=STD)])
# -- define a function for extraction of the detected car ---------------------------------------------------------
def cropImage(outputs, im, boxes, car_class_true):
# Get the masks
masks = list(np.array(outputs["instances"].pred_masks[car_class_true]))
max_idx = torch.tensor([(x[2] - x[0])*(x[3] - x[1]) for x in boxes]).argmax().item()
# Pick an item to mask
item_mask = masks[max_idx]
# Get the true bounding box of the mask
segmentation = np.where(item_mask == True) # return a list of different position in the bow, which are the actual detected object
x_min = int(np.min(segmentation[1])) # minimum x position
x_max = int(np.max(segmentation[1]))
y_min = int(np.min(segmentation[0]))
y_max = int(np.max(segmentation[0]))
# Create cropped image from the just portion of the image we want
cropped = Image.fromarray(im[y_min:y_max, x_min:x_max, :], mode = 'RGB')
# Create a PIL Image out of the mask
mask = Image.fromarray((item_mask * 255).astype('uint8')) ###### change 255
# Crop the mask to match the cropped image
cropped_mask = mask.crop((x_min, y_min, x_max, y_max))
# Load in a background image and choose a paste position
height = y_max-y_min
width = x_max-x_min
background = Image.new(mode='RGB', size=(width, height), color=(255, 255, 255, 0))
# Create a new foreground image as large as the composite and paste the cropped image on top
new_fg_image = Image.new('RGB', background.size)
new_fg_image.paste(cropped)
# Create a new alpha mask as large as the composite and paste the cropped mask
new_alpha_mask = Image.new('L', background.size, color=0)
new_alpha_mask.paste(cropped_mask)
#composite the foreground and background using the alpha mask
composite = Image.composite(new_fg_image, background, new_alpha_mask)
return composite
# -- define function for image segmentation and classification --------------------------------------------------------
def classifyCar(im):
# read image
#im = cv2.imread(im)
# perform segmentation
outputs = predictor(im)
v = Visualizer(im[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1)
out = v.draw_instance_predictions(outputs["instances"])
# check if a car was detected in the image
car_class_true = outputs["instances"].pred_classes == 2
boxes = list(outputs["instances"].pred_boxes[car_class_true])
# if a car was detected, extract the car and perform modernity score classification
if len(boxes) != 0:
im2 = cropImage(outputs, im, boxes, car_class_true)
with torch.no_grad():
scores = torch.nn.functional.softmax(DesignModernityModel(carTransforms(im2).unsqueeze(0))[0])
label = {LABELS[i]: float(scores[i]) for i in range(n_labels)}
# if no car was detected, show original image and print "No car detected"
else:
im2 = Image.fromarray(np.uint8(im)).convert('RGB')
label = "No car detected"
return im2, label
# -- create interface for model ----------------------------------------------------------------------------------------
interface = gr.Interface(classifyCar, inputs='image', outputs=['image','label'], cache_examples=False, title='Modernity car classification')
interface.launch()
|