File size: 5,456 Bytes
2b37f77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6586467
2b37f77
 
 
 
6586467
2b37f77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import numpy as np
import gradio as gr
import torch
from torchvision import models, transforms
from PIL import Image


# -- install detectron2 from source ------------------------------------------------------------------------------
os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
os.system('pip install pyyaml==5.1')

import detectron2

from detectron2.utils.logger import setup_logger

from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog
import cv2


setup_logger()

# -- load rcnn model ---------------------------------------------------------------------------------------------
cfg = get_cfg()
# load model config
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5  # set threshold for this model
# set model weights
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
cfg.MODEL.DEVICE= 'cpu'  # move to cpu
predictor = DefaultPredictor(cfg)  # create model

# -- load design modernity model for classification --------------------------------------------------------------
DesignModernityModel = torch.load("DesignModernityModelBonus.pt")

DesignModernityModel.eval() # set state of the model to inference

# Set class labels
LABELS = ['2000-2003', '2004-2006', '2007-2009', '2010-2012', '2013-2015', '2016-2019']
n_labels = len(LABELS)

# define maéan and std dev for normalization
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]

# define image transformation steps
carTransforms = transforms.Compose([transforms.Resize(224),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=MEAN, std=STD)])

     
# -- define a function for extraction of the detected car ---------------------------------------------------------                                                                                                                                          
def cropImage(outputs, im, boxes, car_class_true):
    # Get the masks
    masks = list(np.array(outputs["instances"].pred_masks[car_class_true]))
    max_idx = torch.tensor([(x[2] - x[0])*(x[3] - x[1]) for x in boxes]).argmax().item()

    # Pick an item to mask
    item_mask = masks[max_idx]
    
    # Get the true bounding box of the mask
    segmentation = np.where(item_mask == True) # return a list of different position in the bow, which are the actual detected object
    x_min = int(np.min(segmentation[1])) # minimum x position
    x_max = int(np.max(segmentation[1]))
    y_min = int(np.min(segmentation[0]))
    y_max = int(np.max(segmentation[0]))
    
    # Create cropped image from the just portion of the image we want
    cropped = Image.fromarray(im[y_min:y_max, x_min:x_max, :], mode = 'RGB')
    # Create a PIL Image out of the mask
    mask = Image.fromarray((item_mask * 255).astype('uint8')) ###### change 255 
    # Crop the mask to match the cropped image
    cropped_mask = mask.crop((x_min, y_min, x_max, y_max))
    
    # Load in a background image and choose a paste position
    height = y_max-y_min
    width = x_max-x_min
    background = Image.new(mode='RGB', size=(width, height), color=(255, 255, 255, 0))
    
    # Create a new foreground image as large as the composite and paste the cropped image on top
    new_fg_image = Image.new('RGB', background.size)
    new_fg_image.paste(cropped)
    
    # Create a new alpha mask as large as the composite and paste the cropped mask
    new_alpha_mask = Image.new('L', background.size, color=0)
    new_alpha_mask.paste(cropped_mask)
    
    #composite the foreground and background using the alpha mask
    composite = Image.composite(new_fg_image, background, new_alpha_mask)
    
    return composite
    
# -- define function for image segmentation and classification --------------------------------------------------------
def classifyCar(im):
    # read image
    #im = cv2.imread(im)
    # perform segmentation
    outputs = predictor(im)
    v = Visualizer(im[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1)
    out = v.draw_instance_predictions(outputs["instances"])
    # check if a car was detected in the image
    car_class_true = outputs["instances"].pred_classes == 2
    boxes = list(outputs["instances"].pred_boxes[car_class_true])
    
    # if a car was detected, extract the car and perform modernity score classification
    if len(boxes) != 0:
        im2 = cropImage(outputs, im, boxes, car_class_true)

      
        with torch.no_grad():
            scores = torch.nn.functional.softmax(DesignModernityModel(carTransforms(im2).unsqueeze(0))[0])
        label = {LABELS[i]: float(scores[i]) for i in range(n_labels)}
      
    # if no car was detected, show original image and print "No car detected"  
    else:
        im2 = Image.fromarray(np.uint8(im)).convert('RGB')
        label = "No car detected"

    return im2, label

# -- create interface for model ----------------------------------------------------------------------------------------
interface = gr.Interface(classifyCar, inputs='image', outputs=['image','label'], cache_examples=False, title='Modernity car classification')
interface.launch()