AAAAAAyq commited on
Commit
f2149ba
·
1 Parent(s): 8406393

Update requirements

Browse files
Files changed (1) hide show
  1. app.py +6 -42
app.py CHANGED
@@ -13,7 +13,6 @@ def format_results(result,filter = 0):
13
  annotation = {}
14
  mask = result.masks.data[i] == 1.0
15
 
16
-
17
  if torch.sum(mask) < filter:
18
  continue
19
  annotation['id'] = i
@@ -50,51 +49,16 @@ def post_process(annotations, image, mask_random_color=True, bbox=None, points=N
50
  for i, mask in enumerate(annotations):
51
  show_mask(mask, plt.gca(),random_color=mask_random_color,bbox=bbox,points=points)
52
  plt.axis('off')
53
- # # create a BytesIO object
54
- # buf = io.BytesIO()
55
-
56
- # # save plot to buf
57
- # plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0.0)
58
-
59
- # # use PIL to open the image
60
- # img = Image.open(buf)
61
 
62
- # # copy the image data
63
- # img_copy = img.copy()
64
  plt.tight_layout()
65
-
66
- # # don't forget to close the buffer
67
- # buf.close()
68
  return fig
69
 
70
 
71
- # def show_mask(annotation, ax, random_color=False):
72
- # if random_color : # 掩膜颜色是否随机决定
73
- # color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
74
- # else:
75
- # color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
76
- # mask = annotation.cpu().numpy()
77
- # h, w = mask.shape[-2:]
78
- # mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
79
- # ax.imshow(mask_image)
80
-
81
- # def post_process(annotations, image):
82
- # plt.figure(figsize=(10, 10))
83
- # plt.imshow(image)
84
- # for i, mask in enumerate(annotations):
85
- # show_mask(mask.data, plt.gca(),random_color=True)
86
- # plt.axis('off')
87
-
88
- # 获取渲染后的像素数据并转换为PIL图像
89
-
90
- return pil_image
91
-
92
-
93
  # post_process(results[0].masks, Image.open("../data/cake.png"))
94
 
95
- def predict(inp, imgsz):
96
- imgsz = int(imgsz) # 确保 imgsz 是整数
97
- results = model(inp, device='cpu', retina_masks=True, iou=0.7, conf=0.25, imgsz=imgsz)
98
  results = format_results(results[0], 100)
99
  results.sort(key=lambda x: x['area'], reverse=True)
100
  pil_image = post_process(annotations=results, image=inp)
@@ -106,10 +70,10 @@ def predict(inp, imgsz):
106
  # post_process(annotations=results, image_path=inp)
107
 
108
  demo = gr.Interface(fn=predict,
109
- inputs=[gr.inputs.Image(type='pil'), gr.inputs.Dropdown(choices=[800, 960, 1024])],
110
  outputs=['plot'],
111
- examples=[["assets/sa_8776.jpg", 1024],
112
- ["assets/sa_1309.jpg", 1024]],
113
  # examples=[["assets/sa_192.jpg"], ["assets/sa_414.jpg"],
114
  # ["assets/sa_561.jpg"], ["assets/sa_862.jpg"],
115
  # ["assets/sa_1309.jpg"], ["assets/sa_8776.jpg"],
 
13
  annotation = {}
14
  mask = result.masks.data[i] == 1.0
15
 
 
16
  if torch.sum(mask) < filter:
17
  continue
18
  annotation['id'] = i
 
49
  for i, mask in enumerate(annotations):
50
  show_mask(mask, plt.gca(),random_color=mask_random_color,bbox=bbox,points=points)
51
  plt.axis('off')
 
 
 
 
 
 
 
 
52
 
 
 
53
  plt.tight_layout()
 
 
 
54
  return fig
55
 
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  # post_process(results[0].masks, Image.open("../data/cake.png"))
58
 
59
+ def predict(inp, input_size):
60
+ input_size = int(input_size) # 确保 imgsz 是整数
61
+ results = model(inp, device='cpu', retina_masks=True, iou=0.7, conf=0.25, imgsz=input_size)
62
  results = format_results(results[0], 100)
63
  results.sort(key=lambda x: x['area'], reverse=True)
64
  pil_image = post_process(annotations=results, image=inp)
 
70
  # post_process(annotations=results, image_path=inp)
71
 
72
  demo = gr.Interface(fn=predict,
73
+ inputs=[gr.inputs.Image(type='pil'), gr.inputs.Dropdown(choices=[512, 800, 1024])],
74
  outputs=['plot'],
75
+ examples=[["assets/sa_8776.jpg", 1024]],
76
+ # ["assets/sa_1309.jpg", 1024]],
77
  # examples=[["assets/sa_192.jpg"], ["assets/sa_414.jpg"],
78
  # ["assets/sa_561.jpg"], ["assets/sa_862.jpg"],
79
  # ["assets/sa_1309.jpg"], ["assets/sa_8776.jpg"],