import argparse import os import time import cv2 import numpy as np import requests import torch import wget import yolov7 from mobile_sam import SamPredictor, sam_model_registry from PIL import Image from tqdm import tqdm from transformers import YolosForObjectDetection, YolosImageProcessor from images_to_video import VideoCreator from video_to_images import ImageCreator def download_mobile_sam_weight(path): if not os.path.exists(path): sam_weights = "https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/master/weights/mobile_sam.pt" for i in range(2, len(path.split("/"))): temp = path.split("/")[:i] cur_path = "/".join(temp) if not os.path.isdir(cur_path): os.mkdir(cur_path) model_name = path.split("/")[-1] if model_name in sam_weights: wget.download(sam_weights, path) else: raise NameError( "There is no pretrained weight to download for %s, you need to provide a path to segformer weights." % model_name ) def get_closest_bbox(bbox_list, bbox_target): """ Given a list of bounding boxes, find the one that is closest to the target bounding box. Args: bbox_list: list of bounding boxes bbox_target: target bounding box Returns: closest bounding box """ min_dist = 100000000 min_idx = 0 for idx, bbox in enumerate(bbox_list): dist = np.linalg.norm(bbox - bbox_target) if dist < min_dist: min_dist = dist min_idx = idx return bbox_list[min_idx] def get_bboxes(image_file, image, model, image_processor, threshold=0.9): if image_processor is None: results = model(image_file) predictions = results.pred[0] boxes = predictions[:, :4].detach().numpy() return boxes else: inputs = image_processor(images=image, return_tensors="pt") outputs = model(**inputs) target_sizes = torch.tensor([image.size[::-1]]) results = image_processor.post_process_object_detection( outputs, threshold=threshold, target_sizes=target_sizes )[0] return results["boxes"].detach().numpy() def segment_video( video_filename, dir_frames, image_start, image_end, bbox_file, skip_vid2im, mobile_sam_weights, auto_detect=False, tracker_name="yolov7", background_color="#009000", output_dir="output_frames", output_video="output.mp4", pbar=False, reverse_mask=False, ): if not skip_vid2im: vid_to_im = ImageCreator( video_filename, dir_frames, image_start=image_start, image_end=image_end, pbar=pbar, ) vid_to_im.get_images() # Get fps of video vid = cv2.VideoCapture(video_filename) fps = vid.get(cv2.CAP_PROP_FPS) vid.release() background_color = background_color.lstrip("#") background_color = ( np.array([int(background_color[i : i + 2], 16) for i in (0, 2, 4)]) / 255.0 ) with open(bbox_file, "r") as f: bbox_orig = [int(coord) for coord in f.read().split(" ")] download_mobile_sam_weight(mobile_sam_weights) if image_end == 0: frames = sorted(os.listdir(dir_frames))[image_start:] else: frames = sorted(os.listdir(dir_frames))[image_start:image_end] model_type = "vit_t" if torch.backends.mps.is_available(): device = "mps" elif torch.cuda.is_available(): device = "cuda" else: device = "cpu" sam = sam_model_registry[model_type](checkpoint=mobile_sam_weights) sam.to(device=device) sam.eval() predictor = SamPredictor(sam) if not auto_detect: if tracker_name == "yolov7": model = yolov7.load("kadirnar/yolov7-tiny-v0.1", hf_model=True) model.conf = 0.25 # NMS confidence threshold model.iou = 0.45 # NMS IoU threshold model.classes = None image_processor = None else: model = YolosForObjectDetection.from_pretrained("hustvl/yolos-tiny") image_processor = YolosImageProcessor.from_pretrained("hustvl/yolos-tiny") output_frames = [] if pbar: pb = tqdm(frames) else: pb = frames processed_frames = 0 init_time = time.time() for frame in pb: processed_frames += 1 image_file = dir_frames + "/" + frame image_pil = Image.open(image_file) image_np = np.array(image_pil) if not auto_detect: bboxes = get_bboxes(image_file, image_pil, model, image_processor) closest_bbox = get_closest_bbox(bboxes, bbox_orig) input_box = np.array(closest_bbox) else: input_box = np.array([0, 0, image_np.shape[1], image_np.shape[0]]) predictor.set_image(image_np) masks, _, _ = predictor.predict( point_coords=None, point_labels=None, box=input_box[None, :], multimask_output=True, ) if reverse_mask: mask = masks[0] h, w = mask.shape[-2:] mask_image = ( (mask).reshape(h, w, 1) * background_color.reshape(1, 1, -1) ) * 255 masked_image = image_np * (1 - mask).reshape(h, w, 1) masked_image = masked_image + mask_image output_frames.append(masked_image) else: mask = masks[0] h, w = mask.shape[-2:] mask_image = ( (1 - mask).reshape(h, w, 1) * background_color.reshape(1, 1, -1) ) * 255 masked_image = image_np * mask.reshape(h, w, 1) masked_image = masked_image + mask_image output_frames.append(masked_image) if not pbar and processed_frames % 10 == 0: remaining_time = ( (time.time() - init_time) / processed_frames * (len(frames) - processed_frames) ) remaining_time = int(remaining_time) remaining_time_str = f"{remaining_time//60}m {remaining_time%60}s" print( f"Processed frame {processed_frames}/{len(frames)} - Remaining time: {remaining_time_str}" ) if not os.path.exists(output_dir): os.mkdir(output_dir) zfill_max = len(str(len(output_frames))) for idx, frame in enumerate(output_frames): cv2.imwrite( f"{output_dir}/frame_{str(idx).zfill(zfill_max)}.png", frame, ) vid_creator = VideoCreator(output_dir, output_video, pbar=pbar) vid_creator.create_video(fps=int(fps)) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--video_filename", default="assets/example.mp4", type=str, help="path to the video", ) parser.add_argument( "--dir_frames", type=str, default="frames", help="path to the directory in which all input frames will be stored", ) parser.add_argument( "--image_start", type=int, default=0, help="first image to be stored" ) parser.add_argument( "--image_end", type=int, default=0, help="last image to be stored, last one if 0", ) parser.add_argument( "--bbox_file", type=str, default="bbox.txt", help="path to the bounding box text file", ) parser.add_argument( "--skip_vid2im", action="store_true", help="whether to write the video frames as images", ) parser.add_argument( "--mobile_sam_weights", type=str, default="./models/mobile_sam.pt", help="path to MobileSAM weights", ) parser.add_argument( "--tracker_name", type=str, default="yolov7", help="tracker name", choices=["yolov7", "yoloS"], ) parser.add_argument( "--output_dir", type=str, default="output_frames", help="directory to store the output frames", ) parser.add_argument( "--output_video", type=str, default="output.mp4", help="path to store the output video", ) parser.add_argument( "--auto_detect", action="store_true", help="whether to use a bounding box to force the model to segment the object", ) parser.add_argument( "--background_color", type=str, default="#009000", help="background color for the output (hex)", ) args = parser.parse_args() segment_video( args.video_filename, args.dir_frames, args.image_start, args.image_end, args.bbox_file, args.skip_vid2im, args.mobile_sam_weights, args.auto_detect, args.output_dir, args.output_video, args.tracker_name, args.background_color, )