killian31
initial commit
8b09391
import argparse
import os
import time
import cv2
import numpy as np
import requests
import torch
import wget
import yolov7
from mobile_sam import SamPredictor, sam_model_registry
from PIL import Image
from tqdm import tqdm
from transformers import YolosForObjectDetection, YolosImageProcessor
from images_to_video import VideoCreator
from video_to_images import ImageCreator
def download_mobile_sam_weight(path):
if not os.path.exists(path):
sam_weights = "https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/master/weights/mobile_sam.pt"
for i in range(2, len(path.split("/"))):
temp = path.split("/")[:i]
cur_path = "/".join(temp)
if not os.path.isdir(cur_path):
os.mkdir(cur_path)
model_name = path.split("/")[-1]
if model_name in sam_weights:
wget.download(sam_weights, path)
else:
raise NameError(
"There is no pretrained weight to download for %s, you need to provide a path to segformer weights."
% model_name
)
def get_closest_bbox(bbox_list, bbox_target):
"""
Given a list of bounding boxes, find the one that is closest to the target bounding box.
Args:
bbox_list: list of bounding boxes
bbox_target: target bounding box
Returns:
closest bounding box
"""
min_dist = 100000000
min_idx = 0
for idx, bbox in enumerate(bbox_list):
dist = np.linalg.norm(bbox - bbox_target)
if dist < min_dist:
min_dist = dist
min_idx = idx
return bbox_list[min_idx]
def get_bboxes(image_file, image, model, image_processor, threshold=0.9):
if image_processor is None:
results = model(image_file)
predictions = results.pred[0]
boxes = predictions[:, :4].detach().numpy()
return boxes
else:
inputs = image_processor(images=image, return_tensors="pt")
outputs = model(**inputs)
target_sizes = torch.tensor([image.size[::-1]])
results = image_processor.post_process_object_detection(
outputs, threshold=threshold, target_sizes=target_sizes
)[0]
return results["boxes"].detach().numpy()
def segment_video(
video_filename,
dir_frames,
image_start,
image_end,
bbox_file,
skip_vid2im,
mobile_sam_weights,
auto_detect=False,
tracker_name="yolov7",
background_color="#009000",
output_dir="output_frames",
output_video="output.mp4",
pbar=False,
reverse_mask=False,
):
if not skip_vid2im:
vid_to_im = ImageCreator(
video_filename,
dir_frames,
image_start=image_start,
image_end=image_end,
pbar=pbar,
)
vid_to_im.get_images()
# Get fps of video
vid = cv2.VideoCapture(video_filename)
fps = vid.get(cv2.CAP_PROP_FPS)
vid.release()
background_color = background_color.lstrip("#")
background_color = (
np.array([int(background_color[i : i + 2], 16) for i in (0, 2, 4)]) / 255.0
)
with open(bbox_file, "r") as f:
bbox_orig = [int(coord) for coord in f.read().split(" ")]
download_mobile_sam_weight(mobile_sam_weights)
if image_end == 0:
frames = sorted(os.listdir(dir_frames))[image_start:]
else:
frames = sorted(os.listdir(dir_frames))[image_start:image_end]
model_type = "vit_t"
if torch.backends.mps.is_available():
device = "mps"
elif torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
sam = sam_model_registry[model_type](checkpoint=mobile_sam_weights)
sam.to(device=device)
sam.eval()
predictor = SamPredictor(sam)
if not auto_detect:
if tracker_name == "yolov7":
model = yolov7.load("kadirnar/yolov7-tiny-v0.1", hf_model=True)
model.conf = 0.25 # NMS confidence threshold
model.iou = 0.45 # NMS IoU threshold
model.classes = None
image_processor = None
else:
model = YolosForObjectDetection.from_pretrained("hustvl/yolos-tiny")
image_processor = YolosImageProcessor.from_pretrained("hustvl/yolos-tiny")
output_frames = []
if pbar:
pb = tqdm(frames)
else:
pb = frames
processed_frames = 0
init_time = time.time()
for frame in pb:
processed_frames += 1
image_file = dir_frames + "/" + frame
image_pil = Image.open(image_file)
image_np = np.array(image_pil)
if not auto_detect:
bboxes = get_bboxes(image_file, image_pil, model, image_processor)
closest_bbox = get_closest_bbox(bboxes, bbox_orig)
input_box = np.array(closest_bbox)
else:
input_box = np.array([0, 0, image_np.shape[1], image_np.shape[0]])
predictor.set_image(image_np)
masks, _, _ = predictor.predict(
point_coords=None,
point_labels=None,
box=input_box[None, :],
multimask_output=True,
)
if reverse_mask:
mask = masks[0]
h, w = mask.shape[-2:]
mask_image = (
(mask).reshape(h, w, 1) * background_color.reshape(1, 1, -1)
) * 255
masked_image = image_np * (1 - mask).reshape(h, w, 1)
masked_image = masked_image + mask_image
output_frames.append(masked_image)
else:
mask = masks[0]
h, w = mask.shape[-2:]
mask_image = (
(1 - mask).reshape(h, w, 1) * background_color.reshape(1, 1, -1)
) * 255
masked_image = image_np * mask.reshape(h, w, 1)
masked_image = masked_image + mask_image
output_frames.append(masked_image)
if not pbar and processed_frames % 10 == 0:
remaining_time = (
(time.time() - init_time)
/ processed_frames
* (len(frames) - processed_frames)
)
remaining_time = int(remaining_time)
remaining_time_str = f"{remaining_time//60}m {remaining_time%60}s"
print(
f"Processed frame {processed_frames}/{len(frames)} - Remaining time: {remaining_time_str}"
)
if not os.path.exists(output_dir):
os.mkdir(output_dir)
zfill_max = len(str(len(output_frames)))
for idx, frame in enumerate(output_frames):
cv2.imwrite(
f"{output_dir}/frame_{str(idx).zfill(zfill_max)}.png",
frame,
)
vid_creator = VideoCreator(output_dir, output_video, pbar=pbar)
vid_creator.create_video(fps=int(fps))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--video_filename",
default="assets/example.mp4",
type=str,
help="path to the video",
)
parser.add_argument(
"--dir_frames",
type=str,
default="frames",
help="path to the directory in which all input frames will be stored",
)
parser.add_argument(
"--image_start", type=int, default=0, help="first image to be stored"
)
parser.add_argument(
"--image_end",
type=int,
default=0,
help="last image to be stored, last one if 0",
)
parser.add_argument(
"--bbox_file",
type=str,
default="bbox.txt",
help="path to the bounding box text file",
)
parser.add_argument(
"--skip_vid2im",
action="store_true",
help="whether to write the video frames as images",
)
parser.add_argument(
"--mobile_sam_weights",
type=str,
default="./models/mobile_sam.pt",
help="path to MobileSAM weights",
)
parser.add_argument(
"--tracker_name",
type=str,
default="yolov7",
help="tracker name",
choices=["yolov7", "yoloS"],
)
parser.add_argument(
"--output_dir",
type=str,
default="output_frames",
help="directory to store the output frames",
)
parser.add_argument(
"--output_video",
type=str,
default="output.mp4",
help="path to store the output video",
)
parser.add_argument(
"--auto_detect",
action="store_true",
help="whether to use a bounding box to force the model to segment the object",
)
parser.add_argument(
"--background_color",
type=str,
default="#009000",
help="background color for the output (hex)",
)
args = parser.parse_args()
segment_video(
args.video_filename,
args.dir_frames,
args.image_start,
args.image_end,
args.bbox_file,
args.skip_vid2im,
args.mobile_sam_weights,
args.auto_detect,
args.output_dir,
args.output_video,
args.tracker_name,
args.background_color,
)