import numpy as np import cv2 import argparse import warnings try: import torch as th from transformers import AutoImageProcessor ,Mask2FormerModel,Mask2FormerForUniversalSegmentation except ImportError as error: raise ('Try installing torch and Transfomers module using pip.') warnings.filterwarnings("ignore") class MASK2FORMER: def __init__(self,model_name="facebook/mask2former-swin-small-ade-semantic",class_id =6): ## use large self.image_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-small-ade-semantic") self.maskformer_processor = Mask2FormerModel.from_pretrained(model_name) self.maskformer_model = Mask2FormerForUniversalSegmentation.from_pretrained(model_name) self.DEVICE = "cuda" if th.cuda.is_available() else 'cpu' self.segment_id = class_id self.maskformer_model.to(self.DEVICE) def create_rgb_mask(self,mask,value=255): gray_3_channel = cv2.merge((mask, mask, mask)) gray_3_channel[mask==value] = (255,255,255) return gray_3_channel.astype(np.uint8) def get_mask(self,segmentation): """ Mask out the segment of the class from the provided segment_id args : segmentation -> torch.obj - segmentation ouput from the maskformer model segment_id -> class id of the object to be extracted return : ndarray -> 2D Mask of the image """ if self.segment_id == "vehicle": mask = (segmentation.cpu().numpy().copy()==2) | (segmentation.cpu().numpy().copy()==5) | (segmentation.cpu().numpy().copy()== 7) else: mask = (segmentation.cpu().numpy() == 6) visual_mask = (mask * 255).astype(np.uint8) return visual_mask #np.asarray(visual_mask) def generate_road_mask(self,img): """ Extract semantic road mask from raw image args : img -> np.array - input_image return : ndarray -> masked out road . """ inputs = self.image_processor(img, return_tensors="pt") inputs = inputs.to(self.DEVICE) with th.no_grad(): outputs = self.maskformer_model(**inputs) segmentation = self.image_processor.post_process_semantic_segmentation(outputs,target_sizes=[(img.shape[0],img.shape[1])])[0] segmented_mask = self.get_mask(segmentation=segmentation) return segmented_mask def get_rgb_mask(self,img,segmented_mask): """ Extract RGB road image and removing the background . args: img -> ndarray - raw image segmented_mask - binary mask from the semantic segmentation return : ndarray -> RGB road image with background pixels as 0. """ predicted_rgb_mask = self.create_rgb_mask(segmented_mask) rgb_mask_img = cv2.bitwise_and(img,predicted_rgb_mask ) return rgb_mask_img def run_inference(self,image_name): """ Function used to create a segmentation mask for specific segment_id provided. The function uses "facebook/maskformer-swin-small-coco" maskformer model to extract segmentation mask for the provided image args: image_name -> str/numpy_array- image path read and processed by maskformer . out_path -> str - output path save the masked output skip_read -> bool- If provided image is nd_array skip_read == True else False segment_id -> int- id value to extract maks Default value is 100 for road """ input_image = cv2.cvtColor( cv2.imread(image_name),cv2.COLOR_BGR2RGB) road_mask = self.generate_road_mask(input_image) road_image = self.get_rgb_mask(input_image,road_mask) obj_prop = round((np.count_nonzero(road_image) / np.size(road_image)) * 100, 1) ## empty gou cache with th.no_grad(): th.cuda.empty_cache() return obj_prop def main(args): mask2former = ROADMASK_WITH_MASK2FORMER() input_image = cv2.cvtColor( cv2.imread(args.image_path),cv2.COLOR_BGR2RGB) road_mask = mask2former.generate_road_mask(input_image) road_image = mask2former.get_rgb_mask(input_image,road_mask) obj_prop = round(np.count_nonzero(road_image) / np.size(road_image) * 100, 1) return road_mask , road_image , obj_prop if __name__=="__main__": parser = argparse.ArgumentParser() parser.add_argument('-image_path',help='raw_image_path', required=True) args = parser.parse_args() main(args)