AAAAAAyq commited on
Commit
8406393
·
1 Parent(s): d901fa4

Update requirements

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -92,8 +92,9 @@ def post_process(annotations, image, mask_random_color=True, bbox=None, points=N
92
 
93
  # post_process(results[0].masks, Image.open("../data/cake.png"))
94
 
95
- def predict(inp):
96
- results = model(inp, device='cpu', retina_masks=True, iou=0.7, conf=0.25, imgsz=720)
 
97
  results = format_results(results[0], 100)
98
  results.sort(key=lambda x: x['area'], reverse=True)
99
  pil_image = post_process(annotations=results, image=inp)
@@ -105,9 +106,10 @@ def predict(inp):
105
  # post_process(annotations=results, image_path=inp)
106
 
107
  demo = gr.Interface(fn=predict,
108
- inputs=gr.inputs.Image(type='pil'),
109
  outputs=['plot'],
110
- examples=[["assets/sa_8776.jpg"],],
 
111
  # examples=[["assets/sa_192.jpg"], ["assets/sa_414.jpg"],
112
  # ["assets/sa_561.jpg"], ["assets/sa_862.jpg"],
113
  # ["assets/sa_1309.jpg"], ["assets/sa_8776.jpg"],
 
92
 
93
  # post_process(results[0].masks, Image.open("../data/cake.png"))
94
 
95
+ def predict(inp, imgsz):
96
+ imgsz = int(imgsz) # 确保 imgsz 是整数
97
+ results = model(inp, device='cpu', retina_masks=True, iou=0.7, conf=0.25, imgsz=imgsz)
98
  results = format_results(results[0], 100)
99
  results.sort(key=lambda x: x['area'], reverse=True)
100
  pil_image = post_process(annotations=results, image=inp)
 
106
  # post_process(annotations=results, image_path=inp)
107
 
108
  demo = gr.Interface(fn=predict,
109
+ inputs=[gr.inputs.Image(type='pil'), gr.inputs.Dropdown(choices=[800, 960, 1024])],
110
  outputs=['plot'],
111
+ examples=[["assets/sa_8776.jpg", 1024],
112
+ ["assets/sa_1309.jpg", 1024]],
113
  # examples=[["assets/sa_192.jpg"], ["assets/sa_414.jpg"],
114
  # ["assets/sa_561.jpg"], ["assets/sa_862.jpg"],
115
  # ["assets/sa_1309.jpg"], ["assets/sa_8776.jpg"],