FunSR / test_inr_diinn_arbrcan_sadnarc_funsr_overnet.py
KyanChen's picture
add
02c5426
import argparse
import json
import os
import math
from functools import partial
import seaborn as sns
import cv2.dnn
import numpy as np
import yaml
import torch
from einops import rearrange
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from tqdm import tqdm
import datasets
import models
import utils
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
def batched_predict(model, img, coord, bsize):
with torch.no_grad():
pred = model(img, coord)
return pred
def eval_psnr(loader, class_names, model,
data_norm=None, eval_type=None, save_fig=False, save_featmap=False,
scale_ratio=1, save_path=None, verbose=False, crop_border=4,
cal_metrics=True,
):
crop_border = int(crop_border) if crop_border else crop_border
print('crop border: ', crop_border)
model.eval()
if data_norm is None:
data_norm = {
'img': {'sub': [0], 'div': [1]},
'gt': {'sub': [0], 'div': [1]}
}
t = data_norm['img']
img_sub = torch.FloatTensor(t['sub']).view(1, -1, 1, 1).to(device)
img_div = torch.FloatTensor(t['div']).view(1, -1, 1, 1).to(device)
t = data_norm['gt']
gt_sub = torch.FloatTensor(t['sub']).view(1, -1, 1, 1).to(device)
gt_div = torch.FloatTensor(t['div']).view(1, -1, 1, 1).to(device)
if eval_type is None:
metric_fn = [utils.calculate_psnr_pt, utils.calculate_ssim_pt]
elif eval_type == 'psnr+ssim':
metric_fn = [utils.calculate_psnr_pt, utils.calculate_ssim_pt]
elif eval_type.startswith('div2k'):
scale = int(eval_type.split('-')[1])
metric_fn = partial(utils.calc_psnr, dataset='div2k', scale=scale)
elif eval_type.startswith('benchmark'):
scale = int(eval_type.split('-')[1])
metric_fn = partial(utils.calc_psnr, dataset='benchmark', scale=scale)
else:
raise NotImplementedError
val_res_psnr = utils.Averager(class_names)
val_res_ssim = utils.Averager(class_names)
pbar = tqdm(loader, leave=False, desc='val')
for batch in pbar:
for k, v in batch.items():
if torch.is_tensor(v):
batch[k] = v.to(device)
img = (batch['img'] - img_sub) / img_div
with torch.no_grad():
preds = model(img, batch['gt'].shape[-2:])
if save_featmap:
pred = preds[0][-1]
returned_featmap = preds[1]
assert returned_featmap.size(1) == 6
else:
if isinstance(preds, list):
pred = preds[-1]
# import pdb
# pdb.set_trace()
pred = pred * gt_div + gt_sub
# if eval_type is not None: # reshape for shaving-eval
# ih, iw = batch['img'].shape[-2:]
# s = math.sqrt(batch['coord'].shape[1] / (ih * iw))
# if s > 1:
# shape = [batch['img'].shape[0], round(ih * s), round(iw * s), 3]
# else:
# shape = [batch['img'].shape[0], 32, batch['coord'].shape[1]//32, 3]
#
# pred = pred.view(*shape) \
# .permute(0, 3, 1, 2).contiguous()
# batch['gt'] = batch['gt'].view(*shape) \
# .permute(0, 3, 1, 2).contiguous()
# if crop_border is not None:
# h = math.sqrt(pred.shape[1])
# shape = [img.shape[0], round(h), round(h), 3]
# pred = pred.view(*shape).permute(0, 3, 1, 2).contiguous()
# batch['gt'] = batch['gt'].view(*shape).permute(0, 3, 1, 2).contiguous()
# else:
# pred = pred.permute(0, 2, 1).contiguous() # B 3 N
# batch['gt'] = batch['gt'].permute(0, 2, 1).contiguous()
# import pdb
# pdb.set_trace()
if cal_metrics:
res_psnr = metric_fn[0](
pred,
batch['gt'],
crop_border=crop_border
)
res_ssim = metric_fn[1](
pred,
batch['gt'],
crop_border=crop_border
)
else:
res_psnr = torch.ones(len(pred))
res_ssim = torch.ones(len(pred))
file_names = batch.get('filename', None)
if file_names is not None and save_featmap:
for idx in range(len(batch['img'])):
ori_img = batch['img'][idx].cpu().numpy() * 255
ori_img = np.clip(ori_img, a_min=0, a_max=255)
ori_img = ori_img.astype(np.uint8)
ori_img = rearrange(ori_img, 'C H W -> H W C')
pred_img = pred[idx].cpu().numpy() * 255
pred_img = np.clip(pred_img, a_min=0, a_max=255)
pred_img = pred_img.astype(np.uint8)
pred_img = rearrange(pred_img, 'C H W -> H W C')
is_normalize = True
f_tensors = returned_featmap[idx]
for idx_f in range(len(f_tensors)):
f_tensor = f_tensors[idx_f]
if is_normalize:
# normalize the features / feature maps
f_tensor = torch.sigmoid(f_tensor)
f_tensor = f_tensor.detach().cpu().numpy()
# for better visualization, you can normalize the feature heatmap
f_tensor = (f_tensor - np.min(f_tensor)) / (np.max(f_tensor) - np.min(f_tensor))
# f_tensor = (f_tensor - np.min(f_tensor)) / (np.max(f_tensor) - np.min(f_tensor))
sns.heatmap(f_tensor, vmin=0, vmax=1, cmap="jet", center=0.5)
plt.axis('off')
plt.xticks([])
plt.yticks([])
# plt.imshow(heatmap, cmap='YlGnBu', vmin=0, vmax=1)
# plt.show()
ori_file_name = f'{save_path}/{file_names[idx]}_{idx_f}.png'
plt.savefig(ori_file_name, dpi=600)
plt.close()
gt_img = batch['gt'][idx].cpu().numpy() * 255
gt_img = np.clip(gt_img, a_min=0, a_max=255)
gt_img = gt_img.astype(np.uint8)
gt_img = rearrange(gt_img, 'C H W -> H W C')
psnr = res_psnr[idx].cpu().numpy()
ssim = res_ssim[idx].cpu().numpy()
ori_file_name = f'{save_path}/{file_names[idx]}_Ori.png'
cv2.imwrite(ori_file_name, ori_img)
pred_file_name = f'{save_path}/{file_names[idx]}_{scale_ratio}X_{psnr:.2f}_{ssim:.4f}.png'
cv2.imwrite(pred_file_name, pred_img)
gt_file_name = f'{save_path}/{file_names[idx]}_GT.png'
cv2.imwrite(gt_file_name, gt_img)
# import pdb
# pdb.set_trace()
if file_names is not None and save_fig:
for idx in range(len(batch['img'])):
ori_img = batch['img'][idx].cpu().numpy() * 255
ori_img = np.clip(ori_img, a_min=0, a_max=255)
ori_img = ori_img.astype(np.uint8)
ori_img = rearrange(ori_img, 'C H W -> H W C')
pred_img = pred[idx].cpu().numpy() * 255
pred_img = np.clip(pred_img, a_min=0, a_max=255)
pred_img = pred_img.astype(np.uint8)
pred_img = rearrange(pred_img, 'C H W -> H W C')
gt_img = batch['gt'][idx].cpu().numpy() * 255
gt_img = np.clip(gt_img, a_min=0, a_max=255)
gt_img = gt_img.astype(np.uint8)
gt_img = rearrange(gt_img, 'C H W -> H W C')
psnr = res_psnr[idx].cpu().numpy()
ssim = res_ssim[idx].cpu().numpy()
ori_file_name = f'{save_path}/{file_names[idx]}_Ori.png'
cv2.imwrite(ori_file_name, ori_img)
pred_file_name = f'{save_path}/{file_names[idx]}_{scale_ratio}X_{psnr:.2f}_{ssim:.4f}.png'
cv2.imwrite(pred_file_name, pred_img)
gt_file_name = f'{save_path}/{file_names[idx]}_GT.png'
cv2.imwrite(gt_file_name, gt_img)
val_res_psnr.add(batch['class_name'], res_psnr)
val_res_ssim.add(batch['class_name'], res_ssim)
if verbose:
pbar.set_description(
'val psnr: {:.4f} ssim: {:.4f}'.format(val_res_psnr.item()['all'], val_res_ssim.item()['all']))
return val_res_psnr.item(), val_res_ssim.item()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config', default='configs/test_UC_INR_mysr.yaml')
parser.add_argument('--model', default='checkpoints/EXP20220610_5/epoch-best.pth')
parser.add_argument('--scale_ratio', default=4, type=float)
parser.add_argument('--save_fig', default=False, type=bool)
parser.add_argument('--save_featmap', default=False, type=bool)
parser.add_argument('--save_path', default='tmp', type=str)
parser.add_argument('--cal_metrics', default=True, type=bool)
parser.add_argument('--return_class_metrics', default=False, type=bool)
parser.add_argument('--dataset_name', default='UC', type=str)
args = parser.parse_args()
with open(args.config, 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
root_split_file = {'UC':
{
'root_path': '/Users/kyanchen/Documents/UC/256',
'split_file': '/Users/kyanchen/My_Code/sr/data_split/UC_split.json'
},
'AID':
{
'root_path': '/data/kyanchen/datasets/AID',
'split_file': 'data_split/AID_split.json'
}
}
config['test_dataset']['dataset']['args']['root_path'] = root_split_file[args.dataset_name]['root_path']
config['test_dataset']['dataset']['args']['split_file'] = root_split_file[args.dataset_name]['split_file']
config['test_dataset']['wrapper']['args']['scale_ratio'] = args.scale_ratio
spec = config['test_dataset']
dataset = datasets.make(spec['dataset'])
dataset = datasets.make(spec['wrapper'], args={'dataset': dataset})
loader = DataLoader(dataset, batch_size=spec['batch_size'], num_workers=0, pin_memory=True, shuffle=False, drop_last=False)
if not os.path.exists(args.model):
assert NameError
model_spec = torch.load(args.model, map_location='cpu')['model']
print(model_spec['args'])
model = models.make(model_spec, load_sd=True).to(device)
file_names = json.load(open(config['test_dataset']['dataset']['args']['split_file']))['test']
class_names = list(set([os.path.basename(os.path.dirname(x)) for x in file_names]))
crop_border = config['test_dataset']['wrapper']['args']['scale_ratio'] + 5
dataset_name = os.path.basename(config['test_dataset']['dataset']['args']['split_file']).split('_')[0]
max_scale = {'UC': 5, 'AID': 12}
if args.scale_ratio > max_scale[dataset_name]:
crop_border = int((args.scale_ratio - max_scale[dataset_name]) / 2 * 48)
if args.save_fig or args.save_featmap:
os.makedirs(args.save_path, exist_ok=True)
res = eval_psnr(
loader, class_names, model,
data_norm=config.get('data_norm'),
eval_type=config.get('eval_type'),
crop_border=crop_border,
verbose=True,
save_fig=args.save_fig,
save_featmap=args.save_featmap,
scale_ratio=args.scale_ratio,
save_path=args.save_path,
cal_metrics=args.cal_metrics
)
if args.return_class_metrics:
keys = list(res[0].keys())
keys.sort()
print('psnr')
for k in keys:
print(f'{k}: {res[0][k]:0.2f}')
print('ssim')
for k in keys:
print(f'{k}: {res[1][k]:0.4f}')
print(f'psnr: {res[0]["all"]:0.2f}')
print(f'ssim: {res[1]["all"]:0.4f}')