File size: 4,560 Bytes
0e78cbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import numpy as np
import cv2
import argparse
import warnings 
try: 
    import torch as th 
    from transformers import AutoImageProcessor ,Mask2FormerModel,Mask2FormerForUniversalSegmentation
except ImportError as error:
    raise ('Try installing torch and Transfomers module using pip.')

warnings.filterwarnings("ignore")


class MASK2FORMER:
    def __init__(self,model_name="facebook/mask2former-swin-small-ade-semantic",class_id =6): ## use large 
        self.image_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-small-ade-semantic")
        self.maskformer_processor = Mask2FormerModel.from_pretrained(model_name)
        self.maskformer_model = Mask2FormerForUniversalSegmentation.from_pretrained(model_name)
        self.DEVICE = "cuda" if th.cuda.is_available() else 'cpu'
        self.segment_id = class_id
        self.maskformer_model.to(self.DEVICE)
    
    def create_rgb_mask(self,mask,value=255):
        gray_3_channel = cv2.merge((mask, mask, mask))
        gray_3_channel[mask==value] = (255,255,255)
        return gray_3_channel.astype(np.uint8) 

    def get_mask(self,segmentation):
        """
        Mask out the segment of the class from the provided segment_id 
        args : segmentation -> torch.obj - segmentation ouput from the maskformer model 
               segment_id -> class id of the object to be extracted 
        return : ndarray -> 2D Mask of the image 
        """
        if self.segment_id == "vehicle":
            mask = (segmentation.cpu().numpy().copy()==2) | (segmentation.cpu().numpy().copy()==5) | (segmentation.cpu().numpy().copy()== 7)             
        else:
            mask = (segmentation.cpu().numpy() == 6)
        visual_mask = (mask * 255).astype(np.uint8)
        return  visual_mask #np.asarray(visual_mask)

    def generate_road_mask(self,img):
        """
        Extract semantic road mask from raw image 
        args : img -> np.array - input_image 
        return : ndarray -> masked out road .
        """
        inputs = self.image_processor(img, return_tensors="pt")
        
        inputs = inputs.to(self.DEVICE)
        with th.no_grad():
            outputs = self.maskformer_model(**inputs)
        
        segmentation = self.image_processor.post_process_semantic_segmentation(outputs,target_sizes=[(img.shape[0],img.shape[1])])[0]
        segmented_mask = self.get_mask(segmentation=segmentation)
        return  segmented_mask

    def get_rgb_mask(self,img,segmented_mask):
        """
        Extract RGB road image and removing the background .
        args: img -> ndarray - raw image 
              segmented_mask - binary mask from the semantic segmentation 
        return : ndarray -> RGB road image with background pixels as 0. 
        """
        predicted_rgb_mask = self.create_rgb_mask(segmented_mask)
        rgb_mask_img = cv2.bitwise_and(img,predicted_rgb_mask )
        return rgb_mask_img

    def run_inference(self,image_name):
        """
        Function used to create a segmentation mask for specific segment_id provided. The function uses 
        "facebook/maskformer-swin-small-coco" maskformer model to extract segmentation mask for the provided image
        args: image_name -> str/numpy_array- image path read and processed by maskformer .  
              out_path -> str - output path save the masked output
              skip_read -> bool- If provided image is nd_array skip_read == True else False 
              segment_id -> int- id value to extract maks Default value is 100 for road 
        """
        input_image = cv2.cvtColor( cv2.imread(image_name),cv2.COLOR_BGR2RGB)  
        road_mask = self.generate_road_mask(input_image)
        road_image = self.get_rgb_mask(input_image,road_mask)
        obj_prop = round((np.count_nonzero(road_image) / np.size(road_image)) * 100, 1)
        ## empty gou cache 
        with th.no_grad():
            th.cuda.empty_cache()
        return obj_prop
    
        
def main(args):
    mask2former = ROADMASK_WITH_MASK2FORMER()
    input_image = cv2.cvtColor( cv2.imread(args.image_path),cv2.COLOR_BGR2RGB)  

    road_mask = mask2former.generate_road_mask(input_image)
    road_image = mask2former.get_rgb_mask(input_image,road_mask)
    obj_prop = round(np.count_nonzero(road_image) / np.size(road_image) * 100, 1)
    
    return road_mask , road_image , obj_prop

if __name__=="__main__": 
    parser = argparse.ArgumentParser()
    parser.add_argument('-image_path',help='raw_image_path', required=True)
    args = parser.parse_args()
    main(args)