Deploy_Restoration / Lowlight.py
AlexZou's picture
Update Lowlight.py
aed4629
raw
history blame
1.61 kB
import torch
import torchvision
import onnxruntime
import onnx
import cv2
import argparse
import warnings
import numpy as np
import matplotlib.pyplot as plt
import os
parser = argparse.ArgumentParser()
parser.add_argument('--test_path', type=str, default='/home/arye-stark/zwb/Illumination-Adaptive-Transformer/IAT_enhance/demo_imgs/low_demo.jpg')
parser.add_argument('--pk_path', type=str, default='model_zoo/Low.onnx')
parser.add_argument('--save_path', type=str, default='Results/')
config = parser.parse_args()
if not os.path.isdir(config.save_path):
os.mkdir(config.save_path)
img = plt.imread(config.test_path)
input_image = np.asarray(img) / 255.0
input_image = torch.from_numpy(input_image).float()
input_image = input_image.permute(2, 0, 1).unsqueeze(0)
input_image = input_image.numpy()
providers = ['CPUExecutionProvider']
model_name = 'IAT'
print('-' * 50)
try:
onnx_session = onnxruntime.InferenceSession(config.pk_path, providers=providers)
onnx_input = {'input': input_image}
#onnx_output0, onnx_output1, onnx_output2 = onnx_session.run(['output0', 'output1', 'output2'], onnx_input)
onnx_output = onnx_session.run(['output'], onnx_input)
torchvision.utils.save_image(torch.from_numpy(onnx_output[0]), config.save_path+'output.png')
#torch_output = np.squeeze(onnx_output[0], 0)
#torch_output = np.transpose(torch_output * 255, [1, 2, 0]).astype(np.uint8)
#plt.imsave(config.save_path+'output.png', torch_output)
except Exception as e:
print(f'Input on model:{model_name} failed')
print(e)
else:
print(f'Input on model:{model_name} succeed')