xiongjie commited on
Commit
289b883
·
1 Parent(s): e27d92f
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -36,21 +36,24 @@ def run_inference(onnx_session, input_size, image):
36
  onnx_result = (onnx_result - min_value) / (max_value - min_value)
37
  onnx_result *= 255
38
  onnx_result = onnx_result.astype('uint8')
39
- onnx_result[onnx_result >= 125] = 255
40
- onnx_result[onnx_result < 125] = 0
41
 
42
  return onnx_result
43
 
44
  # Load model
45
  onnx_session = onnxruntime.InferenceSession("u2net.onnx")
46
 
47
- def create_rgba(image):
48
  out = run_inference(
49
  onnx_session,
50
  320,
51
  image,
52
  )
53
  resize_image = cv.resize(out, dsize=(image.shape[1], image.shape[0]))
 
 
 
 
 
54
  mask = Image.fromarray(resize_image)
55
 
56
  rgba_image = Image.fromarray(image).convert('RGBA')
@@ -59,6 +62,6 @@ def create_rgba(image):
59
  return rgba_image
60
 
61
  css = ".output_image {height: 100% !important; width: 100% !important;}"
62
- inputs = gradio.inputs.Image()
63
  outputs = gradio.outputs.Image()
64
  gradio.Interface(fn=create_rgba, inputs=inputs, outputs=outputs, css=css).launch()
 
36
  onnx_result = (onnx_result - min_value) / (max_value - min_value)
37
  onnx_result *= 255
38
  onnx_result = onnx_result.astype('uint8')
 
 
39
 
40
  return onnx_result
41
 
42
  # Load model
43
  onnx_session = onnxruntime.InferenceSession("u2net.onnx")
44
 
45
+ def create_rgba(mode, image):
46
  out = run_inference(
47
  onnx_session,
48
  320,
49
  image,
50
  )
51
  resize_image = cv.resize(out, dsize=(image.shape[1], image.shape[0]))
52
+
53
+ if mode == "binary":
54
+ resize_image[resize_image > 255] = 255
55
+ resize_image[resize_image < 125] = 0
56
+
57
  mask = Image.fromarray(resize_image)
58
 
59
  rgba_image = Image.fromarray(image).convert('RGBA')
 
62
  return rgba_image
63
 
64
  css = ".output_image {height: 100% !important; width: 100% !important;}"
65
+ inputs = [gradio.inputs.Radio(["binary", "smooth"]), gradio.inputs.Image()]
66
  outputs = gradio.outputs.Image()
67
  gradio.Interface(fn=create_rgba, inputs=inputs, outputs=outputs, css=css).launch()