xiongjie commited on
Commit
e27d92f
·
1 Parent(s): a78808a
Files changed (2) hide show
  1. app.py +23 -13
  2. u2net.onnx +3 -0
app.py CHANGED
@@ -11,42 +11,52 @@ from PIL import Image
11
  import gradio
12
 
13
  def run_inference(onnx_session, input_size, image):
14
- # Pre process:Resize, BGR->RGB, Transpose, PyTorch standardization, float32 cast
15
  temp_image = copy.deepcopy(image)
16
- resize_image = cv.resize(temp_image, dsize=(input_size[0], input_size[1]))
17
  x = cv.cvtColor(resize_image, cv.COLOR_BGR2RGB)
 
 
18
  x = np.array(x, dtype=np.float32)
19
  mean = [0.485, 0.456, 0.406]
20
  std = [0.229, 0.224, 0.225]
21
  x = (x / 255 - mean) / std
22
- x = x.reshape(-1, input_size[0], input_size[1], 3).astype('float32')
 
23
 
24
- # Inference
25
  input_name = onnx_session.get_inputs()[0].name
26
  output_name = onnx_session.get_outputs()[0].name
27
  onnx_result = onnx_session.run([output_name], {input_name: x})
28
 
29
- # Post process
30
  onnx_result = np.array(onnx_result).squeeze()
31
- onnx_result = (1 - onnx_result)
32
  min_value = np.min(onnx_result)
33
  max_value = np.max(onnx_result)
34
  onnx_result = (onnx_result - min_value) / (max_value - min_value)
35
  onnx_result *= 255
36
  onnx_result = onnx_result.astype('uint8')
 
 
37
 
38
  return onnx_result
39
 
40
  # Load model
41
- onnx_session = onnxruntime.InferenceSession("model_float32.onnx")
42
 
43
  def create_rgba(image):
44
- print(images.shape)
45
- return run_inference(
46
- onnx_session,
47
- [480,640],
48
- image,
49
- )
 
 
 
 
 
 
50
 
51
  css = ".output_image {height: 100% !important; width: 100% !important;}"
52
  inputs = gradio.inputs.Image()
 
11
  import gradio
12
 
13
  def run_inference(onnx_session, input_size, image):
14
+ # リサイズ
15
  temp_image = copy.deepcopy(image)
16
+ resize_image = cv.resize(temp_image, dsize=(input_size, input_size))
17
  x = cv.cvtColor(resize_image, cv.COLOR_BGR2RGB)
18
+
19
+ # 前処理
20
  x = np.array(x, dtype=np.float32)
21
  mean = [0.485, 0.456, 0.406]
22
  std = [0.229, 0.224, 0.225]
23
  x = (x / 255 - mean) / std
24
+ x = x.transpose(2, 0, 1).astype('float32')
25
+ x = x.reshape(-1, 3, input_size, input_size)
26
 
27
+ # 推論
28
  input_name = onnx_session.get_inputs()[0].name
29
  output_name = onnx_session.get_outputs()[0].name
30
  onnx_result = onnx_session.run([output_name], {input_name: x})
31
 
32
+ # 後処理
33
  onnx_result = np.array(onnx_result).squeeze()
 
34
  min_value = np.min(onnx_result)
35
  max_value = np.max(onnx_result)
36
  onnx_result = (onnx_result - min_value) / (max_value - min_value)
37
  onnx_result *= 255
38
  onnx_result = onnx_result.astype('uint8')
39
+ onnx_result[onnx_result >= 125] = 255
40
+ onnx_result[onnx_result < 125] = 0
41
 
42
  return onnx_result
43
 
44
  # Load model
45
+ onnx_session = onnxruntime.InferenceSession("u2net.onnx")
46
 
47
  def create_rgba(image):
48
+ out = run_inference(
49
+ onnx_session,
50
+ 320,
51
+ image,
52
+ )
53
+ resize_image = cv.resize(out, dsize=(image.shape[1], image.shape[0]))
54
+ mask = Image.fromarray(resize_image)
55
+
56
+ rgba_image = Image.fromarray(image).convert('RGBA')
57
+ rgba_image.putalpha(mask)
58
+
59
+ return rgba_image
60
 
61
  css = ".output_image {height: 100% !important; width: 100% !important;}"
62
  inputs = gradio.inputs.Image()
u2net.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:309ea1fc458e8f6711efb645da85d1cc2a9cd2aec261d379bba31c0a0ddc78af
3
+ size 175995934