|
import argparse |
|
import json |
|
import os |
|
|
|
import math |
|
from functools import partial |
|
import seaborn as sns |
|
import cv2.dnn |
|
import numpy as np |
|
import yaml |
|
import torch |
|
from einops import rearrange |
|
from matplotlib import pyplot as plt |
|
from torch.utils.data import DataLoader |
|
from tqdm import tqdm |
|
|
|
import datasets |
|
import models |
|
import utils |
|
|
|
device = 'cuda:0' if torch.cuda.is_available() else 'cpu' |
|
|
|
def batched_predict(model, img, coord, bsize): |
|
with torch.no_grad(): |
|
pred = model(img, coord) |
|
return pred |
|
|
|
|
|
def eval_psnr(loader, class_names, model, |
|
data_norm=None, eval_type=None, save_fig=False, save_featmap=False, |
|
scale_ratio=1, save_path=None, verbose=False, crop_border=4, |
|
cal_metrics=True, |
|
): |
|
crop_border = int(crop_border) if crop_border else crop_border |
|
print('crop border: ', crop_border) |
|
model.eval() |
|
|
|
if data_norm is None: |
|
data_norm = { |
|
'img': {'sub': [0], 'div': [1]}, |
|
'gt': {'sub': [0], 'div': [1]} |
|
} |
|
t = data_norm['img'] |
|
img_sub = torch.FloatTensor(t['sub']).view(1, -1, 1, 1).to(device) |
|
img_div = torch.FloatTensor(t['div']).view(1, -1, 1, 1).to(device) |
|
t = data_norm['gt'] |
|
gt_sub = torch.FloatTensor(t['sub']).view(1, -1, 1, 1).to(device) |
|
gt_div = torch.FloatTensor(t['div']).view(1, -1, 1, 1).to(device) |
|
|
|
if eval_type is None: |
|
metric_fn = [utils.calculate_psnr_pt, utils.calculate_ssim_pt] |
|
elif eval_type == 'psnr+ssim': |
|
metric_fn = [utils.calculate_psnr_pt, utils.calculate_ssim_pt] |
|
elif eval_type.startswith('div2k'): |
|
scale = int(eval_type.split('-')[1]) |
|
metric_fn = partial(utils.calc_psnr, dataset='div2k', scale=scale) |
|
elif eval_type.startswith('benchmark'): |
|
scale = int(eval_type.split('-')[1]) |
|
metric_fn = partial(utils.calc_psnr, dataset='benchmark', scale=scale) |
|
else: |
|
raise NotImplementedError |
|
|
|
val_res_psnr = utils.Averager(class_names) |
|
val_res_ssim = utils.Averager(class_names) |
|
|
|
pbar = tqdm(loader, leave=False, desc='val') |
|
for batch in pbar: |
|
for k, v in batch.items(): |
|
if torch.is_tensor(v): |
|
batch[k] = v.to(device) |
|
|
|
img = (batch['img'] - img_sub) / img_div |
|
with torch.no_grad(): |
|
preds = model(img, batch['gt'].shape[-2:]) |
|
if save_featmap: |
|
pred = preds[0][-1] |
|
returned_featmap = preds[1] |
|
assert returned_featmap.size(1) == 6 |
|
else: |
|
if isinstance(preds, list): |
|
pred = preds[-1] |
|
|
|
|
|
pred = pred * gt_div + gt_sub |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if cal_metrics: |
|
res_psnr = metric_fn[0]( |
|
pred, |
|
batch['gt'], |
|
crop_border=crop_border |
|
) |
|
res_ssim = metric_fn[1]( |
|
pred, |
|
batch['gt'], |
|
crop_border=crop_border |
|
) |
|
else: |
|
res_psnr = torch.ones(len(pred)) |
|
res_ssim = torch.ones(len(pred)) |
|
|
|
file_names = batch.get('filename', None) |
|
if file_names is not None and save_featmap: |
|
for idx in range(len(batch['img'])): |
|
ori_img = batch['img'][idx].cpu().numpy() * 255 |
|
ori_img = np.clip(ori_img, a_min=0, a_max=255) |
|
ori_img = ori_img.astype(np.uint8) |
|
ori_img = rearrange(ori_img, 'C H W -> H W C') |
|
|
|
pred_img = pred[idx].cpu().numpy() * 255 |
|
pred_img = np.clip(pred_img, a_min=0, a_max=255) |
|
pred_img = pred_img.astype(np.uint8) |
|
pred_img = rearrange(pred_img, 'C H W -> H W C') |
|
|
|
is_normalize = True |
|
f_tensors = returned_featmap[idx] |
|
for idx_f in range(len(f_tensors)): |
|
f_tensor = f_tensors[idx_f] |
|
if is_normalize: |
|
|
|
f_tensor = torch.sigmoid(f_tensor) |
|
f_tensor = f_tensor.detach().cpu().numpy() |
|
|
|
f_tensor = (f_tensor - np.min(f_tensor)) / (np.max(f_tensor) - np.min(f_tensor)) |
|
|
|
sns.heatmap(f_tensor, vmin=0, vmax=1, cmap="jet", center=0.5) |
|
plt.axis('off') |
|
plt.xticks([]) |
|
plt.yticks([]) |
|
|
|
|
|
ori_file_name = f'{save_path}/{file_names[idx]}_{idx_f}.png' |
|
plt.savefig(ori_file_name, dpi=600) |
|
plt.close() |
|
|
|
gt_img = batch['gt'][idx].cpu().numpy() * 255 |
|
gt_img = np.clip(gt_img, a_min=0, a_max=255) |
|
gt_img = gt_img.astype(np.uint8) |
|
gt_img = rearrange(gt_img, 'C H W -> H W C') |
|
|
|
psnr = res_psnr[idx].cpu().numpy() |
|
ssim = res_ssim[idx].cpu().numpy() |
|
ori_file_name = f'{save_path}/{file_names[idx]}_Ori.png' |
|
cv2.imwrite(ori_file_name, ori_img) |
|
pred_file_name = f'{save_path}/{file_names[idx]}_{scale_ratio}X_{psnr:.2f}_{ssim:.4f}.png' |
|
cv2.imwrite(pred_file_name, pred_img) |
|
gt_file_name = f'{save_path}/{file_names[idx]}_GT.png' |
|
cv2.imwrite(gt_file_name, gt_img) |
|
|
|
|
|
|
|
if file_names is not None and save_fig: |
|
for idx in range(len(batch['img'])): |
|
ori_img = batch['img'][idx].cpu().numpy() * 255 |
|
ori_img = np.clip(ori_img, a_min=0, a_max=255) |
|
ori_img = ori_img.astype(np.uint8) |
|
ori_img = rearrange(ori_img, 'C H W -> H W C') |
|
|
|
pred_img = pred[idx].cpu().numpy() * 255 |
|
pred_img = np.clip(pred_img, a_min=0, a_max=255) |
|
pred_img = pred_img.astype(np.uint8) |
|
pred_img = rearrange(pred_img, 'C H W -> H W C') |
|
|
|
gt_img = batch['gt'][idx].cpu().numpy() * 255 |
|
gt_img = np.clip(gt_img, a_min=0, a_max=255) |
|
gt_img = gt_img.astype(np.uint8) |
|
gt_img = rearrange(gt_img, 'C H W -> H W C') |
|
|
|
psnr = res_psnr[idx].cpu().numpy() |
|
ssim = res_ssim[idx].cpu().numpy() |
|
ori_file_name = f'{save_path}/{file_names[idx]}_Ori.png' |
|
cv2.imwrite(ori_file_name, ori_img) |
|
pred_file_name = f'{save_path}/{file_names[idx]}_{scale_ratio}X_{psnr:.2f}_{ssim:.4f}.png' |
|
cv2.imwrite(pred_file_name, pred_img) |
|
gt_file_name = f'{save_path}/{file_names[idx]}_GT.png' |
|
cv2.imwrite(gt_file_name, gt_img) |
|
|
|
val_res_psnr.add(batch['class_name'], res_psnr) |
|
val_res_ssim.add(batch['class_name'], res_ssim) |
|
|
|
if verbose: |
|
pbar.set_description( |
|
'val psnr: {:.4f} ssim: {:.4f}'.format(val_res_psnr.item()['all'], val_res_ssim.item()['all'])) |
|
|
|
return val_res_psnr.item(), val_res_ssim.item() |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--config', default='configs/test_UC_INR_mysr.yaml') |
|
parser.add_argument('--model', default='checkpoints/EXP20220610_5/epoch-best.pth') |
|
parser.add_argument('--scale_ratio', default=4, type=float) |
|
parser.add_argument('--save_fig', default=False, type=bool) |
|
parser.add_argument('--save_featmap', default=False, type=bool) |
|
parser.add_argument('--save_path', default='tmp', type=str) |
|
parser.add_argument('--cal_metrics', default=True, type=bool) |
|
parser.add_argument('--return_class_metrics', default=False, type=bool) |
|
parser.add_argument('--dataset_name', default='UC', type=str) |
|
args = parser.parse_args() |
|
|
|
with open(args.config, 'r') as f: |
|
config = yaml.load(f, Loader=yaml.FullLoader) |
|
root_split_file = {'UC': |
|
{ |
|
'root_path': '/Users/kyanchen/Documents/UC/256', |
|
'split_file': '/Users/kyanchen/My_Code/sr/data_split/UC_split.json' |
|
}, |
|
'AID': |
|
{ |
|
'root_path': '/data/kyanchen/datasets/AID', |
|
'split_file': 'data_split/AID_split.json' |
|
} |
|
} |
|
config['test_dataset']['dataset']['args']['root_path'] = root_split_file[args.dataset_name]['root_path'] |
|
config['test_dataset']['dataset']['args']['split_file'] = root_split_file[args.dataset_name]['split_file'] |
|
|
|
config['test_dataset']['wrapper']['args']['scale_ratio'] = args.scale_ratio |
|
|
|
spec = config['test_dataset'] |
|
dataset = datasets.make(spec['dataset']) |
|
dataset = datasets.make(spec['wrapper'], args={'dataset': dataset}) |
|
loader = DataLoader(dataset, batch_size=spec['batch_size'], num_workers=0, pin_memory=True, shuffle=False, drop_last=False) |
|
if not os.path.exists(args.model): |
|
assert NameError |
|
model_spec = torch.load(args.model, map_location='cpu')['model'] |
|
print(model_spec['args']) |
|
model = models.make(model_spec, load_sd=True).to(device) |
|
|
|
file_names = json.load(open(config['test_dataset']['dataset']['args']['split_file']))['test'] |
|
class_names = list(set([os.path.basename(os.path.dirname(x)) for x in file_names])) |
|
|
|
crop_border = config['test_dataset']['wrapper']['args']['scale_ratio'] + 5 |
|
dataset_name = os.path.basename(config['test_dataset']['dataset']['args']['split_file']).split('_')[0] |
|
max_scale = {'UC': 5, 'AID': 12} |
|
if args.scale_ratio > max_scale[dataset_name]: |
|
crop_border = int((args.scale_ratio - max_scale[dataset_name]) / 2 * 48) |
|
|
|
if args.save_fig or args.save_featmap: |
|
os.makedirs(args.save_path, exist_ok=True) |
|
|
|
res = eval_psnr( |
|
loader, class_names, model, |
|
data_norm=config.get('data_norm'), |
|
eval_type=config.get('eval_type'), |
|
crop_border=crop_border, |
|
verbose=True, |
|
save_fig=args.save_fig, |
|
save_featmap=args.save_featmap, |
|
scale_ratio=args.scale_ratio, |
|
save_path=args.save_path, |
|
cal_metrics=args.cal_metrics |
|
) |
|
|
|
if args.return_class_metrics: |
|
keys = list(res[0].keys()) |
|
keys.sort() |
|
print('psnr') |
|
for k in keys: |
|
print(f'{k}: {res[0][k]:0.2f}') |
|
print('ssim') |
|
for k in keys: |
|
print(f'{k}: {res[1][k]:0.4f}') |
|
print(f'psnr: {res[0]["all"]:0.2f}') |
|
print(f'ssim: {res[1]["all"]:0.4f}') |
|
|