File size: 3,248 Bytes
da9195c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f86314
da9195c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
import os
import sys
import pathlib
CURRENT_DIR = pathlib.Path(__file__).parent
sys.path.append(str(CURRENT_DIR))
from tqdm import tqdm
import data
import metric
import onnxruntime
import cv2
from data.data_tiling import tiling_inference
import argparse

class Configs():
    def __init__(self):
        parser = argparse.ArgumentParser(description='SR')

        # ipu test or cpu, you need to provide onnx path 
        parser.add_argument('--ipu', action='store_true',
                            help='use ipu')
        parser.add_argument('--onnx_path', type=str, default='RCAN_int8_NHWC.onnx',
                            help='onnx path')
        parser.add_argument('--provider_config', type=str, default=None,
                            help='provider config path')
        # Data specifications, you can use default
        parser.add_argument('--dir_data', type=str, default='dataset/',
                            help='dataset directory')
        parser.add_argument('--data_test', type=str, default='Set5',
                            help='test dataset name')

        parser.add_argument('--n_threads', type=int, default=6,
                            help='number of threads for data loading')
        parser.add_argument('--scale', type=str, default='2',
                            help='super resolution scale, now only support x2')
        self.parser = parser

    def parse(self):
        args = self.parser.parse_args()
        args.scale = list(map(lambda x: int(x), args.scale.split('+')))
        args.data_test = args.data_test.split('+')
        print(args)
        return args



def quantize(img, rgb_range): # clamp pix to rgb range
    pixel_range = 255 / rgb_range
    return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)

def test_model(session, loader, device):
    torch.set_grad_enabled(False)
    self_scale = [2]
    for idx_data, d in enumerate(loader.loader_test):
        eval_ssim = 0
        eval_psnr = 0
        for idx_scale, scale in enumerate(self_scale):
            d.dataset.set_scale(idx_scale)
            for lr, hr, filename in tqdm(d, ncols=80):
                sr = tiling_inference(session, lr.cpu().numpy(), 8, (56, 56))
                sr = torch.from_numpy(sr).to(device)
                sr = quantize(sr, 255)
                eval_psnr += metric.calc_psnr(
                    sr, hr, scale, 255, benchmark=d)
                eval_ssim += metric.calc_ssim(
                    sr, hr, scale, 255, dataset=d)
            mean_ssim = eval_ssim / len(d)
            mean_psnr = eval_psnr  / len(d)
        print("psnr: %s, ssim: %s"%(mean_psnr, mean_ssim))
    return mean_psnr, mean_ssim

def main(args):
    loader = data.Data(args)
    if args.ipu:
        providers = ["VitisAIExecutionProvider"]
        provider_options = [{"config_file": args.provider_config}]
    else:
        providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
        provider_options = None
    onnx_file_name = args.onnx_path
    ort_session = onnxruntime.InferenceSession(onnx_file_name, providers=providers, provider_options=provider_options) 
    test_model(ort_session, loader, device="cpu")


if __name__ == '__main__':
    args = Configs().parse()
    main(args)