import argparse import json import os import math from functools import partial import cv2 import numpy as np import yaml import torch from PIL.Image import Image from einops import rearrange from torch.utils.data import DataLoader from torchvision import transforms from torchvision.transforms import InterpolationMode from tqdm import tqdm import datasets import models import utils device = 'cuda:0' if torch.cuda.is_available() else 'cpu' def batched_predict(model, img, bsize): with torch.no_grad(): pred = model(img) return pred def eval_psnr(loader, class_names, data_norm=None, eval_type=None, save_fig=False, scale_ratio=1, save_path=None, verbose=False, crop_border=4, cal_metrics=True, ): crop_border = int(crop_border) if crop_border else crop_border print('crop border: ', crop_border) if data_norm is None: data_norm = { 'img': {'sub': [0], 'div': [1]}, 'gt': {'sub': [0], 'div': [1]} } t = data_norm['img'] img_sub = torch.FloatTensor(t['sub']).view(1, -1, 1, 1).to(device) img_div = torch.FloatTensor(t['div']).view(1, -1, 1, 1).to(device) t = data_norm['gt'] gt_sub = torch.FloatTensor(t['sub']).view(1, 1, -1).to(device) gt_div = torch.FloatTensor(t['div']).view(1, 1, -1).to(device) if eval_type is None: metric_fn = utils.calculate_psnr_pt elif eval_type == 'psnr+ssim': metric_fn = [utils.calculate_psnr_pt, utils.calculate_ssim_pt] elif eval_type.startswith('div2k'): scale = int(eval_type.split('-')[1]) metric_fn = partial(utils.calc_psnr, dataset='div2k', scale=scale) elif eval_type.startswith('benchmark'): scale = int(eval_type.split('-')[1]) metric_fn = partial(utils.calc_psnr, dataset='benchmark', scale=scale) else: raise NotImplementedError val_res_psnr = utils.Averager(class_names) val_res_ssim = utils.Averager(class_names) pbar = tqdm(loader, leave=False, desc='val') for batch in pbar: for k, v in batch.items(): if torch.is_tensor(v): batch[k] = v.to(device) img = (batch['img'] - img_sub) / img_div pred = transforms.Resize(batch['gt'].size(-1), InterpolationMode.BICUBIC)(img) pred = pred * gt_div + gt_sub if cal_metrics: res_psnr = metric_fn[0]( pred, batch['gt'], crop_border=crop_border ) res_ssim = metric_fn[1]( pred, batch['gt'], crop_border=crop_border ) else: res_psnr = torch.ones(len(pred)) res_ssim = torch.ones(len(pred)) file_names = batch.get('filename', None) if file_names is not None and save_fig: for idx in range(len(batch['img'])): ori_img = batch['img'][idx].cpu().numpy() * 255 ori_img = np.clip(ori_img, a_min=0, a_max=255) ori_img = ori_img.astype(np.uint8) ori_img = rearrange(ori_img, 'C H W -> H W C') pred_img = pred[idx].cpu().numpy() * 255 pred_img = np.clip(pred_img, a_min=0, a_max=255) pred_img = pred_img.astype(np.uint8) pred_img = rearrange(pred_img, 'C H W -> H W C') gt_img = batch['gt'][idx].cpu().numpy() * 255 gt_img = np.clip(gt_img, a_min=0, a_max=255) gt_img = gt_img.astype(np.uint8) gt_img = rearrange(gt_img, 'C H W -> H W C') psnr = res_psnr[idx].cpu().numpy() ssim = res_ssim[idx].cpu().numpy() ori_file_name = f'{save_path}/{file_names[idx]}_Ori.png' cv2.imwrite(ori_file_name, ori_img) pred_file_name = f'{save_path}/{file_names[idx]}_{scale_ratio}X_{psnr:.2f}_{ssim:.4f}.png' cv2.imwrite(pred_file_name, pred_img) gt_file_name = f'{save_path}/{file_names[idx]}_GT.png' cv2.imwrite(gt_file_name, gt_img) val_res_psnr.add(batch['class_name'], res_psnr) val_res_ssim.add(batch['class_name'], res_ssim) if verbose: pbar.set_description('val psnr: {:.4f} ssim: {:.4f}'.format(val_res_psnr.item()['all'], val_res_ssim.item()['all'])) return val_res_psnr.item(), val_res_ssim.item() if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--config', default='configs/test_fixed_scale_sr.yaml') parser.add_argument('--model', default='checkpoints/EXP20220610_5/epoch-best.pth') parser.add_argument('--scale_ratio', default=4, type=float) parser.add_argument('--save_fig', default=False, type=bool) parser.add_argument('--save_path', default='tmp', type=str) parser.add_argument('--cal_metrics', default=True, type=bool) parser.add_argument('--return_class_metrics', default=False, type=bool) parser.add_argument('--dataset_name', default='UC', type=str) args = parser.parse_args() with open(args.config, 'r') as f: config = yaml.load(f, Loader=yaml.FullLoader) root_split_file = {'UC': { 'root_path': '/data/kyanchen/datasets/UC/256', 'split_file': 'data_split/UC_split.json' }, 'AID': { 'root_path': '/data/kyanchen/datasets/AID', 'split_file': 'data_split/AID_split.json' } } config['test_dataset']['dataset']['args']['root_path'] = root_split_file[args.dataset_name]['root_path'] config['test_dataset']['dataset']['args']['split_file'] = root_split_file[args.dataset_name]['split_file'] config['test_dataset']['wrapper']['args']['scale_ratio'] = args.scale_ratio spec = config['test_dataset'] dataset = datasets.make(spec['dataset']) dataset = datasets.make(spec['wrapper'], args={'dataset': dataset}) loader = DataLoader(dataset, batch_size=spec['batch_size'], num_workers=0, pin_memory=True, shuffle=False, drop_last=False) file_names = json.load(open(config['test_dataset']['dataset']['args']['split_file']))['test'] class_names = list(set([os.path.basename(os.path.dirname(x)) for x in file_names])) crop_border = config['test_dataset']['wrapper']['args']['scale_ratio'] dataset_name = os.path.basename(config['test_dataset']['dataset']['args']['split_file']).split('_')[0] max_scale = {'UC': 5, 'AID': 12} if args.scale_ratio > max_scale[dataset_name]: crop_border = int((args.scale_ratio-max_scale[dataset_name])/2*48) if args.save_fig: os.makedirs(args.save_path, exist_ok=True) res = eval_psnr( loader, class_names, data_norm=config.get('data_norm'), eval_type=config.get('eval_type'), crop_border=crop_border, verbose=True, save_fig=args.save_fig, scale_ratio=args.scale_ratio, save_path=args.save_path, cal_metrics=args.cal_metrics ) if args.return_class_metrics: keys = list(res[0].keys()) keys.sort() print('psnr') for k in keys: print(f'{k}: {res[0][k]:0.2f}') print('ssim') for k in keys: print(f'{k}: {res[1][k]:0.4f}') print(f'psnr: {res[0]["all"]:0.2f}') print(f'ssim: {res[1]["all"]:0.4f}')