File size: 1,500 Bytes
3589cb2 33923aa 3589cb2 33923aa 3589cb2 33923aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
import os
import onnxruntime
import gradio as gr
import numpy as np
from PIL import Image
# Load the ONNX model
onnx_model_path = "sar2rgb.onnx"
sess = onnxruntime.InferenceSession(onnx_model_path)
# Function to process the input and make predictions
def predict(input_image):
# Preprocess the input image (e.g., resize, normalize)
input_image = input_image.resize((256, 256)) # Adjust size as needed
input_image = np.array(input_image).transpose(2, 0, 1) # HWC to CHW
input_image = input_image.astype(np.float32) / 255.0 # [0,1]
input_image = (input_image - 0.5) / 0.5 # [-1,1]
input_image = np.expand_dims(input_image, axis=0) # Add batch dimension
# Run the model
inputs = {sess.get_inputs()[0].name: input_image}
output = sess.run(None, inputs)
# Post-process the output image (if necessary)
output_image = output[0].squeeze().transpose(1, 2, 0) # CHW to HWC
output_image = (output_image + 1) / 2 # [0,1]
output_image = (output_image * 255).astype(np.uint8) # Denormalize [0,255]
return Image.fromarray(output_image)
# Specify example images
example_images = [[os.path.join("examples", fname)] for fname in os.listdir("examples")]
# Create Gradio interface
iface = gr.Interface(fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="pil"),
examples=example_images
)
# Launch the interface
iface.launch()
|