John Ho commited on
Commit
07c2352
·
1 Parent(s): aaa1b00

removed visualization for debugging

Browse files
Files changed (2) hide show
  1. samv2_handler.py +10 -9
  2. visualizer.py +5 -3
samv2_handler.py CHANGED
@@ -9,10 +9,11 @@ from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
9
  from sam2.utils.misc import variant_to_config_mapping
10
  from sam2.utils.visualization import show_masks
11
  from ffmpeg_extractor import extract_frames, logger
12
- from visualizer import annotate_masks, mask_to_xyxy
13
  from toolbox.vid_utils import VidInfo, VidReader
14
  from toolbox.mask_encoding import b64_mask_encode
15
- from toolbox.img_utils import get_pil_im
 
16
 
17
  variant_checkpoints_mapping = {
18
  "tiny": "checkpoints/sam2_hiera_tiny.pt",
@@ -232,13 +233,13 @@ def run_sam_video_inference(
232
  logger.debug(f"model initiated with object_ids of len {len(object_ids)}")
233
  init_masks = (mask_logits > 0.0).cpu().numpy().astype(np.uint8)
234
  init_masks = [m.squeeze() for m in init_masks]
235
- ref_frame_im = get_pil_im(np.array(vr.get_data(ref_frame_idx)))
236
- init_masks_im_fp = os.path.join(vframes_dir, f"model_init_masks.jpg")
237
- input_masks_im_fp = os.path.join(vframes_dir, f"input_masks.jpg")
238
- annotate_masks(ref_frame_im, init_masks).save(init_masks_im_fp)
239
- annotate_masks(ref_frame_im, masks).save(input_masks_im_fp)
240
- logger.debug(f"masks received by model visualized at {init_masks_im_fp}")
241
- logger.debug(f"masks provided to model visualized at {input_masks_im_fp}")
242
 
243
  masks_generator = model.propagate_in_video(inference_state)
244
  detections = unpack_masks(
 
9
  from sam2.utils.misc import variant_to_config_mapping
10
  from sam2.utils.visualization import show_masks
11
  from ffmpeg_extractor import extract_frames, logger
12
+ from visualizer import mask_to_xyxy
13
  from toolbox.vid_utils import VidInfo, VidReader
14
  from toolbox.mask_encoding import b64_mask_encode
15
+
16
+ # from toolbox.img_utils import get_pil_im
17
 
18
  variant_checkpoints_mapping = {
19
  "tiny": "checkpoints/sam2_hiera_tiny.pt",
 
233
  logger.debug(f"model initiated with object_ids of len {len(object_ids)}")
234
  init_masks = (mask_logits > 0.0).cpu().numpy().astype(np.uint8)
235
  init_masks = [m.squeeze() for m in init_masks]
236
+ # ref_frame_im = get_pil_im(np.array(vr.get_data(ref_frame_idx)))
237
+ # init_masks_im_fp = os.path.join(vframes_dir, f"model_init_masks.jpg")
238
+ # input_masks_im_fp = os.path.join(vframes_dir, f"input_masks.jpg")
239
+ # annotate_masks(ref_frame_im, init_masks).save(init_masks_im_fp)
240
+ # annotate_masks(ref_frame_im, masks).save(input_masks_im_fp)
241
+ # logger.debug(f"masks received by model visualized at {init_masks_im_fp}")
242
+ # logger.debug(f"masks provided to model visualized at {input_masks_im_fp}")
243
 
244
  masks_generator = model.propagate_in_video(inference_state)
245
  detections = unpack_masks(
visualizer.py CHANGED
@@ -1,8 +1,10 @@
1
  from PIL import Image, ImageColor
2
- import matplotlib.colors as mcolors
 
3
  import numpy as np
4
- from toolbox.mask_encoding import b64_mask_decode
5
- from toolbox.img_utils import im_draw_bbox, im_draw_point, im_color_mask
 
6
 
7
 
8
  def mask_to_xyxy(mask: np.ndarray, verbose: bool = False) -> tuple:
 
1
  from PIL import Image, ImageColor
2
+
3
+ # import matplotlib.colors as mcolors
4
  import numpy as np
5
+
6
+ # from toolbox.mask_encoding import b64_mask_decode
7
+ # from toolbox.img_utils import im_draw_bbox, im_draw_point, im_color_mask
8
 
9
 
10
  def mask_to_xyxy(mask: np.ndarray, verbose: bool = False) -> tuple: