File size: 5,442 Bytes
6b2d891
9a30e62
28b287d
2e4ea99
 
 
adca53b
1094aba
6b2d891
76c1707
fa10257
a64c8d1
958d113
2e4ea99
d01979a
76c1707
2e4ea99
 
 
 
b8316ce
 
5f79397
d54d7ce
76c1707
2e4ea99
 
76c1707
2e4ea99
11bf99e
2e4ea99
 
11bf99e
2e4ea99
11bf99e
 
2e4ea99
76c1707
2e4ea99
 
 
 
76c1707
1094aba
146be9b
2e4ea99
76c1707
6f792de
 
 
76c1707
28b287d
 
 
2e4ea99
76c1707
 
50861f1
 
 
 
 
 
 
76c1707
50861f1
 
 
 
 
 
76c1707
50861f1
 
 
 
 
 
76c1707
50861f1
 
 
 
76c1707
50861f1
 
 
76c1707
50861f1
 
 
76c1707
50861f1
 
76c1707
50861f1
 
76c1707
2e4ea99
76c1707
8e8bbb8
76c1707
00fbc7c
2f0508c
00fbc7c
76c1707
7d640ab
2f0508c
 
76c1707
17cf5c8
20328b0
 
17cf5c8
20328b0
 
 
17cf5c8
76c1707
17cf5c8
 
 
2e4ea99
462d840
2e4ea99
76c1707
50861f1
2e4ea99
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131

import os
import numpy as np
import gradio as gr
import torch
from torchvision import models, transforms
from PIL import Image


# -- install detectron2 from source ------------------------------------------------------------------------------
os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
os.system('pip install pyyaml==5.1')

import detectron2

from detectron2.utils.logger import setup_logger

from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog
import cv2


setup_logger()

# -- load rcnn model ---------------------------------------------------------------------------------------------
cfg = get_cfg()
# load model config
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5  # set threshold for this model
# set model weights
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
cfg.MODEL.DEVICE= 'cpu'  # move to cpu
predictor = DefaultPredictor(cfg)  # create model

# -- load design modernity model for classification --------------------------------------------------------------
DesignModernityModel = torch.load("DesignModernityModel.pt")

DesignModernityModel.eval() # set state of the model to inference

# Set class labels
LABELS = ['2000-2003', '2006-2008', '2009-2011', '2012-2014', '2015-2018']
n_labels = len(LABELS)

# define maéan and std dev for normalization
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]

# define image transformation steps
carTransforms = transforms.Compose([transforms.Resize(224),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=MEAN, std=STD)])

     
# -- define a function for extraction of the detected car ---------------------------------------------------------                                                                                                                                          
def cropImage(outputs, im, boxes, car_class_true):
    # Get the masks
    masks = list(np.array(outputs["instances"].pred_masks[car_class_true]))
    max_idx = torch.tensor([(x[2] - x[0])*(x[3] - x[1]) for x in boxes]).argmax().item()

    # Pick an item to mask
    item_mask = masks[max_idx]
    
    # Get the true bounding box of the mask
    segmentation = np.where(item_mask == True) # return a list of different position in the bow, which are the actual detected object
    x_min = int(np.min(segmentation[1])) # minimum x position
    x_max = int(np.max(segmentation[1]))
    y_min = int(np.min(segmentation[0]))
    y_max = int(np.max(segmentation[0]))
    
    # Create cropped image from the just portion of the image we want
    cropped = Image.fromarray(im[y_min:y_max, x_min:x_max, :], mode = 'RGB')
    # Create a PIL Image out of the mask
    mask = Image.fromarray((item_mask * 255).astype('uint8')) ###### change 255 
    # Crop the mask to match the cropped image
    cropped_mask = mask.crop((x_min, y_min, x_max, y_max))
    
    # Load in a background image and choose a paste position
    height = y_max-y_min
    width = x_max-x_min
    background = Image.new(mode='RGB', size=(width, height), color=(255, 255, 255, 0))
    
    # Create a new foreground image as large as the composite and paste the cropped image on top
    new_fg_image = Image.new('RGB', background.size)
    new_fg_image.paste(cropped)
    
    # Create a new alpha mask as large as the composite and paste the cropped mask
    new_alpha_mask = Image.new('L', background.size, color=0)
    new_alpha_mask.paste(cropped_mask)
    
    #composite the foreground and background using the alpha mask
    composite = Image.composite(new_fg_image, background, new_alpha_mask)
    
    return composite
    
# -- define function for image segmentation and classification --------------------------------------------------------
def classifyCar(im):
    # read image
    #im = cv2.imread(im)
    # perform segmentation
    outputs = predictor(im)
    v = Visualizer(im[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1)
    out = v.draw_instance_predictions(outputs["instances"])
    # check if a car was detected in the image
    car_class_true = outputs["instances"].pred_classes == 2
    boxes = list(outputs["instances"].pred_boxes[car_class_true])
    
    # if a car was detected, extract the car and perform modernity score classification
    if len(boxes) != 0:
        im2 = cropImage(outputs, im, boxes, car_class_true)

      
        with torch.no_grad():
            scores = torch.nn.functional.softmax(DesignModernityModel(carTransforms(im2).unsqueeze(0))[0])
        label = {LABELS[i]: float(scores[i]) for i in range(n_labels)}
      
    # if no car was detected, show original image and print "No car detected"  
    else:
        im2 = Image.fromarray(np.uint8(im)).convert('RGB')
        label = "No car detected"

    return im2, label

# -- create interface for model ----------------------------------------------------------------------------------------
interface = gr.Interface(classifyCar, inputs='image', outputs=['image','label'], cache_examples=False, title='Modernity car classification')
interface.launch()