import os import onnxruntime import gradio as gr import numpy as np from PIL import Image # Load the ONNX model onnx_model_path = "sar2rgb.onnx" sess = onnxruntime.InferenceSession(onnx_model_path) # Function to process the input and make predictions def predict(input_image): # Preprocess the input image (e.g., resize, normalize) input_image = input_image.resize((256, 256)) # Adjust size as needed input_image = np.array(input_image).transpose(2, 0, 1) # HWC to CHW input_image = input_image.astype(np.float32) / 255.0 # [0,1] input_image = (input_image - 0.5) / 0.5 # [-1,1] input_image = np.expand_dims(input_image, axis=0) # Add batch dimension # Run the model inputs = {sess.get_inputs()[0].name: input_image} output = sess.run(None, inputs) # Post-process the output image (if necessary) output_image = output[0].squeeze().transpose(1, 2, 0) # CHW to HWC output_image = (output_image + 1) / 2 # [0,1] output_image = (output_image * 255).astype(np.uint8) # Denormalize [0,255] return Image.fromarray(output_image) # Specify example images example_images = [[os.path.join("examples", fname)] for fname in os.listdir("examples")] # Create Gradio interface iface = gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Image(type="pil"), examples=example_images ) # Launch the interface iface.launch()