CV-Agent / tool_utils /mask2former.py
Samarth991's picture
adding CV agent file
0e78cbf
import numpy as np
import cv2
import argparse
import warnings
try:
import torch as th
from transformers import AutoImageProcessor ,Mask2FormerModel,Mask2FormerForUniversalSegmentation
except ImportError as error:
raise ('Try installing torch and Transfomers module using pip.')
warnings.filterwarnings("ignore")
class MASK2FORMER:
def __init__(self,model_name="facebook/mask2former-swin-small-ade-semantic",class_id =6): ## use large
self.image_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-small-ade-semantic")
self.maskformer_processor = Mask2FormerModel.from_pretrained(model_name)
self.maskformer_model = Mask2FormerForUniversalSegmentation.from_pretrained(model_name)
self.DEVICE = "cuda" if th.cuda.is_available() else 'cpu'
self.segment_id = class_id
self.maskformer_model.to(self.DEVICE)
def create_rgb_mask(self,mask,value=255):
gray_3_channel = cv2.merge((mask, mask, mask))
gray_3_channel[mask==value] = (255,255,255)
return gray_3_channel.astype(np.uint8)
def get_mask(self,segmentation):
"""
Mask out the segment of the class from the provided segment_id
args : segmentation -> torch.obj - segmentation ouput from the maskformer model
segment_id -> class id of the object to be extracted
return : ndarray -> 2D Mask of the image
"""
if self.segment_id == "vehicle":
mask = (segmentation.cpu().numpy().copy()==2) | (segmentation.cpu().numpy().copy()==5) | (segmentation.cpu().numpy().copy()== 7)
else:
mask = (segmentation.cpu().numpy() == 6)
visual_mask = (mask * 255).astype(np.uint8)
return visual_mask #np.asarray(visual_mask)
def generate_road_mask(self,img):
"""
Extract semantic road mask from raw image
args : img -> np.array - input_image
return : ndarray -> masked out road .
"""
inputs = self.image_processor(img, return_tensors="pt")
inputs = inputs.to(self.DEVICE)
with th.no_grad():
outputs = self.maskformer_model(**inputs)
segmentation = self.image_processor.post_process_semantic_segmentation(outputs,target_sizes=[(img.shape[0],img.shape[1])])[0]
segmented_mask = self.get_mask(segmentation=segmentation)
return segmented_mask
def get_rgb_mask(self,img,segmented_mask):
"""
Extract RGB road image and removing the background .
args: img -> ndarray - raw image
segmented_mask - binary mask from the semantic segmentation
return : ndarray -> RGB road image with background pixels as 0.
"""
predicted_rgb_mask = self.create_rgb_mask(segmented_mask)
rgb_mask_img = cv2.bitwise_and(img,predicted_rgb_mask )
return rgb_mask_img
def run_inference(self,image_name):
"""
Function used to create a segmentation mask for specific segment_id provided. The function uses
"facebook/maskformer-swin-small-coco" maskformer model to extract segmentation mask for the provided image
args: image_name -> str/numpy_array- image path read and processed by maskformer .
out_path -> str - output path save the masked output
skip_read -> bool- If provided image is nd_array skip_read == True else False
segment_id -> int- id value to extract maks Default value is 100 for road
"""
input_image = cv2.cvtColor( cv2.imread(image_name),cv2.COLOR_BGR2RGB)
road_mask = self.generate_road_mask(input_image)
road_image = self.get_rgb_mask(input_image,road_mask)
obj_prop = round((np.count_nonzero(road_image) / np.size(road_image)) * 100, 1)
## empty gou cache
with th.no_grad():
th.cuda.empty_cache()
return obj_prop
def main(args):
mask2former = ROADMASK_WITH_MASK2FORMER()
input_image = cv2.cvtColor( cv2.imread(args.image_path),cv2.COLOR_BGR2RGB)
road_mask = mask2former.generate_road_mask(input_image)
road_image = mask2former.get_rgb_mask(input_image,road_mask)
obj_prop = round(np.count_nonzero(road_image) / np.size(road_image) * 100, 1)
return road_mask , road_image , obj_prop
if __name__=="__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-image_path',help='raw_image_path', required=True)
args = parser.parse_args()
main(args)