xiongjie commited on
Commit
d8f7d4a
·
1 Parent(s): 0798b8f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -0
app.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import copy
3
+ import time
4
+
5
+ import cv2 as cv
6
+ import numpy as np
7
+ import onnxruntime
8
+
9
+ import gradio
10
+
11
+ def run_inference(onnx_session, input_size, image):
12
+ # Pre process:Resize, BGR->RGB, Transpose, PyTorch standardization, float32 cast
13
+ temp_image = copy.deepcopy(image)
14
+ resize_image = cv.resize(temp_image, dsize=(input_size[0], input_size[1]))
15
+ x = cv.cvtColor(resize_image, cv.COLOR_BGR2RGB)
16
+ x = np.array(x, dtype=np.float32)
17
+ mean = [0.485, 0.456, 0.406]
18
+ std = [0.229, 0.224, 0.225]
19
+ x = (x / 255 - mean) / std
20
+ x = x.reshape(-1, input_size[0], input_size[1], 3).astype('float32')
21
+
22
+ # Inference
23
+ input_name = onnx_session.get_inputs()[0].name
24
+ output_name = onnx_session.get_outputs()[0].name
25
+ onnx_result = onnx_session.run([output_name], {input_name: x})
26
+
27
+ # Post process
28
+ onnx_result = np.array(onnx_result).squeeze()
29
+ onnx_result = (1 - onnx_result)
30
+ min_value = np.min(onnx_result)
31
+ max_value = np.max(onnx_result)
32
+ onnx_result = (onnx_result - min_value) / (max_value - min_value)
33
+ onnx_result *= 255
34
+ onnx_result = onnx_result.astype('uint8')
35
+
36
+ return onnx_result
37
+
38
+ # Load model
39
+ onnx_session = onnxruntime.InferenceSession(model_path)
40
+
41
+ def create_rgba(image):
42
+ return run_inference(
43
+ onnx_session,
44
+ image.shape,
45
+ image,
46
+ )
47
+
48
+ css = ".output_image {height: 100% !important; width: 100% !important;}"
49
+ inputs = gradio.inputs.Image()
50
+ outputs = gradio.outputs.Image()
51
+ gradio.Interface(fn=create_rgba, inputs=inputs, outputs=outputs, css=css).launch()