Spaces:
Build error
Build error
import torch | |
import torchvision | |
import onnxruntime | |
import onnx | |
import cv2 | |
import argparse | |
import warnings | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import os | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--test_path', type=str, default='/home/arye-stark/zwb/Illumination-Adaptive-Transformer/IAT_enhance/demo_imgs/low_demo.jpg') | |
parser.add_argument('--pk_path', type=str, default='model_zoo/Low.onnx') | |
parser.add_argument('--save_path', type=str, default='Results/') | |
config = parser.parse_args() | |
if not os.path.isdir(config.save_path): | |
os.mkdir(config.save_path) | |
img = plt.imread(config.test_path) | |
input_image = np.asarray(img) / 255.0 | |
input_image = torch.from_numpy(input_image).float() | |
input_image = input_image.permute(2, 0, 1).unsqueeze(0) | |
input_image = input_image.numpy() | |
providers = ['CPUExecutionProvider'] | |
model_name = 'IAT' | |
print('-' * 50) | |
try: | |
onnx_session = onnxruntime.InferenceSession(config.pk_path, providers=providers) | |
onnx_input = {'input': input_image} | |
#onnx_output0, onnx_output1, onnx_output2 = onnx_session.run(['output0', 'output1', 'output2'], onnx_input) | |
onnx_output = onnx_session.run(['output'], onnx_input) | |
torchvision.utils.save_image(torch.from_numpy(onnx_output[0]), config.save_path+'output.png') | |
#torch_output = np.squeeze(onnx_output[0], 0) | |
#torch_output = np.transpose(torch_output * 255, [1, 2, 0]).astype(np.uint8) | |
#plt.imsave(config.save_path+'output.png', torch_output) | |
except Exception as e: | |
print(f'Input on model:{model_name} failed') | |
print(e) | |
else: | |
print(f'Input on model:{model_name} succeed') | |