Spaces:
Running
Running
import numpy as np | |
import cv2 | |
import argparse | |
import warnings | |
try: | |
import torch as th | |
from transformers import AutoImageProcessor ,Mask2FormerModel,Mask2FormerForUniversalSegmentation | |
except ImportError as error: | |
raise ('Try installing torch and Transfomers module using pip.') | |
warnings.filterwarnings("ignore") | |
class MASK2FORMER: | |
def __init__(self,model_name="facebook/mask2former-swin-small-ade-semantic",class_id =6): ## use large | |
self.image_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-small-ade-semantic") | |
self.maskformer_processor = Mask2FormerModel.from_pretrained(model_name) | |
self.maskformer_model = Mask2FormerForUniversalSegmentation.from_pretrained(model_name) | |
self.DEVICE = "cuda" if th.cuda.is_available() else 'cpu' | |
self.segment_id = class_id | |
self.maskformer_model.to(self.DEVICE) | |
def create_rgb_mask(self,mask,value=255): | |
gray_3_channel = cv2.merge((mask, mask, mask)) | |
gray_3_channel[mask==value] = (255,255,255) | |
return gray_3_channel.astype(np.uint8) | |
def get_mask(self,segmentation): | |
""" | |
Mask out the segment of the class from the provided segment_id | |
args : segmentation -> torch.obj - segmentation ouput from the maskformer model | |
segment_id -> class id of the object to be extracted | |
return : ndarray -> 2D Mask of the image | |
""" | |
if self.segment_id == "vehicle": | |
mask = (segmentation.cpu().numpy().copy()==2) | (segmentation.cpu().numpy().copy()==5) | (segmentation.cpu().numpy().copy()== 7) | |
else: | |
mask = (segmentation.cpu().numpy() == 6) | |
visual_mask = (mask * 255).astype(np.uint8) | |
return visual_mask #np.asarray(visual_mask) | |
def generate_road_mask(self,img): | |
""" | |
Extract semantic road mask from raw image | |
args : img -> np.array - input_image | |
return : ndarray -> masked out road . | |
""" | |
inputs = self.image_processor(img, return_tensors="pt") | |
inputs = inputs.to(self.DEVICE) | |
with th.no_grad(): | |
outputs = self.maskformer_model(**inputs) | |
segmentation = self.image_processor.post_process_semantic_segmentation(outputs,target_sizes=[(img.shape[0],img.shape[1])])[0] | |
segmented_mask = self.get_mask(segmentation=segmentation) | |
return segmented_mask | |
def get_rgb_mask(self,img,segmented_mask): | |
""" | |
Extract RGB road image and removing the background . | |
args: img -> ndarray - raw image | |
segmented_mask - binary mask from the semantic segmentation | |
return : ndarray -> RGB road image with background pixels as 0. | |
""" | |
predicted_rgb_mask = self.create_rgb_mask(segmented_mask) | |
rgb_mask_img = cv2.bitwise_and(img,predicted_rgb_mask ) | |
return rgb_mask_img | |
def run_inference(self,image_name): | |
""" | |
Function used to create a segmentation mask for specific segment_id provided. The function uses | |
"facebook/maskformer-swin-small-coco" maskformer model to extract segmentation mask for the provided image | |
args: image_name -> str/numpy_array- image path read and processed by maskformer . | |
out_path -> str - output path save the masked output | |
skip_read -> bool- If provided image is nd_array skip_read == True else False | |
segment_id -> int- id value to extract maks Default value is 100 for road | |
""" | |
input_image = cv2.cvtColor( cv2.imread(image_name),cv2.COLOR_BGR2RGB) | |
road_mask = self.generate_road_mask(input_image) | |
road_image = self.get_rgb_mask(input_image,road_mask) | |
obj_prop = round((np.count_nonzero(road_image) / np.size(road_image)) * 100, 1) | |
## empty gou cache | |
with th.no_grad(): | |
th.cuda.empty_cache() | |
return obj_prop | |
def main(args): | |
mask2former = ROADMASK_WITH_MASK2FORMER() | |
input_image = cv2.cvtColor( cv2.imread(args.image_path),cv2.COLOR_BGR2RGB) | |
road_mask = mask2former.generate_road_mask(input_image) | |
road_image = mask2former.get_rgb_mask(input_image,road_mask) | |
obj_prop = round(np.count_nonzero(road_image) / np.size(road_image) * 100, 1) | |
return road_mask , road_image , obj_prop | |
if __name__=="__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument('-image_path',help='raw_image_path', required=True) | |
args = parser.parse_args() | |
main(args) |