File size: 1,500 Bytes
3589cb2
33923aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3589cb2
 
 
33923aa
 
 
3589cb2
 
 
33923aa
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import os
import onnxruntime
import gradio as gr
import numpy as np
from PIL import Image

# Load the ONNX model
onnx_model_path = "sar2rgb.onnx"
sess = onnxruntime.InferenceSession(onnx_model_path)

# Function to process the input and make predictions
def predict(input_image):
    # Preprocess the input image (e.g., resize, normalize)
    input_image = input_image.resize((256, 256))  # Adjust size as needed
    input_image = np.array(input_image).transpose(2, 0, 1)  # HWC to CHW
    input_image = input_image.astype(np.float32) / 255.0  # [0,1]
    input_image = (input_image - 0.5) / 0.5               # [-1,1] 
    input_image = np.expand_dims(input_image, axis=0)  # Add batch dimension
    
    # Run the model
    inputs = {sess.get_inputs()[0].name: input_image}
    output = sess.run(None, inputs)
    
    # Post-process the output image (if necessary)
    output_image = output[0].squeeze().transpose(1, 2, 0)  # CHW to HWC
    output_image = (output_image + 1) / 2 # [0,1]
    output_image = (output_image * 255).astype(np.uint8)  # Denormalize [0,255]
    
    return Image.fromarray(output_image)

# Specify example images
example_images = [[os.path.join("examples", fname)] for fname in os.listdir("examples")]

# Create Gradio interface
iface = gr.Interface(fn=predict, 
                     inputs=gr.Image(type="pil"), 
                     outputs=gr.Image(type="pil"),
                     examples=example_images
                     )

# Launch the interface
iface.launch()