|
import os |
|
from argparse import ArgumentParser |
|
import warnings |
|
|
|
from omegaconf import OmegaConf |
|
import torch |
|
from torch.nn import functional as F |
|
from torch.utils.data import DataLoader |
|
from torch.utils.tensorboard import SummaryWriter |
|
from torchvision.utils import make_grid |
|
from accelerate import Accelerator |
|
from accelerate.utils import set_seed |
|
from einops import rearrange |
|
from tqdm import tqdm |
|
import lpips |
|
|
|
from model import SwinIR |
|
from utils.common import instantiate_from_config |
|
|
|
|
|
|
|
def rgb2ycbcr_pt(img, y_only=False): |
|
"""Convert RGB images to YCbCr images (PyTorch version). |
|
|
|
It implements the ITU-R BT.601 conversion for standard-definition television. See more details in |
|
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. |
|
|
|
Args: |
|
img (Tensor): Images with shape (n, 3, h, w), the range [0, 1], float, RGB format. |
|
y_only (bool): Whether to only return Y channel. Default: False. |
|
|
|
Returns: |
|
(Tensor): converted images with the shape (n, 3/1, h, w), the range [0, 1], float. |
|
""" |
|
if y_only: |
|
weight = torch.tensor([[65.481], [128.553], [24.966]]).to(img) |
|
out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + 16.0 |
|
else: |
|
weight = torch.tensor([[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]).to(img) |
|
bias = torch.tensor([16, 128, 128]).view(1, 3, 1, 1).to(img) |
|
out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + bias |
|
|
|
out_img = out_img / 255. |
|
return out_img |
|
|
|
|
|
|
|
def calculate_psnr_pt(img, img2, crop_border, test_y_channel=False): |
|
"""Calculate PSNR (Peak Signal-to-Noise Ratio) (PyTorch version). |
|
|
|
Reference: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio |
|
|
|
Args: |
|
img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). |
|
img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). |
|
crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. |
|
test_y_channel (bool): Test on Y channel of YCbCr. Default: False. |
|
|
|
Returns: |
|
float: PSNR result. |
|
""" |
|
|
|
assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') |
|
|
|
if crop_border != 0: |
|
img = img[:, :, crop_border:-crop_border, crop_border:-crop_border] |
|
img2 = img2[:, :, crop_border:-crop_border, crop_border:-crop_border] |
|
|
|
if test_y_channel: |
|
img = rgb2ycbcr_pt(img, y_only=True) |
|
img2 = rgb2ycbcr_pt(img2, y_only=True) |
|
|
|
img = img.to(torch.float64) |
|
img2 = img2.to(torch.float64) |
|
|
|
mse = torch.mean((img - img2)**2, dim=[1, 2, 3]) |
|
return 10. * torch.log10(1. / (mse + 1e-8)) |
|
|
|
|
|
def main(args) -> None: |
|
|
|
accelerator = Accelerator(split_batches=True) |
|
set_seed(231) |
|
device = accelerator.device |
|
cfg = OmegaConf.load(args.config) |
|
|
|
|
|
if accelerator.is_local_main_process: |
|
exp_dir = cfg.train.exp_dir |
|
os.makedirs(exp_dir, exist_ok=True) |
|
ckpt_dir = os.path.join(exp_dir, "checkpoints") |
|
os.makedirs(ckpt_dir, exist_ok=True) |
|
print(f"Experiment directory created at {exp_dir}") |
|
|
|
|
|
swinir: SwinIR = instantiate_from_config(cfg.model.swinir) |
|
if cfg.train.resume: |
|
swinir.load_state_dict(torch.load(cfg.train.resume, map_location="cpu"), strict=True) |
|
if accelerator.is_local_main_process: |
|
print(f"strictly load weight from checkpoint: {cfg.train.resume}") |
|
else: |
|
if accelerator.is_local_main_process: |
|
print("initialize from scratch") |
|
|
|
|
|
opt = torch.optim.AdamW( |
|
swinir.parameters(), lr=cfg.train.learning_rate, |
|
weight_decay=0 |
|
) |
|
|
|
|
|
dataset = instantiate_from_config(cfg.dataset.train) |
|
loader = DataLoader( |
|
dataset=dataset, batch_size=cfg.train.batch_size, |
|
num_workers=cfg.train.num_workers, |
|
shuffle=True, drop_last=True |
|
) |
|
val_dataset = instantiate_from_config(cfg.dataset.val) |
|
val_loader = DataLoader( |
|
dataset=val_dataset, batch_size=cfg.train.batch_size, |
|
num_workers=cfg.train.num_workers, |
|
shuffle=False, drop_last=False |
|
) |
|
if accelerator.is_local_main_process: |
|
print(f"Dataset contains {len(dataset):,} images from {dataset.file_list}") |
|
|
|
|
|
swinir.train().to(device) |
|
swinir, opt, loader, val_loader = accelerator.prepare(swinir, opt, loader, val_loader) |
|
pure_swinir = accelerator.unwrap_model(swinir) |
|
|
|
|
|
global_step = 0 |
|
max_steps = cfg.train.train_steps |
|
step_loss = [] |
|
epoch = 0 |
|
epoch_loss = [] |
|
with warnings.catch_warnings(): |
|
|
|
warnings.simplefilter("ignore") |
|
lpips_model = lpips.LPIPS(net="alex", verbose=accelerator.is_local_main_process).eval().to(device) |
|
if accelerator.is_local_main_process: |
|
writer = SummaryWriter(exp_dir) |
|
print(f"Training for {max_steps} steps...") |
|
|
|
while global_step < max_steps: |
|
pbar = tqdm(iterable=None, disable=not accelerator.is_local_main_process, unit="batch", total=len(loader)) |
|
for gt, lq, _ in loader: |
|
gt = rearrange((gt + 1) / 2, "b h w c -> b c h w").contiguous().float().to(device) |
|
lq = rearrange(lq, "b h w c -> b c h w").contiguous().float().to(device) |
|
pred = swinir(lq) |
|
loss = F.mse_loss(input=pred, target=gt, reduction="sum") |
|
|
|
opt.zero_grad() |
|
accelerator.backward(loss) |
|
opt.step() |
|
accelerator.wait_for_everyone() |
|
|
|
global_step += 1 |
|
step_loss.append(loss.item()) |
|
epoch_loss.append(loss.item()) |
|
pbar.update(1) |
|
pbar.set_description(f"Epoch: {epoch:04d}, Global Step: {global_step:07d}, Loss: {loss.item():.6f}") |
|
|
|
|
|
if global_step % cfg.train.log_every == 0: |
|
|
|
avg_loss = accelerator.gather(torch.tensor(step_loss, device=device).unsqueeze(0)).mean().item() |
|
step_loss.clear() |
|
if accelerator.is_local_main_process: |
|
writer.add_scalar("train/loss_step", avg_loss, global_step) |
|
|
|
|
|
if global_step % cfg.train.ckpt_every == 0: |
|
if accelerator.is_local_main_process: |
|
checkpoint = pure_swinir.state_dict() |
|
ckpt_path = f"{ckpt_dir}/{global_step:07d}.pt" |
|
torch.save(checkpoint, ckpt_path) |
|
|
|
if global_step % cfg.train.image_every == 0 or global_step == 1: |
|
swinir.eval() |
|
N = 12 |
|
log_gt, log_lq = gt[:N], lq[:N] |
|
with torch.no_grad(): |
|
log_pred = swinir(log_lq) |
|
if accelerator.is_local_main_process: |
|
for tag, image in [ |
|
("image/pred", log_pred), |
|
("image/gt", log_gt), |
|
("image/lq", log_lq), |
|
]: |
|
writer.add_image(tag, make_grid(image, nrow=4), global_step) |
|
swinir.train() |
|
|
|
|
|
if global_step % cfg.train.val_every == 0: |
|
swinir.eval() |
|
val_loss = [] |
|
val_lpips = [] |
|
val_psnr = [] |
|
val_pbar = tqdm(iterable=None, disable=not accelerator.is_local_main_process, unit="batch", |
|
total=len(val_loader), leave=False, desc="Validation") |
|
|
|
for val_gt, val_lq, _ in val_loader: |
|
val_gt = rearrange((val_gt + 1) / 2, "b h w c -> b c h w").contiguous().float().to(device) |
|
val_lq = rearrange(val_lq, "b h w c -> b c h w").contiguous().float().to(device) |
|
with torch.no_grad(): |
|
|
|
val_pred = swinir(val_lq) |
|
|
|
val_loss.append(F.mse_loss(input=val_pred, target=val_gt, reduction="sum").item()) |
|
val_lpips.append(lpips_model(val_pred, val_gt, normalize=True).mean().item()) |
|
val_psnr.append(calculate_psnr_pt(val_pred, val_gt, crop_border=0).mean().item()) |
|
val_pbar.update(1) |
|
val_pbar.close() |
|
avg_val_loss = accelerator.gather(torch.tensor(val_loss, device=device).unsqueeze(0)).mean().item() |
|
avg_val_lpips = accelerator.gather(torch.tensor(val_lpips, device=device).unsqueeze(0)).mean().item() |
|
avg_val_psnr = accelerator.gather(torch.tensor(val_psnr, device=device).unsqueeze(0)).mean().item() |
|
if accelerator.is_local_main_process: |
|
for tag, val in [ |
|
("val/loss", avg_val_loss), |
|
("val/lpips", avg_val_lpips), |
|
("val/psnr", avg_val_psnr) |
|
]: |
|
writer.add_scalar(tag, val, global_step) |
|
swinir.train() |
|
|
|
accelerator.wait_for_everyone() |
|
|
|
if global_step == max_steps: |
|
break |
|
|
|
pbar.close() |
|
epoch += 1 |
|
avg_epoch_loss = accelerator.gather(torch.tensor(epoch_loss, device=device).unsqueeze(0)).mean().item() |
|
epoch_loss.clear() |
|
if accelerator.is_local_main_process: |
|
writer.add_scalar("train/loss_epoch", avg_epoch_loss, global_step) |
|
|
|
if accelerator.is_local_main_process: |
|
print("done!") |
|
writer.close() |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = ArgumentParser() |
|
parser.add_argument("--config", type=str, required=True) |
|
args = parser.parse_args() |
|
main(args) |
|
|