yuxindu commited on
Commit
0e8330e
·
1 Parent(s): 12be15f

fix shape bug

Browse files
Files changed (1) hide show
  1. model/inference_cpu.py +4 -3
model/inference_cpu.py CHANGED
@@ -30,7 +30,8 @@ def zoom_in_zoom_out(args, segvol_model, image, image_resize, text_prompt, point
30
  image_single_resize = image_resize
31
  image_single = image[0,0]
32
  ori_shape = image_single.shape
33
-
 
34
  # generate prompts
35
  text_single = None if text_prompt is None else [text_prompt]
36
  points_single = None
@@ -39,10 +40,10 @@ def zoom_in_zoom_out(args, segvol_model, image, image_resize, text_prompt, point
39
  if args.use_point_prompt:
40
  point, point_label = point_prompt
41
  points_single = (point.unsqueeze(0).float(), point_label.unsqueeze(0).float())
42
- binary_points_resize = build_binary_points(point, point_label, ori_shape)
43
  if args.use_box_prompt:
44
  box_single = box_prompt.unsqueeze(0).float()
45
- binary_cube_resize = build_binary_cube(box_single, binary_cube_shape=ori_shape)
46
 
47
  ####################
48
  # zoom-out inference:
 
30
  image_single_resize = image_resize
31
  image_single = image[0,0]
32
  ori_shape = image_single.shape
33
+ resize_shape = image_single_resize.shape[2:]
34
+
35
  # generate prompts
36
  text_single = None if text_prompt is None else [text_prompt]
37
  points_single = None
 
40
  if args.use_point_prompt:
41
  point, point_label = point_prompt
42
  points_single = (point.unsqueeze(0).float(), point_label.unsqueeze(0).float())
43
+ binary_points_resize = build_binary_points(point, point_label, resize_shape)
44
  if args.use_box_prompt:
45
  box_single = box_prompt.unsqueeze(0).float()
46
+ binary_cube_resize = build_binary_cube(box_single, binary_cube_shape=resize_shape)
47
 
48
  ####################
49
  # zoom-out inference: