yuulind commited on
Commit
33923aa
·
1 Parent(s): b4ccdc6

Add app.py

Browse files
Files changed (2) hide show
  1. app.py +36 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnxruntime
2
+ import gradio as gr
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+ # Load the ONNX model
7
+ onnx_model_path = "sar2rgb.onnx"
8
+ sess = onnxruntime.InferenceSession(onnx_model_path)
9
+
10
+ # Function to process the input and make predictions
11
+ def predict(input_image):
12
+ # Preprocess the input image (e.g., resize, normalize)
13
+ input_image = input_image.resize((256, 256)) # Adjust size as needed
14
+ input_image = np.array(input_image).transpose(2, 0, 1) # HWC to CHW
15
+ input_image = input_image.astype(np.float32) / 255.0 # [0,1]
16
+ input_image = (input_image - 0.5) / 0.5 # [-1,1]
17
+ input_image = np.expand_dims(input_image, axis=0) # Add batch dimension
18
+
19
+ # Run the model
20
+ inputs = {sess.get_inputs()[0].name: input_image}
21
+ output = sess.run(None, inputs)
22
+
23
+ # Post-process the output image (if necessary)
24
+ output_image = output[0].squeeze().transpose(1, 2, 0) # CHW to HWC
25
+ output_image = (output_image + 1) / 2 # [0,1]
26
+ output_image = (output_image * 255).astype(np.uint8) # Denormalize [0,255]
27
+
28
+ return Image.fromarray(output_image)
29
+
30
+ # Create Gradio interface
31
+ iface = gr.Interface(fn=predict,
32
+ inputs=gr.Image(type="pil"),
33
+ outputs=gr.Image(type="pil"))
34
+
35
+ # Launch the interface
36
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio==5.5.0
2
+ numpy==2.1.3
3
+ onnxruntime==1.16.3
4
+ Pillow==10.1.0
5
+ Pillow==11.0.0