Spaces:
Runtime error
Runtime error
| import os | |
| import cv2 | |
| from model_args import segtracker_args,sam_args,aot_args | |
| from PIL import Image | |
| from aot_tracker import _palette | |
| import numpy as np | |
| import torch | |
| import gc | |
| import imageio | |
| from scipy.ndimage import binary_dilation | |
| def save_prediction(pred_mask,output_dir,file_name): | |
| save_mask = Image.fromarray(pred_mask.astype(np.uint8)) | |
| save_mask = save_mask.convert(mode='P') | |
| save_mask.putpalette(_palette) | |
| save_mask.save(os.path.join(output_dir,file_name)) | |
| def colorize_mask(pred_mask): | |
| save_mask = Image.fromarray(pred_mask.astype(np.uint8)) | |
| save_mask = save_mask.convert(mode='P') | |
| save_mask.putpalette(_palette) | |
| save_mask = save_mask.convert(mode='RGB') | |
| return np.array(save_mask) | |
| def draw_mask(img, mask, alpha=0.5, id_countour=False): | |
| img_mask = np.zeros_like(img) | |
| img_mask = img | |
| if id_countour: | |
| # very slow ~ 1s per image | |
| obj_ids = np.unique(mask) | |
| obj_ids = obj_ids[obj_ids!=0] | |
| for id in obj_ids: | |
| # Overlay color on binary mask | |
| if id <= 255: | |
| color = _palette[id*3:id*3+3] | |
| else: | |
| color = [0,0,0] | |
| foreground = img * (1-alpha) + np.ones_like(img) * alpha * np.array(color) | |
| binary_mask = (mask == id) | |
| # Compose image | |
| img_mask[binary_mask] = foreground[binary_mask] | |
| countours = binary_dilation(binary_mask,iterations=1) ^ binary_mask | |
| img_mask[countours, :] = 0 | |
| else: | |
| binary_mask = (mask!=0) | |
| countours = binary_dilation(binary_mask,iterations=1) ^ binary_mask | |
| foreground = img*(1-alpha)+colorize_mask(mask)*alpha | |
| img_mask[binary_mask] = foreground[binary_mask] | |
| img_mask[countours,:] = 0 | |
| return img_mask.astype(img.dtype) | |
| def create_dir(dir_path): | |
| if os.path.isdir(dir_path): | |
| os.system(f"rm -r {dir_path}") | |
| os.makedirs(dir_path) | |
| aot_model2ckpt = { | |
| "deaotb": "./ckpt/DeAOTB_PRE_YTB_DAV.pth", | |
| "deaotl": "./ckpt/DeAOTL_PRE_YTB_DAV", | |
| "r50_deaotl": "./ckpt/R50_DeAOTL_PRE_YTB_DAV.pth", | |
| } | |
| def tracking_objects_in_video(SegTracker, input_video, input_img_seq, fps): | |
| if input_video is not None: | |
| video_name = os.path.basename(input_video).split('.')[0] | |
| elif input_img_seq is not None: | |
| file_name = input_img_seq.name.split('/')[-1].split('.')[0] | |
| file_path = f'./assets/{file_name}' | |
| imgs_path = sorted([os.path.join(file_path, img_name) for img_name in os.listdir(file_path)]) | |
| video_name = file_name | |
| else: | |
| return None, None | |
| # create dir to save result | |
| tracking_result_dir = f'{os.path.join(os.path.dirname(__file__), "tracking_results", f"{video_name}")}' | |
| create_dir(tracking_result_dir) | |
| io_args = { | |
| 'tracking_result_dir': tracking_result_dir, | |
| 'output_mask_dir': f'{tracking_result_dir}/{video_name}_masks', | |
| 'output_masked_frame_dir': f'{tracking_result_dir}/{video_name}_masked_frames', | |
| 'output_video': f'{tracking_result_dir}/{video_name}_seg.mp4', # keep same format as input video | |
| 'output_gif': f'{tracking_result_dir}/{video_name}_seg.gif', | |
| } | |
| if input_video is not None: | |
| return video_type_input_tracking(SegTracker, input_video, io_args, video_name) | |
| elif input_img_seq is not None: | |
| return img_seq_type_input_tracking(SegTracker, io_args, video_name, imgs_path, fps) | |
| def video_type_input_tracking(SegTracker, input_video, io_args, video_name): | |
| # source video to segment | |
| cap = cv2.VideoCapture(input_video) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| # create dir to save predicted mask and masked frame | |
| output_mask_dir = io_args['output_mask_dir'] | |
| create_dir(io_args['output_mask_dir']) | |
| create_dir(io_args['output_masked_frame_dir']) | |
| pred_list = [] | |
| masked_pred_list = [] | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| sam_gap = SegTracker.sam_gap | |
| frame_idx = 0 | |
| with torch.cuda.amp.autocast(): | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB) | |
| if frame_idx == 0: | |
| pred_mask = SegTracker.first_frame_mask | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| elif (frame_idx % sam_gap) == 0: | |
| seg_mask = SegTracker.seg(frame) | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| track_mask = SegTracker.track(frame) | |
| # find new objects, and update tracker with new objects | |
| new_obj_mask = SegTracker.find_new_objs(track_mask,seg_mask) | |
| save_prediction(new_obj_mask, output_mask_dir, str(frame_idx).zfill(5) + '_new.png') | |
| pred_mask = track_mask + new_obj_mask | |
| # segtracker.restart_tracker() | |
| SegTracker.add_reference(frame, pred_mask) | |
| else: | |
| pred_mask = SegTracker.track(frame,update_memory=True) | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| save_prediction(pred_mask, output_mask_dir, str(frame_idx).zfill(5) + '.png') | |
| pred_list.append(pred_mask) | |
| print("processed frame {}, obj_num {}".format(frame_idx, SegTracker.get_obj_num()),end='\r') | |
| frame_idx += 1 | |
| cap.release() | |
| print('\nfinished') | |
| ################## | |
| # Visualization | |
| ################## | |
| # draw pred mask on frame and save as a video | |
| cap = cv2.VideoCapture(input_video) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
| # if input_video[-3:]=='mp4': | |
| # fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
| # elif input_video[-3:] == 'avi': | |
| # fourcc = cv2.VideoWriter_fourcc(*"MJPG") | |
| # # fourcc = cv2.VideoWriter_fourcc(*"XVID") | |
| # else: | |
| # fourcc = int(cap.get(cv2.CAP_PROP_FOURCC)) | |
| out = cv2.VideoWriter(io_args['output_video'], fourcc, fps, (width, height)) | |
| frame_idx = 0 | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB) | |
| pred_mask = pred_list[frame_idx] | |
| masked_frame = draw_mask(frame, pred_mask) | |
| cv2.imwrite(f"{io_args['output_masked_frame_dir']}/{str(frame_idx).zfill(5)}.png", masked_frame[:, :, ::-1]) | |
| masked_pred_list.append(masked_frame) | |
| masked_frame = cv2.cvtColor(masked_frame,cv2.COLOR_RGB2BGR) | |
| out.write(masked_frame) | |
| print('frame {} writed'.format(frame_idx),end='\r') | |
| frame_idx += 1 | |
| out.release() | |
| cap.release() | |
| print("\n{} saved".format(io_args['output_video'])) | |
| print('\nfinished') | |
| # save colorized masks as a gif | |
| imageio.mimsave(io_args['output_gif'], masked_pred_list, fps=fps) | |
| print("{} saved".format(io_args['output_gif'])) | |
| # zip predicted mask | |
| os.system(f"zip -r {io_args['tracking_result_dir']}/{video_name}_pred_mask.zip {io_args['output_mask_dir']}") | |
| # manually release memory (after cuda out of memory) | |
| del SegTracker | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return io_args['output_video'], f"{io_args['tracking_result_dir']}/{video_name}_pred_mask.zip" | |
| def img_seq_type_input_tracking(SegTracker, io_args, video_name, imgs_path, fps): | |
| # create dir to save predicted mask and masked frame | |
| output_mask_dir = io_args['output_mask_dir'] | |
| create_dir(io_args['output_mask_dir']) | |
| create_dir(io_args['output_masked_frame_dir']) | |
| pred_list = [] | |
| masked_pred_list = [] | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| sam_gap = SegTracker.sam_gap | |
| frame_idx = 0 | |
| with torch.cuda.amp.autocast(): | |
| for img_path in imgs_path: | |
| frame_name = os.path.basename(img_path).split('.')[0] | |
| frame = cv2.imread(img_path) | |
| frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB) | |
| if frame_idx == 0: | |
| pred_mask = SegTracker.first_frame_mask | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| elif (frame_idx % sam_gap) == 0: | |
| seg_mask = SegTracker.seg(frame) | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| track_mask = SegTracker.track(frame) | |
| # find new objects, and update tracker with new objects | |
| new_obj_mask = SegTracker.find_new_objs(track_mask,seg_mask) | |
| save_prediction(new_obj_mask, output_mask_dir, f'{frame_name}_new.png') | |
| pred_mask = track_mask + new_obj_mask | |
| # segtracker.restart_tracker() | |
| SegTracker.add_reference(frame, pred_mask) | |
| else: | |
| pred_mask = SegTracker.track(frame,update_memory=True) | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| save_prediction(pred_mask, output_mask_dir, f'{frame_name}.png') | |
| pred_list.append(pred_mask) | |
| print("processed frame {}, obj_num {}".format(frame_idx, SegTracker.get_obj_num()),end='\r') | |
| frame_idx += 1 | |
| print('\nfinished') | |
| ################## | |
| # Visualization | |
| ################## | |
| # draw pred mask on frame and save as a video | |
| height, width = pred_list[0].shape | |
| fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
| out = cv2.VideoWriter(io_args['output_video'], fourcc, fps, (width, height)) | |
| frame_idx = 0 | |
| for img_path in imgs_path: | |
| frame_name = os.path.basename(img_path).split('.')[0] | |
| frame = cv2.imread(img_path) | |
| frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB) | |
| pred_mask = pred_list[frame_idx] | |
| masked_frame = draw_mask(frame, pred_mask) | |
| masked_pred_list.append(masked_frame) | |
| cv2.imwrite(f"{io_args['output_masked_frame_dir']}/{frame_name}.png", masked_frame[:, :, ::-1]) | |
| masked_frame = cv2.cvtColor(masked_frame,cv2.COLOR_RGB2BGR) | |
| out.write(masked_frame) | |
| print('frame {} writed'.format(frame_name),end='\r') | |
| frame_idx += 1 | |
| out.release() | |
| print("\n{} saved".format(io_args['output_video'])) | |
| print('\nfinished') | |
| # save colorized masks as a gif | |
| imageio.mimsave(io_args['output_gif'], masked_pred_list, fps=fps) | |
| print("{} saved".format(io_args['output_gif'])) | |
| # zip predicted mask | |
| os.system(f"zip -r {io_args['tracking_result_dir']}/{video_name}_pred_mask.zip {io_args['output_mask_dir']}") | |
| # manually release memory (after cuda out of memory) | |
| del SegTracker | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return io_args['output_video'], f"{io_args['tracking_result_dir']}/{video_name}_pred_mask.zip" |