diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..f64400b7ca62cceb317e7e4966391439fbfa516a 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
 *.zip filter=lfs diff=lfs merge=lfs -text
 *.zst filter=lfs diff=lfs merge=lfs -text
 *tfevents* filter=lfs diff=lfs merge=lfs -text
+assets/pic/div2k_comparison.jpg filter=lfs diff=lfs merge=lfs -text
+assets/pic/london2.jpg filter=lfs diff=lfs merge=lfs -text
+assets/pic/main_framework.jpg filter=lfs diff=lfs merge=lfs -text
+assets/pic/realsr_vis3.jpg filter=lfs diff=lfs merge=lfs -text
diff --git a/assets/mm-realsr/de_net.pth b/assets/mm-realsr/de_net.pth
new file mode 100644
index 0000000000000000000000000000000000000000..7f74e7a76f9f2dafb1cf2d5f9c7b2ab0a162d059
--- /dev/null
+++ b/assets/mm-realsr/de_net.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a6e77c1cb0dd51e01ef60bbf7b9d85b3f9399c3c1253889ddb73de1436231b1a
+size 9424338
diff --git a/assets/pic/div2k_comparison.jpg b/assets/pic/div2k_comparison.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..ff9d7a9cd189e351bc6a9325dd2d98dd86f8906e
--- /dev/null
+++ b/assets/pic/div2k_comparison.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d52daed3de6be211bda5e1c11b8d9352fee84fca53af53e10b3cddc86036db99
+size 6224264
diff --git a/assets/pic/gradio.png b/assets/pic/gradio.png
new file mode 100644
index 0000000000000000000000000000000000000000..77314c7983ded13f13ea92cd7dd72857b3854112
Binary files /dev/null and b/assets/pic/gradio.png differ
diff --git a/assets/pic/london2.jpg b/assets/pic/london2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..7f0d621c00e06874b6e3a789b028ad9e69211ffc
--- /dev/null
+++ b/assets/pic/london2.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6055d986cff9fe82d97b60f2ed728f00fb81600aa69541a97db80aa73cd906fa
+size 6388509
diff --git a/assets/pic/main_framework.jpg b/assets/pic/main_framework.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..2232fe62b3b14cafc99ad1234c5872b4acce6264
--- /dev/null
+++ b/assets/pic/main_framework.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f089c6f09a44300d36cb3c9b6b4905ca4f801d63ae1907fa83ab66aec70ec893
+size 4684687
diff --git a/assets/pic/realsr_vis3.jpg b/assets/pic/realsr_vis3.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..b94ebf4ed0c82dcbb63b3fd3152d568799123fd8
--- /dev/null
+++ b/assets/pic/realsr_vis3.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b07d0ffe838adf48958e3e910b467e1726d4a7a2f1623a9ae530f8db622fbb06
+size 5329675
diff --git a/basicsr/__init__.py b/basicsr/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..28437544a254656cca7fb7021ef7bbf724cf2879
--- /dev/null
+++ b/basicsr/__init__.py
@@ -0,0 +1,12 @@
+# https://github.com/xinntao/BasicSR
+# flake8: noqa
+from .archs import *
+from .data import *
+from .losses import *
+from .metrics import *
+from .models import *
+from .ops import *
+from .test import *
+from .train import *
+from .utils import *
+# from .version import __gitsha__, __version__
diff --git a/basicsr/archs/__init__.py b/basicsr/archs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..af6bcbd97bb3e4914c3c91dc53e0708bcac66075
--- /dev/null
+++ b/basicsr/archs/__init__.py
@@ -0,0 +1,24 @@
+import importlib
+from copy import deepcopy
+from os import path as osp
+
+from basicsr.utils import get_root_logger, scandir
+from basicsr.utils.registry import ARCH_REGISTRY
+
+__all__ = ['build_network']
+
+# automatically scan and import arch modules for registry
+# scan all the files under the 'archs' folder and collect files ending with '_arch.py'
+arch_folder = osp.dirname(osp.abspath(__file__))
+arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
+# import all the arch modules
+_arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames]
+
+
+def build_network(opt):
+    opt = deepcopy(opt)
+    network_type = opt.pop('type')
+    net = ARCH_REGISTRY.get(network_type)(**opt)
+    logger = get_root_logger()
+    logger.info(f'Network [{net.__class__.__name__}] is created.')
+    return net
diff --git a/basicsr/archs/__pycache__/__init__.cpython-310.pyc b/basicsr/archs/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..42cc1633c7eac9adb8918a45ab3f1750d5586c7c
Binary files /dev/null and b/basicsr/archs/__pycache__/__init__.cpython-310.pyc differ
diff --git a/basicsr/archs/__pycache__/arch_util.cpython-310.pyc b/basicsr/archs/__pycache__/arch_util.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..daa40709b2db10cc5e0d72c611a756f009f8cf8b
Binary files /dev/null and b/basicsr/archs/__pycache__/arch_util.cpython-310.pyc differ
diff --git a/basicsr/archs/__pycache__/basicvsr_arch.cpython-310.pyc b/basicsr/archs/__pycache__/basicvsr_arch.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bd89bdf1ade589e50162e45f0c7846234f27c7c8
Binary files /dev/null and b/basicsr/archs/__pycache__/basicvsr_arch.cpython-310.pyc differ
diff --git a/basicsr/archs/__pycache__/basicvsrpp_arch.cpython-310.pyc b/basicsr/archs/__pycache__/basicvsrpp_arch.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c2ca4bf0f8a5884810417320f39a773d9ad4f9f6
Binary files /dev/null and b/basicsr/archs/__pycache__/basicvsrpp_arch.cpython-310.pyc differ
diff --git a/basicsr/archs/__pycache__/degradat_arch.cpython-310.pyc b/basicsr/archs/__pycache__/degradat_arch.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..46b169e1f8ed6753fc0ae43a00e4eba15cc6dc1f
Binary files /dev/null and b/basicsr/archs/__pycache__/degradat_arch.cpython-310.pyc differ
diff --git a/basicsr/archs/__pycache__/dfdnet_arch.cpython-310.pyc b/basicsr/archs/__pycache__/dfdnet_arch.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f2f2508819ab09acb8be786f7d12ff7bd1778694
Binary files /dev/null and b/basicsr/archs/__pycache__/dfdnet_arch.cpython-310.pyc differ
diff --git a/basicsr/archs/__pycache__/dfdnet_util.cpython-310.pyc b/basicsr/archs/__pycache__/dfdnet_util.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5795c182b4625d4b8f09c1afd14535c58b079e9e
Binary files /dev/null and b/basicsr/archs/__pycache__/dfdnet_util.cpython-310.pyc differ
diff --git a/basicsr/archs/__pycache__/discriminator_arch.cpython-310.pyc b/basicsr/archs/__pycache__/discriminator_arch.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..27f6981aa4d1ef80414ed6ffc16a5b6ee37396b2
Binary files /dev/null and b/basicsr/archs/__pycache__/discriminator_arch.cpython-310.pyc differ
diff --git a/basicsr/archs/__pycache__/duf_arch.cpython-310.pyc b/basicsr/archs/__pycache__/duf_arch.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bdf303bb2727c9745b7f0514e1c8818b8734388e
Binary files /dev/null and b/basicsr/archs/__pycache__/duf_arch.cpython-310.pyc differ
diff --git a/basicsr/archs/__pycache__/ecbsr_arch.cpython-310.pyc b/basicsr/archs/__pycache__/ecbsr_arch.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..370f1151128f00644f94748692a2d49a64976354
Binary files /dev/null and b/basicsr/archs/__pycache__/ecbsr_arch.cpython-310.pyc differ
diff --git a/basicsr/archs/__pycache__/edsr_arch.cpython-310.pyc b/basicsr/archs/__pycache__/edsr_arch.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ba9d41be046570128c961352d76e3cb242536f1d
Binary files /dev/null and b/basicsr/archs/__pycache__/edsr_arch.cpython-310.pyc differ
diff --git a/basicsr/archs/__pycache__/edvr_arch.cpython-310.pyc b/basicsr/archs/__pycache__/edvr_arch.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..50c9b74e1ea2520e762a73f0233061cf657aaee6
Binary files /dev/null and b/basicsr/archs/__pycache__/edvr_arch.cpython-310.pyc differ
diff --git a/basicsr/archs/__pycache__/hifacegan_arch.cpython-310.pyc b/basicsr/archs/__pycache__/hifacegan_arch.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9f1a15d865cac226c02d1e6df6ef6ca08dab49f7
Binary files /dev/null and b/basicsr/archs/__pycache__/hifacegan_arch.cpython-310.pyc differ
diff --git a/basicsr/archs/__pycache__/hifacegan_util.cpython-310.pyc b/basicsr/archs/__pycache__/hifacegan_util.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8ec83d83b4392a7aae3aa37b4a1050c61e9a50af
Binary files /dev/null and b/basicsr/archs/__pycache__/hifacegan_util.cpython-310.pyc differ
diff --git a/basicsr/archs/__pycache__/rcan_arch.cpython-310.pyc b/basicsr/archs/__pycache__/rcan_arch.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..624274a671dd1f7244d6e36cb955a31fd22102c5
Binary files /dev/null and b/basicsr/archs/__pycache__/rcan_arch.cpython-310.pyc differ
diff --git a/basicsr/archs/__pycache__/ridnet_arch.cpython-310.pyc b/basicsr/archs/__pycache__/ridnet_arch.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..60b9c40c34382099eab3b639cf8c6f63b2751d8a
Binary files /dev/null and b/basicsr/archs/__pycache__/ridnet_arch.cpython-310.pyc differ
diff --git a/basicsr/archs/__pycache__/rrdbnet_arch.cpython-310.pyc b/basicsr/archs/__pycache__/rrdbnet_arch.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..72b0e92593f0a2b4dc0aa33d78dbe85553778f60
Binary files /dev/null and b/basicsr/archs/__pycache__/rrdbnet_arch.cpython-310.pyc differ
diff --git a/basicsr/archs/__pycache__/spynet_arch.cpython-310.pyc b/basicsr/archs/__pycache__/spynet_arch.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1163d4f73b0bb7acd1c41b6b205f43a2b2ca69bf
Binary files /dev/null and b/basicsr/archs/__pycache__/spynet_arch.cpython-310.pyc differ
diff --git a/basicsr/archs/__pycache__/srresnet_arch.cpython-310.pyc b/basicsr/archs/__pycache__/srresnet_arch.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..baf26142f38dc736cf702d2083f33402f55e9d73
Binary files /dev/null and b/basicsr/archs/__pycache__/srresnet_arch.cpython-310.pyc differ
diff --git a/basicsr/archs/__pycache__/srvgg_arch.cpython-310.pyc b/basicsr/archs/__pycache__/srvgg_arch.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..55dba13fb752495c0ef9769d1b1b74499f5f0e1c
Binary files /dev/null and b/basicsr/archs/__pycache__/srvgg_arch.cpython-310.pyc differ
diff --git a/basicsr/archs/__pycache__/stylegan2_arch.cpython-310.pyc b/basicsr/archs/__pycache__/stylegan2_arch.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bea2f163422d4c8a1dc4002622188276c15b5e6e
Binary files /dev/null and b/basicsr/archs/__pycache__/stylegan2_arch.cpython-310.pyc differ
diff --git a/basicsr/archs/__pycache__/stylegan2_bilinear_arch.cpython-310.pyc b/basicsr/archs/__pycache__/stylegan2_bilinear_arch.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2759ea197d29ec6e60a2026ad90fa705e4fde70f
Binary files /dev/null and b/basicsr/archs/__pycache__/stylegan2_bilinear_arch.cpython-310.pyc differ
diff --git a/basicsr/archs/__pycache__/swinir_arch.cpython-310.pyc b/basicsr/archs/__pycache__/swinir_arch.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ccfafa0b0c9f270ada130831a489c34e31285a5d
Binary files /dev/null and b/basicsr/archs/__pycache__/swinir_arch.cpython-310.pyc differ
diff --git a/basicsr/archs/__pycache__/tof_arch.cpython-310.pyc b/basicsr/archs/__pycache__/tof_arch.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6b292810333816f0589ea6a48282fb761715d25c
Binary files /dev/null and b/basicsr/archs/__pycache__/tof_arch.cpython-310.pyc differ
diff --git a/basicsr/archs/__pycache__/vgg_arch.cpython-310.pyc b/basicsr/archs/__pycache__/vgg_arch.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..49c328416326c40e0a93165d20fc842e6eb39140
Binary files /dev/null and b/basicsr/archs/__pycache__/vgg_arch.cpython-310.pyc differ
diff --git a/basicsr/archs/arch_util.py b/basicsr/archs/arch_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..07fd7762814136c0bf5d34432f46722723e68e3f
--- /dev/null
+++ b/basicsr/archs/arch_util.py
@@ -0,0 +1,355 @@
+import collections.abc
+import math
+import torch
+import torchvision
+import warnings
+from distutils.version import LooseVersion
+from itertools import repeat
+from torch import nn as nn
+from torch.nn import functional as F
+from torch.nn import init as init
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv
+from basicsr.utils import get_root_logger
+
+
+@torch.no_grad()
+def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
+    """Initialize network weights.
+
+    Args:
+        module_list (list[nn.Module] | nn.Module): Modules to be initialized.
+        scale (float): Scale initialized weights, especially for residual
+            blocks. Default: 1.
+        bias_fill (float): The value to fill bias. Default: 0
+        kwargs (dict): Other arguments for initialization function.
+    """
+    if not isinstance(module_list, list):
+        module_list = [module_list]
+    for module in module_list:
+        for m in module.modules():
+            if isinstance(m, nn.Conv2d):
+                init.kaiming_normal_(m.weight, **kwargs)
+                m.weight.data *= scale
+                if m.bias is not None:
+                    m.bias.data.fill_(bias_fill)
+            elif isinstance(m, nn.Linear):
+                init.kaiming_normal_(m.weight, **kwargs)
+                m.weight.data *= scale
+                if m.bias is not None:
+                    m.bias.data.fill_(bias_fill)
+            elif isinstance(m, _BatchNorm):
+                init.constant_(m.weight, 1)
+                if m.bias is not None:
+                    m.bias.data.fill_(bias_fill)
+
+
+def make_layer(basic_block, num_basic_block, **kwarg):
+    """Make layers by stacking the same blocks.
+
+    Args:
+        basic_block (nn.module): nn.module class for basic block.
+        num_basic_block (int): number of blocks.
+
+    Returns:
+        nn.Sequential: Stacked blocks in nn.Sequential.
+    """
+    layers = []
+    for _ in range(num_basic_block):
+        layers.append(basic_block(**kwarg))
+    return nn.Sequential(*layers)
+
+class PixelShufflePack(nn.Module):
+    """Pixel Shuffle upsample layer.
+    Args:
+        in_channels (int): Number of input channels.
+        out_channels (int): Number of output channels.
+        scale_factor (int): Upsample ratio.
+        upsample_kernel (int): Kernel size of Conv layer to expand channels.
+    Returns:
+        Upsampled feature map.
+    """
+
+    def __init__(self, in_channels, out_channels, scale_factor,
+                 upsample_kernel):
+        super().__init__()
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.scale_factor = scale_factor
+        self.upsample_kernel = upsample_kernel
+        self.upsample_conv = nn.Conv2d(
+            self.in_channels,
+            self.out_channels * scale_factor * scale_factor,
+            self.upsample_kernel,
+            padding=(self.upsample_kernel - 1) // 2)
+        self.init_weights()
+
+    def init_weights(self):
+        """Initialize weights for PixelShufflePack."""
+        default_init_weights(self, 1)
+
+    def forward(self, x):
+        """Forward function for PixelShufflePack.
+        Args:
+            x (Tensor): Input tensor with shape (n, c, h, w).
+        Returns:
+            Tensor: Forward results.
+        """
+        x = self.upsample_conv(x)
+        x = F.pixel_shuffle(x, self.scale_factor)
+        return x
+
+class ResidualBlockNoBN(nn.Module):
+    """Residual block without BN.
+
+    Args:
+        num_feat (int): Channel number of intermediate features.
+            Default: 64.
+        res_scale (float): Residual scale. Default: 1.
+        pytorch_init (bool): If set to True, use pytorch default init,
+            otherwise, use default_init_weights. Default: False.
+    """
+
+    def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
+        super(ResidualBlockNoBN, self).__init__()
+        self.res_scale = res_scale
+        self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
+        self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
+        self.relu = nn.ReLU()
+
+        if not pytorch_init:
+            default_init_weights([self.conv1, self.conv2], 0.1)
+
+    def forward(self, x):
+        identity = x
+        x = self.conv1(x)
+        x = self.relu(x)
+        out = self.conv2(x)
+        # out = self.conv2(self.relu(self.conv1(x)))
+        return identity + out * self.res_scale
+
+
+class Upsample(nn.Sequential):
+    """Upsample module.
+
+    Args:
+        scale (int): Scale factor. Supported scales: 2^n and 3.
+        num_feat (int): Channel number of intermediate features.
+    """
+
+    def __init__(self, scale, num_feat):
+        m = []
+        if (scale & (scale - 1)) == 0:  # scale = 2^n
+            for _ in range(int(math.log(scale, 2))):
+                m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
+                m.append(nn.PixelShuffle(2))
+        elif scale == 3:
+            m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
+            m.append(nn.PixelShuffle(3))
+        else:
+            raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
+        super(Upsample, self).__init__(*m)
+
+
+def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
+    """Warp an image or feature map with optical flow.
+
+    Args:
+        x (Tensor): Tensor with size (n, c, h, w).
+        flow (Tensor): Tensor with size (n, h, w, 2), normal value.
+        interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
+        padding_mode (str): 'zeros' or 'border' or 'reflection'.
+            Default: 'zeros'.
+        align_corners (bool): Before pytorch 1.3, the default value is
+            align_corners=True. After pytorch 1.3, the default value is
+            align_corners=False. Here, we use the True as default.
+
+    Returns:
+        Tensor: Warped image or feature map.
+    """
+    assert x.size()[-2:] == flow.size()[1:3]
+    _, _, h, w = x.size()
+    # create mesh grid
+    grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
+    grid = torch.stack((grid_x, grid_y), 2).float()  # W(x), H(y), 2
+    grid.requires_grad = False
+
+    vgrid = grid + flow
+    # scale grid to [-1,1]
+    vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
+    vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
+    vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
+    output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
+
+    # TODO, what if align_corners=False
+    return output
+
+
+def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
+    """Resize a flow according to ratio or shape.
+
+    Args:
+        flow (Tensor): Precomputed flow. shape [N, 2, H, W].
+        size_type (str): 'ratio' or 'shape'.
+        sizes (list[int | float]): the ratio for resizing or the final output
+            shape.
+            1) The order of ratio should be [ratio_h, ratio_w]. For
+            downsampling, the ratio should be smaller than 1.0 (i.e., ratio
+            < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
+            ratio > 1.0).
+            2) The order of output_size should be [out_h, out_w].
+        interp_mode (str): The mode of interpolation for resizing.
+            Default: 'bilinear'.
+        align_corners (bool): Whether align corners. Default: False.
+
+    Returns:
+        Tensor: Resized flow.
+    """
+    _, _, flow_h, flow_w = flow.size()
+    if size_type == 'ratio':
+        output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
+    elif size_type == 'shape':
+        output_h, output_w = sizes[0], sizes[1]
+    else:
+        raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
+
+    input_flow = flow.clone()
+    ratio_h = output_h / flow_h
+    ratio_w = output_w / flow_w
+    input_flow[:, 0, :, :] *= ratio_w
+    input_flow[:, 1, :, :] *= ratio_h
+    resized_flow = F.interpolate(
+        input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
+    return resized_flow
+
+
+# TODO: may write a cpp file
+def pixel_unshuffle(x, scale):
+    """ Pixel unshuffle.
+
+    Args:
+        x (Tensor): Input feature with shape (b, c, hh, hw).
+        scale (int): Downsample ratio.
+
+    Returns:
+        Tensor: the pixel unshuffled feature.
+    """
+    b, c, hh, hw = x.size()
+    out_channel = c * (scale**2)
+    assert hh % scale == 0 and hw % scale == 0
+    h = hh // scale
+    w = hw // scale
+    x_view = x.view(b, c, h, scale, w, scale)
+    return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
+
+
+class DCNv2Pack(ModulatedDeformConvPack):
+    """Modulated deformable conv for deformable alignment.
+
+    Different from the official DCNv2Pack, which generates offsets and masks
+    from the preceding features, this DCNv2Pack takes another different
+    features to generate offsets and masks.
+
+    ``Paper: Delving Deep into Deformable Alignment in Video Super-Resolution``
+    """
+
+    def forward(self, x, feat):
+        out = self.conv_offset(feat)
+        o1, o2, mask = torch.chunk(out, 3, dim=1)
+        offset = torch.cat((o1, o2), dim=1)
+        mask = torch.sigmoid(mask)
+
+        offset_absmean = torch.mean(torch.abs(offset))
+        if offset_absmean > 50:
+            logger = get_root_logger()
+            logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.')
+
+        if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'):
+            return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
+                                                 self.dilation, mask)
+        else:
+            return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding,
+                                         self.dilation, self.groups, self.deformable_groups)
+
+
+def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+    # From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
+    # Cut & paste from PyTorch official master until it's in a few official releases - RW
+    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+    def norm_cdf(x):
+        # Computes standard normal cumulative distribution function
+        return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+    if (mean < a - 2 * std) or (mean > b + 2 * std):
+        warnings.warn(
+            'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
+            'The distribution of values may be incorrect.',
+            stacklevel=2)
+
+    with torch.no_grad():
+        # Values are generated by using a truncated uniform distribution and
+        # then using the inverse CDF for the normal distribution.
+        # Get upper and lower cdf values
+        low = norm_cdf((a - mean) / std)
+        up = norm_cdf((b - mean) / std)
+
+        # Uniformly fill tensor with values from [low, up], then translate to
+        # [2l-1, 2u-1].
+        tensor.uniform_(2 * low - 1, 2 * up - 1)
+
+        # Use inverse cdf transform for normal distribution to get truncated
+        # standard normal
+        tensor.erfinv_()
+
+        # Transform to proper mean, std
+        tensor.mul_(std * math.sqrt(2.))
+        tensor.add_(mean)
+
+        # Clamp to ensure it's in the proper range
+        tensor.clamp_(min=a, max=b)
+        return tensor
+
+
+def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
+    r"""Fills the input Tensor with values drawn from a truncated
+    normal distribution.
+
+    From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
+
+    The values are effectively drawn from the
+    normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+    with values outside :math:`[a, b]` redrawn until they are within
+    the bounds. The method used for generating the random values works
+    best when :math:`a \leq \text{mean} \leq b`.
+
+    Args:
+        tensor: an n-dimensional `torch.Tensor`
+        mean: the mean of the normal distribution
+        std: the standard deviation of the normal distribution
+        a: the minimum cutoff value
+        b: the maximum cutoff value
+
+    Examples:
+        >>> w = torch.empty(3, 5)
+        >>> nn.init.trunc_normal_(w)
+    """
+    return _no_grad_trunc_normal_(tensor, mean, std, a, b)
+
+
+# From PyTorch
+def _ntuple(n):
+
+    def parse(x):
+        if isinstance(x, collections.abc.Iterable):
+            return x
+        return tuple(repeat(x, n))
+
+    return parse
+
+
+to_1tuple = _ntuple(1)
+to_2tuple = _ntuple(2)
+to_3tuple = _ntuple(3)
+to_4tuple = _ntuple(4)
+to_ntuple = _ntuple
diff --git a/basicsr/archs/basicvsr_arch.py b/basicsr/archs/basicvsr_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed7b824eae108a9bcca57f1c14dd0d8afafc4f58
--- /dev/null
+++ b/basicsr/archs/basicvsr_arch.py
@@ -0,0 +1,336 @@
+import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+from basicsr.utils.registry import ARCH_REGISTRY
+from .arch_util import ResidualBlockNoBN, flow_warp, make_layer
+from .edvr_arch import PCDAlignment, TSAFusion
+from .spynet_arch import SpyNet
+
+
+@ARCH_REGISTRY.register()
+class BasicVSR(nn.Module):
+    """A recurrent network for video SR. Now only x4 is supported.
+
+    Args:
+        num_feat (int): Number of channels. Default: 64.
+        num_block (int): Number of residual blocks for each branch. Default: 15
+        spynet_path (str): Path to the pretrained weights of SPyNet. Default: None.
+    """
+
+    def __init__(self, num_feat=64, num_block=15, spynet_path=None):
+        super().__init__()
+        self.num_feat = num_feat
+
+        # alignment
+        self.spynet = SpyNet(spynet_path)
+
+        # propagation
+        self.backward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block)
+        self.forward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block)
+
+        # reconstruction
+        self.fusion = nn.Conv2d(num_feat * 2, num_feat, 1, 1, 0, bias=True)
+        self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1, bias=True)
+        self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1, bias=True)
+        self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
+        self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
+
+        self.pixel_shuffle = nn.PixelShuffle(2)
+
+        # activation functions
+        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+
+    def get_flow(self, x):
+        b, n, c, h, w = x.size()
+
+        x_1 = x[:, :-1, :, :, :].reshape(-1, c, h, w)
+        x_2 = x[:, 1:, :, :, :].reshape(-1, c, h, w)
+
+        flows_backward = self.spynet(x_1, x_2).view(b, n - 1, 2, h, w)
+        flows_forward = self.spynet(x_2, x_1).view(b, n - 1, 2, h, w)
+
+        return flows_forward, flows_backward
+
+    def forward(self, x):
+        """Forward function of BasicVSR.
+
+        Args:
+            x: Input frames with shape (b, n, c, h, w). n is the temporal dimension / number of frames.
+        """
+        flows_forward, flows_backward = self.get_flow(x)
+        b, n, _, h, w = x.size()
+
+        # backward branch
+        out_l = []
+        feat_prop = x.new_zeros(b, self.num_feat, h, w)
+        for i in range(n - 1, -1, -1):
+            x_i = x[:, i, :, :, :]
+            if i < n - 1:
+                flow = flows_backward[:, i, :, :, :]
+                feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
+            feat_prop = torch.cat([x_i, feat_prop], dim=1)
+            feat_prop = self.backward_trunk(feat_prop)
+            out_l.insert(0, feat_prop)
+
+        # forward branch
+        feat_prop = torch.zeros_like(feat_prop)
+        for i in range(0, n):
+            x_i = x[:, i, :, :, :]
+            if i > 0:
+                flow = flows_forward[:, i - 1, :, :, :]
+                feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
+
+            feat_prop = torch.cat([x_i, feat_prop], dim=1)
+            feat_prop = self.forward_trunk(feat_prop)
+
+            # upsample
+            out = torch.cat([out_l[i], feat_prop], dim=1)
+            out = self.lrelu(self.fusion(out))
+            out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
+            out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
+            out = self.lrelu(self.conv_hr(out))
+            out = self.conv_last(out)
+            base = F.interpolate(x_i, scale_factor=4, mode='bilinear', align_corners=False)
+            out += base
+            out_l[i] = out
+
+        return torch.stack(out_l, dim=1)
+
+
+class ConvResidualBlocks(nn.Module):
+    """Conv and residual block used in BasicVSR.
+
+    Args:
+        num_in_ch (int): Number of input channels. Default: 3.
+        num_out_ch (int): Number of output channels. Default: 64.
+        num_block (int): Number of residual blocks. Default: 15.
+    """
+
+    def __init__(self, num_in_ch=3, num_out_ch=64, num_block=15):
+        super().__init__()
+        self.main = nn.Sequential(
+            nn.Conv2d(num_in_ch, num_out_ch, 3, 1, 1, bias=True), nn.LeakyReLU(negative_slope=0.1, inplace=True),
+            make_layer(ResidualBlockNoBN, num_block, num_feat=num_out_ch))
+
+    def forward(self, fea):
+        return self.main(fea)
+
+
+@ARCH_REGISTRY.register()
+class IconVSR(nn.Module):
+    """IconVSR, proposed also in the BasicVSR paper.
+
+    Args:
+        num_feat (int): Number of channels. Default: 64.
+        num_block (int): Number of residual blocks for each branch. Default: 15.
+        keyframe_stride (int): Keyframe stride. Default: 5.
+        temporal_padding (int): Temporal padding. Default: 2.
+        spynet_path (str): Path to the pretrained weights of SPyNet. Default: None.
+        edvr_path (str): Path to the pretrained EDVR model. Default: None.
+    """
+
+    def __init__(self,
+                 num_feat=64,
+                 num_block=15,
+                 keyframe_stride=5,
+                 temporal_padding=2,
+                 spynet_path=None,
+                 edvr_path=None):
+        super().__init__()
+
+        self.num_feat = num_feat
+        self.temporal_padding = temporal_padding
+        self.keyframe_stride = keyframe_stride
+
+        # keyframe_branch
+        self.edvr = EDVRFeatureExtractor(temporal_padding * 2 + 1, num_feat, edvr_path)
+        # alignment
+        self.spynet = SpyNet(spynet_path)
+
+        # propagation
+        self.backward_fusion = nn.Conv2d(2 * num_feat, num_feat, 3, 1, 1, bias=True)
+        self.backward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block)
+
+        self.forward_fusion = nn.Conv2d(2 * num_feat, num_feat, 3, 1, 1, bias=True)
+        self.forward_trunk = ConvResidualBlocks(2 * num_feat + 3, num_feat, num_block)
+
+        # reconstruction
+        self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1, bias=True)
+        self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1, bias=True)
+        self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
+        self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
+
+        self.pixel_shuffle = nn.PixelShuffle(2)
+
+        # activation functions
+        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+
+    def pad_spatial(self, x):
+        """Apply padding spatially.
+
+        Since the PCD module in EDVR requires that the resolution is a multiple
+        of 4, we apply padding to the input LR images if their resolution is
+        not divisible by 4.
+
+        Args:
+            x (Tensor): Input LR sequence with shape (n, t, c, h, w).
+        Returns:
+            Tensor: Padded LR sequence with shape (n, t, c, h_pad, w_pad).
+        """
+        n, t, c, h, w = x.size()
+
+        pad_h = (4 - h % 4) % 4
+        pad_w = (4 - w % 4) % 4
+
+        # padding
+        x = x.view(-1, c, h, w)
+        x = F.pad(x, [0, pad_w, 0, pad_h], mode='reflect')
+
+        return x.view(n, t, c, h + pad_h, w + pad_w)
+
+    def get_flow(self, x):
+        b, n, c, h, w = x.size()
+
+        x_1 = x[:, :-1, :, :, :].reshape(-1, c, h, w)
+        x_2 = x[:, 1:, :, :, :].reshape(-1, c, h, w)
+
+        flows_backward = self.spynet(x_1, x_2).view(b, n - 1, 2, h, w)
+        flows_forward = self.spynet(x_2, x_1).view(b, n - 1, 2, h, w)
+
+        return flows_forward, flows_backward
+
+    def get_keyframe_feature(self, x, keyframe_idx):
+        if self.temporal_padding == 2:
+            x = [x[:, [4, 3]], x, x[:, [-4, -5]]]
+        elif self.temporal_padding == 3:
+            x = [x[:, [6, 5, 4]], x, x[:, [-5, -6, -7]]]
+        x = torch.cat(x, dim=1)
+
+        num_frames = 2 * self.temporal_padding + 1
+        feats_keyframe = {}
+        for i in keyframe_idx:
+            feats_keyframe[i] = self.edvr(x[:, i:i + num_frames].contiguous())
+        return feats_keyframe
+
+    def forward(self, x):
+        b, n, _, h_input, w_input = x.size()
+
+        x = self.pad_spatial(x)
+        h, w = x.shape[3:]
+
+        keyframe_idx = list(range(0, n, self.keyframe_stride))
+        if keyframe_idx[-1] != n - 1:
+            keyframe_idx.append(n - 1)  # last frame is a keyframe
+
+        # compute flow and keyframe features
+        flows_forward, flows_backward = self.get_flow(x)
+        feats_keyframe = self.get_keyframe_feature(x, keyframe_idx)
+
+        # backward branch
+        out_l = []
+        feat_prop = x.new_zeros(b, self.num_feat, h, w)
+        for i in range(n - 1, -1, -1):
+            x_i = x[:, i, :, :, :]
+            if i < n - 1:
+                flow = flows_backward[:, i, :, :, :]
+                feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
+            if i in keyframe_idx:
+                feat_prop = torch.cat([feat_prop, feats_keyframe[i]], dim=1)
+                feat_prop = self.backward_fusion(feat_prop)
+            feat_prop = torch.cat([x_i, feat_prop], dim=1)
+            feat_prop = self.backward_trunk(feat_prop)
+            out_l.insert(0, feat_prop)
+
+        # forward branch
+        feat_prop = torch.zeros_like(feat_prop)
+        for i in range(0, n):
+            x_i = x[:, i, :, :, :]
+            if i > 0:
+                flow = flows_forward[:, i - 1, :, :, :]
+                feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
+            if i in keyframe_idx:
+                feat_prop = torch.cat([feat_prop, feats_keyframe[i]], dim=1)
+                feat_prop = self.forward_fusion(feat_prop)
+
+            feat_prop = torch.cat([x_i, out_l[i], feat_prop], dim=1)
+            feat_prop = self.forward_trunk(feat_prop)
+
+            # upsample
+            out = self.lrelu(self.pixel_shuffle(self.upconv1(feat_prop)))
+            out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
+            out = self.lrelu(self.conv_hr(out))
+            out = self.conv_last(out)
+            base = F.interpolate(x_i, scale_factor=4, mode='bilinear', align_corners=False)
+            out += base
+            out_l[i] = out
+
+        return torch.stack(out_l, dim=1)[..., :4 * h_input, :4 * w_input]
+
+
+class EDVRFeatureExtractor(nn.Module):
+    """EDVR feature extractor used in IconVSR.
+
+    Args:
+        num_input_frame (int): Number of input frames.
+        num_feat (int): Number of feature channels
+        load_path (str): Path to the pretrained weights of EDVR. Default: None.
+    """
+
+    def __init__(self, num_input_frame, num_feat, load_path):
+
+        super(EDVRFeatureExtractor, self).__init__()
+
+        self.center_frame_idx = num_input_frame // 2
+
+        # extract pyramid features
+        self.conv_first = nn.Conv2d(3, num_feat, 3, 1, 1)
+        self.feature_extraction = make_layer(ResidualBlockNoBN, 5, num_feat=num_feat)
+        self.conv_l2_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
+        self.conv_l2_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+        self.conv_l3_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
+        self.conv_l3_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+
+        # pcd and tsa module
+        self.pcd_align = PCDAlignment(num_feat=num_feat, deformable_groups=8)
+        self.fusion = TSAFusion(num_feat=num_feat, num_frame=num_input_frame, center_frame_idx=self.center_frame_idx)
+
+        # activation function
+        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+
+        if load_path:
+            self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params'])
+
+    def forward(self, x):
+        b, n, c, h, w = x.size()
+
+        # extract features for each frame
+        # L1
+        feat_l1 = self.lrelu(self.conv_first(x.view(-1, c, h, w)))
+        feat_l1 = self.feature_extraction(feat_l1)
+        # L2
+        feat_l2 = self.lrelu(self.conv_l2_1(feat_l1))
+        feat_l2 = self.lrelu(self.conv_l2_2(feat_l2))
+        # L3
+        feat_l3 = self.lrelu(self.conv_l3_1(feat_l2))
+        feat_l3 = self.lrelu(self.conv_l3_2(feat_l3))
+
+        feat_l1 = feat_l1.view(b, n, -1, h, w)
+        feat_l2 = feat_l2.view(b, n, -1, h // 2, w // 2)
+        feat_l3 = feat_l3.view(b, n, -1, h // 4, w // 4)
+
+        # PCD alignment
+        ref_feat_l = [  # reference feature list
+            feat_l1[:, self.center_frame_idx, :, :, :].clone(), feat_l2[:, self.center_frame_idx, :, :, :].clone(),
+            feat_l3[:, self.center_frame_idx, :, :, :].clone()
+        ]
+        aligned_feat = []
+        for i in range(n):
+            nbr_feat_l = [  # neighboring feature list
+                feat_l1[:, i, :, :, :].clone(), feat_l2[:, i, :, :, :].clone(), feat_l3[:, i, :, :, :].clone()
+            ]
+            aligned_feat.append(self.pcd_align(nbr_feat_l, ref_feat_l))
+        aligned_feat = torch.stack(aligned_feat, dim=1)  # (b, t, c, h, w)
+
+        # TSA fusion
+        return self.fusion(aligned_feat)
diff --git a/basicsr/archs/basicvsrpp_arch.py b/basicsr/archs/basicvsrpp_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a9952e4b441de0030d665a3db141774184f332f
--- /dev/null
+++ b/basicsr/archs/basicvsrpp_arch.py
@@ -0,0 +1,417 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision
+import warnings
+
+from basicsr.archs.arch_util import flow_warp
+from basicsr.archs.basicvsr_arch import ConvResidualBlocks
+from basicsr.archs.spynet_arch import SpyNet
+from basicsr.ops.dcn import ModulatedDeformConvPack
+from basicsr.utils.registry import ARCH_REGISTRY
+
+
+@ARCH_REGISTRY.register()
+class BasicVSRPlusPlus(nn.Module):
+    """BasicVSR++ network structure.
+
+    Support either x4 upsampling or same size output. Since DCN is used in this
+    model, it can only be used with CUDA enabled. If CUDA is not enabled,
+    feature alignment will be skipped. Besides, we adopt the official DCN
+    implementation and the version of torch need to be higher than 1.9.
+
+    ``Paper: BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment``
+
+    Args:
+        mid_channels (int, optional): Channel number of the intermediate
+            features. Default: 64.
+        num_blocks (int, optional): The number of residual blocks in each
+            propagation branch. Default: 7.
+        max_residue_magnitude (int): The maximum magnitude of the offset
+            residue (Eq. 6 in paper). Default: 10.
+        is_low_res_input (bool, optional): Whether the input is low-resolution
+            or not. If False, the output resolution is equal to the input
+            resolution. Default: True.
+        spynet_path (str): Path to the pretrained weights of SPyNet. Default: None.
+        cpu_cache_length (int, optional): When the length of sequence is larger
+            than this value, the intermediate features are sent to CPU. This
+            saves GPU memory, but slows down the inference speed. You can
+            increase this number if you have a GPU with large memory.
+            Default: 100.
+    """
+
+    def __init__(self,
+                 mid_channels=64,
+                 num_blocks=7,
+                 max_residue_magnitude=10,
+                 is_low_res_input=True,
+                 spynet_path=None,
+                 cpu_cache_length=100):
+
+        super().__init__()
+        self.mid_channels = mid_channels
+        self.is_low_res_input = is_low_res_input
+        self.cpu_cache_length = cpu_cache_length
+
+        # optical flow
+        self.spynet = SpyNet(spynet_path)
+
+        # feature extraction module
+        if is_low_res_input:
+            self.feat_extract = ConvResidualBlocks(3, mid_channels, 5)
+        else:
+            self.feat_extract = nn.Sequential(
+                nn.Conv2d(3, mid_channels, 3, 2, 1), nn.LeakyReLU(negative_slope=0.1, inplace=True),
+                nn.Conv2d(mid_channels, mid_channels, 3, 2, 1), nn.LeakyReLU(negative_slope=0.1, inplace=True),
+                ConvResidualBlocks(mid_channels, mid_channels, 5))
+
+        # propagation branches
+        self.deform_align = nn.ModuleDict()
+        self.backbone = nn.ModuleDict()
+        modules = ['backward_1', 'forward_1', 'backward_2', 'forward_2']
+        for i, module in enumerate(modules):
+            if torch.cuda.is_available():
+                self.deform_align[module] = SecondOrderDeformableAlignment(
+                    2 * mid_channels,
+                    mid_channels,
+                    3,
+                    padding=1,
+                    deformable_groups=16,
+                    max_residue_magnitude=max_residue_magnitude)
+            self.backbone[module] = ConvResidualBlocks((2 + i) * mid_channels, mid_channels, num_blocks)
+
+        # upsampling module
+        self.reconstruction = ConvResidualBlocks(5 * mid_channels, mid_channels, 5)
+
+        self.upconv1 = nn.Conv2d(mid_channels, mid_channels * 4, 3, 1, 1, bias=True)
+        self.upconv2 = nn.Conv2d(mid_channels, 64 * 4, 3, 1, 1, bias=True)
+
+        self.pixel_shuffle = nn.PixelShuffle(2)
+
+        self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
+        self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
+        self.img_upsample = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False)
+
+        # activation function
+        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+
+        # check if the sequence is augmented by flipping
+        self.is_mirror_extended = False
+
+        if len(self.deform_align) > 0:
+            self.is_with_alignment = True
+        else:
+            self.is_with_alignment = False
+            warnings.warn('Deformable alignment module is not added. '
+                          'Probably your CUDA is not configured correctly. DCN can only '
+                          'be used with CUDA enabled. Alignment is skipped now.')
+
+    def check_if_mirror_extended(self, lqs):
+        """Check whether the input is a mirror-extended sequence.
+
+        If mirror-extended, the i-th (i=0, ..., t-1) frame is equal to the (t-1-i)-th frame.
+
+        Args:
+            lqs (tensor): Input low quality (LQ) sequence with shape (n, t, c, h, w).
+        """
+
+        if lqs.size(1) % 2 == 0:
+            lqs_1, lqs_2 = torch.chunk(lqs, 2, dim=1)
+            if torch.norm(lqs_1 - lqs_2.flip(1)) == 0:
+                self.is_mirror_extended = True
+
+    def compute_flow(self, lqs):
+        """Compute optical flow using SPyNet for feature alignment.
+
+        Note that if the input is an mirror-extended sequence, 'flows_forward'
+        is not needed, since it is equal to 'flows_backward.flip(1)'.
+
+        Args:
+            lqs (tensor): Input low quality (LQ) sequence with
+                shape (n, t, c, h, w).
+
+        Return:
+            tuple(Tensor): Optical flow. 'flows_forward' corresponds to the flows used for forward-time propagation \
+                (current to previous). 'flows_backward' corresponds to the flows used for backward-time \
+                propagation (current to next).
+        """
+
+        n, t, c, h, w = lqs.size()
+        lqs_1 = lqs[:, :-1, :, :, :].reshape(-1, c, h, w)
+        lqs_2 = lqs[:, 1:, :, :, :].reshape(-1, c, h, w)
+
+        flows_backward = self.spynet(lqs_1, lqs_2).view(n, t - 1, 2, h, w)
+
+        if self.is_mirror_extended:  # flows_forward = flows_backward.flip(1)
+            flows_forward = flows_backward.flip(1)
+        else:
+            flows_forward = self.spynet(lqs_2, lqs_1).view(n, t - 1, 2, h, w)
+
+        if self.cpu_cache:
+            flows_backward = flows_backward.cpu()
+            flows_forward = flows_forward.cpu()
+
+        return flows_forward, flows_backward
+
+    def propagate(self, feats, flows, module_name):
+        """Propagate the latent features throughout the sequence.
+
+        Args:
+            feats dict(list[tensor]): Features from previous branches. Each
+                component is a list of tensors with shape (n, c, h, w).
+            flows (tensor): Optical flows with shape (n, t - 1, 2, h, w).
+            module_name (str): The name of the propgation branches. Can either
+                be 'backward_1', 'forward_1', 'backward_2', 'forward_2'.
+
+        Return:
+            dict(list[tensor]): A dictionary containing all the propagated \
+                features. Each key in the dictionary corresponds to a \
+                propagation branch, which is represented by a list of tensors.
+        """
+
+        n, t, _, h, w = flows.size()
+
+        frame_idx = range(0, t + 1)
+        flow_idx = range(-1, t)
+        mapping_idx = list(range(0, len(feats['spatial'])))
+        mapping_idx += mapping_idx[::-1]
+
+        if 'backward' in module_name:
+            frame_idx = frame_idx[::-1]
+            flow_idx = frame_idx
+
+        feat_prop = flows.new_zeros(n, self.mid_channels, h, w)
+        for i, idx in enumerate(frame_idx):
+            feat_current = feats['spatial'][mapping_idx[idx]]
+            if self.cpu_cache:
+                feat_current = feat_current.cuda()
+                feat_prop = feat_prop.cuda()
+            # second-order deformable alignment
+            if i > 0 and self.is_with_alignment:
+                flow_n1 = flows[:, flow_idx[i], :, :, :]
+                if self.cpu_cache:
+                    flow_n1 = flow_n1.cuda()
+
+                cond_n1 = flow_warp(feat_prop, flow_n1.permute(0, 2, 3, 1))
+
+                # initialize second-order features
+                feat_n2 = torch.zeros_like(feat_prop)
+                flow_n2 = torch.zeros_like(flow_n1)
+                cond_n2 = torch.zeros_like(cond_n1)
+
+                if i > 1:  # second-order features
+                    feat_n2 = feats[module_name][-2]
+                    if self.cpu_cache:
+                        feat_n2 = feat_n2.cuda()
+
+                    flow_n2 = flows[:, flow_idx[i - 1], :, :, :]
+                    if self.cpu_cache:
+                        flow_n2 = flow_n2.cuda()
+
+                    flow_n2 = flow_n1 + flow_warp(flow_n2, flow_n1.permute(0, 2, 3, 1))
+                    cond_n2 = flow_warp(feat_n2, flow_n2.permute(0, 2, 3, 1))
+
+                # flow-guided deformable convolution
+                cond = torch.cat([cond_n1, feat_current, cond_n2], dim=1)
+                feat_prop = torch.cat([feat_prop, feat_n2], dim=1)
+                feat_prop = self.deform_align[module_name](feat_prop, cond, flow_n1, flow_n2)
+
+            # concatenate and residual blocks
+            feat = [feat_current] + [feats[k][idx] for k in feats if k not in ['spatial', module_name]] + [feat_prop]
+            if self.cpu_cache:
+                feat = [f.cuda() for f in feat]
+
+            feat = torch.cat(feat, dim=1)
+            feat_prop = feat_prop + self.backbone[module_name](feat)
+            feats[module_name].append(feat_prop)
+
+            if self.cpu_cache:
+                feats[module_name][-1] = feats[module_name][-1].cpu()
+                torch.cuda.empty_cache()
+
+        if 'backward' in module_name:
+            feats[module_name] = feats[module_name][::-1]
+
+        return feats
+
+    def upsample(self, lqs, feats):
+        """Compute the output image given the features.
+
+        Args:
+            lqs (tensor): Input low quality (LQ) sequence with
+                shape (n, t, c, h, w).
+            feats (dict): The features from the propagation branches.
+
+        Returns:
+            Tensor: Output HR sequence with shape (n, t, c, 4h, 4w).
+        """
+
+        outputs = []
+        num_outputs = len(feats['spatial'])
+
+        mapping_idx = list(range(0, num_outputs))
+        mapping_idx += mapping_idx[::-1]
+
+        for i in range(0, lqs.size(1)):
+            hr = [feats[k].pop(0) for k in feats if k != 'spatial']
+            hr.insert(0, feats['spatial'][mapping_idx[i]])
+            hr = torch.cat(hr, dim=1)
+            if self.cpu_cache:
+                hr = hr.cuda()
+
+            hr = self.reconstruction(hr)
+            hr = self.lrelu(self.pixel_shuffle(self.upconv1(hr)))
+            hr = self.lrelu(self.pixel_shuffle(self.upconv2(hr)))
+            hr = self.lrelu(self.conv_hr(hr))
+            hr = self.conv_last(hr)
+            if self.is_low_res_input:
+                hr += self.img_upsample(lqs[:, i, :, :, :])
+            else:
+                hr += lqs[:, i, :, :, :]
+
+            if self.cpu_cache:
+                hr = hr.cpu()
+                torch.cuda.empty_cache()
+
+            outputs.append(hr)
+
+        return torch.stack(outputs, dim=1)
+
+    def forward(self, lqs):
+        """Forward function for BasicVSR++.
+
+        Args:
+            lqs (tensor): Input low quality (LQ) sequence with
+                shape (n, t, c, h, w).
+
+        Returns:
+            Tensor: Output HR sequence with shape (n, t, c, 4h, 4w).
+        """
+
+        n, t, c, h, w = lqs.size()
+
+        # whether to cache the features in CPU
+        self.cpu_cache = True if t > self.cpu_cache_length else False
+
+        if self.is_low_res_input:
+            lqs_downsample = lqs.clone()
+        else:
+            lqs_downsample = F.interpolate(
+                lqs.view(-1, c, h, w), scale_factor=0.25, mode='bicubic').view(n, t, c, h // 4, w // 4)
+
+        # check whether the input is an extended sequence
+        self.check_if_mirror_extended(lqs)
+
+        feats = {}
+        # compute spatial features
+        if self.cpu_cache:
+            feats['spatial'] = []
+            for i in range(0, t):
+                feat = self.feat_extract(lqs[:, i, :, :, :]).cpu()
+                feats['spatial'].append(feat)
+                torch.cuda.empty_cache()
+        else:
+            feats_ = self.feat_extract(lqs.view(-1, c, h, w))
+            h, w = feats_.shape[2:]
+            feats_ = feats_.view(n, t, -1, h, w)
+            feats['spatial'] = [feats_[:, i, :, :, :] for i in range(0, t)]
+
+        # compute optical flow using the low-res inputs
+        assert lqs_downsample.size(3) >= 64 and lqs_downsample.size(4) >= 64, (
+            'The height and width of low-res inputs must be at least 64, '
+            f'but got {h} and {w}.')
+        flows_forward, flows_backward = self.compute_flow(lqs_downsample)
+
+        # feature propgation
+        for iter_ in [1, 2]:
+            for direction in ['backward', 'forward']:
+                module = f'{direction}_{iter_}'
+
+                feats[module] = []
+
+                if direction == 'backward':
+                    flows = flows_backward
+                elif flows_forward is not None:
+                    flows = flows_forward
+                else:
+                    flows = flows_backward.flip(1)
+
+                feats = self.propagate(feats, flows, module)
+                if self.cpu_cache:
+                    del flows
+                    torch.cuda.empty_cache()
+
+        return self.upsample(lqs, feats)
+
+
+class SecondOrderDeformableAlignment(ModulatedDeformConvPack):
+    """Second-order deformable alignment module.
+
+    Args:
+        in_channels (int): Same as nn.Conv2d.
+        out_channels (int): Same as nn.Conv2d.
+        kernel_size (int or tuple[int]): Same as nn.Conv2d.
+        stride (int or tuple[int]): Same as nn.Conv2d.
+        padding (int or tuple[int]): Same as nn.Conv2d.
+        dilation (int or tuple[int]): Same as nn.Conv2d.
+        groups (int): Same as nn.Conv2d.
+        bias (bool or str): If specified as `auto`, it will be decided by the
+            norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
+            False.
+        max_residue_magnitude (int): The maximum magnitude of the offset
+            residue (Eq. 6 in paper). Default: 10.
+    """
+
+    def __init__(self, *args, **kwargs):
+        self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10)
+
+        super(SecondOrderDeformableAlignment, self).__init__(*args, **kwargs)
+
+        self.conv_offset = nn.Sequential(
+            nn.Conv2d(3 * self.out_channels + 4, self.out_channels, 3, 1, 1),
+            nn.LeakyReLU(negative_slope=0.1, inplace=True),
+            nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
+            nn.LeakyReLU(negative_slope=0.1, inplace=True),
+            nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
+            nn.LeakyReLU(negative_slope=0.1, inplace=True),
+            nn.Conv2d(self.out_channels, 27 * self.deformable_groups, 3, 1, 1),
+        )
+
+        self.init_offset()
+
+    def init_offset(self):
+
+        def _constant_init(module, val, bias=0):
+            if hasattr(module, 'weight') and module.weight is not None:
+                nn.init.constant_(module.weight, val)
+            if hasattr(module, 'bias') and module.bias is not None:
+                nn.init.constant_(module.bias, bias)
+
+        _constant_init(self.conv_offset[-1], val=0, bias=0)
+
+    def forward(self, x, extra_feat, flow_1, flow_2):
+        extra_feat = torch.cat([extra_feat, flow_1, flow_2], dim=1)
+        out = self.conv_offset(extra_feat)
+        o1, o2, mask = torch.chunk(out, 3, dim=1)
+
+        # offset
+        offset = self.max_residue_magnitude * torch.tanh(torch.cat((o1, o2), dim=1))
+        offset_1, offset_2 = torch.chunk(offset, 2, dim=1)
+        offset_1 = offset_1 + flow_1.flip(1).repeat(1, offset_1.size(1) // 2, 1, 1)
+        offset_2 = offset_2 + flow_2.flip(1).repeat(1, offset_2.size(1) // 2, 1, 1)
+        offset = torch.cat([offset_1, offset_2], dim=1)
+
+        # mask
+        mask = torch.sigmoid(mask)
+
+        return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
+                                             self.dilation, mask)
+
+
+# if __name__ == '__main__':
+#     spynet_path = 'experiments/pretrained_models/flownet/spynet_sintel_final-3d2a1287.pth'
+#     model = BasicVSRPlusPlus(spynet_path=spynet_path).cuda()
+#     input = torch.rand(1, 2, 3, 64, 64).cuda()
+#     output = model(input)
+#     print('===================')
+#     print(output.shape)
diff --git a/basicsr/archs/degradat_arch.py b/basicsr/archs/degradat_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce09ad666a90f175fb6268435073b314df543813
--- /dev/null
+++ b/basicsr/archs/degradat_arch.py
@@ -0,0 +1,90 @@
+from torch import nn as nn
+
+from basicsr.archs.arch_util import ResidualBlockNoBN, default_init_weights
+from basicsr.utils.registry import ARCH_REGISTRY
+
+@ARCH_REGISTRY.register()
+class DEResNet(nn.Module):
+    """Degradation Estimator with ResNetNoBN arch. v2.1, no vector anymore
+    As shown in paper 'Towards Flexible Blind JPEG Artifacts Removal',
+    resnet arch works for image quality estimation.
+    Args:
+        num_in_ch (int): channel number of inputs. Default: 3.
+        num_degradation (int): num of degradation the DE should estimate. Default: 2(blur+noise).
+        degradation_embed_size (int): embedding size of each degradation vector.
+        degradation_degree_actv (int): activation function for degradation degree scalar. Default: sigmoid.
+        num_feats (list): channel number of each stage.
+        num_blocks (list): residual block of each stage.
+        downscales (list): downscales of each stage.
+    """
+
+    def __init__(self,
+                 num_in_ch=3,
+                 num_degradation=2,
+                 degradation_degree_actv='sigmoid',
+                 num_feats=(64, 128, 256, 512),
+                 num_blocks=(2, 2, 2, 2),
+                 downscales=(2, 2, 2, 1)):
+        super(DEResNet, self).__init__()
+
+        assert isinstance(num_feats, list)
+        assert isinstance(num_blocks, list)
+        assert isinstance(downscales, list)
+        assert len(num_feats) == len(num_blocks) and len(num_feats) == len(downscales)
+
+        num_stage = len(num_feats)
+
+        self.conv_first = nn.ModuleList()
+        for _ in range(num_degradation):
+            self.conv_first.append(nn.Conv2d(num_in_ch, num_feats[0], 3, 1, 1))
+        self.body = nn.ModuleList()
+        for _ in range(num_degradation):
+            body = list()
+            for stage in range(num_stage):
+                for _ in range(num_blocks[stage]):
+                    body.append(ResidualBlockNoBN(num_feats[stage]))
+                if downscales[stage] == 1:
+                    if stage < num_stage - 1 and num_feats[stage] != num_feats[stage + 1]:
+                        body.append(nn.Conv2d(num_feats[stage], num_feats[stage + 1], 3, 1, 1))
+                    continue
+                elif downscales[stage] == 2:
+                    body.append(nn.Conv2d(num_feats[stage], num_feats[min(stage + 1, num_stage - 1)], 3, 2, 1))
+                else:
+                    raise NotImplementedError
+            self.body.append(nn.Sequential(*body))
+
+        # self.body = nn.Sequential(*body)
+
+        self.num_degradation = num_degradation
+        self.fc_degree = nn.ModuleList()
+        if degradation_degree_actv == 'sigmoid':
+            actv = nn.Sigmoid
+        elif degradation_degree_actv == 'tanh':
+            actv = nn.Tanh
+        else:
+            raise NotImplementedError(f'only sigmoid and tanh are supported for degradation_degree_actv, '
+                                      f'{degradation_degree_actv} is not supported yet.')
+        for _ in range(num_degradation):
+            self.fc_degree.append(
+                nn.Sequential(
+                    nn.Linear(num_feats[-1], 512),
+                    nn.ReLU(inplace=True),
+                    nn.Linear(512, 1),
+                    actv(),
+                ))
+
+        self.avg_pool = nn.AdaptiveAvgPool2d(1)
+
+        default_init_weights([self.conv_first, self.body, self.fc_degree], 0.1)
+
+    def forward(self, x):
+        degrees = []
+        for i in range(self.num_degradation):
+            x_out = self.conv_first[i](x)
+            feat = self.body[i](x_out)
+            feat = self.avg_pool(feat)
+            feat = feat.squeeze(-1).squeeze(-1)
+            # for i in range(self.num_degradation):
+            degrees.append(self.fc_degree[i](feat).squeeze(-1))
+
+        return degrees
diff --git a/basicsr/archs/dfdnet_arch.py b/basicsr/archs/dfdnet_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..4751434c2f17efbb682d9344951604602d853aaa
--- /dev/null
+++ b/basicsr/archs/dfdnet_arch.py
@@ -0,0 +1,169 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.utils.spectral_norm import spectral_norm
+
+from basicsr.utils.registry import ARCH_REGISTRY
+from .dfdnet_util import AttentionBlock, Blur, MSDilationBlock, UpResBlock, adaptive_instance_normalization
+from .vgg_arch import VGGFeatureExtractor
+
+
+class SFTUpBlock(nn.Module):
+    """Spatial feature transform (SFT) with upsampling block.
+
+    Args:
+        in_channel (int): Number of input channels.
+        out_channel (int): Number of output channels.
+        kernel_size (int): Kernel size in convolutions. Default: 3.
+        padding (int): Padding in convolutions. Default: 1.
+    """
+
+    def __init__(self, in_channel, out_channel, kernel_size=3, padding=1):
+        super(SFTUpBlock, self).__init__()
+        self.conv1 = nn.Sequential(
+            Blur(in_channel),
+            spectral_norm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
+            nn.LeakyReLU(0.04, True),
+            # The official codes use two LeakyReLU here, so 0.04 for equivalent
+        )
+        self.convup = nn.Sequential(
+            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
+            spectral_norm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
+            nn.LeakyReLU(0.2, True),
+        )
+
+        # for SFT scale and shift
+        self.scale_block = nn.Sequential(
+            spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
+            spectral_norm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)))
+        self.shift_block = nn.Sequential(
+            spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
+            spectral_norm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)), nn.Sigmoid())
+        # The official codes use sigmoid for shift block, do not know why
+
+    def forward(self, x, updated_feat):
+        out = self.conv1(x)
+        # SFT
+        scale = self.scale_block(updated_feat)
+        shift = self.shift_block(updated_feat)
+        out = out * scale + shift
+        # upsample
+        out = self.convup(out)
+        return out
+
+
+@ARCH_REGISTRY.register()
+class DFDNet(nn.Module):
+    """DFDNet: Deep Face Dictionary Network.
+
+    It only processes faces with 512x512 size.
+
+    Args:
+        num_feat (int): Number of feature channels.
+        dict_path (str): Path to the facial component dictionary.
+    """
+
+    def __init__(self, num_feat, dict_path):
+        super().__init__()
+        self.parts = ['left_eye', 'right_eye', 'nose', 'mouth']
+        # part_sizes: [80, 80, 50, 110]
+        channel_sizes = [128, 256, 512, 512]
+        self.feature_sizes = np.array([256, 128, 64, 32])
+        self.vgg_layers = ['relu2_2', 'relu3_4', 'relu4_4', 'conv5_4']
+        self.flag_dict_device = False
+
+        # dict
+        self.dict = torch.load(dict_path)
+
+        # vgg face extractor
+        self.vgg_extractor = VGGFeatureExtractor(
+            layer_name_list=self.vgg_layers,
+            vgg_type='vgg19',
+            use_input_norm=True,
+            range_norm=True,
+            requires_grad=False)
+
+        # attention block for fusing dictionary features and input features
+        self.attn_blocks = nn.ModuleDict()
+        for idx, feat_size in enumerate(self.feature_sizes):
+            for name in self.parts:
+                self.attn_blocks[f'{name}_{feat_size}'] = AttentionBlock(channel_sizes[idx])
+
+        # multi scale dilation block
+        self.multi_scale_dilation = MSDilationBlock(num_feat * 8, dilation=[4, 3, 2, 1])
+
+        # upsampling and reconstruction
+        self.upsample0 = SFTUpBlock(num_feat * 8, num_feat * 8)
+        self.upsample1 = SFTUpBlock(num_feat * 8, num_feat * 4)
+        self.upsample2 = SFTUpBlock(num_feat * 4, num_feat * 2)
+        self.upsample3 = SFTUpBlock(num_feat * 2, num_feat)
+        self.upsample4 = nn.Sequential(
+            spectral_norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1)), nn.LeakyReLU(0.2, True), UpResBlock(num_feat),
+            UpResBlock(num_feat), nn.Conv2d(num_feat, 3, kernel_size=3, stride=1, padding=1), nn.Tanh())
+
+    def swap_feat(self, vgg_feat, updated_feat, dict_feat, location, part_name, f_size):
+        """swap the features from the dictionary."""
+        # get the original vgg features
+        part_feat = vgg_feat[:, :, location[1]:location[3], location[0]:location[2]].clone()
+        # resize original vgg features
+        part_resize_feat = F.interpolate(part_feat, dict_feat.size()[2:4], mode='bilinear', align_corners=False)
+        # use adaptive instance normalization to adjust color and illuminations
+        dict_feat = adaptive_instance_normalization(dict_feat, part_resize_feat)
+        # get similarity scores
+        similarity_score = F.conv2d(part_resize_feat, dict_feat)
+        similarity_score = F.softmax(similarity_score.view(-1), dim=0)
+        # select the most similar features in the dict (after norm)
+        select_idx = torch.argmax(similarity_score)
+        swap_feat = F.interpolate(dict_feat[select_idx:select_idx + 1], part_feat.size()[2:4])
+        # attention
+        attn = self.attn_blocks[f'{part_name}_' + str(f_size)](swap_feat - part_feat)
+        attn_feat = attn * swap_feat
+        # update features
+        updated_feat[:, :, location[1]:location[3], location[0]:location[2]] = attn_feat + part_feat
+        return updated_feat
+
+    def put_dict_to_device(self, x):
+        if self.flag_dict_device is False:
+            for k, v in self.dict.items():
+                for kk, vv in v.items():
+                    self.dict[k][kk] = vv.to(x)
+            self.flag_dict_device = True
+
+    def forward(self, x, part_locations):
+        """
+        Now only support testing with batch size = 0.
+
+        Args:
+            x (Tensor): Input faces with shape (b, c, 512, 512).
+            part_locations (list[Tensor]): Part locations.
+        """
+        self.put_dict_to_device(x)
+        # extract vggface features
+        vgg_features = self.vgg_extractor(x)
+        # update vggface features using the dictionary for each part
+        updated_vgg_features = []
+        batch = 0  # only supports testing with batch size = 0
+        for vgg_layer, f_size in zip(self.vgg_layers, self.feature_sizes):
+            dict_features = self.dict[f'{f_size}']
+            vgg_feat = vgg_features[vgg_layer]
+            updated_feat = vgg_feat.clone()
+
+            # swap features from dictionary
+            for part_idx, part_name in enumerate(self.parts):
+                location = (part_locations[part_idx][batch] // (512 / f_size)).int()
+                updated_feat = self.swap_feat(vgg_feat, updated_feat, dict_features[part_name], location, part_name,
+                                              f_size)
+
+            updated_vgg_features.append(updated_feat)
+
+        vgg_feat_dilation = self.multi_scale_dilation(vgg_features['conv5_4'])
+        # use updated vgg features to modulate the upsampled features with
+        # SFT (Spatial Feature Transform) scaling and shifting manner.
+        upsampled_feat = self.upsample0(vgg_feat_dilation, updated_vgg_features[3])
+        upsampled_feat = self.upsample1(upsampled_feat, updated_vgg_features[2])
+        upsampled_feat = self.upsample2(upsampled_feat, updated_vgg_features[1])
+        upsampled_feat = self.upsample3(upsampled_feat, updated_vgg_features[0])
+        out = self.upsample4(upsampled_feat)
+
+        return out
diff --git a/basicsr/archs/dfdnet_util.py b/basicsr/archs/dfdnet_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4dc0ff738c76852e830b32fffbe65bffb5ddf50
--- /dev/null
+++ b/basicsr/archs/dfdnet_util.py
@@ -0,0 +1,162 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Function
+from torch.nn.utils.spectral_norm import spectral_norm
+
+
+class BlurFunctionBackward(Function):
+
+    @staticmethod
+    def forward(ctx, grad_output, kernel, kernel_flip):
+        ctx.save_for_backward(kernel, kernel_flip)
+        grad_input = F.conv2d(grad_output, kernel_flip, padding=1, groups=grad_output.shape[1])
+        return grad_input
+
+    @staticmethod
+    def backward(ctx, gradgrad_output):
+        kernel, _ = ctx.saved_tensors
+        grad_input = F.conv2d(gradgrad_output, kernel, padding=1, groups=gradgrad_output.shape[1])
+        return grad_input, None, None
+
+
+class BlurFunction(Function):
+
+    @staticmethod
+    def forward(ctx, x, kernel, kernel_flip):
+        ctx.save_for_backward(kernel, kernel_flip)
+        output = F.conv2d(x, kernel, padding=1, groups=x.shape[1])
+        return output
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        kernel, kernel_flip = ctx.saved_tensors
+        grad_input = BlurFunctionBackward.apply(grad_output, kernel, kernel_flip)
+        return grad_input, None, None
+
+
+blur = BlurFunction.apply
+
+
+class Blur(nn.Module):
+
+    def __init__(self, channel):
+        super().__init__()
+        kernel = torch.tensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]], dtype=torch.float32)
+        kernel = kernel.view(1, 1, 3, 3)
+        kernel = kernel / kernel.sum()
+        kernel_flip = torch.flip(kernel, [2, 3])
+
+        self.kernel = kernel.repeat(channel, 1, 1, 1)
+        self.kernel_flip = kernel_flip.repeat(channel, 1, 1, 1)
+
+    def forward(self, x):
+        return blur(x, self.kernel.type_as(x), self.kernel_flip.type_as(x))
+
+
+def calc_mean_std(feat, eps=1e-5):
+    """Calculate mean and std for adaptive_instance_normalization.
+
+    Args:
+        feat (Tensor): 4D tensor.
+        eps (float): A small value added to the variance to avoid
+            divide-by-zero. Default: 1e-5.
+    """
+    size = feat.size()
+    assert len(size) == 4, 'The input feature should be 4D tensor.'
+    n, c = size[:2]
+    feat_var = feat.view(n, c, -1).var(dim=2) + eps
+    feat_std = feat_var.sqrt().view(n, c, 1, 1)
+    feat_mean = feat.view(n, c, -1).mean(dim=2).view(n, c, 1, 1)
+    return feat_mean, feat_std
+
+
+def adaptive_instance_normalization(content_feat, style_feat):
+    """Adaptive instance normalization.
+
+    Adjust the reference features to have the similar color and illuminations
+    as those in the degradate features.
+
+    Args:
+        content_feat (Tensor): The reference feature.
+        style_feat (Tensor): The degradate features.
+    """
+    size = content_feat.size()
+    style_mean, style_std = calc_mean_std(style_feat)
+    content_mean, content_std = calc_mean_std(content_feat)
+    normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
+    return normalized_feat * style_std.expand(size) + style_mean.expand(size)
+
+
+def AttentionBlock(in_channel):
+    return nn.Sequential(
+        spectral_norm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
+        spectral_norm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)))
+
+
+def conv_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, bias=True):
+    """Conv block used in MSDilationBlock."""
+
+    return nn.Sequential(
+        spectral_norm(
+            nn.Conv2d(
+                in_channels,
+                out_channels,
+                kernel_size=kernel_size,
+                stride=stride,
+                dilation=dilation,
+                padding=((kernel_size - 1) // 2) * dilation,
+                bias=bias)),
+        nn.LeakyReLU(0.2),
+        spectral_norm(
+            nn.Conv2d(
+                out_channels,
+                out_channels,
+                kernel_size=kernel_size,
+                stride=stride,
+                dilation=dilation,
+                padding=((kernel_size - 1) // 2) * dilation,
+                bias=bias)),
+    )
+
+
+class MSDilationBlock(nn.Module):
+    """Multi-scale dilation block."""
+
+    def __init__(self, in_channels, kernel_size=3, dilation=(1, 1, 1, 1), bias=True):
+        super(MSDilationBlock, self).__init__()
+
+        self.conv_blocks = nn.ModuleList()
+        for i in range(4):
+            self.conv_blocks.append(conv_block(in_channels, in_channels, kernel_size, dilation=dilation[i], bias=bias))
+        self.conv_fusion = spectral_norm(
+            nn.Conv2d(
+                in_channels * 4,
+                in_channels,
+                kernel_size=kernel_size,
+                stride=1,
+                padding=(kernel_size - 1) // 2,
+                bias=bias))
+
+    def forward(self, x):
+        out = []
+        for i in range(4):
+            out.append(self.conv_blocks[i](x))
+        out = torch.cat(out, 1)
+        out = self.conv_fusion(out) + x
+        return out
+
+
+class UpResBlock(nn.Module):
+
+    def __init__(self, in_channel):
+        super(UpResBlock, self).__init__()
+        self.body = nn.Sequential(
+            nn.Conv2d(in_channel, in_channel, 3, 1, 1),
+            nn.LeakyReLU(0.2, True),
+            nn.Conv2d(in_channel, in_channel, 3, 1, 1),
+        )
+
+    def forward(self, x):
+        out = x + self.body(x)
+        return out
diff --git a/basicsr/archs/discriminator_arch.py b/basicsr/archs/discriminator_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..33f9a8f1b25c2052cd3ba801534861a425752e69
--- /dev/null
+++ b/basicsr/archs/discriminator_arch.py
@@ -0,0 +1,150 @@
+from torch import nn as nn
+from torch.nn import functional as F
+from torch.nn.utils import spectral_norm
+
+from basicsr.utils.registry import ARCH_REGISTRY
+
+
+@ARCH_REGISTRY.register()
+class VGGStyleDiscriminator(nn.Module):
+    """VGG style discriminator with input size 128 x 128 or 256 x 256.
+
+    It is used to train SRGAN, ESRGAN, and VideoGAN.
+
+    Args:
+        num_in_ch (int): Channel number of inputs. Default: 3.
+        num_feat (int): Channel number of base intermediate features.Default: 64.
+    """
+
+    def __init__(self, num_in_ch, num_feat, input_size=128):
+        super(VGGStyleDiscriminator, self).__init__()
+        self.input_size = input_size
+        assert self.input_size == 128 or self.input_size == 256, (
+            f'input size must be 128 or 256, but received {input_size}')
+
+        self.conv0_0 = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1, bias=True)
+        self.conv0_1 = nn.Conv2d(num_feat, num_feat, 4, 2, 1, bias=False)
+        self.bn0_1 = nn.BatchNorm2d(num_feat, affine=True)
+
+        self.conv1_0 = nn.Conv2d(num_feat, num_feat * 2, 3, 1, 1, bias=False)
+        self.bn1_0 = nn.BatchNorm2d(num_feat * 2, affine=True)
+        self.conv1_1 = nn.Conv2d(num_feat * 2, num_feat * 2, 4, 2, 1, bias=False)
+        self.bn1_1 = nn.BatchNorm2d(num_feat * 2, affine=True)
+
+        self.conv2_0 = nn.Conv2d(num_feat * 2, num_feat * 4, 3, 1, 1, bias=False)
+        self.bn2_0 = nn.BatchNorm2d(num_feat * 4, affine=True)
+        self.conv2_1 = nn.Conv2d(num_feat * 4, num_feat * 4, 4, 2, 1, bias=False)
+        self.bn2_1 = nn.BatchNorm2d(num_feat * 4, affine=True)
+
+        self.conv3_0 = nn.Conv2d(num_feat * 4, num_feat * 8, 3, 1, 1, bias=False)
+        self.bn3_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
+        self.conv3_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)
+        self.bn3_1 = nn.BatchNorm2d(num_feat * 8, affine=True)
+
+        self.conv4_0 = nn.Conv2d(num_feat * 8, num_feat * 8, 3, 1, 1, bias=False)
+        self.bn4_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
+        self.conv4_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)
+        self.bn4_1 = nn.BatchNorm2d(num_feat * 8, affine=True)
+
+        if self.input_size == 256:
+            self.conv5_0 = nn.Conv2d(num_feat * 8, num_feat * 8, 3, 1, 1, bias=False)
+            self.bn5_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
+            self.conv5_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)
+            self.bn5_1 = nn.BatchNorm2d(num_feat * 8, affine=True)
+
+        self.linear1 = nn.Linear(num_feat * 8 * 4 * 4, 100)
+        self.linear2 = nn.Linear(100, 1)
+
+        # activation function
+        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+
+    def forward(self, x):
+        assert x.size(2) == self.input_size, (f'Input size must be identical to input_size, but received {x.size()}.')
+
+        feat = self.lrelu(self.conv0_0(x))
+        feat = self.lrelu(self.bn0_1(self.conv0_1(feat)))  # output spatial size: /2
+
+        feat = self.lrelu(self.bn1_0(self.conv1_0(feat)))
+        feat = self.lrelu(self.bn1_1(self.conv1_1(feat)))  # output spatial size: /4
+
+        feat = self.lrelu(self.bn2_0(self.conv2_0(feat)))
+        feat = self.lrelu(self.bn2_1(self.conv2_1(feat)))  # output spatial size: /8
+
+        feat = self.lrelu(self.bn3_0(self.conv3_0(feat)))
+        feat = self.lrelu(self.bn3_1(self.conv3_1(feat)))  # output spatial size: /16
+
+        feat = self.lrelu(self.bn4_0(self.conv4_0(feat)))
+        feat = self.lrelu(self.bn4_1(self.conv4_1(feat)))  # output spatial size: /32
+
+        if self.input_size == 256:
+            feat = self.lrelu(self.bn5_0(self.conv5_0(feat)))
+            feat = self.lrelu(self.bn5_1(self.conv5_1(feat)))  # output spatial size: / 64
+
+        # spatial size: (4, 4)
+        feat = feat.view(feat.size(0), -1)
+        feat = self.lrelu(self.linear1(feat))
+        out = self.linear2(feat)
+        return out
+
+
+@ARCH_REGISTRY.register(suffix='basicsr')
+class UNetDiscriminatorSN(nn.Module):
+    """Defines a U-Net discriminator with spectral normalization (SN)
+
+    It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
+
+    Arg:
+        num_in_ch (int): Channel number of inputs. Default: 3.
+        num_feat (int): Channel number of base intermediate features. Default: 64.
+        skip_connection (bool): Whether to use skip connections between U-Net. Default: True.
+    """
+
+    def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
+        super(UNetDiscriminatorSN, self).__init__()
+        self.skip_connection = skip_connection
+        norm = spectral_norm
+        # the first convolution
+        self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
+        # downsample
+        self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
+        self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
+        self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
+        # upsample
+        self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
+        self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
+        self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
+        # extra convolutions
+        self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
+        self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
+        self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
+
+    def forward(self, x):
+        # downsample
+        x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
+        x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
+        x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
+        x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)
+
+        # upsample
+        x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)
+        x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)
+
+        if self.skip_connection:
+            x4 = x4 + x2
+        x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
+        x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)
+
+        if self.skip_connection:
+            x5 = x5 + x1
+        x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False)
+        x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)
+
+        if self.skip_connection:
+            x6 = x6 + x0
+
+        # extra convolutions
+        out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
+        out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
+        out = self.conv9(out)
+
+        return out
diff --git a/basicsr/archs/duf_arch.py b/basicsr/archs/duf_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b3ab7df4d890c9220d74ed8c461ad9d155120a
--- /dev/null
+++ b/basicsr/archs/duf_arch.py
@@ -0,0 +1,276 @@
+import numpy as np
+import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+from basicsr.utils.registry import ARCH_REGISTRY
+
+
+class DenseBlocksTemporalReduce(nn.Module):
+    """A concatenation of 3 dense blocks with reduction in temporal dimension.
+
+    Note that the output temporal dimension is 6 fewer the input temporal dimension, since there are 3 blocks.
+
+    Args:
+        num_feat (int): Number of channels in the blocks. Default: 64.
+        num_grow_ch (int): Growing factor of the dense blocks. Default: 32
+        adapt_official_weights (bool): Whether to adapt the weights translated from the official implementation.
+            Set to false if you want to train from scratch. Default: False.
+    """
+
+    def __init__(self, num_feat=64, num_grow_ch=32, adapt_official_weights=False):
+        super(DenseBlocksTemporalReduce, self).__init__()
+        if adapt_official_weights:
+            eps = 1e-3
+            momentum = 1e-3
+        else:  # pytorch default values
+            eps = 1e-05
+            momentum = 0.1
+
+        self.temporal_reduce1 = nn.Sequential(
+            nn.BatchNorm3d(num_feat, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
+            nn.Conv3d(num_feat, num_feat, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True),
+            nn.BatchNorm3d(num_feat, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
+            nn.Conv3d(num_feat, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True))
+
+        self.temporal_reduce2 = nn.Sequential(
+            nn.BatchNorm3d(num_feat + num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
+            nn.Conv3d(
+                num_feat + num_grow_ch,
+                num_feat + num_grow_ch, (1, 1, 1),
+                stride=(1, 1, 1),
+                padding=(0, 0, 0),
+                bias=True), nn.BatchNorm3d(num_feat + num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
+            nn.Conv3d(num_feat + num_grow_ch, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True))
+
+        self.temporal_reduce3 = nn.Sequential(
+            nn.BatchNorm3d(num_feat + 2 * num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
+            nn.Conv3d(
+                num_feat + 2 * num_grow_ch,
+                num_feat + 2 * num_grow_ch, (1, 1, 1),
+                stride=(1, 1, 1),
+                padding=(0, 0, 0),
+                bias=True), nn.BatchNorm3d(num_feat + 2 * num_grow_ch, eps=eps, momentum=momentum),
+            nn.ReLU(inplace=True),
+            nn.Conv3d(
+                num_feat + 2 * num_grow_ch, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True))
+
+    def forward(self, x):
+        """
+        Args:
+            x (Tensor): Input tensor with shape (b, num_feat, t, h, w).
+
+        Returns:
+            Tensor: Output with shape (b, num_feat + num_grow_ch * 3, 1, h, w).
+        """
+        x1 = self.temporal_reduce1(x)
+        x1 = torch.cat((x[:, :, 1:-1, :, :], x1), 1)
+
+        x2 = self.temporal_reduce2(x1)
+        x2 = torch.cat((x1[:, :, 1:-1, :, :], x2), 1)
+
+        x3 = self.temporal_reduce3(x2)
+        x3 = torch.cat((x2[:, :, 1:-1, :, :], x3), 1)
+
+        return x3
+
+
+class DenseBlocks(nn.Module):
+    """ A concatenation of N dense blocks.
+
+    Args:
+        num_feat (int): Number of channels in the blocks. Default: 64.
+        num_grow_ch (int): Growing factor of the dense blocks. Default: 32.
+        num_block (int): Number of dense blocks. The values are:
+            DUF-S (16 layers): 3
+            DUF-M (18 layers): 9
+            DUF-L (52 layers): 21
+        adapt_official_weights (bool): Whether to adapt the weights translated from the official implementation.
+            Set to false if you want to train from scratch. Default: False.
+    """
+
+    def __init__(self, num_block, num_feat=64, num_grow_ch=16, adapt_official_weights=False):
+        super(DenseBlocks, self).__init__()
+        if adapt_official_weights:
+            eps = 1e-3
+            momentum = 1e-3
+        else:  # pytorch default values
+            eps = 1e-05
+            momentum = 0.1
+
+        self.dense_blocks = nn.ModuleList()
+        for i in range(0, num_block):
+            self.dense_blocks.append(
+                nn.Sequential(
+                    nn.BatchNorm3d(num_feat + i * num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
+                    nn.Conv3d(
+                        num_feat + i * num_grow_ch,
+                        num_feat + i * num_grow_ch, (1, 1, 1),
+                        stride=(1, 1, 1),
+                        padding=(0, 0, 0),
+                        bias=True), nn.BatchNorm3d(num_feat + i * num_grow_ch, eps=eps, momentum=momentum),
+                    nn.ReLU(inplace=True),
+                    nn.Conv3d(
+                        num_feat + i * num_grow_ch,
+                        num_grow_ch, (3, 3, 3),
+                        stride=(1, 1, 1),
+                        padding=(1, 1, 1),
+                        bias=True)))
+
+    def forward(self, x):
+        """
+        Args:
+            x (Tensor): Input tensor with shape (b, num_feat, t, h, w).
+
+        Returns:
+            Tensor: Output with shape (b, num_feat + num_block * num_grow_ch, t, h, w).
+        """
+        for i in range(0, len(self.dense_blocks)):
+            y = self.dense_blocks[i](x)
+            x = torch.cat((x, y), 1)
+        return x
+
+
+class DynamicUpsamplingFilter(nn.Module):
+    """Dynamic upsampling filter used in DUF.
+
+    Reference: https://github.com/yhjo09/VSR-DUF
+
+    It only supports input with 3 channels. And it applies the same filters to 3 channels.
+
+    Args:
+        filter_size (tuple): Filter size of generated filters. The shape is (kh, kw). Default: (5, 5).
+    """
+
+    def __init__(self, filter_size=(5, 5)):
+        super(DynamicUpsamplingFilter, self).__init__()
+        if not isinstance(filter_size, tuple):
+            raise TypeError(f'The type of filter_size must be tuple, but got type{filter_size}')
+        if len(filter_size) != 2:
+            raise ValueError(f'The length of filter size must be 2, but got {len(filter_size)}.')
+        # generate a local expansion filter, similar to im2col
+        self.filter_size = filter_size
+        filter_prod = np.prod(filter_size)
+        expansion_filter = torch.eye(int(filter_prod)).view(filter_prod, 1, *filter_size)  # (kh*kw, 1, kh, kw)
+        self.expansion_filter = expansion_filter.repeat(3, 1, 1, 1)  # repeat for all the 3 channels
+
+    def forward(self, x, filters):
+        """Forward function for DynamicUpsamplingFilter.
+
+        Args:
+            x (Tensor): Input image with 3 channels. The shape is (n, 3, h, w).
+            filters (Tensor): Generated dynamic filters. The shape is (n, filter_prod, upsampling_square, h, w).
+                filter_prod: prod of filter kernel size, e.g., 1*5*5=25.
+                upsampling_square: similar to pixel shuffle, upsampling_square = upsampling * upsampling.
+                e.g., for x 4 upsampling, upsampling_square= 4*4 = 16
+
+        Returns:
+            Tensor: Filtered image with shape (n, 3*upsampling_square, h, w)
+        """
+        n, filter_prod, upsampling_square, h, w = filters.size()
+        kh, kw = self.filter_size
+        expanded_input = F.conv2d(
+            x, self.expansion_filter.to(x), padding=(kh // 2, kw // 2), groups=3)  # (n, 3*filter_prod, h, w)
+        expanded_input = expanded_input.view(n, 3, filter_prod, h, w).permute(0, 3, 4, 1,
+                                                                              2)  # (n, h, w, 3, filter_prod)
+        filters = filters.permute(0, 3, 4, 1, 2)  # (n, h, w, filter_prod, upsampling_square]
+        out = torch.matmul(expanded_input, filters)  # (n, h, w, 3, upsampling_square)
+        return out.permute(0, 3, 4, 1, 2).view(n, 3 * upsampling_square, h, w)
+
+
+@ARCH_REGISTRY.register()
+class DUF(nn.Module):
+    """Network architecture for DUF
+
+    ``Paper: Deep Video Super-Resolution Network Using Dynamic Upsampling Filters Without Explicit Motion Compensation``
+
+    Reference: https://github.com/yhjo09/VSR-DUF
+
+    For all the models below, 'adapt_official_weights' is only necessary when
+    loading the weights converted from the official TensorFlow weights.
+    Please set it to False if you are training the model from scratch.
+
+    There are three models with different model size: DUF16Layers, DUF28Layers,
+    and DUF52Layers. This class is the base class for these models.
+
+    Args:
+        scale (int): The upsampling factor. Default: 4.
+        num_layer (int): The number of layers. Default: 52.
+        adapt_official_weights_weights (bool): Whether to adapt the weights
+            translated from the official implementation. Set to false if you
+            want to train from scratch. Default: False.
+    """
+
+    def __init__(self, scale=4, num_layer=52, adapt_official_weights=False):
+        super(DUF, self).__init__()
+        self.scale = scale
+        if adapt_official_weights:
+            eps = 1e-3
+            momentum = 1e-3
+        else:  # pytorch default values
+            eps = 1e-05
+            momentum = 0.1
+
+        self.conv3d1 = nn.Conv3d(3, 64, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)
+        self.dynamic_filter = DynamicUpsamplingFilter((5, 5))
+
+        if num_layer == 16:
+            num_block = 3
+            num_grow_ch = 32
+        elif num_layer == 28:
+            num_block = 9
+            num_grow_ch = 16
+        elif num_layer == 52:
+            num_block = 21
+            num_grow_ch = 16
+        else:
+            raise ValueError(f'Only supported (16, 28, 52) layers, but got {num_layer}.')
+
+        self.dense_block1 = DenseBlocks(
+            num_block=num_block, num_feat=64, num_grow_ch=num_grow_ch,
+            adapt_official_weights=adapt_official_weights)  # T = 7
+        self.dense_block2 = DenseBlocksTemporalReduce(
+            64 + num_grow_ch * num_block, num_grow_ch, adapt_official_weights=adapt_official_weights)  # T = 1
+        channels = 64 + num_grow_ch * num_block + num_grow_ch * 3
+        self.bn3d2 = nn.BatchNorm3d(channels, eps=eps, momentum=momentum)
+        self.conv3d2 = nn.Conv3d(channels, 256, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)
+
+        self.conv3d_r1 = nn.Conv3d(256, 256, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
+        self.conv3d_r2 = nn.Conv3d(256, 3 * (scale**2), (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
+
+        self.conv3d_f1 = nn.Conv3d(256, 512, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
+        self.conv3d_f2 = nn.Conv3d(
+            512, 1 * 5 * 5 * (scale**2), (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
+
+    def forward(self, x):
+        """
+        Args:
+            x (Tensor): Input with shape (b, 7, c, h, w)
+
+        Returns:
+            Tensor: Output with shape (b, c, h * scale, w * scale)
+        """
+        num_batches, num_imgs, _, h, w = x.size()
+
+        x = x.permute(0, 2, 1, 3, 4)  # (b, c, 7, h, w) for Conv3D
+        x_center = x[:, :, num_imgs // 2, :, :]
+
+        x = self.conv3d1(x)
+        x = self.dense_block1(x)
+        x = self.dense_block2(x)
+        x = F.relu(self.bn3d2(x), inplace=True)
+        x = F.relu(self.conv3d2(x), inplace=True)
+
+        # residual image
+        res = self.conv3d_r2(F.relu(self.conv3d_r1(x), inplace=True))
+
+        # filter
+        filter_ = self.conv3d_f2(F.relu(self.conv3d_f1(x), inplace=True))
+        filter_ = F.softmax(filter_.view(num_batches, 25, self.scale**2, h, w), dim=1)
+
+        # dynamic filter
+        out = self.dynamic_filter(x_center, filter_)
+        out += res.squeeze_(2)
+        out = F.pixel_shuffle(out, self.scale)
+
+        return out
diff --git a/basicsr/archs/ecbsr_arch.py b/basicsr/archs/ecbsr_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe20e772587d74c67fffb40f3b4731cf4f42268b
--- /dev/null
+++ b/basicsr/archs/ecbsr_arch.py
@@ -0,0 +1,275 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from basicsr.utils.registry import ARCH_REGISTRY
+
+
+class SeqConv3x3(nn.Module):
+    """The re-parameterizable block used in the ECBSR architecture.
+
+    ``Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices``
+
+    Reference: https://github.com/xindongzhang/ECBSR
+
+    Args:
+        seq_type (str): Sequence type, option: conv1x1-conv3x3 | conv1x1-sobelx | conv1x1-sobely | conv1x1-laplacian.
+        in_channels (int): Channel number of input.
+        out_channels (int): Channel number of output.
+        depth_multiplier (int): Width multiplier in the expand-and-squeeze conv. Default: 1.
+    """
+
+    def __init__(self, seq_type, in_channels, out_channels, depth_multiplier=1):
+        super(SeqConv3x3, self).__init__()
+        self.seq_type = seq_type
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+
+        if self.seq_type == 'conv1x1-conv3x3':
+            self.mid_planes = int(out_channels * depth_multiplier)
+            conv0 = torch.nn.Conv2d(self.in_channels, self.mid_planes, kernel_size=1, padding=0)
+            self.k0 = conv0.weight
+            self.b0 = conv0.bias
+
+            conv1 = torch.nn.Conv2d(self.mid_planes, self.out_channels, kernel_size=3)
+            self.k1 = conv1.weight
+            self.b1 = conv1.bias
+
+        elif self.seq_type == 'conv1x1-sobelx':
+            conv0 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, padding=0)
+            self.k0 = conv0.weight
+            self.b0 = conv0.bias
+
+            # init scale and bias
+            scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3
+            self.scale = nn.Parameter(scale)
+            bias = torch.randn(self.out_channels) * 1e-3
+            bias = torch.reshape(bias, (self.out_channels, ))
+            self.bias = nn.Parameter(bias)
+            # init mask
+            self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32)
+            for i in range(self.out_channels):
+                self.mask[i, 0, 0, 0] = 1.0
+                self.mask[i, 0, 1, 0] = 2.0
+                self.mask[i, 0, 2, 0] = 1.0
+                self.mask[i, 0, 0, 2] = -1.0
+                self.mask[i, 0, 1, 2] = -2.0
+                self.mask[i, 0, 2, 2] = -1.0
+            self.mask = nn.Parameter(data=self.mask, requires_grad=False)
+
+        elif self.seq_type == 'conv1x1-sobely':
+            conv0 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, padding=0)
+            self.k0 = conv0.weight
+            self.b0 = conv0.bias
+
+            # init scale and bias
+            scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3
+            self.scale = nn.Parameter(torch.FloatTensor(scale))
+            bias = torch.randn(self.out_channels) * 1e-3
+            bias = torch.reshape(bias, (self.out_channels, ))
+            self.bias = nn.Parameter(torch.FloatTensor(bias))
+            # init mask
+            self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32)
+            for i in range(self.out_channels):
+                self.mask[i, 0, 0, 0] = 1.0
+                self.mask[i, 0, 0, 1] = 2.0
+                self.mask[i, 0, 0, 2] = 1.0
+                self.mask[i, 0, 2, 0] = -1.0
+                self.mask[i, 0, 2, 1] = -2.0
+                self.mask[i, 0, 2, 2] = -1.0
+            self.mask = nn.Parameter(data=self.mask, requires_grad=False)
+
+        elif self.seq_type == 'conv1x1-laplacian':
+            conv0 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, padding=0)
+            self.k0 = conv0.weight
+            self.b0 = conv0.bias
+
+            # init scale and bias
+            scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3
+            self.scale = nn.Parameter(torch.FloatTensor(scale))
+            bias = torch.randn(self.out_channels) * 1e-3
+            bias = torch.reshape(bias, (self.out_channels, ))
+            self.bias = nn.Parameter(torch.FloatTensor(bias))
+            # init mask
+            self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32)
+            for i in range(self.out_channels):
+                self.mask[i, 0, 0, 1] = 1.0
+                self.mask[i, 0, 1, 0] = 1.0
+                self.mask[i, 0, 1, 2] = 1.0
+                self.mask[i, 0, 2, 1] = 1.0
+                self.mask[i, 0, 1, 1] = -4.0
+            self.mask = nn.Parameter(data=self.mask, requires_grad=False)
+        else:
+            raise ValueError('The type of seqconv is not supported!')
+
+    def forward(self, x):
+        if self.seq_type == 'conv1x1-conv3x3':
+            # conv-1x1
+            y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1)
+            # explicitly padding with bias
+            y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0)
+            b0_pad = self.b0.view(1, -1, 1, 1)
+            y0[:, :, 0:1, :] = b0_pad
+            y0[:, :, -1:, :] = b0_pad
+            y0[:, :, :, 0:1] = b0_pad
+            y0[:, :, :, -1:] = b0_pad
+            # conv-3x3
+            y1 = F.conv2d(input=y0, weight=self.k1, bias=self.b1, stride=1)
+        else:
+            y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1)
+            # explicitly padding with bias
+            y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0)
+            b0_pad = self.b0.view(1, -1, 1, 1)
+            y0[:, :, 0:1, :] = b0_pad
+            y0[:, :, -1:, :] = b0_pad
+            y0[:, :, :, 0:1] = b0_pad
+            y0[:, :, :, -1:] = b0_pad
+            # conv-3x3
+            y1 = F.conv2d(input=y0, weight=self.scale * self.mask, bias=self.bias, stride=1, groups=self.out_channels)
+        return y1
+
+    def rep_params(self):
+        device = self.k0.get_device()
+        if device < 0:
+            device = None
+
+        if self.seq_type == 'conv1x1-conv3x3':
+            # re-param conv kernel
+            rep_weight = F.conv2d(input=self.k1, weight=self.k0.permute(1, 0, 2, 3))
+            # re-param conv bias
+            rep_bias = torch.ones(1, self.mid_planes, 3, 3, device=device) * self.b0.view(1, -1, 1, 1)
+            rep_bias = F.conv2d(input=rep_bias, weight=self.k1).view(-1, ) + self.b1
+        else:
+            tmp = self.scale * self.mask
+            k1 = torch.zeros((self.out_channels, self.out_channels, 3, 3), device=device)
+            for i in range(self.out_channels):
+                k1[i, i, :, :] = tmp[i, 0, :, :]
+            b1 = self.bias
+            # re-param conv kernel
+            rep_weight = F.conv2d(input=k1, weight=self.k0.permute(1, 0, 2, 3))
+            # re-param conv bias
+            rep_bias = torch.ones(1, self.out_channels, 3, 3, device=device) * self.b0.view(1, -1, 1, 1)
+            rep_bias = F.conv2d(input=rep_bias, weight=k1).view(-1, ) + b1
+        return rep_weight, rep_bias
+
+
+class ECB(nn.Module):
+    """The ECB block used in the ECBSR architecture.
+
+    Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices
+    Ref git repo: https://github.com/xindongzhang/ECBSR
+
+    Args:
+        in_channels (int): Channel number of input.
+        out_channels (int): Channel number of output.
+        depth_multiplier (int): Width multiplier in the expand-and-squeeze conv. Default: 1.
+        act_type (str): Activation type. Option: prelu | relu | rrelu | softplus | linear. Default: prelu.
+        with_idt (bool): Whether to use identity connection. Default: False.
+    """
+
+    def __init__(self, in_channels, out_channels, depth_multiplier, act_type='prelu', with_idt=False):
+        super(ECB, self).__init__()
+
+        self.depth_multiplier = depth_multiplier
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.act_type = act_type
+
+        if with_idt and (self.in_channels == self.out_channels):
+            self.with_idt = True
+        else:
+            self.with_idt = False
+
+        self.conv3x3 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=3, padding=1)
+        self.conv1x1_3x3 = SeqConv3x3('conv1x1-conv3x3', self.in_channels, self.out_channels, self.depth_multiplier)
+        self.conv1x1_sbx = SeqConv3x3('conv1x1-sobelx', self.in_channels, self.out_channels)
+        self.conv1x1_sby = SeqConv3x3('conv1x1-sobely', self.in_channels, self.out_channels)
+        self.conv1x1_lpl = SeqConv3x3('conv1x1-laplacian', self.in_channels, self.out_channels)
+
+        if self.act_type == 'prelu':
+            self.act = nn.PReLU(num_parameters=self.out_channels)
+        elif self.act_type == 'relu':
+            self.act = nn.ReLU(inplace=True)
+        elif self.act_type == 'rrelu':
+            self.act = nn.RReLU(lower=-0.05, upper=0.05)
+        elif self.act_type == 'softplus':
+            self.act = nn.Softplus()
+        elif self.act_type == 'linear':
+            pass
+        else:
+            raise ValueError('The type of activation if not support!')
+
+    def forward(self, x):
+        if self.training:
+            y = self.conv3x3(x) + self.conv1x1_3x3(x) + self.conv1x1_sbx(x) + self.conv1x1_sby(x) + self.conv1x1_lpl(x)
+            if self.with_idt:
+                y += x
+        else:
+            rep_weight, rep_bias = self.rep_params()
+            y = F.conv2d(input=x, weight=rep_weight, bias=rep_bias, stride=1, padding=1)
+        if self.act_type != 'linear':
+            y = self.act(y)
+        return y
+
+    def rep_params(self):
+        weight0, bias0 = self.conv3x3.weight, self.conv3x3.bias
+        weight1, bias1 = self.conv1x1_3x3.rep_params()
+        weight2, bias2 = self.conv1x1_sbx.rep_params()
+        weight3, bias3 = self.conv1x1_sby.rep_params()
+        weight4, bias4 = self.conv1x1_lpl.rep_params()
+        rep_weight, rep_bias = (weight0 + weight1 + weight2 + weight3 + weight4), (
+            bias0 + bias1 + bias2 + bias3 + bias4)
+
+        if self.with_idt:
+            device = rep_weight.get_device()
+            if device < 0:
+                device = None
+            weight_idt = torch.zeros(self.out_channels, self.out_channels, 3, 3, device=device)
+            for i in range(self.out_channels):
+                weight_idt[i, i, 1, 1] = 1.0
+            bias_idt = 0.0
+            rep_weight, rep_bias = rep_weight + weight_idt, rep_bias + bias_idt
+        return rep_weight, rep_bias
+
+
+@ARCH_REGISTRY.register()
+class ECBSR(nn.Module):
+    """ECBSR architecture.
+
+    Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices
+    Ref git repo: https://github.com/xindongzhang/ECBSR
+
+    Args:
+        num_in_ch (int): Channel number of inputs.
+        num_out_ch (int): Channel number of outputs.
+        num_block (int): Block number in the trunk network.
+        num_channel (int): Channel number.
+        with_idt (bool): Whether use identity in convolution layers.
+        act_type (str): Activation type.
+        scale (int): Upsampling factor.
+    """
+
+    def __init__(self, num_in_ch, num_out_ch, num_block, num_channel, with_idt, act_type, scale):
+        super(ECBSR, self).__init__()
+        self.num_in_ch = num_in_ch
+        self.scale = scale
+
+        backbone = []
+        backbone += [ECB(num_in_ch, num_channel, depth_multiplier=2.0, act_type=act_type, with_idt=with_idt)]
+        for _ in range(num_block):
+            backbone += [ECB(num_channel, num_channel, depth_multiplier=2.0, act_type=act_type, with_idt=with_idt)]
+        backbone += [
+            ECB(num_channel, num_out_ch * scale * scale, depth_multiplier=2.0, act_type='linear', with_idt=with_idt)
+        ]
+
+        self.backbone = nn.Sequential(*backbone)
+        self.upsampler = nn.PixelShuffle(scale)
+
+    def forward(self, x):
+        if self.num_in_ch > 1:
+            shortcut = torch.repeat_interleave(x, self.scale * self.scale, dim=1)
+        else:
+            shortcut = x  # will repeat the input in the channel dimension (repeat  scale * scale times)
+        y = self.backbone(x) + shortcut
+        y = self.upsampler(y)
+        return y
diff --git a/basicsr/archs/edsr_arch.py b/basicsr/archs/edsr_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..b80566f11fbd4782d68eee8fbf7da686f89dc4e7
--- /dev/null
+++ b/basicsr/archs/edsr_arch.py
@@ -0,0 +1,61 @@
+import torch
+from torch import nn as nn
+
+from basicsr.archs.arch_util import ResidualBlockNoBN, Upsample, make_layer
+from basicsr.utils.registry import ARCH_REGISTRY
+
+
+@ARCH_REGISTRY.register()
+class EDSR(nn.Module):
+    """EDSR network structure.
+
+    Paper: Enhanced Deep Residual Networks for Single Image Super-Resolution.
+    Ref git repo: https://github.com/thstkdgus35/EDSR-PyTorch
+
+    Args:
+        num_in_ch (int): Channel number of inputs.
+        num_out_ch (int): Channel number of outputs.
+        num_feat (int): Channel number of intermediate features.
+            Default: 64.
+        num_block (int): Block number in the trunk network. Default: 16.
+        upscale (int): Upsampling factor. Support 2^n and 3.
+            Default: 4.
+        res_scale (float): Used to scale the residual in residual block.
+            Default: 1.
+        img_range (float): Image range. Default: 255.
+        rgb_mean (tuple[float]): Image mean in RGB orders.
+            Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset.
+    """
+
+    def __init__(self,
+                 num_in_ch,
+                 num_out_ch,
+                 num_feat=64,
+                 num_block=16,
+                 upscale=4,
+                 res_scale=1,
+                 img_range=255.,
+                 rgb_mean=(0.4488, 0.4371, 0.4040)):
+        super(EDSR, self).__init__()
+
+        self.img_range = img_range
+        self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
+
+        self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
+        self.body = make_layer(ResidualBlockNoBN, num_block, num_feat=num_feat, res_scale=res_scale, pytorch_init=True)
+        self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+        self.upsample = Upsample(upscale, num_feat)
+        self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+
+    def forward(self, x):
+        self.mean = self.mean.type_as(x)
+
+        x = (x - self.mean) * self.img_range
+        x = self.conv_first(x)
+        res = self.conv_after_body(self.body(x))
+        res += x
+
+        x = self.conv_last(self.upsample(res))
+        x = x / self.img_range + self.mean
+
+        return x
diff --git a/basicsr/archs/edvr_arch.py b/basicsr/archs/edvr_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0c4f47deb383d4fe6108b97436c9dfb1e541583
--- /dev/null
+++ b/basicsr/archs/edvr_arch.py
@@ -0,0 +1,382 @@
+import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+from basicsr.utils.registry import ARCH_REGISTRY
+from .arch_util import DCNv2Pack, ResidualBlockNoBN, make_layer
+
+
+class PCDAlignment(nn.Module):
+    """Alignment module using Pyramid, Cascading and Deformable convolution
+    (PCD). It is used in EDVR.
+
+    ``Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks``
+
+    Args:
+        num_feat (int): Channel number of middle features. Default: 64.
+        deformable_groups (int): Deformable groups. Defaults: 8.
+    """
+
+    def __init__(self, num_feat=64, deformable_groups=8):
+        super(PCDAlignment, self).__init__()
+
+        # Pyramid has three levels:
+        # L3: level 3, 1/4 spatial size
+        # L2: level 2, 1/2 spatial size
+        # L1: level 1, original spatial size
+        self.offset_conv1 = nn.ModuleDict()
+        self.offset_conv2 = nn.ModuleDict()
+        self.offset_conv3 = nn.ModuleDict()
+        self.dcn_pack = nn.ModuleDict()
+        self.feat_conv = nn.ModuleDict()
+
+        # Pyramids
+        for i in range(3, 0, -1):
+            level = f'l{i}'
+            self.offset_conv1[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
+            if i == 3:
+                self.offset_conv2[level] = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+            else:
+                self.offset_conv2[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
+                self.offset_conv3[level] = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+            self.dcn_pack[level] = DCNv2Pack(num_feat, num_feat, 3, padding=1, deformable_groups=deformable_groups)
+
+            if i < 3:
+                self.feat_conv[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
+
+        # Cascading dcn
+        self.cas_offset_conv1 = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
+        self.cas_offset_conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+        self.cas_dcnpack = DCNv2Pack(num_feat, num_feat, 3, padding=1, deformable_groups=deformable_groups)
+
+        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
+        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+
+    def forward(self, nbr_feat_l, ref_feat_l):
+        """Align neighboring frame features to the reference frame features.
+
+        Args:
+            nbr_feat_l (list[Tensor]): Neighboring feature list. It
+                contains three pyramid levels (L1, L2, L3),
+                each with shape (b, c, h, w).
+            ref_feat_l (list[Tensor]): Reference feature list. It
+                contains three pyramid levels (L1, L2, L3),
+                each with shape (b, c, h, w).
+
+        Returns:
+            Tensor: Aligned features.
+        """
+        # Pyramids
+        upsampled_offset, upsampled_feat = None, None
+        for i in range(3, 0, -1):
+            level = f'l{i}'
+            offset = torch.cat([nbr_feat_l[i - 1], ref_feat_l[i - 1]], dim=1)
+            offset = self.lrelu(self.offset_conv1[level](offset))
+            if i == 3:
+                offset = self.lrelu(self.offset_conv2[level](offset))
+            else:
+                offset = self.lrelu(self.offset_conv2[level](torch.cat([offset, upsampled_offset], dim=1)))
+                offset = self.lrelu(self.offset_conv3[level](offset))
+
+            feat = self.dcn_pack[level](nbr_feat_l[i - 1], offset)
+            if i < 3:
+                feat = self.feat_conv[level](torch.cat([feat, upsampled_feat], dim=1))
+            if i > 1:
+                feat = self.lrelu(feat)
+
+            if i > 1:  # upsample offset and features
+                # x2: when we upsample the offset, we should also enlarge
+                # the magnitude.
+                upsampled_offset = self.upsample(offset) * 2
+                upsampled_feat = self.upsample(feat)
+
+        # Cascading
+        offset = torch.cat([feat, ref_feat_l[0]], dim=1)
+        offset = self.lrelu(self.cas_offset_conv2(self.lrelu(self.cas_offset_conv1(offset))))
+        feat = self.lrelu(self.cas_dcnpack(feat, offset))
+        return feat
+
+
+class TSAFusion(nn.Module):
+    """Temporal Spatial Attention (TSA) fusion module.
+
+    Temporal: Calculate the correlation between center frame and
+        neighboring frames;
+    Spatial: It has 3 pyramid levels, the attention is similar to SFT.
+        (SFT: Recovering realistic texture in image super-resolution by deep
+            spatial feature transform.)
+
+    Args:
+        num_feat (int): Channel number of middle features. Default: 64.
+        num_frame (int): Number of frames. Default: 5.
+        center_frame_idx (int): The index of center frame. Default: 2.
+    """
+
+    def __init__(self, num_feat=64, num_frame=5, center_frame_idx=2):
+        super(TSAFusion, self).__init__()
+        self.center_frame_idx = center_frame_idx
+        # temporal attention (before fusion conv)
+        self.temporal_attn1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+        self.temporal_attn2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+        self.feat_fusion = nn.Conv2d(num_frame * num_feat, num_feat, 1, 1)
+
+        # spatial attention (after fusion conv)
+        self.max_pool = nn.MaxPool2d(3, stride=2, padding=1)
+        self.avg_pool = nn.AvgPool2d(3, stride=2, padding=1)
+        self.spatial_attn1 = nn.Conv2d(num_frame * num_feat, num_feat, 1)
+        self.spatial_attn2 = nn.Conv2d(num_feat * 2, num_feat, 1)
+        self.spatial_attn3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+        self.spatial_attn4 = nn.Conv2d(num_feat, num_feat, 1)
+        self.spatial_attn5 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+        self.spatial_attn_l1 = nn.Conv2d(num_feat, num_feat, 1)
+        self.spatial_attn_l2 = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
+        self.spatial_attn_l3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+        self.spatial_attn_add1 = nn.Conv2d(num_feat, num_feat, 1)
+        self.spatial_attn_add2 = nn.Conv2d(num_feat, num_feat, 1)
+
+        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
+
+    def forward(self, aligned_feat):
+        """
+        Args:
+            aligned_feat (Tensor): Aligned features with shape (b, t, c, h, w).
+
+        Returns:
+            Tensor: Features after TSA with the shape (b, c, h, w).
+        """
+        b, t, c, h, w = aligned_feat.size()
+        # temporal attention
+        embedding_ref = self.temporal_attn1(aligned_feat[:, self.center_frame_idx, :, :, :].clone())
+        embedding = self.temporal_attn2(aligned_feat.view(-1, c, h, w))
+        embedding = embedding.view(b, t, -1, h, w)  # (b, t, c, h, w)
+
+        corr_l = []  # correlation list
+        for i in range(t):
+            emb_neighbor = embedding[:, i, :, :, :]
+            corr = torch.sum(emb_neighbor * embedding_ref, 1)  # (b, h, w)
+            corr_l.append(corr.unsqueeze(1))  # (b, 1, h, w)
+        corr_prob = torch.sigmoid(torch.cat(corr_l, dim=1))  # (b, t, h, w)
+        corr_prob = corr_prob.unsqueeze(2).expand(b, t, c, h, w)
+        corr_prob = corr_prob.contiguous().view(b, -1, h, w)  # (b, t*c, h, w)
+        aligned_feat = aligned_feat.view(b, -1, h, w) * corr_prob
+
+        # fusion
+        feat = self.lrelu(self.feat_fusion(aligned_feat))
+
+        # spatial attention
+        attn = self.lrelu(self.spatial_attn1(aligned_feat))
+        attn_max = self.max_pool(attn)
+        attn_avg = self.avg_pool(attn)
+        attn = self.lrelu(self.spatial_attn2(torch.cat([attn_max, attn_avg], dim=1)))
+        # pyramid levels
+        attn_level = self.lrelu(self.spatial_attn_l1(attn))
+        attn_max = self.max_pool(attn_level)
+        attn_avg = self.avg_pool(attn_level)
+        attn_level = self.lrelu(self.spatial_attn_l2(torch.cat([attn_max, attn_avg], dim=1)))
+        attn_level = self.lrelu(self.spatial_attn_l3(attn_level))
+        attn_level = self.upsample(attn_level)
+
+        attn = self.lrelu(self.spatial_attn3(attn)) + attn_level
+        attn = self.lrelu(self.spatial_attn4(attn))
+        attn = self.upsample(attn)
+        attn = self.spatial_attn5(attn)
+        attn_add = self.spatial_attn_add2(self.lrelu(self.spatial_attn_add1(attn)))
+        attn = torch.sigmoid(attn)
+
+        # after initialization, * 2 makes (attn * 2) to be close to 1.
+        feat = feat * attn * 2 + attn_add
+        return feat
+
+
+class PredeblurModule(nn.Module):
+    """Pre-dublur module.
+
+    Args:
+        num_in_ch (int): Channel number of input image. Default: 3.
+        num_feat (int): Channel number of intermediate features. Default: 64.
+        hr_in (bool): Whether the input has high resolution. Default: False.
+    """
+
+    def __init__(self, num_in_ch=3, num_feat=64, hr_in=False):
+        super(PredeblurModule, self).__init__()
+        self.hr_in = hr_in
+
+        self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
+        if self.hr_in:
+            # downsample x4 by stride conv
+            self.stride_conv_hr1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
+            self.stride_conv_hr2 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
+
+        # generate feature pyramid
+        self.stride_conv_l2 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
+        self.stride_conv_l3 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
+
+        self.resblock_l3 = ResidualBlockNoBN(num_feat=num_feat)
+        self.resblock_l2_1 = ResidualBlockNoBN(num_feat=num_feat)
+        self.resblock_l2_2 = ResidualBlockNoBN(num_feat=num_feat)
+        self.resblock_l1 = nn.ModuleList([ResidualBlockNoBN(num_feat=num_feat) for i in range(5)])
+
+        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
+        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+
+    def forward(self, x):
+        feat_l1 = self.lrelu(self.conv_first(x))
+        if self.hr_in:
+            feat_l1 = self.lrelu(self.stride_conv_hr1(feat_l1))
+            feat_l1 = self.lrelu(self.stride_conv_hr2(feat_l1))
+
+        # generate feature pyramid
+        feat_l2 = self.lrelu(self.stride_conv_l2(feat_l1))
+        feat_l3 = self.lrelu(self.stride_conv_l3(feat_l2))
+
+        feat_l3 = self.upsample(self.resblock_l3(feat_l3))
+        feat_l2 = self.resblock_l2_1(feat_l2) + feat_l3
+        feat_l2 = self.upsample(self.resblock_l2_2(feat_l2))
+
+        for i in range(2):
+            feat_l1 = self.resblock_l1[i](feat_l1)
+        feat_l1 = feat_l1 + feat_l2
+        for i in range(2, 5):
+            feat_l1 = self.resblock_l1[i](feat_l1)
+        return feat_l1
+
+
+@ARCH_REGISTRY.register()
+class EDVR(nn.Module):
+    """EDVR network structure for video super-resolution.
+
+    Now only support X4 upsampling factor.
+
+    ``Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks``
+
+    Args:
+        num_in_ch (int): Channel number of input image. Default: 3.
+        num_out_ch (int): Channel number of output image. Default: 3.
+        num_feat (int): Channel number of intermediate features. Default: 64.
+        num_frame (int): Number of input frames. Default: 5.
+        deformable_groups (int): Deformable groups. Defaults: 8.
+        num_extract_block (int): Number of blocks for feature extraction.
+            Default: 5.
+        num_reconstruct_block (int): Number of blocks for reconstruction.
+            Default: 10.
+        center_frame_idx (int): The index of center frame. Frame counting from
+            0. Default: Middle of input frames.
+        hr_in (bool): Whether the input has high resolution. Default: False.
+        with_predeblur (bool): Whether has predeblur module.
+            Default: False.
+        with_tsa (bool): Whether has TSA module. Default: True.
+    """
+
+    def __init__(self,
+                 num_in_ch=3,
+                 num_out_ch=3,
+                 num_feat=64,
+                 num_frame=5,
+                 deformable_groups=8,
+                 num_extract_block=5,
+                 num_reconstruct_block=10,
+                 center_frame_idx=None,
+                 hr_in=False,
+                 with_predeblur=False,
+                 with_tsa=True):
+        super(EDVR, self).__init__()
+        if center_frame_idx is None:
+            self.center_frame_idx = num_frame // 2
+        else:
+            self.center_frame_idx = center_frame_idx
+        self.hr_in = hr_in
+        self.with_predeblur = with_predeblur
+        self.with_tsa = with_tsa
+
+        # extract features for each frame
+        if self.with_predeblur:
+            self.predeblur = PredeblurModule(num_feat=num_feat, hr_in=self.hr_in)
+            self.conv_1x1 = nn.Conv2d(num_feat, num_feat, 1, 1)
+        else:
+            self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
+
+        # extract pyramid features
+        self.feature_extraction = make_layer(ResidualBlockNoBN, num_extract_block, num_feat=num_feat)
+        self.conv_l2_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
+        self.conv_l2_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+        self.conv_l3_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
+        self.conv_l3_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+
+        # pcd and tsa module
+        self.pcd_align = PCDAlignment(num_feat=num_feat, deformable_groups=deformable_groups)
+        if self.with_tsa:
+            self.fusion = TSAFusion(num_feat=num_feat, num_frame=num_frame, center_frame_idx=self.center_frame_idx)
+        else:
+            self.fusion = nn.Conv2d(num_frame * num_feat, num_feat, 1, 1)
+
+        # reconstruction
+        self.reconstruction = make_layer(ResidualBlockNoBN, num_reconstruct_block, num_feat=num_feat)
+        # upsample
+        self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1)
+        self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1)
+        self.pixel_shuffle = nn.PixelShuffle(2)
+        self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
+        self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
+
+        # activation function
+        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+
+    def forward(self, x):
+        b, t, c, h, w = x.size()
+        if self.hr_in:
+            assert h % 16 == 0 and w % 16 == 0, ('The height and width must be multiple of 16.')
+        else:
+            assert h % 4 == 0 and w % 4 == 0, ('The height and width must be multiple of 4.')
+
+        x_center = x[:, self.center_frame_idx, :, :, :].contiguous()
+
+        # extract features for each frame
+        # L1
+        if self.with_predeblur:
+            feat_l1 = self.conv_1x1(self.predeblur(x.view(-1, c, h, w)))
+            if self.hr_in:
+                h, w = h // 4, w // 4
+        else:
+            feat_l1 = self.lrelu(self.conv_first(x.view(-1, c, h, w)))
+
+        feat_l1 = self.feature_extraction(feat_l1)
+        # L2
+        feat_l2 = self.lrelu(self.conv_l2_1(feat_l1))
+        feat_l2 = self.lrelu(self.conv_l2_2(feat_l2))
+        # L3
+        feat_l3 = self.lrelu(self.conv_l3_1(feat_l2))
+        feat_l3 = self.lrelu(self.conv_l3_2(feat_l3))
+
+        feat_l1 = feat_l1.view(b, t, -1, h, w)
+        feat_l2 = feat_l2.view(b, t, -1, h // 2, w // 2)
+        feat_l3 = feat_l3.view(b, t, -1, h // 4, w // 4)
+
+        # PCD alignment
+        ref_feat_l = [  # reference feature list
+            feat_l1[:, self.center_frame_idx, :, :, :].clone(), feat_l2[:, self.center_frame_idx, :, :, :].clone(),
+            feat_l3[:, self.center_frame_idx, :, :, :].clone()
+        ]
+        aligned_feat = []
+        for i in range(t):
+            nbr_feat_l = [  # neighboring feature list
+                feat_l1[:, i, :, :, :].clone(), feat_l2[:, i, :, :, :].clone(), feat_l3[:, i, :, :, :].clone()
+            ]
+            aligned_feat.append(self.pcd_align(nbr_feat_l, ref_feat_l))
+        aligned_feat = torch.stack(aligned_feat, dim=1)  # (b, t, c, h, w)
+
+        if not self.with_tsa:
+            aligned_feat = aligned_feat.view(b, -1, h, w)
+        feat = self.fusion(aligned_feat)
+
+        out = self.reconstruction(feat)
+        out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
+        out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
+        out = self.lrelu(self.conv_hr(out))
+        out = self.conv_last(out)
+        if self.hr_in:
+            base = x_center
+        else:
+            base = F.interpolate(x_center, scale_factor=4, mode='bilinear', align_corners=False)
+        out += base
+        return out
diff --git a/basicsr/archs/hifacegan_arch.py b/basicsr/archs/hifacegan_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..098e3ed4306eb19ae9da705c0af580a6f74c6cb9
--- /dev/null
+++ b/basicsr/archs/hifacegan_arch.py
@@ -0,0 +1,260 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from basicsr.utils.registry import ARCH_REGISTRY
+from .hifacegan_util import BaseNetwork, LIPEncoder, SPADEResnetBlock, get_nonspade_norm_layer
+
+
+class SPADEGenerator(BaseNetwork):
+    """Generator with SPADEResBlock"""
+
+    def __init__(self,
+                 num_in_ch=3,
+                 num_feat=64,
+                 use_vae=False,
+                 z_dim=256,
+                 crop_size=512,
+                 norm_g='spectralspadesyncbatch3x3',
+                 is_train=True,
+                 init_train_phase=3):  # progressive training disabled
+        super().__init__()
+        self.nf = num_feat
+        self.input_nc = num_in_ch
+        self.is_train = is_train
+        self.train_phase = init_train_phase
+
+        self.scale_ratio = 5  # hardcoded now
+        self.sw = crop_size // (2**self.scale_ratio)
+        self.sh = self.sw  # 20210519: By default use square image, aspect_ratio = 1.0
+
+        if use_vae:
+            # In case of VAE, we will sample from random z vector
+            self.fc = nn.Linear(z_dim, 16 * self.nf * self.sw * self.sh)
+        else:
+            # Otherwise, we make the network deterministic by starting with
+            # downsampled segmentation map instead of random z
+            self.fc = nn.Conv2d(num_in_ch, 16 * self.nf, 3, padding=1)
+
+        self.head_0 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g)
+
+        self.g_middle_0 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g)
+        self.g_middle_1 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g)
+
+        self.ups = nn.ModuleList([
+            SPADEResnetBlock(16 * self.nf, 8 * self.nf, norm_g),
+            SPADEResnetBlock(8 * self.nf, 4 * self.nf, norm_g),
+            SPADEResnetBlock(4 * self.nf, 2 * self.nf, norm_g),
+            SPADEResnetBlock(2 * self.nf, 1 * self.nf, norm_g)
+        ])
+
+        self.to_rgbs = nn.ModuleList([
+            nn.Conv2d(8 * self.nf, 3, 3, padding=1),
+            nn.Conv2d(4 * self.nf, 3, 3, padding=1),
+            nn.Conv2d(2 * self.nf, 3, 3, padding=1),
+            nn.Conv2d(1 * self.nf, 3, 3, padding=1)
+        ])
+
+        self.up = nn.Upsample(scale_factor=2)
+
+    def encode(self, input_tensor):
+        """
+        Encode input_tensor into feature maps, can be overridden in derived classes
+        Default: nearest downsampling of 2**5 = 32 times
+        """
+        h, w = input_tensor.size()[-2:]
+        sh, sw = h // 2**self.scale_ratio, w // 2**self.scale_ratio
+        x = F.interpolate(input_tensor, size=(sh, sw))
+        return self.fc(x)
+
+    def forward(self, x):
+        # In oroginal SPADE, seg means a segmentation map, but here we use x instead.
+        seg = x
+
+        x = self.encode(x)
+        x = self.head_0(x, seg)
+
+        x = self.up(x)
+        x = self.g_middle_0(x, seg)
+        x = self.g_middle_1(x, seg)
+
+        if self.is_train:
+            phase = self.train_phase + 1
+        else:
+            phase = len(self.to_rgbs)
+
+        for i in range(phase):
+            x = self.up(x)
+            x = self.ups[i](x, seg)
+
+        x = self.to_rgbs[phase - 1](F.leaky_relu(x, 2e-1))
+        x = torch.tanh(x)
+
+        return x
+
+    def mixed_guidance_forward(self, input_x, seg=None, n=0, mode='progressive'):
+        """
+        A helper class for subspace visualization. Input and seg are different images.
+        For the first n levels (including encoder) we use input, for the rest we use seg.
+
+        If mode = 'progressive', the output's like: AAABBB
+        If mode = 'one_plug', the output's like:    AAABAA
+        If mode = 'one_ablate', the output's like:  BBBABB
+        """
+
+        if seg is None:
+            return self.forward(input_x)
+
+        if self.is_train:
+            phase = self.train_phase + 1
+        else:
+            phase = len(self.to_rgbs)
+
+        if mode == 'progressive':
+            n = max(min(n, 4 + phase), 0)
+            guide_list = [input_x] * n + [seg] * (4 + phase - n)
+        elif mode == 'one_plug':
+            n = max(min(n, 4 + phase - 1), 0)
+            guide_list = [seg] * (4 + phase)
+            guide_list[n] = input_x
+        elif mode == 'one_ablate':
+            if n > 3 + phase:
+                return self.forward(input_x)
+            guide_list = [input_x] * (4 + phase)
+            guide_list[n] = seg
+
+        x = self.encode(guide_list[0])
+        x = self.head_0(x, guide_list[1])
+
+        x = self.up(x)
+        x = self.g_middle_0(x, guide_list[2])
+        x = self.g_middle_1(x, guide_list[3])
+
+        for i in range(phase):
+            x = self.up(x)
+            x = self.ups[i](x, guide_list[4 + i])
+
+        x = self.to_rgbs[phase - 1](F.leaky_relu(x, 2e-1))
+        x = torch.tanh(x)
+
+        return x
+
+
+@ARCH_REGISTRY.register()
+class HiFaceGAN(SPADEGenerator):
+    """
+    HiFaceGAN: SPADEGenerator with a learnable feature encoder
+    Current encoder design: LIPEncoder
+    """
+
+    def __init__(self,
+                 num_in_ch=3,
+                 num_feat=64,
+                 use_vae=False,
+                 z_dim=256,
+                 crop_size=512,
+                 norm_g='spectralspadesyncbatch3x3',
+                 is_train=True,
+                 init_train_phase=3):
+        super().__init__(num_in_ch, num_feat, use_vae, z_dim, crop_size, norm_g, is_train, init_train_phase)
+        self.lip_encoder = LIPEncoder(num_in_ch, num_feat, self.sw, self.sh, self.scale_ratio)
+
+    def encode(self, input_tensor):
+        return self.lip_encoder(input_tensor)
+
+
+@ARCH_REGISTRY.register()
+class HiFaceGANDiscriminator(BaseNetwork):
+    """
+    Inspired by pix2pixHD multiscale discriminator.
+
+    Args:
+        num_in_ch (int): Channel number of inputs. Default: 3.
+        num_out_ch (int): Channel number of outputs. Default: 3.
+        conditional_d (bool): Whether use conditional discriminator.
+            Default: True.
+        num_d (int): Number of Multiscale discriminators. Default: 3.
+        n_layers_d (int): Number of downsample layers in each D. Default: 4.
+        num_feat (int): Channel number of base intermediate features.
+            Default: 64.
+        norm_d (str): String to determine normalization layers in D.
+            Choices: [spectral][instance/batch/syncbatch]
+            Default: 'spectralinstance'.
+        keep_features (bool): Keep intermediate features for matching loss, etc.
+            Default: True.
+    """
+
+    def __init__(self,
+                 num_in_ch=3,
+                 num_out_ch=3,
+                 conditional_d=True,
+                 num_d=2,
+                 n_layers_d=4,
+                 num_feat=64,
+                 norm_d='spectralinstance',
+                 keep_features=True):
+        super().__init__()
+        self.num_d = num_d
+
+        input_nc = num_in_ch
+        if conditional_d:
+            input_nc += num_out_ch
+
+        for i in range(num_d):
+            subnet_d = NLayerDiscriminator(input_nc, n_layers_d, num_feat, norm_d, keep_features)
+            self.add_module(f'discriminator_{i}', subnet_d)
+
+    def downsample(self, x):
+        return F.avg_pool2d(x, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False)
+
+    # Returns list of lists of discriminator outputs.
+    # The final result is of size opt.num_d x opt.n_layers_D
+    def forward(self, x):
+        result = []
+        for _, _net_d in self.named_children():
+            out = _net_d(x)
+            result.append(out)
+            x = self.downsample(x)
+
+        return result
+
+
+class NLayerDiscriminator(BaseNetwork):
+    """Defines the PatchGAN discriminator with the specified arguments."""
+
+    def __init__(self, input_nc, n_layers_d, num_feat, norm_d, keep_features):
+        super().__init__()
+        kw = 4
+        padw = int(np.ceil((kw - 1.0) / 2))
+        nf = num_feat
+        self.keep_features = keep_features
+
+        norm_layer = get_nonspade_norm_layer(norm_d)
+        sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, False)]]
+
+        for n in range(1, n_layers_d):
+            nf_prev = nf
+            nf = min(nf * 2, 512)
+            stride = 1 if n == n_layers_d - 1 else 2
+            sequence += [[
+                norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=stride, padding=padw)),
+                nn.LeakyReLU(0.2, False)
+            ]]
+
+        sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
+
+        # We divide the layers into groups to extract intermediate layer outputs
+        for n in range(len(sequence)):
+            self.add_module('model' + str(n), nn.Sequential(*sequence[n]))
+
+    def forward(self, x):
+        results = [x]
+        for submodel in self.children():
+            intermediate_output = submodel(results[-1])
+            results.append(intermediate_output)
+
+        if self.keep_features:
+            return results[1:]
+        else:
+            return results[-1]
diff --git a/basicsr/archs/hifacegan_util.py b/basicsr/archs/hifacegan_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..35cbef3f532fcc6aab0fa57ab316a546d3a17bd5
--- /dev/null
+++ b/basicsr/archs/hifacegan_util.py
@@ -0,0 +1,255 @@
+import re
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import init
+# Warning: spectral norm could be buggy
+# under eval mode and multi-GPU inference
+# A workaround is sticking to single-GPU inference and train mode
+from torch.nn.utils import spectral_norm
+
+
+class SPADE(nn.Module):
+
+    def __init__(self, config_text, norm_nc, label_nc):
+        super().__init__()
+
+        assert config_text.startswith('spade')
+        parsed = re.search('spade(\\D+)(\\d)x\\d', config_text)
+        param_free_norm_type = str(parsed.group(1))
+        ks = int(parsed.group(2))
+
+        if param_free_norm_type == 'instance':
+            self.param_free_norm = nn.InstanceNorm2d(norm_nc)
+        elif param_free_norm_type == 'syncbatch':
+            print('SyncBatchNorm is currently not supported under single-GPU mode, switch to "instance" instead')
+            self.param_free_norm = nn.InstanceNorm2d(norm_nc)
+        elif param_free_norm_type == 'batch':
+            self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
+        else:
+            raise ValueError(f'{param_free_norm_type} is not a recognized param-free norm type in SPADE')
+
+        # The dimension of the intermediate embedding space. Yes, hardcoded.
+        nhidden = 128 if norm_nc > 128 else norm_nc
+
+        pw = ks // 2
+        self.mlp_shared = nn.Sequential(nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), nn.ReLU())
+        self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw, bias=False)
+        self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw, bias=False)
+
+    def forward(self, x, segmap):
+
+        # Part 1. generate parameter-free normalized activations
+        normalized = self.param_free_norm(x)
+
+        # Part 2. produce scaling and bias conditioned on semantic map
+        segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
+        actv = self.mlp_shared(segmap)
+        gamma = self.mlp_gamma(actv)
+        beta = self.mlp_beta(actv)
+
+        # apply scale and bias
+        out = normalized * gamma + beta
+
+        return out
+
+
+class SPADEResnetBlock(nn.Module):
+    """
+    ResNet block that uses SPADE. It differs from the ResNet block of pix2pixHD in that
+    it takes in the segmentation map as input, learns the skip connection if necessary,
+    and applies normalization first and then convolution.
+    This architecture seemed like a standard architecture for unconditional or
+    class-conditional GAN architecture using residual block.
+    The code was inspired from https://github.com/LMescheder/GAN_stability.
+    """
+
+    def __init__(self, fin, fout, norm_g='spectralspadesyncbatch3x3', semantic_nc=3):
+        super().__init__()
+        # Attributes
+        self.learned_shortcut = (fin != fout)
+        fmiddle = min(fin, fout)
+
+        # create conv layers
+        self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
+        self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
+        if self.learned_shortcut:
+            self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
+
+        # apply spectral norm if specified
+        if 'spectral' in norm_g:
+            self.conv_0 = spectral_norm(self.conv_0)
+            self.conv_1 = spectral_norm(self.conv_1)
+            if self.learned_shortcut:
+                self.conv_s = spectral_norm(self.conv_s)
+
+        # define normalization layers
+        spade_config_str = norm_g.replace('spectral', '')
+        self.norm_0 = SPADE(spade_config_str, fin, semantic_nc)
+        self.norm_1 = SPADE(spade_config_str, fmiddle, semantic_nc)
+        if self.learned_shortcut:
+            self.norm_s = SPADE(spade_config_str, fin, semantic_nc)
+
+    # note the resnet block with SPADE also takes in |seg|,
+    # the semantic segmentation map as input
+    def forward(self, x, seg):
+        x_s = self.shortcut(x, seg)
+        dx = self.conv_0(self.act(self.norm_0(x, seg)))
+        dx = self.conv_1(self.act(self.norm_1(dx, seg)))
+        out = x_s + dx
+        return out
+
+    def shortcut(self, x, seg):
+        if self.learned_shortcut:
+            x_s = self.conv_s(self.norm_s(x, seg))
+        else:
+            x_s = x
+        return x_s
+
+    def act(self, x):
+        return F.leaky_relu(x, 2e-1)
+
+
+class BaseNetwork(nn.Module):
+    """ A basis for hifacegan archs with custom initialization """
+
+    def init_weights(self, init_type='normal', gain=0.02):
+
+        def init_func(m):
+            classname = m.__class__.__name__
+            if classname.find('BatchNorm2d') != -1:
+                if hasattr(m, 'weight') and m.weight is not None:
+                    init.normal_(m.weight.data, 1.0, gain)
+                if hasattr(m, 'bias') and m.bias is not None:
+                    init.constant_(m.bias.data, 0.0)
+            elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
+                if init_type == 'normal':
+                    init.normal_(m.weight.data, 0.0, gain)
+                elif init_type == 'xavier':
+                    init.xavier_normal_(m.weight.data, gain=gain)
+                elif init_type == 'xavier_uniform':
+                    init.xavier_uniform_(m.weight.data, gain=1.0)
+                elif init_type == 'kaiming':
+                    init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
+                elif init_type == 'orthogonal':
+                    init.orthogonal_(m.weight.data, gain=gain)
+                elif init_type == 'none':  # uses pytorch's default init method
+                    m.reset_parameters()
+                else:
+                    raise NotImplementedError(f'initialization method [{init_type}] is not implemented')
+                if hasattr(m, 'bias') and m.bias is not None:
+                    init.constant_(m.bias.data, 0.0)
+
+        self.apply(init_func)
+
+        # propagate to children
+        for m in self.children():
+            if hasattr(m, 'init_weights'):
+                m.init_weights(init_type, gain)
+
+    def forward(self, x):
+        pass
+
+
+def lip2d(x, logit, kernel=3, stride=2, padding=1):
+    weight = logit.exp()
+    return F.avg_pool2d(x * weight, kernel, stride, padding) / F.avg_pool2d(weight, kernel, stride, padding)
+
+
+class SoftGate(nn.Module):
+    COEFF = 12.0
+
+    def forward(self, x):
+        return torch.sigmoid(x).mul(self.COEFF)
+
+
+class SimplifiedLIP(nn.Module):
+
+    def __init__(self, channels):
+        super(SimplifiedLIP, self).__init__()
+        self.logit = nn.Sequential(
+            nn.Conv2d(channels, channels, 3, padding=1, bias=False), nn.InstanceNorm2d(channels, affine=True),
+            SoftGate())
+
+    def init_layer(self):
+        self.logit[0].weight.data.fill_(0.0)
+
+    def forward(self, x):
+        frac = lip2d(x, self.logit(x))
+        return frac
+
+
+class LIPEncoder(BaseNetwork):
+    """Local Importance-based Pooling (Ziteng Gao et.al.,ICCV 2019)"""
+
+    def __init__(self, input_nc, ngf, sw, sh, n_2xdown, norm_layer=nn.InstanceNorm2d):
+        super().__init__()
+        self.sw = sw
+        self.sh = sh
+        self.max_ratio = 16
+        # 20200310: Several Convolution (stride 1) + LIP blocks, 4 fold
+        kw = 3
+        pw = (kw - 1) // 2
+
+        model = [
+            nn.Conv2d(input_nc, ngf, kw, stride=1, padding=pw, bias=False),
+            norm_layer(ngf),
+            nn.ReLU(),
+        ]
+        cur_ratio = 1
+        for i in range(n_2xdown):
+            next_ratio = min(cur_ratio * 2, self.max_ratio)
+            model += [
+                SimplifiedLIP(ngf * cur_ratio),
+                nn.Conv2d(ngf * cur_ratio, ngf * next_ratio, kw, stride=1, padding=pw),
+                norm_layer(ngf * next_ratio),
+            ]
+            cur_ratio = next_ratio
+            if i < n_2xdown - 1:
+                model += [nn.ReLU(inplace=True)]
+
+        self.model = nn.Sequential(*model)
+
+    def forward(self, x):
+        return self.model(x)
+
+
+def get_nonspade_norm_layer(norm_type='instance'):
+    # helper function to get # output channels of the previous layer
+    def get_out_channel(layer):
+        if hasattr(layer, 'out_channels'):
+            return getattr(layer, 'out_channels')
+        return layer.weight.size(0)
+
+    # this function will be returned
+    def add_norm_layer(layer):
+        nonlocal norm_type
+        if norm_type.startswith('spectral'):
+            layer = spectral_norm(layer)
+            subnorm_type = norm_type[len('spectral'):]
+
+        if subnorm_type == 'none' or len(subnorm_type) == 0:
+            return layer
+
+        # remove bias in the previous layer, which is meaningless
+        # since it has no effect after normalization
+        if getattr(layer, 'bias', None) is not None:
+            delattr(layer, 'bias')
+            layer.register_parameter('bias', None)
+
+        if subnorm_type == 'batch':
+            norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
+        elif subnorm_type == 'sync_batch':
+            print('SyncBatchNorm is currently not supported under single-GPU mode, switch to "instance" instead')
+            # norm_layer = SynchronizedBatchNorm2d(
+            #    get_out_channel(layer), affine=True)
+            norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
+        elif subnorm_type == 'instance':
+            norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
+        else:
+            raise ValueError(f'normalization layer {subnorm_type} is not recognized')
+
+        return nn.Sequential(layer, norm_layer)
+
+    print('This is a legacy from nvlabs/SPADE, and will be removed in future versions.')
+    return add_norm_layer
diff --git a/basicsr/archs/inception.py b/basicsr/archs/inception.py
new file mode 100644
index 0000000000000000000000000000000000000000..de1abef67270dc1aba770943b53577029141f527
--- /dev/null
+++ b/basicsr/archs/inception.py
@@ -0,0 +1,307 @@
+# Modified from https://github.com/mseitzer/pytorch-fid/blob/master/pytorch_fid/inception.py # noqa: E501
+# For FID metric
+
+import os
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.model_zoo import load_url
+from torchvision import models
+
+# Inception weights ported to Pytorch from
+# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
+FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'  # noqa: E501
+LOCAL_FID_WEIGHTS = 'experiments/pretrained_models/pt_inception-2015-12-05-6726825d.pth'  # noqa: E501
+
+
+class InceptionV3(nn.Module):
+    """Pretrained InceptionV3 network returning feature maps"""
+
+    # Index of default block of inception to return,
+    # corresponds to output of final average pooling
+    DEFAULT_BLOCK_INDEX = 3
+
+    # Maps feature dimensionality to their output blocks indices
+    BLOCK_INDEX_BY_DIM = {
+        64: 0,  # First max pooling features
+        192: 1,  # Second max pooling features
+        768: 2,  # Pre-aux classifier features
+        2048: 3  # Final average pooling features
+    }
+
+    def __init__(self,
+                 output_blocks=(DEFAULT_BLOCK_INDEX),
+                 resize_input=True,
+                 normalize_input=True,
+                 requires_grad=False,
+                 use_fid_inception=True):
+        """Build pretrained InceptionV3.
+
+        Args:
+            output_blocks (list[int]): Indices of blocks to return features of.
+                Possible values are:
+                - 0: corresponds to output of first max pooling
+                - 1: corresponds to output of second max pooling
+                - 2: corresponds to output which is fed to aux classifier
+                - 3: corresponds to output of final average pooling
+            resize_input (bool): If true, bilinearly resizes input to width and
+                height 299 before feeding input to model. As the network
+                without fully connected layers is fully convolutional, it
+                should be able to handle inputs of arbitrary size, so resizing
+                might not be strictly needed. Default: True.
+            normalize_input (bool): If true, scales the input from range (0, 1)
+                to the range the pretrained Inception network expects,
+                namely (-1, 1). Default: True.
+            requires_grad (bool): If true, parameters of the model require
+                gradients. Possibly useful for finetuning the network.
+                Default: False.
+            use_fid_inception (bool): If true, uses the pretrained Inception
+                model used in Tensorflow's FID implementation.
+                If false, uses the pretrained Inception model available in
+                torchvision. The FID Inception model has different weights
+                and a slightly different structure from torchvision's
+                Inception model. If you want to compute FID scores, you are
+                strongly advised to set this parameter to true to get
+                comparable results. Default: True.
+        """
+        super(InceptionV3, self).__init__()
+
+        self.resize_input = resize_input
+        self.normalize_input = normalize_input
+        self.output_blocks = sorted(output_blocks)
+        self.last_needed_block = max(output_blocks)
+
+        assert self.last_needed_block <= 3, ('Last possible output block index is 3')
+
+        self.blocks = nn.ModuleList()
+
+        if use_fid_inception:
+            inception = fid_inception_v3()
+        else:
+            try:
+                inception = models.inception_v3(pretrained=True, init_weights=False)
+            except TypeError:
+                # pytorch < 1.5 does not have init_weights for inception_v3
+                inception = models.inception_v3(pretrained=True)
+
+        # Block 0: input to maxpool1
+        block0 = [
+            inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3, inception.Conv2d_2b_3x3,
+            nn.MaxPool2d(kernel_size=3, stride=2)
+        ]
+        self.blocks.append(nn.Sequential(*block0))
+
+        # Block 1: maxpool1 to maxpool2
+        if self.last_needed_block >= 1:
+            block1 = [inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3, nn.MaxPool2d(kernel_size=3, stride=2)]
+            self.blocks.append(nn.Sequential(*block1))
+
+        # Block 2: maxpool2 to aux classifier
+        if self.last_needed_block >= 2:
+            block2 = [
+                inception.Mixed_5b,
+                inception.Mixed_5c,
+                inception.Mixed_5d,
+                inception.Mixed_6a,
+                inception.Mixed_6b,
+                inception.Mixed_6c,
+                inception.Mixed_6d,
+                inception.Mixed_6e,
+            ]
+            self.blocks.append(nn.Sequential(*block2))
+
+        # Block 3: aux classifier to final avgpool
+        if self.last_needed_block >= 3:
+            block3 = [
+                inception.Mixed_7a, inception.Mixed_7b, inception.Mixed_7c,
+                nn.AdaptiveAvgPool2d(output_size=(1, 1))
+            ]
+            self.blocks.append(nn.Sequential(*block3))
+
+        for param in self.parameters():
+            param.requires_grad = requires_grad
+
+    def forward(self, x):
+        """Get Inception feature maps.
+
+        Args:
+            x (Tensor): Input tensor of shape (b, 3, h, w).
+                Values are expected to be in range (-1, 1). You can also input
+                (0, 1) with setting normalize_input = True.
+
+        Returns:
+            list[Tensor]: Corresponding to the selected output block, sorted
+            ascending by index.
+        """
+        output = []
+
+        if self.resize_input:
+            x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False)
+
+        if self.normalize_input:
+            x = 2 * x - 1  # Scale from range (0, 1) to range (-1, 1)
+
+        for idx, block in enumerate(self.blocks):
+            x = block(x)
+            if idx in self.output_blocks:
+                output.append(x)
+
+            if idx == self.last_needed_block:
+                break
+
+        return output
+
+
+def fid_inception_v3():
+    """Build pretrained Inception model for FID computation.
+
+    The Inception model for FID computation uses a different set of weights
+    and has a slightly different structure than torchvision's Inception.
+
+    This method first constructs torchvision's Inception and then patches the
+    necessary parts that are different in the FID Inception model.
+    """
+    try:
+        inception = models.inception_v3(num_classes=1008, aux_logits=False, pretrained=False, init_weights=False)
+    except TypeError:
+        # pytorch < 1.5 does not have init_weights for inception_v3
+        inception = models.inception_v3(num_classes=1008, aux_logits=False, pretrained=False)
+
+    inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
+    inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
+    inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
+    inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
+    inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
+    inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
+    inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
+    inception.Mixed_7b = FIDInceptionE_1(1280)
+    inception.Mixed_7c = FIDInceptionE_2(2048)
+
+    if os.path.exists(LOCAL_FID_WEIGHTS):
+        state_dict = torch.load(LOCAL_FID_WEIGHTS, map_location=lambda storage, loc: storage)
+    else:
+        state_dict = load_url(FID_WEIGHTS_URL, progress=True)
+
+    inception.load_state_dict(state_dict)
+    return inception
+
+
+class FIDInceptionA(models.inception.InceptionA):
+    """InceptionA block patched for FID computation"""
+
+    def __init__(self, in_channels, pool_features):
+        super(FIDInceptionA, self).__init__(in_channels, pool_features)
+
+    def forward(self, x):
+        branch1x1 = self.branch1x1(x)
+
+        branch5x5 = self.branch5x5_1(x)
+        branch5x5 = self.branch5x5_2(branch5x5)
+
+        branch3x3dbl = self.branch3x3dbl_1(x)
+        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
+
+        # Patch: Tensorflow's average pool does not use the padded zero's in
+        # its average calculation
+        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
+        branch_pool = self.branch_pool(branch_pool)
+
+        outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
+        return torch.cat(outputs, 1)
+
+
+class FIDInceptionC(models.inception.InceptionC):
+    """InceptionC block patched for FID computation"""
+
+    def __init__(self, in_channels, channels_7x7):
+        super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
+
+    def forward(self, x):
+        branch1x1 = self.branch1x1(x)
+
+        branch7x7 = self.branch7x7_1(x)
+        branch7x7 = self.branch7x7_2(branch7x7)
+        branch7x7 = self.branch7x7_3(branch7x7)
+
+        branch7x7dbl = self.branch7x7dbl_1(x)
+        branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
+        branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
+        branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
+        branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
+
+        # Patch: Tensorflow's average pool does not use the padded zero's in
+        # its average calculation
+        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
+        branch_pool = self.branch_pool(branch_pool)
+
+        outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
+        return torch.cat(outputs, 1)
+
+
+class FIDInceptionE_1(models.inception.InceptionE):
+    """First InceptionE block patched for FID computation"""
+
+    def __init__(self, in_channels):
+        super(FIDInceptionE_1, self).__init__(in_channels)
+
+    def forward(self, x):
+        branch1x1 = self.branch1x1(x)
+
+        branch3x3 = self.branch3x3_1(x)
+        branch3x3 = [
+            self.branch3x3_2a(branch3x3),
+            self.branch3x3_2b(branch3x3),
+        ]
+        branch3x3 = torch.cat(branch3x3, 1)
+
+        branch3x3dbl = self.branch3x3dbl_1(x)
+        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+        branch3x3dbl = [
+            self.branch3x3dbl_3a(branch3x3dbl),
+            self.branch3x3dbl_3b(branch3x3dbl),
+        ]
+        branch3x3dbl = torch.cat(branch3x3dbl, 1)
+
+        # Patch: Tensorflow's average pool does not use the padded zero's in
+        # its average calculation
+        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
+        branch_pool = self.branch_pool(branch_pool)
+
+        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
+        return torch.cat(outputs, 1)
+
+
+class FIDInceptionE_2(models.inception.InceptionE):
+    """Second InceptionE block patched for FID computation"""
+
+    def __init__(self, in_channels):
+        super(FIDInceptionE_2, self).__init__(in_channels)
+
+    def forward(self, x):
+        branch1x1 = self.branch1x1(x)
+
+        branch3x3 = self.branch3x3_1(x)
+        branch3x3 = [
+            self.branch3x3_2a(branch3x3),
+            self.branch3x3_2b(branch3x3),
+        ]
+        branch3x3 = torch.cat(branch3x3, 1)
+
+        branch3x3dbl = self.branch3x3dbl_1(x)
+        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+        branch3x3dbl = [
+            self.branch3x3dbl_3a(branch3x3dbl),
+            self.branch3x3dbl_3b(branch3x3dbl),
+        ]
+        branch3x3dbl = torch.cat(branch3x3dbl, 1)
+
+        # Patch: The FID Inception model uses max pooling instead of average
+        # pooling. This is likely an error in this specific Inception
+        # implementation, as other Inception models use average pooling here
+        # (which matches the description in the paper).
+        branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
+        branch_pool = self.branch_pool(branch_pool)
+
+        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
+        return torch.cat(outputs, 1)
diff --git a/basicsr/archs/rcan_arch.py b/basicsr/archs/rcan_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..48872e6800006d885f56f90dd2f0a2bd16e513d9
--- /dev/null
+++ b/basicsr/archs/rcan_arch.py
@@ -0,0 +1,135 @@
+import torch
+from torch import nn as nn
+
+from basicsr.utils.registry import ARCH_REGISTRY
+from .arch_util import Upsample, make_layer
+
+
+class ChannelAttention(nn.Module):
+    """Channel attention used in RCAN.
+
+    Args:
+        num_feat (int): Channel number of intermediate features.
+        squeeze_factor (int): Channel squeeze factor. Default: 16.
+    """
+
+    def __init__(self, num_feat, squeeze_factor=16):
+        super(ChannelAttention, self).__init__()
+        self.attention = nn.Sequential(
+            nn.AdaptiveAvgPool2d(1), nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0),
+            nn.ReLU(inplace=True), nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0), nn.Sigmoid())
+
+    def forward(self, x):
+        y = self.attention(x)
+        return x * y
+
+
+class RCAB(nn.Module):
+    """Residual Channel Attention Block (RCAB) used in RCAN.
+
+    Args:
+        num_feat (int): Channel number of intermediate features.
+        squeeze_factor (int): Channel squeeze factor. Default: 16.
+        res_scale (float): Scale the residual. Default: 1.
+    """
+
+    def __init__(self, num_feat, squeeze_factor=16, res_scale=1):
+        super(RCAB, self).__init__()
+        self.res_scale = res_scale
+
+        self.rcab = nn.Sequential(
+            nn.Conv2d(num_feat, num_feat, 3, 1, 1), nn.ReLU(True), nn.Conv2d(num_feat, num_feat, 3, 1, 1),
+            ChannelAttention(num_feat, squeeze_factor))
+
+    def forward(self, x):
+        res = self.rcab(x) * self.res_scale
+        return res + x
+
+
+class ResidualGroup(nn.Module):
+    """Residual Group of RCAB.
+
+    Args:
+        num_feat (int): Channel number of intermediate features.
+        num_block (int): Block number in the body network.
+        squeeze_factor (int): Channel squeeze factor. Default: 16.
+        res_scale (float): Scale the residual. Default: 1.
+    """
+
+    def __init__(self, num_feat, num_block, squeeze_factor=16, res_scale=1):
+        super(ResidualGroup, self).__init__()
+
+        self.residual_group = make_layer(
+            RCAB, num_block, num_feat=num_feat, squeeze_factor=squeeze_factor, res_scale=res_scale)
+        self.conv = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+
+    def forward(self, x):
+        res = self.conv(self.residual_group(x))
+        return res + x
+
+
+@ARCH_REGISTRY.register()
+class RCAN(nn.Module):
+    """Residual Channel Attention Networks.
+
+    ``Paper: Image Super-Resolution Using Very Deep Residual Channel Attention Networks``
+
+    Reference: https://github.com/yulunzhang/RCAN
+
+    Args:
+        num_in_ch (int): Channel number of inputs.
+        num_out_ch (int): Channel number of outputs.
+        num_feat (int): Channel number of intermediate features.
+            Default: 64.
+        num_group (int): Number of ResidualGroup. Default: 10.
+        num_block (int): Number of RCAB in ResidualGroup. Default: 16.
+        squeeze_factor (int): Channel squeeze factor. Default: 16.
+        upscale (int): Upsampling factor. Support 2^n and 3.
+            Default: 4.
+        res_scale (float): Used to scale the residual in residual block.
+            Default: 1.
+        img_range (float): Image range. Default: 255.
+        rgb_mean (tuple[float]): Image mean in RGB orders.
+            Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset.
+    """
+
+    def __init__(self,
+                 num_in_ch,
+                 num_out_ch,
+                 num_feat=64,
+                 num_group=10,
+                 num_block=16,
+                 squeeze_factor=16,
+                 upscale=4,
+                 res_scale=1,
+                 img_range=255.,
+                 rgb_mean=(0.4488, 0.4371, 0.4040)):
+        super(RCAN, self).__init__()
+
+        self.img_range = img_range
+        self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
+
+        self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
+        self.body = make_layer(
+            ResidualGroup,
+            num_group,
+            num_feat=num_feat,
+            num_block=num_block,
+            squeeze_factor=squeeze_factor,
+            res_scale=res_scale)
+        self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+        self.upsample = Upsample(upscale, num_feat)
+        self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+
+    def forward(self, x):
+        self.mean = self.mean.type_as(x)
+
+        x = (x - self.mean) * self.img_range
+        x = self.conv_first(x)
+        res = self.conv_after_body(self.body(x))
+        res += x
+
+        x = self.conv_last(self.upsample(res))
+        x = x / self.img_range + self.mean
+
+        return x
diff --git a/basicsr/archs/ridnet_arch.py b/basicsr/archs/ridnet_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..85bb9ae0348e27dd6c797c03f8d9ec43f8b0b829
--- /dev/null
+++ b/basicsr/archs/ridnet_arch.py
@@ -0,0 +1,180 @@
+import torch
+import torch.nn as nn
+
+from basicsr.utils.registry import ARCH_REGISTRY
+from .arch_util import ResidualBlockNoBN, make_layer
+
+
+class MeanShift(nn.Conv2d):
+    """ Data normalization with mean and std.
+
+    Args:
+        rgb_range (int): Maximum value of RGB.
+        rgb_mean (list[float]): Mean for RGB channels.
+        rgb_std (list[float]): Std for RGB channels.
+        sign (int): For subtraction, sign is -1, for addition, sign is 1.
+            Default: -1.
+        requires_grad (bool): Whether to update the self.weight and self.bias.
+            Default: True.
+    """
+
+    def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1, requires_grad=True):
+        super(MeanShift, self).__init__(3, 3, kernel_size=1)
+        std = torch.Tensor(rgb_std)
+        self.weight.data = torch.eye(3).view(3, 3, 1, 1)
+        self.weight.data.div_(std.view(3, 1, 1, 1))
+        self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
+        self.bias.data.div_(std)
+        self.requires_grad = requires_grad
+
+
+class EResidualBlockNoBN(nn.Module):
+    """Enhanced Residual block without BN.
+
+    There are three convolution layers in residual branch.
+    """
+
+    def __init__(self, in_channels, out_channels):
+        super(EResidualBlockNoBN, self).__init__()
+
+        self.body = nn.Sequential(
+            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(out_channels, out_channels, 1, 1, 0),
+        )
+        self.relu = nn.ReLU(inplace=True)
+
+    def forward(self, x):
+        out = self.body(x)
+        out = self.relu(out + x)
+        return out
+
+
+class MergeRun(nn.Module):
+    """ Merge-and-run unit.
+
+    This unit contains two branches with different dilated convolutions,
+    followed by a convolution to process the concatenated features.
+
+    Paper: Real Image Denoising with Feature Attention
+    Ref git repo: https://github.com/saeed-anwar/RIDNet
+    """
+
+    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
+        super(MergeRun, self).__init__()
+
+        self.dilation1 = nn.Sequential(
+            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding), nn.ReLU(inplace=True),
+            nn.Conv2d(out_channels, out_channels, kernel_size, stride, 2, 2), nn.ReLU(inplace=True))
+        self.dilation2 = nn.Sequential(
+            nn.Conv2d(in_channels, out_channels, kernel_size, stride, 3, 3), nn.ReLU(inplace=True),
+            nn.Conv2d(out_channels, out_channels, kernel_size, stride, 4, 4), nn.ReLU(inplace=True))
+
+        self.aggregation = nn.Sequential(
+            nn.Conv2d(out_channels * 2, out_channels, kernel_size, stride, padding), nn.ReLU(inplace=True))
+
+    def forward(self, x):
+        dilation1 = self.dilation1(x)
+        dilation2 = self.dilation2(x)
+        out = torch.cat([dilation1, dilation2], dim=1)
+        out = self.aggregation(out)
+        out = out + x
+        return out
+
+
+class ChannelAttention(nn.Module):
+    """Channel attention.
+
+    Args:
+        num_feat (int): Channel number of intermediate features.
+        squeeze_factor (int): Channel squeeze factor. Default:
+    """
+
+    def __init__(self, mid_channels, squeeze_factor=16):
+        super(ChannelAttention, self).__init__()
+        self.attention = nn.Sequential(
+            nn.AdaptiveAvgPool2d(1), nn.Conv2d(mid_channels, mid_channels // squeeze_factor, 1, padding=0),
+            nn.ReLU(inplace=True), nn.Conv2d(mid_channels // squeeze_factor, mid_channels, 1, padding=0), nn.Sigmoid())
+
+    def forward(self, x):
+        y = self.attention(x)
+        return x * y
+
+
+class EAM(nn.Module):
+    """Enhancement attention modules (EAM) in RIDNet.
+
+    This module contains a merge-and-run unit, a residual block,
+    an enhanced residual block and a feature attention unit.
+
+    Attributes:
+        merge: The merge-and-run unit.
+        block1: The residual block.
+        block2: The enhanced residual block.
+        ca: The feature/channel attention unit.
+    """
+
+    def __init__(self, in_channels, mid_channels, out_channels):
+        super(EAM, self).__init__()
+
+        self.merge = MergeRun(in_channels, mid_channels)
+        self.block1 = ResidualBlockNoBN(mid_channels)
+        self.block2 = EResidualBlockNoBN(mid_channels, out_channels)
+        self.ca = ChannelAttention(out_channels)
+        # The residual block in the paper contains a relu after addition.
+        self.relu = nn.ReLU(inplace=True)
+
+    def forward(self, x):
+        out = self.merge(x)
+        out = self.relu(self.block1(out))
+        out = self.block2(out)
+        out = self.ca(out)
+        return out
+
+
+@ARCH_REGISTRY.register()
+class RIDNet(nn.Module):
+    """RIDNet: Real Image Denoising with Feature Attention.
+
+    Ref git repo: https://github.com/saeed-anwar/RIDNet
+
+    Args:
+        in_channels (int): Channel number of inputs.
+        mid_channels (int): Channel number of EAM modules.
+            Default: 64.
+        out_channels (int): Channel number of outputs.
+        num_block (int): Number of EAM. Default: 4.
+        img_range (float): Image range. Default: 255.
+        rgb_mean (tuple[float]): Image mean in RGB orders.
+            Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset.
+    """
+
+    def __init__(self,
+                 in_channels,
+                 mid_channels,
+                 out_channels,
+                 num_block=4,
+                 img_range=255.,
+                 rgb_mean=(0.4488, 0.4371, 0.4040),
+                 rgb_std=(1.0, 1.0, 1.0)):
+        super(RIDNet, self).__init__()
+
+        self.sub_mean = MeanShift(img_range, rgb_mean, rgb_std)
+        self.add_mean = MeanShift(img_range, rgb_mean, rgb_std, 1)
+
+        self.head = nn.Conv2d(in_channels, mid_channels, 3, 1, 1)
+        self.body = make_layer(
+            EAM, num_block, in_channels=mid_channels, mid_channels=mid_channels, out_channels=mid_channels)
+        self.tail = nn.Conv2d(mid_channels, out_channels, 3, 1, 1)
+
+        self.relu = nn.ReLU(inplace=True)
+
+    def forward(self, x):
+        res = self.sub_mean(x)
+        res = self.tail(self.body(self.relu(self.head(res))))
+        res = self.add_mean(res)
+
+        out = x + res
+        return out
diff --git a/basicsr/archs/rrdbnet_arch.py b/basicsr/archs/rrdbnet_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..63d07080c2ec1305090c59b7bfbbda2b003b18e4
--- /dev/null
+++ b/basicsr/archs/rrdbnet_arch.py
@@ -0,0 +1,119 @@
+import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+from basicsr.utils.registry import ARCH_REGISTRY
+from .arch_util import default_init_weights, make_layer, pixel_unshuffle
+
+
+class ResidualDenseBlock(nn.Module):
+    """Residual Dense Block.
+
+    Used in RRDB block in ESRGAN.
+
+    Args:
+        num_feat (int): Channel number of intermediate features.
+        num_grow_ch (int): Channels for each growth.
+    """
+
+    def __init__(self, num_feat=64, num_grow_ch=32):
+        super(ResidualDenseBlock, self).__init__()
+        self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
+        self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
+        self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
+        self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
+        self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
+
+        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+
+        # initialization
+        default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
+
+    def forward(self, x):
+        x1 = self.lrelu(self.conv1(x))
+        x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
+        x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
+        x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
+        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
+        # Empirically, we use 0.2 to scale the residual for better performance
+        return x5 * 0.2 + x
+
+
+class RRDB(nn.Module):
+    """Residual in Residual Dense Block.
+
+    Used in RRDB-Net in ESRGAN.
+
+    Args:
+        num_feat (int): Channel number of intermediate features.
+        num_grow_ch (int): Channels for each growth.
+    """
+
+    def __init__(self, num_feat, num_grow_ch=32):
+        super(RRDB, self).__init__()
+        self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
+        self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
+        self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
+
+    def forward(self, x):
+        out = self.rdb1(x)
+        out = self.rdb2(out)
+        out = self.rdb3(out)
+        # Empirically, we use 0.2 to scale the residual for better performance
+        return out * 0.2 + x
+
+
+@ARCH_REGISTRY.register()
+class RRDBNet(nn.Module):
+    """Networks consisting of Residual in Residual Dense Block, which is used
+    in ESRGAN.
+
+    ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
+
+    We extend ESRGAN for scale x2 and scale x1.
+    Note: This is one option for scale 1, scale 2 in RRDBNet.
+    We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
+    and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
+
+    Args:
+        num_in_ch (int): Channel number of inputs.
+        num_out_ch (int): Channel number of outputs.
+        num_feat (int): Channel number of intermediate features.
+            Default: 64
+        num_block (int): Block number in the trunk network. Defaults: 23
+        num_grow_ch (int): Channels for each growth. Default: 32.
+    """
+
+    def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
+        super(RRDBNet, self).__init__()
+        self.scale = scale
+        if scale == 2:
+            num_in_ch = num_in_ch * 4
+        elif scale == 1:
+            num_in_ch = num_in_ch * 16
+        self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
+        self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
+        self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+        # upsample
+        self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+        self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+        self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+        self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+
+        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+
+    def forward(self, x):
+        if self.scale == 2:
+            feat = pixel_unshuffle(x, scale=2)
+        elif self.scale == 1:
+            feat = pixel_unshuffle(x, scale=4)
+        else:
+            feat = x
+        feat = self.conv_first(feat)
+        body_feat = self.conv_body(self.body(feat))
+        feat = feat + body_feat
+        # upsample
+        feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
+        feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
+        out = self.conv_last(self.lrelu(self.conv_hr(feat)))
+        return out
diff --git a/basicsr/archs/spynet_arch.py b/basicsr/archs/spynet_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c7af133daef0496b79a57517e1942d06f2d0061
--- /dev/null
+++ b/basicsr/archs/spynet_arch.py
@@ -0,0 +1,96 @@
+import math
+import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+from basicsr.utils.registry import ARCH_REGISTRY
+from .arch_util import flow_warp
+
+
+class BasicModule(nn.Module):
+    """Basic Module for SpyNet.
+    """
+
+    def __init__(self):
+        super(BasicModule, self).__init__()
+
+        self.basic_module = nn.Sequential(
+            nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
+            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
+            nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
+            nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
+            nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3))
+
+    def forward(self, tensor_input):
+        return self.basic_module(tensor_input)
+
+
+@ARCH_REGISTRY.register()
+class SpyNet(nn.Module):
+    """SpyNet architecture.
+
+    Args:
+        load_path (str): path for pretrained SpyNet. Default: None.
+    """
+
+    def __init__(self, load_path=None):
+        super(SpyNet, self).__init__()
+        self.basic_module = nn.ModuleList([BasicModule() for _ in range(6)])
+        if load_path:
+            self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params'])
+
+        self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
+        self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
+
+    def preprocess(self, tensor_input):
+        tensor_output = (tensor_input - self.mean) / self.std
+        return tensor_output
+
+    def process(self, ref, supp):
+        flow = []
+
+        ref = [self.preprocess(ref)]
+        supp = [self.preprocess(supp)]
+
+        for level in range(5):
+            ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False))
+            supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False))
+
+        flow = ref[0].new_zeros(
+            [ref[0].size(0), 2,
+             int(math.floor(ref[0].size(2) / 2.0)),
+             int(math.floor(ref[0].size(3) / 2.0))])
+
+        for level in range(len(ref)):
+            upsampled_flow = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0
+
+            if upsampled_flow.size(2) != ref[level].size(2):
+                upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 0, 0, 1], mode='replicate')
+            if upsampled_flow.size(3) != ref[level].size(3):
+                upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 1, 0, 0], mode='replicate')
+
+            flow = self.basic_module[level](torch.cat([
+                ref[level],
+                flow_warp(
+                    supp[level], upsampled_flow.permute(0, 2, 3, 1), interp_mode='bilinear', padding_mode='border'),
+                upsampled_flow
+            ], 1)) + upsampled_flow
+
+        return flow
+
+    def forward(self, ref, supp):
+        assert ref.size() == supp.size()
+
+        h, w = ref.size(2), ref.size(3)
+        w_floor = math.floor(math.ceil(w / 32.0) * 32.0)
+        h_floor = math.floor(math.ceil(h / 32.0) * 32.0)
+
+        ref = F.interpolate(input=ref, size=(h_floor, w_floor), mode='bilinear', align_corners=False)
+        supp = F.interpolate(input=supp, size=(h_floor, w_floor), mode='bilinear', align_corners=False)
+
+        flow = F.interpolate(input=self.process(ref, supp), size=(h, w), mode='bilinear', align_corners=False)
+
+        flow[:, 0, :, :] *= float(w) / float(w_floor)
+        flow[:, 1, :, :] *= float(h) / float(h_floor)
+
+        return flow
diff --git a/basicsr/archs/srresnet_arch.py b/basicsr/archs/srresnet_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f571557cd7d9ba8791bd6462fccf648c57186d2
--- /dev/null
+++ b/basicsr/archs/srresnet_arch.py
@@ -0,0 +1,65 @@
+from torch import nn as nn
+from torch.nn import functional as F
+
+from basicsr.utils.registry import ARCH_REGISTRY
+from .arch_util import ResidualBlockNoBN, default_init_weights, make_layer
+
+
+@ARCH_REGISTRY.register()
+class MSRResNet(nn.Module):
+    """Modified SRResNet.
+
+    A compacted version modified from SRResNet in
+    "Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network"
+    It uses residual blocks without BN, similar to EDSR.
+    Currently, it supports x2, x3 and x4 upsampling scale factor.
+
+    Args:
+        num_in_ch (int): Channel number of inputs. Default: 3.
+        num_out_ch (int): Channel number of outputs. Default: 3.
+        num_feat (int): Channel number of intermediate features. Default: 64.
+        num_block (int): Block number in the body network. Default: 16.
+        upscale (int): Upsampling factor. Support x2, x3 and x4. Default: 4.
+    """
+
+    def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=16, upscale=4):
+        super(MSRResNet, self).__init__()
+        self.upscale = upscale
+
+        self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
+        self.body = make_layer(ResidualBlockNoBN, num_block, num_feat=num_feat)
+
+        # upsampling
+        if self.upscale in [2, 3]:
+            self.upconv1 = nn.Conv2d(num_feat, num_feat * self.upscale * self.upscale, 3, 1, 1)
+            self.pixel_shuffle = nn.PixelShuffle(self.upscale)
+        elif self.upscale == 4:
+            self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1)
+            self.upconv2 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1)
+            self.pixel_shuffle = nn.PixelShuffle(2)
+
+        self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+        self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+
+        # activation function
+        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+
+        # initialization
+        default_init_weights([self.conv_first, self.upconv1, self.conv_hr, self.conv_last], 0.1)
+        if self.upscale == 4:
+            default_init_weights(self.upconv2, 0.1)
+
+    def forward(self, x):
+        feat = self.lrelu(self.conv_first(x))
+        out = self.body(feat)
+
+        if self.upscale == 4:
+            out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
+            out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
+        elif self.upscale in [2, 3]:
+            out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
+
+        out = self.conv_last(self.lrelu(self.conv_hr(out)))
+        base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False)
+        out += base
+        return out
diff --git a/basicsr/archs/srvgg_arch.py b/basicsr/archs/srvgg_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8fe5ceb40ed9edd35d81ee17aff86f2e3d9adb4
--- /dev/null
+++ b/basicsr/archs/srvgg_arch.py
@@ -0,0 +1,70 @@
+from torch import nn as nn
+from torch.nn import functional as F
+
+from basicsr.utils.registry import ARCH_REGISTRY
+
+
+@ARCH_REGISTRY.register(suffix='basicsr')
+class SRVGGNetCompact(nn.Module):
+    """A compact VGG-style network structure for super-resolution.
+
+    It is a compact network structure, which performs upsampling in the last layer and no convolution is
+    conducted on the HR feature space.
+
+    Args:
+        num_in_ch (int): Channel number of inputs. Default: 3.
+        num_out_ch (int): Channel number of outputs. Default: 3.
+        num_feat (int): Channel number of intermediate features. Default: 64.
+        num_conv (int): Number of convolution layers in the body network. Default: 16.
+        upscale (int): Upsampling factor. Default: 4.
+        act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu.
+    """
+
+    def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
+        super(SRVGGNetCompact, self).__init__()
+        self.num_in_ch = num_in_ch
+        self.num_out_ch = num_out_ch
+        self.num_feat = num_feat
+        self.num_conv = num_conv
+        self.upscale = upscale
+        self.act_type = act_type
+
+        self.body = nn.ModuleList()
+        # the first conv
+        self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
+        # the first activation
+        if act_type == 'relu':
+            activation = nn.ReLU(inplace=True)
+        elif act_type == 'prelu':
+            activation = nn.PReLU(num_parameters=num_feat)
+        elif act_type == 'leakyrelu':
+            activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+        self.body.append(activation)
+
+        # the body structure
+        for _ in range(num_conv):
+            self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
+            # activation
+            if act_type == 'relu':
+                activation = nn.ReLU(inplace=True)
+            elif act_type == 'prelu':
+                activation = nn.PReLU(num_parameters=num_feat)
+            elif act_type == 'leakyrelu':
+                activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+            self.body.append(activation)
+
+        # the last conv
+        self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
+        # upsample
+        self.upsampler = nn.PixelShuffle(upscale)
+
+    def forward(self, x):
+        out = x
+        for i in range(0, len(self.body)):
+            out = self.body[i](out)
+
+        out = self.upsampler(out)
+        # add the nearest upsampled image, so that the network learns the residual
+        base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
+        out += base
+        return out
diff --git a/basicsr/archs/stylegan2_arch.py b/basicsr/archs/stylegan2_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ab37f5a33a2ef21641de35109c16b511a6df163
--- /dev/null
+++ b/basicsr/archs/stylegan2_arch.py
@@ -0,0 +1,799 @@
+import math
+import random
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from basicsr.ops.fused_act import FusedLeakyReLU, fused_leaky_relu
+from basicsr.ops.upfirdn2d import upfirdn2d
+from basicsr.utils.registry import ARCH_REGISTRY
+
+
+class NormStyleCode(nn.Module):
+
+    def forward(self, x):
+        """Normalize the style codes.
+
+        Args:
+            x (Tensor): Style codes with shape (b, c).
+
+        Returns:
+            Tensor: Normalized tensor.
+        """
+        return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
+
+
+def make_resample_kernel(k):
+    """Make resampling kernel for UpFirDn.
+
+    Args:
+        k (list[int]): A list indicating the 1D resample kernel magnitude.
+
+    Returns:
+        Tensor: 2D resampled kernel.
+    """
+    k = torch.tensor(k, dtype=torch.float32)
+    if k.ndim == 1:
+        k = k[None, :] * k[:, None]  # to 2D kernel, outer product
+    # normalize
+    k /= k.sum()
+    return k
+
+
+class UpFirDnUpsample(nn.Module):
+    """Upsample, FIR filter, and downsample (upsampole version).
+
+    References:
+    1. https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.upfirdn.html  # noqa: E501
+    2. http://www.ece.northwestern.edu/local-apps/matlabhelp/toolbox/signal/upfirdn.html  # noqa: E501
+
+    Args:
+        resample_kernel (list[int]): A list indicating the 1D resample kernel
+            magnitude.
+        factor (int): Upsampling scale factor. Default: 2.
+    """
+
+    def __init__(self, resample_kernel, factor=2):
+        super(UpFirDnUpsample, self).__init__()
+        self.kernel = make_resample_kernel(resample_kernel) * (factor**2)
+        self.factor = factor
+
+        pad = self.kernel.shape[0] - factor
+        self.pad = ((pad + 1) // 2 + factor - 1, pad // 2)
+
+    def forward(self, x):
+        out = upfirdn2d(x, self.kernel.type_as(x), up=self.factor, down=1, pad=self.pad)
+        return out
+
+    def __repr__(self):
+        return (f'{self.__class__.__name__}(factor={self.factor})')
+
+
+class UpFirDnDownsample(nn.Module):
+    """Upsample, FIR filter, and downsample (downsampole version).
+
+    Args:
+        resample_kernel (list[int]): A list indicating the 1D resample kernel
+            magnitude.
+        factor (int): Downsampling scale factor. Default: 2.
+    """
+
+    def __init__(self, resample_kernel, factor=2):
+        super(UpFirDnDownsample, self).__init__()
+        self.kernel = make_resample_kernel(resample_kernel)
+        self.factor = factor
+
+        pad = self.kernel.shape[0] - factor
+        self.pad = ((pad + 1) // 2, pad // 2)
+
+    def forward(self, x):
+        out = upfirdn2d(x, self.kernel.type_as(x), up=1, down=self.factor, pad=self.pad)
+        return out
+
+    def __repr__(self):
+        return (f'{self.__class__.__name__}(factor={self.factor})')
+
+
+class UpFirDnSmooth(nn.Module):
+    """Upsample, FIR filter, and downsample (smooth version).
+
+    Args:
+        resample_kernel (list[int]): A list indicating the 1D resample kernel
+            magnitude.
+        upsample_factor (int): Upsampling scale factor. Default: 1.
+        downsample_factor (int): Downsampling scale factor. Default: 1.
+        kernel_size (int): Kernel size: Default: 1.
+    """
+
+    def __init__(self, resample_kernel, upsample_factor=1, downsample_factor=1, kernel_size=1):
+        super(UpFirDnSmooth, self).__init__()
+        self.upsample_factor = upsample_factor
+        self.downsample_factor = downsample_factor
+        self.kernel = make_resample_kernel(resample_kernel)
+        if upsample_factor > 1:
+            self.kernel = self.kernel * (upsample_factor**2)
+
+        if upsample_factor > 1:
+            pad = (self.kernel.shape[0] - upsample_factor) - (kernel_size - 1)
+            self.pad = ((pad + 1) // 2 + upsample_factor - 1, pad // 2 + 1)
+        elif downsample_factor > 1:
+            pad = (self.kernel.shape[0] - downsample_factor) + (kernel_size - 1)
+            self.pad = ((pad + 1) // 2, pad // 2)
+        else:
+            raise NotImplementedError
+
+    def forward(self, x):
+        out = upfirdn2d(x, self.kernel.type_as(x), up=1, down=1, pad=self.pad)
+        return out
+
+    def __repr__(self):
+        return (f'{self.__class__.__name__}(upsample_factor={self.upsample_factor}'
+                f', downsample_factor={self.downsample_factor})')
+
+
+class EqualLinear(nn.Module):
+    """Equalized Linear as StyleGAN2.
+
+    Args:
+        in_channels (int): Size of each sample.
+        out_channels (int): Size of each output sample.
+        bias (bool): If set to ``False``, the layer will not learn an additive
+            bias. Default: ``True``.
+        bias_init_val (float): Bias initialized value. Default: 0.
+        lr_mul (float): Learning rate multiplier. Default: 1.
+        activation (None | str): The activation after ``linear`` operation.
+            Supported: 'fused_lrelu', None. Default: None.
+    """
+
+    def __init__(self, in_channels, out_channels, bias=True, bias_init_val=0, lr_mul=1, activation=None):
+        super(EqualLinear, self).__init__()
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.lr_mul = lr_mul
+        self.activation = activation
+        if self.activation not in ['fused_lrelu', None]:
+            raise ValueError(f'Wrong activation value in EqualLinear: {activation}'
+                             "Supported ones are: ['fused_lrelu', None].")
+        self.scale = (1 / math.sqrt(in_channels)) * lr_mul
+
+        self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul))
+        if bias:
+            self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
+        else:
+            self.register_parameter('bias', None)
+
+    def forward(self, x):
+        if self.bias is None:
+            bias = None
+        else:
+            bias = self.bias * self.lr_mul
+        if self.activation == 'fused_lrelu':
+            out = F.linear(x, self.weight * self.scale)
+            out = fused_leaky_relu(out, bias)
+        else:
+            out = F.linear(x, self.weight * self.scale, bias=bias)
+        return out
+
+    def __repr__(self):
+        return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
+                f'out_channels={self.out_channels}, bias={self.bias is not None})')
+
+
+class ModulatedConv2d(nn.Module):
+    """Modulated Conv2d used in StyleGAN2.
+
+    There is no bias in ModulatedConv2d.
+
+    Args:
+        in_channels (int): Channel number of the input.
+        out_channels (int): Channel number of the output.
+        kernel_size (int): Size of the convolving kernel.
+        num_style_feat (int): Channel number of style features.
+        demodulate (bool): Whether to demodulate in the conv layer.
+            Default: True.
+        sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
+            Default: None.
+        resample_kernel (list[int]): A list indicating the 1D resample kernel
+            magnitude. Default: (1, 3, 3, 1).
+        eps (float): A value added to the denominator for numerical stability.
+            Default: 1e-8.
+    """
+
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 kernel_size,
+                 num_style_feat,
+                 demodulate=True,
+                 sample_mode=None,
+                 resample_kernel=(1, 3, 3, 1),
+                 eps=1e-8):
+        super(ModulatedConv2d, self).__init__()
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.kernel_size = kernel_size
+        self.demodulate = demodulate
+        self.sample_mode = sample_mode
+        self.eps = eps
+
+        if self.sample_mode == 'upsample':
+            self.smooth = UpFirDnSmooth(
+                resample_kernel, upsample_factor=2, downsample_factor=1, kernel_size=kernel_size)
+        elif self.sample_mode == 'downsample':
+            self.smooth = UpFirDnSmooth(
+                resample_kernel, upsample_factor=1, downsample_factor=2, kernel_size=kernel_size)
+        elif self.sample_mode is None:
+            pass
+        else:
+            raise ValueError(f'Wrong sample mode {self.sample_mode}, '
+                             "supported ones are ['upsample', 'downsample', None].")
+
+        self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
+        # modulation inside each modulated conv
+        self.modulation = EqualLinear(
+            num_style_feat, in_channels, bias=True, bias_init_val=1, lr_mul=1, activation=None)
+
+        self.weight = nn.Parameter(torch.randn(1, out_channels, in_channels, kernel_size, kernel_size))
+        self.padding = kernel_size // 2
+
+    def forward(self, x, style):
+        """Forward function.
+
+        Args:
+            x (Tensor): Tensor with shape (b, c, h, w).
+            style (Tensor): Tensor with shape (b, num_style_feat).
+
+        Returns:
+            Tensor: Modulated tensor after convolution.
+        """
+        b, c, h, w = x.shape  # c = c_in
+        # weight modulation
+        style = self.modulation(style).view(b, 1, c, 1, 1)
+        # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
+        weight = self.scale * self.weight * style  # (b, c_out, c_in, k, k)
+
+        if self.demodulate:
+            demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
+            weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
+
+        weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
+
+        if self.sample_mode == 'upsample':
+            x = x.view(1, b * c, h, w)
+            weight = weight.view(b, self.out_channels, c, self.kernel_size, self.kernel_size)
+            weight = weight.transpose(1, 2).reshape(b * c, self.out_channels, self.kernel_size, self.kernel_size)
+            out = F.conv_transpose2d(x, weight, padding=0, stride=2, groups=b)
+            out = out.view(b, self.out_channels, *out.shape[2:4])
+            out = self.smooth(out)
+        elif self.sample_mode == 'downsample':
+            x = self.smooth(x)
+            x = x.view(1, b * c, *x.shape[2:4])
+            out = F.conv2d(x, weight, padding=0, stride=2, groups=b)
+            out = out.view(b, self.out_channels, *out.shape[2:4])
+        else:
+            x = x.view(1, b * c, h, w)
+            # weight: (b*c_out, c_in, k, k), groups=b
+            out = F.conv2d(x, weight, padding=self.padding, groups=b)
+            out = out.view(b, self.out_channels, *out.shape[2:4])
+
+        return out
+
+    def __repr__(self):
+        return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
+                f'out_channels={self.out_channels}, '
+                f'kernel_size={self.kernel_size}, '
+                f'demodulate={self.demodulate}, sample_mode={self.sample_mode})')
+
+
+class StyleConv(nn.Module):
+    """Style conv.
+
+    Args:
+        in_channels (int): Channel number of the input.
+        out_channels (int): Channel number of the output.
+        kernel_size (int): Size of the convolving kernel.
+        num_style_feat (int): Channel number of style features.
+        demodulate (bool): Whether demodulate in the conv layer. Default: True.
+        sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
+            Default: None.
+        resample_kernel (list[int]): A list indicating the 1D resample kernel
+            magnitude. Default: (1, 3, 3, 1).
+    """
+
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 kernel_size,
+                 num_style_feat,
+                 demodulate=True,
+                 sample_mode=None,
+                 resample_kernel=(1, 3, 3, 1)):
+        super(StyleConv, self).__init__()
+        self.modulated_conv = ModulatedConv2d(
+            in_channels,
+            out_channels,
+            kernel_size,
+            num_style_feat,
+            demodulate=demodulate,
+            sample_mode=sample_mode,
+            resample_kernel=resample_kernel)
+        self.weight = nn.Parameter(torch.zeros(1))  # for noise injection
+        self.activate = FusedLeakyReLU(out_channels)
+
+    def forward(self, x, style, noise=None):
+        # modulate
+        out = self.modulated_conv(x, style)
+        # noise injection
+        if noise is None:
+            b, _, h, w = out.shape
+            noise = out.new_empty(b, 1, h, w).normal_()
+        out = out + self.weight * noise
+        # activation (with bias)
+        out = self.activate(out)
+        return out
+
+
+class ToRGB(nn.Module):
+    """To RGB from features.
+
+    Args:
+        in_channels (int): Channel number of input.
+        num_style_feat (int): Channel number of style features.
+        upsample (bool): Whether to upsample. Default: True.
+        resample_kernel (list[int]): A list indicating the 1D resample kernel
+            magnitude. Default: (1, 3, 3, 1).
+    """
+
+    def __init__(self, in_channels, num_style_feat, upsample=True, resample_kernel=(1, 3, 3, 1)):
+        super(ToRGB, self).__init__()
+        if upsample:
+            self.upsample = UpFirDnUpsample(resample_kernel, factor=2)
+        else:
+            self.upsample = None
+        self.modulated_conv = ModulatedConv2d(
+            in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None)
+        self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
+
+    def forward(self, x, style, skip=None):
+        """Forward function.
+
+        Args:
+            x (Tensor): Feature tensor with shape (b, c, h, w).
+            style (Tensor): Tensor with shape (b, num_style_feat).
+            skip (Tensor): Base/skip tensor. Default: None.
+
+        Returns:
+            Tensor: RGB images.
+        """
+        out = self.modulated_conv(x, style)
+        out = out + self.bias
+        if skip is not None:
+            if self.upsample:
+                skip = self.upsample(skip)
+            out = out + skip
+        return out
+
+
+class ConstantInput(nn.Module):
+    """Constant input.
+
+    Args:
+        num_channel (int): Channel number of constant input.
+        size (int): Spatial size of constant input.
+    """
+
+    def __init__(self, num_channel, size):
+        super(ConstantInput, self).__init__()
+        self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
+
+    def forward(self, batch):
+        out = self.weight.repeat(batch, 1, 1, 1)
+        return out
+
+
+@ARCH_REGISTRY.register()
+class StyleGAN2Generator(nn.Module):
+    """StyleGAN2 Generator.
+
+    Args:
+        out_size (int): The spatial size of outputs.
+        num_style_feat (int): Channel number of style features. Default: 512.
+        num_mlp (int): Layer number of MLP style layers. Default: 8.
+        channel_multiplier (int): Channel multiplier for large networks of
+            StyleGAN2. Default: 2.
+        resample_kernel (list[int]): A list indicating the 1D resample kernel
+            magnitude. A cross production will be applied to extent 1D resample
+            kernel to 2D resample kernel. Default: (1, 3, 3, 1).
+        lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
+        narrow (float): Narrow ratio for channels. Default: 1.0.
+    """
+
+    def __init__(self,
+                 out_size,
+                 num_style_feat=512,
+                 num_mlp=8,
+                 channel_multiplier=2,
+                 resample_kernel=(1, 3, 3, 1),
+                 lr_mlp=0.01,
+                 narrow=1):
+        super(StyleGAN2Generator, self).__init__()
+        # Style MLP layers
+        self.num_style_feat = num_style_feat
+        style_mlp_layers = [NormStyleCode()]
+        for i in range(num_mlp):
+            style_mlp_layers.append(
+                EqualLinear(
+                    num_style_feat, num_style_feat, bias=True, bias_init_val=0, lr_mul=lr_mlp,
+                    activation='fused_lrelu'))
+        self.style_mlp = nn.Sequential(*style_mlp_layers)
+
+        channels = {
+            '4': int(512 * narrow),
+            '8': int(512 * narrow),
+            '16': int(512 * narrow),
+            '32': int(512 * narrow),
+            '64': int(256 * channel_multiplier * narrow),
+            '128': int(128 * channel_multiplier * narrow),
+            '256': int(64 * channel_multiplier * narrow),
+            '512': int(32 * channel_multiplier * narrow),
+            '1024': int(16 * channel_multiplier * narrow)
+        }
+        self.channels = channels
+
+        self.constant_input = ConstantInput(channels['4'], size=4)
+        self.style_conv1 = StyleConv(
+            channels['4'],
+            channels['4'],
+            kernel_size=3,
+            num_style_feat=num_style_feat,
+            demodulate=True,
+            sample_mode=None,
+            resample_kernel=resample_kernel)
+        self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False, resample_kernel=resample_kernel)
+
+        self.log_size = int(math.log(out_size, 2))
+        self.num_layers = (self.log_size - 2) * 2 + 1
+        self.num_latent = self.log_size * 2 - 2
+
+        self.style_convs = nn.ModuleList()
+        self.to_rgbs = nn.ModuleList()
+        self.noises = nn.Module()
+
+        in_channels = channels['4']
+        # noise
+        for layer_idx in range(self.num_layers):
+            resolution = 2**((layer_idx + 5) // 2)
+            shape = [1, 1, resolution, resolution]
+            self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
+        # style convs and to_rgbs
+        for i in range(3, self.log_size + 1):
+            out_channels = channels[f'{2**i}']
+            self.style_convs.append(
+                StyleConv(
+                    in_channels,
+                    out_channels,
+                    kernel_size=3,
+                    num_style_feat=num_style_feat,
+                    demodulate=True,
+                    sample_mode='upsample',
+                    resample_kernel=resample_kernel,
+                ))
+            self.style_convs.append(
+                StyleConv(
+                    out_channels,
+                    out_channels,
+                    kernel_size=3,
+                    num_style_feat=num_style_feat,
+                    demodulate=True,
+                    sample_mode=None,
+                    resample_kernel=resample_kernel))
+            self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True, resample_kernel=resample_kernel))
+            in_channels = out_channels
+
+    def make_noise(self):
+        """Make noise for noise injection."""
+        device = self.constant_input.weight.device
+        noises = [torch.randn(1, 1, 4, 4, device=device)]
+
+        for i in range(3, self.log_size + 1):
+            for _ in range(2):
+                noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
+
+        return noises
+
+    def get_latent(self, x):
+        return self.style_mlp(x)
+
+    def mean_latent(self, num_latent):
+        latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
+        latent = self.style_mlp(latent_in).mean(0, keepdim=True)
+        return latent
+
+    def forward(self,
+                styles,
+                input_is_latent=False,
+                noise=None,
+                randomize_noise=True,
+                truncation=1,
+                truncation_latent=None,
+                inject_index=None,
+                return_latents=False):
+        """Forward function for StyleGAN2Generator.
+
+        Args:
+            styles (list[Tensor]): Sample codes of styles.
+            input_is_latent (bool): Whether input is latent style.
+                Default: False.
+            noise (Tensor | None): Input noise or None. Default: None.
+            randomize_noise (bool): Randomize noise, used when 'noise' is
+                False. Default: True.
+            truncation (float): TODO. Default: 1.
+            truncation_latent (Tensor | None): TODO. Default: None.
+            inject_index (int | None): The injection index for mixing noise.
+                Default: None.
+            return_latents (bool): Whether to return style latents.
+                Default: False.
+        """
+        # style codes -> latents with Style MLP layer
+        if not input_is_latent:
+            styles = [self.style_mlp(s) for s in styles]
+        # noises
+        if noise is None:
+            if randomize_noise:
+                noise = [None] * self.num_layers  # for each style conv layer
+            else:  # use the stored noise
+                noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
+        # style truncation
+        if truncation < 1:
+            style_truncation = []
+            for style in styles:
+                style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
+            styles = style_truncation
+        # get style latent with injection
+        if len(styles) == 1:
+            inject_index = self.num_latent
+
+            if styles[0].ndim < 3:
+                # repeat latent code for all the layers
+                latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
+            else:  # used for encoder with different latent code for each layer
+                latent = styles[0]
+        elif len(styles) == 2:  # mixing noises
+            if inject_index is None:
+                inject_index = random.randint(1, self.num_latent - 1)
+            latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
+            latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
+            latent = torch.cat([latent1, latent2], 1)
+
+        # main generation
+        out = self.constant_input(latent.shape[0])
+        out = self.style_conv1(out, latent[:, 0], noise=noise[0])
+        skip = self.to_rgb1(out, latent[:, 1])
+
+        i = 1
+        for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
+                                                        noise[2::2], self.to_rgbs):
+            out = conv1(out, latent[:, i], noise=noise1)
+            out = conv2(out, latent[:, i + 1], noise=noise2)
+            skip = to_rgb(out, latent[:, i + 2], skip)
+            i += 2
+
+        image = skip
+
+        if return_latents:
+            return image, latent
+        else:
+            return image, None
+
+
+class ScaledLeakyReLU(nn.Module):
+    """Scaled LeakyReLU.
+
+    Args:
+        negative_slope (float): Negative slope. Default: 0.2.
+    """
+
+    def __init__(self, negative_slope=0.2):
+        super(ScaledLeakyReLU, self).__init__()
+        self.negative_slope = negative_slope
+
+    def forward(self, x):
+        out = F.leaky_relu(x, negative_slope=self.negative_slope)
+        return out * math.sqrt(2)
+
+
+class EqualConv2d(nn.Module):
+    """Equalized Linear as StyleGAN2.
+
+    Args:
+        in_channels (int): Channel number of the input.
+        out_channels (int): Channel number of the output.
+        kernel_size (int): Size of the convolving kernel.
+        stride (int): Stride of the convolution. Default: 1
+        padding (int): Zero-padding added to both sides of the input.
+            Default: 0.
+        bias (bool): If ``True``, adds a learnable bias to the output.
+            Default: ``True``.
+        bias_init_val (float): Bias initialized value. Default: 0.
+    """
+
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, bias_init_val=0):
+        super(EqualConv2d, self).__init__()
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.kernel_size = kernel_size
+        self.stride = stride
+        self.padding = padding
+        self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
+
+        self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
+        if bias:
+            self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
+        else:
+            self.register_parameter('bias', None)
+
+    def forward(self, x):
+        out = F.conv2d(
+            x,
+            self.weight * self.scale,
+            bias=self.bias,
+            stride=self.stride,
+            padding=self.padding,
+        )
+
+        return out
+
+    def __repr__(self):
+        return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
+                f'out_channels={self.out_channels}, '
+                f'kernel_size={self.kernel_size},'
+                f' stride={self.stride}, padding={self.padding}, '
+                f'bias={self.bias is not None})')
+
+
+class ConvLayer(nn.Sequential):
+    """Conv Layer used in StyleGAN2 Discriminator.
+
+    Args:
+        in_channels (int): Channel number of the input.
+        out_channels (int): Channel number of the output.
+        kernel_size (int): Kernel size.
+        downsample (bool): Whether downsample by a factor of 2.
+            Default: False.
+        resample_kernel (list[int]): A list indicating the 1D resample
+            kernel magnitude. A cross production will be applied to
+            extent 1D resample kernel to 2D resample kernel.
+            Default: (1, 3, 3, 1).
+        bias (bool): Whether with bias. Default: True.
+        activate (bool): Whether use activateion. Default: True.
+    """
+
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 kernel_size,
+                 downsample=False,
+                 resample_kernel=(1, 3, 3, 1),
+                 bias=True,
+                 activate=True):
+        layers = []
+        # downsample
+        if downsample:
+            layers.append(
+                UpFirDnSmooth(resample_kernel, upsample_factor=1, downsample_factor=2, kernel_size=kernel_size))
+            stride = 2
+            self.padding = 0
+        else:
+            stride = 1
+            self.padding = kernel_size // 2
+        # conv
+        layers.append(
+            EqualConv2d(
+                in_channels, out_channels, kernel_size, stride=stride, padding=self.padding, bias=bias
+                and not activate))
+        # activation
+        if activate:
+            if bias:
+                layers.append(FusedLeakyReLU(out_channels))
+            else:
+                layers.append(ScaledLeakyReLU(0.2))
+
+        super(ConvLayer, self).__init__(*layers)
+
+
+class ResBlock(nn.Module):
+    """Residual block used in StyleGAN2 Discriminator.
+
+    Args:
+        in_channels (int): Channel number of the input.
+        out_channels (int): Channel number of the output.
+        resample_kernel (list[int]): A list indicating the 1D resample
+            kernel magnitude. A cross production will be applied to
+            extent 1D resample kernel to 2D resample kernel.
+            Default: (1, 3, 3, 1).
+    """
+
+    def __init__(self, in_channels, out_channels, resample_kernel=(1, 3, 3, 1)):
+        super(ResBlock, self).__init__()
+
+        self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
+        self.conv2 = ConvLayer(
+            in_channels, out_channels, 3, downsample=True, resample_kernel=resample_kernel, bias=True, activate=True)
+        self.skip = ConvLayer(
+            in_channels, out_channels, 1, downsample=True, resample_kernel=resample_kernel, bias=False, activate=False)
+
+    def forward(self, x):
+        out = self.conv1(x)
+        out = self.conv2(out)
+        skip = self.skip(x)
+        out = (out + skip) / math.sqrt(2)
+        return out
+
+
+@ARCH_REGISTRY.register()
+class StyleGAN2Discriminator(nn.Module):
+    """StyleGAN2 Discriminator.
+
+    Args:
+        out_size (int): The spatial size of outputs.
+        channel_multiplier (int): Channel multiplier for large networks of
+            StyleGAN2. Default: 2.
+        resample_kernel (list[int]): A list indicating the 1D resample kernel
+            magnitude. A cross production will be applied to extent 1D resample
+            kernel to 2D resample kernel. Default: (1, 3, 3, 1).
+        stddev_group (int): For group stddev statistics. Default: 4.
+        narrow (float): Narrow ratio for channels. Default: 1.0.
+    """
+
+    def __init__(self, out_size, channel_multiplier=2, resample_kernel=(1, 3, 3, 1), stddev_group=4, narrow=1):
+        super(StyleGAN2Discriminator, self).__init__()
+
+        channels = {
+            '4': int(512 * narrow),
+            '8': int(512 * narrow),
+            '16': int(512 * narrow),
+            '32': int(512 * narrow),
+            '64': int(256 * channel_multiplier * narrow),
+            '128': int(128 * channel_multiplier * narrow),
+            '256': int(64 * channel_multiplier * narrow),
+            '512': int(32 * channel_multiplier * narrow),
+            '1024': int(16 * channel_multiplier * narrow)
+        }
+
+        log_size = int(math.log(out_size, 2))
+
+        conv_body = [ConvLayer(3, channels[f'{out_size}'], 1, bias=True, activate=True)]
+
+        in_channels = channels[f'{out_size}']
+        for i in range(log_size, 2, -1):
+            out_channels = channels[f'{2**(i - 1)}']
+            conv_body.append(ResBlock(in_channels, out_channels, resample_kernel))
+            in_channels = out_channels
+        self.conv_body = nn.Sequential(*conv_body)
+
+        self.final_conv = ConvLayer(in_channels + 1, channels['4'], 3, bias=True, activate=True)
+        self.final_linear = nn.Sequential(
+            EqualLinear(
+                channels['4'] * 4 * 4, channels['4'], bias=True, bias_init_val=0, lr_mul=1, activation='fused_lrelu'),
+            EqualLinear(channels['4'], 1, bias=True, bias_init_val=0, lr_mul=1, activation=None),
+        )
+        self.stddev_group = stddev_group
+        self.stddev_feat = 1
+
+    def forward(self, x):
+        out = self.conv_body(x)
+
+        b, c, h, w = out.shape
+        # concatenate a group stddev statistics to out
+        group = min(b, self.stddev_group)  # Minibatch must be divisible by (or smaller than) group_size
+        stddev = out.view(group, -1, self.stddev_feat, c // self.stddev_feat, h, w)
+        stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
+        stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
+        stddev = stddev.repeat(group, 1, h, w)
+        out = torch.cat([out, stddev], 1)
+
+        out = self.final_conv(out)
+        out = out.view(b, -1)
+        out = self.final_linear(out)
+
+        return out
diff --git a/basicsr/archs/stylegan2_bilinear_arch.py b/basicsr/archs/stylegan2_bilinear_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..2395170411f9d11f2798ac03cf6ec6eb32fe5e43
--- /dev/null
+++ b/basicsr/archs/stylegan2_bilinear_arch.py
@@ -0,0 +1,614 @@
+import math
+import random
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from basicsr.ops.fused_act import FusedLeakyReLU, fused_leaky_relu
+from basicsr.utils.registry import ARCH_REGISTRY
+
+
+class NormStyleCode(nn.Module):
+
+    def forward(self, x):
+        """Normalize the style codes.
+
+        Args:
+            x (Tensor): Style codes with shape (b, c).
+
+        Returns:
+            Tensor: Normalized tensor.
+        """
+        return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
+
+
+class EqualLinear(nn.Module):
+    """Equalized Linear as StyleGAN2.
+
+    Args:
+        in_channels (int): Size of each sample.
+        out_channels (int): Size of each output sample.
+        bias (bool): If set to ``False``, the layer will not learn an additive
+            bias. Default: ``True``.
+        bias_init_val (float): Bias initialized value. Default: 0.
+        lr_mul (float): Learning rate multiplier. Default: 1.
+        activation (None | str): The activation after ``linear`` operation.
+            Supported: 'fused_lrelu', None. Default: None.
+    """
+
+    def __init__(self, in_channels, out_channels, bias=True, bias_init_val=0, lr_mul=1, activation=None):
+        super(EqualLinear, self).__init__()
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.lr_mul = lr_mul
+        self.activation = activation
+        if self.activation not in ['fused_lrelu', None]:
+            raise ValueError(f'Wrong activation value in EqualLinear: {activation}'
+                             "Supported ones are: ['fused_lrelu', None].")
+        self.scale = (1 / math.sqrt(in_channels)) * lr_mul
+
+        self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul))
+        if bias:
+            self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
+        else:
+            self.register_parameter('bias', None)
+
+    def forward(self, x):
+        if self.bias is None:
+            bias = None
+        else:
+            bias = self.bias * self.lr_mul
+        if self.activation == 'fused_lrelu':
+            out = F.linear(x, self.weight * self.scale)
+            out = fused_leaky_relu(out, bias)
+        else:
+            out = F.linear(x, self.weight * self.scale, bias=bias)
+        return out
+
+    def __repr__(self):
+        return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
+                f'out_channels={self.out_channels}, bias={self.bias is not None})')
+
+
+class ModulatedConv2d(nn.Module):
+    """Modulated Conv2d used in StyleGAN2.
+
+    There is no bias in ModulatedConv2d.
+
+    Args:
+        in_channels (int): Channel number of the input.
+        out_channels (int): Channel number of the output.
+        kernel_size (int): Size of the convolving kernel.
+        num_style_feat (int): Channel number of style features.
+        demodulate (bool): Whether to demodulate in the conv layer.
+            Default: True.
+        sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
+            Default: None.
+        eps (float): A value added to the denominator for numerical stability.
+            Default: 1e-8.
+    """
+
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 kernel_size,
+                 num_style_feat,
+                 demodulate=True,
+                 sample_mode=None,
+                 eps=1e-8,
+                 interpolation_mode='bilinear'):
+        super(ModulatedConv2d, self).__init__()
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.kernel_size = kernel_size
+        self.demodulate = demodulate
+        self.sample_mode = sample_mode
+        self.eps = eps
+        self.interpolation_mode = interpolation_mode
+        if self.interpolation_mode == 'nearest':
+            self.align_corners = None
+        else:
+            self.align_corners = False
+
+        self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
+        # modulation inside each modulated conv
+        self.modulation = EqualLinear(
+            num_style_feat, in_channels, bias=True, bias_init_val=1, lr_mul=1, activation=None)
+
+        self.weight = nn.Parameter(torch.randn(1, out_channels, in_channels, kernel_size, kernel_size))
+        self.padding = kernel_size // 2
+
+    def forward(self, x, style):
+        """Forward function.
+
+        Args:
+            x (Tensor): Tensor with shape (b, c, h, w).
+            style (Tensor): Tensor with shape (b, num_style_feat).
+
+        Returns:
+            Tensor: Modulated tensor after convolution.
+        """
+        b, c, h, w = x.shape  # c = c_in
+        # weight modulation
+        style = self.modulation(style).view(b, 1, c, 1, 1)
+        # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
+        weight = self.scale * self.weight * style  # (b, c_out, c_in, k, k)
+
+        if self.demodulate:
+            demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
+            weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
+
+        weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
+
+        if self.sample_mode == 'upsample':
+            x = F.interpolate(x, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners)
+        elif self.sample_mode == 'downsample':
+            x = F.interpolate(x, scale_factor=0.5, mode=self.interpolation_mode, align_corners=self.align_corners)
+
+        b, c, h, w = x.shape
+        x = x.view(1, b * c, h, w)
+        # weight: (b*c_out, c_in, k, k), groups=b
+        out = F.conv2d(x, weight, padding=self.padding, groups=b)
+        out = out.view(b, self.out_channels, *out.shape[2:4])
+
+        return out
+
+    def __repr__(self):
+        return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
+                f'out_channels={self.out_channels}, '
+                f'kernel_size={self.kernel_size}, '
+                f'demodulate={self.demodulate}, sample_mode={self.sample_mode})')
+
+
+class StyleConv(nn.Module):
+    """Style conv.
+
+    Args:
+        in_channels (int): Channel number of the input.
+        out_channels (int): Channel number of the output.
+        kernel_size (int): Size of the convolving kernel.
+        num_style_feat (int): Channel number of style features.
+        demodulate (bool): Whether demodulate in the conv layer. Default: True.
+        sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
+            Default: None.
+    """
+
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 kernel_size,
+                 num_style_feat,
+                 demodulate=True,
+                 sample_mode=None,
+                 interpolation_mode='bilinear'):
+        super(StyleConv, self).__init__()
+        self.modulated_conv = ModulatedConv2d(
+            in_channels,
+            out_channels,
+            kernel_size,
+            num_style_feat,
+            demodulate=demodulate,
+            sample_mode=sample_mode,
+            interpolation_mode=interpolation_mode)
+        self.weight = nn.Parameter(torch.zeros(1))  # for noise injection
+        self.activate = FusedLeakyReLU(out_channels)
+
+    def forward(self, x, style, noise=None):
+        # modulate
+        out = self.modulated_conv(x, style)
+        # noise injection
+        if noise is None:
+            b, _, h, w = out.shape
+            noise = out.new_empty(b, 1, h, w).normal_()
+        out = out + self.weight * noise
+        # activation (with bias)
+        out = self.activate(out)
+        return out
+
+
+class ToRGB(nn.Module):
+    """To RGB from features.
+
+    Args:
+        in_channels (int): Channel number of input.
+        num_style_feat (int): Channel number of style features.
+        upsample (bool): Whether to upsample. Default: True.
+    """
+
+    def __init__(self, in_channels, num_style_feat, upsample=True, interpolation_mode='bilinear'):
+        super(ToRGB, self).__init__()
+        self.upsample = upsample
+        self.interpolation_mode = interpolation_mode
+        if self.interpolation_mode == 'nearest':
+            self.align_corners = None
+        else:
+            self.align_corners = False
+        self.modulated_conv = ModulatedConv2d(
+            in_channels,
+            3,
+            kernel_size=1,
+            num_style_feat=num_style_feat,
+            demodulate=False,
+            sample_mode=None,
+            interpolation_mode=interpolation_mode)
+        self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
+
+    def forward(self, x, style, skip=None):
+        """Forward function.
+
+        Args:
+            x (Tensor): Feature tensor with shape (b, c, h, w).
+            style (Tensor): Tensor with shape (b, num_style_feat).
+            skip (Tensor): Base/skip tensor. Default: None.
+
+        Returns:
+            Tensor: RGB images.
+        """
+        out = self.modulated_conv(x, style)
+        out = out + self.bias
+        if skip is not None:
+            if self.upsample:
+                skip = F.interpolate(
+                    skip, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners)
+            out = out + skip
+        return out
+
+
+class ConstantInput(nn.Module):
+    """Constant input.
+
+    Args:
+        num_channel (int): Channel number of constant input.
+        size (int): Spatial size of constant input.
+    """
+
+    def __init__(self, num_channel, size):
+        super(ConstantInput, self).__init__()
+        self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
+
+    def forward(self, batch):
+        out = self.weight.repeat(batch, 1, 1, 1)
+        return out
+
+
+@ARCH_REGISTRY.register(suffix='basicsr')
+class StyleGAN2GeneratorBilinear(nn.Module):
+    """StyleGAN2 Generator.
+
+    Args:
+        out_size (int): The spatial size of outputs.
+        num_style_feat (int): Channel number of style features. Default: 512.
+        num_mlp (int): Layer number of MLP style layers. Default: 8.
+        channel_multiplier (int): Channel multiplier for large networks of
+            StyleGAN2. Default: 2.
+        lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
+        narrow (float): Narrow ratio for channels. Default: 1.0.
+    """
+
+    def __init__(self,
+                 out_size,
+                 num_style_feat=512,
+                 num_mlp=8,
+                 channel_multiplier=2,
+                 lr_mlp=0.01,
+                 narrow=1,
+                 interpolation_mode='bilinear'):
+        super(StyleGAN2GeneratorBilinear, self).__init__()
+        # Style MLP layers
+        self.num_style_feat = num_style_feat
+        style_mlp_layers = [NormStyleCode()]
+        for i in range(num_mlp):
+            style_mlp_layers.append(
+                EqualLinear(
+                    num_style_feat, num_style_feat, bias=True, bias_init_val=0, lr_mul=lr_mlp,
+                    activation='fused_lrelu'))
+        self.style_mlp = nn.Sequential(*style_mlp_layers)
+
+        channels = {
+            '4': int(512 * narrow),
+            '8': int(512 * narrow),
+            '16': int(512 * narrow),
+            '32': int(512 * narrow),
+            '64': int(256 * channel_multiplier * narrow),
+            '128': int(128 * channel_multiplier * narrow),
+            '256': int(64 * channel_multiplier * narrow),
+            '512': int(32 * channel_multiplier * narrow),
+            '1024': int(16 * channel_multiplier * narrow)
+        }
+        self.channels = channels
+
+        self.constant_input = ConstantInput(channels['4'], size=4)
+        self.style_conv1 = StyleConv(
+            channels['4'],
+            channels['4'],
+            kernel_size=3,
+            num_style_feat=num_style_feat,
+            demodulate=True,
+            sample_mode=None,
+            interpolation_mode=interpolation_mode)
+        self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False, interpolation_mode=interpolation_mode)
+
+        self.log_size = int(math.log(out_size, 2))
+        self.num_layers = (self.log_size - 2) * 2 + 1
+        self.num_latent = self.log_size * 2 - 2
+
+        self.style_convs = nn.ModuleList()
+        self.to_rgbs = nn.ModuleList()
+        self.noises = nn.Module()
+
+        in_channels = channels['4']
+        # noise
+        for layer_idx in range(self.num_layers):
+            resolution = 2**((layer_idx + 5) // 2)
+            shape = [1, 1, resolution, resolution]
+            self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
+        # style convs and to_rgbs
+        for i in range(3, self.log_size + 1):
+            out_channels = channels[f'{2**i}']
+            self.style_convs.append(
+                StyleConv(
+                    in_channels,
+                    out_channels,
+                    kernel_size=3,
+                    num_style_feat=num_style_feat,
+                    demodulate=True,
+                    sample_mode='upsample',
+                    interpolation_mode=interpolation_mode))
+            self.style_convs.append(
+                StyleConv(
+                    out_channels,
+                    out_channels,
+                    kernel_size=3,
+                    num_style_feat=num_style_feat,
+                    demodulate=True,
+                    sample_mode=None,
+                    interpolation_mode=interpolation_mode))
+            self.to_rgbs.append(
+                ToRGB(out_channels, num_style_feat, upsample=True, interpolation_mode=interpolation_mode))
+            in_channels = out_channels
+
+    def make_noise(self):
+        """Make noise for noise injection."""
+        device = self.constant_input.weight.device
+        noises = [torch.randn(1, 1, 4, 4, device=device)]
+
+        for i in range(3, self.log_size + 1):
+            for _ in range(2):
+                noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
+
+        return noises
+
+    def get_latent(self, x):
+        return self.style_mlp(x)
+
+    def mean_latent(self, num_latent):
+        latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
+        latent = self.style_mlp(latent_in).mean(0, keepdim=True)
+        return latent
+
+    def forward(self,
+                styles,
+                input_is_latent=False,
+                noise=None,
+                randomize_noise=True,
+                truncation=1,
+                truncation_latent=None,
+                inject_index=None,
+                return_latents=False):
+        """Forward function for StyleGAN2Generator.
+
+        Args:
+            styles (list[Tensor]): Sample codes of styles.
+            input_is_latent (bool): Whether input is latent style.
+                Default: False.
+            noise (Tensor | None): Input noise or None. Default: None.
+            randomize_noise (bool): Randomize noise, used when 'noise' is
+                False. Default: True.
+            truncation (float): TODO. Default: 1.
+            truncation_latent (Tensor | None): TODO. Default: None.
+            inject_index (int | None): The injection index for mixing noise.
+                Default: None.
+            return_latents (bool): Whether to return style latents.
+                Default: False.
+        """
+        # style codes -> latents with Style MLP layer
+        if not input_is_latent:
+            styles = [self.style_mlp(s) for s in styles]
+        # noises
+        if noise is None:
+            if randomize_noise:
+                noise = [None] * self.num_layers  # for each style conv layer
+            else:  # use the stored noise
+                noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
+        # style truncation
+        if truncation < 1:
+            style_truncation = []
+            for style in styles:
+                style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
+            styles = style_truncation
+        # get style latent with injection
+        if len(styles) == 1:
+            inject_index = self.num_latent
+
+            if styles[0].ndim < 3:
+                # repeat latent code for all the layers
+                latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
+            else:  # used for encoder with different latent code for each layer
+                latent = styles[0]
+        elif len(styles) == 2:  # mixing noises
+            if inject_index is None:
+                inject_index = random.randint(1, self.num_latent - 1)
+            latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
+            latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
+            latent = torch.cat([latent1, latent2], 1)
+
+        # main generation
+        out = self.constant_input(latent.shape[0])
+        out = self.style_conv1(out, latent[:, 0], noise=noise[0])
+        skip = self.to_rgb1(out, latent[:, 1])
+
+        i = 1
+        for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
+                                                        noise[2::2], self.to_rgbs):
+            out = conv1(out, latent[:, i], noise=noise1)
+            out = conv2(out, latent[:, i + 1], noise=noise2)
+            skip = to_rgb(out, latent[:, i + 2], skip)
+            i += 2
+
+        image = skip
+
+        if return_latents:
+            return image, latent
+        else:
+            return image, None
+
+
+class ScaledLeakyReLU(nn.Module):
+    """Scaled LeakyReLU.
+
+    Args:
+        negative_slope (float): Negative slope. Default: 0.2.
+    """
+
+    def __init__(self, negative_slope=0.2):
+        super(ScaledLeakyReLU, self).__init__()
+        self.negative_slope = negative_slope
+
+    def forward(self, x):
+        out = F.leaky_relu(x, negative_slope=self.negative_slope)
+        return out * math.sqrt(2)
+
+
+class EqualConv2d(nn.Module):
+    """Equalized Linear as StyleGAN2.
+
+    Args:
+        in_channels (int): Channel number of the input.
+        out_channels (int): Channel number of the output.
+        kernel_size (int): Size of the convolving kernel.
+        stride (int): Stride of the convolution. Default: 1
+        padding (int): Zero-padding added to both sides of the input.
+            Default: 0.
+        bias (bool): If ``True``, adds a learnable bias to the output.
+            Default: ``True``.
+        bias_init_val (float): Bias initialized value. Default: 0.
+    """
+
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, bias_init_val=0):
+        super(EqualConv2d, self).__init__()
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.kernel_size = kernel_size
+        self.stride = stride
+        self.padding = padding
+        self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
+
+        self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
+        if bias:
+            self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
+        else:
+            self.register_parameter('bias', None)
+
+    def forward(self, x):
+        out = F.conv2d(
+            x,
+            self.weight * self.scale,
+            bias=self.bias,
+            stride=self.stride,
+            padding=self.padding,
+        )
+
+        return out
+
+    def __repr__(self):
+        return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
+                f'out_channels={self.out_channels}, '
+                f'kernel_size={self.kernel_size},'
+                f' stride={self.stride}, padding={self.padding}, '
+                f'bias={self.bias is not None})')
+
+
+class ConvLayer(nn.Sequential):
+    """Conv Layer used in StyleGAN2 Discriminator.
+
+    Args:
+        in_channels (int): Channel number of the input.
+        out_channels (int): Channel number of the output.
+        kernel_size (int): Kernel size.
+        downsample (bool): Whether downsample by a factor of 2.
+            Default: False.
+        bias (bool): Whether with bias. Default: True.
+        activate (bool): Whether use activateion. Default: True.
+    """
+
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 kernel_size,
+                 downsample=False,
+                 bias=True,
+                 activate=True,
+                 interpolation_mode='bilinear'):
+        layers = []
+        self.interpolation_mode = interpolation_mode
+        # downsample
+        if downsample:
+            if self.interpolation_mode == 'nearest':
+                self.align_corners = None
+            else:
+                self.align_corners = False
+
+            layers.append(
+                torch.nn.Upsample(scale_factor=0.5, mode=interpolation_mode, align_corners=self.align_corners))
+        stride = 1
+        self.padding = kernel_size // 2
+        # conv
+        layers.append(
+            EqualConv2d(
+                in_channels, out_channels, kernel_size, stride=stride, padding=self.padding, bias=bias
+                and not activate))
+        # activation
+        if activate:
+            if bias:
+                layers.append(FusedLeakyReLU(out_channels))
+            else:
+                layers.append(ScaledLeakyReLU(0.2))
+
+        super(ConvLayer, self).__init__(*layers)
+
+
+class ResBlock(nn.Module):
+    """Residual block used in StyleGAN2 Discriminator.
+
+    Args:
+        in_channels (int): Channel number of the input.
+        out_channels (int): Channel number of the output.
+    """
+
+    def __init__(self, in_channels, out_channels, interpolation_mode='bilinear'):
+        super(ResBlock, self).__init__()
+
+        self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
+        self.conv2 = ConvLayer(
+            in_channels,
+            out_channels,
+            3,
+            downsample=True,
+            interpolation_mode=interpolation_mode,
+            bias=True,
+            activate=True)
+        self.skip = ConvLayer(
+            in_channels,
+            out_channels,
+            1,
+            downsample=True,
+            interpolation_mode=interpolation_mode,
+            bias=False,
+            activate=False)
+
+    def forward(self, x):
+        out = self.conv1(x)
+        out = self.conv2(out)
+        skip = self.skip(x)
+        out = (out + skip) / math.sqrt(2)
+        return out
diff --git a/basicsr/archs/swinir_arch.py b/basicsr/archs/swinir_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..3917fa2c7408e1f5b55b9930c643a9af920a4d81
--- /dev/null
+++ b/basicsr/archs/swinir_arch.py
@@ -0,0 +1,956 @@
+# Modified from https://github.com/JingyunLiang/SwinIR
+# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
+# Originally Written by Ze Liu, Modified by Jingyun Liang.
+
+import math
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint as checkpoint
+
+from basicsr.utils.registry import ARCH_REGISTRY
+from .arch_util import to_2tuple, trunc_normal_
+
+
+def drop_path(x, drop_prob: float = 0., training: bool = False):
+    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+    From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
+    """
+    if drop_prob == 0. or not training:
+        return x
+    keep_prob = 1 - drop_prob
+    shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
+    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
+    random_tensor.floor_()  # binarize
+    output = x.div(keep_prob) * random_tensor
+    return output
+
+
+class DropPath(nn.Module):
+    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
+
+    From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
+    """
+
+    def __init__(self, drop_prob=None):
+        super(DropPath, self).__init__()
+        self.drop_prob = drop_prob
+
+    def forward(self, x):
+        return drop_path(x, self.drop_prob, self.training)
+
+
+class Mlp(nn.Module):
+
+    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.Linear(in_features, hidden_features)
+        self.act = act_layer()
+        self.fc2 = nn.Linear(hidden_features, out_features)
+        self.drop = nn.Dropout(drop)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.drop(x)
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
+
+
+def window_partition(x, window_size):
+    """
+    Args:
+        x: (b, h, w, c)
+        window_size (int): window size
+
+    Returns:
+        windows: (num_windows*b, window_size, window_size, c)
+    """
+    b, h, w, c = x.shape
+    x = x.view(b, h // window_size, window_size, w // window_size, window_size, c)
+    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c)
+    return windows
+
+
+def window_reverse(windows, window_size, h, w):
+    """
+    Args:
+        windows: (num_windows*b, window_size, window_size, c)
+        window_size (int): Window size
+        h (int): Height of image
+        w (int): Width of image
+
+    Returns:
+        x: (b, h, w, c)
+    """
+    b = int(windows.shape[0] / (h * w / window_size / window_size))
+    x = windows.view(b, h // window_size, w // window_size, window_size, window_size, -1)
+    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1)
+    return x
+
+
+class WindowAttention(nn.Module):
+    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
+    It supports both of shifted and non-shifted window.
+
+    Args:
+        dim (int): Number of input channels.
+        window_size (tuple[int]): The height and width of the window.
+        num_heads (int): Number of attention heads.
+        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
+        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+    """
+
+    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
+
+        super().__init__()
+        self.dim = dim
+        self.window_size = window_size  # Wh, Ww
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        self.scale = qk_scale or head_dim**-0.5
+
+        # define a parameter table of relative position bias
+        self.relative_position_bias_table = nn.Parameter(
+            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH
+
+        # get pair-wise relative position index for each token inside the window
+        coords_h = torch.arange(self.window_size[0])
+        coords_w = torch.arange(self.window_size[1])
+        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
+        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
+        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
+        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
+        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
+        relative_coords[:, :, 1] += self.window_size[1] - 1
+        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
+        self.register_buffer('relative_position_index', relative_position_index)
+
+        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(dim, dim)
+
+        self.proj_drop = nn.Dropout(proj_drop)
+
+        trunc_normal_(self.relative_position_bias_table, std=.02)
+        self.softmax = nn.Softmax(dim=-1)
+
+    def forward(self, x, mask=None):
+        """
+        Args:
+            x: input features with shape of (num_windows*b, n, c)
+            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+        """
+        b_, n, c = x.shape
+        qkv = self.qkv(x).reshape(b_, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4)
+        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
+
+        q = q * self.scale
+        attn = (q @ k.transpose(-2, -1))
+
+        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
+        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
+        attn = attn + relative_position_bias.unsqueeze(0)
+
+        if mask is not None:
+            nw = mask.shape[0]
+            attn = attn.view(b_ // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0)
+            attn = attn.view(-1, self.num_heads, n, n)
+            attn = self.softmax(attn)
+        else:
+            attn = self.softmax(attn)
+
+        attn = self.attn_drop(attn)
+
+        x = (attn @ v).transpose(1, 2).reshape(b_, n, c)
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x
+
+    def extra_repr(self) -> str:
+        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
+
+    def flops(self, n):
+        # calculate flops for 1 window with token length of n
+        flops = 0
+        # qkv = self.qkv(x)
+        flops += n * self.dim * 3 * self.dim
+        # attn = (q @ k.transpose(-2, -1))
+        flops += self.num_heads * n * (self.dim // self.num_heads) * n
+        #  x = (attn @ v)
+        flops += self.num_heads * n * n * (self.dim // self.num_heads)
+        # x = self.proj(x)
+        flops += n * self.dim * self.dim
+        return flops
+
+
+class SwinTransformerBlock(nn.Module):
+    r""" Swin Transformer Block.
+
+    Args:
+        dim (int): Number of input channels.
+        input_resolution (tuple[int]): Input resolution.
+        num_heads (int): Number of attention heads.
+        window_size (int): Window size.
+        shift_size (int): Shift size for SW-MSA.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+        drop (float, optional): Dropout rate. Default: 0.0
+        attn_drop (float, optional): Attention dropout rate. Default: 0.0
+        drop_path (float, optional): Stochastic depth rate. Default: 0.0
+        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
+    """
+
+    def __init__(self,
+                 dim,
+                 input_resolution,
+                 num_heads,
+                 window_size=7,
+                 shift_size=0,
+                 mlp_ratio=4.,
+                 qkv_bias=True,
+                 qk_scale=None,
+                 drop=0.,
+                 attn_drop=0.,
+                 drop_path=0.,
+                 act_layer=nn.GELU,
+                 norm_layer=nn.LayerNorm):
+        super().__init__()
+        self.dim = dim
+        self.input_resolution = input_resolution
+        self.num_heads = num_heads
+        self.window_size = window_size
+        self.shift_size = shift_size
+        self.mlp_ratio = mlp_ratio
+        if min(self.input_resolution) <= self.window_size:
+            # if window size is larger than input resolution, we don't partition windows
+            self.shift_size = 0
+            self.window_size = min(self.input_resolution)
+        assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size'
+
+        self.norm1 = norm_layer(dim)
+        self.attn = WindowAttention(
+            dim,
+            window_size=to_2tuple(self.window_size),
+            num_heads=num_heads,
+            qkv_bias=qkv_bias,
+            qk_scale=qk_scale,
+            attn_drop=attn_drop,
+            proj_drop=drop)
+
+        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+        self.norm2 = norm_layer(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+        if self.shift_size > 0:
+            attn_mask = self.calculate_mask(self.input_resolution)
+        else:
+            attn_mask = None
+
+        self.register_buffer('attn_mask', attn_mask)
+
+    def calculate_mask(self, x_size):
+        # calculate attention mask for SW-MSA
+        h, w = x_size
+        img_mask = torch.zeros((1, h, w, 1))  # 1 h w 1
+        h_slices = (slice(0, -self.window_size), slice(-self.window_size,
+                                                       -self.shift_size), slice(-self.shift_size, None))
+        w_slices = (slice(0, -self.window_size), slice(-self.window_size,
+                                                       -self.shift_size), slice(-self.shift_size, None))
+        cnt = 0
+        for h in h_slices:
+            for w in w_slices:
+                img_mask[:, h, w, :] = cnt
+                cnt += 1
+
+        mask_windows = window_partition(img_mask, self.window_size)  # nw, window_size, window_size, 1
+        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+
+        return attn_mask
+
+    def forward(self, x, x_size):
+        h, w = x_size
+        b, _, c = x.shape
+        # assert seq_len == h * w, "input feature has wrong size"
+
+        shortcut = x
+        x = self.norm1(x)
+        x = x.view(b, h, w, c)
+
+        # cyclic shift
+        if self.shift_size > 0:
+            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+        else:
+            shifted_x = x
+
+        # partition windows
+        x_windows = window_partition(shifted_x, self.window_size)  # nw*b, window_size, window_size, c
+        x_windows = x_windows.view(-1, self.window_size * self.window_size, c)  # nw*b, window_size*window_size, c
+
+        # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
+        if self.input_resolution == x_size:
+            attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nw*b, window_size*window_size, c
+        else:
+            attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
+
+        # merge windows
+        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, c)
+        shifted_x = window_reverse(attn_windows, self.window_size, h, w)  # b h' w' c
+
+        # reverse cyclic shift
+        if self.shift_size > 0:
+            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+        else:
+            x = shifted_x
+        x = x.view(b, h * w, c)
+
+        # FFN
+        x = shortcut + self.drop_path(x)
+        x = x + self.drop_path(self.mlp(self.norm2(x)))
+
+        return x
+
+    def extra_repr(self) -> str:
+        return (f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, '
+                f'window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}')
+
+    def flops(self):
+        flops = 0
+        h, w = self.input_resolution
+        # norm1
+        flops += self.dim * h * w
+        # W-MSA/SW-MSA
+        nw = h * w / self.window_size / self.window_size
+        flops += nw * self.attn.flops(self.window_size * self.window_size)
+        # mlp
+        flops += 2 * h * w * self.dim * self.dim * self.mlp_ratio
+        # norm2
+        flops += self.dim * h * w
+        return flops
+
+
+class PatchMerging(nn.Module):
+    r""" Patch Merging Layer.
+
+    Args:
+        input_resolution (tuple[int]): Resolution of input feature.
+        dim (int): Number of input channels.
+        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
+    """
+
+    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
+        super().__init__()
+        self.input_resolution = input_resolution
+        self.dim = dim
+        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+        self.norm = norm_layer(4 * dim)
+
+    def forward(self, x):
+        """
+        x: b, h*w, c
+        """
+        h, w = self.input_resolution
+        b, seq_len, c = x.shape
+        assert seq_len == h * w, 'input feature has wrong size'
+        assert h % 2 == 0 and w % 2 == 0, f'x size ({h}*{w}) are not even.'
+
+        x = x.view(b, h, w, c)
+
+        x0 = x[:, 0::2, 0::2, :]  # b h/2 w/2 c
+        x1 = x[:, 1::2, 0::2, :]  # b h/2 w/2 c
+        x2 = x[:, 0::2, 1::2, :]  # b h/2 w/2 c
+        x3 = x[:, 1::2, 1::2, :]  # b h/2 w/2 c
+        x = torch.cat([x0, x1, x2, x3], -1)  # b h/2 w/2 4*c
+        x = x.view(b, -1, 4 * c)  # b h/2*w/2 4*c
+
+        x = self.norm(x)
+        x = self.reduction(x)
+
+        return x
+
+    def extra_repr(self) -> str:
+        return f'input_resolution={self.input_resolution}, dim={self.dim}'
+
+    def flops(self):
+        h, w = self.input_resolution
+        flops = h * w * self.dim
+        flops += (h // 2) * (w // 2) * 4 * self.dim * 2 * self.dim
+        return flops
+
+
+class BasicLayer(nn.Module):
+    """ A basic Swin Transformer layer for one stage.
+
+    Args:
+        dim (int): Number of input channels.
+        input_resolution (tuple[int]): Input resolution.
+        depth (int): Number of blocks.
+        num_heads (int): Number of attention heads.
+        window_size (int): Local window size.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+        drop (float, optional): Dropout rate. Default: 0.0
+        attn_drop (float, optional): Attention dropout rate. Default: 0.0
+        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+    """
+
+    def __init__(self,
+                 dim,
+                 input_resolution,
+                 depth,
+                 num_heads,
+                 window_size,
+                 mlp_ratio=4.,
+                 qkv_bias=True,
+                 qk_scale=None,
+                 drop=0.,
+                 attn_drop=0.,
+                 drop_path=0.,
+                 norm_layer=nn.LayerNorm,
+                 downsample=None,
+                 use_checkpoint=False):
+
+        super().__init__()
+        self.dim = dim
+        self.input_resolution = input_resolution
+        self.depth = depth
+        self.use_checkpoint = use_checkpoint
+
+        # build blocks
+        self.blocks = nn.ModuleList([
+            SwinTransformerBlock(
+                dim=dim,
+                input_resolution=input_resolution,
+                num_heads=num_heads,
+                window_size=window_size,
+                shift_size=0 if (i % 2 == 0) else window_size // 2,
+                mlp_ratio=mlp_ratio,
+                qkv_bias=qkv_bias,
+                qk_scale=qk_scale,
+                drop=drop,
+                attn_drop=attn_drop,
+                drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+                norm_layer=norm_layer) for i in range(depth)
+        ])
+
+        # patch merging layer
+        if downsample is not None:
+            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
+        else:
+            self.downsample = None
+
+    def forward(self, x, x_size):
+        for blk in self.blocks:
+            if self.use_checkpoint:
+                x = checkpoint.checkpoint(blk, x)
+            else:
+                x = blk(x, x_size)
+        if self.downsample is not None:
+            x = self.downsample(x)
+        return x
+
+    def extra_repr(self) -> str:
+        return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}'
+
+    def flops(self):
+        flops = 0
+        for blk in self.blocks:
+            flops += blk.flops()
+        if self.downsample is not None:
+            flops += self.downsample.flops()
+        return flops
+
+
+class RSTB(nn.Module):
+    """Residual Swin Transformer Block (RSTB).
+
+    Args:
+        dim (int): Number of input channels.
+        input_resolution (tuple[int]): Input resolution.
+        depth (int): Number of blocks.
+        num_heads (int): Number of attention heads.
+        window_size (int): Local window size.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+        drop (float, optional): Dropout rate. Default: 0.0
+        attn_drop (float, optional): Attention dropout rate. Default: 0.0
+        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+        img_size: Input image size.
+        patch_size: Patch size.
+        resi_connection: The convolutional block before residual connection.
+    """
+
+    def __init__(self,
+                 dim,
+                 input_resolution,
+                 depth,
+                 num_heads,
+                 window_size,
+                 mlp_ratio=4.,
+                 qkv_bias=True,
+                 qk_scale=None,
+                 drop=0.,
+                 attn_drop=0.,
+                 drop_path=0.,
+                 norm_layer=nn.LayerNorm,
+                 downsample=None,
+                 use_checkpoint=False,
+                 img_size=224,
+                 patch_size=4,
+                 resi_connection='1conv'):
+        super(RSTB, self).__init__()
+
+        self.dim = dim
+        self.input_resolution = input_resolution
+
+        self.residual_group = BasicLayer(
+            dim=dim,
+            input_resolution=input_resolution,
+            depth=depth,
+            num_heads=num_heads,
+            window_size=window_size,
+            mlp_ratio=mlp_ratio,
+            qkv_bias=qkv_bias,
+            qk_scale=qk_scale,
+            drop=drop,
+            attn_drop=attn_drop,
+            drop_path=drop_path,
+            norm_layer=norm_layer,
+            downsample=downsample,
+            use_checkpoint=use_checkpoint)
+
+        if resi_connection == '1conv':
+            self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
+        elif resi_connection == '3conv':
+            # to save parameters and memory
+            self.conv = nn.Sequential(
+                nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
+                nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True),
+                nn.Conv2d(dim // 4, dim, 3, 1, 1))
+
+        self.patch_embed = PatchEmbed(
+            img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
+
+        self.patch_unembed = PatchUnEmbed(
+            img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
+
+    def forward(self, x, x_size):
+        return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
+
+    def flops(self):
+        flops = 0
+        flops += self.residual_group.flops()
+        h, w = self.input_resolution
+        flops += h * w * self.dim * self.dim * 9
+        flops += self.patch_embed.flops()
+        flops += self.patch_unembed.flops()
+
+        return flops
+
+
+class PatchEmbed(nn.Module):
+    r""" Image to Patch Embedding
+
+    Args:
+        img_size (int): Image size.  Default: 224.
+        patch_size (int): Patch token size. Default: 4.
+        in_chans (int): Number of input image channels. Default: 3.
+        embed_dim (int): Number of linear projection output channels. Default: 96.
+        norm_layer (nn.Module, optional): Normalization layer. Default: None
+    """
+
+    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+        super().__init__()
+        img_size = to_2tuple(img_size)
+        patch_size = to_2tuple(patch_size)
+        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
+        self.img_size = img_size
+        self.patch_size = patch_size
+        self.patches_resolution = patches_resolution
+        self.num_patches = patches_resolution[0] * patches_resolution[1]
+
+        self.in_chans = in_chans
+        self.embed_dim = embed_dim
+
+        if norm_layer is not None:
+            self.norm = norm_layer(embed_dim)
+        else:
+            self.norm = None
+
+    def forward(self, x):
+        x = x.flatten(2).transpose(1, 2)  # b Ph*Pw c
+        if self.norm is not None:
+            x = self.norm(x)
+        return x
+
+    def flops(self):
+        flops = 0
+        h, w = self.img_size
+        if self.norm is not None:
+            flops += h * w * self.embed_dim
+        return flops
+
+
+class PatchUnEmbed(nn.Module):
+    r""" Image to Patch Unembedding
+
+    Args:
+        img_size (int): Image size.  Default: 224.
+        patch_size (int): Patch token size. Default: 4.
+        in_chans (int): Number of input image channels. Default: 3.
+        embed_dim (int): Number of linear projection output channels. Default: 96.
+        norm_layer (nn.Module, optional): Normalization layer. Default: None
+    """
+
+    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+        super().__init__()
+        img_size = to_2tuple(img_size)
+        patch_size = to_2tuple(patch_size)
+        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
+        self.img_size = img_size
+        self.patch_size = patch_size
+        self.patches_resolution = patches_resolution
+        self.num_patches = patches_resolution[0] * patches_resolution[1]
+
+        self.in_chans = in_chans
+        self.embed_dim = embed_dim
+
+    def forward(self, x, x_size):
+        x = x.transpose(1, 2).view(x.shape[0], self.embed_dim, x_size[0], x_size[1])  # b Ph*Pw c
+        return x
+
+    def flops(self):
+        flops = 0
+        return flops
+
+
+class Upsample(nn.Sequential):
+    """Upsample module.
+
+    Args:
+        scale (int): Scale factor. Supported scales: 2^n and 3.
+        num_feat (int): Channel number of intermediate features.
+    """
+
+    def __init__(self, scale, num_feat):
+        m = []
+        if (scale & (scale - 1)) == 0:  # scale = 2^n
+            for _ in range(int(math.log(scale, 2))):
+                m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
+                m.append(nn.PixelShuffle(2))
+        elif scale == 3:
+            m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
+            m.append(nn.PixelShuffle(3))
+        else:
+            raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
+        super(Upsample, self).__init__(*m)
+
+
+class UpsampleOneStep(nn.Sequential):
+    """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
+       Used in lightweight SR to save parameters.
+
+    Args:
+        scale (int): Scale factor. Supported scales: 2^n and 3.
+        num_feat (int): Channel number of intermediate features.
+
+    """
+
+    def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
+        self.num_feat = num_feat
+        self.input_resolution = input_resolution
+        m = []
+        m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1))
+        m.append(nn.PixelShuffle(scale))
+        super(UpsampleOneStep, self).__init__(*m)
+
+    def flops(self):
+        h, w = self.input_resolution
+        flops = h * w * self.num_feat * 3 * 9
+        return flops
+
+
+@ARCH_REGISTRY.register()
+class SwinIR(nn.Module):
+    r""" SwinIR
+        A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
+
+    Args:
+        img_size (int | tuple(int)): Input image size. Default 64
+        patch_size (int | tuple(int)): Patch size. Default: 1
+        in_chans (int): Number of input image channels. Default: 3
+        embed_dim (int): Patch embedding dimension. Default: 96
+        depths (tuple(int)): Depth of each Swin Transformer layer.
+        num_heads (tuple(int)): Number of attention heads in different layers.
+        window_size (int): Window size. Default: 7
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
+        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
+        drop_rate (float): Dropout rate. Default: 0
+        attn_drop_rate (float): Attention dropout rate. Default: 0
+        drop_path_rate (float): Stochastic depth rate. Default: 0.1
+        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
+        patch_norm (bool): If True, add normalization after patch embedding. Default: True
+        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
+        upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
+        img_range: Image range. 1. or 255.
+        upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
+        resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
+    """
+
+    def __init__(self,
+                 img_size=64,
+                 patch_size=1,
+                 in_chans=3,
+                 embed_dim=96,
+                 depths=(6, 6, 6, 6),
+                 num_heads=(6, 6, 6, 6),
+                 window_size=7,
+                 mlp_ratio=4.,
+                 qkv_bias=True,
+                 qk_scale=None,
+                 drop_rate=0.,
+                 attn_drop_rate=0.,
+                 drop_path_rate=0.1,
+                 norm_layer=nn.LayerNorm,
+                 ape=False,
+                 patch_norm=True,
+                 use_checkpoint=False,
+                 upscale=2,
+                 img_range=1.,
+                 upsampler='',
+                 resi_connection='1conv',
+                 **kwargs):
+        super(SwinIR, self).__init__()
+        num_in_ch = in_chans
+        num_out_ch = in_chans
+        num_feat = 64
+        self.img_range = img_range
+        if in_chans == 3:
+            rgb_mean = (0.4488, 0.4371, 0.4040)
+            self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
+        else:
+            self.mean = torch.zeros(1, 1, 1, 1)
+        self.upscale = upscale
+        self.upsampler = upsampler
+
+        # ------------------------- 1, shallow feature extraction ------------------------- #
+        self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
+
+        # ------------------------- 2, deep feature extraction ------------------------- #
+        self.num_layers = len(depths)
+        self.embed_dim = embed_dim
+        self.ape = ape
+        self.patch_norm = patch_norm
+        self.num_features = embed_dim
+        self.mlp_ratio = mlp_ratio
+
+        # split image into non-overlapping patches
+        self.patch_embed = PatchEmbed(
+            img_size=img_size,
+            patch_size=patch_size,
+            in_chans=embed_dim,
+            embed_dim=embed_dim,
+            norm_layer=norm_layer if self.patch_norm else None)
+        num_patches = self.patch_embed.num_patches
+        patches_resolution = self.patch_embed.patches_resolution
+        self.patches_resolution = patches_resolution
+
+        # merge non-overlapping patches into image
+        self.patch_unembed = PatchUnEmbed(
+            img_size=img_size,
+            patch_size=patch_size,
+            in_chans=embed_dim,
+            embed_dim=embed_dim,
+            norm_layer=norm_layer if self.patch_norm else None)
+
+        # absolute position embedding
+        if self.ape:
+            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
+            trunc_normal_(self.absolute_pos_embed, std=.02)
+
+        self.pos_drop = nn.Dropout(p=drop_rate)
+
+        # stochastic depth
+        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule
+
+        # build Residual Swin Transformer blocks (RSTB)
+        self.layers = nn.ModuleList()
+        for i_layer in range(self.num_layers):
+            layer = RSTB(
+                dim=embed_dim,
+                input_resolution=(patches_resolution[0], patches_resolution[1]),
+                depth=depths[i_layer],
+                num_heads=num_heads[i_layer],
+                window_size=window_size,
+                mlp_ratio=self.mlp_ratio,
+                qkv_bias=qkv_bias,
+                qk_scale=qk_scale,
+                drop=drop_rate,
+                attn_drop=attn_drop_rate,
+                drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],  # no impact on SR results
+                norm_layer=norm_layer,
+                downsample=None,
+                use_checkpoint=use_checkpoint,
+                img_size=img_size,
+                patch_size=patch_size,
+                resi_connection=resi_connection)
+            self.layers.append(layer)
+        self.norm = norm_layer(self.num_features)
+
+        # build the last conv layer in deep feature extraction
+        if resi_connection == '1conv':
+            self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
+        elif resi_connection == '3conv':
+            # to save parameters and memory
+            self.conv_after_body = nn.Sequential(
+                nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
+                nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True),
+                nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
+
+        # ------------------------- 3, high quality image reconstruction ------------------------- #
+        if self.upsampler == 'pixelshuffle':
+            # for classical SR
+            self.conv_before_upsample = nn.Sequential(
+                nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
+            self.upsample = Upsample(upscale, num_feat)
+            self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+        elif self.upsampler == 'pixelshuffledirect':
+            # for lightweight SR (to save parameters)
+            self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
+                                            (patches_resolution[0], patches_resolution[1]))
+        elif self.upsampler == 'nearest+conv':
+            # for real-world SR (less artifacts)
+            assert self.upscale == 4, 'only support x4 now.'
+            self.conv_before_upsample = nn.Sequential(
+                nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
+            self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+            self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+            self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+            self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+            self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+        else:
+            # for image denoising and JPEG compression artifact reduction
+            self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
+
+        self.apply(self._init_weights)
+
+    def _init_weights(self, m):
+        if isinstance(m, nn.Linear):
+            trunc_normal_(m.weight, std=.02)
+            if isinstance(m, nn.Linear) and m.bias is not None:
+                nn.init.constant_(m.bias, 0)
+        elif isinstance(m, nn.LayerNorm):
+            nn.init.constant_(m.bias, 0)
+            nn.init.constant_(m.weight, 1.0)
+
+    @torch.jit.ignore
+    def no_weight_decay(self):
+        return {'absolute_pos_embed'}
+
+    @torch.jit.ignore
+    def no_weight_decay_keywords(self):
+        return {'relative_position_bias_table'}
+
+    def forward_features(self, x):
+        x_size = (x.shape[2], x.shape[3])
+        x = self.patch_embed(x)
+        if self.ape:
+            x = x + self.absolute_pos_embed
+        x = self.pos_drop(x)
+
+        for layer in self.layers:
+            x = layer(x, x_size)
+
+        x = self.norm(x)  # b seq_len c
+        x = self.patch_unembed(x, x_size)
+
+        return x
+
+    def forward(self, x):
+        self.mean = self.mean.type_as(x)
+        x = (x - self.mean) * self.img_range
+
+        if self.upsampler == 'pixelshuffle':
+            # for classical SR
+            x = self.conv_first(x)
+            x = self.conv_after_body(self.forward_features(x)) + x
+            x = self.conv_before_upsample(x)
+            x = self.conv_last(self.upsample(x))
+        elif self.upsampler == 'pixelshuffledirect':
+            # for lightweight SR
+            x = self.conv_first(x)
+            x = self.conv_after_body(self.forward_features(x)) + x
+            x = self.upsample(x)
+        elif self.upsampler == 'nearest+conv':
+            # for real-world SR
+            x = self.conv_first(x)
+            x = self.conv_after_body(self.forward_features(x)) + x
+            x = self.conv_before_upsample(x)
+            x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
+            x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
+            x = self.conv_last(self.lrelu(self.conv_hr(x)))
+        else:
+            # for image denoising and JPEG compression artifact reduction
+            x_first = self.conv_first(x)
+            res = self.conv_after_body(self.forward_features(x_first)) + x_first
+            x = x + self.conv_last(res)
+
+        x = x / self.img_range + self.mean
+
+        return x
+
+    def flops(self):
+        flops = 0
+        h, w = self.patches_resolution
+        flops += h * w * 3 * self.embed_dim * 9
+        flops += self.patch_embed.flops()
+        for layer in self.layers:
+            flops += layer.flops()
+        flops += h * w * 3 * self.embed_dim * self.embed_dim
+        flops += self.upsample.flops()
+        return flops
+
+
+if __name__ == '__main__':
+    upscale = 4
+    window_size = 8
+    height = (1024 // upscale // window_size + 1) * window_size
+    width = (720 // upscale // window_size + 1) * window_size
+    model = SwinIR(
+        upscale=2,
+        img_size=(height, width),
+        window_size=window_size,
+        img_range=1.,
+        depths=[6, 6, 6, 6],
+        embed_dim=60,
+        num_heads=[6, 6, 6, 6],
+        mlp_ratio=2,
+        upsampler='pixelshuffledirect')
+    print(model)
+    print(height, width, model.flops() / 1e9)
+
+    x = torch.randn((1, 3, height, width))
+    x = model(x)
+    print(x.shape)
diff --git a/basicsr/archs/tof_arch.py b/basicsr/archs/tof_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..a90a64d89386e19f92c987bbe2133472991d764a
--- /dev/null
+++ b/basicsr/archs/tof_arch.py
@@ -0,0 +1,172 @@
+import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+from basicsr.utils.registry import ARCH_REGISTRY
+from .arch_util import flow_warp
+
+
+class BasicModule(nn.Module):
+    """Basic module of SPyNet.
+
+    Note that unlike the architecture in spynet_arch.py, the basic module
+    here contains batch normalization.
+    """
+
+    def __init__(self):
+        super(BasicModule, self).__init__()
+        self.basic_module = nn.Sequential(
+            nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3, bias=False),
+            nn.BatchNorm2d(32), nn.ReLU(inplace=True),
+            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3, bias=False),
+            nn.BatchNorm2d(64), nn.ReLU(inplace=True),
+            nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3, bias=False),
+            nn.BatchNorm2d(32), nn.ReLU(inplace=True),
+            nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3, bias=False),
+            nn.BatchNorm2d(16), nn.ReLU(inplace=True),
+            nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3))
+
+    def forward(self, tensor_input):
+        """
+        Args:
+            tensor_input (Tensor): Input tensor with shape (b, 8, h, w).
+                8 channels contain:
+                [reference image (3), neighbor image (3), initial flow (2)].
+
+        Returns:
+            Tensor: Estimated flow with shape (b, 2, h, w)
+        """
+        return self.basic_module(tensor_input)
+
+
+class SPyNetTOF(nn.Module):
+    """SPyNet architecture for TOF.
+
+    Note that this implementation is specifically for TOFlow. Please use :file:`spynet_arch.py` for general use.
+    They differ in the following aspects:
+
+    1. The basic modules here contain BatchNorm.
+    2. Normalization and denormalization are not done here, as they are done in TOFlow.
+
+    ``Paper: Optical Flow Estimation using a Spatial Pyramid Network``
+
+    Reference: https://github.com/Coldog2333/pytoflow
+
+    Args:
+        load_path (str): Path for pretrained SPyNet. Default: None.
+    """
+
+    def __init__(self, load_path=None):
+        super(SPyNetTOF, self).__init__()
+
+        self.basic_module = nn.ModuleList([BasicModule() for _ in range(4)])
+        if load_path:
+            self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params'])
+
+    def forward(self, ref, supp):
+        """
+        Args:
+            ref (Tensor): Reference image with shape of (b, 3, h, w).
+            supp: The supporting image to be warped: (b, 3, h, w).
+
+        Returns:
+            Tensor: Estimated optical flow: (b, 2, h, w).
+        """
+        num_batches, _, h, w = ref.size()
+        ref = [ref]
+        supp = [supp]
+
+        # generate downsampled frames
+        for _ in range(3):
+            ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False))
+            supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False))
+
+        # flow computation
+        flow = ref[0].new_zeros(num_batches, 2, h // 16, w // 16)
+        for i in range(4):
+            flow_up = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0
+            flow = flow_up + self.basic_module[i](
+                torch.cat([ref[i], flow_warp(supp[i], flow_up.permute(0, 2, 3, 1)), flow_up], 1))
+        return flow
+
+
+@ARCH_REGISTRY.register()
+class TOFlow(nn.Module):
+    """PyTorch implementation of TOFlow.
+
+    In TOFlow, the LR frames are pre-upsampled and have the same size with the GT frames.
+
+    ``Paper: Video Enhancement with Task-Oriented Flow``
+
+    Reference: https://github.com/anchen1011/toflow
+
+    Reference: https://github.com/Coldog2333/pytoflow
+
+    Args:
+        adapt_official_weights (bool): Whether to adapt the weights translated
+            from the official implementation. Set to false if you want to
+            train from scratch. Default: False
+    """
+
+    def __init__(self, adapt_official_weights=False):
+        super(TOFlow, self).__init__()
+        self.adapt_official_weights = adapt_official_weights
+        self.ref_idx = 0 if adapt_official_weights else 3
+
+        self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
+        self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
+
+        # flow estimation module
+        self.spynet = SPyNetTOF()
+
+        # reconstruction module
+        self.conv_1 = nn.Conv2d(3 * 7, 64, 9, 1, 4)
+        self.conv_2 = nn.Conv2d(64, 64, 9, 1, 4)
+        self.conv_3 = nn.Conv2d(64, 64, 1)
+        self.conv_4 = nn.Conv2d(64, 3, 1)
+
+        # activation function
+        self.relu = nn.ReLU(inplace=True)
+
+    def normalize(self, img):
+        return (img - self.mean) / self.std
+
+    def denormalize(self, img):
+        return img * self.std + self.mean
+
+    def forward(self, lrs):
+        """
+        Args:
+            lrs: Input lr frames: (b, 7, 3, h, w).
+
+        Returns:
+            Tensor: SR frame: (b, 3, h, w).
+        """
+        # In the official implementation, the 0-th frame is the reference frame
+        if self.adapt_official_weights:
+            lrs = lrs[:, [3, 0, 1, 2, 4, 5, 6], :, :, :]
+
+        num_batches, num_lrs, _, h, w = lrs.size()
+
+        lrs = self.normalize(lrs.view(-1, 3, h, w))
+        lrs = lrs.view(num_batches, num_lrs, 3, h, w)
+
+        lr_ref = lrs[:, self.ref_idx, :, :, :]
+        lr_aligned = []
+        for i in range(7):  # 7 frames
+            if i == self.ref_idx:
+                lr_aligned.append(lr_ref)
+            else:
+                lr_supp = lrs[:, i, :, :, :]
+                flow = self.spynet(lr_ref, lr_supp)
+                lr_aligned.append(flow_warp(lr_supp, flow.permute(0, 2, 3, 1)))
+
+        # reconstruction
+        hr = torch.stack(lr_aligned, dim=1)
+        hr = hr.view(num_batches, -1, h, w)
+        hr = self.relu(self.conv_1(hr))
+        hr = self.relu(self.conv_2(hr))
+        hr = self.relu(self.conv_3(hr))
+        hr = self.conv_4(hr) + lr_ref
+
+        return self.denormalize(hr)
diff --git a/basicsr/archs/vgg_arch.py b/basicsr/archs/vgg_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..05200334e477e59feefd1e4a0b5e94204e4eb2fa
--- /dev/null
+++ b/basicsr/archs/vgg_arch.py
@@ -0,0 +1,161 @@
+import os
+import torch
+from collections import OrderedDict
+from torch import nn as nn
+from torchvision.models import vgg as vgg
+
+from basicsr.utils.registry import ARCH_REGISTRY
+
+VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth'
+NAMES = {
+    'vgg11': [
+        'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
+        'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
+        'pool5'
+    ],
+    'vgg13': [
+        'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
+        'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4',
+        'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
+    ],
+    'vgg16': [
+        'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
+        'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
+        'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
+        'pool5'
+    ],
+    'vgg19': [
+        'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
+        'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
+        'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
+        'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'
+    ]
+}
+
+
+def insert_bn(names):
+    """Insert bn layer after each conv.
+
+    Args:
+        names (list): The list of layer names.
+
+    Returns:
+        list: The list of layer names with bn layers.
+    """
+    names_bn = []
+    for name in names:
+        names_bn.append(name)
+        if 'conv' in name:
+            position = name.replace('conv', '')
+            names_bn.append('bn' + position)
+    return names_bn
+
+
+@ARCH_REGISTRY.register()
+class VGGFeatureExtractor(nn.Module):
+    """VGG network for feature extraction.
+
+    In this implementation, we allow users to choose whether use normalization
+    in the input feature and the type of vgg network. Note that the pretrained
+    path must fit the vgg type.
+
+    Args:
+        layer_name_list (list[str]): Forward function returns the corresponding
+            features according to the layer_name_list.
+            Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
+        vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
+        use_input_norm (bool): If True, normalize the input image. Importantly,
+            the input feature must in the range [0, 1]. Default: True.
+        range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
+            Default: False.
+        requires_grad (bool): If true, the parameters of VGG network will be
+            optimized. Default: False.
+        remove_pooling (bool): If true, the max pooling operations in VGG net
+            will be removed. Default: False.
+        pooling_stride (int): The stride of max pooling operation. Default: 2.
+    """
+
+    def __init__(self,
+                 layer_name_list,
+                 vgg_type='vgg19',
+                 use_input_norm=True,
+                 range_norm=False,
+                 requires_grad=False,
+                 remove_pooling=False,
+                 pooling_stride=2):
+        super(VGGFeatureExtractor, self).__init__()
+
+        self.layer_name_list = layer_name_list
+        self.use_input_norm = use_input_norm
+        self.range_norm = range_norm
+
+        self.names = NAMES[vgg_type.replace('_bn', '')]
+        if 'bn' in vgg_type:
+            self.names = insert_bn(self.names)
+
+        # only borrow layers that will be used to avoid unused params
+        max_idx = 0
+        for v in layer_name_list:
+            idx = self.names.index(v)
+            if idx > max_idx:
+                max_idx = idx
+
+        if os.path.exists(VGG_PRETRAIN_PATH):
+            vgg_net = getattr(vgg, vgg_type)(pretrained=False)
+            state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage)
+            vgg_net.load_state_dict(state_dict)
+        else:
+            vgg_net = getattr(vgg, vgg_type)(pretrained=True)
+
+        features = vgg_net.features[:max_idx + 1]
+
+        modified_net = OrderedDict()
+        for k, v in zip(self.names, features):
+            if 'pool' in k:
+                # if remove_pooling is true, pooling operation will be removed
+                if remove_pooling:
+                    continue
+                else:
+                    # in some cases, we may want to change the default stride
+                    modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
+            else:
+                modified_net[k] = v
+
+        self.vgg_net = nn.Sequential(modified_net)
+
+        if not requires_grad:
+            self.vgg_net.eval()
+            for param in self.parameters():
+                param.requires_grad = False
+        else:
+            self.vgg_net.train()
+            for param in self.parameters():
+                param.requires_grad = True
+
+        if self.use_input_norm:
+            # the mean is for image with range [0, 1]
+            self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
+            # the std is for image with range [0, 1]
+            self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
+
+    def forward(self, x):
+        """Forward function.
+
+        Args:
+            x (Tensor): Input tensor with shape (n, c, h, w).
+
+        Returns:
+            Tensor: Forward results.
+        """
+        if self.range_norm:
+            x = (x + 1) / 2
+        if self.use_input_norm:
+            x = (x - self.mean) / self.std
+
+        output = {}
+        for key, layer in self.vgg_net._modules.items():
+            x = layer(x)
+            if key in self.layer_name_list:
+                output[key] = x.clone()
+
+        return output
diff --git a/basicsr/data/__init__.py b/basicsr/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..510df16771d153f61fbf2126baac24f69d3de7e4
--- /dev/null
+++ b/basicsr/data/__init__.py
@@ -0,0 +1,101 @@
+import importlib
+import numpy as np
+import random
+import torch
+import torch.utils.data
+from copy import deepcopy
+from functools import partial
+from os import path as osp
+
+from basicsr.data.prefetch_dataloader import PrefetchDataLoader
+from basicsr.utils import get_root_logger, scandir
+from basicsr.utils.dist_util import get_dist_info
+from basicsr.utils.registry import DATASET_REGISTRY
+
+__all__ = ['build_dataset', 'build_dataloader']
+
+# automatically scan and import dataset modules for registry
+# scan all the files under the data folder with '_dataset' in file names
+data_folder = osp.dirname(osp.abspath(__file__))
+dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
+# import all the dataset modules
+_dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]
+
+
+def build_dataset(dataset_opt):
+    """Build dataset from options.
+
+    Args:
+        dataset_opt (dict): Configuration for dataset. It must contain:
+            name (str): Dataset name.
+            type (str): Dataset type.
+    """
+    dataset_opt = deepcopy(dataset_opt)
+    dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
+    logger = get_root_logger()
+    logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.')
+    return dataset
+
+
+def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
+    """Build dataloader.
+
+    Args:
+        dataset (torch.utils.data.Dataset): Dataset.
+        dataset_opt (dict): Dataset options. It contains the following keys:
+            phase (str): 'train' or 'val'.
+            num_worker_per_gpu (int): Number of workers for each GPU.
+            batch_size_per_gpu (int): Training batch size for each GPU.
+        num_gpu (int): Number of GPUs. Used only in the train phase.
+            Default: 1.
+        dist (bool): Whether in distributed training. Used only in the train
+            phase. Default: False.
+        sampler (torch.utils.data.sampler): Data sampler. Default: None.
+        seed (int | None): Seed. Default: None
+    """
+    phase = dataset_opt['phase']
+    rank, _ = get_dist_info()
+    if phase == 'train':
+        if dist:  # distributed training
+            batch_size = dataset_opt['batch_size_per_gpu']
+            num_workers = dataset_opt['num_worker_per_gpu']
+        else:  # non-distributed training
+            multiplier = 1 if num_gpu == 0 else num_gpu
+            batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
+            num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
+        dataloader_args = dict(
+            dataset=dataset,
+            batch_size=batch_size,
+            shuffle=False,
+            num_workers=num_workers,
+            sampler=sampler,
+            drop_last=True)
+        if sampler is None:
+            dataloader_args['shuffle'] = True
+        dataloader_args['worker_init_fn'] = partial(
+            worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
+    elif phase in ['val', 'test']:  # validation
+        dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
+    else:
+        raise ValueError(f"Wrong dataset phase: {phase}. Supported ones are 'train', 'val' and 'test'.")
+
+    dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
+    dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False)
+
+    prefetch_mode = dataset_opt.get('prefetch_mode')
+    if prefetch_mode == 'cpu':  # CPUPrefetcher
+        num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
+        logger = get_root_logger()
+        logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}')
+        return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
+    else:
+        # prefetch_mode=None: Normal dataloader
+        # prefetch_mode='cuda': dataloader for CUDAPrefetcher
+        return torch.utils.data.DataLoader(**dataloader_args)
+
+
+def worker_init_fn(worker_id, num_workers, rank, seed):
+    # Set the worker seed to num_workers * rank + worker_id + seed
+    worker_seed = num_workers * rank + worker_id + seed
+    np.random.seed(worker_seed)
+    random.seed(worker_seed)
diff --git a/basicsr/data/__pycache__/__init__.cpython-310.pyc b/basicsr/data/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3c7f9cfdf882bfb9d0627d31ee728ed5c645d428
Binary files /dev/null and b/basicsr/data/__pycache__/__init__.cpython-310.pyc differ
diff --git a/basicsr/data/__pycache__/data_sampler.cpython-310.pyc b/basicsr/data/__pycache__/data_sampler.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c7ac9e4ac3973c193e220e4459b7fcd8c549a654
Binary files /dev/null and b/basicsr/data/__pycache__/data_sampler.cpython-310.pyc differ
diff --git a/basicsr/data/__pycache__/data_util.cpython-310.pyc b/basicsr/data/__pycache__/data_util.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..40ef9471e96a8317ae1a6a793e08cce784688d19
Binary files /dev/null and b/basicsr/data/__pycache__/data_util.cpython-310.pyc differ
diff --git a/basicsr/data/__pycache__/degradations.cpython-310.pyc b/basicsr/data/__pycache__/degradations.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9c94bbd66b6b316e0eb35cd70c95a093c4556668
Binary files /dev/null and b/basicsr/data/__pycache__/degradations.cpython-310.pyc differ
diff --git a/basicsr/data/__pycache__/ffhq_dataset.cpython-310.pyc b/basicsr/data/__pycache__/ffhq_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d263bfc6d1308ec9ca4dab514c17f1c8a9f31366
Binary files /dev/null and b/basicsr/data/__pycache__/ffhq_dataset.cpython-310.pyc differ
diff --git a/basicsr/data/__pycache__/ffhq_degradation_dataset.cpython-310.pyc b/basicsr/data/__pycache__/ffhq_degradation_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..49d19b3e137ff62c17beea5d13ad04e8aa83d7d4
Binary files /dev/null and b/basicsr/data/__pycache__/ffhq_degradation_dataset.cpython-310.pyc differ
diff --git a/basicsr/data/__pycache__/paired_image_dataset.cpython-310.pyc b/basicsr/data/__pycache__/paired_image_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f4ab8b0f0b82e30c2f21ee78ccbcf3ef77be4783
Binary files /dev/null and b/basicsr/data/__pycache__/paired_image_dataset.cpython-310.pyc differ
diff --git a/basicsr/data/__pycache__/prefetch_dataloader.cpython-310.pyc b/basicsr/data/__pycache__/prefetch_dataloader.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..196fd1a8987c762e39c2d77a36da5763b100410e
Binary files /dev/null and b/basicsr/data/__pycache__/prefetch_dataloader.cpython-310.pyc differ
diff --git a/basicsr/data/__pycache__/realesrgan_dataset.cpython-310.pyc b/basicsr/data/__pycache__/realesrgan_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3f7b511a2ac470286f8125434fcb10e4b7481b23
Binary files /dev/null and b/basicsr/data/__pycache__/realesrgan_dataset.cpython-310.pyc differ
diff --git a/basicsr/data/__pycache__/realesrgan_paired_dataset.cpython-310.pyc b/basicsr/data/__pycache__/realesrgan_paired_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cd702193fbe239420cab4137136ad872d3882257
Binary files /dev/null and b/basicsr/data/__pycache__/realesrgan_paired_dataset.cpython-310.pyc differ
diff --git a/basicsr/data/__pycache__/reds_dataset.cpython-310.pyc b/basicsr/data/__pycache__/reds_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dc10848e313a329293fb3411f15d8844e962033a
Binary files /dev/null and b/basicsr/data/__pycache__/reds_dataset.cpython-310.pyc differ
diff --git a/basicsr/data/__pycache__/single_image_dataset.cpython-310.pyc b/basicsr/data/__pycache__/single_image_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..49f2b0117e6230bd9f45b78cfde824bf40a9e5bb
Binary files /dev/null and b/basicsr/data/__pycache__/single_image_dataset.cpython-310.pyc differ
diff --git a/basicsr/data/__pycache__/transforms.cpython-310.pyc b/basicsr/data/__pycache__/transforms.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..996a496d3c5ef61da47a8477358ef07712c93658
Binary files /dev/null and b/basicsr/data/__pycache__/transforms.cpython-310.pyc differ
diff --git a/basicsr/data/__pycache__/video_test_dataset.cpython-310.pyc b/basicsr/data/__pycache__/video_test_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..32bedb9ec89d091a5d0ec73dc2256e3991c23bb2
Binary files /dev/null and b/basicsr/data/__pycache__/video_test_dataset.cpython-310.pyc differ
diff --git a/basicsr/data/__pycache__/vimeo90k_dataset.cpython-310.pyc b/basicsr/data/__pycache__/vimeo90k_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..db49fa38ec3d1855ce1058aca0175d71952640a2
Binary files /dev/null and b/basicsr/data/__pycache__/vimeo90k_dataset.cpython-310.pyc differ
diff --git a/basicsr/data/data_sampler.py b/basicsr/data/data_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..575452d9f844a928f7f42296c81635cfbadec7c2
--- /dev/null
+++ b/basicsr/data/data_sampler.py
@@ -0,0 +1,48 @@
+import math
+import torch
+from torch.utils.data.sampler import Sampler
+
+
+class EnlargedSampler(Sampler):
+    """Sampler that restricts data loading to a subset of the dataset.
+
+    Modified from torch.utils.data.distributed.DistributedSampler
+    Support enlarging the dataset for iteration-based training, for saving
+    time when restart the dataloader after each epoch
+
+    Args:
+        dataset (torch.utils.data.Dataset): Dataset used for sampling.
+        num_replicas (int | None): Number of processes participating in
+            the training. It is usually the world_size.
+        rank (int | None): Rank of the current process within num_replicas.
+        ratio (int): Enlarging ratio. Default: 1.
+    """
+
+    def __init__(self, dataset, num_replicas, rank, ratio=1):
+        self.dataset = dataset
+        self.num_replicas = num_replicas
+        self.rank = rank
+        self.epoch = 0
+        self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
+        self.total_size = self.num_samples * self.num_replicas
+
+    def __iter__(self):
+        # deterministically shuffle based on epoch
+        g = torch.Generator()
+        g.manual_seed(self.epoch)
+        indices = torch.randperm(self.total_size, generator=g).tolist()
+
+        dataset_size = len(self.dataset)
+        indices = [v % dataset_size for v in indices]
+
+        # subsample
+        indices = indices[self.rank:self.total_size:self.num_replicas]
+        assert len(indices) == self.num_samples
+
+        return iter(indices)
+
+    def __len__(self):
+        return self.num_samples
+
+    def set_epoch(self, epoch):
+        self.epoch = epoch
diff --git a/basicsr/data/data_util.py b/basicsr/data/data_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..dce2562fb9f99475c44e9185f50018a428859214
--- /dev/null
+++ b/basicsr/data/data_util.py
@@ -0,0 +1,362 @@
+import cv2
+import numpy as np
+import torch
+from os import path as osp
+from torch.nn import functional as F
+
+from basicsr.data.transforms import mod_crop
+from basicsr.utils import img2tensor, scandir
+
+
+def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False):
+    """Read a sequence of images from a given folder path.
+
+    Args:
+        path (list[str] | str): List of image paths or image folder path.
+        require_mod_crop (bool): Require mod crop for each image.
+            Default: False.
+        scale (int): Scale factor for mod_crop. Default: 1.
+        return_imgname(bool): Whether return image names. Default False.
+
+    Returns:
+        Tensor: size (t, c, h, w), RGB, [0, 1].
+        list[str]: Returned image name list.
+    """
+    if isinstance(path, list):
+        img_paths = path
+    else:
+        img_paths = sorted(list(scandir(path, full_path=True)))
+    imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
+
+    if require_mod_crop:
+        imgs = [mod_crop(img, scale) for img in imgs]
+    imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
+    imgs = torch.stack(imgs, dim=0)
+
+    if return_imgname:
+        imgnames = [osp.splitext(osp.basename(path))[0] for path in img_paths]
+        return imgs, imgnames
+    else:
+        return imgs
+
+
+def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
+    """Generate an index list for reading `num_frames` frames from a sequence
+    of images.
+
+    Args:
+        crt_idx (int): Current center index.
+        max_frame_num (int): Max number of the sequence of images (from 1).
+        num_frames (int): Reading num_frames frames.
+        padding (str): Padding mode, one of
+            'replicate' | 'reflection' | 'reflection_circle' | 'circle'
+            Examples: current_idx = 0, num_frames = 5
+            The generated frame indices under different padding mode:
+            replicate: [0, 0, 0, 1, 2]
+            reflection: [2, 1, 0, 1, 2]
+            reflection_circle: [4, 3, 0, 1, 2]
+            circle: [3, 4, 0, 1, 2]
+
+    Returns:
+        list[int]: A list of indices.
+    """
+    assert num_frames % 2 == 1, 'num_frames should be an odd number.'
+    assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'
+
+    max_frame_num = max_frame_num - 1  # start from 0
+    num_pad = num_frames // 2
+
+    indices = []
+    for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
+        if i < 0:
+            if padding == 'replicate':
+                pad_idx = 0
+            elif padding == 'reflection':
+                pad_idx = -i
+            elif padding == 'reflection_circle':
+                pad_idx = crt_idx + num_pad - i
+            else:
+                pad_idx = num_frames + i
+        elif i > max_frame_num:
+            if padding == 'replicate':
+                pad_idx = max_frame_num
+            elif padding == 'reflection':
+                pad_idx = max_frame_num * 2 - i
+            elif padding == 'reflection_circle':
+                pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
+            else:
+                pad_idx = i - num_frames
+        else:
+            pad_idx = i
+        indices.append(pad_idx)
+    return indices
+
+
+def paired_paths_from_lmdb(folders, keys):
+    """Generate paired paths from lmdb files.
+
+    Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
+
+    ::
+
+        lq.lmdb
+        ├── data.mdb
+        ├── lock.mdb
+        ├── meta_info.txt
+
+    The data.mdb and lock.mdb are standard lmdb files and you can refer to
+    https://lmdb.readthedocs.io/en/release/ for more details.
+
+    The meta_info.txt is a specified txt file to record the meta information
+    of our datasets. It will be automatically created when preparing
+    datasets by our provided dataset tools.
+    Each line in the txt file records
+    1)image name (with extension),
+    2)image shape,
+    3)compression level, separated by a white space.
+    Example: `baboon.png (120,125,3) 1`
+
+    We use the image name without extension as the lmdb key.
+    Note that we use the same key for the corresponding lq and gt images.
+
+    Args:
+        folders (list[str]): A list of folder path. The order of list should
+            be [input_folder, gt_folder].
+        keys (list[str]): A list of keys identifying folders. The order should
+            be in consistent with folders, e.g., ['lq', 'gt'].
+            Note that this key is different from lmdb keys.
+
+    Returns:
+        list[str]: Returned path list.
+    """
+    assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
+                               f'But got {len(folders)}')
+    assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
+    input_folder, gt_folder = folders
+    input_key, gt_key = keys
+
+    if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
+        raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
+                         f'formats. But received {input_key}: {input_folder}; '
+                         f'{gt_key}: {gt_folder}')
+    # ensure that the two meta_info files are the same
+    with open(osp.join(input_folder, 'meta_info.txt')) as fin:
+        input_lmdb_keys = [line.split('.')[0] for line in fin]
+    with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
+        gt_lmdb_keys = [line.split('.')[0] for line in fin]
+    if set(input_lmdb_keys) != set(gt_lmdb_keys):
+        raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
+    else:
+        paths = []
+        for lmdb_key in sorted(input_lmdb_keys):
+            paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
+        return paths
+
+
+def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
+    """Generate paired paths from an meta information file.
+
+    Each line in the meta information file contains the image names and
+    image shape (usually for gt), separated by a white space.
+
+    Example of an meta information file:
+    ```
+    0001_s001.png (480,480,3)
+    0001_s002.png (480,480,3)
+    ```
+
+    Args:
+        folders (list[str]): A list of folder path. The order of list should
+            be [input_folder, gt_folder].
+        keys (list[str]): A list of keys identifying folders. The order should
+            be in consistent with folders, e.g., ['lq', 'gt'].
+        meta_info_file (str): Path to the meta information file.
+        filename_tmpl (str): Template for each filename. Note that the
+            template excludes the file extension. Usually the filename_tmpl is
+            for files in the input folder.
+
+    Returns:
+        list[str]: Returned path list.
+    """
+    assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
+                               f'But got {len(folders)}')
+    assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
+    input_folder, gt_folder = folders
+    input_key, gt_key = keys
+
+    with open(meta_info_file, 'r') as fin:
+        gt_names = [line.strip().split(' ')[0] for line in fin]
+
+    paths = []
+    for gt_name in gt_names:
+        basename, ext = osp.splitext(osp.basename(gt_name))
+        input_name = f'{filename_tmpl.format(basename)}{ext}'
+        input_path = osp.join(input_folder, input_name)
+        gt_path = osp.join(gt_folder, gt_name)
+        paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
+    return paths
+
+def paired_paths_from_meta_info_file_2(folders, keys, meta_info_file, filename_tmpl):
+    """Generate paired paths from an meta information file.
+
+    Each line in the meta information file contains the image names and
+    image shape (usually for gt), separated by a white space.
+
+    Example of an meta information file:
+    ```
+    0001_s001.png (480,480,3)
+    0001_s002.png (480,480,3)
+    ```
+
+    Args:
+        folders (list[str]): A list of folder path. The order of list should
+            be [input_folder, gt_folder].
+        keys (list[str]): A list of keys identifying folders. The order should
+            be in consistent with folders, e.g., ['lq', 'gt'].
+        meta_info_file (str): Path to the meta information file.
+        filename_tmpl (str): Template for each filename. Note that the
+            template excludes the file extension. Usually the filename_tmpl is
+            for files in the input folder.
+
+    Returns:
+        list[str]: Returned path list.
+    """
+    assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
+                               f'But got {len(folders)}')
+    assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
+    input_folder, gt_folder = folders
+    input_key, gt_key = keys
+
+    with open(meta_info_file, 'r') as fin:
+        gt_names = [line.strip().split(' ')[0] for line in fin]
+    with open(meta_info_file, 'r') as fin:
+        input_names = [line.strip().split(' ')[1] for line in fin]
+    paths = []
+    for i in range(len(gt_names)):
+        gt_name = gt_names[i]
+        lq_name = input_names[i]
+        basename, ext = osp.splitext(osp.basename(gt_name))
+        basename = gt_name[:-len(ext)]
+        gt_path = osp.join(gt_folder, gt_name)
+        basename, ext = osp.splitext(osp.basename(lq_name))
+        basename = lq_name[:-len(ext)]
+        input_path = osp.join(input_folder, lq_name)
+        paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
+    return paths
+
+def paired_paths_from_folder(folders, keys, filename_tmpl):
+    """Generate paired paths from folders.
+
+    Args:
+        folders (list[str]): A list of folder path. The order of list should
+            be [input_folder, gt_folder].
+        keys (list[str]): A list of keys identifying folders. The order should
+            be in consistent with folders, e.g., ['lq', 'gt'].
+        filename_tmpl (str): Template for each filename. Note that the
+            template excludes the file extension. Usually the filename_tmpl is
+            for files in the input folder.
+
+    Returns:
+        list[str]: Returned path list.
+    """
+    assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
+                               f'But got {len(folders)}')
+    assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
+    input_folder, gt_folder = folders
+    input_key, gt_key = keys
+
+    input_paths = list(scandir(input_folder))
+    gt_paths = list(scandir(gt_folder))
+    assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
+                                               f'{len(input_paths)}, {len(gt_paths)}.')
+    paths = []
+    for gt_path in gt_paths:
+        basename, ext = osp.splitext(osp.basename(gt_path))
+        input_name = f'{filename_tmpl.format(basename)}{ext}'
+        input_path = osp.join(input_folder, input_name)
+        assert input_name in input_paths, f'{input_name} is not in {input_key}_paths.'
+        gt_path = osp.join(gt_folder, gt_path)
+        paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
+    return paths
+
+
+def paths_from_folder(folder):
+    """Generate paths from folder.
+
+    Args:
+        folder (str): Folder path.
+
+    Returns:
+        list[str]: Returned path list.
+    """
+
+    paths = list(scandir(folder))
+    paths = [osp.join(folder, path) for path in paths]
+    return paths
+
+
+def paths_from_lmdb(folder):
+    """Generate paths from lmdb.
+
+    Args:
+        folder (str): Folder path.
+
+    Returns:
+        list[str]: Returned path list.
+    """
+    if not folder.endswith('.lmdb'):
+        raise ValueError(f'Folder {folder}folder should in lmdb format.')
+    with open(osp.join(folder, 'meta_info.txt')) as fin:
+        paths = [line.split('.')[0] for line in fin]
+    return paths
+
+
+def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
+    """Generate Gaussian kernel used in `duf_downsample`.
+
+    Args:
+        kernel_size (int): Kernel size. Default: 13.
+        sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
+
+    Returns:
+        np.array: The Gaussian kernel.
+    """
+    from scipy.ndimage import filters as filters
+    kernel = np.zeros((kernel_size, kernel_size))
+    # set element at the middle to one, a dirac delta
+    kernel[kernel_size // 2, kernel_size // 2] = 1
+    # gaussian-smooth the dirac, resulting in a gaussian filter
+    return filters.gaussian_filter(kernel, sigma)
+
+
+def duf_downsample(x, kernel_size=13, scale=4):
+    """Downsamping with Gaussian kernel used in the DUF official code.
+
+    Args:
+        x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
+        kernel_size (int): Kernel size. Default: 13.
+        scale (int): Downsampling factor. Supported scale: (2, 3, 4).
+            Default: 4.
+
+    Returns:
+        Tensor: DUF downsampled frames.
+    """
+    assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'
+
+    squeeze_flag = False
+    if x.ndim == 4:
+        squeeze_flag = True
+        x = x.unsqueeze(0)
+    b, t, c, h, w = x.size()
+    x = x.view(-1, 1, h, w)
+    pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
+    x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
+
+    gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
+    gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
+    x = F.conv2d(x, gaussian_filter, stride=scale)
+    x = x[:, :, 2:-2, 2:-2]
+    x = x.view(b, t, c, x.size(2), x.size(3))
+    if squeeze_flag:
+        x = x.squeeze(0)
+    return x
diff --git a/basicsr/data/degradations.py b/basicsr/data/degradations.py
new file mode 100644
index 0000000000000000000000000000000000000000..5db40fb080908e9a0de503b9c9518710f89e2e0d
--- /dev/null
+++ b/basicsr/data/degradations.py
@@ -0,0 +1,935 @@
+import cv2
+import math
+import numpy as np
+import random
+import torch
+from scipy import special
+from scipy.stats import multivariate_normal
+from torchvision.transforms.functional_tensor import rgb_to_grayscale
+
+# -------------------------------------------------------------------- #
+# --------------------------- blur kernels --------------------------- #
+# -------------------------------------------------------------------- #
+
+
+# --------------------------- util functions --------------------------- #
+def sigma_matrix2(sig_x, sig_y, theta):
+    """Calculate the rotated sigma matrix (two dimensional matrix).
+
+    Args:
+        sig_x (float):
+        sig_y (float):
+        theta (float): Radian measurement.
+
+    Returns:
+        ndarray: Rotated sigma matrix.
+    """
+    d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]])
+    u_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
+    return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T))
+
+
+def mesh_grid(kernel_size):
+    """Generate the mesh grid, centering at zero.
+
+    Args:
+        kernel_size (int):
+
+    Returns:
+        xy (ndarray): with the shape (kernel_size, kernel_size, 2)
+        xx (ndarray): with the shape (kernel_size, kernel_size)
+        yy (ndarray): with the shape (kernel_size, kernel_size)
+    """
+    ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
+    xx, yy = np.meshgrid(ax, ax)
+    xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), yy.reshape(kernel_size * kernel_size,
+                                                                           1))).reshape(kernel_size, kernel_size, 2)
+    return xy, xx, yy
+
+
+def pdf2(sigma_matrix, grid):
+    """Calculate PDF of the bivariate Gaussian distribution.
+
+    Args:
+        sigma_matrix (ndarray): with the shape (2, 2)
+        grid (ndarray): generated by :func:`mesh_grid`,
+            with the shape (K, K, 2), K is the kernel size.
+
+    Returns:
+        kernel (ndarrray): un-normalized kernel.
+    """
+    inverse_sigma = np.linalg.inv(sigma_matrix)
+    kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
+    return kernel
+
+
+def cdf2(d_matrix, grid):
+    """Calculate the CDF of the standard bivariate Gaussian distribution.
+        Used in skewed Gaussian distribution.
+
+    Args:
+        d_matrix (ndarrasy): skew matrix.
+        grid (ndarray): generated by :func:`mesh_grid`,
+            with the shape (K, K, 2), K is the kernel size.
+
+    Returns:
+        cdf (ndarray): skewed cdf.
+    """
+    rv = multivariate_normal([0, 0], [[1, 0], [0, 1]])
+    grid = np.dot(grid, d_matrix)
+    cdf = rv.cdf(grid)
+    return cdf
+
+
+def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True):
+    """Generate a bivariate isotropic or anisotropic Gaussian kernel.
+
+    In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
+
+    Args:
+        kernel_size (int):
+        sig_x (float):
+        sig_y (float):
+        theta (float): Radian measurement.
+        grid (ndarray, optional): generated by :func:`mesh_grid`,
+            with the shape (K, K, 2), K is the kernel size. Default: None
+        isotropic (bool):
+
+    Returns:
+        kernel (ndarray): normalized kernel.
+    """
+    if grid is None:
+        grid, _, _ = mesh_grid(kernel_size)
+    if isotropic:
+        sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
+    else:
+        sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
+    kernel = pdf2(sigma_matrix, grid)
+    kernel = kernel / np.sum(kernel)
+    return kernel
+
+
+def bivariate_generalized_Gaussian(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
+    """Generate a bivariate generalized Gaussian kernel.
+
+    ``Paper: Parameter Estimation For Multivariate Generalized Gaussian Distributions``
+
+    In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
+
+    Args:
+        kernel_size (int):
+        sig_x (float):
+        sig_y (float):
+        theta (float): Radian measurement.
+        beta (float): shape parameter, beta = 1 is the normal distribution.
+        grid (ndarray, optional): generated by :func:`mesh_grid`,
+            with the shape (K, K, 2), K is the kernel size. Default: None
+
+    Returns:
+        kernel (ndarray): normalized kernel.
+    """
+    if grid is None:
+        grid, _, _ = mesh_grid(kernel_size)
+    if isotropic:
+        sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
+    else:
+        sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
+    inverse_sigma = np.linalg.inv(sigma_matrix)
+    kernel = np.exp(-0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta))
+    kernel = kernel / np.sum(kernel)
+    return kernel
+
+
+def bivariate_plateau(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
+    """Generate a plateau-like anisotropic kernel.
+
+    1 / (1+x^(beta))
+
+    Reference: https://stats.stackexchange.com/questions/203629/is-there-a-plateau-shaped-distribution
+
+    In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
+
+    Args:
+        kernel_size (int):
+        sig_x (float):
+        sig_y (float):
+        theta (float): Radian measurement.
+        beta (float): shape parameter, beta = 1 is the normal distribution.
+        grid (ndarray, optional): generated by :func:`mesh_grid`,
+            with the shape (K, K, 2), K is the kernel size. Default: None
+
+    Returns:
+        kernel (ndarray): normalized kernel.
+    """
+    if grid is None:
+        grid, _, _ = mesh_grid(kernel_size)
+    if isotropic:
+        sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
+    else:
+        sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
+    inverse_sigma = np.linalg.inv(sigma_matrix)
+    kernel = np.reciprocal(np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
+    kernel = kernel / np.sum(kernel)
+    return kernel
+
+
+def random_bivariate_Gaussian(kernel_size,
+                              sigma_x_range,
+                              sigma_y_range,
+                              rotation_range,
+                              noise_range=None,
+                              isotropic=True,
+                              return_sigma=False):
+    """Randomly generate bivariate isotropic or anisotropic Gaussian kernels.
+
+    In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
+
+    Args:
+        kernel_size (int):
+        sigma_x_range (tuple): [0.6, 5]
+        sigma_y_range (tuple): [0.6, 5]
+        rotation range (tuple): [-math.pi, math.pi]
+        noise_range(tuple, optional): multiplicative kernel noise,
+            [0.75, 1.25]. Default: None
+
+    Returns:
+        kernel (ndarray):
+    """
+    assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
+    assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
+    sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
+    if isotropic is False:
+        assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
+        assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
+        sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
+        rotation = np.random.uniform(rotation_range[0], rotation_range[1])
+    else:
+        sigma_y = sigma_x
+        rotation = 0
+
+    kernel = bivariate_Gaussian(kernel_size, sigma_x, sigma_y, rotation, isotropic=isotropic)
+
+    # add multiplicative noise
+    if noise_range is not None:
+        assert noise_range[0] < noise_range[1], 'Wrong noise range.'
+        noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
+        kernel = kernel * noise
+    kernel = kernel / np.sum(kernel)
+    if not return_sigma:
+        return kernel
+    else:
+        return kernel, [sigma_x, sigma_y]
+
+
+def random_bivariate_generalized_Gaussian(kernel_size,
+                                          sigma_x_range,
+                                          sigma_y_range,
+                                          rotation_range,
+                                          beta_range,
+                                          noise_range=None,
+                                          isotropic=True,
+                                          return_sigma=False):
+    """Randomly generate bivariate generalized Gaussian kernels.
+
+    In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
+
+    Args:
+        kernel_size (int):
+        sigma_x_range (tuple): [0.6, 5]
+        sigma_y_range (tuple): [0.6, 5]
+        rotation range (tuple): [-math.pi, math.pi]
+        beta_range (tuple): [0.5, 8]
+        noise_range(tuple, optional): multiplicative kernel noise,
+            [0.75, 1.25]. Default: None
+
+    Returns:
+        kernel (ndarray):
+    """
+    assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
+    assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
+    sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
+    if isotropic is False:
+        assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
+        assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
+        sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
+        rotation = np.random.uniform(rotation_range[0], rotation_range[1])
+    else:
+        sigma_y = sigma_x
+        rotation = 0
+
+    # assume beta_range[0] < 1 < beta_range[1]
+    if np.random.uniform() < 0.5:
+        beta = np.random.uniform(beta_range[0], 1)
+    else:
+        beta = np.random.uniform(1, beta_range[1])
+
+    kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
+
+    # add multiplicative noise
+    if noise_range is not None:
+        assert noise_range[0] < noise_range[1], 'Wrong noise range.'
+        noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
+        kernel = kernel * noise
+    kernel = kernel / np.sum(kernel)
+    if not return_sigma:
+        return kernel
+    else:
+        return kernel, [sigma_x, sigma_y]
+
+
+def random_bivariate_plateau(kernel_size,
+                             sigma_x_range,
+                             sigma_y_range,
+                             rotation_range,
+                             beta_range,
+                             noise_range=None,
+                             isotropic=True,
+                             return_sigma=False):
+    """Randomly generate bivariate plateau kernels.
+
+    In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
+
+    Args:
+        kernel_size (int):
+        sigma_x_range (tuple): [0.6, 5]
+        sigma_y_range (tuple): [0.6, 5]
+        rotation range (tuple): [-math.pi/2, math.pi/2]
+        beta_range (tuple): [1, 4]
+        noise_range(tuple, optional): multiplicative kernel noise,
+            [0.75, 1.25]. Default: None
+
+    Returns:
+        kernel (ndarray):
+    """
+    assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
+    assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
+    sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
+    if isotropic is False:
+        assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
+        assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
+        sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
+        rotation = np.random.uniform(rotation_range[0], rotation_range[1])
+    else:
+        sigma_y = sigma_x
+        rotation = 0
+
+    # TODO: this may be not proper
+    if np.random.uniform() < 0.5:
+        beta = np.random.uniform(beta_range[0], 1)
+    else:
+        beta = np.random.uniform(1, beta_range[1])
+
+    kernel = bivariate_plateau(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
+    # add multiplicative noise
+    if noise_range is not None:
+        assert noise_range[0] < noise_range[1], 'Wrong noise range.'
+        noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
+        kernel = kernel * noise
+    kernel = kernel / np.sum(kernel)
+
+    if not return_sigma:
+        return kernel
+    else:
+        return kernel, [sigma_x, sigma_y]
+
+
+def random_mixed_kernels(kernel_list,
+                         kernel_prob,
+                         kernel_size=21,
+                         sigma_x_range=(0.6, 5),
+                         sigma_y_range=(0.6, 5),
+                         rotation_range=(-math.pi, math.pi),
+                         betag_range=(0.5, 8),
+                         betap_range=(0.5, 8),
+                         noise_range=None,
+                         return_sigma=False):
+    """Randomly generate mixed kernels.
+
+    Args:
+        kernel_list (tuple): a list name of kernel types,
+            support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso',
+            'plateau_aniso']
+        kernel_prob (tuple): corresponding kernel probability for each
+            kernel type
+        kernel_size (int):
+        sigma_x_range (tuple): [0.6, 5]
+        sigma_y_range (tuple): [0.6, 5]
+        rotation range (tuple): [-math.pi, math.pi]
+        beta_range (tuple): [0.5, 8]
+        noise_range(tuple, optional): multiplicative kernel noise,
+            [0.75, 1.25]. Default: None
+
+    Returns:
+        kernel (ndarray):
+    """
+    kernel_type = random.choices(kernel_list, kernel_prob)[0]
+    if not return_sigma:
+        if kernel_type == 'iso':
+            kernel = random_bivariate_Gaussian(
+                kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=True, return_sigma=return_sigma)
+        elif kernel_type == 'aniso':
+            kernel = random_bivariate_Gaussian(
+                kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=False, return_sigma=return_sigma)
+        elif kernel_type == 'generalized_iso':
+            kernel = random_bivariate_generalized_Gaussian(
+                kernel_size,
+                sigma_x_range,
+                sigma_y_range,
+                rotation_range,
+                betag_range,
+                noise_range=noise_range,
+                isotropic=True,
+                return_sigma=return_sigma)
+        elif kernel_type == 'generalized_aniso':
+            kernel = random_bivariate_generalized_Gaussian(
+                kernel_size,
+                sigma_x_range,
+                sigma_y_range,
+                rotation_range,
+                betag_range,
+                noise_range=noise_range,
+                isotropic=False,
+                return_sigma=return_sigma)
+        elif kernel_type == 'plateau_iso':
+            kernel = random_bivariate_plateau(
+                kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=True, return_sigma=return_sigma)
+        elif kernel_type == 'plateau_aniso':
+            kernel = random_bivariate_plateau(
+                kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=False, return_sigma=return_sigma)
+        return kernel
+    else:
+        if kernel_type == 'iso':
+            kernel, sigma_list = random_bivariate_Gaussian(
+                kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=True, return_sigma=return_sigma)
+        elif kernel_type == 'aniso':
+            kernel, sigma_list = random_bivariate_Gaussian(
+                kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=False, return_sigma=return_sigma)
+        elif kernel_type == 'generalized_iso':
+            kernel, sigma_list = random_bivariate_generalized_Gaussian(
+                kernel_size,
+                sigma_x_range,
+                sigma_y_range,
+                rotation_range,
+                betag_range,
+                noise_range=noise_range,
+                isotropic=True,
+                return_sigma=return_sigma)
+        elif kernel_type == 'generalized_aniso':
+            kernel, sigma_list = random_bivariate_generalized_Gaussian(
+                kernel_size,
+                sigma_x_range,
+                sigma_y_range,
+                rotation_range,
+                betag_range,
+                noise_range=noise_range,
+                isotropic=False,
+                return_sigma=return_sigma)
+        elif kernel_type == 'plateau_iso':
+            kernel, sigma_list = random_bivariate_plateau(
+                kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=True, return_sigma=return_sigma)
+        elif kernel_type == 'plateau_aniso':
+            kernel, sigma_list = random_bivariate_plateau(
+                kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=False, return_sigma=return_sigma)
+        return kernel, sigma_list
+
+
+np.seterr(divide='ignore', invalid='ignore')
+
+
+def circular_lowpass_kernel(cutoff, kernel_size, pad_to=0):
+    """2D sinc filter
+
+    Reference: https://dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter
+
+    Args:
+        cutoff (float): cutoff frequency in radians (pi is max)
+        kernel_size (int): horizontal and vertical size, must be odd.
+        pad_to (int): pad kernel size to desired size, must be odd or zero.
+    """
+    assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
+    kernel = np.fromfunction(
+        lambda x, y: cutoff * special.j1(cutoff * np.sqrt(
+            (x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)) / (2 * np.pi * np.sqrt(
+                (x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)), [kernel_size, kernel_size])
+    kernel[(kernel_size - 1) // 2, (kernel_size - 1) // 2] = cutoff**2 / (4 * np.pi)
+    kernel = kernel / np.sum(kernel)
+    if pad_to > kernel_size:
+        pad_size = (pad_to - kernel_size) // 2
+        kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
+    return kernel
+
+
+# ------------------------------------------------------------- #
+# --------------------------- noise --------------------------- #
+# ------------------------------------------------------------- #
+
+# ----------------------- Gaussian Noise ----------------------- #
+
+
+def generate_gaussian_noise(img, sigma=10, gray_noise=False):
+    """Generate Gaussian noise.
+
+    Args:
+        img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
+        sigma (float): Noise scale (measured in range 255). Default: 10.
+
+    Returns:
+        (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
+            float32.
+    """
+    if gray_noise:
+        noise = np.float32(np.random.randn(*(img.shape[0:2]))) * sigma / 255.
+        noise = np.expand_dims(noise, axis=2).repeat(3, axis=2)
+    else:
+        noise = np.float32(np.random.randn(*(img.shape))) * sigma / 255.
+    return noise
+
+
+def add_gaussian_noise(img, sigma=10, clip=True, rounds=False, gray_noise=False):
+    """Add Gaussian noise.
+
+    Args:
+        img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
+        sigma (float): Noise scale (measured in range 255). Default: 10.
+
+    Returns:
+        (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
+            float32.
+    """
+    noise = generate_gaussian_noise(img, sigma, gray_noise)
+    out = img + noise
+    if clip and rounds:
+        out = np.clip((out * 255.0).round(), 0, 255) / 255.
+    elif clip:
+        out = np.clip(out, 0, 1)
+    elif rounds:
+        out = (out * 255.0).round() / 255.
+    return out
+
+
+def generate_gaussian_noise_pt(img, sigma=10, gray_noise=0):
+    """Add Gaussian noise (PyTorch version).
+
+    Args:
+        img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
+        scale (float | Tensor): Noise scale. Default: 1.0.
+
+    Returns:
+        (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
+            float32.
+    """
+    b, _, h, w = img.size()
+    if not isinstance(sigma, (float, int)):
+        sigma = sigma.view(img.size(0), 1, 1, 1)
+    if isinstance(gray_noise, (float, int)):
+        cal_gray_noise = gray_noise > 0
+    else:
+        gray_noise = gray_noise.view(b, 1, 1, 1)
+        cal_gray_noise = torch.sum(gray_noise) > 0
+
+    if cal_gray_noise:
+        noise_gray = torch.randn(*img.size()[2:4], dtype=img.dtype, device=img.device) * sigma / 255.
+        noise_gray = noise_gray.view(b, 1, h, w)
+
+    # always calculate color noise
+    noise = torch.randn(*img.size(), dtype=img.dtype, device=img.device) * sigma / 255.
+
+    if cal_gray_noise:
+        noise = noise * (1 - gray_noise) + noise_gray * gray_noise
+    return noise
+
+
+def add_gaussian_noise_pt(img, sigma=10, gray_noise=0, clip=True, rounds=False):
+    """Add Gaussian noise (PyTorch version).
+
+    Args:
+        img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
+        scale (float | Tensor): Noise scale. Default: 1.0.
+
+    Returns:
+        (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
+            float32.
+    """
+    noise = generate_gaussian_noise_pt(img, sigma, gray_noise)
+    out = img + noise
+    if clip and rounds:
+        out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
+    elif clip:
+        out = torch.clamp(out, 0, 1)
+    elif rounds:
+        out = (out * 255.0).round() / 255.
+    return out
+
+
+# ----------------------- Random Gaussian Noise ----------------------- #
+def random_generate_gaussian_noise(img, sigma_range=(0, 10), gray_prob=0, return_sigma=False):
+    sigma = np.random.uniform(sigma_range[0], sigma_range[1])
+    if np.random.uniform() < gray_prob:
+        gray_noise = True
+    else:
+        gray_noise = False
+    if return_sigma:
+        return generate_gaussian_noise(img, sigma, gray_noise), sigma
+    else:
+        return generate_gaussian_noise(img, sigma, gray_noise)
+
+
+def random_add_gaussian_noise(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False, return_sigma=False):
+    if return_sigma:
+        noise, sigma = random_generate_gaussian_noise(img, sigma_range, gray_prob, return_sigma=return_sigma)
+    else:
+        noise = random_generate_gaussian_noise(img, sigma_range, gray_prob, return_sigma=return_sigma)
+    out = img + noise
+    if clip and rounds:
+        out = np.clip((out * 255.0).round(), 0, 255) / 255.
+    elif clip:
+        out = np.clip(out, 0, 1)
+    elif rounds:
+        out = (out * 255.0).round() / 255.
+    if return_sigma:
+        return out, sigma
+    else:
+        return out
+
+
+def random_generate_gaussian_noise_pt(img, sigma_range=(0, 10), gray_prob=0):
+    sigma = torch.rand(
+        img.size(0), dtype=img.dtype, device=img.device) * (sigma_range[1] - sigma_range[0]) + sigma_range[0]
+    gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
+    gray_noise = (gray_noise < gray_prob).float()
+    return generate_gaussian_noise_pt(img, sigma, gray_noise)
+
+
+def random_add_gaussian_noise_pt(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
+    noise = random_generate_gaussian_noise_pt(img, sigma_range, gray_prob)
+    out = img + noise
+    if clip and rounds:
+        out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
+    elif clip:
+        out = torch.clamp(out, 0, 1)
+    elif rounds:
+        out = (out * 255.0).round() / 255.
+    return out
+
+# ----------------------- Poisson (Shot) Noise ----------------------- #
+
+
+def generate_poisson_noise(img, scale=1.0, gray_noise=False):
+    """Generate poisson noise.
+
+    Reference: https://github.com/scikit-image/scikit-image/blob/main/skimage/util/noise.py#L37-L219
+
+    Args:
+        img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
+        scale (float): Noise scale. Default: 1.0.
+        gray_noise (bool): Whether generate gray noise. Default: False.
+
+    Returns:
+        (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
+            float32.
+    """
+    if gray_noise:
+        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+    # round and clip image for counting vals correctly
+    img = np.clip((img * 255.0).round(), 0, 255) / 255.
+    vals = len(np.unique(img))
+    vals = 2**np.ceil(np.log2(vals))
+    out = np.float32(np.random.poisson(img * vals) / float(vals))
+    noise = out - img
+    if gray_noise:
+        noise = np.repeat(noise[:, :, np.newaxis], 3, axis=2)
+    return noise * scale
+
+
+def add_poisson_noise(img, scale=1.0, clip=True, rounds=False, gray_noise=False):
+    """Add poisson noise.
+
+    Args:
+        img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
+        scale (float): Noise scale. Default: 1.0.
+        gray_noise (bool): Whether generate gray noise. Default: False.
+
+    Returns:
+        (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
+            float32.
+    """
+    noise = generate_poisson_noise(img, scale, gray_noise)
+    out = img + noise
+    if clip and rounds:
+        out = np.clip((out * 255.0).round(), 0, 255) / 255.
+    elif clip:
+        out = np.clip(out, 0, 1)
+    elif rounds:
+        out = (out * 255.0).round() / 255.
+    return out
+
+
+def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0):
+    """Generate a batch of poisson noise (PyTorch version)
+
+    Args:
+        img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
+        scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
+            Default: 1.0.
+        gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
+            0 for False, 1 for True. Default: 0.
+
+    Returns:
+        (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
+            float32.
+    """
+    b, _, h, w = img.size()
+    if isinstance(gray_noise, (float, int)):
+        cal_gray_noise = gray_noise > 0
+    else:
+        gray_noise = gray_noise.view(b, 1, 1, 1)
+        cal_gray_noise = torch.sum(gray_noise) > 0
+    if cal_gray_noise:
+        img_gray = rgb_to_grayscale(img, num_output_channels=1)
+        # round and clip image for counting vals correctly
+        img_gray = torch.clamp((img_gray * 255.0).round(), 0, 255) / 255.
+        # use for-loop to get the unique values for each sample
+        vals_list = [len(torch.unique(img_gray[i, :, :, :])) for i in range(b)]
+        vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
+        vals = img_gray.new_tensor(vals_list).view(b, 1, 1, 1)
+        out = torch.poisson(img_gray * vals) / vals
+        noise_gray = out - img_gray
+        noise_gray = noise_gray.expand(b, 3, h, w)
+
+    # always calculate color noise
+    # round and clip image for counting vals correctly
+    img = torch.clamp((img * 255.0).round(), 0, 255) / 255.
+    # use for-loop to get the unique values for each sample
+    vals_list = [len(torch.unique(img[i, :, :, :])) for i in range(b)]
+    vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
+    vals = img.new_tensor(vals_list).view(b, 1, 1, 1)
+    out = torch.poisson(img * vals) / vals
+    noise = out - img
+    if cal_gray_noise:
+        noise = noise * (1 - gray_noise) + noise_gray * gray_noise
+    if not isinstance(scale, (float, int)):
+        scale = scale.view(b, 1, 1, 1)
+    return noise * scale
+
+
+def add_poisson_noise_pt(img, scale=1.0, clip=True, rounds=False, gray_noise=0):
+    """Add poisson noise to a batch of images (PyTorch version).
+
+    Args:
+        img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
+        scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
+            Default: 1.0.
+        gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
+            0 for False, 1 for True. Default: 0.
+
+    Returns:
+        (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
+            float32.
+    """
+    noise = generate_poisson_noise_pt(img, scale, gray_noise)
+    out = img + noise
+    if clip and rounds:
+        out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
+    elif clip:
+        out = torch.clamp(out, 0, 1)
+    elif rounds:
+        out = (out * 255.0).round() / 255.
+    return out
+
+
+# ----------------------- Random Poisson (Shot) Noise ----------------------- #
+
+
+def random_generate_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0):
+    scale = np.random.uniform(scale_range[0], scale_range[1])
+    if np.random.uniform() < gray_prob:
+        gray_noise = True
+    else:
+        gray_noise = False
+    return generate_poisson_noise(img, scale, gray_noise)
+
+
+def random_add_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
+    noise = random_generate_poisson_noise(img, scale_range, gray_prob)
+    out = img + noise
+    if clip and rounds:
+        out = np.clip((out * 255.0).round(), 0, 255) / 255.
+    elif clip:
+        out = np.clip(out, 0, 1)
+    elif rounds:
+        out = (out * 255.0).round() / 255.
+    return out
+
+
+def random_generate_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0):
+    scale = torch.rand(
+        img.size(0), dtype=img.dtype, device=img.device) * (scale_range[1] - scale_range[0]) + scale_range[0]
+    gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
+    gray_noise = (gray_noise < gray_prob).float()
+    return generate_poisson_noise_pt(img, scale, gray_noise)
+
+
+def random_add_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
+    noise = random_generate_poisson_noise_pt(img, scale_range, gray_prob)
+    out = img + noise
+    if clip and rounds:
+        out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
+    elif clip:
+        out = torch.clamp(out, 0, 1)
+    elif rounds:
+        out = (out * 255.0).round() / 255.
+    return out
+
+# ----------------------- Random speckle Noise ----------------------- #
+
+def random_add_speckle_noise(imgs, speckle_std):
+    std_range = speckle_std
+    std_l = std_range[0]
+    std_r = std_range[1]
+    mean=0
+    std=random.uniform(std_l/255.,std_r/255.)
+
+    outputs = []
+    for img in imgs:
+        gauss=np.random.normal(loc=mean,scale=std,size=img.shape)
+        noisy=img+gauss*img
+        noisy=np.clip(noisy,0,1).astype(np.float32)
+
+        outputs.append(noisy)
+
+    return outputs
+
+
+def random_add_speckle_noise_pt(img, speckle_std):
+    std_range = speckle_std
+    std_l = std_range[0]
+    std_r = std_range[1]
+    mean=0
+    std=random.uniform(std_l/255.,std_r/255.)
+    gauss=torch.normal(mean=mean,std=std,size=img.size()).to(img.device)
+    noisy=img+gauss*img
+    noisy=torch.clamp(noisy,0,1)
+    return noisy
+
+# ----------------------- Random saltpepper Noise ----------------------- #
+
+def random_add_saltpepper_noise(imgs, saltpepper_amount, saltpepper_svsp):
+    p_range = saltpepper_amount
+    p = random.uniform(p_range[0], p_range[1])
+    q_range = saltpepper_svsp
+    q = random.uniform(q_range[0], q_range[1])
+
+    outputs = []
+    for img in imgs:
+        out = img.copy()
+        flipped = np.random.choice([True, False], size=img.shape,
+                            p=[p, 1 - p])
+        salted = np.random.choice([True, False], size=img.shape,
+                            p=[q, 1 - q])
+        peppered = ~salted
+        out[flipped & salted] = 1
+        out[flipped & peppered] = 0.
+        noisy = np.clip(out, 0, 1).astype(np.float32)
+
+        outputs.append(noisy)
+
+    return outputs
+
+def random_add_saltpepper_noise_pt(imgs, saltpepper_amount, saltpepper_svsp):
+    p_range = saltpepper_amount
+    p = random.uniform(p_range[0], p_range[1])
+    q_range = saltpepper_svsp
+    q = random.uniform(q_range[0], q_range[1])
+
+    imgs = imgs.permute(0,2,3,1)
+
+    outputs = []
+    for i in range(imgs.size(0)):
+        img = imgs[i]
+        out = img.clone()
+        flipped = np.random.choice([True, False], size=img.shape,
+                            p=[p, 1 - p])
+        salted = np.random.choice([True, False], size=img.shape,
+                            p=[q, 1 - q])
+        peppered = ~salted
+        temp = flipped & salted
+        out[flipped & salted] = 1
+        out[flipped & peppered] = 0.
+        noisy = torch.clamp(out, 0, 1)
+
+        outputs.append(noisy.permute(2,0,1))
+    if len(outputs)>1:
+        return torch.cat(outputs, dim=0)
+    else:
+        return outputs[0].unsqueeze(0)
+
+# ----------------------- Random screen Noise ----------------------- #
+
+def random_add_screen_noise(imgs, linewidth, space):
+    #screen_noise = np.random.uniform() < self.params['noise_prob'][0]
+    linewidth = linewidth
+    linewidth = int(np.random.uniform(linewidth[0], linewidth[1]))
+    space = space
+    space = int(np.random.uniform(space[0], space[1]))
+    center_color = [213,230,230] # RGB
+    outputs = []
+    for img in imgs:
+        noise = img.copy()
+
+        tmp_mask = np.zeros((img.shape[1], img.shape[0]), dtype=np.float32)
+        for i in range(0, img.shape[0], int((space+linewidth))):
+            tmp_mask[:, i:(i+linewidth)] = 1
+        colour_masks = np.zeros((img.shape[0], img.shape[1], 3), dtype=np.float32)
+        colour_masks[:,:,0] = (center_color[0] + np.random.uniform(-20, 20))/255.
+        colour_masks[:,:,1] = (center_color[1] + np.random.uniform(0, 20))/255.
+        colour_masks[:,:,2] = (center_color[2] + np.random.uniform(0, 20))/255.
+        noise_color = cv2.addWeighted(noise, 0.6, colour_masks, 0.4, 0.0)
+        noise = noise*(1-(tmp_mask[:,:,np.newaxis])) + noise_color*(tmp_mask[:,:,np.newaxis])
+
+        outputs.append(noise)
+
+    return outputs
+
+
+# ------------------------------------------------------------------------ #
+# --------------------------- JPEG compression --------------------------- #
+# ------------------------------------------------------------------------ #
+
+
+def add_jpg_compression(img, quality=90):
+    """Add JPG compression artifacts.
+
+    Args:
+        img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
+        quality (float): JPG compression quality. 0 for lowest quality, 100 for
+            best quality. Default: 90.
+
+    Returns:
+        (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
+            float32.
+    """
+    img = np.clip(img, 0, 1)
+    encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), int(quality)]
+    _, encimg = cv2.imencode('.jpg', img * 255., encode_param)
+    img = np.float32(cv2.imdecode(encimg, 1)) / 255.
+    return img
+
+
+def random_add_jpg_compression(img, quality_range=(90, 100), return_q=False):
+    """Randomly add JPG compression artifacts.
+
+    Args:
+        img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
+        quality_range (tuple[float] | list[float]): JPG compression quality
+            range. 0 for lowest quality, 100 for best quality.
+            Default: (90, 100).
+
+    Returns:
+        (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
+            float32.
+    """
+    quality = np.random.uniform(quality_range[0], quality_range[1])
+    if return_q:
+        return add_jpg_compression(img, quality), quality
+    else:
+        return add_jpg_compression(img, quality)
diff --git a/basicsr/data/ffhq_dataset.py b/basicsr/data/ffhq_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..23992eb877f6b7b46cf5f40ed3667fc10916269b
--- /dev/null
+++ b/basicsr/data/ffhq_dataset.py
@@ -0,0 +1,80 @@
+import random
+import time
+from os import path as osp
+from torch.utils import data as data
+from torchvision.transforms.functional import normalize
+
+from basicsr.data.transforms import augment
+from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
+from basicsr.utils.registry import DATASET_REGISTRY
+
+
+@DATASET_REGISTRY.register()
+class FFHQDataset(data.Dataset):
+    """FFHQ dataset for StyleGAN.
+
+    Args:
+        opt (dict): Config for train datasets. It contains the following keys:
+            dataroot_gt (str): Data root path for gt.
+            io_backend (dict): IO backend type and other kwarg.
+            mean (list | tuple): Image mean.
+            std (list | tuple): Image std.
+            use_hflip (bool): Whether to horizontally flip.
+
+    """
+
+    def __init__(self, opt):
+        super(FFHQDataset, self).__init__()
+        self.opt = opt
+        # file client (io backend)
+        self.file_client = None
+        self.io_backend_opt = opt['io_backend']
+
+        self.gt_folder = opt['dataroot_gt']
+        self.mean = opt['mean']
+        self.std = opt['std']
+
+        if self.io_backend_opt['type'] == 'lmdb':
+            self.io_backend_opt['db_paths'] = self.gt_folder
+            if not self.gt_folder.endswith('.lmdb'):
+                raise ValueError("'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
+            with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
+                self.paths = [line.split('.')[0] for line in fin]
+        else:
+            # FFHQ has 70000 images in total
+            self.paths = [osp.join(self.gt_folder, f'{v:08d}.png') for v in range(70000)]
+
+    def __getitem__(self, index):
+        if self.file_client is None:
+            self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+        # load gt image
+        gt_path = self.paths[index]
+        # avoid errors caused by high latency in reading files
+        retry = 3
+        while retry > 0:
+            try:
+                img_bytes = self.file_client.get(gt_path)
+            except Exception as e:
+                logger = get_root_logger()
+                logger.warning(f'File client error: {e}, remaining retry times: {retry - 1}')
+                # change another file to read
+                index = random.randint(0, self.__len__())
+                gt_path = self.paths[index]
+                time.sleep(1)  # sleep 1s for occasional server congestion
+            else:
+                break
+            finally:
+                retry -= 1
+        img_gt = imfrombytes(img_bytes, float32=True)
+
+        # random horizontal flip
+        img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False)
+        # BGR to RGB, HWC to CHW, numpy to tensor
+        img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True)
+        # normalize
+        normalize(img_gt, self.mean, self.std, inplace=True)
+        return {'gt': img_gt, 'gt_path': gt_path}
+
+    def __len__(self):
+        return len(self.paths)
diff --git a/basicsr/data/ffhq_degradation_dataset.py b/basicsr/data/ffhq_degradation_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3d8c934ddd68816f2d2f0038191530056fe912c
--- /dev/null
+++ b/basicsr/data/ffhq_degradation_dataset.py
@@ -0,0 +1,232 @@
+import cv2
+import math
+import numpy as np
+import os.path as osp
+import torch
+import torch.utils.data as data
+import random
+from basicsr.data import degradations as degradations
+from basicsr.data.data_util import paths_from_folder
+from basicsr.data.transforms import augment
+from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
+from basicsr.utils.registry import DATASET_REGISTRY
+from pathlib import Path
+from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation,
+                                               normalize)
+
+@DATASET_REGISTRY.register()
+class FFHQDegradationDataset(data.Dataset):
+    """FFHQ dataset for GFPGAN.
+    It reads high resolution images, and then generate low-quality (LQ) images on-the-fly.
+    Args:
+        opt (dict): Config for train datasets. It contains the following keys:
+            dataroot_gt (str): Data root path for gt.
+            io_backend (dict): IO backend type and other kwarg.
+            mean (list | tuple): Image mean.
+            std (list | tuple): Image std.
+            use_hflip (bool): Whether to horizontally flip.
+            Please see more options in the codes.
+    """
+
+    def __init__(self, opt):
+        super(FFHQDegradationDataset, self).__init__()
+        self.opt = opt
+        # file client (io backend)
+        self.file_client = None
+        self.io_backend_opt = opt['io_backend']
+        if 'image_type' not in opt:
+            opt['image_type'] = 'png'
+
+        self.gt_folder = opt['dataroot_gt']
+        self.mean = opt['mean']
+        self.std = opt['std']
+        self.out_size = opt['out_size']
+
+        self.crop_components = opt.get('crop_components', False)  # facial components
+        self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1)  # whether enlarge eye regions
+
+        if self.crop_components:
+            # load component list from a pre-process pth files
+            self.components_list = torch.load(opt.get('component_path'))
+
+        # file client (lmdb io backend)
+        if self.io_backend_opt['type'] == 'lmdb':
+            self.io_backend_opt['db_paths'] = self.gt_folder
+            if not self.gt_folder.endswith('.lmdb'):
+                raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
+            with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
+                self.paths = [line.split('.')[0] for line in fin]
+        else:
+            # disk backend: scan file list from a folder
+            self.paths = self.paths = sorted([str(x) for x in Path(self.gt_folder).glob('*.'+opt['image_type'])])
+
+        # degradation configurations
+        self.blur_kernel_size = opt['blur_kernel_size']
+        self.kernel_list = opt['kernel_list']
+        self.kernel_prob = opt['kernel_prob']
+        self.blur_sigma = opt['blur_sigma']
+        self.downsample_range = opt['downsample_range']
+        self.noise_range = opt['noise_range']
+        self.jpeg_range = opt['jpeg_range']
+
+        # color jitter
+        self.color_jitter_prob = opt.get('color_jitter_prob')
+        self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob')
+        self.color_jitter_shift = opt.get('color_jitter_shift', 20)
+        # to gray
+        self.gray_prob = opt.get('gray_prob')
+
+        logger = get_root_logger()
+        logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]')
+        logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
+        logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
+        logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
+
+        if self.color_jitter_prob is not None:
+            logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}')
+        if self.gray_prob is not None:
+            logger.info(f'Use random gray. Prob: {self.gray_prob}')
+        if self.color_jitter_shift is not None:
+            self.color_jitter_shift /= 255.
+
+    @staticmethod
+    def color_jitter(img, shift):
+        """jitter color: randomly jitter the RGB values, in numpy formats"""
+        jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
+        img = img + jitter_val
+        img = np.clip(img, 0, 1)
+        return img
+
+    @staticmethod
+    def color_jitter_pt(img, brightness, contrast, saturation, hue):
+        """jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats"""
+        fn_idx = torch.randperm(4)
+        for fn_id in fn_idx:
+            if fn_id == 0 and brightness is not None:
+                brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
+                img = adjust_brightness(img, brightness_factor)
+
+            if fn_id == 1 and contrast is not None:
+                contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
+                img = adjust_contrast(img, contrast_factor)
+
+            if fn_id == 2 and saturation is not None:
+                saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
+                img = adjust_saturation(img, saturation_factor)
+
+            if fn_id == 3 and hue is not None:
+                hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
+                img = adjust_hue(img, hue_factor)
+        return img
+
+    def get_component_coordinates(self, index, status):
+        """Get facial component (left_eye, right_eye, mouth) coordinates from a pre-loaded pth file"""
+        components_bbox = self.components_list[f'{index:08d}']
+        if status[0]:  # hflip
+            # exchange right and left eye
+            tmp = components_bbox['left_eye']
+            components_bbox['left_eye'] = components_bbox['right_eye']
+            components_bbox['right_eye'] = tmp
+            # modify the width coordinate
+            components_bbox['left_eye'][0] = self.out_size - components_bbox['left_eye'][0]
+            components_bbox['right_eye'][0] = self.out_size - components_bbox['right_eye'][0]
+            components_bbox['mouth'][0] = self.out_size - components_bbox['mouth'][0]
+
+        # get coordinates
+        locations = []
+        for part in ['left_eye', 'right_eye', 'mouth']:
+            mean = components_bbox[part][0:2]
+            half_len = components_bbox[part][2]
+            if 'eye' in part:
+                half_len *= self.eye_enlarge_ratio
+            loc = np.hstack((mean - half_len + 1, mean + half_len))
+            loc = torch.from_numpy(loc).float()
+            locations.append(loc)
+        return locations
+
+    def __getitem__(self, index):
+        if self.file_client is None:
+            self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+        # load gt image
+        # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
+        gt_path = self.paths[index]
+        img_bytes = self.file_client.get(gt_path)
+        img_gt = imfrombytes(img_bytes, float32=True)
+
+        # random horizontal flip
+        img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
+        h, w, _ = img_gt.shape
+
+        # get facial component coordinates
+        if self.crop_components:
+            locations = self.get_component_coordinates(index, status)
+            loc_left_eye, loc_right_eye, loc_mouth = locations
+
+        # ------------------------ generate lq image ------------------------ #
+        # blur
+        kernel = degradations.random_mixed_kernels(
+            self.kernel_list,
+            self.kernel_prob,
+            self.blur_kernel_size,
+            self.blur_sigma,
+            self.blur_sigma, [-math.pi, math.pi],
+            noise_range=None)
+        img_lq = cv2.filter2D(img_gt, -1, kernel)
+        # downsample
+        scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
+        img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR)
+        # noise
+        if self.noise_range is not None:
+            img_lq = degradations.random_add_gaussian_noise(img_lq, self.noise_range)
+        # jpeg compression
+        if self.jpeg_range is not None:
+            img_lq = degradations.random_add_jpg_compression(img_lq, self.jpeg_range)
+
+        # resize to original size
+        img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR)
+
+        # random color jitter (only for lq)
+        if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
+            img_lq = self.color_jitter(img_lq, self.color_jitter_shift)
+        # random to gray (only for lq)
+        if self.gray_prob and np.random.uniform() < self.gray_prob:
+            img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY)
+            img_lq = np.tile(img_lq[:, :, None], [1, 1, 3])
+            if self.opt.get('gt_gray'):  # whether convert GT to gray images
+                img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
+                img_gt = np.tile(img_gt[:, :, None], [1, 1, 3])  # repeat the color channels
+
+        # BGR to RGB, HWC to CHW, numpy to tensor
+        img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
+
+        # random color jitter (pytorch version) (only for lq)
+        if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
+            brightness = self.opt.get('brightness', (0.5, 1.5))
+            contrast = self.opt.get('contrast', (0.5, 1.5))
+            saturation = self.opt.get('saturation', (0, 1.5))
+            hue = self.opt.get('hue', (-0.1, 0.1))
+            img_lq = self.color_jitter_pt(img_lq, brightness, contrast, saturation, hue)
+
+        # round and clip
+        img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255.
+
+        # normalize
+        normalize(img_gt, self.mean, self.std, inplace=True)
+        normalize(img_lq, self.mean, self.std, inplace=True)
+
+        if self.crop_components:
+            return_dict = {
+                'lq': img_lq,
+                'gt': img_gt,
+                'gt_path': gt_path,
+                'loc_left_eye': loc_left_eye,
+                'loc_right_eye': loc_right_eye,
+                'loc_mouth': loc_mouth
+            }
+            return return_dict
+        else:
+            return {'lq': img_lq, 'gt': img_gt, 'gt_path': gt_path}
+
+    def __len__(self):
+        return len(self.paths)
diff --git a/basicsr/data/paired_image_dataset.py b/basicsr/data/paired_image_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..41965cd159ec539aca3d60f5a5ccd84736e13d61
--- /dev/null
+++ b/basicsr/data/paired_image_dataset.py
@@ -0,0 +1,115 @@
+from torch.utils import data as data
+from torchvision.transforms.functional import normalize
+
+from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file, paired_paths_from_meta_info_file_2
+from basicsr.data.transforms import augment, paired_random_crop
+from basicsr.utils import FileClient, bgr2ycbcr, imfrombytes, img2tensor
+from basicsr.utils.registry import DATASET_REGISTRY
+import cv2
+
+
+@DATASET_REGISTRY.register()
+class PairedImageDataset(data.Dataset):
+    """Paired image dataset for image restoration.
+
+    Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
+
+    There are three modes:
+
+    1. **lmdb**: Use lmdb files. If opt['io_backend'] == lmdb.
+    2. **meta_info_file**: Use meta information file to generate paths. \
+        If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
+    3. **folder**: Scan folders to generate paths. The rest.
+
+    Args:
+        opt (dict): Config for train datasets. It contains the following keys:
+        dataroot_gt (str): Data root path for gt.
+        dataroot_lq (str): Data root path for lq.
+        meta_info_file (str): Path for meta information file.
+        io_backend (dict): IO backend type and other kwarg.
+        filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
+            Default: '{}'.
+        gt_size (int): Cropped patched size for gt patches.
+        use_hflip (bool): Use horizontal flips.
+        use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
+        scale (bool): Scale, which will be added automatically.
+        phase (str): 'train' or 'val'.
+    """
+
+    def __init__(self, opt):
+        super(PairedImageDataset, self).__init__()
+        self.opt = opt
+        # file client (io backend)
+        self.file_client = None
+        self.io_backend_opt = opt['io_backend']
+        self.mean = opt['mean'] if 'mean' in opt else None
+        self.std = opt['std'] if 'std' in opt else None
+
+        self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
+        if 'filename_tmpl' in opt:
+            self.filename_tmpl = opt['filename_tmpl']
+        else:
+            self.filename_tmpl = '{}'
+
+        if self.io_backend_opt['type'] == 'lmdb':
+            self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
+            self.io_backend_opt['client_keys'] = ['lq', 'gt']
+            self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
+        elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None:
+            self.paths = paired_paths_from_meta_info_file_2([self.lq_folder, self.gt_folder], ['lq', 'gt'],
+                                                          self.opt['meta_info_file'], self.filename_tmpl)
+        else:
+            self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
+
+    def __getitem__(self, index):
+        if self.file_client is None:
+            self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+        scale = self.opt['scale']
+
+        # Load gt and lq images. Dimension order: HWC; channel order: BGR;
+        # image range: [0, 1], float32.
+        gt_path = self.paths[index]['gt_path']
+        img_bytes = self.file_client.get(gt_path, 'gt')
+        img_gt = imfrombytes(img_bytes, float32=True)
+        lq_path = self.paths[index]['lq_path']
+        img_bytes = self.file_client.get(lq_path, 'lq')
+        img_lq = imfrombytes(img_bytes, float32=True)
+
+        h, w = img_gt.shape[0:2]
+        # pad
+        if h < self.opt['gt_size'] or w < self.opt['gt_size']:
+            pad_h = max(0, self.opt['gt_size'] - h)
+            pad_w = max(0, self.opt['gt_size'] - w)
+            img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)
+            img_lq = cv2.copyMakeBorder(img_lq, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)
+
+        # augmentation for training
+        if self.opt['phase'] == 'train':
+            gt_size = self.opt['gt_size']
+            # random crop
+            img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
+            # flip, rotation
+            img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
+
+        # color space transform
+        if 'color' in self.opt and self.opt['color'] == 'y':
+            img_gt = bgr2ycbcr(img_gt, y_only=True)[..., None]
+            img_lq = bgr2ycbcr(img_lq, y_only=True)[..., None]
+
+        # crop the unmatched GT images during validation or testing, especially for SR benchmark datasets
+        # TODO: It is better to update the datasets, rather than force to crop
+        if self.opt['phase'] != 'train':
+            img_gt = img_gt[0:img_lq.shape[0] * scale, 0:img_lq.shape[1] * scale, :]
+
+        # BGR to RGB, HWC to CHW, numpy to tensor
+        img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
+        # normalize
+        if self.mean is not None or self.std is not None:
+            normalize(img_lq, self.mean, self.std, inplace=True)
+            normalize(img_gt, self.mean, self.std, inplace=True)
+
+        return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
+
+    def __len__(self):
+        return len(self.paths)
diff --git a/basicsr/data/prefetch_dataloader.py b/basicsr/data/prefetch_dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..332abd32fcb004e6892d12dc69848a4454e3c503
--- /dev/null
+++ b/basicsr/data/prefetch_dataloader.py
@@ -0,0 +1,122 @@
+import queue as Queue
+import threading
+import torch
+from torch.utils.data import DataLoader
+
+
+class PrefetchGenerator(threading.Thread):
+    """A general prefetch generator.
+
+    Reference: https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
+
+    Args:
+        generator: Python generator.
+        num_prefetch_queue (int): Number of prefetch queue.
+    """
+
+    def __init__(self, generator, num_prefetch_queue):
+        threading.Thread.__init__(self)
+        self.queue = Queue.Queue(num_prefetch_queue)
+        self.generator = generator
+        self.daemon = True
+        self.start()
+
+    def run(self):
+        for item in self.generator:
+            self.queue.put(item)
+        self.queue.put(None)
+
+    def __next__(self):
+        next_item = self.queue.get()
+        if next_item is None:
+            raise StopIteration
+        return next_item
+
+    def __iter__(self):
+        return self
+
+
+class PrefetchDataLoader(DataLoader):
+    """Prefetch version of dataloader.
+
+    Reference: https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
+
+    TODO:
+    Need to test on single gpu and ddp (multi-gpu). There is a known issue in
+    ddp.
+
+    Args:
+        num_prefetch_queue (int): Number of prefetch queue.
+        kwargs (dict): Other arguments for dataloader.
+    """
+
+    def __init__(self, num_prefetch_queue, **kwargs):
+        self.num_prefetch_queue = num_prefetch_queue
+        super(PrefetchDataLoader, self).__init__(**kwargs)
+
+    def __iter__(self):
+        return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
+
+
+class CPUPrefetcher():
+    """CPU prefetcher.
+
+    Args:
+        loader: Dataloader.
+    """
+
+    def __init__(self, loader):
+        self.ori_loader = loader
+        self.loader = iter(loader)
+
+    def next(self):
+        try:
+            return next(self.loader)
+        except StopIteration:
+            return None
+
+    def reset(self):
+        self.loader = iter(self.ori_loader)
+
+
+class CUDAPrefetcher():
+    """CUDA prefetcher.
+
+    Reference: https://github.com/NVIDIA/apex/issues/304#
+
+    It may consume more GPU memory.
+
+    Args:
+        loader: Dataloader.
+        opt (dict): Options.
+    """
+
+    def __init__(self, loader, opt):
+        self.ori_loader = loader
+        self.loader = iter(loader)
+        self.opt = opt
+        self.stream = torch.cuda.Stream()
+        self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
+        self.preload()
+
+    def preload(self):
+        try:
+            self.batch = next(self.loader)  # self.batch is a dict
+        except StopIteration:
+            self.batch = None
+            return None
+        # put tensors to gpu
+        with torch.cuda.stream(self.stream):
+            for k, v in self.batch.items():
+                if torch.is_tensor(v):
+                    self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
+
+    def next(self):
+        torch.cuda.current_stream().wait_stream(self.stream)
+        batch = self.batch
+        self.preload()
+        return batch
+
+    def reset(self):
+        self.loader = iter(self.ori_loader)
+        self.preload()
diff --git a/basicsr/data/realesrgan_dataset.py b/basicsr/data/realesrgan_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b7c0603d8353f5457b0dd96f9a9a876a192d113
--- /dev/null
+++ b/basicsr/data/realesrgan_dataset.py
@@ -0,0 +1,242 @@
+import cv2
+import math
+import numpy as np
+import os
+import os.path as osp
+import random
+import time
+import torch
+from pathlib import Path
+from torch.utils import data as data
+
+from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
+from basicsr.data.transforms import augment
+from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
+from basicsr.utils.registry import DATASET_REGISTRY
+
+@DATASET_REGISTRY.register(suffix='basicsr')
+class RealESRGANDataset(data.Dataset):
+    """Modified dataset based on the dataset used for Real-ESRGAN model:
+    Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
+
+    It loads gt (Ground-Truth) images, and augments them.
+    It also generates blur kernels and sinc kernels for generating low-quality images.
+    Note that the low-quality images are processed in tensors on GPUS for faster processing.
+
+    Args:
+        opt (dict): Config for train datasets. It contains the following keys:
+            dataroot_gt (str): Data root path for gt.
+            meta_info (str): Path for meta information file.
+            io_backend (dict): IO backend type and other kwarg.
+            use_hflip (bool): Use horizontal flips.
+            use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
+            Please see more options in the codes.
+    """
+
+    def __init__(self, opt):
+        super(RealESRGANDataset, self).__init__()
+        self.opt = opt
+        self.file_client = None
+        self.io_backend_opt = opt['io_backend']
+        if 'crop_size' in opt:
+            self.crop_size = opt['crop_size']
+        else:
+            self.crop_size = 512
+        if 'image_type' not in opt:
+            opt['image_type'] = 'png'
+
+        # support multiple type of data: file path and meta data, remove support of lmdb
+        self.paths = []
+        if 'meta_info' in opt:
+            with open(self.opt['meta_info']) as fin:
+                    paths = [line.strip().split(' ')[0] for line in fin]
+                    self.paths = [v for v in paths]
+            if 'meta_num' in opt:
+                self.paths = sorted(self.paths)[:opt['meta_num']]
+        if 'gt_path' in opt:
+            if isinstance(opt['gt_path'], str):
+                self.paths.extend(sorted([str(x) for x in Path(opt['gt_path']).glob('*.'+opt['image_type'])]))
+            else:
+                self.paths.extend(sorted([str(x) for x in Path(opt['gt_path'][0]).glob('*.'+opt['image_type'])]))
+                if len(opt['gt_path']) > 1:
+                    for i in range(len(opt['gt_path'])-1):
+                        self.paths.extend(sorted([str(x) for x in Path(opt['gt_path'][i+1]).glob('*.'+opt['image_type'])]))
+        if 'imagenet_path' in opt:
+            class_list = os.listdir(opt['imagenet_path'])
+            for class_file in class_list:
+                self.paths.extend(sorted([str(x) for x in Path(os.path.join(opt['imagenet_path'], class_file)).glob('*.'+'JPEG')]))
+        if 'face_gt_path' in opt:
+            if isinstance(opt['face_gt_path'], str):
+                face_list = sorted([str(x) for x in Path(opt['face_gt_path']).glob('*.'+opt['image_type'])])
+                self.paths.extend(face_list[:opt['num_face']])
+            else:
+                face_list = sorted([str(x) for x in Path(opt['face_gt_path'][0]).glob('*.'+opt['image_type'])])
+                self.paths.extend(face_list[:opt['num_face']])
+                if len(opt['face_gt_path']) > 1:
+                    for i in range(len(opt['face_gt_path'])-1):
+                        self.paths.extend(sorted([str(x) for x in Path(opt['face_gt_path'][0]).glob('*.'+opt['image_type'])])[:opt['num_face']])
+
+        # limit number of pictures for test
+        if 'num_pic' in opt:
+            if 'val' or 'test' in opt:
+                random.shuffle(self.paths)
+                self.paths = self.paths[:opt['num_pic']]
+            else:
+                self.paths = self.paths[:opt['num_pic']]
+
+        if 'mul_num' in opt:
+            self.paths = self.paths * opt['mul_num']
+            # print('>>>>>>>>>>>>>>>>>>>>>')
+            # print(self.paths)
+
+        # blur settings for the first degradation
+        self.blur_kernel_size = opt['blur_kernel_size']
+        self.kernel_list = opt['kernel_list']
+        self.kernel_prob = opt['kernel_prob']  # a list for each kernel probability
+        self.blur_sigma = opt['blur_sigma']
+        self.betag_range = opt['betag_range']  # betag used in generalized Gaussian blur kernels
+        self.betap_range = opt['betap_range']  # betap used in plateau blur kernels
+        self.sinc_prob = opt['sinc_prob']  # the probability for sinc filters
+
+        # blur settings for the second degradation
+        self.blur_kernel_size2 = opt['blur_kernel_size2']
+        self.kernel_list2 = opt['kernel_list2']
+        self.kernel_prob2 = opt['kernel_prob2']
+        self.blur_sigma2 = opt['blur_sigma2']
+        self.betag_range2 = opt['betag_range2']
+        self.betap_range2 = opt['betap_range2']
+        self.sinc_prob2 = opt['sinc_prob2']
+
+        # a final sinc filter
+        self.final_sinc_prob = opt['final_sinc_prob']
+
+        self.kernel_range = [2 * v + 1 for v in range(3, 11)]  # kernel size ranges from 7 to 21
+        # TODO: kernel range is now hard-coded, should be in the configure file
+        self.pulse_tensor = torch.zeros(21, 21).float()  # convolving with pulse tensor brings no blurry effect
+        self.pulse_tensor[10, 10] = 1
+
+    def __getitem__(self, index):
+        if self.file_client is None:
+            self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+        # -------------------------------- Load gt images -------------------------------- #
+        # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
+        gt_path = self.paths[index]
+        # avoid errors caused by high latency in reading files
+        retry = 3
+        while retry > 0:
+            try:
+                img_bytes = self.file_client.get(gt_path, 'gt')
+            except (IOError, OSError) as e:
+                # logger = get_root_logger()
+                # logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}')
+                # change another file to read
+                index = random.randint(0, self.__len__()-1)
+                gt_path = self.paths[index]
+                time.sleep(1)  # sleep 1s for occasional server congestion
+            else:
+                break
+            finally:
+                retry -= 1
+        img_gt = imfrombytes(img_bytes, float32=True)
+        # filter the dataset and remove images with too low quality
+        img_size = os.path.getsize(gt_path)
+        img_size = img_size/1024
+
+        while img_gt.shape[0] * img_gt.shape[1] < 384*384 or img_size<100:
+            index = random.randint(0, self.__len__()-1)
+            gt_path = self.paths[index]
+
+            time.sleep(0.1)  # sleep 1s for occasional server congestion
+            img_bytes = self.file_client.get(gt_path, 'gt')
+            img_gt = imfrombytes(img_bytes, float32=True)
+            img_size = os.path.getsize(gt_path)
+            img_size = img_size/1024
+
+        # -------------------- Do augmentation for training: flip, rotation -------------------- #
+        img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot'])
+
+        # crop or pad to 400
+        # TODO: 400 is hard-coded. You may change it accordingly
+        h, w = img_gt.shape[0:2]
+        crop_pad_size = self.crop_size
+        # pad
+        if h < crop_pad_size or w < crop_pad_size:
+            pad_h = max(0, crop_pad_size - h)
+            pad_w = max(0, crop_pad_size - w)
+            img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)
+        # crop
+        if img_gt.shape[0] > crop_pad_size or img_gt.shape[1] > crop_pad_size:
+            h, w = img_gt.shape[0:2]
+            # randomly choose top and left coordinates
+            top = random.randint(0, h - crop_pad_size)
+            left = random.randint(0, w - crop_pad_size)
+            # top = (h - crop_pad_size) // 2 -1
+            # left = (w - crop_pad_size) // 2 -1
+            img_gt = img_gt[top:top + crop_pad_size, left:left + crop_pad_size, ...]
+
+        # ------------------------ Generate kernels (used in the first degradation) ------------------------ #
+        kernel_size = random.choice(self.kernel_range)
+        if np.random.uniform() < self.opt['sinc_prob']:
+            # this sinc filter setting is for kernels ranging from [7, 21]
+            if kernel_size < 13:
+                omega_c = np.random.uniform(np.pi / 3, np.pi)
+            else:
+                omega_c = np.random.uniform(np.pi / 5, np.pi)
+            kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
+        else:
+            kernel = random_mixed_kernels(
+                self.kernel_list,
+                self.kernel_prob,
+                kernel_size,
+                self.blur_sigma,
+                self.blur_sigma, [-math.pi, math.pi],
+                self.betag_range,
+                self.betap_range,
+                noise_range=None)
+        # pad kernel
+        pad_size = (21 - kernel_size) // 2
+        kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
+
+        # ------------------------ Generate kernels (used in the second degradation) ------------------------ #
+        kernel_size = random.choice(self.kernel_range)
+        if np.random.uniform() < self.opt['sinc_prob2']:
+            if kernel_size < 13:
+                omega_c = np.random.uniform(np.pi / 3, np.pi)
+            else:
+                omega_c = np.random.uniform(np.pi / 5, np.pi)
+            kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
+        else:
+            kernel2 = random_mixed_kernels(
+                self.kernel_list2,
+                self.kernel_prob2,
+                kernel_size,
+                self.blur_sigma2,
+                self.blur_sigma2, [-math.pi, math.pi],
+                self.betag_range2,
+                self.betap_range2,
+                noise_range=None)
+
+        # pad kernel
+        pad_size = (21 - kernel_size) // 2
+        kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
+
+        # ------------------------------------- the final sinc kernel ------------------------------------- #
+        if np.random.uniform() < self.opt['final_sinc_prob']:
+            kernel_size = random.choice(self.kernel_range)
+            omega_c = np.random.uniform(np.pi / 3, np.pi)
+            sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
+            sinc_kernel = torch.FloatTensor(sinc_kernel)
+        else:
+            sinc_kernel = self.pulse_tensor
+
+        # BGR to RGB, HWC to CHW, numpy to tensor
+        img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0]
+        kernel = torch.FloatTensor(kernel)
+        kernel2 = torch.FloatTensor(kernel2)
+
+        return_d = {'gt': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'gt_path': gt_path}
+        return return_d
+
+    def __len__(self):
+        return len(self.paths)
diff --git a/basicsr/data/realesrgan_paired_dataset.py b/basicsr/data/realesrgan_paired_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d0c6159d448f26fc8a256d6a9d0c51096b78fe0
--- /dev/null
+++ b/basicsr/data/realesrgan_paired_dataset.py
@@ -0,0 +1,114 @@
+import os
+from torch.utils import data as data
+from torchvision.transforms.functional import normalize
+
+from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb
+from basicsr.data.transforms import augment, paired_random_crop
+from basicsr.utils import FileClient, imfrombytes, img2tensor
+from basicsr.utils.registry import DATASET_REGISTRY
+
+
+@DATASET_REGISTRY.register(suffix='basicsr')
+class RealESRGANPairedDataset(data.Dataset):
+    """Paired image dataset for image restoration.
+
+    Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
+
+    There are three modes:
+
+    1. **lmdb**: Use lmdb files. If opt['io_backend'] == lmdb.
+    2. **meta_info_file**: Use meta information file to generate paths. \
+        If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
+    3. **folder**: Scan folders to generate paths. The rest.
+
+    Args:
+        opt (dict): Config for train datasets. It contains the following keys:
+        dataroot_gt (str): Data root path for gt.
+        dataroot_lq (str): Data root path for lq.
+        meta_info (str): Path for meta information file.
+        io_backend (dict): IO backend type and other kwarg.
+        filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
+            Default: '{}'.
+        gt_size (int): Cropped patched size for gt patches.
+        use_hflip (bool): Use horizontal flips.
+        use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
+        scale (bool): Scale, which will be added automatically.
+        phase (str): 'train' or 'val'.
+    """
+
+    def __init__(self, opt):
+        super(RealESRGANPairedDataset, self).__init__()
+        self.opt = opt
+        self.file_client = None
+        self.io_backend_opt = opt['io_backend']
+        # mean and std for normalizing the input images
+        self.mean = opt['mean'] if 'mean' in opt else None
+        self.std = opt['std'] if 'std' in opt else None
+
+        self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
+        self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}'
+
+        # file client (lmdb io backend)
+        if self.io_backend_opt['type'] == 'lmdb':
+            self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
+            self.io_backend_opt['client_keys'] = ['lq', 'gt']
+            self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
+        elif 'meta_info' in self.opt and self.opt['meta_info'] is not None:
+            # disk backend with meta_info
+            # Each line in the meta_info describes the relative path to an image
+            with open(self.opt['meta_info']) as fin:
+                paths = [line.strip() for line in fin]
+            self.paths = []
+            for path in paths:
+                gt_path, lq_path = path.split(', ')
+                gt_path = os.path.join(self.gt_folder, gt_path)
+                lq_path = os.path.join(self.lq_folder, lq_path)
+                self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)]))
+        else:
+            # disk backend
+            # it will scan the whole folder to get meta info
+            # it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file
+            self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
+
+        if 'num_pic' in self.opt:
+            self.paths = self.paths[:self.opt['num_pic']]
+        if 'phase' not in self.opt:
+            self.opt['phase'] = 'test'
+        if 'scale' not in self.opt:
+            self.opt['scale'] = 1
+
+
+    def __getitem__(self, index):
+        if self.file_client is None:
+            self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+        scale = self.opt['scale']
+
+        # Load gt and lq images. Dimension order: HWC; channel order: BGR;
+        # image range: [0, 1], float32.
+        gt_path = self.paths[index]['gt_path']
+        img_bytes = self.file_client.get(gt_path, 'gt')
+        img_gt = imfrombytes(img_bytes, float32=True)
+        lq_path = self.paths[index]['lq_path']
+        img_bytes = self.file_client.get(lq_path, 'lq')
+        img_lq = imfrombytes(img_bytes, float32=True)
+
+        # augmentation for training
+        if self.opt['phase'] == 'train':
+            gt_size = self.opt['gt_size']
+            # random crop
+            img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
+            # flip, rotation
+            img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
+
+        # BGR to RGB, HWC to CHW, numpy to tensor
+        img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
+        # normalize
+        if self.mean is not None or self.std is not None:
+            normalize(img_lq, self.mean, self.std, inplace=True)
+            normalize(img_gt, self.mean, self.std, inplace=True)
+
+        return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
+
+    def __len__(self):
+        return len(self.paths)
diff --git a/basicsr/data/reds_dataset.py b/basicsr/data/reds_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..fabef1d7e80866888f3b57ecfeb4d97c93bcb5cd
--- /dev/null
+++ b/basicsr/data/reds_dataset.py
@@ -0,0 +1,352 @@
+import numpy as np
+import random
+import torch
+from pathlib import Path
+from torch.utils import data as data
+
+from basicsr.data.transforms import augment, paired_random_crop
+from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
+from basicsr.utils.flow_util import dequantize_flow
+from basicsr.utils.registry import DATASET_REGISTRY
+
+
+@DATASET_REGISTRY.register()
+class REDSDataset(data.Dataset):
+    """REDS dataset for training.
+
+    The keys are generated from a meta info txt file.
+    basicsr/data/meta_info/meta_info_REDS_GT.txt
+
+    Each line contains:
+    1. subfolder (clip) name; 2. frame number; 3. image shape, separated by
+    a white space.
+    Examples:
+    000 100 (720,1280,3)
+    001 100 (720,1280,3)
+    ...
+
+    Key examples: "000/00000000"
+    GT (gt): Ground-Truth;
+    LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
+
+    Args:
+        opt (dict): Config for train dataset. It contains the following keys:
+        dataroot_gt (str): Data root path for gt.
+        dataroot_lq (str): Data root path for lq.
+        dataroot_flow (str, optional): Data root path for flow.
+        meta_info_file (str): Path for meta information file.
+        val_partition (str): Validation partition types. 'REDS4' or 'official'.
+        io_backend (dict): IO backend type and other kwarg.
+        num_frame (int): Window size for input frames.
+        gt_size (int): Cropped patched size for gt patches.
+        interval_list (list): Interval list for temporal augmentation.
+        random_reverse (bool): Random reverse input frames.
+        use_hflip (bool): Use horizontal flips.
+        use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
+        scale (bool): Scale, which will be added automatically.
+    """
+
+    def __init__(self, opt):
+        super(REDSDataset, self).__init__()
+        self.opt = opt
+        self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq'])
+        self.flow_root = Path(opt['dataroot_flow']) if opt['dataroot_flow'] is not None else None
+        assert opt['num_frame'] % 2 == 1, (f'num_frame should be odd number, but got {opt["num_frame"]}')
+        self.num_frame = opt['num_frame']
+        self.num_half_frames = opt['num_frame'] // 2
+
+        self.keys = []
+        with open(opt['meta_info_file'], 'r') as fin:
+            for line in fin:
+                folder, frame_num, _ = line.split(' ')
+                self.keys.extend([f'{folder}/{i:08d}' for i in range(int(frame_num))])
+
+        # remove the video clips used in validation
+        if opt['val_partition'] == 'REDS4':
+            val_partition = ['000', '011', '015', '020']
+        elif opt['val_partition'] == 'official':
+            val_partition = [f'{v:03d}' for v in range(240, 270)]
+        else:
+            raise ValueError(f'Wrong validation partition {opt["val_partition"]}.'
+                             f"Supported ones are ['official', 'REDS4'].")
+        self.keys = [v for v in self.keys if v.split('/')[0] not in val_partition]
+
+        # file client (io backend)
+        self.file_client = None
+        self.io_backend_opt = opt['io_backend']
+        self.is_lmdb = False
+        if self.io_backend_opt['type'] == 'lmdb':
+            self.is_lmdb = True
+            if self.flow_root is not None:
+                self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root, self.flow_root]
+                self.io_backend_opt['client_keys'] = ['lq', 'gt', 'flow']
+            else:
+                self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
+                self.io_backend_opt['client_keys'] = ['lq', 'gt']
+
+        # temporal augmentation configs
+        self.interval_list = opt['interval_list']
+        self.random_reverse = opt['random_reverse']
+        interval_str = ','.join(str(x) for x in opt['interval_list'])
+        logger = get_root_logger()
+        logger.info(f'Temporal augmentation interval list: [{interval_str}]; '
+                    f'random reverse is {self.random_reverse}.')
+
+    def __getitem__(self, index):
+        if self.file_client is None:
+            self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+        scale = self.opt['scale']
+        gt_size = self.opt['gt_size']
+        key = self.keys[index]
+        clip_name, frame_name = key.split('/')  # key example: 000/00000000
+        center_frame_idx = int(frame_name)
+
+        # determine the neighboring frames
+        interval = random.choice(self.interval_list)
+
+        # ensure not exceeding the borders
+        start_frame_idx = center_frame_idx - self.num_half_frames * interval
+        end_frame_idx = center_frame_idx + self.num_half_frames * interval
+        # each clip has 100 frames starting from 0 to 99
+        while (start_frame_idx < 0) or (end_frame_idx > 99):
+            center_frame_idx = random.randint(0, 99)
+            start_frame_idx = (center_frame_idx - self.num_half_frames * interval)
+            end_frame_idx = center_frame_idx + self.num_half_frames * interval
+        frame_name = f'{center_frame_idx:08d}'
+        neighbor_list = list(range(start_frame_idx, end_frame_idx + 1, interval))
+        # random reverse
+        if self.random_reverse and random.random() < 0.5:
+            neighbor_list.reverse()
+
+        assert len(neighbor_list) == self.num_frame, (f'Wrong length of neighbor list: {len(neighbor_list)}')
+
+        # get the GT frame (as the center frame)
+        if self.is_lmdb:
+            img_gt_path = f'{clip_name}/{frame_name}'
+        else:
+            img_gt_path = self.gt_root / clip_name / f'{frame_name}.png'
+        img_bytes = self.file_client.get(img_gt_path, 'gt')
+        img_gt = imfrombytes(img_bytes, float32=True)
+
+        # get the neighboring LQ frames
+        img_lqs = []
+        for neighbor in neighbor_list:
+            if self.is_lmdb:
+                img_lq_path = f'{clip_name}/{neighbor:08d}'
+            else:
+                img_lq_path = self.lq_root / clip_name / f'{neighbor:08d}.png'
+            img_bytes = self.file_client.get(img_lq_path, 'lq')
+            img_lq = imfrombytes(img_bytes, float32=True)
+            img_lqs.append(img_lq)
+
+        # get flows
+        if self.flow_root is not None:
+            img_flows = []
+            # read previous flows
+            for i in range(self.num_half_frames, 0, -1):
+                if self.is_lmdb:
+                    flow_path = f'{clip_name}/{frame_name}_p{i}'
+                else:
+                    flow_path = (self.flow_root / clip_name / f'{frame_name}_p{i}.png')
+                img_bytes = self.file_client.get(flow_path, 'flow')
+                cat_flow = imfrombytes(img_bytes, flag='grayscale', float32=False)  # uint8, [0, 255]
+                dx, dy = np.split(cat_flow, 2, axis=0)
+                flow = dequantize_flow(dx, dy, max_val=20, denorm=False)  # we use max_val 20 here.
+                img_flows.append(flow)
+            # read next flows
+            for i in range(1, self.num_half_frames + 1):
+                if self.is_lmdb:
+                    flow_path = f'{clip_name}/{frame_name}_n{i}'
+                else:
+                    flow_path = (self.flow_root / clip_name / f'{frame_name}_n{i}.png')
+                img_bytes = self.file_client.get(flow_path, 'flow')
+                cat_flow = imfrombytes(img_bytes, flag='grayscale', float32=False)  # uint8, [0, 255]
+                dx, dy = np.split(cat_flow, 2, axis=0)
+                flow = dequantize_flow(dx, dy, max_val=20, denorm=False)  # we use max_val 20 here.
+                img_flows.append(flow)
+
+            # for random crop, here, img_flows and img_lqs have the same
+            # spatial size
+            img_lqs.extend(img_flows)
+
+        # randomly crop
+        img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale, img_gt_path)
+        if self.flow_root is not None:
+            img_lqs, img_flows = img_lqs[:self.num_frame], img_lqs[self.num_frame:]
+
+        # augmentation - flip, rotate
+        img_lqs.append(img_gt)
+        if self.flow_root is not None:
+            img_results, img_flows = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'], img_flows)
+        else:
+            img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
+
+        img_results = img2tensor(img_results)
+        img_lqs = torch.stack(img_results[0:-1], dim=0)
+        img_gt = img_results[-1]
+
+        if self.flow_root is not None:
+            img_flows = img2tensor(img_flows)
+            # add the zero center flow
+            img_flows.insert(self.num_half_frames, torch.zeros_like(img_flows[0]))
+            img_flows = torch.stack(img_flows, dim=0)
+
+        # img_lqs: (t, c, h, w)
+        # img_flows: (t, 2, h, w)
+        # img_gt: (c, h, w)
+        # key: str
+        if self.flow_root is not None:
+            return {'lq': img_lqs, 'flow': img_flows, 'gt': img_gt, 'key': key}
+        else:
+            return {'lq': img_lqs, 'gt': img_gt, 'key': key}
+
+    def __len__(self):
+        return len(self.keys)
+
+
+@DATASET_REGISTRY.register()
+class REDSRecurrentDataset(data.Dataset):
+    """REDS dataset for training recurrent networks.
+
+    The keys are generated from a meta info txt file.
+    basicsr/data/meta_info/meta_info_REDS_GT.txt
+
+    Each line contains:
+    1. subfolder (clip) name; 2. frame number; 3. image shape, separated by
+    a white space.
+    Examples:
+    000 100 (720,1280,3)
+    001 100 (720,1280,3)
+    ...
+
+    Key examples: "000/00000000"
+    GT (gt): Ground-Truth;
+    LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
+
+    Args:
+        opt (dict): Config for train dataset. It contains the following keys:
+        dataroot_gt (str): Data root path for gt.
+        dataroot_lq (str): Data root path for lq.
+        dataroot_flow (str, optional): Data root path for flow.
+        meta_info_file (str): Path for meta information file.
+        val_partition (str): Validation partition types. 'REDS4' or 'official'.
+        io_backend (dict): IO backend type and other kwarg.
+        num_frame (int): Window size for input frames.
+        gt_size (int): Cropped patched size for gt patches.
+        interval_list (list): Interval list for temporal augmentation.
+        random_reverse (bool): Random reverse input frames.
+        use_hflip (bool): Use horizontal flips.
+        use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
+        scale (bool): Scale, which will be added automatically.
+    """
+
+    def __init__(self, opt):
+        super(REDSRecurrentDataset, self).__init__()
+        self.opt = opt
+        self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq'])
+        self.num_frame = opt['num_frame']
+
+        self.keys = []
+        with open(opt['meta_info_file'], 'r') as fin:
+            for line in fin:
+                folder, frame_num, _ = line.split(' ')
+                self.keys.extend([f'{folder}/{i:08d}' for i in range(int(frame_num))])
+
+        # remove the video clips used in validation
+        if opt['val_partition'] == 'REDS4':
+            val_partition = ['000', '011', '015', '020']
+        elif opt['val_partition'] == 'official':
+            val_partition = [f'{v:03d}' for v in range(240, 270)]
+        else:
+            raise ValueError(f'Wrong validation partition {opt["val_partition"]}.'
+                             f"Supported ones are ['official', 'REDS4'].")
+        if opt['test_mode']:
+            self.keys = [v for v in self.keys if v.split('/')[0] in val_partition]
+        else:
+            self.keys = [v for v in self.keys if v.split('/')[0] not in val_partition]
+
+        # file client (io backend)
+        self.file_client = None
+        self.io_backend_opt = opt['io_backend']
+        self.is_lmdb = False
+        if self.io_backend_opt['type'] == 'lmdb':
+            self.is_lmdb = True
+            if hasattr(self, 'flow_root') and self.flow_root is not None:
+                self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root, self.flow_root]
+                self.io_backend_opt['client_keys'] = ['lq', 'gt', 'flow']
+            else:
+                self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
+                self.io_backend_opt['client_keys'] = ['lq', 'gt']
+
+        # temporal augmentation configs
+        self.interval_list = opt.get('interval_list', [1])
+        self.random_reverse = opt.get('random_reverse', False)
+        interval_str = ','.join(str(x) for x in self.interval_list)
+        logger = get_root_logger()
+        logger.info(f'Temporal augmentation interval list: [{interval_str}]; '
+                    f'random reverse is {self.random_reverse}.')
+
+    def __getitem__(self, index):
+        if self.file_client is None:
+            self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+        scale = self.opt['scale']
+        gt_size = self.opt['gt_size']
+        key = self.keys[index]
+        clip_name, frame_name = key.split('/')  # key example: 000/00000000
+
+        # determine the neighboring frames
+        interval = random.choice(self.interval_list)
+
+        # ensure not exceeding the borders
+        start_frame_idx = int(frame_name)
+        if start_frame_idx > 100 - self.num_frame * interval:
+            start_frame_idx = random.randint(0, 100 - self.num_frame * interval)
+        end_frame_idx = start_frame_idx + self.num_frame * interval
+
+        neighbor_list = list(range(start_frame_idx, end_frame_idx, interval))
+
+        # random reverse
+        if self.random_reverse and random.random() < 0.5:
+            neighbor_list.reverse()
+
+        # get the neighboring LQ and GT frames
+        img_lqs = []
+        img_gts = []
+        for neighbor in neighbor_list:
+            if self.is_lmdb:
+                img_lq_path = f'{clip_name}/{neighbor:08d}'
+                img_gt_path = f'{clip_name}/{neighbor:08d}'
+            else:
+                img_lq_path = self.lq_root / clip_name / f'{neighbor:08d}.png'
+                img_gt_path = self.gt_root / clip_name / f'{neighbor:08d}.png'
+
+            # get LQ
+            img_bytes = self.file_client.get(img_lq_path, 'lq')
+            img_lq = imfrombytes(img_bytes, float32=True)
+            img_lqs.append(img_lq)
+
+            # get GT
+            img_bytes = self.file_client.get(img_gt_path, 'gt')
+            img_gt = imfrombytes(img_bytes, float32=True)
+            img_gts.append(img_gt)
+
+        # randomly crop
+        img_gts, img_lqs = paired_random_crop(img_gts, img_lqs, gt_size, scale, img_gt_path)
+
+        # augmentation - flip, rotate
+        img_lqs.extend(img_gts)
+        img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
+
+        img_results = img2tensor(img_results)
+        img_gts = torch.stack(img_results[len(img_lqs) // 2:], dim=0)
+        img_lqs = torch.stack(img_results[:len(img_lqs) // 2], dim=0)
+
+        # img_lqs: (t, c, h, w)
+        # img_gts: (t, c, h, w)
+        # key: str
+        return {'lq': img_lqs, 'gt': img_gts, 'key': key}
+
+    def __len__(self):
+        return len(self.keys)
diff --git a/basicsr/data/single_image_dataset.py b/basicsr/data/single_image_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8d1a94d1723fb832b0c6fc897e72e0081c4a399
--- /dev/null
+++ b/basicsr/data/single_image_dataset.py
@@ -0,0 +1,164 @@
+from os import path as osp
+from torch.utils import data as data
+from torchvision.transforms.functional import normalize
+
+from basicsr.data.data_util import paths_from_lmdb
+from basicsr.utils import FileClient, imfrombytes, img2tensor, rgb2ycbcr, scandir
+from basicsr.utils.registry import DATASET_REGISTRY
+
+from pathlib import Path
+import random
+import cv2
+import numpy as np
+import torch
+
+@DATASET_REGISTRY.register()
+class SingleImageDataset(data.Dataset):
+    """Read only lq images in the test phase.
+
+    Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc).
+
+    There are two modes:
+    1. 'meta_info_file': Use meta information file to generate paths.
+    2. 'folder': Scan folders to generate paths.
+
+    Args:
+        opt (dict): Config for train datasets. It contains the following keys:
+            dataroot_lq (str): Data root path for lq.
+            meta_info_file (str): Path for meta information file.
+            io_backend (dict): IO backend type and other kwarg.
+    """
+
+    def __init__(self, opt):
+        super(SingleImageDataset, self).__init__()
+        self.opt = opt
+        # file client (io backend)
+        self.file_client = None
+        self.io_backend_opt = opt['io_backend']
+        self.mean = opt['mean'] if 'mean' in opt else None
+        self.std = opt['std'] if 'std' in opt else None
+        self.lq_folder = opt['dataroot_lq']
+
+        if self.io_backend_opt['type'] == 'lmdb':
+            self.io_backend_opt['db_paths'] = [self.lq_folder]
+            self.io_backend_opt['client_keys'] = ['lq']
+            self.paths = paths_from_lmdb(self.lq_folder)
+        elif 'meta_info_file' in self.opt:
+            with open(self.opt['meta_info_file'], 'r') as fin:
+                self.paths = [osp.join(self.lq_folder, line.rstrip().split(' ')[0]) for line in fin]
+        else:
+            self.paths = sorted(list(scandir(self.lq_folder, full_path=True)))
+
+    def __getitem__(self, index):
+        if self.file_client is None:
+            self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+        # load lq image
+        lq_path = self.paths[index]
+        img_bytes = self.file_client.get(lq_path, 'lq')
+        img_lq = imfrombytes(img_bytes, float32=True)
+
+        # color space transform
+        if 'color' in self.opt and self.opt['color'] == 'y':
+            img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None]
+
+        # BGR to RGB, HWC to CHW, numpy to tensor
+        img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True)
+        # normalize
+        if self.mean is not None or self.std is not None:
+            normalize(img_lq, self.mean, self.std, inplace=True)
+        return {'lq': img_lq, 'lq_path': lq_path}
+
+    def __len__(self):
+        return len(self.paths)
+
+@DATASET_REGISTRY.register()
+class SingleImageNPDataset(data.Dataset):
+    """Read only lq images in the test phase.
+
+    Read diffusion generated data for training CFW.
+
+    Args:
+        opt (dict): Config for train datasets. It contains the following keys:
+            gt_path: Data root path for training data. The path needs to contain the following folders:
+                gts: Ground-truth images.
+                inputs: Input LQ images.
+                latents: The corresponding HQ latent code generated by diffusion model given the input LQ image.
+                samples: The corresponding HQ image given the HQ latent code, just for verification.
+            io_backend (dict): IO backend type and other kwarg.
+    """
+
+    def __init__(self, opt):
+        super(SingleImageNPDataset, self).__init__()
+        self.opt = opt
+        # file client (io backend)
+        self.file_client = None
+        self.io_backend_opt = opt['io_backend']
+        self.mean = opt['mean'] if 'mean' in opt else None
+        self.std = opt['std'] if 'std' in opt else None
+        if 'image_type' not in opt:
+            opt['image_type'] = 'png'
+
+        if isinstance(opt['gt_path'], str):
+            self.gt_paths = sorted([str(x) for x in Path(opt['gt_path']+'/gts').glob('*.'+opt['image_type'])])
+            self.lq_paths = sorted([str(x) for x in Path(opt['gt_path']+'/inputs').glob('*.'+opt['image_type'])])
+            self.np_paths = sorted([str(x) for x in Path(opt['gt_path']+'/latents').glob('*.npy')])
+            self.sample_paths = sorted([str(x) for x in Path(opt['gt_path']+'/samples').glob('*.'+opt['image_type'])])
+        else:
+            self.gt_paths = sorted([str(x) for x in Path(opt['gt_path'][0]+'/gts').glob('*.'+opt['image_type'])])
+            self.lq_paths = sorted([str(x) for x in Path(opt['gt_path'][0]+'/inputs').glob('*.'+opt['image_type'])])
+            self.np_paths = sorted([str(x) for x in Path(opt['gt_path'][0]+'/latents').glob('*.npy')])
+            self.sample_paths = sorted([str(x) for x in Path(opt['gt_path'][0]+'/samples').glob('*.'+opt['image_type'])])
+            if len(opt['gt_path']) > 1:
+                for i in range(len(opt['gt_path'])-1):
+                    self.gt_paths.extend(sorted([str(x) for x in Path(opt['gt_path'][i+1]+'/gts').glob('*.'+opt['image_type'])]))
+                    self.lq_paths.extend(sorted([str(x) for x in Path(opt['gt_path'][i+1]+'/inputs').glob('*.'+opt['image_type'])]))
+                    self.np_paths.extend(sorted([str(x) for x in Path(opt['gt_path'][i+1]+'/latents').glob('*.npy')]))
+                    self.sample_paths.extend(sorted([str(x) for x in Path(opt['gt_path'][i+1]+'/samples').glob('*.'+opt['image_type'])]))
+
+        assert len(self.gt_paths) == len(self.lq_paths)
+        assert len(self.gt_paths) == len(self.np_paths)
+        assert len(self.gt_paths) == len(self.sample_paths)
+
+    def __getitem__(self, index):
+        if self.file_client is None:
+            self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+        # load lq image
+        lq_path = self.lq_paths[index]
+        gt_path = self.gt_paths[index]
+        sample_path = self.sample_paths[index]
+        np_path = self.np_paths[index]
+
+        img_bytes = self.file_client.get(lq_path, 'lq')
+        img_lq = imfrombytes(img_bytes, float32=True)
+
+        img_bytes_gt = self.file_client.get(gt_path, 'gt')
+        img_gt = imfrombytes(img_bytes_gt, float32=True)
+
+        img_bytes_sample = self.file_client.get(sample_path, 'sample')
+        img_sample = imfrombytes(img_bytes_sample, float32=True)
+
+        latent_np = np.load(np_path)
+
+        # color space transform
+        if 'color' in self.opt and self.opt['color'] == 'y':
+            img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None]
+            img_gt = rgb2ycbcr(img_gt, y_only=True)[..., None]
+            img_sample = rgb2ycbcr(img_sample, y_only=True)[..., None]
+
+        # BGR to RGB, HWC to CHW, numpy to tensor
+        img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True)
+        img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True)
+        img_sample = img2tensor(img_sample, bgr2rgb=True, float32=True)
+        latent_np = torch.from_numpy(latent_np).float()
+        latent_np = latent_np.to(img_gt.device)
+        # normalize
+        if self.mean is not None or self.std is not None:
+            normalize(img_lq, self.mean, self.std, inplace=True)
+            normalize(img_gt, self.mean, self.std, inplace=True)
+            normalize(img_sample, self.mean, self.std, inplace=True)
+        return {'lq': img_lq, 'lq_path': lq_path, 'gt': img_gt, 'gt_path': gt_path, 'latent': latent_np[0], 'latent_path': np_path, 'sample': img_sample, 'sample_path': sample_path}
+
+    def __len__(self):
+        return len(self.gt_paths)
diff --git a/basicsr/data/transforms.py b/basicsr/data/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..c700a399bb737a2286ea705fcebd937e6fb54ca7
--- /dev/null
+++ b/basicsr/data/transforms.py
@@ -0,0 +1,240 @@
+import cv2
+import random
+import torch
+
+
+def mod_crop(img, scale):
+    """Mod crop images, used during testing.
+
+    Args:
+        img (ndarray): Input image.
+        scale (int): Scale factor.
+
+    Returns:
+        ndarray: Result image.
+    """
+    img = img.copy()
+    if img.ndim in (2, 3):
+        h, w = img.shape[0], img.shape[1]
+        h_remainder, w_remainder = h % scale, w % scale
+        img = img[:h - h_remainder, :w - w_remainder, ...]
+    else:
+        raise ValueError(f'Wrong img ndim: {img.ndim}.')
+    return img
+
+
+def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None):
+    """Paired random crop. Support Numpy array and Tensor inputs.
+
+    It crops lists of lq and gt images with corresponding locations.
+
+    Args:
+        img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images
+            should have the same shape. If the input is an ndarray, it will
+            be transformed to a list containing itself.
+        img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
+            should have the same shape. If the input is an ndarray, it will
+            be transformed to a list containing itself.
+        gt_patch_size (int): GT patch size.
+        scale (int): Scale factor.
+        gt_path (str): Path to ground-truth. Default: None.
+
+    Returns:
+        list[ndarray] | ndarray: GT images and LQ images. If returned results
+            only have one element, just return ndarray.
+    """
+
+    if not isinstance(img_gts, list):
+        img_gts = [img_gts]
+    if not isinstance(img_lqs, list):
+        img_lqs = [img_lqs]
+
+    # determine input type: Numpy array or Tensor
+    input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy'
+
+    if input_type == 'Tensor':
+        h_lq, w_lq = img_lqs[0].size()[-2:]
+        h_gt, w_gt = img_gts[0].size()[-2:]
+    else:
+        h_lq, w_lq = img_lqs[0].shape[0:2]
+        h_gt, w_gt = img_gts[0].shape[0:2]
+    lq_patch_size = gt_patch_size // scale
+
+    if h_gt != h_lq * scale or w_gt != w_lq * scale:
+        raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
+                         f'multiplication of LQ ({h_lq}, {w_lq}).')
+    if h_lq < lq_patch_size or w_lq < lq_patch_size:
+        raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
+                         f'({lq_patch_size}, {lq_patch_size}). '
+                         f'Please remove {gt_path}.')
+
+    # randomly choose top and left coordinates for lq patch
+    top = random.randint(0, h_lq - lq_patch_size)
+    left = random.randint(0, w_lq - lq_patch_size)
+
+    # crop lq patch
+    if input_type == 'Tensor':
+        img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs]
+    else:
+        img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
+
+    # crop corresponding gt patch
+    top_gt, left_gt = int(top * scale), int(left * scale)
+    if input_type == 'Tensor':
+        img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts]
+    else:
+        img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
+    if len(img_gts) == 1:
+        img_gts = img_gts[0]
+    if len(img_lqs) == 1:
+        img_lqs = img_lqs[0]
+    return img_gts, img_lqs
+
+def triplet_random_crop(img_gts, img_lqs, img_segs, gt_patch_size, scale, gt_path=None):
+
+    if not isinstance(img_gts, list):
+        img_gts = [img_gts]
+    if not isinstance(img_lqs, list):
+        img_lqs = [img_lqs]
+    if not isinstance(img_segs, list):
+        img_segs = [img_segs]
+
+    # determine input type: Numpy array or Tensor
+    input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy'
+
+    if input_type == 'Tensor':
+        h_lq, w_lq = img_lqs[0].size()[-2:]
+        h_gt, w_gt = img_gts[0].size()[-2:]
+        h_seg, w_seg = img_segs[0].size()[-2:]
+    else:
+        h_lq, w_lq = img_lqs[0].shape[0:2]
+        h_gt, w_gt = img_gts[0].shape[0:2]
+        h_seg, w_seg = img_segs[0].shape[0:2]
+    lq_patch_size = gt_patch_size // scale
+
+    if h_gt != h_lq * scale or w_gt != w_lq * scale:
+        raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
+                         f'multiplication of LQ ({h_lq}, {w_lq}).')
+    if h_lq < lq_patch_size or w_lq < lq_patch_size:
+        raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
+                         f'({lq_patch_size}, {lq_patch_size}). '
+                         f'Please remove {gt_path}.')
+
+    # randomly choose top and left coordinates for lq patch
+    top = random.randint(0, h_lq - lq_patch_size)
+    left = random.randint(0, w_lq - lq_patch_size)
+
+    # crop lq patch
+    if input_type == 'Tensor':
+        img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs]
+    else:
+        img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
+
+    # crop corresponding gt patch
+    top_gt, left_gt = int(top * scale), int(left * scale)
+    if input_type == 'Tensor':
+        img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts]
+    else:
+        img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
+
+    if input_type == 'Tensor':
+        img_segs = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_segs]
+    else:
+        img_segs = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_segs]
+
+    if len(img_gts) == 1:
+        img_gts = img_gts[0]
+    if len(img_lqs) == 1:
+        img_lqs = img_lqs[0]
+    if len(img_segs) == 1:
+        img_segs = img_segs[0]
+
+    return img_gts, img_lqs, img_segs
+
+
+def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
+    """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
+
+    We use vertical flip and transpose for rotation implementation.
+    All the images in the list use the same augmentation.
+
+    Args:
+        imgs (list[ndarray] | ndarray): Images to be augmented. If the input
+            is an ndarray, it will be transformed to a list.
+        hflip (bool): Horizontal flip. Default: True.
+        rotation (bool): Ratotation. Default: True.
+        flows (list[ndarray]: Flows to be augmented. If the input is an
+            ndarray, it will be transformed to a list.
+            Dimension is (h, w, 2). Default: None.
+        return_status (bool): Return the status of flip and rotation.
+            Default: False.
+
+    Returns:
+        list[ndarray] | ndarray: Augmented images and flows. If returned
+            results only have one element, just return ndarray.
+
+    """
+    hflip = hflip and random.random() < 0.5
+    vflip = rotation and random.random() < 0.5
+    rot90 = rotation and random.random() < 0.5
+
+    def _augment(img):
+        if hflip:  # horizontal
+            cv2.flip(img, 1, img)
+        if vflip:  # vertical
+            cv2.flip(img, 0, img)
+        if rot90:
+            img = img.transpose(1, 0, 2)
+        return img
+
+    def _augment_flow(flow):
+        if hflip:  # horizontal
+            cv2.flip(flow, 1, flow)
+            flow[:, :, 0] *= -1
+        if vflip:  # vertical
+            cv2.flip(flow, 0, flow)
+            flow[:, :, 1] *= -1
+        if rot90:
+            flow = flow.transpose(1, 0, 2)
+            flow = flow[:, :, [1, 0]]
+        return flow
+
+    if not isinstance(imgs, list):
+        imgs = [imgs]
+    imgs = [_augment(img) for img in imgs]
+    if len(imgs) == 1:
+        imgs = imgs[0]
+
+    if flows is not None:
+        if not isinstance(flows, list):
+            flows = [flows]
+        flows = [_augment_flow(flow) for flow in flows]
+        if len(flows) == 1:
+            flows = flows[0]
+        return imgs, flows
+    else:
+        if return_status:
+            return imgs, (hflip, vflip, rot90)
+        else:
+            return imgs
+
+
+def img_rotate(img, angle, center=None, scale=1.0):
+    """Rotate image.
+
+    Args:
+        img (ndarray): Image to be rotated.
+        angle (float): Rotation angle in degrees. Positive values mean
+            counter-clockwise rotation.
+        center (tuple[int]): Rotation center. If the center is None,
+            initialize it as the center of the image. Default: None.
+        scale (float): Isotropic scale factor. Default: 1.0.
+    """
+    (h, w) = img.shape[:2]
+
+    if center is None:
+        center = (w // 2, h // 2)
+
+    matrix = cv2.getRotationMatrix2D(center, angle, scale)
+    rotated_img = cv2.warpAffine(img, matrix, (w, h))
+    return rotated_img
diff --git a/basicsr/data/video_test_dataset.py b/basicsr/data/video_test_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..929f7d97472a0eb810e33e694d5362a6749ab4b6
--- /dev/null
+++ b/basicsr/data/video_test_dataset.py
@@ -0,0 +1,283 @@
+import glob
+import torch
+from os import path as osp
+from torch.utils import data as data
+
+from basicsr.data.data_util import duf_downsample, generate_frame_indices, read_img_seq
+from basicsr.utils import get_root_logger, scandir
+from basicsr.utils.registry import DATASET_REGISTRY
+
+
+@DATASET_REGISTRY.register()
+class VideoTestDataset(data.Dataset):
+    """Video test dataset.
+
+    Supported datasets: Vid4, REDS4, REDSofficial.
+    More generally, it supports testing dataset with following structures:
+
+    ::
+
+        dataroot
+        ├── subfolder1
+            ├── frame000
+            ├── frame001
+            ├── ...
+        ├── subfolder2
+            ├── frame000
+            ├── frame001
+            ├── ...
+        ├── ...
+
+    For testing datasets, there is no need to prepare LMDB files.
+
+    Args:
+        opt (dict): Config for train dataset. It contains the following keys:
+        dataroot_gt (str): Data root path for gt.
+        dataroot_lq (str): Data root path for lq.
+        io_backend (dict): IO backend type and other kwarg.
+        cache_data (bool): Whether to cache testing datasets.
+        name (str): Dataset name.
+        meta_info_file (str): The path to the file storing the list of test folders. If not provided, all the folders
+            in the dataroot will be used.
+        num_frame (int): Window size for input frames.
+        padding (str): Padding mode.
+    """
+
+    def __init__(self, opt):
+        super(VideoTestDataset, self).__init__()
+        self.opt = opt
+        self.cache_data = opt['cache_data']
+        self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq']
+        self.data_info = {'lq_path': [], 'gt_path': [], 'folder': [], 'idx': [], 'border': []}
+        # file client (io backend)
+        self.file_client = None
+        self.io_backend_opt = opt['io_backend']
+        assert self.io_backend_opt['type'] != 'lmdb', 'No need to use lmdb during validation/test.'
+
+        logger = get_root_logger()
+        logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}')
+        self.imgs_lq, self.imgs_gt = {}, {}
+        if 'meta_info_file' in opt:
+            with open(opt['meta_info_file'], 'r') as fin:
+                subfolders = [line.split(' ')[0] for line in fin]
+                subfolders_lq = [osp.join(self.lq_root, key) for key in subfolders]
+                subfolders_gt = [osp.join(self.gt_root, key) for key in subfolders]
+        else:
+            subfolders_lq = sorted(glob.glob(osp.join(self.lq_root, '*')))
+            subfolders_gt = sorted(glob.glob(osp.join(self.gt_root, '*')))
+
+        if opt['name'].lower() in ['vid4', 'reds4', 'redsofficial']:
+            for subfolder_lq, subfolder_gt in zip(subfolders_lq, subfolders_gt):
+                # get frame list for lq and gt
+                subfolder_name = osp.basename(subfolder_lq)
+                img_paths_lq = sorted(list(scandir(subfolder_lq, full_path=True)))
+                img_paths_gt = sorted(list(scandir(subfolder_gt, full_path=True)))
+
+                max_idx = len(img_paths_lq)
+                assert max_idx == len(img_paths_gt), (f'Different number of images in lq ({max_idx})'
+                                                      f' and gt folders ({len(img_paths_gt)})')
+
+                self.data_info['lq_path'].extend(img_paths_lq)
+                self.data_info['gt_path'].extend(img_paths_gt)
+                self.data_info['folder'].extend([subfolder_name] * max_idx)
+                for i in range(max_idx):
+                    self.data_info['idx'].append(f'{i}/{max_idx}')
+                border_l = [0] * max_idx
+                for i in range(self.opt['num_frame'] // 2):
+                    border_l[i] = 1
+                    border_l[max_idx - i - 1] = 1
+                self.data_info['border'].extend(border_l)
+
+                # cache data or save the frame list
+                if self.cache_data:
+                    logger.info(f'Cache {subfolder_name} for VideoTestDataset...')
+                    self.imgs_lq[subfolder_name] = read_img_seq(img_paths_lq)
+                    self.imgs_gt[subfolder_name] = read_img_seq(img_paths_gt)
+                else:
+                    self.imgs_lq[subfolder_name] = img_paths_lq
+                    self.imgs_gt[subfolder_name] = img_paths_gt
+        else:
+            raise ValueError(f'Non-supported video test dataset: {type(opt["name"])}')
+
+    def __getitem__(self, index):
+        folder = self.data_info['folder'][index]
+        idx, max_idx = self.data_info['idx'][index].split('/')
+        idx, max_idx = int(idx), int(max_idx)
+        border = self.data_info['border'][index]
+        lq_path = self.data_info['lq_path'][index]
+
+        select_idx = generate_frame_indices(idx, max_idx, self.opt['num_frame'], padding=self.opt['padding'])
+
+        if self.cache_data:
+            imgs_lq = self.imgs_lq[folder].index_select(0, torch.LongTensor(select_idx))
+            img_gt = self.imgs_gt[folder][idx]
+        else:
+            img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx]
+            imgs_lq = read_img_seq(img_paths_lq)
+            img_gt = read_img_seq([self.imgs_gt[folder][idx]])
+            img_gt.squeeze_(0)
+
+        return {
+            'lq': imgs_lq,  # (t, c, h, w)
+            'gt': img_gt,  # (c, h, w)
+            'folder': folder,  # folder name
+            'idx': self.data_info['idx'][index],  # e.g., 0/99
+            'border': border,  # 1 for border, 0 for non-border
+            'lq_path': lq_path  # center frame
+        }
+
+    def __len__(self):
+        return len(self.data_info['gt_path'])
+
+
+@DATASET_REGISTRY.register()
+class VideoTestVimeo90KDataset(data.Dataset):
+    """Video test dataset for Vimeo90k-Test dataset.
+
+    It only keeps the center frame for testing.
+    For testing datasets, there is no need to prepare LMDB files.
+
+    Args:
+        opt (dict): Config for train dataset. It contains the following keys:
+        dataroot_gt (str): Data root path for gt.
+        dataroot_lq (str): Data root path for lq.
+        io_backend (dict): IO backend type and other kwarg.
+        cache_data (bool): Whether to cache testing datasets.
+        name (str): Dataset name.
+        meta_info_file (str): The path to the file storing the list of test folders. If not provided, all the folders
+            in the dataroot will be used.
+        num_frame (int): Window size for input frames.
+        padding (str): Padding mode.
+    """
+
+    def __init__(self, opt):
+        super(VideoTestVimeo90KDataset, self).__init__()
+        self.opt = opt
+        self.cache_data = opt['cache_data']
+        if self.cache_data:
+            raise NotImplementedError('cache_data in Vimeo90K-Test dataset is not implemented.')
+        self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq']
+        self.data_info = {'lq_path': [], 'gt_path': [], 'folder': [], 'idx': [], 'border': []}
+        neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])]
+
+        # file client (io backend)
+        self.file_client = None
+        self.io_backend_opt = opt['io_backend']
+        assert self.io_backend_opt['type'] != 'lmdb', 'No need to use lmdb during validation/test.'
+
+        logger = get_root_logger()
+        logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}')
+        with open(opt['meta_info_file'], 'r') as fin:
+            subfolders = [line.split(' ')[0] for line in fin]
+        for idx, subfolder in enumerate(subfolders):
+            gt_path = osp.join(self.gt_root, subfolder, 'im4.png')
+            self.data_info['gt_path'].append(gt_path)
+            lq_paths = [osp.join(self.lq_root, subfolder, f'im{i}.png') for i in neighbor_list]
+            self.data_info['lq_path'].append(lq_paths)
+            self.data_info['folder'].append('vimeo90k')
+            self.data_info['idx'].append(f'{idx}/{len(subfolders)}')
+            self.data_info['border'].append(0)
+
+    def __getitem__(self, index):
+        lq_path = self.data_info['lq_path'][index]
+        gt_path = self.data_info['gt_path'][index]
+        imgs_lq = read_img_seq(lq_path)
+        img_gt = read_img_seq([gt_path])
+        img_gt.squeeze_(0)
+
+        return {
+            'lq': imgs_lq,  # (t, c, h, w)
+            'gt': img_gt,  # (c, h, w)
+            'folder': self.data_info['folder'][index],  # folder name
+            'idx': self.data_info['idx'][index],  # e.g., 0/843
+            'border': self.data_info['border'][index],  # 0 for non-border
+            'lq_path': lq_path[self.opt['num_frame'] // 2]  # center frame
+        }
+
+    def __len__(self):
+        return len(self.data_info['gt_path'])
+
+
+@DATASET_REGISTRY.register()
+class VideoTestDUFDataset(VideoTestDataset):
+    """ Video test dataset for DUF dataset.
+
+    Args:
+        opt (dict): Config for train dataset. Most of keys are the same as VideoTestDataset.
+            It has the following extra keys:
+        use_duf_downsampling (bool): Whether to use duf downsampling to generate low-resolution frames.
+        scale (bool): Scale, which will be added automatically.
+    """
+
+    def __getitem__(self, index):
+        folder = self.data_info['folder'][index]
+        idx, max_idx = self.data_info['idx'][index].split('/')
+        idx, max_idx = int(idx), int(max_idx)
+        border = self.data_info['border'][index]
+        lq_path = self.data_info['lq_path'][index]
+
+        select_idx = generate_frame_indices(idx, max_idx, self.opt['num_frame'], padding=self.opt['padding'])
+
+        if self.cache_data:
+            if self.opt['use_duf_downsampling']:
+                # read imgs_gt to generate low-resolution frames
+                imgs_lq = self.imgs_gt[folder].index_select(0, torch.LongTensor(select_idx))
+                imgs_lq = duf_downsample(imgs_lq, kernel_size=13, scale=self.opt['scale'])
+            else:
+                imgs_lq = self.imgs_lq[folder].index_select(0, torch.LongTensor(select_idx))
+            img_gt = self.imgs_gt[folder][idx]
+        else:
+            if self.opt['use_duf_downsampling']:
+                img_paths_lq = [self.imgs_gt[folder][i] for i in select_idx]
+                # read imgs_gt to generate low-resolution frames
+                imgs_lq = read_img_seq(img_paths_lq, require_mod_crop=True, scale=self.opt['scale'])
+                imgs_lq = duf_downsample(imgs_lq, kernel_size=13, scale=self.opt['scale'])
+            else:
+                img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx]
+                imgs_lq = read_img_seq(img_paths_lq)
+            img_gt = read_img_seq([self.imgs_gt[folder][idx]], require_mod_crop=True, scale=self.opt['scale'])
+            img_gt.squeeze_(0)
+
+        return {
+            'lq': imgs_lq,  # (t, c, h, w)
+            'gt': img_gt,  # (c, h, w)
+            'folder': folder,  # folder name
+            'idx': self.data_info['idx'][index],  # e.g., 0/99
+            'border': border,  # 1 for border, 0 for non-border
+            'lq_path': lq_path  # center frame
+        }
+
+
+@DATASET_REGISTRY.register()
+class VideoRecurrentTestDataset(VideoTestDataset):
+    """Video test dataset for recurrent architectures, which takes LR video
+    frames as input and output corresponding HR video frames.
+
+    Args:
+        opt (dict): Same as VideoTestDataset. Unused opt:
+        padding (str): Padding mode.
+
+    """
+
+    def __init__(self, opt):
+        super(VideoRecurrentTestDataset, self).__init__(opt)
+        # Find unique folder strings
+        self.folders = sorted(list(set(self.data_info['folder'])))
+
+    def __getitem__(self, index):
+        folder = self.folders[index]
+
+        if self.cache_data:
+            imgs_lq = self.imgs_lq[folder]
+            imgs_gt = self.imgs_gt[folder]
+        else:
+            raise NotImplementedError('Without cache_data is not implemented.')
+
+        return {
+            'lq': imgs_lq,
+            'gt': imgs_gt,
+            'folder': folder,
+        }
+
+    def __len__(self):
+        return len(self.folders)
diff --git a/basicsr/data/vimeo90k_dataset.py b/basicsr/data/vimeo90k_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5e33e1082667aeee61fecf2436fb287e82e0936
--- /dev/null
+++ b/basicsr/data/vimeo90k_dataset.py
@@ -0,0 +1,199 @@
+import random
+import torch
+from pathlib import Path
+from torch.utils import data as data
+
+from basicsr.data.transforms import augment, paired_random_crop
+from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
+from basicsr.utils.registry import DATASET_REGISTRY
+
+
+@DATASET_REGISTRY.register()
+class Vimeo90KDataset(data.Dataset):
+    """Vimeo90K dataset for training.
+
+    The keys are generated from a meta info txt file.
+    basicsr/data/meta_info/meta_info_Vimeo90K_train_GT.txt
+
+    Each line contains the following items, separated by a white space.
+
+    1. clip name;
+    2. frame number;
+    3. image shape
+
+    Examples:
+
+    ::
+
+        00001/0001 7 (256,448,3)
+        00001/0002 7 (256,448,3)
+
+    - Key examples: "00001/0001"
+    - GT (gt): Ground-Truth;
+    - LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
+
+    The neighboring frame list for different num_frame:
+
+    ::
+
+        num_frame | frame list
+                1 | 4
+                3 | 3,4,5
+                5 | 2,3,4,5,6
+                7 | 1,2,3,4,5,6,7
+
+    Args:
+        opt (dict): Config for train dataset. It contains the following keys:
+        dataroot_gt (str): Data root path for gt.
+        dataroot_lq (str): Data root path for lq.
+        meta_info_file (str): Path for meta information file.
+        io_backend (dict): IO backend type and other kwarg.
+        num_frame (int): Window size for input frames.
+        gt_size (int): Cropped patched size for gt patches.
+        random_reverse (bool): Random reverse input frames.
+        use_hflip (bool): Use horizontal flips.
+        use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
+        scale (bool): Scale, which will be added automatically.
+    """
+
+    def __init__(self, opt):
+        super(Vimeo90KDataset, self).__init__()
+        self.opt = opt
+        self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq'])
+
+        with open(opt['meta_info_file'], 'r') as fin:
+            self.keys = [line.split(' ')[0] for line in fin]
+
+        # file client (io backend)
+        self.file_client = None
+        self.io_backend_opt = opt['io_backend']
+        self.is_lmdb = False
+        if self.io_backend_opt['type'] == 'lmdb':
+            self.is_lmdb = True
+            self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
+            self.io_backend_opt['client_keys'] = ['lq', 'gt']
+
+        # indices of input images
+        self.neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])]
+
+        # temporal augmentation configs
+        self.random_reverse = opt['random_reverse']
+        logger = get_root_logger()
+        logger.info(f'Random reverse is {self.random_reverse}.')
+
+    def __getitem__(self, index):
+        if self.file_client is None:
+            self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+        # random reverse
+        if self.random_reverse and random.random() < 0.5:
+            self.neighbor_list.reverse()
+
+        scale = self.opt['scale']
+        gt_size = self.opt['gt_size']
+        key = self.keys[index]
+        clip, seq = key.split('/')  # key example: 00001/0001
+
+        # get the GT frame (im4.png)
+        if self.is_lmdb:
+            img_gt_path = f'{key}/im4'
+        else:
+            img_gt_path = self.gt_root / clip / seq / 'im4.png'
+        img_bytes = self.file_client.get(img_gt_path, 'gt')
+        img_gt = imfrombytes(img_bytes, float32=True)
+
+        # get the neighboring LQ frames
+        img_lqs = []
+        for neighbor in self.neighbor_list:
+            if self.is_lmdb:
+                img_lq_path = f'{clip}/{seq}/im{neighbor}'
+            else:
+                img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png'
+            img_bytes = self.file_client.get(img_lq_path, 'lq')
+            img_lq = imfrombytes(img_bytes, float32=True)
+            img_lqs.append(img_lq)
+
+        # randomly crop
+        img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale, img_gt_path)
+
+        # augmentation - flip, rotate
+        img_lqs.append(img_gt)
+        img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
+
+        img_results = img2tensor(img_results)
+        img_lqs = torch.stack(img_results[0:-1], dim=0)
+        img_gt = img_results[-1]
+
+        # img_lqs: (t, c, h, w)
+        # img_gt: (c, h, w)
+        # key: str
+        return {'lq': img_lqs, 'gt': img_gt, 'key': key}
+
+    def __len__(self):
+        return len(self.keys)
+
+
+@DATASET_REGISTRY.register()
+class Vimeo90KRecurrentDataset(Vimeo90KDataset):
+
+    def __init__(self, opt):
+        super(Vimeo90KRecurrentDataset, self).__init__(opt)
+
+        self.flip_sequence = opt['flip_sequence']
+        self.neighbor_list = [1, 2, 3, 4, 5, 6, 7]
+
+    def __getitem__(self, index):
+        if self.file_client is None:
+            self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+        # random reverse
+        if self.random_reverse and random.random() < 0.5:
+            self.neighbor_list.reverse()
+
+        scale = self.opt['scale']
+        gt_size = self.opt['gt_size']
+        key = self.keys[index]
+        clip, seq = key.split('/')  # key example: 00001/0001
+
+        # get the neighboring LQ and  GT frames
+        img_lqs = []
+        img_gts = []
+        for neighbor in self.neighbor_list:
+            if self.is_lmdb:
+                img_lq_path = f'{clip}/{seq}/im{neighbor}'
+                img_gt_path = f'{clip}/{seq}/im{neighbor}'
+            else:
+                img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png'
+                img_gt_path = self.gt_root / clip / seq / f'im{neighbor}.png'
+            # LQ
+            img_bytes = self.file_client.get(img_lq_path, 'lq')
+            img_lq = imfrombytes(img_bytes, float32=True)
+            # GT
+            img_bytes = self.file_client.get(img_gt_path, 'gt')
+            img_gt = imfrombytes(img_bytes, float32=True)
+
+            img_lqs.append(img_lq)
+            img_gts.append(img_gt)
+
+        # randomly crop
+        img_gts, img_lqs = paired_random_crop(img_gts, img_lqs, gt_size, scale, img_gt_path)
+
+        # augmentation - flip, rotate
+        img_lqs.extend(img_gts)
+        img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
+
+        img_results = img2tensor(img_results)
+        img_lqs = torch.stack(img_results[:7], dim=0)
+        img_gts = torch.stack(img_results[7:], dim=0)
+
+        if self.flip_sequence:  # flip the sequence: 7 frames to 14 frames
+            img_lqs = torch.cat([img_lqs, img_lqs.flip(0)], dim=0)
+            img_gts = torch.cat([img_gts, img_gts.flip(0)], dim=0)
+
+        # img_lqs: (t, c, h, w)
+        # img_gt: (c, h, w)
+        # key: str
+        return {'lq': img_lqs, 'gt': img_gts, 'key': key}
+
+    def __len__(self):
+        return len(self.keys)
diff --git a/basicsr/losses/__init__.py b/basicsr/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..70a172aeed5b388ae102466eb1f02d40ba30e9b4
--- /dev/null
+++ b/basicsr/losses/__init__.py
@@ -0,0 +1,31 @@
+import importlib
+from copy import deepcopy
+from os import path as osp
+
+from basicsr.utils import get_root_logger, scandir
+from basicsr.utils.registry import LOSS_REGISTRY
+from .gan_loss import g_path_regularize, gradient_penalty_loss, r1_penalty
+
+__all__ = ['build_loss', 'gradient_penalty_loss', 'r1_penalty', 'g_path_regularize']
+
+# automatically scan and import loss modules for registry
+# scan all the files under the 'losses' folder and collect files ending with '_loss.py'
+loss_folder = osp.dirname(osp.abspath(__file__))
+loss_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(loss_folder) if v.endswith('_loss.py')]
+# import all the loss modules
+_model_modules = [importlib.import_module(f'basicsr.losses.{file_name}') for file_name in loss_filenames]
+
+
+def build_loss(opt):
+    """Build loss from options.
+
+    Args:
+        opt (dict): Configuration. It must contain:
+            type (str): Model type.
+    """
+    opt = deepcopy(opt)
+    loss_type = opt.pop('type')
+    loss = LOSS_REGISTRY.get(loss_type)(**opt)
+    logger = get_root_logger()
+    logger.info(f'Loss [{loss.__class__.__name__}] is created.')
+    return loss
diff --git a/basicsr/losses/__pycache__/__init__.cpython-310.pyc b/basicsr/losses/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..45eab683236c4c18b5b7113cab433c8ab99d3b80
Binary files /dev/null and b/basicsr/losses/__pycache__/__init__.cpython-310.pyc differ
diff --git a/basicsr/losses/__pycache__/basic_loss.cpython-310.pyc b/basicsr/losses/__pycache__/basic_loss.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a4d7b64853935dfecbe9502acf59fc4d9bf32a8b
Binary files /dev/null and b/basicsr/losses/__pycache__/basic_loss.cpython-310.pyc differ
diff --git a/basicsr/losses/__pycache__/gan_loss.cpython-310.pyc b/basicsr/losses/__pycache__/gan_loss.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..56923128d39ab39a1a0c7b81a43d869e83cd0ac2
Binary files /dev/null and b/basicsr/losses/__pycache__/gan_loss.cpython-310.pyc differ
diff --git a/basicsr/losses/__pycache__/loss_util.cpython-310.pyc b/basicsr/losses/__pycache__/loss_util.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3cd31179ff2c3d240c340b2f5f004d76002c2103
Binary files /dev/null and b/basicsr/losses/__pycache__/loss_util.cpython-310.pyc differ
diff --git a/basicsr/losses/basic_loss.py b/basicsr/losses/basic_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2e965526a9b0e2686575bf93f0173cc2664d9bb
--- /dev/null
+++ b/basicsr/losses/basic_loss.py
@@ -0,0 +1,253 @@
+import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+from basicsr.archs.vgg_arch import VGGFeatureExtractor
+from basicsr.utils.registry import LOSS_REGISTRY
+from .loss_util import weighted_loss
+
+_reduction_modes = ['none', 'mean', 'sum']
+
+
+@weighted_loss
+def l1_loss(pred, target):
+    return F.l1_loss(pred, target, reduction='none')
+
+
+@weighted_loss
+def mse_loss(pred, target):
+    return F.mse_loss(pred, target, reduction='none')
+
+
+@weighted_loss
+def charbonnier_loss(pred, target, eps=1e-12):
+    return torch.sqrt((pred - target)**2 + eps)
+
+
+@LOSS_REGISTRY.register()
+class L1Loss(nn.Module):
+    """L1 (mean absolute error, MAE) loss.
+
+    Args:
+        loss_weight (float): Loss weight for L1 loss. Default: 1.0.
+        reduction (str): Specifies the reduction to apply to the output.
+            Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
+    """
+
+    def __init__(self, loss_weight=1.0, reduction='mean'):
+        super(L1Loss, self).__init__()
+        if reduction not in ['none', 'mean', 'sum']:
+            raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}')
+
+        self.loss_weight = loss_weight
+        self.reduction = reduction
+
+    def forward(self, pred, target, weight=None, **kwargs):
+        """
+        Args:
+            pred (Tensor): of shape (N, C, H, W). Predicted tensor.
+            target (Tensor): of shape (N, C, H, W). Ground truth tensor.
+            weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
+        """
+        return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction)
+
+
+@LOSS_REGISTRY.register()
+class MSELoss(nn.Module):
+    """MSE (L2) loss.
+
+    Args:
+        loss_weight (float): Loss weight for MSE loss. Default: 1.0.
+        reduction (str): Specifies the reduction to apply to the output.
+            Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
+    """
+
+    def __init__(self, loss_weight=1.0, reduction='mean'):
+        super(MSELoss, self).__init__()
+        if reduction not in ['none', 'mean', 'sum']:
+            raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}')
+
+        self.loss_weight = loss_weight
+        self.reduction = reduction
+
+    def forward(self, pred, target, weight=None, **kwargs):
+        """
+        Args:
+            pred (Tensor): of shape (N, C, H, W). Predicted tensor.
+            target (Tensor): of shape (N, C, H, W). Ground truth tensor.
+            weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
+        """
+        return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction)
+
+
+@LOSS_REGISTRY.register()
+class CharbonnierLoss(nn.Module):
+    """Charbonnier loss (one variant of Robust L1Loss, a differentiable
+    variant of L1Loss).
+
+    Described in "Deep Laplacian Pyramid Networks for Fast and Accurate
+        Super-Resolution".
+
+    Args:
+        loss_weight (float): Loss weight for L1 loss. Default: 1.0.
+        reduction (str): Specifies the reduction to apply to the output.
+            Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
+        eps (float): A value used to control the curvature near zero. Default: 1e-12.
+    """
+
+    def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12):
+        super(CharbonnierLoss, self).__init__()
+        if reduction not in ['none', 'mean', 'sum']:
+            raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}')
+
+        self.loss_weight = loss_weight
+        self.reduction = reduction
+        self.eps = eps
+
+    def forward(self, pred, target, weight=None, **kwargs):
+        """
+        Args:
+            pred (Tensor): of shape (N, C, H, W). Predicted tensor.
+            target (Tensor): of shape (N, C, H, W). Ground truth tensor.
+            weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
+        """
+        return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction)
+
+
+@LOSS_REGISTRY.register()
+class WeightedTVLoss(L1Loss):
+    """Weighted TV loss.
+
+    Args:
+        loss_weight (float): Loss weight. Default: 1.0.
+    """
+
+    def __init__(self, loss_weight=1.0, reduction='mean'):
+        if reduction not in ['mean', 'sum']:
+            raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: mean | sum')
+        super(WeightedTVLoss, self).__init__(loss_weight=loss_weight, reduction=reduction)
+
+    def forward(self, pred, weight=None):
+        if weight is None:
+            y_weight = None
+            x_weight = None
+        else:
+            y_weight = weight[:, :, :-1, :]
+            x_weight = weight[:, :, :, :-1]
+
+        y_diff = super().forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=y_weight)
+        x_diff = super().forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=x_weight)
+
+        loss = x_diff + y_diff
+
+        return loss
+
+
+@LOSS_REGISTRY.register()
+class PerceptualLoss(nn.Module):
+    """Perceptual loss with commonly used style loss.
+
+    Args:
+        layer_weights (dict): The weight for each layer of vgg feature.
+            Here is an example: {'conv5_4': 1.}, which means the conv5_4
+            feature layer (before relu5_4) will be extracted with weight
+            1.0 in calculating losses.
+        vgg_type (str): The type of vgg network used as feature extractor.
+            Default: 'vgg19'.
+        use_input_norm (bool):  If True, normalize the input image in vgg.
+            Default: True.
+        range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
+            Default: False.
+        perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
+            loss will be calculated and the loss will multiplied by the
+            weight. Default: 1.0.
+        style_weight (float): If `style_weight > 0`, the style loss will be
+            calculated and the loss will multiplied by the weight.
+            Default: 0.
+        criterion (str): Criterion used for perceptual loss. Default: 'l1'.
+    """
+
+    def __init__(self,
+                 layer_weights,
+                 vgg_type='vgg19',
+                 use_input_norm=True,
+                 range_norm=False,
+                 perceptual_weight=1.0,
+                 style_weight=0.,
+                 criterion='l1'):
+        super(PerceptualLoss, self).__init__()
+        self.perceptual_weight = perceptual_weight
+        self.style_weight = style_weight
+        self.layer_weights = layer_weights
+        self.vgg = VGGFeatureExtractor(
+            layer_name_list=list(layer_weights.keys()),
+            vgg_type=vgg_type,
+            use_input_norm=use_input_norm,
+            range_norm=range_norm)
+
+        self.criterion_type = criterion
+        if self.criterion_type == 'l1':
+            self.criterion = torch.nn.L1Loss()
+        elif self.criterion_type == 'l2':
+            self.criterion = torch.nn.L2loss()
+        elif self.criterion_type == 'fro':
+            self.criterion = None
+        else:
+            raise NotImplementedError(f'{criterion} criterion has not been supported.')
+
+    def forward(self, x, gt):
+        """Forward function.
+
+        Args:
+            x (Tensor): Input tensor with shape (n, c, h, w).
+            gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
+
+        Returns:
+            Tensor: Forward results.
+        """
+        # extract vgg features
+        x_features = self.vgg(x)
+        gt_features = self.vgg(gt.detach())
+
+        # calculate perceptual loss
+        if self.perceptual_weight > 0:
+            percep_loss = 0
+            for k in x_features.keys():
+                if self.criterion_type == 'fro':
+                    percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k]
+                else:
+                    percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
+            percep_loss *= self.perceptual_weight
+        else:
+            percep_loss = None
+
+        # calculate style loss
+        if self.style_weight > 0:
+            style_loss = 0
+            for k in x_features.keys():
+                if self.criterion_type == 'fro':
+                    style_loss += torch.norm(
+                        self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k]
+                else:
+                    style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(
+                        gt_features[k])) * self.layer_weights[k]
+            style_loss *= self.style_weight
+        else:
+            style_loss = None
+
+        return percep_loss, style_loss
+
+    def _gram_mat(self, x):
+        """Calculate Gram matrix.
+
+        Args:
+            x (torch.Tensor): Tensor with shape of (n, c, h, w).
+
+        Returns:
+            torch.Tensor: Gram matrix.
+        """
+        n, c, h, w = x.size()
+        features = x.view(n, c, w * h)
+        features_t = features.transpose(1, 2)
+        gram = features.bmm(features_t) / (c * h * w)
+        return gram
diff --git a/basicsr/losses/gan_loss.py b/basicsr/losses/gan_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..870baa2227b79eab29a3141a216b4b614e2bcdf3
--- /dev/null
+++ b/basicsr/losses/gan_loss.py
@@ -0,0 +1,207 @@
+import math
+import torch
+from torch import autograd as autograd
+from torch import nn as nn
+from torch.nn import functional as F
+
+from basicsr.utils.registry import LOSS_REGISTRY
+
+
+@LOSS_REGISTRY.register()
+class GANLoss(nn.Module):
+    """Define GAN loss.
+
+    Args:
+        gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
+        real_label_val (float): The value for real label. Default: 1.0.
+        fake_label_val (float): The value for fake label. Default: 0.0.
+        loss_weight (float): Loss weight. Default: 1.0.
+            Note that loss_weight is only for generators; and it is always 1.0
+            for discriminators.
+    """
+
+    def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
+        super(GANLoss, self).__init__()
+        self.gan_type = gan_type
+        self.loss_weight = loss_weight
+        self.real_label_val = real_label_val
+        self.fake_label_val = fake_label_val
+
+        if self.gan_type == 'vanilla':
+            self.loss = nn.BCEWithLogitsLoss()
+        elif self.gan_type == 'lsgan':
+            self.loss = nn.MSELoss()
+        elif self.gan_type == 'wgan':
+            self.loss = self._wgan_loss
+        elif self.gan_type == 'wgan_softplus':
+            self.loss = self._wgan_softplus_loss
+        elif self.gan_type == 'hinge':
+            self.loss = nn.ReLU()
+        else:
+            raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.')
+
+    def _wgan_loss(self, input, target):
+        """wgan loss.
+
+        Args:
+            input (Tensor): Input tensor.
+            target (bool): Target label.
+
+        Returns:
+            Tensor: wgan loss.
+        """
+        return -input.mean() if target else input.mean()
+
+    def _wgan_softplus_loss(self, input, target):
+        """wgan loss with soft plus. softplus is a smooth approximation to the
+        ReLU function.
+
+        In StyleGAN2, it is called:
+            Logistic loss for discriminator;
+            Non-saturating loss for generator.
+
+        Args:
+            input (Tensor): Input tensor.
+            target (bool): Target label.
+
+        Returns:
+            Tensor: wgan loss.
+        """
+        return F.softplus(-input).mean() if target else F.softplus(input).mean()
+
+    def get_target_label(self, input, target_is_real):
+        """Get target label.
+
+        Args:
+            input (Tensor): Input tensor.
+            target_is_real (bool): Whether the target is real or fake.
+
+        Returns:
+            (bool | Tensor): Target tensor. Return bool for wgan, otherwise,
+                return Tensor.
+        """
+
+        if self.gan_type in ['wgan', 'wgan_softplus']:
+            return target_is_real
+        target_val = (self.real_label_val if target_is_real else self.fake_label_val)
+        return input.new_ones(input.size()) * target_val
+
+    def forward(self, input, target_is_real, is_disc=False):
+        """
+        Args:
+            input (Tensor): The input for the loss module, i.e., the network
+                prediction.
+            target_is_real (bool): Whether the targe is real or fake.
+            is_disc (bool): Whether the loss for discriminators or not.
+                Default: False.
+
+        Returns:
+            Tensor: GAN loss value.
+        """
+        target_label = self.get_target_label(input, target_is_real)
+        if self.gan_type == 'hinge':
+            if is_disc:  # for discriminators in hinge-gan
+                input = -input if target_is_real else input
+                loss = self.loss(1 + input).mean()
+            else:  # for generators in hinge-gan
+                loss = -input.mean()
+        else:  # other gan types
+            loss = self.loss(input, target_label)
+
+        # loss_weight is always 1.0 for discriminators
+        return loss if is_disc else loss * self.loss_weight
+
+
+@LOSS_REGISTRY.register()
+class MultiScaleGANLoss(GANLoss):
+    """
+    MultiScaleGANLoss accepts a list of predictions
+    """
+
+    def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
+        super(MultiScaleGANLoss, self).__init__(gan_type, real_label_val, fake_label_val, loss_weight)
+
+    def forward(self, input, target_is_real, is_disc=False):
+        """
+        The input is a list of tensors, or a list of (a list of tensors)
+        """
+        if isinstance(input, list):
+            loss = 0
+            for pred_i in input:
+                if isinstance(pred_i, list):
+                    # Only compute GAN loss for the last layer
+                    # in case of multiscale feature matching
+                    pred_i = pred_i[-1]
+                # Safe operation: 0-dim tensor calling self.mean() does nothing
+                loss_tensor = super().forward(pred_i, target_is_real, is_disc).mean()
+                loss += loss_tensor
+            return loss / len(input)
+        else:
+            return super().forward(input, target_is_real, is_disc)
+
+
+def r1_penalty(real_pred, real_img):
+    """R1 regularization for discriminator. The core idea is to
+        penalize the gradient on real data alone: when the
+        generator distribution produces the true data distribution
+        and the discriminator is equal to 0 on the data manifold, the
+        gradient penalty ensures that the discriminator cannot create
+        a non-zero gradient orthogonal to the data manifold without
+        suffering a loss in the GAN game.
+
+        Reference: Eq. 9 in Which training methods for GANs do actually converge.
+        """
+    grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0]
+    grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
+    return grad_penalty
+
+
+def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
+    noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3])
+    grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0]
+    path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
+
+    path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
+
+    path_penalty = (path_lengths - path_mean).pow(2).mean()
+
+    return path_penalty, path_lengths.detach().mean(), path_mean.detach()
+
+
+def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None):
+    """Calculate gradient penalty for wgan-gp.
+
+    Args:
+        discriminator (nn.Module): Network for the discriminator.
+        real_data (Tensor): Real input data.
+        fake_data (Tensor): Fake input data.
+        weight (Tensor): Weight tensor. Default: None.
+
+    Returns:
+        Tensor: A tensor for gradient penalty.
+    """
+
+    batch_size = real_data.size(0)
+    alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1))
+
+    # interpolate between real_data and fake_data
+    interpolates = alpha * real_data + (1. - alpha) * fake_data
+    interpolates = autograd.Variable(interpolates, requires_grad=True)
+
+    disc_interpolates = discriminator(interpolates)
+    gradients = autograd.grad(
+        outputs=disc_interpolates,
+        inputs=interpolates,
+        grad_outputs=torch.ones_like(disc_interpolates),
+        create_graph=True,
+        retain_graph=True,
+        only_inputs=True)[0]
+
+    if weight is not None:
+        gradients = gradients * weight
+
+    gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
+    if weight is not None:
+        gradients_penalty /= torch.mean(weight)
+
+    return gradients_penalty
diff --git a/basicsr/losses/loss_util.py b/basicsr/losses/loss_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd293ff9e6a22814e5aeff6ae11fb54d2e4bafff
--- /dev/null
+++ b/basicsr/losses/loss_util.py
@@ -0,0 +1,145 @@
+import functools
+import torch
+from torch.nn import functional as F
+
+
+def reduce_loss(loss, reduction):
+    """Reduce loss as specified.
+
+    Args:
+        loss (Tensor): Elementwise loss tensor.
+        reduction (str): Options are 'none', 'mean' and 'sum'.
+
+    Returns:
+        Tensor: Reduced loss tensor.
+    """
+    reduction_enum = F._Reduction.get_enum(reduction)
+    # none: 0, elementwise_mean:1, sum: 2
+    if reduction_enum == 0:
+        return loss
+    elif reduction_enum == 1:
+        return loss.mean()
+    else:
+        return loss.sum()
+
+
+def weight_reduce_loss(loss, weight=None, reduction='mean'):
+    """Apply element-wise weight and reduce loss.
+
+    Args:
+        loss (Tensor): Element-wise loss.
+        weight (Tensor): Element-wise weights. Default: None.
+        reduction (str): Same as built-in losses of PyTorch. Options are
+            'none', 'mean' and 'sum'. Default: 'mean'.
+
+    Returns:
+        Tensor: Loss values.
+    """
+    # if weight is specified, apply element-wise weight
+    if weight is not None:
+        assert weight.dim() == loss.dim()
+        assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
+        loss = loss * weight
+
+    # if weight is not specified or reduction is sum, just reduce the loss
+    if weight is None or reduction == 'sum':
+        loss = reduce_loss(loss, reduction)
+    # if reduction is mean, then compute mean over weight region
+    elif reduction == 'mean':
+        if weight.size(1) > 1:
+            weight = weight.sum()
+        else:
+            weight = weight.sum() * loss.size(1)
+        loss = loss.sum() / weight
+
+    return loss
+
+
+def weighted_loss(loss_func):
+    """Create a weighted version of a given loss function.
+
+    To use this decorator, the loss function must have the signature like
+    `loss_func(pred, target, **kwargs)`. The function only needs to compute
+    element-wise loss without any reduction. This decorator will add weight
+    and reduction arguments to the function. The decorated function will have
+    the signature like `loss_func(pred, target, weight=None, reduction='mean',
+    **kwargs)`.
+
+    :Example:
+
+    >>> import torch
+    >>> @weighted_loss
+    >>> def l1_loss(pred, target):
+    >>>     return (pred - target).abs()
+
+    >>> pred = torch.Tensor([0, 2, 3])
+    >>> target = torch.Tensor([1, 1, 1])
+    >>> weight = torch.Tensor([1, 0, 1])
+
+    >>> l1_loss(pred, target)
+    tensor(1.3333)
+    >>> l1_loss(pred, target, weight)
+    tensor(1.5000)
+    >>> l1_loss(pred, target, reduction='none')
+    tensor([1., 1., 2.])
+    >>> l1_loss(pred, target, weight, reduction='sum')
+    tensor(3.)
+    """
+
+    @functools.wraps(loss_func)
+    def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
+        # get element-wise loss
+        loss = loss_func(pred, target, **kwargs)
+        loss = weight_reduce_loss(loss, weight, reduction)
+        return loss
+
+    return wrapper
+
+
+def get_local_weights(residual, ksize):
+    """Get local weights for generating the artifact map of LDL.
+
+    It is only called by the `get_refined_artifact_map` function.
+
+    Args:
+        residual (Tensor): Residual between predicted and ground truth images.
+        ksize (Int): size of the local window.
+
+    Returns:
+        Tensor: weight for each pixel to be discriminated as an artifact pixel
+    """
+
+    pad = (ksize - 1) // 2
+    residual_pad = F.pad(residual, pad=[pad, pad, pad, pad], mode='reflect')
+
+    unfolded_residual = residual_pad.unfold(2, ksize, 1).unfold(3, ksize, 1)
+    pixel_level_weight = torch.var(unfolded_residual, dim=(-1, -2), unbiased=True, keepdim=True).squeeze(-1).squeeze(-1)
+
+    return pixel_level_weight
+
+
+def get_refined_artifact_map(img_gt, img_output, img_ema, ksize):
+    """Calculate the artifact map of LDL
+    (Details or Artifacts: A Locally Discriminative Learning Approach to Realistic Image Super-Resolution. In CVPR 2022)
+
+    Args:
+        img_gt (Tensor): ground truth images.
+        img_output (Tensor): output images given by the optimizing model.
+        img_ema (Tensor): output images given by the ema model.
+        ksize (Int): size of the local window.
+
+    Returns:
+        overall_weight: weight for each pixel to be discriminated as an artifact pixel
+        (calculated based on both local and global observations).
+    """
+
+    residual_ema = torch.sum(torch.abs(img_gt - img_ema), 1, keepdim=True)
+    residual_sr = torch.sum(torch.abs(img_gt - img_output), 1, keepdim=True)
+
+    patch_level_weight = torch.var(residual_sr.clone(), dim=(-1, -2, -3), keepdim=True)**(1 / 5)
+    pixel_level_weight = get_local_weights(residual_sr.clone(), ksize)
+    overall_weight = patch_level_weight * pixel_level_weight
+
+    overall_weight[residual_sr < residual_ema] = 0
+
+    return overall_weight
diff --git a/basicsr/metrics/README.md b/basicsr/metrics/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..98d00308ab79e92a2393f9759190de8122a8e79d
--- /dev/null
+++ b/basicsr/metrics/README.md
@@ -0,0 +1,48 @@
+# Metrics
+
+[English](README.md) **|** [简体中文](README_CN.md)
+
+- [约定](#约定)
+- [PSNR 和 SSIM](#psnr-和-ssim)
+
+## 约定
+
+因为不同的输入类型会导致结果的不同,因此我们对输入做如下约定:
+
+- Numpy 类型 (一般是 cv2 的结果)
+  - UINT8: BGR, [0, 255], (h, w, c)
+  - float: BGR, [0, 1], (h, w, c). 一般作为中间结果
+- Tensor 类型
+  - float: RGB, [0, 1], (n, c, h, w)
+
+其他约定:
+
+- 以 `_pt` 结尾的是 PyTorch 结果
+- PyTorch version 支持 batch 计算
+- 颜色转换在 float32 上做;metric计算在 float64 上做
+
+## PSNR 和 SSIM
+
+PSNR 和 SSIM 的结果趋势是一致的,即一般 PSNR 高,则 SSIM 也高。
+在实现上, PSNR 的各种实现都很一致。SSIM 有各种各样的实现,我们这里和 MATLAB 最原始版本保持 (参考 [NTIRE17比赛](https://competitions.codalab.org/competitions/16306#participate) 的 [evaluation代码](https://competitions.codalab.org/my/datasets/download/ebe960d8-0ec8-4846-a1a2-7c4a586a7378))
+
+下面列了各个实现的结果比对.
+总结:PyTorch 实现和 MATLAB 实现基本一致,在 GPU 运行上会有稍许差异
+
+- PSNR 比对
+
+|Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU  |
+|:---| :---: | :---:  | :---:      |     :---:      | :---: |
+|baboon| RGB |  20.419710  | 20.419710 | 20.419710 |20.419710 |
+|baboon| Y | - |22.441898 | 22.441899 |  22.444916|
+|comic | RGB | 20.239912 | 20.239912 | 20.239912 | 20.239912 |
+|comic | Y | - | 21.720398 | 21.720398  | 21.721663|
+
+- SSIM 比对
+
+|Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU  |
+|:---| :---: | :---:  | :---:      |     :---:      | :---: |
+|baboon| RGB |  0.391853  | 0.391853 | 0.391853|0.391853 |
+|baboon| Y | - |0.453097| 0.453097 |  0.453171|
+|comic | RGB | 0.567738 | 0.567738 | 0.567738 | 0.567738|
+|comic | Y | - | 0.585511 | 0.585511 | 0.585522 |
diff --git a/basicsr/metrics/README_CN.md b/basicsr/metrics/README_CN.md
new file mode 100644
index 0000000000000000000000000000000000000000..98d00308ab79e92a2393f9759190de8122a8e79d
--- /dev/null
+++ b/basicsr/metrics/README_CN.md
@@ -0,0 +1,48 @@
+# Metrics
+
+[English](README.md) **|** [简体中文](README_CN.md)
+
+- [约定](#约定)
+- [PSNR 和 SSIM](#psnr-和-ssim)
+
+## 约定
+
+因为不同的输入类型会导致结果的不同,因此我们对输入做如下约定:
+
+- Numpy 类型 (一般是 cv2 的结果)
+  - UINT8: BGR, [0, 255], (h, w, c)
+  - float: BGR, [0, 1], (h, w, c). 一般作为中间结果
+- Tensor 类型
+  - float: RGB, [0, 1], (n, c, h, w)
+
+其他约定:
+
+- 以 `_pt` 结尾的是 PyTorch 结果
+- PyTorch version 支持 batch 计算
+- 颜色转换在 float32 上做;metric计算在 float64 上做
+
+## PSNR 和 SSIM
+
+PSNR 和 SSIM 的结果趋势是一致的,即一般 PSNR 高,则 SSIM 也高。
+在实现上, PSNR 的各种实现都很一致。SSIM 有各种各样的实现,我们这里和 MATLAB 最原始版本保持 (参考 [NTIRE17比赛](https://competitions.codalab.org/competitions/16306#participate) 的 [evaluation代码](https://competitions.codalab.org/my/datasets/download/ebe960d8-0ec8-4846-a1a2-7c4a586a7378))
+
+下面列了各个实现的结果比对.
+总结:PyTorch 实现和 MATLAB 实现基本一致,在 GPU 运行上会有稍许差异
+
+- PSNR 比对
+
+|Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU  |
+|:---| :---: | :---:  | :---:      |     :---:      | :---: |
+|baboon| RGB |  20.419710  | 20.419710 | 20.419710 |20.419710 |
+|baboon| Y | - |22.441898 | 22.441899 |  22.444916|
+|comic | RGB | 20.239912 | 20.239912 | 20.239912 | 20.239912 |
+|comic | Y | - | 21.720398 | 21.720398  | 21.721663|
+
+- SSIM 比对
+
+|Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU  |
+|:---| :---: | :---:  | :---:      |     :---:      | :---: |
+|baboon| RGB |  0.391853  | 0.391853 | 0.391853|0.391853 |
+|baboon| Y | - |0.453097| 0.453097 |  0.453171|
+|comic | RGB | 0.567738 | 0.567738 | 0.567738 | 0.567738|
+|comic | Y | - | 0.585511 | 0.585511 | 0.585522 |
diff --git a/basicsr/metrics/__init__.py b/basicsr/metrics/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..330f3c863f66a98d41942c6995837283265d94ef
--- /dev/null
+++ b/basicsr/metrics/__init__.py
@@ -0,0 +1,20 @@
+from copy import deepcopy
+
+from basicsr.utils.registry import METRIC_REGISTRY
+from .niqe import calculate_niqe
+from .psnr_ssim import calculate_psnr, calculate_ssim, calculate_ssim_pt, calculate_psnr_pt
+
+__all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe']
+
+
+def calculate_metric(data, opt):
+    """Calculate metric from data and options.
+
+    Args:
+        opt (dict): Configuration. It must contain:
+            type (str): Model type.
+    """
+    opt = deepcopy(opt)
+    metric_type = opt.pop('type')
+    metric = METRIC_REGISTRY.get(metric_type)(**data, **opt)
+    return metric
diff --git a/basicsr/metrics/__pycache__/__init__.cpython-310.pyc b/basicsr/metrics/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b4d597fb39f26d9f4379800a6de6eb22cf7b6ee0
Binary files /dev/null and b/basicsr/metrics/__pycache__/__init__.cpython-310.pyc differ
diff --git a/basicsr/metrics/__pycache__/metric_util.cpython-310.pyc b/basicsr/metrics/__pycache__/metric_util.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fdfa57ae90e81ef558cd76a4811a95b3e9af1290
Binary files /dev/null and b/basicsr/metrics/__pycache__/metric_util.cpython-310.pyc differ
diff --git a/basicsr/metrics/__pycache__/niqe.cpython-310.pyc b/basicsr/metrics/__pycache__/niqe.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3564457c955f01d0f9b0bc87df528c9c1c928d34
Binary files /dev/null and b/basicsr/metrics/__pycache__/niqe.cpython-310.pyc differ
diff --git a/basicsr/metrics/__pycache__/psnr_ssim.cpython-310.pyc b/basicsr/metrics/__pycache__/psnr_ssim.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5772df06bc1e353cbd85ecf3707b1cca3fcb1f95
Binary files /dev/null and b/basicsr/metrics/__pycache__/psnr_ssim.cpython-310.pyc differ
diff --git a/basicsr/metrics/fid.py b/basicsr/metrics/fid.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b0ba6df1de96d93a60c1cfd3dc1fcf4d3d31533
--- /dev/null
+++ b/basicsr/metrics/fid.py
@@ -0,0 +1,89 @@
+import numpy as np
+import torch
+import torch.nn as nn
+from scipy import linalg
+from tqdm import tqdm
+
+from basicsr.archs.inception import InceptionV3
+
+
+def load_patched_inception_v3(device='cuda', resize_input=True, normalize_input=False):
+    # we may not resize the input, but in [rosinality/stylegan2-pytorch] it
+    # does resize the input.
+    inception = InceptionV3([3], resize_input=resize_input, normalize_input=normalize_input)
+    inception = nn.DataParallel(inception).eval().to(device)
+    return inception
+
+
+@torch.no_grad()
+def extract_inception_features(data_generator, inception, len_generator=None, device='cuda'):
+    """Extract inception features.
+
+    Args:
+        data_generator (generator): A data generator.
+        inception (nn.Module): Inception model.
+        len_generator (int): Length of the data_generator to show the
+            progressbar. Default: None.
+        device (str): Device. Default: cuda.
+
+    Returns:
+        Tensor: Extracted features.
+    """
+    if len_generator is not None:
+        pbar = tqdm(total=len_generator, unit='batch', desc='Extract')
+    else:
+        pbar = None
+    features = []
+
+    for data in data_generator:
+        if pbar:
+            pbar.update(1)
+        data = data.to(device)
+        feature = inception(data)[0].view(data.shape[0], -1)
+        features.append(feature.to('cpu'))
+    if pbar:
+        pbar.close()
+    features = torch.cat(features, 0)
+    return features
+
+
+def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6):
+    """Numpy implementation of the Frechet Distance.
+
+    The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) and X_2 ~ N(mu_2, C_2) is:
+    d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
+    Stable version by Dougal J. Sutherland.
+
+    Args:
+        mu1 (np.array): The sample mean over activations.
+        sigma1 (np.array): The covariance matrix over activations for generated samples.
+        mu2 (np.array): The sample mean over activations, precalculated on an representative data set.
+        sigma2 (np.array): The covariance matrix over activations, precalculated on an representative data set.
+
+    Returns:
+        float: The Frechet Distance.
+    """
+    assert mu1.shape == mu2.shape, 'Two mean vectors have different lengths'
+    assert sigma1.shape == sigma2.shape, ('Two covariances have different dimensions')
+
+    cov_sqrt, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False)
+
+    # Product might be almost singular
+    if not np.isfinite(cov_sqrt).all():
+        print('Product of cov matrices is singular. Adding {eps} to diagonal of cov estimates')
+        offset = np.eye(sigma1.shape[0]) * eps
+        cov_sqrt = linalg.sqrtm((sigma1 + offset) @ (sigma2 + offset))
+
+    # Numerical error might give slight imaginary component
+    if np.iscomplexobj(cov_sqrt):
+        if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3):
+            m = np.max(np.abs(cov_sqrt.imag))
+            raise ValueError(f'Imaginary component {m}')
+        cov_sqrt = cov_sqrt.real
+
+    mean_diff = mu1 - mu2
+    mean_norm = mean_diff @ mean_diff
+    trace = np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(cov_sqrt)
+    fid = mean_norm + trace
+
+    return fid
diff --git a/basicsr/metrics/metric_util.py b/basicsr/metrics/metric_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a27c70a043beeeb59cfaf533079492293065448
--- /dev/null
+++ b/basicsr/metrics/metric_util.py
@@ -0,0 +1,45 @@
+import numpy as np
+
+from basicsr.utils import bgr2ycbcr
+
+
+def reorder_image(img, input_order='HWC'):
+    """Reorder images to 'HWC' order.
+
+    If the input_order is (h, w), return (h, w, 1);
+    If the input_order is (c, h, w), return (h, w, c);
+    If the input_order is (h, w, c), return as it is.
+
+    Args:
+        img (ndarray): Input image.
+        input_order (str): Whether the input order is 'HWC' or 'CHW'.
+            If the input image shape is (h, w), input_order will not have
+            effects. Default: 'HWC'.
+
+    Returns:
+        ndarray: reordered image.
+    """
+
+    if input_order not in ['HWC', 'CHW']:
+        raise ValueError(f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'")
+    if len(img.shape) == 2:
+        img = img[..., None]
+    if input_order == 'CHW':
+        img = img.transpose(1, 2, 0)
+    return img
+
+
+def to_y_channel(img):
+    """Change to Y channel of YCbCr.
+
+    Args:
+        img (ndarray): Images with range [0, 255].
+
+    Returns:
+        (ndarray): Images with range [0, 255] (float type) without round.
+    """
+    img = img.astype(np.float32) / 255.
+    if img.ndim == 3 and img.shape[2] == 3:
+        img = bgr2ycbcr(img, y_only=True)
+        img = img[..., None]
+    return img * 255.
diff --git a/basicsr/metrics/niqe.py b/basicsr/metrics/niqe.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3c1467f61d809ec3b2630073118460d9d61a861
--- /dev/null
+++ b/basicsr/metrics/niqe.py
@@ -0,0 +1,199 @@
+import cv2
+import math
+import numpy as np
+import os
+from scipy.ndimage import convolve
+from scipy.special import gamma
+
+from basicsr.metrics.metric_util import reorder_image, to_y_channel
+from basicsr.utils.matlab_functions import imresize
+from basicsr.utils.registry import METRIC_REGISTRY
+
+
+def estimate_aggd_param(block):
+    """Estimate AGGD (Asymmetric Generalized Gaussian Distribution) parameters.
+
+    Args:
+        block (ndarray): 2D Image block.
+
+    Returns:
+        tuple: alpha (float), beta_l (float) and beta_r (float) for the AGGD
+            distribution (Estimating the parames in Equation 7 in the paper).
+    """
+    block = block.flatten()
+    gam = np.arange(0.2, 10.001, 0.001)  # len = 9801
+    gam_reciprocal = np.reciprocal(gam)
+    r_gam = np.square(gamma(gam_reciprocal * 2)) / (gamma(gam_reciprocal) * gamma(gam_reciprocal * 3))
+
+    left_std = np.sqrt(np.mean(block[block < 0]**2))
+    right_std = np.sqrt(np.mean(block[block > 0]**2))
+    gammahat = left_std / right_std
+    rhat = (np.mean(np.abs(block)))**2 / np.mean(block**2)
+    rhatnorm = (rhat * (gammahat**3 + 1) * (gammahat + 1)) / ((gammahat**2 + 1)**2)
+    array_position = np.argmin((r_gam - rhatnorm)**2)
+
+    alpha = gam[array_position]
+    beta_l = left_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha))
+    beta_r = right_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha))
+    return (alpha, beta_l, beta_r)
+
+
+def compute_feature(block):
+    """Compute features.
+
+    Args:
+        block (ndarray): 2D Image block.
+
+    Returns:
+        list: Features with length of 18.
+    """
+    feat = []
+    alpha, beta_l, beta_r = estimate_aggd_param(block)
+    feat.extend([alpha, (beta_l + beta_r) / 2])
+
+    # distortions disturb the fairly regular structure of natural images.
+    # This deviation can be captured by analyzing the sample distribution of
+    # the products of pairs of adjacent coefficients computed along
+    # horizontal, vertical and diagonal orientations.
+    shifts = [[0, 1], [1, 0], [1, 1], [1, -1]]
+    for i in range(len(shifts)):
+        shifted_block = np.roll(block, shifts[i], axis=(0, 1))
+        alpha, beta_l, beta_r = estimate_aggd_param(block * shifted_block)
+        # Eq. 8
+        mean = (beta_r - beta_l) * (gamma(2 / alpha) / gamma(1 / alpha))
+        feat.extend([alpha, mean, beta_l, beta_r])
+    return feat
+
+
+def niqe(img, mu_pris_param, cov_pris_param, gaussian_window, block_size_h=96, block_size_w=96):
+    """Calculate NIQE (Natural Image Quality Evaluator) metric.
+
+    ``Paper: Making a "Completely Blind" Image Quality Analyzer``
+
+    This implementation could produce almost the same results as the official
+    MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip
+
+    Note that we do not include block overlap height and width, since they are
+    always 0 in the official implementation.
+
+    For good performance, it is advisable by the official implementation to
+    divide the distorted image in to the same size patched as used for the
+    construction of multivariate Gaussian model.
+
+    Args:
+        img (ndarray): Input image whose quality needs to be computed. The
+            image must be a gray or Y (of YCbCr) image with shape (h, w).
+            Range [0, 255] with float type.
+        mu_pris_param (ndarray): Mean of a pre-defined multivariate Gaussian
+            model calculated on the pristine dataset.
+        cov_pris_param (ndarray): Covariance of a pre-defined multivariate
+            Gaussian model calculated on the pristine dataset.
+        gaussian_window (ndarray): A 7x7 Gaussian window used for smoothing the
+            image.
+        block_size_h (int): Height of the blocks in to which image is divided.
+            Default: 96 (the official recommended value).
+        block_size_w (int): Width of the blocks in to which image is divided.
+            Default: 96 (the official recommended value).
+    """
+    assert img.ndim == 2, ('Input image must be a gray or Y (of YCbCr) image with shape (h, w).')
+    # crop image
+    h, w = img.shape
+    num_block_h = math.floor(h / block_size_h)
+    num_block_w = math.floor(w / block_size_w)
+    img = img[0:num_block_h * block_size_h, 0:num_block_w * block_size_w]
+
+    distparam = []  # dist param is actually the multiscale features
+    for scale in (1, 2):  # perform on two scales (1, 2)
+        mu = convolve(img, gaussian_window, mode='nearest')
+        sigma = np.sqrt(np.abs(convolve(np.square(img), gaussian_window, mode='nearest') - np.square(mu)))
+        # normalize, as in Eq. 1 in the paper
+        img_nomalized = (img - mu) / (sigma + 1)
+
+        feat = []
+        for idx_w in range(num_block_w):
+            for idx_h in range(num_block_h):
+                # process ecah block
+                block = img_nomalized[idx_h * block_size_h // scale:(idx_h + 1) * block_size_h // scale,
+                                      idx_w * block_size_w // scale:(idx_w + 1) * block_size_w // scale]
+                feat.append(compute_feature(block))
+
+        distparam.append(np.array(feat))
+
+        if scale == 1:
+            img = imresize(img / 255., scale=0.5, antialiasing=True)
+            img = img * 255.
+
+    distparam = np.concatenate(distparam, axis=1)
+
+    # fit a MVG (multivariate Gaussian) model to distorted patch features
+    mu_distparam = np.nanmean(distparam, axis=0)
+    # use nancov. ref: https://ww2.mathworks.cn/help/stats/nancov.html
+    distparam_no_nan = distparam[~np.isnan(distparam).any(axis=1)]
+    cov_distparam = np.cov(distparam_no_nan, rowvar=False)
+
+    # compute niqe quality, Eq. 10 in the paper
+    invcov_param = np.linalg.pinv((cov_pris_param + cov_distparam) / 2)
+    quality = np.matmul(
+        np.matmul((mu_pris_param - mu_distparam), invcov_param), np.transpose((mu_pris_param - mu_distparam)))
+
+    quality = np.sqrt(quality)
+    quality = float(np.squeeze(quality))
+    return quality
+
+
+@METRIC_REGISTRY.register()
+def calculate_niqe(img, crop_border, input_order='HWC', convert_to='y', **kwargs):
+    """Calculate NIQE (Natural Image Quality Evaluator) metric.
+
+    ``Paper: Making a "Completely Blind" Image Quality Analyzer``
+
+    This implementation could produce almost the same results as the official
+    MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip
+
+    > MATLAB R2021a result for tests/data/baboon.png: 5.72957338 (5.7296)
+    > Our re-implementation result for tests/data/baboon.png: 5.7295763 (5.7296)
+
+    We use the official params estimated from the pristine dataset.
+    We use the recommended block size (96, 96) without overlaps.
+
+    Args:
+        img (ndarray): Input image whose quality needs to be computed.
+            The input image must be in range [0, 255] with float/int type.
+            The input_order of image can be 'HW' or 'HWC' or 'CHW'. (BGR order)
+            If the input order is 'HWC' or 'CHW', it will be converted to gray
+            or Y (of YCbCr) image according to the ``convert_to`` argument.
+        crop_border (int): Cropped pixels in each edge of an image. These
+            pixels are not involved in the metric calculation.
+        input_order (str): Whether the input order is 'HW', 'HWC' or 'CHW'.
+            Default: 'HWC'.
+        convert_to (str): Whether converted to 'y' (of MATLAB YCbCr) or 'gray'.
+            Default: 'y'.
+
+    Returns:
+        float: NIQE result.
+    """
+    ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
+    # we use the official params estimated from the pristine dataset.
+    niqe_pris_params = np.load(os.path.join(ROOT_DIR, 'niqe_pris_params.npz'))
+    mu_pris_param = niqe_pris_params['mu_pris_param']
+    cov_pris_param = niqe_pris_params['cov_pris_param']
+    gaussian_window = niqe_pris_params['gaussian_window']
+
+    img = img.astype(np.float32)
+    if input_order != 'HW':
+        img = reorder_image(img, input_order=input_order)
+        if convert_to == 'y':
+            img = to_y_channel(img)
+        elif convert_to == 'gray':
+            img = cv2.cvtColor(img / 255., cv2.COLOR_BGR2GRAY) * 255.
+        img = np.squeeze(img)
+
+    if crop_border != 0:
+        img = img[crop_border:-crop_border, crop_border:-crop_border]
+
+    # round is necessary for being consistent with MATLAB's result
+    img = img.round()
+
+    niqe_result = niqe(img, mu_pris_param, cov_pris_param, gaussian_window)
+
+    return niqe_result
diff --git a/basicsr/metrics/niqe_pris_params.npz b/basicsr/metrics/niqe_pris_params.npz
new file mode 100644
index 0000000000000000000000000000000000000000..42f06a9a18e6ed8bbf7933bec1477b189ef798de
--- /dev/null
+++ b/basicsr/metrics/niqe_pris_params.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2a7c182a68c9e7f1b2e2e5ec723279d6f65d912b6fcaf37eb2bf03d7367c4296
+size 11850
diff --git a/basicsr/metrics/psnr_ssim.py b/basicsr/metrics/psnr_ssim.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab03113f89805c990ff22795601274bf45db23a1
--- /dev/null
+++ b/basicsr/metrics/psnr_ssim.py
@@ -0,0 +1,231 @@
+import cv2
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from basicsr.metrics.metric_util import reorder_image, to_y_channel
+from basicsr.utils.color_util import rgb2ycbcr_pt
+from basicsr.utils.registry import METRIC_REGISTRY
+
+
+@METRIC_REGISTRY.register()
+def calculate_psnr(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs):
+    """Calculate PSNR (Peak Signal-to-Noise Ratio).
+
+    Reference: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
+
+    Args:
+        img (ndarray): Images with range [0, 255].
+        img2 (ndarray): Images with range [0, 255].
+        crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
+        input_order (str): Whether the input order is 'HWC' or 'CHW'. Default: 'HWC'.
+        test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
+
+    Returns:
+        float: PSNR result.
+    """
+
+    assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
+    if input_order not in ['HWC', 'CHW']:
+        raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"')
+    img = reorder_image(img, input_order=input_order)
+    img2 = reorder_image(img2, input_order=input_order)
+
+    if crop_border != 0:
+        img = img[crop_border:-crop_border, crop_border:-crop_border, ...]
+        img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
+
+    if test_y_channel:
+        img = to_y_channel(img)
+        img2 = to_y_channel(img2)
+
+    img = img.astype(np.float64)
+    img2 = img2.astype(np.float64)
+
+    mse = np.mean((img - img2)**2)
+    if mse == 0:
+        return float('inf')
+    return 10. * np.log10(255. * 255. / mse)
+
+
+@METRIC_REGISTRY.register()
+def calculate_psnr_pt(img, img2, crop_border, test_y_channel=False, **kwargs):
+    """Calculate PSNR (Peak Signal-to-Noise Ratio) (PyTorch version).
+
+    Reference: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
+
+    Args:
+        img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
+        img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
+        crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
+        test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
+
+    Returns:
+        float: PSNR result.
+    """
+
+    assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
+
+    if crop_border != 0:
+        img = img[:, :, crop_border:-crop_border, crop_border:-crop_border]
+        img2 = img2[:, :, crop_border:-crop_border, crop_border:-crop_border]
+
+    if test_y_channel:
+        img = rgb2ycbcr_pt(img, y_only=True)
+        img2 = rgb2ycbcr_pt(img2, y_only=True)
+
+    img = img.to(torch.float64)
+    img2 = img2.to(torch.float64)
+
+    mse = torch.mean((img - img2)**2, dim=[1, 2, 3])
+    return 10. * torch.log10(1. / (mse + 1e-8))
+
+
+@METRIC_REGISTRY.register()
+def calculate_ssim(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs):
+    """Calculate SSIM (structural similarity).
+
+    ``Paper: Image quality assessment: From error visibility to structural similarity``
+
+    The results are the same as that of the official released MATLAB code in
+    https://ece.uwaterloo.ca/~z70wang/research/ssim/.
+
+    For three-channel images, SSIM is calculated for each channel and then
+    averaged.
+
+    Args:
+        img (ndarray): Images with range [0, 255].
+        img2 (ndarray): Images with range [0, 255].
+        crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
+        input_order (str): Whether the input order is 'HWC' or 'CHW'.
+            Default: 'HWC'.
+        test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
+
+    Returns:
+        float: SSIM result.
+    """
+
+    assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
+    if input_order not in ['HWC', 'CHW']:
+        raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"')
+    img = reorder_image(img, input_order=input_order)
+    img2 = reorder_image(img2, input_order=input_order)
+
+    if crop_border != 0:
+        img = img[crop_border:-crop_border, crop_border:-crop_border, ...]
+        img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
+
+    if test_y_channel:
+        img = to_y_channel(img)
+        img2 = to_y_channel(img2)
+
+    img = img.astype(np.float64)
+    img2 = img2.astype(np.float64)
+
+    ssims = []
+    for i in range(img.shape[2]):
+        ssims.append(_ssim(img[..., i], img2[..., i]))
+    return np.array(ssims).mean()
+
+
+@METRIC_REGISTRY.register()
+def calculate_ssim_pt(img, img2, crop_border, test_y_channel=False, **kwargs):
+    """Calculate SSIM (structural similarity) (PyTorch version).
+
+    ``Paper: Image quality assessment: From error visibility to structural similarity``
+
+    The results are the same as that of the official released MATLAB code in
+    https://ece.uwaterloo.ca/~z70wang/research/ssim/.
+
+    For three-channel images, SSIM is calculated for each channel and then
+    averaged.
+
+    Args:
+        img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
+        img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
+        crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
+        test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
+
+    Returns:
+        float: SSIM result.
+    """
+
+    assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
+
+    if crop_border != 0:
+        img = img[:, :, crop_border:-crop_border, crop_border:-crop_border]
+        img2 = img2[:, :, crop_border:-crop_border, crop_border:-crop_border]
+
+    if test_y_channel:
+        img = rgb2ycbcr_pt(img, y_only=True)
+        img2 = rgb2ycbcr_pt(img2, y_only=True)
+
+    img = img.to(torch.float64)
+    img2 = img2.to(torch.float64)
+
+    ssim = _ssim_pth(img * 255., img2 * 255.)
+    return ssim
+
+
+def _ssim(img, img2):
+    """Calculate SSIM (structural similarity) for one channel images.
+
+    It is called by func:`calculate_ssim`.
+
+    Args:
+        img (ndarray): Images with range [0, 255] with order 'HWC'.
+        img2 (ndarray): Images with range [0, 255] with order 'HWC'.
+
+    Returns:
+        float: SSIM result.
+    """
+
+    c1 = (0.01 * 255)**2
+    c2 = (0.03 * 255)**2
+    kernel = cv2.getGaussianKernel(11, 1.5)
+    window = np.outer(kernel, kernel.transpose())
+
+    mu1 = cv2.filter2D(img, -1, window)[5:-5, 5:-5]  # valid mode for window size 11
+    mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
+    mu1_sq = mu1**2
+    mu2_sq = mu2**2
+    mu1_mu2 = mu1 * mu2
+    sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq
+    sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
+    sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
+
+    ssim_map = ((2 * mu1_mu2 + c1) * (2 * sigma12 + c2)) / ((mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2))
+    return ssim_map.mean()
+
+
+def _ssim_pth(img, img2):
+    """Calculate SSIM (structural similarity) (PyTorch version).
+
+    It is called by func:`calculate_ssim_pt`.
+
+    Args:
+        img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
+        img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
+
+    Returns:
+        float: SSIM result.
+    """
+    c1 = (0.01 * 255)**2
+    c2 = (0.03 * 255)**2
+
+    kernel = cv2.getGaussianKernel(11, 1.5)
+    window = np.outer(kernel, kernel.transpose())
+    window = torch.from_numpy(window).view(1, 1, 11, 11).expand(img.size(1), 1, 11, 11).to(img.dtype).to(img.device)
+
+    mu1 = F.conv2d(img, window, stride=1, padding=0, groups=img.shape[1])  # valid mode
+    mu2 = F.conv2d(img2, window, stride=1, padding=0, groups=img2.shape[1])  # valid mode
+    mu1_sq = mu1.pow(2)
+    mu2_sq = mu2.pow(2)
+    mu1_mu2 = mu1 * mu2
+    sigma1_sq = F.conv2d(img * img, window, stride=1, padding=0, groups=img.shape[1]) - mu1_sq
+    sigma2_sq = F.conv2d(img2 * img2, window, stride=1, padding=0, groups=img.shape[1]) - mu2_sq
+    sigma12 = F.conv2d(img * img2, window, stride=1, padding=0, groups=img.shape[1]) - mu1_mu2
+
+    cs_map = (2 * sigma12 + c2) / (sigma1_sq + sigma2_sq + c2)
+    ssim_map = ((2 * mu1_mu2 + c1) / (mu1_sq + mu2_sq + c1)) * cs_map
+    return ssim_map.mean([1, 2, 3])
diff --git a/basicsr/metrics/test_metrics/test_psnr_ssim.py b/basicsr/metrics/test_metrics/test_psnr_ssim.py
new file mode 100644
index 0000000000000000000000000000000000000000..18b05a73a0e38e89b2321ddc9415123a92f5c5a4
--- /dev/null
+++ b/basicsr/metrics/test_metrics/test_psnr_ssim.py
@@ -0,0 +1,52 @@
+import cv2
+import torch
+
+from basicsr.metrics import calculate_psnr, calculate_ssim
+from basicsr.metrics.psnr_ssim import calculate_psnr_pt, calculate_ssim_pt
+from basicsr.utils import img2tensor
+
+
+def test(img_path, img_path2, crop_border, test_y_channel=False):
+    img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
+    img2 = cv2.imread(img_path2, cv2.IMREAD_UNCHANGED)
+
+    # --------------------- Numpy ---------------------
+    psnr = calculate_psnr(img, img2, crop_border=crop_border, input_order='HWC', test_y_channel=test_y_channel)
+    ssim = calculate_ssim(img, img2, crop_border=crop_border, input_order='HWC', test_y_channel=test_y_channel)
+    print(f'\tNumpy\tPSNR: {psnr:.6f} dB, \tSSIM: {ssim:.6f}')
+
+    # --------------------- PyTorch (CPU) ---------------------
+    img = img2tensor(img / 255., bgr2rgb=True, float32=True).unsqueeze_(0)
+    img2 = img2tensor(img2 / 255., bgr2rgb=True, float32=True).unsqueeze_(0)
+
+    psnr_pth = calculate_psnr_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel)
+    ssim_pth = calculate_ssim_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel)
+    print(f'\tTensor (CPU) \tPSNR: {psnr_pth[0]:.6f} dB, \tSSIM: {ssim_pth[0]:.6f}')
+
+    # --------------------- PyTorch (GPU) ---------------------
+    img = img.cuda()
+    img2 = img2.cuda()
+    psnr_pth = calculate_psnr_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel)
+    ssim_pth = calculate_ssim_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel)
+    print(f'\tTensor (GPU) \tPSNR: {psnr_pth[0]:.6f} dB, \tSSIM: {ssim_pth[0]:.6f}')
+
+    psnr_pth = calculate_psnr_pt(
+        torch.repeat_interleave(img, 2, dim=0),
+        torch.repeat_interleave(img2, 2, dim=0),
+        crop_border=crop_border,
+        test_y_channel=test_y_channel)
+    ssim_pth = calculate_ssim_pt(
+        torch.repeat_interleave(img, 2, dim=0),
+        torch.repeat_interleave(img2, 2, dim=0),
+        crop_border=crop_border,
+        test_y_channel=test_y_channel)
+    print(f'\tTensor (GPU batch) \tPSNR: {psnr_pth[0]:.6f}, {psnr_pth[1]:.6f} dB,'
+          f'\tSSIM: {ssim_pth[0]:.6f}, {ssim_pth[1]:.6f}')
+
+
+if __name__ == '__main__':
+    test('tests/data/bic/baboon.png', 'tests/data/gt/baboon.png', crop_border=4, test_y_channel=False)
+    test('tests/data/bic/baboon.png', 'tests/data/gt/baboon.png', crop_border=4, test_y_channel=True)
+
+    test('tests/data/bic/comic.png', 'tests/data/gt/comic.png', crop_border=4, test_y_channel=False)
+    test('tests/data/bic/comic.png', 'tests/data/gt/comic.png', crop_border=4, test_y_channel=True)
diff --git a/basicsr/models/__init__.py b/basicsr/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..85796deae014c20a9aa600133468d04900c4fb89
--- /dev/null
+++ b/basicsr/models/__init__.py
@@ -0,0 +1,29 @@
+import importlib
+from copy import deepcopy
+from os import path as osp
+
+from basicsr.utils import get_root_logger, scandir
+from basicsr.utils.registry import MODEL_REGISTRY
+
+__all__ = ['build_model']
+
+# automatically scan and import model modules for registry
+# scan all the files under the 'models' folder and collect files ending with '_model.py'
+model_folder = osp.dirname(osp.abspath(__file__))
+model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
+# import all the model modules
+_model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames]
+
+
+def build_model(opt):
+    """Build model from options.
+
+    Args:
+        opt (dict): Configuration. It must contain:
+            model_type (str): Model type.
+    """
+    opt = deepcopy(opt)
+    model = MODEL_REGISTRY.get(opt['model_type'])(opt)
+    logger = get_root_logger()
+    logger.info(f'Model [{model.__class__.__name__}] is created.')
+    return model
diff --git a/basicsr/models/__pycache__/__init__.cpython-310.pyc b/basicsr/models/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ba727bf1df1871606d58e9ed3fb09d78dd475c75
Binary files /dev/null and b/basicsr/models/__pycache__/__init__.cpython-310.pyc differ
diff --git a/basicsr/models/__pycache__/base_model.cpython-310.pyc b/basicsr/models/__pycache__/base_model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a0bac3c7f0da27060809f0e360e906c1641456b6
Binary files /dev/null and b/basicsr/models/__pycache__/base_model.cpython-310.pyc differ
diff --git a/basicsr/models/__pycache__/edvr_model.cpython-310.pyc b/basicsr/models/__pycache__/edvr_model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e9a2fd412f040e103fc453e3d4fcdb2bee147a26
Binary files /dev/null and b/basicsr/models/__pycache__/edvr_model.cpython-310.pyc differ
diff --git a/basicsr/models/__pycache__/esrgan_model.cpython-310.pyc b/basicsr/models/__pycache__/esrgan_model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a29bab0dec8f42183211c68c41703c6a21054266
Binary files /dev/null and b/basicsr/models/__pycache__/esrgan_model.cpython-310.pyc differ
diff --git a/basicsr/models/__pycache__/hifacegan_model.cpython-310.pyc b/basicsr/models/__pycache__/hifacegan_model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c76fa597dbfcccee32d3497130bcd89feca35e72
Binary files /dev/null and b/basicsr/models/__pycache__/hifacegan_model.cpython-310.pyc differ
diff --git a/basicsr/models/__pycache__/lr_scheduler.cpython-310.pyc b/basicsr/models/__pycache__/lr_scheduler.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7aa57c5591728845ca57b616e0d170dd78f08159
Binary files /dev/null and b/basicsr/models/__pycache__/lr_scheduler.cpython-310.pyc differ
diff --git a/basicsr/models/__pycache__/realesrgan_model.cpython-310.pyc b/basicsr/models/__pycache__/realesrgan_model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d87e0fb55748a1f84aa1a6a0972b0b0413cbd9d1
Binary files /dev/null and b/basicsr/models/__pycache__/realesrgan_model.cpython-310.pyc differ
diff --git a/basicsr/models/__pycache__/realesrnet_model.cpython-310.pyc b/basicsr/models/__pycache__/realesrnet_model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6ede21c212796fa90149610073d5b3ce04529f1c
Binary files /dev/null and b/basicsr/models/__pycache__/realesrnet_model.cpython-310.pyc differ
diff --git a/basicsr/models/__pycache__/sr_model.cpython-310.pyc b/basicsr/models/__pycache__/sr_model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6d3b40c61913eb7fdb946423281da584a58e8fb3
Binary files /dev/null and b/basicsr/models/__pycache__/sr_model.cpython-310.pyc differ
diff --git a/basicsr/models/__pycache__/srgan_model.cpython-310.pyc b/basicsr/models/__pycache__/srgan_model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4ab82228e4acdd5bba78db05feacd793e2be2d81
Binary files /dev/null and b/basicsr/models/__pycache__/srgan_model.cpython-310.pyc differ
diff --git a/basicsr/models/__pycache__/stylegan2_model.cpython-310.pyc b/basicsr/models/__pycache__/stylegan2_model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7f083ad65878672a0631fcc943c46b5afa066f03
Binary files /dev/null and b/basicsr/models/__pycache__/stylegan2_model.cpython-310.pyc differ
diff --git a/basicsr/models/__pycache__/swinir_model.cpython-310.pyc b/basicsr/models/__pycache__/swinir_model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3c9c20da06300db42bec0ea2827d8dec178e9591
Binary files /dev/null and b/basicsr/models/__pycache__/swinir_model.cpython-310.pyc differ
diff --git a/basicsr/models/__pycache__/video_base_model.cpython-310.pyc b/basicsr/models/__pycache__/video_base_model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e669db1dff4e1644abee687b1df84514756c2437
Binary files /dev/null and b/basicsr/models/__pycache__/video_base_model.cpython-310.pyc differ
diff --git a/basicsr/models/__pycache__/video_gan_model.cpython-310.pyc b/basicsr/models/__pycache__/video_gan_model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5a48bd51a9bd87a6b9fbb5bcc25063e4059c6b87
Binary files /dev/null and b/basicsr/models/__pycache__/video_gan_model.cpython-310.pyc differ
diff --git a/basicsr/models/__pycache__/video_recurrent_gan_model.cpython-310.pyc b/basicsr/models/__pycache__/video_recurrent_gan_model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5dbefc3f99e6734de5b1bed5d79d00c58129655c
Binary files /dev/null and b/basicsr/models/__pycache__/video_recurrent_gan_model.cpython-310.pyc differ
diff --git a/basicsr/models/__pycache__/video_recurrent_model.cpython-310.pyc b/basicsr/models/__pycache__/video_recurrent_model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..209d9ff667bd722f36273ee01d3798544375607c
Binary files /dev/null and b/basicsr/models/__pycache__/video_recurrent_model.cpython-310.pyc differ
diff --git a/basicsr/models/base_model.py b/basicsr/models/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbf8229f59dee86a7f9f95c1d07da785fb5f15b3
--- /dev/null
+++ b/basicsr/models/base_model.py
@@ -0,0 +1,392 @@
+import os
+import time
+import torch
+from collections import OrderedDict
+from copy import deepcopy
+from torch.nn.parallel import DataParallel, DistributedDataParallel
+
+from basicsr.models import lr_scheduler as lr_scheduler
+from basicsr.utils import get_root_logger
+from basicsr.utils.dist_util import master_only
+
+
+class BaseModel():
+    """Base model."""
+
+    def __init__(self, opt):
+        self.opt = opt
+        self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
+        self.is_train = opt['is_train']
+        self.schedulers = []
+        self.optimizers = []
+
+    def feed_data(self, data):
+        pass
+
+    def optimize_parameters(self):
+        pass
+
+    def get_current_visuals(self):
+        pass
+
+    def save(self, epoch, current_iter):
+        """Save networks and training state."""
+        pass
+
+    def validation(self, dataloader, current_iter, tb_logger, save_img=False):
+        """Validation function.
+
+        Args:
+            dataloader (torch.utils.data.DataLoader): Validation dataloader.
+            current_iter (int): Current iteration.
+            tb_logger (tensorboard logger): Tensorboard logger.
+            save_img (bool): Whether to save images. Default: False.
+        """
+        if self.opt['dist']:
+            self.dist_validation(dataloader, current_iter, tb_logger, save_img)
+        else:
+            self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
+
+    def _initialize_best_metric_results(self, dataset_name):
+        """Initialize the best metric results dict for recording the best metric value and iteration."""
+        if hasattr(self, 'best_metric_results') and dataset_name in self.best_metric_results:
+            return
+        elif not hasattr(self, 'best_metric_results'):
+            self.best_metric_results = dict()
+
+        # add a dataset record
+        record = dict()
+        for metric, content in self.opt['val']['metrics'].items():
+            better = content.get('better', 'higher')
+            init_val = float('-inf') if better == 'higher' else float('inf')
+            record[metric] = dict(better=better, val=init_val, iter=-1)
+        self.best_metric_results[dataset_name] = record
+
+    def _update_best_metric_result(self, dataset_name, metric, val, current_iter):
+        if self.best_metric_results[dataset_name][metric]['better'] == 'higher':
+            if val >= self.best_metric_results[dataset_name][metric]['val']:
+                self.best_metric_results[dataset_name][metric]['val'] = val
+                self.best_metric_results[dataset_name][metric]['iter'] = current_iter
+        else:
+            if val <= self.best_metric_results[dataset_name][metric]['val']:
+                self.best_metric_results[dataset_name][metric]['val'] = val
+                self.best_metric_results[dataset_name][metric]['iter'] = current_iter
+
+    def model_ema(self, decay=0.999):
+        net_g = self.get_bare_model(self.net_g)
+
+        net_g_params = dict(net_g.named_parameters())
+        net_g_ema_params = dict(self.net_g_ema.named_parameters())
+
+        for k in net_g_ema_params.keys():
+            net_g_ema_params[k].data.mul_(decay).add_(net_g_params[k].data, alpha=1 - decay)
+
+    def get_current_log(self):
+        return self.log_dict
+
+    def model_to_device(self, net):
+        """Model to device. It also warps models with DistributedDataParallel
+        or DataParallel.
+
+        Args:
+            net (nn.Module)
+        """
+        net = net.to(self.device)
+        if self.opt['dist']:
+            find_unused_parameters = self.opt.get('find_unused_parameters', False)
+            net = DistributedDataParallel(
+                net, device_ids=[torch.cuda.current_device()], find_unused_parameters=find_unused_parameters)
+        elif self.opt['num_gpu'] > 1:
+            net = DataParallel(net)
+        return net
+
+    def get_optimizer(self, optim_type, params, lr, **kwargs):
+        if optim_type == 'Adam':
+            optimizer = torch.optim.Adam(params, lr, **kwargs)
+        elif optim_type == 'AdamW':
+            optimizer = torch.optim.AdamW(params, lr, **kwargs)
+        elif optim_type == 'Adamax':
+            optimizer = torch.optim.Adamax(params, lr, **kwargs)
+        elif optim_type == 'SGD':
+            optimizer = torch.optim.SGD(params, lr, **kwargs)
+        elif optim_type == 'ASGD':
+            optimizer = torch.optim.ASGD(params, lr, **kwargs)
+        elif optim_type == 'RMSprop':
+            optimizer = torch.optim.RMSprop(params, lr, **kwargs)
+        elif optim_type == 'Rprop':
+            optimizer = torch.optim.Rprop(params, lr, **kwargs)
+        else:
+            raise NotImplementedError(f'optimizer {optim_type} is not supported yet.')
+        return optimizer
+
+    def setup_schedulers(self):
+        """Set up schedulers."""
+        train_opt = self.opt['train']
+        scheduler_type = train_opt['scheduler'].pop('type')
+        if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']:
+            for optimizer in self.optimizers:
+                self.schedulers.append(lr_scheduler.MultiStepRestartLR(optimizer, **train_opt['scheduler']))
+        elif scheduler_type == 'CosineAnnealingRestartLR':
+            for optimizer in self.optimizers:
+                self.schedulers.append(lr_scheduler.CosineAnnealingRestartLR(optimizer, **train_opt['scheduler']))
+        else:
+            raise NotImplementedError(f'Scheduler {scheduler_type} is not implemented yet.')
+
+    def get_bare_model(self, net):
+        """Get bare model, especially under wrapping with
+        DistributedDataParallel or DataParallel.
+        """
+        if isinstance(net, (DataParallel, DistributedDataParallel)):
+            net = net.module
+        return net
+
+    @master_only
+    def print_network(self, net):
+        """Print the str and parameter number of a network.
+
+        Args:
+            net (nn.Module)
+        """
+        if isinstance(net, (DataParallel, DistributedDataParallel)):
+            net_cls_str = f'{net.__class__.__name__} - {net.module.__class__.__name__}'
+        else:
+            net_cls_str = f'{net.__class__.__name__}'
+
+        net = self.get_bare_model(net)
+        net_str = str(net)
+        net_params = sum(map(lambda x: x.numel(), net.parameters()))
+
+        logger = get_root_logger()
+        logger.info(f'Network: {net_cls_str}, with parameters: {net_params:,d}')
+        logger.info(net_str)
+
+    def _set_lr(self, lr_groups_l):
+        """Set learning rate for warm-up.
+
+        Args:
+            lr_groups_l (list): List for lr_groups, each for an optimizer.
+        """
+        for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
+            for param_group, lr in zip(optimizer.param_groups, lr_groups):
+                param_group['lr'] = lr
+
+    def _get_init_lr(self):
+        """Get the initial lr, which is set by the scheduler.
+        """
+        init_lr_groups_l = []
+        for optimizer in self.optimizers:
+            init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups])
+        return init_lr_groups_l
+
+    def update_learning_rate(self, current_iter, warmup_iter=-1):
+        """Update learning rate.
+
+        Args:
+            current_iter (int): Current iteration.
+            warmup_iter (int): Warm-up iter numbers. -1 for no warm-up.
+                Default: -1.
+        """
+        if current_iter > 1:
+            for scheduler in self.schedulers:
+                scheduler.step()
+        # set up warm-up learning rate
+        if current_iter < warmup_iter:
+            # get initial lr for each group
+            init_lr_g_l = self._get_init_lr()
+            # modify warming-up learning rates
+            # currently only support linearly warm up
+            warm_up_lr_l = []
+            for init_lr_g in init_lr_g_l:
+                warm_up_lr_l.append([v / warmup_iter * current_iter for v in init_lr_g])
+            # set learning rate
+            self._set_lr(warm_up_lr_l)
+
+    def get_current_learning_rate(self):
+        return [param_group['lr'] for param_group in self.optimizers[0].param_groups]
+
+    @master_only
+    def save_network(self, net, net_label, current_iter, param_key='params'):
+        """Save networks.
+
+        Args:
+            net (nn.Module | list[nn.Module]): Network(s) to be saved.
+            net_label (str): Network label.
+            current_iter (int): Current iter number.
+            param_key (str | list[str]): The parameter key(s) to save network.
+                Default: 'params'.
+        """
+        if current_iter == -1:
+            current_iter = 'latest'
+        save_filename = f'{net_label}_{current_iter}.pth'
+        save_path = os.path.join(self.opt['path']['models'], save_filename)
+
+        net = net if isinstance(net, list) else [net]
+        param_key = param_key if isinstance(param_key, list) else [param_key]
+        assert len(net) == len(param_key), 'The lengths of net and param_key should be the same.'
+
+        save_dict = {}
+        for net_, param_key_ in zip(net, param_key):
+            net_ = self.get_bare_model(net_)
+            state_dict = net_.state_dict()
+            for key, param in state_dict.items():
+                if key.startswith('module.'):  # remove unnecessary 'module.'
+                    key = key[7:]
+                state_dict[key] = param.cpu()
+            save_dict[param_key_] = state_dict
+
+        # avoid occasional writing errors
+        retry = 3
+        while retry > 0:
+            try:
+                torch.save(save_dict, save_path)
+            except Exception as e:
+                logger = get_root_logger()
+                logger.warning(f'Save model error: {e}, remaining retry times: {retry - 1}')
+                time.sleep(1)
+            else:
+                break
+            finally:
+                retry -= 1
+        if retry == 0:
+            logger.warning(f'Still cannot save {save_path}. Just ignore it.')
+            # raise IOError(f'Cannot save {save_path}.')
+
+    def _print_different_keys_loading(self, crt_net, load_net, strict=True):
+        """Print keys with different name or different size when loading models.
+
+        1. Print keys with different names.
+        2. If strict=False, print the same key but with different tensor size.
+            It also ignore these keys with different sizes (not load).
+
+        Args:
+            crt_net (torch model): Current network.
+            load_net (dict): Loaded network.
+            strict (bool): Whether strictly loaded. Default: True.
+        """
+        crt_net = self.get_bare_model(crt_net)
+        crt_net = crt_net.state_dict()
+        crt_net_keys = set(crt_net.keys())
+        load_net_keys = set(load_net.keys())
+
+        logger = get_root_logger()
+        if crt_net_keys != load_net_keys:
+            logger.warning('Current net - loaded net:')
+            for v in sorted(list(crt_net_keys - load_net_keys)):
+                logger.warning(f'  {v}')
+            logger.warning('Loaded net - current net:')
+            for v in sorted(list(load_net_keys - crt_net_keys)):
+                logger.warning(f'  {v}')
+
+        # check the size for the same keys
+        if not strict:
+            common_keys = crt_net_keys & load_net_keys
+            for k in common_keys:
+                if crt_net[k].size() != load_net[k].size():
+                    logger.warning(f'Size different, ignore [{k}]: crt_net: '
+                                   f'{crt_net[k].shape}; load_net: {load_net[k].shape}')
+                    load_net[k + '.ignore'] = load_net.pop(k)
+
+    def load_network(self, net, load_path, strict=True, param_key='params'):
+        """Load network.
+
+        Args:
+            load_path (str): The path of networks to be loaded.
+            net (nn.Module): Network.
+            strict (bool): Whether strictly loaded.
+            param_key (str): The parameter key of loaded network. If set to
+                None, use the root 'path'.
+                Default: 'params'.
+        """
+        logger = get_root_logger()
+        net = self.get_bare_model(net)
+        load_net = torch.load(load_path, map_location=lambda storage, loc: storage)
+        if param_key is not None:
+            if param_key not in load_net and 'params' in load_net:
+                param_key = 'params'
+                logger.info('Loading: params_ema does not exist, use params.')
+            load_net = load_net[param_key]
+        logger.info(f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].')
+        # remove unnecessary 'module.'
+        for k, v in deepcopy(load_net).items():
+            if k.startswith('module.'):
+                load_net[k[7:]] = v
+                load_net.pop(k)
+        self._print_different_keys_loading(net, load_net, strict)
+        net.load_state_dict(load_net, strict=strict)
+
+    @master_only
+    def save_training_state(self, epoch, current_iter):
+        """Save training states during training, which will be used for
+        resuming.
+
+        Args:
+            epoch (int): Current epoch.
+            current_iter (int): Current iteration.
+        """
+        if current_iter != -1:
+            state = {'epoch': epoch, 'iter': current_iter, 'optimizers': [], 'schedulers': []}
+            for o in self.optimizers:
+                state['optimizers'].append(o.state_dict())
+            for s in self.schedulers:
+                state['schedulers'].append(s.state_dict())
+            save_filename = f'{current_iter}.state'
+            save_path = os.path.join(self.opt['path']['training_states'], save_filename)
+
+            # avoid occasional writing errors
+            retry = 3
+            while retry > 0:
+                try:
+                    torch.save(state, save_path)
+                except Exception as e:
+                    logger = get_root_logger()
+                    logger.warning(f'Save training state error: {e}, remaining retry times: {retry - 1}')
+                    time.sleep(1)
+                else:
+                    break
+                finally:
+                    retry -= 1
+            if retry == 0:
+                logger.warning(f'Still cannot save {save_path}. Just ignore it.')
+                # raise IOError(f'Cannot save {save_path}.')
+
+    def resume_training(self, resume_state):
+        """Reload the optimizers and schedulers for resumed training.
+
+        Args:
+            resume_state (dict): Resume state.
+        """
+        resume_optimizers = resume_state['optimizers']
+        resume_schedulers = resume_state['schedulers']
+        assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers'
+        assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers'
+        for i, o in enumerate(resume_optimizers):
+            self.optimizers[i].load_state_dict(o)
+        for i, s in enumerate(resume_schedulers):
+            self.schedulers[i].load_state_dict(s)
+
+    def reduce_loss_dict(self, loss_dict):
+        """reduce loss dict.
+
+        In distributed training, it averages the losses among different GPUs .
+
+        Args:
+            loss_dict (OrderedDict): Loss dict.
+        """
+        with torch.no_grad():
+            if self.opt['dist']:
+                keys = []
+                losses = []
+                for name, value in loss_dict.items():
+                    keys.append(name)
+                    losses.append(value)
+                losses = torch.stack(losses, 0)
+                torch.distributed.reduce(losses, dst=0)
+                if self.opt['rank'] == 0:
+                    losses /= self.opt['world_size']
+                loss_dict = {key: loss for key, loss in zip(keys, losses)}
+
+            log_dict = OrderedDict()
+            for name, value in loss_dict.items():
+                log_dict[name] = value.mean().item()
+
+            return log_dict
diff --git a/basicsr/models/edvr_model.py b/basicsr/models/edvr_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bdbf7b94fe3f06c76fbf2a4941621f64e0003e7
--- /dev/null
+++ b/basicsr/models/edvr_model.py
@@ -0,0 +1,62 @@
+from basicsr.utils import get_root_logger
+from basicsr.utils.registry import MODEL_REGISTRY
+from .video_base_model import VideoBaseModel
+
+
+@MODEL_REGISTRY.register()
+class EDVRModel(VideoBaseModel):
+    """EDVR Model.
+
+    Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks.  # noqa: E501
+    """
+
+    def __init__(self, opt):
+        super(EDVRModel, self).__init__(opt)
+        if self.is_train:
+            self.train_tsa_iter = opt['train'].get('tsa_iter')
+
+    def setup_optimizers(self):
+        train_opt = self.opt['train']
+        dcn_lr_mul = train_opt.get('dcn_lr_mul', 1)
+        logger = get_root_logger()
+        logger.info(f'Multiple the learning rate for dcn with {dcn_lr_mul}.')
+        if dcn_lr_mul == 1:
+            optim_params = self.net_g.parameters()
+        else:  # separate dcn params and normal params for different lr
+            normal_params = []
+            dcn_params = []
+            for name, param in self.net_g.named_parameters():
+                if 'dcn' in name:
+                    dcn_params.append(param)
+                else:
+                    normal_params.append(param)
+            optim_params = [
+                {  # add normal params first
+                    'params': normal_params,
+                    'lr': train_opt['optim_g']['lr']
+                },
+                {
+                    'params': dcn_params,
+                    'lr': train_opt['optim_g']['lr'] * dcn_lr_mul
+                },
+            ]
+
+        optim_type = train_opt['optim_g'].pop('type')
+        self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g'])
+        self.optimizers.append(self.optimizer_g)
+
+    def optimize_parameters(self, current_iter):
+        if self.train_tsa_iter:
+            if current_iter == 1:
+                logger = get_root_logger()
+                logger.info(f'Only train TSA module for {self.train_tsa_iter} iters.')
+                for name, param in self.net_g.named_parameters():
+                    if 'fusion' not in name:
+                        param.requires_grad = False
+            elif current_iter == self.train_tsa_iter:
+                logger = get_root_logger()
+                logger.warning('Train all the parameters.')
+                for param in self.net_g.parameters():
+                    param.requires_grad = True
+
+        super(EDVRModel, self).optimize_parameters(current_iter)
diff --git a/basicsr/models/esrgan_model.py b/basicsr/models/esrgan_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d746d0e29418d9e8f35fa9c1e3a315d694075be
--- /dev/null
+++ b/basicsr/models/esrgan_model.py
@@ -0,0 +1,83 @@
+import torch
+from collections import OrderedDict
+
+from basicsr.utils.registry import MODEL_REGISTRY
+from .srgan_model import SRGANModel
+
+
+@MODEL_REGISTRY.register()
+class ESRGANModel(SRGANModel):
+    """ESRGAN model for single image super-resolution."""
+
+    def optimize_parameters(self, current_iter):
+        # optimize net_g
+        for p in self.net_d.parameters():
+            p.requires_grad = False
+
+        self.optimizer_g.zero_grad()
+        self.output = self.net_g(self.lq)
+
+        l_g_total = 0
+        loss_dict = OrderedDict()
+        if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
+            # pixel loss
+            if self.cri_pix:
+                l_g_pix = self.cri_pix(self.output, self.gt)
+                l_g_total += l_g_pix
+                loss_dict['l_g_pix'] = l_g_pix
+            # perceptual loss
+            if self.cri_perceptual:
+                l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt)
+                if l_g_percep is not None:
+                    l_g_total += l_g_percep
+                    loss_dict['l_g_percep'] = l_g_percep
+                if l_g_style is not None:
+                    l_g_total += l_g_style
+                    loss_dict['l_g_style'] = l_g_style
+            # gan loss (relativistic gan)
+            real_d_pred = self.net_d(self.gt).detach()
+            fake_g_pred = self.net_d(self.output)
+            l_g_real = self.cri_gan(real_d_pred - torch.mean(fake_g_pred), False, is_disc=False)
+            l_g_fake = self.cri_gan(fake_g_pred - torch.mean(real_d_pred), True, is_disc=False)
+            l_g_gan = (l_g_real + l_g_fake) / 2
+
+            l_g_total += l_g_gan
+            loss_dict['l_g_gan'] = l_g_gan
+
+            l_g_total.backward()
+            self.optimizer_g.step()
+
+        # optimize net_d
+        for p in self.net_d.parameters():
+            p.requires_grad = True
+
+        self.optimizer_d.zero_grad()
+        # gan loss (relativistic gan)
+
+        # In order to avoid the error in distributed training:
+        # "Error detected in CudnnBatchNormBackward: RuntimeError: one of
+        # the variables needed for gradient computation has been modified by
+        # an inplace operation",
+        # we separate the backwards for real and fake, and also detach the
+        # tensor for calculating mean.
+
+        # real
+        fake_d_pred = self.net_d(self.output).detach()
+        real_d_pred = self.net_d(self.gt)
+        l_d_real = self.cri_gan(real_d_pred - torch.mean(fake_d_pred), True, is_disc=True) * 0.5
+        l_d_real.backward()
+        # fake
+        fake_d_pred = self.net_d(self.output.detach())
+        l_d_fake = self.cri_gan(fake_d_pred - torch.mean(real_d_pred.detach()), False, is_disc=True) * 0.5
+        l_d_fake.backward()
+        self.optimizer_d.step()
+
+        loss_dict['l_d_real'] = l_d_real
+        loss_dict['l_d_fake'] = l_d_fake
+        loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
+        loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
+
+        self.log_dict = self.reduce_loss_dict(loss_dict)
+
+        if self.ema_decay > 0:
+            self.model_ema(decay=self.ema_decay)
diff --git a/basicsr/models/hifacegan_model.py b/basicsr/models/hifacegan_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..435a2b179d6b7c670fe96a83ce45b461300b2c89
--- /dev/null
+++ b/basicsr/models/hifacegan_model.py
@@ -0,0 +1,288 @@
+import torch
+from collections import OrderedDict
+from os import path as osp
+from tqdm import tqdm
+
+from basicsr.archs import build_network
+from basicsr.losses import build_loss
+from basicsr.metrics import calculate_metric
+from basicsr.utils import imwrite, tensor2img
+from basicsr.utils.registry import MODEL_REGISTRY
+from .sr_model import SRModel
+
+
+@MODEL_REGISTRY.register()
+class HiFaceGANModel(SRModel):
+    """HiFaceGAN model for generic-purpose face restoration.
+    No prior modeling required, works for any degradations.
+    Currently doesn't support EMA for inference.
+    """
+
+    def init_training_settings(self):
+
+        train_opt = self.opt['train']
+        self.ema_decay = train_opt.get('ema_decay', 0)
+        if self.ema_decay > 0:
+            raise (NotImplementedError('HiFaceGAN does not support EMA now. Pass'))
+
+        self.net_g.train()
+
+        self.net_d = build_network(self.opt['network_d'])
+        self.net_d = self.model_to_device(self.net_d)
+        self.print_network(self.net_d)
+
+        # define losses
+        # HiFaceGAN does not use pixel loss by default
+        if train_opt.get('pixel_opt'):
+            self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
+        else:
+            self.cri_pix = None
+
+        if train_opt.get('perceptual_opt'):
+            self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
+        else:
+            self.cri_perceptual = None
+
+        if train_opt.get('feature_matching_opt'):
+            self.cri_feat = build_loss(train_opt['feature_matching_opt']).to(self.device)
+        else:
+            self.cri_feat = None
+
+        if self.cri_pix is None and self.cri_perceptual is None:
+            raise ValueError('Both pixel and perceptual losses are None.')
+
+        if train_opt.get('gan_opt'):
+            self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
+
+        self.net_d_iters = train_opt.get('net_d_iters', 1)
+        self.net_d_init_iters = train_opt.get('net_d_init_iters', 0)
+        # set up optimizers and schedulers
+        self.setup_optimizers()
+        self.setup_schedulers()
+
+    def setup_optimizers(self):
+        train_opt = self.opt['train']
+        # optimizer g
+        optim_type = train_opt['optim_g'].pop('type')
+        self.optimizer_g = self.get_optimizer(optim_type, self.net_g.parameters(), **train_opt['optim_g'])
+        self.optimizers.append(self.optimizer_g)
+        # optimizer d
+        optim_type = train_opt['optim_d'].pop('type')
+        self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])
+        self.optimizers.append(self.optimizer_d)
+
+    def discriminate(self, input_lq, output, ground_truth):
+        """
+        This is a conditional (on the input) discriminator
+        In Batch Normalization, the fake and real images are
+        recommended to be in the same batch to avoid disparate
+        statistics in fake and real images.
+        So both fake and real images are fed to D all at once.
+        """
+        h, w = output.shape[-2:]
+        if output.shape[-2:] != input_lq.shape[-2:]:
+            lq = torch.nn.functional.interpolate(input_lq, (h, w))
+            real = torch.nn.functional.interpolate(ground_truth, (h, w))
+            fake_concat = torch.cat([lq, output], dim=1)
+            real_concat = torch.cat([lq, real], dim=1)
+        else:
+            fake_concat = torch.cat([input_lq, output], dim=1)
+            real_concat = torch.cat([input_lq, ground_truth], dim=1)
+
+        fake_and_real = torch.cat([fake_concat, real_concat], dim=0)
+        discriminator_out = self.net_d(fake_and_real)
+        pred_fake, pred_real = self._divide_pred(discriminator_out)
+        return pred_fake, pred_real
+
+    @staticmethod
+    def _divide_pred(pred):
+        """
+        Take the prediction of fake and real images from the combined batch.
+        The prediction contains the intermediate outputs of multiscale GAN,
+        so it's usually a list
+        """
+        if type(pred) == list:
+            fake = []
+            real = []
+            for p in pred:
+                fake.append([tensor[:tensor.size(0) // 2] for tensor in p])
+                real.append([tensor[tensor.size(0) // 2:] for tensor in p])
+        else:
+            fake = pred[:pred.size(0) // 2]
+            real = pred[pred.size(0) // 2:]
+
+        return fake, real
+
+    def optimize_parameters(self, current_iter):
+        # optimize net_g
+        for p in self.net_d.parameters():
+            p.requires_grad = False
+
+        self.optimizer_g.zero_grad()
+        self.output = self.net_g(self.lq)
+
+        l_g_total = 0
+        loss_dict = OrderedDict()
+
+        if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
+            # pixel loss
+            if self.cri_pix:
+                l_g_pix = self.cri_pix(self.output, self.gt)
+                l_g_total += l_g_pix
+                loss_dict['l_g_pix'] = l_g_pix
+
+            # perceptual loss
+            if self.cri_perceptual:
+                l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt)
+                if l_g_percep is not None:
+                    l_g_total += l_g_percep
+                    loss_dict['l_g_percep'] = l_g_percep
+                if l_g_style is not None:
+                    l_g_total += l_g_style
+                    loss_dict['l_g_style'] = l_g_style
+
+            # Requires real prediction for feature matching loss
+            pred_fake, pred_real = self.discriminate(self.lq, self.output, self.gt)
+            l_g_gan = self.cri_gan(pred_fake, True, is_disc=False)
+            l_g_total += l_g_gan
+            loss_dict['l_g_gan'] = l_g_gan
+
+            # feature matching loss
+            if self.cri_feat:
+                l_g_feat = self.cri_feat(pred_fake, pred_real)
+                l_g_total += l_g_feat
+                loss_dict['l_g_feat'] = l_g_feat
+
+            l_g_total.backward()
+            self.optimizer_g.step()
+
+        # optimize net_d
+        for p in self.net_d.parameters():
+            p.requires_grad = True
+
+        self.optimizer_d.zero_grad()
+        # TODO: Benchmark test between HiFaceGAN and SRGAN implementation:
+        # SRGAN use the same fake output for discriminator update
+        # while HiFaceGAN regenerate a new output using updated net_g
+        # This should not make too much difference though. Stick to SRGAN now.
+        # -------------------------------------------------------------------
+        # ---------- Below are original HiFaceGAN code snippet --------------
+        # -------------------------------------------------------------------
+        # with torch.no_grad():
+        #    fake_image = self.net_g(self.lq)
+        #    fake_image = fake_image.detach()
+        #    fake_image.requires_grad_()
+        #    pred_fake, pred_real = self.discriminate(self.lq, fake_image, self.gt)
+
+        # real
+        pred_fake, pred_real = self.discriminate(self.lq, self.output.detach(), self.gt)
+        l_d_real = self.cri_gan(pred_real, True, is_disc=True)
+        loss_dict['l_d_real'] = l_d_real
+        # fake
+        l_d_fake = self.cri_gan(pred_fake, False, is_disc=True)
+        loss_dict['l_d_fake'] = l_d_fake
+
+        l_d_total = (l_d_real + l_d_fake) / 2
+        l_d_total.backward()
+        self.optimizer_d.step()
+
+        self.log_dict = self.reduce_loss_dict(loss_dict)
+
+        if self.ema_decay > 0:
+            print('HiFaceGAN does not support EMA now. pass')
+
+    def validation(self, dataloader, current_iter, tb_logger, save_img=False):
+        """
+        Warning: HiFaceGAN requires train() mode even for validation
+        For more info, see https://github.com/Lotayou/Face-Renovation/issues/31
+
+        Args:
+            dataloader (torch.utils.data.DataLoader): Validation dataloader.
+            current_iter (int): Current iteration.
+            tb_logger (tensorboard logger): Tensorboard logger.
+            save_img (bool): Whether to save images. Default: False.
+        """
+
+        if self.opt['network_g']['type'] in ('HiFaceGAN', 'SPADEGenerator'):
+            self.net_g.train()
+
+        if self.opt['dist']:
+            self.dist_validation(dataloader, current_iter, tb_logger, save_img)
+        else:
+            print('In HiFaceGANModel: The new metrics package is under development.' +
+                  'Using super method now (Only PSNR & SSIM are supported)')
+            super().nondist_validation(dataloader, current_iter, tb_logger, save_img)
+
+    def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
+        """
+        TODO: Validation using updated metric system
+        The metrics are now evaluated after all images have been tested
+        This allows batch processing, and also allows evaluation of
+        distributional metrics, such as:
+
+        @ Frechet Inception Distance: FID
+        @ Maximum Mean Discrepancy: MMD
+
+        Warning:
+            Need careful batch management for different inference settings.
+
+        """
+        dataset_name = dataloader.dataset.opt['name']
+        with_metrics = self.opt['val'].get('metrics') is not None
+        if with_metrics:
+            self.metric_results = dict()  # {metric: 0 for metric in self.opt['val']['metrics'].keys()}
+            sr_tensors = []
+            gt_tensors = []
+
+        pbar = tqdm(total=len(dataloader), unit='image')
+        for val_data in dataloader:
+            img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
+            self.feed_data(val_data)
+            self.test()
+
+            visuals = self.get_current_visuals()  # detached cpu tensor, non-squeeze
+            sr_tensors.append(visuals['result'])
+            if 'gt' in visuals:
+                gt_tensors.append(visuals['gt'])
+                del self.gt
+
+            # tentative for out of GPU memory
+            del self.lq
+            del self.output
+            torch.cuda.empty_cache()
+
+            if save_img:
+                if self.opt['is_train']:
+                    save_img_path = osp.join(self.opt['path']['visualization'], img_name,
+                                             f'{img_name}_{current_iter}.png')
+                else:
+                    if self.opt['val']['suffix']:
+                        save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
+                                                 f'{img_name}_{self.opt["val"]["suffix"]}.png')
+                    else:
+                        save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
+                                                 f'{img_name}_{self.opt["name"]}.png')
+
+                imwrite(tensor2img(visuals['result']), save_img_path)
+
+            pbar.update(1)
+            pbar.set_description(f'Test {img_name}')
+        pbar.close()
+
+        if with_metrics:
+            sr_pack = torch.cat(sr_tensors, dim=0)
+            gt_pack = torch.cat(gt_tensors, dim=0)
+            # calculate metrics
+            for name, opt_ in self.opt['val']['metrics'].items():
+                # The new metric caller automatically returns mean value
+                # FIXME: ERROR: calculate_metric only supports two arguments. Now the codes cannot be successfully run
+                self.metric_results[name] = calculate_metric(dict(sr_pack=sr_pack, gt_pack=gt_pack), opt_)
+            self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
+
+    def save(self, epoch, current_iter):
+        if hasattr(self, 'net_g_ema'):
+            print('HiFaceGAN does not support EMA now. Fallback to normal mode.')
+
+        self.save_network(self.net_g, 'net_g', current_iter)
+        self.save_network(self.net_d, 'net_d', current_iter)
+        self.save_training_state(epoch, current_iter)
diff --git a/basicsr/models/lr_scheduler.py b/basicsr/models/lr_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..11e1c6c7a74f5233accda52370f92681d3d3cecf
--- /dev/null
+++ b/basicsr/models/lr_scheduler.py
@@ -0,0 +1,96 @@
+import math
+from collections import Counter
+from torch.optim.lr_scheduler import _LRScheduler
+
+
+class MultiStepRestartLR(_LRScheduler):
+    """ MultiStep with restarts learning rate scheme.
+
+    Args:
+        optimizer (torch.nn.optimizer): Torch optimizer.
+        milestones (list): Iterations that will decrease learning rate.
+        gamma (float): Decrease ratio. Default: 0.1.
+        restarts (list): Restart iterations. Default: [0].
+        restart_weights (list): Restart weights at each restart iteration.
+            Default: [1].
+        last_epoch (int): Used in _LRScheduler. Default: -1.
+    """
+
+    def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1):
+        self.milestones = Counter(milestones)
+        self.gamma = gamma
+        self.restarts = restarts
+        self.restart_weights = restart_weights
+        assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.'
+        super(MultiStepRestartLR, self).__init__(optimizer, last_epoch)
+
+    def get_lr(self):
+        if self.last_epoch in self.restarts:
+            weight = self.restart_weights[self.restarts.index(self.last_epoch)]
+            return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
+        if self.last_epoch not in self.milestones:
+            return [group['lr'] for group in self.optimizer.param_groups]
+        return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups]
+
+
+def get_position_from_periods(iteration, cumulative_period):
+    """Get the position from a period list.
+
+    It will return the index of the right-closest number in the period list.
+    For example, the cumulative_period = [100, 200, 300, 400],
+    if iteration == 50, return 0;
+    if iteration == 210, return 2;
+    if iteration == 300, return 2.
+
+    Args:
+        iteration (int): Current iteration.
+        cumulative_period (list[int]): Cumulative period list.
+
+    Returns:
+        int: The position of the right-closest number in the period list.
+    """
+    for i, period in enumerate(cumulative_period):
+        if iteration <= period:
+            return i
+
+
+class CosineAnnealingRestartLR(_LRScheduler):
+    """ Cosine annealing with restarts learning rate scheme.
+
+    An example of config:
+    periods = [10, 10, 10, 10]
+    restart_weights = [1, 0.5, 0.5, 0.5]
+    eta_min=1e-7
+
+    It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
+    scheduler will restart with the weights in restart_weights.
+
+    Args:
+        optimizer (torch.nn.optimizer): Torch optimizer.
+        periods (list): Period for each cosine anneling cycle.
+        restart_weights (list): Restart weights at each restart iteration.
+            Default: [1].
+        eta_min (float): The minimum lr. Default: 0.
+        last_epoch (int): Used in _LRScheduler. Default: -1.
+    """
+
+    def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1):
+        self.periods = periods
+        self.restart_weights = restart_weights
+        self.eta_min = eta_min
+        assert (len(self.periods) == len(
+            self.restart_weights)), 'periods and restart_weights should have the same length.'
+        self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))]
+        super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch)
+
+    def get_lr(self):
+        idx = get_position_from_periods(self.last_epoch, self.cumulative_period)
+        current_weight = self.restart_weights[idx]
+        nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
+        current_period = self.periods[idx]
+
+        return [
+            self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) *
+            (1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period)))
+            for base_lr in self.base_lrs
+        ]
diff --git a/basicsr/models/realesrgan_model.py b/basicsr/models/realesrgan_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..c74b28fb1dc6a7f5c5ad3f7d8bb96c19c52ee92b
--- /dev/null
+++ b/basicsr/models/realesrgan_model.py
@@ -0,0 +1,267 @@
+import numpy as np
+import random
+import torch
+from collections import OrderedDict
+from torch.nn import functional as F
+
+from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
+from basicsr.data.transforms import paired_random_crop
+from basicsr.losses.loss_util import get_refined_artifact_map
+from basicsr.models.srgan_model import SRGANModel
+from basicsr.utils import DiffJPEG, USMSharp
+from basicsr.utils.img_process_util import filter2D
+from basicsr.utils.registry import MODEL_REGISTRY
+
+
+@MODEL_REGISTRY.register(suffix='basicsr')
+class RealESRGANModel(SRGANModel):
+    """RealESRGAN Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
+
+    It mainly performs:
+    1. randomly synthesize LQ images in GPU tensors
+    2. optimize the networks with GAN training.
+    """
+
+    def __init__(self, opt):
+        super(RealESRGANModel, self).__init__(opt)
+        self.jpeger = DiffJPEG(differentiable=False).cuda()  # simulate JPEG compression artifacts
+        self.usm_sharpener = USMSharp().cuda()  # do usm sharpening
+        self.queue_size = opt.get('queue_size', 180)
+
+    @torch.no_grad()
+    def _dequeue_and_enqueue(self):
+        """It is the training pair pool for increasing the diversity in a batch.
+
+        Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
+        batch could not have different resize scaling factors. Therefore, we employ this training pair pool
+        to increase the degradation diversity in a batch.
+        """
+        # initialize
+        b, c, h, w = self.lq.size()
+        if not hasattr(self, 'queue_lr'):
+            assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
+            self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
+            _, c, h, w = self.gt.size()
+            self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
+            self.queue_ptr = 0
+        if self.queue_ptr == self.queue_size:  # the pool is full
+            # do dequeue and enqueue
+            # shuffle
+            idx = torch.randperm(self.queue_size)
+            self.queue_lr = self.queue_lr[idx]
+            self.queue_gt = self.queue_gt[idx]
+            # get first b samples
+            lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
+            gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
+            # update the queue
+            self.queue_lr[0:b, :, :, :] = self.lq.clone()
+            self.queue_gt[0:b, :, :, :] = self.gt.clone()
+
+            self.lq = lq_dequeue
+            self.gt = gt_dequeue
+        else:
+            # only do enqueue
+            self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
+            self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
+            self.queue_ptr = self.queue_ptr + b
+
+    @torch.no_grad()
+    def feed_data(self, data):
+        """Accept data from dataloader, and then add two-order degradations to obtain LQ images.
+        """
+        if self.is_train and self.opt.get('high_order_degradation', True):
+            # training data synthesis
+            self.gt = data['gt'].to(self.device)
+            self.gt_usm = self.usm_sharpener(self.gt)
+
+            self.kernel1 = data['kernel1'].to(self.device)
+            self.kernel2 = data['kernel2'].to(self.device)
+            self.sinc_kernel = data['sinc_kernel'].to(self.device)
+
+            ori_h, ori_w = self.gt.size()[2:4]
+
+            # ----------------------- The first degradation process ----------------------- #
+            # blur
+            out = filter2D(self.gt_usm, self.kernel1)
+            # random resize
+            updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
+            if updown_type == 'up':
+                scale = np.random.uniform(1, self.opt['resize_range'][1])
+            elif updown_type == 'down':
+                scale = np.random.uniform(self.opt['resize_range'][0], 1)
+            else:
+                scale = 1
+            mode = random.choice(['area', 'bilinear', 'bicubic'])
+            out = F.interpolate(out, scale_factor=scale, mode=mode)
+            # add noise
+            gray_noise_prob = self.opt['gray_noise_prob']
+            if np.random.uniform() < self.opt['gaussian_noise_prob']:
+                out = random_add_gaussian_noise_pt(
+                    out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
+            else:
+                out = random_add_poisson_noise_pt(
+                    out,
+                    scale_range=self.opt['poisson_scale_range'],
+                    gray_prob=gray_noise_prob,
+                    clip=True,
+                    rounds=False)
+            # JPEG compression
+            jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
+            out = torch.clamp(out, 0, 1)  # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
+            out = self.jpeger(out, quality=jpeg_p)
+
+            # ----------------------- The second degradation process ----------------------- #
+            # blur
+            if np.random.uniform() < self.opt['second_blur_prob']:
+                out = filter2D(out, self.kernel2)
+            # random resize
+            updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
+            if updown_type == 'up':
+                scale = np.random.uniform(1, self.opt['resize_range2'][1])
+            elif updown_type == 'down':
+                scale = np.random.uniform(self.opt['resize_range2'][0], 1)
+            else:
+                scale = 1
+            mode = random.choice(['area', 'bilinear', 'bicubic'])
+            out = F.interpolate(
+                out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
+            # add noise
+            gray_noise_prob = self.opt['gray_noise_prob2']
+            if np.random.uniform() < self.opt['gaussian_noise_prob2']:
+                out = random_add_gaussian_noise_pt(
+                    out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
+            else:
+                out = random_add_poisson_noise_pt(
+                    out,
+                    scale_range=self.opt['poisson_scale_range2'],
+                    gray_prob=gray_noise_prob,
+                    clip=True,
+                    rounds=False)
+
+            # JPEG compression + the final sinc filter
+            # We also need to resize images to desired sizes. We group [resize back + sinc filter] together
+            # as one operation.
+            # We consider two orders:
+            #   1. [resize back + sinc filter] + JPEG compression
+            #   2. JPEG compression + [resize back + sinc filter]
+            # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
+            if np.random.uniform() < 0.5:
+                # resize back + the final sinc filter
+                mode = random.choice(['area', 'bilinear', 'bicubic'])
+                out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
+                out = filter2D(out, self.sinc_kernel)
+                # JPEG compression
+                jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
+                out = torch.clamp(out, 0, 1)
+                out = self.jpeger(out, quality=jpeg_p)
+            else:
+                # JPEG compression
+                jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
+                out = torch.clamp(out, 0, 1)
+                out = self.jpeger(out, quality=jpeg_p)
+                # resize back + the final sinc filter
+                mode = random.choice(['area', 'bilinear', 'bicubic'])
+                out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
+                out = filter2D(out, self.sinc_kernel)
+
+            # clamp and round
+            self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
+
+            # random crop
+            gt_size = self.opt['gt_size']
+            (self.gt, self.gt_usm), self.lq = paired_random_crop([self.gt, self.gt_usm], self.lq, gt_size,
+                                                                 self.opt['scale'])
+
+            # training pair pool
+            self._dequeue_and_enqueue()
+            # sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue
+            self.gt_usm = self.usm_sharpener(self.gt)
+            self.lq = self.lq.contiguous()  # for the warning: grad and param do not obey the gradient layout contract
+        else:
+            # for paired training or validation
+            self.lq = data['lq'].to(self.device)
+            if 'gt' in data:
+                self.gt = data['gt'].to(self.device)
+                self.gt_usm = self.usm_sharpener(self.gt)
+
+    def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
+        # do not use the synthetic process during validation
+        self.is_train = False
+        super(RealESRGANModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img)
+        self.is_train = True
+
+    def optimize_parameters(self, current_iter):
+        # usm sharpening
+        l1_gt = self.gt_usm
+        percep_gt = self.gt_usm
+        gan_gt = self.gt_usm
+        if self.opt['l1_gt_usm'] is False:
+            l1_gt = self.gt
+        if self.opt['percep_gt_usm'] is False:
+            percep_gt = self.gt
+        if self.opt['gan_gt_usm'] is False:
+            gan_gt = self.gt
+
+        # optimize net_g
+        for p in self.net_d.parameters():
+            p.requires_grad = False
+
+        self.optimizer_g.zero_grad()
+        self.output = self.net_g(self.lq)
+        if self.cri_ldl:
+            self.output_ema = self.net_g_ema(self.lq)
+
+        l_g_total = 0
+        loss_dict = OrderedDict()
+        if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
+            # pixel loss
+            if self.cri_pix:
+                l_g_pix = self.cri_pix(self.output, l1_gt)
+                l_g_total += l_g_pix
+                loss_dict['l_g_pix'] = l_g_pix
+            if self.cri_ldl:
+                pixel_weight = get_refined_artifact_map(self.gt, self.output, self.output_ema, 7)
+                l_g_ldl = self.cri_ldl(torch.mul(pixel_weight, self.output), torch.mul(pixel_weight, self.gt))
+                l_g_total += l_g_ldl
+                loss_dict['l_g_ldl'] = l_g_ldl
+            # perceptual loss
+            if self.cri_perceptual:
+                l_g_percep, l_g_style = self.cri_perceptual(self.output, percep_gt)
+                if l_g_percep is not None:
+                    l_g_total += l_g_percep
+                    loss_dict['l_g_percep'] = l_g_percep
+                if l_g_style is not None:
+                    l_g_total += l_g_style
+                    loss_dict['l_g_style'] = l_g_style
+            # gan loss
+            fake_g_pred = self.net_d(self.output)
+            l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
+            l_g_total += l_g_gan
+            loss_dict['l_g_gan'] = l_g_gan
+
+            l_g_total.backward()
+            self.optimizer_g.step()
+
+        # optimize net_d
+        for p in self.net_d.parameters():
+            p.requires_grad = True
+
+        self.optimizer_d.zero_grad()
+        # real
+        real_d_pred = self.net_d(gan_gt)
+        l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
+        loss_dict['l_d_real'] = l_d_real
+        loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
+        l_d_real.backward()
+        # fake
+        fake_d_pred = self.net_d(self.output.detach().clone())  # clone for pt1.9
+        l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
+        loss_dict['l_d_fake'] = l_d_fake
+        loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
+        l_d_fake.backward()
+        self.optimizer_d.step()
+
+        if self.ema_decay > 0:
+            self.model_ema(decay=self.ema_decay)
+
+        self.log_dict = self.reduce_loss_dict(loss_dict)
diff --git a/basicsr/models/realesrnet_model.py b/basicsr/models/realesrnet_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5790918b969682a0db0e2ed9236b7046d627b90
--- /dev/null
+++ b/basicsr/models/realesrnet_model.py
@@ -0,0 +1,189 @@
+import numpy as np
+import random
+import torch
+from torch.nn import functional as F
+
+from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
+from basicsr.data.transforms import paired_random_crop
+from basicsr.models.sr_model import SRModel
+from basicsr.utils import DiffJPEG, USMSharp
+from basicsr.utils.img_process_util import filter2D
+from basicsr.utils.registry import MODEL_REGISTRY
+
+
+@MODEL_REGISTRY.register(suffix='basicsr')
+class RealESRNetModel(SRModel):
+    """RealESRNet Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
+
+    It is trained without GAN losses.
+    It mainly performs:
+    1. randomly synthesize LQ images in GPU tensors
+    2. optimize the networks with GAN training.
+    """
+
+    def __init__(self, opt):
+        super(RealESRNetModel, self).__init__(opt)
+        self.jpeger = DiffJPEG(differentiable=False).cuda()  # simulate JPEG compression artifacts
+        self.usm_sharpener = USMSharp().cuda()  # do usm sharpening
+        self.queue_size = opt.get('queue_size', 180)
+
+    @torch.no_grad()
+    def _dequeue_and_enqueue(self):
+        """It is the training pair pool for increasing the diversity in a batch.
+
+        Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
+        batch could not have different resize scaling factors. Therefore, we employ this training pair pool
+        to increase the degradation diversity in a batch.
+        """
+        # initialize
+        b, c, h, w = self.lq.size()
+        if not hasattr(self, 'queue_lr'):
+            assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
+            self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
+            _, c, h, w = self.gt.size()
+            self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
+            self.queue_ptr = 0
+        if self.queue_ptr == self.queue_size:  # the pool is full
+            # do dequeue and enqueue
+            # shuffle
+            idx = torch.randperm(self.queue_size)
+            self.queue_lr = self.queue_lr[idx]
+            self.queue_gt = self.queue_gt[idx]
+            # get first b samples
+            lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
+            gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
+            # update the queue
+            self.queue_lr[0:b, :, :, :] = self.lq.clone()
+            self.queue_gt[0:b, :, :, :] = self.gt.clone()
+
+            self.lq = lq_dequeue
+            self.gt = gt_dequeue
+        else:
+            # only do enqueue
+            self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
+            self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
+            self.queue_ptr = self.queue_ptr + b
+
+    @torch.no_grad()
+    def feed_data(self, data):
+        """Accept data from dataloader, and then add two-order degradations to obtain LQ images.
+        """
+        if self.is_train and self.opt.get('high_order_degradation', True):
+            # training data synthesis
+            self.gt = data['gt'].to(self.device)
+            # USM sharpen the GT images
+            if self.opt['gt_usm'] is True:
+                self.gt = self.usm_sharpener(self.gt)
+
+            self.kernel1 = data['kernel1'].to(self.device)
+            self.kernel2 = data['kernel2'].to(self.device)
+            self.sinc_kernel = data['sinc_kernel'].to(self.device)
+
+            ori_h, ori_w = self.gt.size()[2:4]
+
+            # ----------------------- The first degradation process ----------------------- #
+            # blur
+            out = filter2D(self.gt, self.kernel1)
+            # random resize
+            updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
+            if updown_type == 'up':
+                scale = np.random.uniform(1, self.opt['resize_range'][1])
+            elif updown_type == 'down':
+                scale = np.random.uniform(self.opt['resize_range'][0], 1)
+            else:
+                scale = 1
+            mode = random.choice(['area', 'bilinear', 'bicubic'])
+            out = F.interpolate(out, scale_factor=scale, mode=mode)
+            # add noise
+            gray_noise_prob = self.opt['gray_noise_prob']
+            if np.random.uniform() < self.opt['gaussian_noise_prob']:
+                out = random_add_gaussian_noise_pt(
+                    out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
+            else:
+                out = random_add_poisson_noise_pt(
+                    out,
+                    scale_range=self.opt['poisson_scale_range'],
+                    gray_prob=gray_noise_prob,
+                    clip=True,
+                    rounds=False)
+            # JPEG compression
+            jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
+            out = torch.clamp(out, 0, 1)  # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
+            out = self.jpeger(out, quality=jpeg_p)
+
+            # ----------------------- The second degradation process ----------------------- #
+            # blur
+            if np.random.uniform() < self.opt['second_blur_prob']:
+                out = filter2D(out, self.kernel2)
+            # random resize
+            updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
+            if updown_type == 'up':
+                scale = np.random.uniform(1, self.opt['resize_range2'][1])
+            elif updown_type == 'down':
+                scale = np.random.uniform(self.opt['resize_range2'][0], 1)
+            else:
+                scale = 1
+            mode = random.choice(['area', 'bilinear', 'bicubic'])
+            out = F.interpolate(
+                out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
+            # add noise
+            gray_noise_prob = self.opt['gray_noise_prob2']
+            if np.random.uniform() < self.opt['gaussian_noise_prob2']:
+                out = random_add_gaussian_noise_pt(
+                    out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
+            else:
+                out = random_add_poisson_noise_pt(
+                    out,
+                    scale_range=self.opt['poisson_scale_range2'],
+                    gray_prob=gray_noise_prob,
+                    clip=True,
+                    rounds=False)
+
+            # JPEG compression + the final sinc filter
+            # We also need to resize images to desired sizes. We group [resize back + sinc filter] together
+            # as one operation.
+            # We consider two orders:
+            #   1. [resize back + sinc filter] + JPEG compression
+            #   2. JPEG compression + [resize back + sinc filter]
+            # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
+            if np.random.uniform() < 0.5:
+                # resize back + the final sinc filter
+                mode = random.choice(['area', 'bilinear', 'bicubic'])
+                out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
+                out = filter2D(out, self.sinc_kernel)
+                # JPEG compression
+                jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
+                out = torch.clamp(out, 0, 1)
+                out = self.jpeger(out, quality=jpeg_p)
+            else:
+                # JPEG compression
+                jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
+                out = torch.clamp(out, 0, 1)
+                out = self.jpeger(out, quality=jpeg_p)
+                # resize back + the final sinc filter
+                mode = random.choice(['area', 'bilinear', 'bicubic'])
+                out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
+                out = filter2D(out, self.sinc_kernel)
+
+            # clamp and round
+            self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
+
+            # random crop
+            gt_size = self.opt['gt_size']
+            self.gt, self.lq = paired_random_crop(self.gt, self.lq, gt_size, self.opt['scale'])
+
+            # training pair pool
+            self._dequeue_and_enqueue()
+            self.lq = self.lq.contiguous()  # for the warning: grad and param do not obey the gradient layout contract
+        else:
+            # for paired training or validation
+            self.lq = data['lq'].to(self.device)
+            if 'gt' in data:
+                self.gt = data['gt'].to(self.device)
+                self.gt_usm = self.usm_sharpener(self.gt)
+
+    def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
+        # do not use the synthetic process during validation
+        self.is_train = False
+        super(RealESRNetModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img)
+        self.is_train = True
diff --git a/basicsr/models/sr_model.py b/basicsr/models/sr_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..787f1fd2eab5963579c764c1bfb87199b7dd196f
--- /dev/null
+++ b/basicsr/models/sr_model.py
@@ -0,0 +1,279 @@
+import torch
+from collections import OrderedDict
+from os import path as osp
+from tqdm import tqdm
+
+from basicsr.archs import build_network
+from basicsr.losses import build_loss
+from basicsr.metrics import calculate_metric
+from basicsr.utils import get_root_logger, imwrite, tensor2img
+from basicsr.utils.registry import MODEL_REGISTRY
+from .base_model import BaseModel
+
+
+@MODEL_REGISTRY.register()
+class SRModel(BaseModel):
+    """Base SR model for single image super-resolution."""
+
+    def __init__(self, opt):
+        super(SRModel, self).__init__(opt)
+
+        # define network
+        self.net_g = build_network(opt['network_g'])
+        self.net_g = self.model_to_device(self.net_g)
+        self.print_network(self.net_g)
+
+        # load pretrained models
+        load_path = self.opt['path'].get('pretrain_network_g', None)
+        if load_path is not None:
+            param_key = self.opt['path'].get('param_key_g', 'params')
+            self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)
+
+        if self.is_train:
+            self.init_training_settings()
+
+    def init_training_settings(self):
+        self.net_g.train()
+        train_opt = self.opt['train']
+
+        self.ema_decay = train_opt.get('ema_decay', 0)
+        if self.ema_decay > 0:
+            logger = get_root_logger()
+            logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
+            # define network net_g with Exponential Moving Average (EMA)
+            # net_g_ema is used only for testing on one GPU and saving
+            # There is no need to wrap with DistributedDataParallel
+            self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
+            # load pretrained model
+            load_path = self.opt['path'].get('pretrain_network_g', None)
+            if load_path is not None:
+                self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
+            else:
+                self.model_ema(0)  # copy net_g weight
+            self.net_g_ema.eval()
+
+        # define losses
+        if train_opt.get('pixel_opt'):
+            self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
+        else:
+            self.cri_pix = None
+
+        if train_opt.get('perceptual_opt'):
+            self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
+        else:
+            self.cri_perceptual = None
+
+        if self.cri_pix is None and self.cri_perceptual is None:
+            raise ValueError('Both pixel and perceptual losses are None.')
+
+        # set up optimizers and schedulers
+        self.setup_optimizers()
+        self.setup_schedulers()
+
+    def setup_optimizers(self):
+        train_opt = self.opt['train']
+        optim_params = []
+        for k, v in self.net_g.named_parameters():
+            if v.requires_grad:
+                optim_params.append(v)
+            else:
+                logger = get_root_logger()
+                logger.warning(f'Params {k} will not be optimized.')
+
+        optim_type = train_opt['optim_g'].pop('type')
+        self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g'])
+        self.optimizers.append(self.optimizer_g)
+
+    def feed_data(self, data):
+        self.lq = data['lq'].to(self.device)
+        if 'gt' in data:
+            self.gt = data['gt'].to(self.device)
+
+    def optimize_parameters(self, current_iter):
+        self.optimizer_g.zero_grad()
+        self.output = self.net_g(self.lq)
+
+        l_total = 0
+        loss_dict = OrderedDict()
+        # pixel loss
+        if self.cri_pix:
+            l_pix = self.cri_pix(self.output, self.gt)
+            l_total += l_pix
+            loss_dict['l_pix'] = l_pix
+        # perceptual loss
+        if self.cri_perceptual:
+            l_percep, l_style = self.cri_perceptual(self.output, self.gt)
+            if l_percep is not None:
+                l_total += l_percep
+                loss_dict['l_percep'] = l_percep
+            if l_style is not None:
+                l_total += l_style
+                loss_dict['l_style'] = l_style
+
+        l_total.backward()
+        self.optimizer_g.step()
+
+        self.log_dict = self.reduce_loss_dict(loss_dict)
+
+        if self.ema_decay > 0:
+            self.model_ema(decay=self.ema_decay)
+
+    def test(self):
+        if hasattr(self, 'net_g_ema'):
+            self.net_g_ema.eval()
+            with torch.no_grad():
+                self.output = self.net_g_ema(self.lq)
+        else:
+            self.net_g.eval()
+            with torch.no_grad():
+                self.output = self.net_g(self.lq)
+            self.net_g.train()
+
+    def test_selfensemble(self):
+        # TODO: to be tested
+        # 8 augmentations
+        # modified from https://github.com/thstkdgus35/EDSR-PyTorch
+
+        def _transform(v, op):
+            # if self.precision != 'single': v = v.float()
+            v2np = v.data.cpu().numpy()
+            if op == 'v':
+                tfnp = v2np[:, :, :, ::-1].copy()
+            elif op == 'h':
+                tfnp = v2np[:, :, ::-1, :].copy()
+            elif op == 't':
+                tfnp = v2np.transpose((0, 1, 3, 2)).copy()
+
+            ret = torch.Tensor(tfnp).to(self.device)
+            # if self.precision == 'half': ret = ret.half()
+
+            return ret
+
+        # prepare augmented data
+        lq_list = [self.lq]
+        for tf in 'v', 'h', 't':
+            lq_list.extend([_transform(t, tf) for t in lq_list])
+
+        # inference
+        if hasattr(self, 'net_g_ema'):
+            self.net_g_ema.eval()
+            with torch.no_grad():
+                out_list = [self.net_g_ema(aug) for aug in lq_list]
+        else:
+            self.net_g.eval()
+            with torch.no_grad():
+                out_list = [self.net_g_ema(aug) for aug in lq_list]
+            self.net_g.train()
+
+        # merge results
+        for i in range(len(out_list)):
+            if i > 3:
+                out_list[i] = _transform(out_list[i], 't')
+            if i % 4 > 1:
+                out_list[i] = _transform(out_list[i], 'h')
+            if (i % 4) % 2 == 1:
+                out_list[i] = _transform(out_list[i], 'v')
+        output = torch.cat(out_list, dim=0)
+
+        self.output = output.mean(dim=0, keepdim=True)
+
+    def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
+        if self.opt['rank'] == 0:
+            self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
+
+    def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
+        dataset_name = dataloader.dataset.opt['name']
+        with_metrics = self.opt['val'].get('metrics') is not None
+        use_pbar = self.opt['val'].get('pbar', False)
+
+        if with_metrics:
+            if not hasattr(self, 'metric_results'):  # only execute in the first run
+                self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
+            # initialize the best metric results for each dataset_name (supporting multiple validation datasets)
+            self._initialize_best_metric_results(dataset_name)
+        # zero self.metric_results
+        if with_metrics:
+            self.metric_results = {metric: 0 for metric in self.metric_results}
+
+        metric_data = dict()
+        if use_pbar:
+            pbar = tqdm(total=len(dataloader), unit='image')
+
+        for idx, val_data in enumerate(dataloader):
+            img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
+            self.feed_data(val_data)
+            self.test()
+
+            visuals = self.get_current_visuals()
+            sr_img = tensor2img([visuals['result']])
+            metric_data['img'] = sr_img
+            if 'gt' in visuals:
+                gt_img = tensor2img([visuals['gt']])
+                metric_data['img2'] = gt_img
+                del self.gt
+
+            # tentative for out of GPU memory
+            del self.lq
+            del self.output
+            torch.cuda.empty_cache()
+
+            if save_img:
+                if self.opt['is_train']:
+                    save_img_path = osp.join(self.opt['path']['visualization'], img_name,
+                                             f'{img_name}_{current_iter}.png')
+                else:
+                    if self.opt['val']['suffix']:
+                        save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
+                                                 f'{img_name}_{self.opt["val"]["suffix"]}.png')
+                    else:
+                        save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
+                                                 f'{img_name}_{self.opt["name"]}.png')
+                imwrite(sr_img, save_img_path)
+
+            if with_metrics:
+                # calculate metrics
+                for name, opt_ in self.opt['val']['metrics'].items():
+                    self.metric_results[name] += calculate_metric(metric_data, opt_)
+            if use_pbar:
+                pbar.update(1)
+                pbar.set_description(f'Test {img_name}')
+        if use_pbar:
+            pbar.close()
+
+        if with_metrics:
+            for metric in self.metric_results.keys():
+                self.metric_results[metric] /= (idx + 1)
+                # update the best metric result
+                self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter)
+
+            self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
+
+    def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
+        log_str = f'Validation {dataset_name}\n'
+        for metric, value in self.metric_results.items():
+            log_str += f'\t # {metric}: {value:.4f}'
+            if hasattr(self, 'best_metric_results'):
+                log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ '
+                            f'{self.best_metric_results[dataset_name][metric]["iter"]} iter')
+            log_str += '\n'
+
+        logger = get_root_logger()
+        logger.info(log_str)
+        if tb_logger:
+            for metric, value in self.metric_results.items():
+                tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter)
+
+    def get_current_visuals(self):
+        out_dict = OrderedDict()
+        out_dict['lq'] = self.lq.detach().cpu()
+        out_dict['result'] = self.output.detach().cpu()
+        if hasattr(self, 'gt'):
+            out_dict['gt'] = self.gt.detach().cpu()
+        return out_dict
+
+    def save(self, epoch, current_iter):
+        if hasattr(self, 'net_g_ema'):
+            self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
+        else:
+            self.save_network(self.net_g, 'net_g', current_iter)
+        self.save_training_state(epoch, current_iter)
diff --git a/basicsr/models/srgan_model.py b/basicsr/models/srgan_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..45387ca7908e3f38f59a605adb8242ad12fcf1a1
--- /dev/null
+++ b/basicsr/models/srgan_model.py
@@ -0,0 +1,149 @@
+import torch
+from collections import OrderedDict
+
+from basicsr.archs import build_network
+from basicsr.losses import build_loss
+from basicsr.utils import get_root_logger
+from basicsr.utils.registry import MODEL_REGISTRY
+from .sr_model import SRModel
+
+
+@MODEL_REGISTRY.register()
+class SRGANModel(SRModel):
+    """SRGAN model for single image super-resolution."""
+
+    def init_training_settings(self):
+        train_opt = self.opt['train']
+
+        self.ema_decay = train_opt.get('ema_decay', 0)
+        if self.ema_decay > 0:
+            logger = get_root_logger()
+            logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
+            # define network net_g with Exponential Moving Average (EMA)
+            # net_g_ema is used only for testing on one GPU and saving
+            # There is no need to wrap with DistributedDataParallel
+            self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
+            # load pretrained model
+            load_path = self.opt['path'].get('pretrain_network_g', None)
+            if load_path is not None:
+                self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
+            else:
+                self.model_ema(0)  # copy net_g weight
+            self.net_g_ema.eval()
+
+        # define network net_d
+        self.net_d = build_network(self.opt['network_d'])
+        self.net_d = self.model_to_device(self.net_d)
+        self.print_network(self.net_d)
+
+        # load pretrained models
+        load_path = self.opt['path'].get('pretrain_network_d', None)
+        if load_path is not None:
+            param_key = self.opt['path'].get('param_key_d', 'params')
+            self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True), param_key)
+
+        self.net_g.train()
+        self.net_d.train()
+
+        # define losses
+        if train_opt.get('pixel_opt'):
+            self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
+        else:
+            self.cri_pix = None
+
+        if train_opt.get('ldl_opt'):
+            self.cri_ldl = build_loss(train_opt['ldl_opt']).to(self.device)
+        else:
+            self.cri_ldl = None
+
+        if train_opt.get('perceptual_opt'):
+            self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
+        else:
+            self.cri_perceptual = None
+
+        if train_opt.get('gan_opt'):
+            self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
+
+        self.net_d_iters = train_opt.get('net_d_iters', 1)
+        self.net_d_init_iters = train_opt.get('net_d_init_iters', 0)
+
+        # set up optimizers and schedulers
+        self.setup_optimizers()
+        self.setup_schedulers()
+
+    def setup_optimizers(self):
+        train_opt = self.opt['train']
+        # optimizer g
+        optim_type = train_opt['optim_g'].pop('type')
+        self.optimizer_g = self.get_optimizer(optim_type, self.net_g.parameters(), **train_opt['optim_g'])
+        self.optimizers.append(self.optimizer_g)
+        # optimizer d
+        optim_type = train_opt['optim_d'].pop('type')
+        self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])
+        self.optimizers.append(self.optimizer_d)
+
+    def optimize_parameters(self, current_iter):
+        # optimize net_g
+        for p in self.net_d.parameters():
+            p.requires_grad = False
+
+        self.optimizer_g.zero_grad()
+        self.output = self.net_g(self.lq)
+
+        l_g_total = 0
+        loss_dict = OrderedDict()
+        if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
+            # pixel loss
+            if self.cri_pix:
+                l_g_pix = self.cri_pix(self.output, self.gt)
+                l_g_total += l_g_pix
+                loss_dict['l_g_pix'] = l_g_pix
+            # perceptual loss
+            if self.cri_perceptual:
+                l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt)
+                if l_g_percep is not None:
+                    l_g_total += l_g_percep
+                    loss_dict['l_g_percep'] = l_g_percep
+                if l_g_style is not None:
+                    l_g_total += l_g_style
+                    loss_dict['l_g_style'] = l_g_style
+            # gan loss
+            fake_g_pred = self.net_d(self.output)
+            l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
+            l_g_total += l_g_gan
+            loss_dict['l_g_gan'] = l_g_gan
+
+            l_g_total.backward()
+            self.optimizer_g.step()
+
+        # optimize net_d
+        for p in self.net_d.parameters():
+            p.requires_grad = True
+
+        self.optimizer_d.zero_grad()
+        # real
+        real_d_pred = self.net_d(self.gt)
+        l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
+        loss_dict['l_d_real'] = l_d_real
+        loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
+        l_d_real.backward()
+        # fake
+        fake_d_pred = self.net_d(self.output.detach())
+        l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
+        loss_dict['l_d_fake'] = l_d_fake
+        loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
+        l_d_fake.backward()
+        self.optimizer_d.step()
+
+        self.log_dict = self.reduce_loss_dict(loss_dict)
+
+        if self.ema_decay > 0:
+            self.model_ema(decay=self.ema_decay)
+
+    def save(self, epoch, current_iter):
+        if hasattr(self, 'net_g_ema'):
+            self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
+        else:
+            self.save_network(self.net_g, 'net_g', current_iter)
+        self.save_network(self.net_d, 'net_d', current_iter)
+        self.save_training_state(epoch, current_iter)
diff --git a/basicsr/models/stylegan2_model.py b/basicsr/models/stylegan2_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7da708122160f2be51a98a6a635349f34ee042e
--- /dev/null
+++ b/basicsr/models/stylegan2_model.py
@@ -0,0 +1,283 @@
+import cv2
+import math
+import numpy as np
+import random
+import torch
+from collections import OrderedDict
+from os import path as osp
+
+from basicsr.archs import build_network
+from basicsr.losses import build_loss
+from basicsr.losses.gan_loss import g_path_regularize, r1_penalty
+from basicsr.utils import imwrite, tensor2img
+from basicsr.utils.registry import MODEL_REGISTRY
+from .base_model import BaseModel
+
+
+@MODEL_REGISTRY.register()
+class StyleGAN2Model(BaseModel):
+    """StyleGAN2 model."""
+
+    def __init__(self, opt):
+        super(StyleGAN2Model, self).__init__(opt)
+
+        # define network net_g
+        self.net_g = build_network(opt['network_g'])
+        self.net_g = self.model_to_device(self.net_g)
+        self.print_network(self.net_g)
+        # load pretrained model
+        load_path = self.opt['path'].get('pretrain_network_g', None)
+        if load_path is not None:
+            param_key = self.opt['path'].get('param_key_g', 'params')
+            self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)
+
+        # latent dimension: self.num_style_feat
+        self.num_style_feat = opt['network_g']['num_style_feat']
+        num_val_samples = self.opt['val'].get('num_val_samples', 16)
+        self.fixed_sample = torch.randn(num_val_samples, self.num_style_feat, device=self.device)
+
+        if self.is_train:
+            self.init_training_settings()
+
+    def init_training_settings(self):
+        train_opt = self.opt['train']
+
+        # define network net_d
+        self.net_d = build_network(self.opt['network_d'])
+        self.net_d = self.model_to_device(self.net_d)
+        self.print_network(self.net_d)
+
+        # load pretrained model
+        load_path = self.opt['path'].get('pretrain_network_d', None)
+        if load_path is not None:
+            param_key = self.opt['path'].get('param_key_d', 'params')
+            self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True), param_key)
+
+        # define network net_g with Exponential Moving Average (EMA)
+        # net_g_ema only used for testing on one GPU and saving, do not need to
+        # wrap with DistributedDataParallel
+        self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
+        # load pretrained model
+        load_path = self.opt['path'].get('pretrain_network_g', None)
+        if load_path is not None:
+            self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
+        else:
+            self.model_ema(0)  # copy net_g weight
+
+        self.net_g.train()
+        self.net_d.train()
+        self.net_g_ema.eval()
+
+        # define losses
+        # gan loss (wgan)
+        self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
+        # regularization weights
+        self.r1_reg_weight = train_opt['r1_reg_weight']  # for discriminator
+        self.path_reg_weight = train_opt['path_reg_weight']  # for generator
+
+        self.net_g_reg_every = train_opt['net_g_reg_every']
+        self.net_d_reg_every = train_opt['net_d_reg_every']
+        self.mixing_prob = train_opt['mixing_prob']
+
+        self.mean_path_length = 0
+
+        # set up optimizers and schedulers
+        self.setup_optimizers()
+        self.setup_schedulers()
+
+    def setup_optimizers(self):
+        train_opt = self.opt['train']
+        # optimizer g
+        net_g_reg_ratio = self.net_g_reg_every / (self.net_g_reg_every + 1)
+        if self.opt['network_g']['type'] == 'StyleGAN2GeneratorC':
+            normal_params = []
+            style_mlp_params = []
+            modulation_conv_params = []
+            for name, param in self.net_g.named_parameters():
+                if 'modulation' in name:
+                    normal_params.append(param)
+                elif 'style_mlp' in name:
+                    style_mlp_params.append(param)
+                elif 'modulated_conv' in name:
+                    modulation_conv_params.append(param)
+                else:
+                    normal_params.append(param)
+            optim_params_g = [
+                {  # add normal params first
+                    'params': normal_params,
+                    'lr': train_opt['optim_g']['lr']
+                },
+                {
+                    'params': style_mlp_params,
+                    'lr': train_opt['optim_g']['lr'] * 0.01
+                },
+                {
+                    'params': modulation_conv_params,
+                    'lr': train_opt['optim_g']['lr'] / 3
+                }
+            ]
+        else:
+            normal_params = []
+            for name, param in self.net_g.named_parameters():
+                normal_params.append(param)
+            optim_params_g = [{  # add normal params first
+                'params': normal_params,
+                'lr': train_opt['optim_g']['lr']
+            }]
+
+        optim_type = train_opt['optim_g'].pop('type')
+        lr = train_opt['optim_g']['lr'] * net_g_reg_ratio
+        betas = (0**net_g_reg_ratio, 0.99**net_g_reg_ratio)
+        self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, lr, betas=betas)
+        self.optimizers.append(self.optimizer_g)
+
+        # optimizer d
+        net_d_reg_ratio = self.net_d_reg_every / (self.net_d_reg_every + 1)
+        if self.opt['network_d']['type'] == 'StyleGAN2DiscriminatorC':
+            normal_params = []
+            linear_params = []
+            for name, param in self.net_d.named_parameters():
+                if 'final_linear' in name:
+                    linear_params.append(param)
+                else:
+                    normal_params.append(param)
+            optim_params_d = [
+                {  # add normal params first
+                    'params': normal_params,
+                    'lr': train_opt['optim_d']['lr']
+                },
+                {
+                    'params': linear_params,
+                    'lr': train_opt['optim_d']['lr'] * (1 / math.sqrt(512))
+                }
+            ]
+        else:
+            normal_params = []
+            for name, param in self.net_d.named_parameters():
+                normal_params.append(param)
+            optim_params_d = [{  # add normal params first
+                'params': normal_params,
+                'lr': train_opt['optim_d']['lr']
+            }]
+
+        optim_type = train_opt['optim_d'].pop('type')
+        lr = train_opt['optim_d']['lr'] * net_d_reg_ratio
+        betas = (0**net_d_reg_ratio, 0.99**net_d_reg_ratio)
+        self.optimizer_d = self.get_optimizer(optim_type, optim_params_d, lr, betas=betas)
+        self.optimizers.append(self.optimizer_d)
+
+    def feed_data(self, data):
+        self.real_img = data['gt'].to(self.device)
+
+    def make_noise(self, batch, num_noise):
+        if num_noise == 1:
+            noises = torch.randn(batch, self.num_style_feat, device=self.device)
+        else:
+            noises = torch.randn(num_noise, batch, self.num_style_feat, device=self.device).unbind(0)
+        return noises
+
+    def mixing_noise(self, batch, prob):
+        if random.random() < prob:
+            return self.make_noise(batch, 2)
+        else:
+            return [self.make_noise(batch, 1)]
+
+    def optimize_parameters(self, current_iter):
+        loss_dict = OrderedDict()
+
+        # optimize net_d
+        for p in self.net_d.parameters():
+            p.requires_grad = True
+        self.optimizer_d.zero_grad()
+
+        batch = self.real_img.size(0)
+        noise = self.mixing_noise(batch, self.mixing_prob)
+        fake_img, _ = self.net_g(noise)
+        fake_pred = self.net_d(fake_img.detach())
+
+        real_pred = self.net_d(self.real_img)
+        # wgan loss with softplus (logistic loss) for discriminator
+        l_d = self.cri_gan(real_pred, True, is_disc=True) + self.cri_gan(fake_pred, False, is_disc=True)
+        loss_dict['l_d'] = l_d
+        # In wgan, real_score should be positive and fake_score should be
+        # negative
+        loss_dict['real_score'] = real_pred.detach().mean()
+        loss_dict['fake_score'] = fake_pred.detach().mean()
+        l_d.backward()
+
+        if current_iter % self.net_d_reg_every == 0:
+            self.real_img.requires_grad = True
+            real_pred = self.net_d(self.real_img)
+            l_d_r1 = r1_penalty(real_pred, self.real_img)
+            l_d_r1 = (self.r1_reg_weight / 2 * l_d_r1 * self.net_d_reg_every + 0 * real_pred[0])
+            # TODO: why do we need to add 0 * real_pred, otherwise, a runtime
+            # error will arise: RuntimeError: Expected to have finished
+            # reduction in the prior iteration before starting a new one.
+            # This error indicates that your module has parameters that were
+            # not used in producing loss.
+            loss_dict['l_d_r1'] = l_d_r1.detach().mean()
+            l_d_r1.backward()
+
+        self.optimizer_d.step()
+
+        # optimize net_g
+        for p in self.net_d.parameters():
+            p.requires_grad = False
+        self.optimizer_g.zero_grad()
+
+        noise = self.mixing_noise(batch, self.mixing_prob)
+        fake_img, _ = self.net_g(noise)
+        fake_pred = self.net_d(fake_img)
+
+        # wgan loss with softplus (non-saturating loss) for generator
+        l_g = self.cri_gan(fake_pred, True, is_disc=False)
+        loss_dict['l_g'] = l_g
+        l_g.backward()
+
+        if current_iter % self.net_g_reg_every == 0:
+            path_batch_size = max(1, batch // self.opt['train']['path_batch_shrink'])
+            noise = self.mixing_noise(path_batch_size, self.mixing_prob)
+            fake_img, latents = self.net_g(noise, return_latents=True)
+            l_g_path, path_lengths, self.mean_path_length = g_path_regularize(fake_img, latents, self.mean_path_length)
+
+            l_g_path = (self.path_reg_weight * self.net_g_reg_every * l_g_path + 0 * fake_img[0, 0, 0, 0])
+            # TODO:  why do we need to add 0 * fake_img[0, 0, 0, 0]
+            l_g_path.backward()
+            loss_dict['l_g_path'] = l_g_path.detach().mean()
+            loss_dict['path_length'] = path_lengths
+
+        self.optimizer_g.step()
+
+        self.log_dict = self.reduce_loss_dict(loss_dict)
+
+        # EMA
+        self.model_ema(decay=0.5**(32 / (10 * 1000)))
+
+    def test(self):
+        with torch.no_grad():
+            self.net_g_ema.eval()
+            self.output, _ = self.net_g_ema([self.fixed_sample])
+
+    def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
+        if self.opt['rank'] == 0:
+            self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
+
+    def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
+        assert dataloader is None, 'Validation dataloader should be None.'
+        self.test()
+        result = tensor2img(self.output, min_max=(-1, 1))
+        if self.opt['is_train']:
+            save_img_path = osp.join(self.opt['path']['visualization'], 'train', f'train_{current_iter}.png')
+        else:
+            save_img_path = osp.join(self.opt['path']['visualization'], 'test', f'test_{self.opt["name"]}.png')
+        imwrite(result, save_img_path)
+        # add sample images to tb_logger
+        result = (result / 255.).astype(np.float32)
+        result = cv2.cvtColor(result, cv2.COLOR_BGR2RGB)
+        if tb_logger is not None:
+            tb_logger.add_image('samples', result, global_step=current_iter, dataformats='HWC')
+
+    def save(self, epoch, current_iter):
+        self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
+        self.save_network(self.net_d, 'net_d', current_iter)
+        self.save_training_state(epoch, current_iter)
diff --git a/basicsr/models/swinir_model.py b/basicsr/models/swinir_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac182f23b4a300aff14b2b45fcdca8c00da90c1
--- /dev/null
+++ b/basicsr/models/swinir_model.py
@@ -0,0 +1,33 @@
+import torch
+from torch.nn import functional as F
+
+from basicsr.utils.registry import MODEL_REGISTRY
+from .sr_model import SRModel
+
+
+@MODEL_REGISTRY.register()
+class SwinIRModel(SRModel):
+
+    def test(self):
+        # pad to multiplication of window_size
+        window_size = self.opt['network_g']['window_size']
+        scale = self.opt.get('scale', 1)
+        mod_pad_h, mod_pad_w = 0, 0
+        _, _, h, w = self.lq.size()
+        if h % window_size != 0:
+            mod_pad_h = window_size - h % window_size
+        if w % window_size != 0:
+            mod_pad_w = window_size - w % window_size
+        img = F.pad(self.lq, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
+        if hasattr(self, 'net_g_ema'):
+            self.net_g_ema.eval()
+            with torch.no_grad():
+                self.output = self.net_g_ema(img)
+        else:
+            self.net_g.eval()
+            with torch.no_grad():
+                self.output = self.net_g(img)
+            self.net_g.train()
+
+        _, _, h, w = self.output.size()
+        self.output = self.output[:, :, 0:h - mod_pad_h * scale, 0:w - mod_pad_w * scale]
diff --git a/basicsr/models/video_base_model.py b/basicsr/models/video_base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f7993a15e585526135d1ede094f4dcff47f64db
--- /dev/null
+++ b/basicsr/models/video_base_model.py
@@ -0,0 +1,160 @@
+import torch
+from collections import Counter
+from os import path as osp
+from torch import distributed as dist
+from tqdm import tqdm
+
+from basicsr.metrics import calculate_metric
+from basicsr.utils import get_root_logger, imwrite, tensor2img
+from basicsr.utils.dist_util import get_dist_info
+from basicsr.utils.registry import MODEL_REGISTRY
+from .sr_model import SRModel
+
+
+@MODEL_REGISTRY.register()
+class VideoBaseModel(SRModel):
+    """Base video SR model."""
+
+    def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
+        dataset = dataloader.dataset
+        dataset_name = dataset.opt['name']
+        with_metrics = self.opt['val']['metrics'] is not None
+        # initialize self.metric_results
+        # It is a dict: {
+        #    'folder1': tensor (num_frame x len(metrics)),
+        #    'folder2': tensor (num_frame x len(metrics))
+        # }
+        if with_metrics:
+            if not hasattr(self, 'metric_results'):  # only execute in the first run
+                self.metric_results = {}
+                num_frame_each_folder = Counter(dataset.data_info['folder'])
+                for folder, num_frame in num_frame_each_folder.items():
+                    self.metric_results[folder] = torch.zeros(
+                        num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device='cuda')
+            # initialize the best metric results
+            self._initialize_best_metric_results(dataset_name)
+        # zero self.metric_results
+        rank, world_size = get_dist_info()
+        if with_metrics:
+            for _, tensor in self.metric_results.items():
+                tensor.zero_()
+
+        metric_data = dict()
+        # record all frames (border and center frames)
+        if rank == 0:
+            pbar = tqdm(total=len(dataset), unit='frame')
+        for idx in range(rank, len(dataset), world_size):
+            val_data = dataset[idx]
+            val_data['lq'].unsqueeze_(0)
+            val_data['gt'].unsqueeze_(0)
+            folder = val_data['folder']
+            frame_idx, max_idx = val_data['idx'].split('/')
+            lq_path = val_data['lq_path']
+
+            self.feed_data(val_data)
+            self.test()
+            visuals = self.get_current_visuals()
+            result_img = tensor2img([visuals['result']])
+            metric_data['img'] = result_img
+            if 'gt' in visuals:
+                gt_img = tensor2img([visuals['gt']])
+                metric_data['img2'] = gt_img
+                del self.gt
+
+            # tentative for out of GPU memory
+            del self.lq
+            del self.output
+            torch.cuda.empty_cache()
+
+            if save_img:
+                if self.opt['is_train']:
+                    raise NotImplementedError('saving image is not supported during training.')
+                else:
+                    if 'vimeo' in dataset_name.lower():  # vimeo90k dataset
+                        split_result = lq_path.split('/')
+                        img_name = f'{split_result[-3]}_{split_result[-2]}_{split_result[-1].split(".")[0]}'
+                    else:  # other datasets, e.g., REDS, Vid4
+                        img_name = osp.splitext(osp.basename(lq_path))[0]
+
+                    if self.opt['val']['suffix']:
+                        save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, folder,
+                                                 f'{img_name}_{self.opt["val"]["suffix"]}.png')
+                    else:
+                        save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, folder,
+                                                 f'{img_name}_{self.opt["name"]}.png')
+                imwrite(result_img, save_img_path)
+
+            if with_metrics:
+                # calculate metrics
+                for metric_idx, opt_ in enumerate(self.opt['val']['metrics'].values()):
+                    result = calculate_metric(metric_data, opt_)
+                    self.metric_results[folder][int(frame_idx), metric_idx] += result
+
+            # progress bar
+            if rank == 0:
+                for _ in range(world_size):
+                    pbar.update(1)
+                    pbar.set_description(f'Test {folder}: {int(frame_idx) + world_size}/{max_idx}')
+        if rank == 0:
+            pbar.close()
+
+        if with_metrics:
+            if self.opt['dist']:
+                # collect data among GPUs
+                for _, tensor in self.metric_results.items():
+                    dist.reduce(tensor, 0)
+                dist.barrier()
+            else:
+                pass  # assume use one gpu in non-dist testing
+
+            if rank == 0:
+                self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
+
+    def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
+        logger = get_root_logger()
+        logger.warning('nondist_validation is not implemented. Run dist_validation.')
+        self.dist_validation(dataloader, current_iter, tb_logger, save_img)
+
+    def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
+        # ----------------- calculate the average values for each folder, and for each metric  ----------------- #
+        # average all frames for each sub-folder
+        # metric_results_avg is a dict:{
+        #    'folder1': tensor (len(metrics)),
+        #    'folder2': tensor (len(metrics))
+        # }
+        metric_results_avg = {
+            folder: torch.mean(tensor, dim=0).cpu()
+            for (folder, tensor) in self.metric_results.items()
+        }
+        # total_avg_results is a dict: {
+        #    'metric1': float,
+        #    'metric2': float
+        # }
+        total_avg_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
+        for folder, tensor in metric_results_avg.items():
+            for idx, metric in enumerate(total_avg_results.keys()):
+                total_avg_results[metric] += metric_results_avg[folder][idx].item()
+        # average among folders
+        for metric in total_avg_results.keys():
+            total_avg_results[metric] /= len(metric_results_avg)
+            # update the best metric result
+            self._update_best_metric_result(dataset_name, metric, total_avg_results[metric], current_iter)
+
+        # ------------------------------------------ log the metric ------------------------------------------ #
+        log_str = f'Validation {dataset_name}\n'
+        for metric_idx, (metric, value) in enumerate(total_avg_results.items()):
+            log_str += f'\t # {metric}: {value:.4f}'
+            for folder, tensor in metric_results_avg.items():
+                log_str += f'\t # {folder}: {tensor[metric_idx].item():.4f}'
+            if hasattr(self, 'best_metric_results'):
+                log_str += (f'\n\t    Best: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ '
+                            f'{self.best_metric_results[dataset_name][metric]["iter"]} iter')
+            log_str += '\n'
+
+        logger = get_root_logger()
+        logger.info(log_str)
+        if tb_logger:
+            for metric_idx, (metric, value) in enumerate(total_avg_results.items()):
+                tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
+                for folder, tensor in metric_results_avg.items():
+                    tb_logger.add_scalar(f'metrics/{metric}/{folder}', tensor[metric_idx].item(), current_iter)
diff --git a/basicsr/models/video_gan_model.py b/basicsr/models/video_gan_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2adcdeee59e494dd7d1c285919fac5c99cd9efb
--- /dev/null
+++ b/basicsr/models/video_gan_model.py
@@ -0,0 +1,19 @@
+from basicsr.utils.registry import MODEL_REGISTRY
+from .srgan_model import SRGANModel
+from .video_base_model import VideoBaseModel
+
+
+@MODEL_REGISTRY.register()
+class VideoGANModel(SRGANModel, VideoBaseModel):
+    """Video GAN model.
+
+    Use multiple inheritance.
+    It will first use the functions of :class:`SRGANModel`:
+
+    - :func:`init_training_settings`
+    - :func:`setup_optimizers`
+    - :func:`optimize_parameters`
+    - :func:`save`
+
+    Then find functions in :class:`VideoBaseModel`.
+    """
diff --git a/basicsr/models/video_recurrent_gan_model.py b/basicsr/models/video_recurrent_gan_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..74cf81145c50ffafb220d22b51e56746dee5ba41
--- /dev/null
+++ b/basicsr/models/video_recurrent_gan_model.py
@@ -0,0 +1,180 @@
+import torch
+from collections import OrderedDict
+
+from basicsr.archs import build_network
+from basicsr.losses import build_loss
+from basicsr.utils import get_root_logger
+from basicsr.utils.registry import MODEL_REGISTRY
+from .video_recurrent_model import VideoRecurrentModel
+
+
+@MODEL_REGISTRY.register()
+class VideoRecurrentGANModel(VideoRecurrentModel):
+
+    def init_training_settings(self):
+        train_opt = self.opt['train']
+
+        self.ema_decay = train_opt.get('ema_decay', 0)
+        if self.ema_decay > 0:
+            logger = get_root_logger()
+            logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
+            # build network net_g with Exponential Moving Average (EMA)
+            # net_g_ema only used for testing on one GPU and saving.
+            # There is no need to wrap with DistributedDataParallel
+            self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
+            # load pretrained model
+            load_path = self.opt['path'].get('pretrain_network_g', None)
+            if load_path is not None:
+                self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
+            else:
+                self.model_ema(0)  # copy net_g weight
+            self.net_g_ema.eval()
+
+        # define network net_d
+        self.net_d = build_network(self.opt['network_d'])
+        self.net_d = self.model_to_device(self.net_d)
+        self.print_network(self.net_d)
+
+        # load pretrained models
+        load_path = self.opt['path'].get('pretrain_network_d', None)
+        if load_path is not None:
+            param_key = self.opt['path'].get('param_key_d', 'params')
+            self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True), param_key)
+
+        self.net_g.train()
+        self.net_d.train()
+
+        # define losses
+        if train_opt.get('pixel_opt'):
+            self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
+        else:
+            self.cri_pix = None
+
+        if train_opt.get('perceptual_opt'):
+            self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
+        else:
+            self.cri_perceptual = None
+
+        if train_opt.get('gan_opt'):
+            self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
+
+        self.net_d_iters = train_opt.get('net_d_iters', 1)
+        self.net_d_init_iters = train_opt.get('net_d_init_iters', 0)
+
+        # set up optimizers and schedulers
+        self.setup_optimizers()
+        self.setup_schedulers()
+
+    def setup_optimizers(self):
+        train_opt = self.opt['train']
+        if train_opt['fix_flow']:
+            normal_params = []
+            flow_params = []
+            for name, param in self.net_g.named_parameters():
+                if 'spynet' in name:  # The fix_flow now only works for spynet.
+                    flow_params.append(param)
+                else:
+                    normal_params.append(param)
+
+            optim_params = [
+                {  # add flow params first
+                    'params': flow_params,
+                    'lr': train_opt['lr_flow']
+                },
+                {
+                    'params': normal_params,
+                    'lr': train_opt['optim_g']['lr']
+                },
+            ]
+        else:
+            optim_params = self.net_g.parameters()
+
+        # optimizer g
+        optim_type = train_opt['optim_g'].pop('type')
+        self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g'])
+        self.optimizers.append(self.optimizer_g)
+        # optimizer d
+        optim_type = train_opt['optim_d'].pop('type')
+        self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])
+        self.optimizers.append(self.optimizer_d)
+
+    def optimize_parameters(self, current_iter):
+        logger = get_root_logger()
+        # optimize net_g
+        for p in self.net_d.parameters():
+            p.requires_grad = False
+
+        if self.fix_flow_iter:
+            if current_iter == 1:
+                logger.info(f'Fix flow network and feature extractor for {self.fix_flow_iter} iters.')
+                for name, param in self.net_g.named_parameters():
+                    if 'spynet' in name or 'edvr' in name:
+                        param.requires_grad_(False)
+            elif current_iter == self.fix_flow_iter:
+                logger.warning('Train all the parameters.')
+                self.net_g.requires_grad_(True)
+
+        self.optimizer_g.zero_grad()
+        self.output = self.net_g(self.lq)
+
+        _, _, c, h, w = self.output.size()
+
+        l_g_total = 0
+        loss_dict = OrderedDict()
+        if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
+            # pixel loss
+            if self.cri_pix:
+                l_g_pix = self.cri_pix(self.output, self.gt)
+                l_g_total += l_g_pix
+                loss_dict['l_g_pix'] = l_g_pix
+            # perceptual loss
+            if self.cri_perceptual:
+                l_g_percep, l_g_style = self.cri_perceptual(self.output.view(-1, c, h, w), self.gt.view(-1, c, h, w))
+                if l_g_percep is not None:
+                    l_g_total += l_g_percep
+                    loss_dict['l_g_percep'] = l_g_percep
+                if l_g_style is not None:
+                    l_g_total += l_g_style
+                    loss_dict['l_g_style'] = l_g_style
+            # gan loss
+            fake_g_pred = self.net_d(self.output.view(-1, c, h, w))
+            l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
+            l_g_total += l_g_gan
+            loss_dict['l_g_gan'] = l_g_gan
+
+            l_g_total.backward()
+            self.optimizer_g.step()
+
+        # optimize net_d
+        for p in self.net_d.parameters():
+            p.requires_grad = True
+
+        self.optimizer_d.zero_grad()
+        # real
+        # reshape to (b*n, c, h, w)
+        real_d_pred = self.net_d(self.gt.view(-1, c, h, w))
+        l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
+        loss_dict['l_d_real'] = l_d_real
+        loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
+        l_d_real.backward()
+        # fake
+        # reshape to (b*n, c, h, w)
+        fake_d_pred = self.net_d(self.output.view(-1, c, h, w).detach())
+        l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
+        loss_dict['l_d_fake'] = l_d_fake
+        loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
+        l_d_fake.backward()
+        self.optimizer_d.step()
+
+        self.log_dict = self.reduce_loss_dict(loss_dict)
+
+        if self.ema_decay > 0:
+            self.model_ema(decay=self.ema_decay)
+
+    def save(self, epoch, current_iter):
+        if self.ema_decay > 0:
+            self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
+        else:
+            self.save_network(self.net_g, 'net_g', current_iter)
+        self.save_network(self.net_d, 'net_d', current_iter)
+        self.save_training_state(epoch, current_iter)
diff --git a/basicsr/models/video_recurrent_model.py b/basicsr/models/video_recurrent_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..796ee57d5aeb84e81fe8dc769facc8339798cc3e
--- /dev/null
+++ b/basicsr/models/video_recurrent_model.py
@@ -0,0 +1,197 @@
+import torch
+from collections import Counter
+from os import path as osp
+from torch import distributed as dist
+from tqdm import tqdm
+
+from basicsr.metrics import calculate_metric
+from basicsr.utils import get_root_logger, imwrite, tensor2img
+from basicsr.utils.dist_util import get_dist_info
+from basicsr.utils.registry import MODEL_REGISTRY
+from .video_base_model import VideoBaseModel
+
+
+@MODEL_REGISTRY.register()
+class VideoRecurrentModel(VideoBaseModel):
+
+    def __init__(self, opt):
+        super(VideoRecurrentModel, self).__init__(opt)
+        if self.is_train:
+            self.fix_flow_iter = opt['train'].get('fix_flow')
+
+    def setup_optimizers(self):
+        train_opt = self.opt['train']
+        flow_lr_mul = train_opt.get('flow_lr_mul', 1)
+        logger = get_root_logger()
+        logger.info(f'Multiple the learning rate for flow network with {flow_lr_mul}.')
+        if flow_lr_mul == 1:
+            optim_params = self.net_g.parameters()
+        else:  # separate flow params and normal params for different lr
+            normal_params = []
+            flow_params = []
+            for name, param in self.net_g.named_parameters():
+                if 'spynet' in name:
+                    flow_params.append(param)
+                else:
+                    normal_params.append(param)
+            optim_params = [
+                {  # add normal params first
+                    'params': normal_params,
+                    'lr': train_opt['optim_g']['lr']
+                },
+                {
+                    'params': flow_params,
+                    'lr': train_opt['optim_g']['lr'] * flow_lr_mul
+                },
+            ]
+
+        optim_type = train_opt['optim_g'].pop('type')
+        self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g'])
+        self.optimizers.append(self.optimizer_g)
+
+    def optimize_parameters(self, current_iter):
+        if self.fix_flow_iter:
+            logger = get_root_logger()
+            if current_iter == 1:
+                logger.info(f'Fix flow network and feature extractor for {self.fix_flow_iter} iters.')
+                for name, param in self.net_g.named_parameters():
+                    if 'spynet' in name or 'edvr' in name:
+                        param.requires_grad_(False)
+            elif current_iter == self.fix_flow_iter:
+                logger.warning('Train all the parameters.')
+                self.net_g.requires_grad_(True)
+
+        super(VideoRecurrentModel, self).optimize_parameters(current_iter)
+
+    def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
+        dataset = dataloader.dataset
+        dataset_name = dataset.opt['name']
+        with_metrics = self.opt['val']['metrics'] is not None
+        # initialize self.metric_results
+        # It is a dict: {
+        #    'folder1': tensor (num_frame x len(metrics)),
+        #    'folder2': tensor (num_frame x len(metrics))
+        # }
+        if with_metrics:
+            if not hasattr(self, 'metric_results'):  # only execute in the first run
+                self.metric_results = {}
+                num_frame_each_folder = Counter(dataset.data_info['folder'])
+                for folder, num_frame in num_frame_each_folder.items():
+                    self.metric_results[folder] = torch.zeros(
+                        num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device='cuda')
+            # initialize the best metric results
+            self._initialize_best_metric_results(dataset_name)
+        # zero self.metric_results
+        rank, world_size = get_dist_info()
+        if with_metrics:
+            for _, tensor in self.metric_results.items():
+                tensor.zero_()
+
+        metric_data = dict()
+        num_folders = len(dataset)
+        num_pad = (world_size - (num_folders % world_size)) % world_size
+        if rank == 0:
+            pbar = tqdm(total=len(dataset), unit='folder')
+        # Will evaluate (num_folders + num_pad) times, but only the first num_folders results will be recorded.
+        # (To avoid wait-dead)
+        for i in range(rank, num_folders + num_pad, world_size):
+            idx = min(i, num_folders - 1)
+            val_data = dataset[idx]
+            folder = val_data['folder']
+
+            # compute outputs
+            val_data['lq'].unsqueeze_(0)
+            val_data['gt'].unsqueeze_(0)
+            self.feed_data(val_data)
+            val_data['lq'].squeeze_(0)
+            val_data['gt'].squeeze_(0)
+
+            self.test()
+            visuals = self.get_current_visuals()
+
+            # tentative for out of GPU memory
+            del self.lq
+            del self.output
+            if 'gt' in visuals:
+                del self.gt
+            torch.cuda.empty_cache()
+
+            if self.center_frame_only:
+                visuals['result'] = visuals['result'].unsqueeze(1)
+                if 'gt' in visuals:
+                    visuals['gt'] = visuals['gt'].unsqueeze(1)
+
+            # evaluate
+            if i < num_folders:
+                for idx in range(visuals['result'].size(1)):
+                    result = visuals['result'][0, idx, :, :, :]
+                    result_img = tensor2img([result])  # uint8, bgr
+                    metric_data['img'] = result_img
+                    if 'gt' in visuals:
+                        gt = visuals['gt'][0, idx, :, :, :]
+                        gt_img = tensor2img([gt])  # uint8, bgr
+                        metric_data['img2'] = gt_img
+
+                    if save_img:
+                        if self.opt['is_train']:
+                            raise NotImplementedError('saving image is not supported during training.')
+                        else:
+                            if self.center_frame_only:  # vimeo-90k
+                                clip_ = val_data['lq_path'].split('/')[-3]
+                                seq_ = val_data['lq_path'].split('/')[-2]
+                                name_ = f'{clip_}_{seq_}'
+                                img_path = osp.join(self.opt['path']['visualization'], dataset_name, folder,
+                                                    f"{name_}_{self.opt['name']}.png")
+                            else:  # others
+                                img_path = osp.join(self.opt['path']['visualization'], dataset_name, folder,
+                                                    f"{idx:08d}_{self.opt['name']}.png")
+                            # image name only for REDS dataset
+                        imwrite(result_img, img_path)
+
+                    # calculate metrics
+                    if with_metrics:
+                        for metric_idx, opt_ in enumerate(self.opt['val']['metrics'].values()):
+                            result = calculate_metric(metric_data, opt_)
+                            self.metric_results[folder][idx, metric_idx] += result
+
+                # progress bar
+                if rank == 0:
+                    for _ in range(world_size):
+                        pbar.update(1)
+                        pbar.set_description(f'Folder: {folder}')
+
+        if rank == 0:
+            pbar.close()
+
+        if with_metrics:
+            if self.opt['dist']:
+                # collect data among GPUs
+                for _, tensor in self.metric_results.items():
+                    dist.reduce(tensor, 0)
+                dist.barrier()
+
+            if rank == 0:
+                self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
+
+    def test(self):
+        n = self.lq.size(1)
+        self.net_g.eval()
+
+        flip_seq = self.opt['val'].get('flip_seq', False)
+        self.center_frame_only = self.opt['val'].get('center_frame_only', False)
+
+        if flip_seq:
+            self.lq = torch.cat([self.lq, self.lq.flip(1)], dim=1)
+
+        with torch.no_grad():
+            self.output = self.net_g(self.lq)
+
+        if flip_seq:
+            output_1 = self.output[:, :n, :, :, :]
+            output_2 = self.output[:, n:, :, :, :].flip(1)
+            self.output = 0.5 * (output_1 + output_2)
+
+        if self.center_frame_only:
+            self.output = self.output[:, n // 2, :, :, :]
+
+        self.net_g.train()
diff --git a/basicsr/ops/__init__.py b/basicsr/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/basicsr/ops/__pycache__/__init__.cpython-310.pyc b/basicsr/ops/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bd770add360a14cfbfc2f9f6d6eaf9fc76707b13
Binary files /dev/null and b/basicsr/ops/__pycache__/__init__.cpython-310.pyc differ
diff --git a/basicsr/ops/dcn/__init__.py b/basicsr/ops/dcn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..32e3592f896d61b4127e09d0476381b9d55e32ff
--- /dev/null
+++ b/basicsr/ops/dcn/__init__.py
@@ -0,0 +1,7 @@
+from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv,
+                          modulated_deform_conv)
+
+__all__ = [
+    'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv',
+    'modulated_deform_conv'
+]
diff --git a/basicsr/ops/dcn/__pycache__/__init__.cpython-310.pyc b/basicsr/ops/dcn/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..00c76fac1cc097dc5695c9bab30c17258133c742
Binary files /dev/null and b/basicsr/ops/dcn/__pycache__/__init__.cpython-310.pyc differ
diff --git a/basicsr/ops/dcn/__pycache__/deform_conv.cpython-310.pyc b/basicsr/ops/dcn/__pycache__/deform_conv.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6a0c47a985176753cf6d28415c1dc14e964fd515
Binary files /dev/null and b/basicsr/ops/dcn/__pycache__/deform_conv.cpython-310.pyc differ
diff --git a/basicsr/ops/dcn/deform_conv.py b/basicsr/ops/dcn/deform_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..6268ca825d59ef4a30d4d2156c4438cbbe9b3c1e
--- /dev/null
+++ b/basicsr/ops/dcn/deform_conv.py
@@ -0,0 +1,379 @@
+import math
+import os
+import torch
+from torch import nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn import functional as F
+from torch.nn.modules.utils import _pair, _single
+
+BASICSR_JIT = os.getenv('BASICSR_JIT')
+if BASICSR_JIT == 'True':
+    from torch.utils.cpp_extension import load
+    module_path = os.path.dirname(__file__)
+    deform_conv_ext = load(
+        'deform_conv',
+        sources=[
+            os.path.join(module_path, 'src', 'deform_conv_ext.cpp'),
+            os.path.join(module_path, 'src', 'deform_conv_cuda.cpp'),
+            os.path.join(module_path, 'src', 'deform_conv_cuda_kernel.cu'),
+        ],
+    )
+else:
+    try:
+        from . import deform_conv_ext
+    except ImportError:
+        pass
+        # avoid annoying print output
+        # print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n '
+        #       '1. compile with BASICSR_EXT=True. or\n '
+        #       '2. set BASICSR_JIT=True during running')
+
+
+class DeformConvFunction(Function):
+
+    @staticmethod
+    def forward(ctx,
+                input,
+                offset,
+                weight,
+                stride=1,
+                padding=0,
+                dilation=1,
+                groups=1,
+                deformable_groups=1,
+                im2col_step=64):
+        if input is not None and input.dim() != 4:
+            raise ValueError(f'Expected 4D tensor as input, got {input.dim()}D tensor instead.')
+        ctx.stride = _pair(stride)
+        ctx.padding = _pair(padding)
+        ctx.dilation = _pair(dilation)
+        ctx.groups = groups
+        ctx.deformable_groups = deformable_groups
+        ctx.im2col_step = im2col_step
+
+        ctx.save_for_backward(input, offset, weight)
+
+        output = input.new_empty(DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride))
+
+        ctx.bufs_ = [input.new_empty(0), input.new_empty(0)]  # columns, ones
+
+        if not input.is_cuda:
+            raise NotImplementedError
+        else:
+            cur_im2col_step = min(ctx.im2col_step, input.shape[0])
+            assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
+            deform_conv_ext.deform_conv_forward(input, weight,
+                                                offset, output, ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
+                                                weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
+                                                ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
+                                                ctx.deformable_groups, cur_im2col_step)
+        return output
+
+    @staticmethod
+    @once_differentiable
+    def backward(ctx, grad_output):
+        input, offset, weight = ctx.saved_tensors
+
+        grad_input = grad_offset = grad_weight = None
+
+        if not grad_output.is_cuda:
+            raise NotImplementedError
+        else:
+            cur_im2col_step = min(ctx.im2col_step, input.shape[0])
+            assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
+
+            if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
+                grad_input = torch.zeros_like(input)
+                grad_offset = torch.zeros_like(offset)
+                deform_conv_ext.deform_conv_backward_input(input, offset, grad_output, grad_input,
+                                                           grad_offset, weight, ctx.bufs_[0], weight.size(3),
+                                                           weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
+                                                           ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
+                                                           ctx.deformable_groups, cur_im2col_step)
+
+            if ctx.needs_input_grad[2]:
+                grad_weight = torch.zeros_like(weight)
+                deform_conv_ext.deform_conv_backward_parameters(input, offset, grad_output, grad_weight,
+                                                                ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
+                                                                weight.size(2), ctx.stride[1], ctx.stride[0],
+                                                                ctx.padding[1], ctx.padding[0], ctx.dilation[1],
+                                                                ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1,
+                                                                cur_im2col_step)
+
+        return (grad_input, grad_offset, grad_weight, None, None, None, None, None)
+
+    @staticmethod
+    def _output_size(input, weight, padding, dilation, stride):
+        channels = weight.size(0)
+        output_size = (input.size(0), channels)
+        for d in range(input.dim() - 2):
+            in_size = input.size(d + 2)
+            pad = padding[d]
+            kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
+            stride_ = stride[d]
+            output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
+        if not all(map(lambda s: s > 0, output_size)):
+            raise ValueError(f'convolution input is too small (output would be {"x".join(map(str, output_size))})')
+        return output_size
+
+
+class ModulatedDeformConvFunction(Function):
+
+    @staticmethod
+    def forward(ctx,
+                input,
+                offset,
+                mask,
+                weight,
+                bias=None,
+                stride=1,
+                padding=0,
+                dilation=1,
+                groups=1,
+                deformable_groups=1):
+        ctx.stride = stride
+        ctx.padding = padding
+        ctx.dilation = dilation
+        ctx.groups = groups
+        ctx.deformable_groups = deformable_groups
+        ctx.with_bias = bias is not None
+        if not ctx.with_bias:
+            bias = input.new_empty(1)  # fake tensor
+        if not input.is_cuda:
+            raise NotImplementedError
+        if weight.requires_grad or mask.requires_grad or offset.requires_grad or input.requires_grad:
+            ctx.save_for_backward(input, offset, mask, weight, bias)
+        output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
+        ctx._bufs = [input.new_empty(0), input.new_empty(0)]
+        deform_conv_ext.modulated_deform_conv_forward(input, weight, bias, ctx._bufs[0], offset, mask, output,
+                                                      ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride,
+                                                      ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
+                                                      ctx.groups, ctx.deformable_groups, ctx.with_bias)
+        return output
+
+    @staticmethod
+    @once_differentiable
+    def backward(ctx, grad_output):
+        if not grad_output.is_cuda:
+            raise NotImplementedError
+        input, offset, mask, weight, bias = ctx.saved_tensors
+        grad_input = torch.zeros_like(input)
+        grad_offset = torch.zeros_like(offset)
+        grad_mask = torch.zeros_like(mask)
+        grad_weight = torch.zeros_like(weight)
+        grad_bias = torch.zeros_like(bias)
+        deform_conv_ext.modulated_deform_conv_backward(input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1],
+                                                       grad_input, grad_weight, grad_bias, grad_offset, grad_mask,
+                                                       grad_output, weight.shape[2], weight.shape[3], ctx.stride,
+                                                       ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
+                                                       ctx.groups, ctx.deformable_groups, ctx.with_bias)
+        if not ctx.with_bias:
+            grad_bias = None
+
+        return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None, None)
+
+    @staticmethod
+    def _infer_shape(ctx, input, weight):
+        n = input.size(0)
+        channels_out = weight.size(0)
+        height, width = input.shape[2:4]
+        kernel_h, kernel_w = weight.shape[2:4]
+        height_out = (height + 2 * ctx.padding - (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1
+        width_out = (width + 2 * ctx.padding - (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1
+        return n, channels_out, height_out, width_out
+
+
+deform_conv = DeformConvFunction.apply
+modulated_deform_conv = ModulatedDeformConvFunction.apply
+
+
+class DeformConv(nn.Module):
+
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 kernel_size,
+                 stride=1,
+                 padding=0,
+                 dilation=1,
+                 groups=1,
+                 deformable_groups=1,
+                 bias=False):
+        super(DeformConv, self).__init__()
+
+        assert not bias
+        assert in_channels % groups == 0, f'in_channels {in_channels} is not divisible by groups {groups}'
+        assert out_channels % groups == 0, f'out_channels {out_channels} is not divisible by groups {groups}'
+
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.kernel_size = _pair(kernel_size)
+        self.stride = _pair(stride)
+        self.padding = _pair(padding)
+        self.dilation = _pair(dilation)
+        self.groups = groups
+        self.deformable_groups = deformable_groups
+        # enable compatibility with nn.Conv2d
+        self.transposed = False
+        self.output_padding = _single(0)
+
+        self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size))
+
+        self.reset_parameters()
+
+    def reset_parameters(self):
+        n = self.in_channels
+        for k in self.kernel_size:
+            n *= k
+        stdv = 1. / math.sqrt(n)
+        self.weight.data.uniform_(-stdv, stdv)
+
+    def forward(self, x, offset):
+        # To fix an assert error in deform_conv_cuda.cpp:128
+        # input image is smaller than kernel
+        input_pad = (x.size(2) < self.kernel_size[0] or x.size(3) < self.kernel_size[1])
+        if input_pad:
+            pad_h = max(self.kernel_size[0] - x.size(2), 0)
+            pad_w = max(self.kernel_size[1] - x.size(3), 0)
+            x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
+            offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
+        out = deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups,
+                          self.deformable_groups)
+        if input_pad:
+            out = out[:, :, :out.size(2) - pad_h, :out.size(3) - pad_w].contiguous()
+        return out
+
+
+class DeformConvPack(DeformConv):
+    """A Deformable Conv Encapsulation that acts as normal Conv layers.
+
+    Args:
+        in_channels (int): Same as nn.Conv2d.
+        out_channels (int): Same as nn.Conv2d.
+        kernel_size (int or tuple[int]): Same as nn.Conv2d.
+        stride (int or tuple[int]): Same as nn.Conv2d.
+        padding (int or tuple[int]): Same as nn.Conv2d.
+        dilation (int or tuple[int]): Same as nn.Conv2d.
+        groups (int): Same as nn.Conv2d.
+        bias (bool or str): If specified as `auto`, it will be decided by the
+            norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
+            False.
+    """
+
+    _version = 2
+
+    def __init__(self, *args, **kwargs):
+        super(DeformConvPack, self).__init__(*args, **kwargs)
+
+        self.conv_offset = nn.Conv2d(
+            self.in_channels,
+            self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
+            kernel_size=self.kernel_size,
+            stride=_pair(self.stride),
+            padding=_pair(self.padding),
+            dilation=_pair(self.dilation),
+            bias=True)
+        self.init_offset()
+
+    def init_offset(self):
+        self.conv_offset.weight.data.zero_()
+        self.conv_offset.bias.data.zero_()
+
+    def forward(self, x):
+        offset = self.conv_offset(x)
+        return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups,
+                           self.deformable_groups)
+
+
+class ModulatedDeformConv(nn.Module):
+
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 kernel_size,
+                 stride=1,
+                 padding=0,
+                 dilation=1,
+                 groups=1,
+                 deformable_groups=1,
+                 bias=True):
+        super(ModulatedDeformConv, self).__init__()
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.kernel_size = _pair(kernel_size)
+        self.stride = stride
+        self.padding = padding
+        self.dilation = dilation
+        self.groups = groups
+        self.deformable_groups = deformable_groups
+        self.with_bias = bias
+        # enable compatibility with nn.Conv2d
+        self.transposed = False
+        self.output_padding = _single(0)
+
+        self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
+        if bias:
+            self.bias = nn.Parameter(torch.Tensor(out_channels))
+        else:
+            self.register_parameter('bias', None)
+        self.init_weights()
+
+    def init_weights(self):
+        n = self.in_channels
+        for k in self.kernel_size:
+            n *= k
+        stdv = 1. / math.sqrt(n)
+        self.weight.data.uniform_(-stdv, stdv)
+        if self.bias is not None:
+            self.bias.data.zero_()
+
+    def forward(self, x, offset, mask):
+        return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation,
+                                     self.groups, self.deformable_groups)
+
+
+class ModulatedDeformConvPack(ModulatedDeformConv):
+    """A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers.
+
+    Args:
+        in_channels (int): Same as nn.Conv2d.
+        out_channels (int): Same as nn.Conv2d.
+        kernel_size (int or tuple[int]): Same as nn.Conv2d.
+        stride (int or tuple[int]): Same as nn.Conv2d.
+        padding (int or tuple[int]): Same as nn.Conv2d.
+        dilation (int or tuple[int]): Same as nn.Conv2d.
+        groups (int): Same as nn.Conv2d.
+        bias (bool or str): If specified as `auto`, it will be decided by the
+            norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
+            False.
+    """
+
+    _version = 2
+
+    def __init__(self, *args, **kwargs):
+        super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
+
+        self.conv_offset = nn.Conv2d(
+            self.in_channels,
+            self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
+            kernel_size=self.kernel_size,
+            stride=_pair(self.stride),
+            padding=_pair(self.padding),
+            dilation=_pair(self.dilation),
+            bias=True)
+        self.init_weights()
+
+    def init_weights(self):
+        super(ModulatedDeformConvPack, self).init_weights()
+        if hasattr(self, 'conv_offset'):
+            self.conv_offset.weight.data.zero_()
+            self.conv_offset.bias.data.zero_()
+
+    def forward(self, x):
+        out = self.conv_offset(x)
+        o1, o2, mask = torch.chunk(out, 3, dim=1)
+        offset = torch.cat((o1, o2), dim=1)
+        mask = torch.sigmoid(mask)
+        return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation,
+                                     self.groups, self.deformable_groups)
diff --git a/basicsr/ops/fused_act/__init__.py b/basicsr/ops/fused_act/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..241dc0754fae7d88dbbd9a02e665ca30a73c7422
--- /dev/null
+++ b/basicsr/ops/fused_act/__init__.py
@@ -0,0 +1,3 @@
+from .fused_act import FusedLeakyReLU, fused_leaky_relu
+
+__all__ = ['FusedLeakyReLU', 'fused_leaky_relu']
diff --git a/basicsr/ops/fused_act/__pycache__/__init__.cpython-310.pyc b/basicsr/ops/fused_act/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..841671207d25a059649c695f10d6676fa41db268
Binary files /dev/null and b/basicsr/ops/fused_act/__pycache__/__init__.cpython-310.pyc differ
diff --git a/basicsr/ops/fused_act/__pycache__/fused_act.cpython-310.pyc b/basicsr/ops/fused_act/__pycache__/fused_act.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c481e75a5df24eacc08c57a01d1b74e60328004e
Binary files /dev/null and b/basicsr/ops/fused_act/__pycache__/fused_act.cpython-310.pyc differ
diff --git a/basicsr/ops/fused_act/fused_act.py b/basicsr/ops/fused_act/fused_act.py
new file mode 100644
index 0000000000000000000000000000000000000000..88edc445484b71119dc22a258e83aef49ce39b07
--- /dev/null
+++ b/basicsr/ops/fused_act/fused_act.py
@@ -0,0 +1,95 @@
+# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501
+
+import os
+import torch
+from torch import nn
+from torch.autograd import Function
+
+BASICSR_JIT = os.getenv('BASICSR_JIT')
+if BASICSR_JIT == 'True':
+    from torch.utils.cpp_extension import load
+    module_path = os.path.dirname(__file__)
+    fused_act_ext = load(
+        'fused',
+        sources=[
+            os.path.join(module_path, 'src', 'fused_bias_act.cpp'),
+            os.path.join(module_path, 'src', 'fused_bias_act_kernel.cu'),
+        ],
+    )
+else:
+    try:
+        from . import fused_act_ext
+    except ImportError:
+        pass
+        # avoid annoying print output
+        # print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n '
+        #       '1. compile with BASICSR_EXT=True. or\n '
+        #       '2. set BASICSR_JIT=True during running')
+
+
+class FusedLeakyReLUFunctionBackward(Function):
+
+    @staticmethod
+    def forward(ctx, grad_output, out, negative_slope, scale):
+        ctx.save_for_backward(out)
+        ctx.negative_slope = negative_slope
+        ctx.scale = scale
+
+        empty = grad_output.new_empty(0)
+
+        grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale)
+
+        dim = [0]
+
+        if grad_input.ndim > 2:
+            dim += list(range(2, grad_input.ndim))
+
+        grad_bias = grad_input.sum(dim).detach()
+
+        return grad_input, grad_bias
+
+    @staticmethod
+    def backward(ctx, gradgrad_input, gradgrad_bias):
+        out, = ctx.saved_tensors
+        gradgrad_out = fused_act_ext.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope,
+                                                    ctx.scale)
+
+        return gradgrad_out, None, None, None
+
+
+class FusedLeakyReLUFunction(Function):
+
+    @staticmethod
+    def forward(ctx, input, bias, negative_slope, scale):
+        empty = input.new_empty(0)
+        out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
+        ctx.save_for_backward(out)
+        ctx.negative_slope = negative_slope
+        ctx.scale = scale
+
+        return out
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        out, = ctx.saved_tensors
+
+        grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale)
+
+        return grad_input, grad_bias, None, None
+
+
+class FusedLeakyReLU(nn.Module):
+
+    def __init__(self, channel, negative_slope=0.2, scale=2**0.5):
+        super().__init__()
+
+        self.bias = nn.Parameter(torch.zeros(channel))
+        self.negative_slope = negative_slope
+        self.scale = scale
+
+    def forward(self, input):
+        return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
+
+
+def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5):
+    return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
diff --git a/basicsr/ops/upfirdn2d/__init__.py b/basicsr/ops/upfirdn2d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..397e85bea063e97fc4c12ad4d3e15669b69290bd
--- /dev/null
+++ b/basicsr/ops/upfirdn2d/__init__.py
@@ -0,0 +1,3 @@
+from .upfirdn2d import upfirdn2d
+
+__all__ = ['upfirdn2d']
diff --git a/basicsr/ops/upfirdn2d/__pycache__/__init__.cpython-310.pyc b/basicsr/ops/upfirdn2d/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..160f3fc62d7714850fedc2303303884069f8d28c
Binary files /dev/null and b/basicsr/ops/upfirdn2d/__pycache__/__init__.cpython-310.pyc differ
diff --git a/basicsr/ops/upfirdn2d/__pycache__/upfirdn2d.cpython-310.pyc b/basicsr/ops/upfirdn2d/__pycache__/upfirdn2d.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..001d0881eeb6bcb7011d47d06e9c48b8c14b628a
Binary files /dev/null and b/basicsr/ops/upfirdn2d/__pycache__/upfirdn2d.cpython-310.pyc differ
diff --git a/basicsr/ops/upfirdn2d/upfirdn2d.py b/basicsr/ops/upfirdn2d/upfirdn2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6122d59aa32fd52e956bd36200ba79af4a17b17
--- /dev/null
+++ b/basicsr/ops/upfirdn2d/upfirdn2d.py
@@ -0,0 +1,192 @@
+# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py  # noqa:E501
+
+import os
+import torch
+from torch.autograd import Function
+from torch.nn import functional as F
+
+BASICSR_JIT = os.getenv('BASICSR_JIT')
+if BASICSR_JIT == 'True':
+    from torch.utils.cpp_extension import load
+    module_path = os.path.dirname(__file__)
+    upfirdn2d_ext = load(
+        'upfirdn2d',
+        sources=[
+            os.path.join(module_path, 'src', 'upfirdn2d.cpp'),
+            os.path.join(module_path, 'src', 'upfirdn2d_kernel.cu'),
+        ],
+    )
+else:
+    try:
+        from . import upfirdn2d_ext
+    except ImportError:
+        pass
+        # avoid annoying print output
+        # print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n '
+        #       '1. compile with BASICSR_EXT=True. or\n '
+        #       '2. set BASICSR_JIT=True during running')
+
+
+class UpFirDn2dBackward(Function):
+
+    @staticmethod
+    def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size):
+
+        up_x, up_y = up
+        down_x, down_y = down
+        g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
+
+        grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
+
+        grad_input = upfirdn2d_ext.upfirdn2d(
+            grad_output,
+            grad_kernel,
+            down_x,
+            down_y,
+            up_x,
+            up_y,
+            g_pad_x0,
+            g_pad_x1,
+            g_pad_y0,
+            g_pad_y1,
+        )
+        grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
+
+        ctx.save_for_backward(kernel)
+
+        pad_x0, pad_x1, pad_y0, pad_y1 = pad
+
+        ctx.up_x = up_x
+        ctx.up_y = up_y
+        ctx.down_x = down_x
+        ctx.down_y = down_y
+        ctx.pad_x0 = pad_x0
+        ctx.pad_x1 = pad_x1
+        ctx.pad_y0 = pad_y0
+        ctx.pad_y1 = pad_y1
+        ctx.in_size = in_size
+        ctx.out_size = out_size
+
+        return grad_input
+
+    @staticmethod
+    def backward(ctx, gradgrad_input):
+        kernel, = ctx.saved_tensors
+
+        gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
+
+        gradgrad_out = upfirdn2d_ext.upfirdn2d(
+            gradgrad_input,
+            kernel,
+            ctx.up_x,
+            ctx.up_y,
+            ctx.down_x,
+            ctx.down_y,
+            ctx.pad_x0,
+            ctx.pad_x1,
+            ctx.pad_y0,
+            ctx.pad_y1,
+        )
+        # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0],
+        #                                  ctx.out_size[1], ctx.in_size[3])
+        gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1])
+
+        return gradgrad_out, None, None, None, None, None, None, None, None
+
+
+class UpFirDn2d(Function):
+
+    @staticmethod
+    def forward(ctx, input, kernel, up, down, pad):
+        up_x, up_y = up
+        down_x, down_y = down
+        pad_x0, pad_x1, pad_y0, pad_y1 = pad
+
+        kernel_h, kernel_w = kernel.shape
+        _, channel, in_h, in_w = input.shape
+        ctx.in_size = input.shape
+
+        input = input.reshape(-1, in_h, in_w, 1)
+
+        ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
+
+        out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
+        out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+        ctx.out_size = (out_h, out_w)
+
+        ctx.up = (up_x, up_y)
+        ctx.down = (down_x, down_y)
+        ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
+
+        g_pad_x0 = kernel_w - pad_x0 - 1
+        g_pad_y0 = kernel_h - pad_y0 - 1
+        g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
+        g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
+
+        ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
+
+        out = upfirdn2d_ext.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1)
+        # out = out.view(major, out_h, out_w, minor)
+        out = out.view(-1, channel, out_h, out_w)
+
+        return out
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        kernel, grad_kernel = ctx.saved_tensors
+
+        grad_input = UpFirDn2dBackward.apply(
+            grad_output,
+            kernel,
+            grad_kernel,
+            ctx.up,
+            ctx.down,
+            ctx.pad,
+            ctx.g_pad,
+            ctx.in_size,
+            ctx.out_size,
+        )
+
+        return grad_input, None, None, None, None
+
+
+def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
+    if input.device.type == 'cpu':
+        out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
+    else:
+        out = UpFirDn2d.apply(input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]))
+
+    return out
+
+
+def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
+    _, channel, in_h, in_w = input.shape
+    input = input.reshape(-1, in_h, in_w, 1)
+
+    _, in_h, in_w, minor = input.shape
+    kernel_h, kernel_w = kernel.shape
+
+    out = input.view(-1, in_h, 1, in_w, 1, minor)
+    out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
+    out = out.view(-1, in_h * up_y, in_w * up_x, minor)
+
+    out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
+    out = out[:, max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ]
+
+    out = out.permute(0, 3, 1, 2)
+    out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
+    w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
+    out = F.conv2d(out, w)
+    out = out.reshape(
+        -1,
+        minor,
+        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
+        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
+    )
+    out = out.permute(0, 2, 3, 1)
+    out = out[:, ::down_y, ::down_x, :]
+
+    out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
+    out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+
+    return out.view(-1, channel, out_h, out_w)
diff --git a/basicsr/test.py b/basicsr/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..53cb3b7aa860c90518e15ba76e1a55fdf404bcc2
--- /dev/null
+++ b/basicsr/test.py
@@ -0,0 +1,45 @@
+import logging
+import torch
+from os import path as osp
+
+from basicsr.data import build_dataloader, build_dataset
+from basicsr.models import build_model
+from basicsr.utils import get_env_info, get_root_logger, get_time_str, make_exp_dirs
+from basicsr.utils.options import dict2str, parse_options
+
+
+def test_pipeline(root_path):
+    # parse options, set distributed setting, set ramdom seed
+    opt, _ = parse_options(root_path, is_train=False)
+
+    torch.backends.cudnn.benchmark = True
+    # torch.backends.cudnn.deterministic = True
+
+    # mkdir and initialize loggers
+    make_exp_dirs(opt)
+    log_file = osp.join(opt['path']['log'], f"test_{opt['name']}_{get_time_str()}.log")
+    logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
+    logger.info(get_env_info())
+    logger.info(dict2str(opt))
+
+    # create test dataset and dataloader
+    test_loaders = []
+    for _, dataset_opt in sorted(opt['datasets'].items()):
+        test_set = build_dataset(dataset_opt)
+        test_loader = build_dataloader(
+            test_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed'])
+        logger.info(f"Number of test images in {dataset_opt['name']}: {len(test_set)}")
+        test_loaders.append(test_loader)
+
+    # create model
+    model = build_model(opt)
+
+    for test_loader in test_loaders:
+        test_set_name = test_loader.dataset.opt['name']
+        logger.info(f'Testing {test_set_name}...')
+        model.validation(test_loader, current_iter=opt['name'], tb_logger=None, save_img=opt['val']['save_img'])
+
+
+if __name__ == '__main__':
+    root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
+    test_pipeline(root_path)
diff --git a/basicsr/train.py b/basicsr/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..e02d98fe07f8c2924dda5b49f95adfa21990fa91
--- /dev/null
+++ b/basicsr/train.py
@@ -0,0 +1,215 @@
+import datetime
+import logging
+import math
+import time
+import torch
+from os import path as osp
+
+from basicsr.data import build_dataloader, build_dataset
+from basicsr.data.data_sampler import EnlargedSampler
+from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher
+from basicsr.models import build_model
+from basicsr.utils import (AvgTimer, MessageLogger, check_resume, get_env_info, get_root_logger, get_time_str,
+                           init_tb_logger, init_wandb_logger, make_exp_dirs, mkdir_and_rename, scandir)
+from basicsr.utils.options import copy_opt_file, dict2str, parse_options
+
+
+def init_tb_loggers(opt):
+    # initialize wandb logger before tensorboard logger to allow proper sync
+    if (opt['logger'].get('wandb') is not None) and (opt['logger']['wandb'].get('project')
+                                                     is not None) and ('debug' not in opt['name']):
+        assert opt['logger'].get('use_tb_logger') is True, ('should turn on tensorboard when using wandb')
+        init_wandb_logger(opt)
+    tb_logger = None
+    if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name']:
+        tb_logger = init_tb_logger(log_dir=osp.join(opt['root_path'], 'tb_logger', opt['name']))
+    return tb_logger
+
+
+def create_train_val_dataloader(opt, logger):
+    # create train and val dataloaders
+    train_loader, val_loaders = None, []
+    for phase, dataset_opt in opt['datasets'].items():
+        if phase == 'train':
+            dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1)
+            train_set = build_dataset(dataset_opt)
+            train_sampler = EnlargedSampler(train_set, opt['world_size'], opt['rank'], dataset_enlarge_ratio)
+            train_loader = build_dataloader(
+                train_set,
+                dataset_opt,
+                num_gpu=opt['num_gpu'],
+                dist=opt['dist'],
+                sampler=train_sampler,
+                seed=opt['manual_seed'])
+
+            num_iter_per_epoch = math.ceil(
+                len(train_set) * dataset_enlarge_ratio / (dataset_opt['batch_size_per_gpu'] * opt['world_size']))
+            total_iters = int(opt['train']['total_iter'])
+            total_epochs = math.ceil(total_iters / (num_iter_per_epoch))
+            logger.info('Training statistics:'
+                        f'\n\tNumber of train images: {len(train_set)}'
+                        f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}'
+                        f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}'
+                        f'\n\tWorld size (gpu number): {opt["world_size"]}'
+                        f'\n\tRequire iter number per epoch: {num_iter_per_epoch}'
+                        f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.')
+        elif phase.split('_')[0] == 'val':
+            val_set = build_dataset(dataset_opt)
+            val_loader = build_dataloader(
+                val_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed'])
+            logger.info(f'Number of val images/folders in {dataset_opt["name"]}: {len(val_set)}')
+            val_loaders.append(val_loader)
+        else:
+            raise ValueError(f'Dataset phase {phase} is not recognized.')
+
+    return train_loader, train_sampler, val_loaders, total_epochs, total_iters
+
+
+def load_resume_state(opt):
+    resume_state_path = None
+    if opt['auto_resume']:
+        state_path = osp.join('experiments', opt['name'], 'training_states')
+        if osp.isdir(state_path):
+            states = list(scandir(state_path, suffix='state', recursive=False, full_path=False))
+            if len(states) != 0:
+                states = [float(v.split('.state')[0]) for v in states]
+                resume_state_path = osp.join(state_path, f'{max(states):.0f}.state')
+                opt['path']['resume_state'] = resume_state_path
+    else:
+        if opt['path'].get('resume_state'):
+            resume_state_path = opt['path']['resume_state']
+
+    if resume_state_path is None:
+        resume_state = None
+    else:
+        device_id = torch.cuda.current_device()
+        resume_state = torch.load(resume_state_path, map_location=lambda storage, loc: storage.cuda(device_id))
+        check_resume(opt, resume_state['iter'])
+    return resume_state
+
+
+def train_pipeline(root_path):
+    # parse options, set distributed setting, set random seed
+    opt, args = parse_options(root_path, is_train=True)
+    opt['root_path'] = root_path
+
+    torch.backends.cudnn.benchmark = True
+    # torch.backends.cudnn.deterministic = True
+
+    # load resume states if necessary
+    resume_state = load_resume_state(opt)
+    # mkdir for experiments and logger
+    if resume_state is None:
+        make_exp_dirs(opt)
+        if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name'] and opt['rank'] == 0:
+            mkdir_and_rename(osp.join(opt['root_path'], 'tb_logger', opt['name']))
+
+    # copy the yml file to the experiment root
+    copy_opt_file(args.opt, opt['path']['experiments_root'])
+
+    # WARNING: should not use get_root_logger in the above codes, including the called functions
+    # Otherwise the logger will not be properly initialized
+    log_file = osp.join(opt['path']['log'], f"train_{opt['name']}_{get_time_str()}.log")
+    logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
+    logger.info(get_env_info())
+    logger.info(dict2str(opt))
+    # initialize wandb and tb loggers
+    tb_logger = init_tb_loggers(opt)
+
+    # create train and validation dataloaders
+    result = create_train_val_dataloader(opt, logger)
+    train_loader, train_sampler, val_loaders, total_epochs, total_iters = result
+
+    # create model
+    model = build_model(opt)
+    if resume_state:  # resume training
+        model.resume_training(resume_state)  # handle optimizers and schedulers
+        logger.info(f"Resuming training from epoch: {resume_state['epoch']}, iter: {resume_state['iter']}.")
+        start_epoch = resume_state['epoch']
+        current_iter = resume_state['iter']
+    else:
+        start_epoch = 0
+        current_iter = 0
+
+    # create message logger (formatted outputs)
+    msg_logger = MessageLogger(opt, current_iter, tb_logger)
+
+    # dataloader prefetcher
+    prefetch_mode = opt['datasets']['train'].get('prefetch_mode')
+    if prefetch_mode is None or prefetch_mode == 'cpu':
+        prefetcher = CPUPrefetcher(train_loader)
+    elif prefetch_mode == 'cuda':
+        prefetcher = CUDAPrefetcher(train_loader, opt)
+        logger.info(f'Use {prefetch_mode} prefetch dataloader')
+        if opt['datasets']['train'].get('pin_memory') is not True:
+            raise ValueError('Please set pin_memory=True for CUDAPrefetcher.')
+    else:
+        raise ValueError(f"Wrong prefetch_mode {prefetch_mode}. Supported ones are: None, 'cuda', 'cpu'.")
+
+    # training
+    logger.info(f'Start training from epoch: {start_epoch}, iter: {current_iter}')
+    data_timer, iter_timer = AvgTimer(), AvgTimer()
+    start_time = time.time()
+
+    for epoch in range(start_epoch, total_epochs + 1):
+        train_sampler.set_epoch(epoch)
+        prefetcher.reset()
+        train_data = prefetcher.next()
+
+        while train_data is not None:
+            data_timer.record()
+
+            current_iter += 1
+            if current_iter > total_iters:
+                break
+            # update learning rate
+            model.update_learning_rate(current_iter, warmup_iter=opt['train'].get('warmup_iter', -1))
+            # training
+            model.feed_data(train_data)
+            model.optimize_parameters(current_iter)
+            iter_timer.record()
+            if current_iter == 1:
+                # reset start time in msg_logger for more accurate eta_time
+                # not work in resume mode
+                msg_logger.reset_start_time()
+            # log
+            if current_iter % opt['logger']['print_freq'] == 0:
+                log_vars = {'epoch': epoch, 'iter': current_iter}
+                log_vars.update({'lrs': model.get_current_learning_rate()})
+                log_vars.update({'time': iter_timer.get_avg_time(), 'data_time': data_timer.get_avg_time()})
+                log_vars.update(model.get_current_log())
+                msg_logger(log_vars)
+
+            # save models and training states
+            if current_iter % opt['logger']['save_checkpoint_freq'] == 0:
+                logger.info('Saving models and training states.')
+                model.save(epoch, current_iter)
+
+            # validation
+            if opt.get('val') is not None and (current_iter % opt['val']['val_freq'] == 0):
+                if len(val_loaders) > 1:
+                    logger.warning('Multiple validation datasets are *only* supported by SRModel.')
+                for val_loader in val_loaders:
+                    model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])
+
+            data_timer.start()
+            iter_timer.start()
+            train_data = prefetcher.next()
+        # end of iter
+
+    # end of epoch
+
+    consumed_time = str(datetime.timedelta(seconds=int(time.time() - start_time)))
+    logger.info(f'End of training. Time consumed: {consumed_time}')
+    logger.info('Save the latest model.')
+    model.save(epoch=-1, current_iter=-1)  # -1 stands for the latest
+    if opt.get('val') is not None:
+        for val_loader in val_loaders:
+            model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])
+    if tb_logger:
+        tb_logger.close()
+
+
+if __name__ == '__main__':
+    root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
+    train_pipeline(root_path)
diff --git a/basicsr/utils/__init__.py b/basicsr/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9569c50780415b356c8e06edac5d960cf1fe1e91
--- /dev/null
+++ b/basicsr/utils/__init__.py
@@ -0,0 +1,47 @@
+from .color_util import bgr2ycbcr, rgb2ycbcr, rgb2ycbcr_pt, ycbcr2bgr, ycbcr2rgb
+from .diffjpeg import DiffJPEG
+from .file_client import FileClient
+from .img_process_util import USMSharp, usm_sharp
+from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img
+from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger
+from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt
+from .options import yaml_load
+
+__all__ = [
+    #  color_util.py
+    'bgr2ycbcr',
+    'rgb2ycbcr',
+    'rgb2ycbcr_pt',
+    'ycbcr2bgr',
+    'ycbcr2rgb',
+    # file_client.py
+    'FileClient',
+    # img_util.py
+    'img2tensor',
+    'tensor2img',
+    'imfrombytes',
+    'imwrite',
+    'crop_border',
+    # logger.py
+    'MessageLogger',
+    'AvgTimer',
+    'init_tb_logger',
+    'init_wandb_logger',
+    'get_root_logger',
+    'get_env_info',
+    # misc.py
+    'set_random_seed',
+    'get_time_str',
+    'mkdir_and_rename',
+    'make_exp_dirs',
+    'scandir',
+    'check_resume',
+    'sizeof_fmt',
+    # diffjpeg
+    'DiffJPEG',
+    # img_process_util
+    'USMSharp',
+    'usm_sharp',
+    # options
+    'yaml_load'
+]
diff --git a/basicsr/utils/__pycache__/__init__.cpython-310.pyc b/basicsr/utils/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ef1a705154e127900a848704a4c76aafff0269d4
Binary files /dev/null and b/basicsr/utils/__pycache__/__init__.cpython-310.pyc differ
diff --git a/basicsr/utils/__pycache__/color_util.cpython-310.pyc b/basicsr/utils/__pycache__/color_util.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..576a8cd964a3dd41e72a8716393da318a83a8273
Binary files /dev/null and b/basicsr/utils/__pycache__/color_util.cpython-310.pyc differ
diff --git a/basicsr/utils/__pycache__/diffjpeg.cpython-310.pyc b/basicsr/utils/__pycache__/diffjpeg.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ea950373c071f149e69493f9de069532ac0f8e5d
Binary files /dev/null and b/basicsr/utils/__pycache__/diffjpeg.cpython-310.pyc differ
diff --git a/basicsr/utils/__pycache__/dist_util.cpython-310.pyc b/basicsr/utils/__pycache__/dist_util.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c8dbc93ef8125020211bd165e6f76dd5387d2ff3
Binary files /dev/null and b/basicsr/utils/__pycache__/dist_util.cpython-310.pyc differ
diff --git a/basicsr/utils/__pycache__/file_client.cpython-310.pyc b/basicsr/utils/__pycache__/file_client.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1f627cdd480d8fb3481b480ce9b8b18a6c257dd0
Binary files /dev/null and b/basicsr/utils/__pycache__/file_client.cpython-310.pyc differ
diff --git a/basicsr/utils/__pycache__/flow_util.cpython-310.pyc b/basicsr/utils/__pycache__/flow_util.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..76555d3a2048adca186eff33b7e514ab17944580
Binary files /dev/null and b/basicsr/utils/__pycache__/flow_util.cpython-310.pyc differ
diff --git a/basicsr/utils/__pycache__/img_process_util.cpython-310.pyc b/basicsr/utils/__pycache__/img_process_util.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ef699484d4152007212f77fd6deb861ee821c987
Binary files /dev/null and b/basicsr/utils/__pycache__/img_process_util.cpython-310.pyc differ
diff --git a/basicsr/utils/__pycache__/img_util.cpython-310.pyc b/basicsr/utils/__pycache__/img_util.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9f0022bc34275d5591e2b4f3113cc57a6325cddc
Binary files /dev/null and b/basicsr/utils/__pycache__/img_util.cpython-310.pyc differ
diff --git a/basicsr/utils/__pycache__/logger.cpython-310.pyc b/basicsr/utils/__pycache__/logger.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9a2ae8c7c35503a9da42e45b7ff5b43b0b97624a
Binary files /dev/null and b/basicsr/utils/__pycache__/logger.cpython-310.pyc differ
diff --git a/basicsr/utils/__pycache__/matlab_functions.cpython-310.pyc b/basicsr/utils/__pycache__/matlab_functions.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2089612e4b1e67dbbe99c774d7bed56b7edda7cd
Binary files /dev/null and b/basicsr/utils/__pycache__/matlab_functions.cpython-310.pyc differ
diff --git a/basicsr/utils/__pycache__/misc.cpython-310.pyc b/basicsr/utils/__pycache__/misc.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d14fbc10feb34106511e674dd21a50597644db22
Binary files /dev/null and b/basicsr/utils/__pycache__/misc.cpython-310.pyc differ
diff --git a/basicsr/utils/__pycache__/options.cpython-310.pyc b/basicsr/utils/__pycache__/options.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5d68bd857da9c7eab8bbabb17329de9922f34f5a
Binary files /dev/null and b/basicsr/utils/__pycache__/options.cpython-310.pyc differ
diff --git a/basicsr/utils/__pycache__/registry.cpython-310.pyc b/basicsr/utils/__pycache__/registry.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e4dd273d40e29345391f53ce916e55bf99f60f3b
Binary files /dev/null and b/basicsr/utils/__pycache__/registry.cpython-310.pyc differ
diff --git a/basicsr/utils/color_util.py b/basicsr/utils/color_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..4740d5c98dd0680654e20d46b81ab30dfe936d6e
--- /dev/null
+++ b/basicsr/utils/color_util.py
@@ -0,0 +1,208 @@
+import numpy as np
+import torch
+
+
+def rgb2ycbcr(img, y_only=False):
+    """Convert a RGB image to YCbCr image.
+
+    This function produces the same results as Matlab's `rgb2ycbcr` function.
+    It implements the ITU-R BT.601 conversion for standard-definition
+    television. See more details in
+    https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+    It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
+    In OpenCV, it implements a JPEG conversion. See more details in
+    https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+    Args:
+        img (ndarray): The input image. It accepts:
+            1. np.uint8 type with range [0, 255];
+            2. np.float32 type with range [0, 1].
+        y_only (bool): Whether to only return Y channel. Default: False.
+
+    Returns:
+        ndarray: The converted YCbCr image. The output image has the same type
+            and range as input image.
+    """
+    img_type = img.dtype
+    img = _convert_input_type_range(img)
+    if y_only:
+        out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
+    else:
+        out_img = np.matmul(
+            img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128]
+    out_img = _convert_output_type_range(out_img, img_type)
+    return out_img
+
+
+def bgr2ycbcr(img, y_only=False):
+    """Convert a BGR image to YCbCr image.
+
+    The bgr version of rgb2ycbcr.
+    It implements the ITU-R BT.601 conversion for standard-definition
+    television. See more details in
+    https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+    It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
+    In OpenCV, it implements a JPEG conversion. See more details in
+    https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+    Args:
+        img (ndarray): The input image. It accepts:
+            1. np.uint8 type with range [0, 255];
+            2. np.float32 type with range [0, 1].
+        y_only (bool): Whether to only return Y channel. Default: False.
+
+    Returns:
+        ndarray: The converted YCbCr image. The output image has the same type
+            and range as input image.
+    """
+    img_type = img.dtype
+    img = _convert_input_type_range(img)
+    if y_only:
+        out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
+    else:
+        out_img = np.matmul(
+            img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128]
+    out_img = _convert_output_type_range(out_img, img_type)
+    return out_img
+
+
+def ycbcr2rgb(img):
+    """Convert a YCbCr image to RGB image.
+
+    This function produces the same results as Matlab's ycbcr2rgb function.
+    It implements the ITU-R BT.601 conversion for standard-definition
+    television. See more details in
+    https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+    It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
+    In OpenCV, it implements a JPEG conversion. See more details in
+    https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+    Args:
+        img (ndarray): The input image. It accepts:
+            1. np.uint8 type with range [0, 255];
+            2. np.float32 type with range [0, 1].
+
+    Returns:
+        ndarray: The converted RGB image. The output image has the same type
+            and range as input image.
+    """
+    img_type = img.dtype
+    img = _convert_input_type_range(img) * 255
+    out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
+                              [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]  # noqa: E126
+    out_img = _convert_output_type_range(out_img, img_type)
+    return out_img
+
+
+def ycbcr2bgr(img):
+    """Convert a YCbCr image to BGR image.
+
+    The bgr version of ycbcr2rgb.
+    It implements the ITU-R BT.601 conversion for standard-definition
+    television. See more details in
+    https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+    It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
+    In OpenCV, it implements a JPEG conversion. See more details in
+    https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+    Args:
+        img (ndarray): The input image. It accepts:
+            1. np.uint8 type with range [0, 255];
+            2. np.float32 type with range [0, 1].
+
+    Returns:
+        ndarray: The converted BGR image. The output image has the same type
+            and range as input image.
+    """
+    img_type = img.dtype
+    img = _convert_input_type_range(img) * 255
+    out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0],
+                              [0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921]  # noqa: E126
+    out_img = _convert_output_type_range(out_img, img_type)
+    return out_img
+
+
+def _convert_input_type_range(img):
+    """Convert the type and range of the input image.
+
+    It converts the input image to np.float32 type and range of [0, 1].
+    It is mainly used for pre-processing the input image in colorspace
+    conversion functions such as rgb2ycbcr and ycbcr2rgb.
+
+    Args:
+        img (ndarray): The input image. It accepts:
+            1. np.uint8 type with range [0, 255];
+            2. np.float32 type with range [0, 1].
+
+    Returns:
+        (ndarray): The converted image with type of np.float32 and range of
+            [0, 1].
+    """
+    img_type = img.dtype
+    img = img.astype(np.float32)
+    if img_type == np.float32:
+        pass
+    elif img_type == np.uint8:
+        img /= 255.
+    else:
+        raise TypeError(f'The img type should be np.float32 or np.uint8, but got {img_type}')
+    return img
+
+
+def _convert_output_type_range(img, dst_type):
+    """Convert the type and range of the image according to dst_type.
+
+    It converts the image to desired type and range. If `dst_type` is np.uint8,
+    images will be converted to np.uint8 type with range [0, 255]. If
+    `dst_type` is np.float32, it converts the image to np.float32 type with
+    range [0, 1].
+    It is mainly used for post-processing images in colorspace conversion
+    functions such as rgb2ycbcr and ycbcr2rgb.
+
+    Args:
+        img (ndarray): The image to be converted with np.float32 type and
+            range [0, 255].
+        dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
+            converts the image to np.uint8 type with range [0, 255]. If
+            dst_type is np.float32, it converts the image to np.float32 type
+            with range [0, 1].
+
+    Returns:
+        (ndarray): The converted image with desired type and range.
+    """
+    if dst_type not in (np.uint8, np.float32):
+        raise TypeError(f'The dst_type should be np.float32 or np.uint8, but got {dst_type}')
+    if dst_type == np.uint8:
+        img = img.round()
+    else:
+        img /= 255.
+    return img.astype(dst_type)
+
+
+def rgb2ycbcr_pt(img, y_only=False):
+    """Convert RGB images to YCbCr images (PyTorch version).
+
+    It implements the ITU-R BT.601 conversion for standard-definition television. See more details in
+    https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+    Args:
+        img (Tensor): Images with shape (n, 3, h, w), the range [0, 1], float, RGB format.
+         y_only (bool): Whether to only return Y channel. Default: False.
+
+    Returns:
+        (Tensor): converted images with the shape (n, 3/1, h, w), the range [0, 1], float.
+    """
+    if y_only:
+        weight = torch.tensor([[65.481], [128.553], [24.966]]).to(img)
+        out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + 16.0
+    else:
+        weight = torch.tensor([[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]).to(img)
+        bias = torch.tensor([16, 128, 128]).view(1, 3, 1, 1).to(img)
+        out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + bias
+
+    out_img = out_img / 255.
+    return out_img
diff --git a/basicsr/utils/diffjpeg.py b/basicsr/utils/diffjpeg.py
new file mode 100644
index 0000000000000000000000000000000000000000..65f96b44f9e7f3f8a589668f0003adf328cc5742
--- /dev/null
+++ b/basicsr/utils/diffjpeg.py
@@ -0,0 +1,515 @@
+"""
+Modified from https://github.com/mlomnitz/DiffJPEG
+
+For images not divisible by 8
+https://dsp.stackexchange.com/questions/35339/jpeg-dct-padding/35343#35343
+"""
+import itertools
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+
+# ------------------------ utils ------------------------#
+y_table = np.array(
+    [[16, 11, 10, 16, 24, 40, 51, 61], [12, 12, 14, 19, 26, 58, 60, 55], [14, 13, 16, 24, 40, 57, 69, 56],
+     [14, 17, 22, 29, 51, 87, 80, 62], [18, 22, 37, 56, 68, 109, 103, 77], [24, 35, 55, 64, 81, 104, 113, 92],
+     [49, 64, 78, 87, 103, 121, 120, 101], [72, 92, 95, 98, 112, 100, 103, 99]],
+    dtype=np.float32).T
+y_table = nn.Parameter(torch.from_numpy(y_table))
+c_table = np.empty((8, 8), dtype=np.float32)
+c_table.fill(99)
+c_table[:4, :4] = np.array([[17, 18, 24, 47], [18, 21, 26, 66], [24, 26, 56, 99], [47, 66, 99, 99]]).T
+c_table = nn.Parameter(torch.from_numpy(c_table))
+
+
+def diff_round(x):
+    """ Differentiable rounding function
+    """
+    return torch.round(x) + (x - torch.round(x))**3
+
+
+def quality_to_factor(quality):
+    """ Calculate factor corresponding to quality
+
+    Args:
+        quality(float): Quality for jpeg compression.
+
+    Returns:
+        float: Compression factor.
+    """
+    if quality < 50:
+        quality = 5000. / quality
+    else:
+        quality = 200. - quality * 2
+    return quality / 100.
+
+
+# ------------------------ compression ------------------------#
+class RGB2YCbCrJpeg(nn.Module):
+    """ Converts RGB image to YCbCr
+    """
+
+    def __init__(self):
+        super(RGB2YCbCrJpeg, self).__init__()
+        matrix = np.array([[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5], [0.5, -0.418688, -0.081312]],
+                          dtype=np.float32).T
+        self.shift = nn.Parameter(torch.tensor([0., 128., 128.]))
+        self.matrix = nn.Parameter(torch.from_numpy(matrix))
+
+    def forward(self, image):
+        """
+        Args:
+            image(Tensor): batch x 3 x height x width
+
+        Returns:
+            Tensor: batch x height x width x 3
+        """
+        image = image.permute(0, 2, 3, 1)
+        result = torch.tensordot(image, self.matrix, dims=1) + self.shift
+        return result.view(image.shape)
+
+
+class ChromaSubsampling(nn.Module):
+    """ Chroma subsampling on CbCr channels
+    """
+
+    def __init__(self):
+        super(ChromaSubsampling, self).__init__()
+
+    def forward(self, image):
+        """
+        Args:
+            image(tensor): batch x height x width x 3
+
+        Returns:
+            y(tensor): batch x height x width
+            cb(tensor): batch x height/2 x width/2
+            cr(tensor): batch x height/2 x width/2
+        """
+        image_2 = image.permute(0, 3, 1, 2).clone()
+        cb = F.avg_pool2d(image_2[:, 1, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False)
+        cr = F.avg_pool2d(image_2[:, 2, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False)
+        cb = cb.permute(0, 2, 3, 1)
+        cr = cr.permute(0, 2, 3, 1)
+        return image[:, :, :, 0], cb.squeeze(3), cr.squeeze(3)
+
+
+class BlockSplitting(nn.Module):
+    """ Splitting image into patches
+    """
+
+    def __init__(self):
+        super(BlockSplitting, self).__init__()
+        self.k = 8
+
+    def forward(self, image):
+        """
+        Args:
+            image(tensor): batch x height x width
+
+        Returns:
+            Tensor:  batch x h*w/64 x h x w
+        """
+        height, _ = image.shape[1:3]
+        batch_size = image.shape[0]
+        image_reshaped = image.view(batch_size, height // self.k, self.k, -1, self.k)
+        image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
+        return image_transposed.contiguous().view(batch_size, -1, self.k, self.k)
+
+
+class DCT8x8(nn.Module):
+    """ Discrete Cosine Transformation
+    """
+
+    def __init__(self):
+        super(DCT8x8, self).__init__()
+        tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
+        for x, y, u, v in itertools.product(range(8), repeat=4):
+            tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos((2 * y + 1) * v * np.pi / 16)
+        alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
+        self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
+        self.scale = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha) * 0.25).float())
+
+    def forward(self, image):
+        """
+        Args:
+            image(tensor): batch x height x width
+
+        Returns:
+            Tensor: batch x height x width
+        """
+        image = image - 128
+        result = self.scale * torch.tensordot(image, self.tensor, dims=2)
+        result.view(image.shape)
+        return result
+
+
+class YQuantize(nn.Module):
+    """ JPEG Quantization for Y channel
+
+    Args:
+        rounding(function): rounding function to use
+    """
+
+    def __init__(self, rounding):
+        super(YQuantize, self).__init__()
+        self.rounding = rounding
+        self.y_table = y_table
+
+    def forward(self, image, factor=1):
+        """
+        Args:
+            image(tensor): batch x height x width
+
+        Returns:
+            Tensor: batch x height x width
+        """
+        if isinstance(factor, (int, float)):
+            image = image.float() / (self.y_table * factor)
+        else:
+            b = factor.size(0)
+            table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
+            image = image.float() / table
+        image = self.rounding(image)
+        return image
+
+
+class CQuantize(nn.Module):
+    """ JPEG Quantization for CbCr channels
+
+    Args:
+        rounding(function): rounding function to use
+    """
+
+    def __init__(self, rounding):
+        super(CQuantize, self).__init__()
+        self.rounding = rounding
+        self.c_table = c_table
+
+    def forward(self, image, factor=1):
+        """
+        Args:
+            image(tensor): batch x height x width
+
+        Returns:
+            Tensor: batch x height x width
+        """
+        if isinstance(factor, (int, float)):
+            image = image.float() / (self.c_table * factor)
+        else:
+            b = factor.size(0)
+            table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
+            image = image.float() / table
+        image = self.rounding(image)
+        return image
+
+
+class CompressJpeg(nn.Module):
+    """Full JPEG compression algorithm
+
+    Args:
+        rounding(function): rounding function to use
+    """
+
+    def __init__(self, rounding=torch.round):
+        super(CompressJpeg, self).__init__()
+        self.l1 = nn.Sequential(RGB2YCbCrJpeg(), ChromaSubsampling())
+        self.l2 = nn.Sequential(BlockSplitting(), DCT8x8())
+        self.c_quantize = CQuantize(rounding=rounding)
+        self.y_quantize = YQuantize(rounding=rounding)
+
+    def forward(self, image, factor=1):
+        """
+        Args:
+            image(tensor): batch x 3 x height x width
+
+        Returns:
+            dict(tensor): Compressed tensor with batch x h*w/64 x 8 x 8.
+        """
+        y, cb, cr = self.l1(image * 255)
+        components = {'y': y, 'cb': cb, 'cr': cr}
+        for k in components.keys():
+            comp = self.l2(components[k])
+            if k in ('cb', 'cr'):
+                comp = self.c_quantize(comp, factor=factor)
+            else:
+                comp = self.y_quantize(comp, factor=factor)
+
+            components[k] = comp
+
+        return components['y'], components['cb'], components['cr']
+
+
+# ------------------------ decompression ------------------------#
+
+
+class YDequantize(nn.Module):
+    """Dequantize Y channel
+    """
+
+    def __init__(self):
+        super(YDequantize, self).__init__()
+        self.y_table = y_table
+
+    def forward(self, image, factor=1):
+        """
+        Args:
+            image(tensor): batch x height x width
+
+        Returns:
+            Tensor: batch x height x width
+        """
+        if isinstance(factor, (int, float)):
+            out = image * (self.y_table * factor)
+        else:
+            b = factor.size(0)
+            table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
+            out = image * table
+        return out
+
+
+class CDequantize(nn.Module):
+    """Dequantize CbCr channel
+    """
+
+    def __init__(self):
+        super(CDequantize, self).__init__()
+        self.c_table = c_table
+
+    def forward(self, image, factor=1):
+        """
+        Args:
+            image(tensor): batch x height x width
+
+        Returns:
+            Tensor: batch x height x width
+        """
+        if isinstance(factor, (int, float)):
+            out = image * (self.c_table * factor)
+        else:
+            b = factor.size(0)
+            table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
+            out = image * table
+        return out
+
+
+class iDCT8x8(nn.Module):
+    """Inverse discrete Cosine Transformation
+    """
+
+    def __init__(self):
+        super(iDCT8x8, self).__init__()
+        alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
+        self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float())
+        tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
+        for x, y, u, v in itertools.product(range(8), repeat=4):
+            tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos((2 * v + 1) * y * np.pi / 16)
+        self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
+
+    def forward(self, image):
+        """
+        Args:
+            image(tensor): batch x height x width
+
+        Returns:
+            Tensor: batch x height x width
+        """
+        image = image * self.alpha
+        result = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128
+        result.view(image.shape)
+        return result
+
+
+class BlockMerging(nn.Module):
+    """Merge patches into image
+    """
+
+    def __init__(self):
+        super(BlockMerging, self).__init__()
+
+    def forward(self, patches, height, width):
+        """
+        Args:
+            patches(tensor) batch x height*width/64, height x width
+            height(int)
+            width(int)
+
+        Returns:
+            Tensor: batch x height x width
+        """
+        k = 8
+        batch_size = patches.shape[0]
+        image_reshaped = patches.view(batch_size, height // k, width // k, k, k)
+        image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
+        return image_transposed.contiguous().view(batch_size, height, width)
+
+
+class ChromaUpsampling(nn.Module):
+    """Upsample chroma layers
+    """
+
+    def __init__(self):
+        super(ChromaUpsampling, self).__init__()
+
+    def forward(self, y, cb, cr):
+        """
+        Args:
+            y(tensor): y channel image
+            cb(tensor): cb channel
+            cr(tensor): cr channel
+
+        Returns:
+            Tensor: batch x height x width x 3
+        """
+
+        def repeat(x, k=2):
+            height, width = x.shape[1:3]
+            x = x.unsqueeze(-1)
+            x = x.repeat(1, 1, k, k)
+            x = x.view(-1, height * k, width * k)
+            return x
+
+        cb = repeat(cb)
+        cr = repeat(cr)
+        return torch.cat([y.unsqueeze(3), cb.unsqueeze(3), cr.unsqueeze(3)], dim=3)
+
+
+class YCbCr2RGBJpeg(nn.Module):
+    """Converts YCbCr image to RGB JPEG
+    """
+
+    def __init__(self):
+        super(YCbCr2RGBJpeg, self).__init__()
+
+        matrix = np.array([[1., 0., 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]], dtype=np.float32).T
+        self.shift = nn.Parameter(torch.tensor([0, -128., -128.]))
+        self.matrix = nn.Parameter(torch.from_numpy(matrix))
+
+    def forward(self, image):
+        """
+        Args:
+            image(tensor): batch x height x width x 3
+
+        Returns:
+            Tensor: batch x 3 x height x width
+        """
+        result = torch.tensordot(image + self.shift, self.matrix, dims=1)
+        return result.view(image.shape).permute(0, 3, 1, 2)
+
+
+class DeCompressJpeg(nn.Module):
+    """Full JPEG decompression algorithm
+
+    Args:
+        rounding(function): rounding function to use
+    """
+
+    def __init__(self, rounding=torch.round):
+        super(DeCompressJpeg, self).__init__()
+        self.c_dequantize = CDequantize()
+        self.y_dequantize = YDequantize()
+        self.idct = iDCT8x8()
+        self.merging = BlockMerging()
+        self.chroma = ChromaUpsampling()
+        self.colors = YCbCr2RGBJpeg()
+
+    def forward(self, y, cb, cr, imgh, imgw, factor=1):
+        """
+        Args:
+            compressed(dict(tensor)): batch x h*w/64 x 8 x 8
+            imgh(int)
+            imgw(int)
+            factor(float)
+
+        Returns:
+            Tensor: batch x 3 x height x width
+        """
+        components = {'y': y, 'cb': cb, 'cr': cr}
+        for k in components.keys():
+            if k in ('cb', 'cr'):
+                comp = self.c_dequantize(components[k], factor=factor)
+                height, width = int(imgh / 2), int(imgw / 2)
+            else:
+                comp = self.y_dequantize(components[k], factor=factor)
+                height, width = imgh, imgw
+            comp = self.idct(comp)
+            components[k] = self.merging(comp, height, width)
+            #
+        image = self.chroma(components['y'], components['cb'], components['cr'])
+        image = self.colors(image)
+
+        image = torch.min(255 * torch.ones_like(image), torch.max(torch.zeros_like(image), image))
+        return image / 255
+
+
+# ------------------------ main DiffJPEG ------------------------ #
+
+
+class DiffJPEG(nn.Module):
+    """This JPEG algorithm result is slightly different from cv2.
+    DiffJPEG supports batch processing.
+
+    Args:
+        differentiable(bool): If True, uses custom differentiable rounding function, if False, uses standard torch.round
+    """
+
+    def __init__(self, differentiable=True):
+        super(DiffJPEG, self).__init__()
+        if differentiable:
+            rounding = diff_round
+        else:
+            rounding = torch.round
+
+        self.compress = CompressJpeg(rounding=rounding)
+        self.decompress = DeCompressJpeg(rounding=rounding)
+
+    def forward(self, x, quality):
+        """
+        Args:
+            x (Tensor): Input image, bchw, rgb, [0, 1]
+            quality(float): Quality factor for jpeg compression scheme.
+        """
+        factor = quality
+        if isinstance(factor, (int, float)):
+            factor = quality_to_factor(factor)
+        else:
+            for i in range(factor.size(0)):
+                factor[i] = quality_to_factor(factor[i])
+        h, w = x.size()[-2:]
+        h_pad, w_pad = 0, 0
+        # why should use 16
+        if h % 16 != 0:
+            h_pad = 16 - h % 16
+        if w % 16 != 0:
+            w_pad = 16 - w % 16
+        x = F.pad(x, (0, w_pad, 0, h_pad), mode='constant', value=0)
+
+        y, cb, cr = self.compress(x, factor=factor)
+        recovered = self.decompress(y, cb, cr, (h + h_pad), (w + w_pad), factor=factor)
+        recovered = recovered[:, :, 0:h, 0:w]
+        return recovered
+
+
+if __name__ == '__main__':
+    import cv2
+
+    from basicsr.utils import img2tensor, tensor2img
+
+    img_gt = cv2.imread('test.png') / 255.
+
+    # -------------- cv2 -------------- #
+    encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 20]
+    _, encimg = cv2.imencode('.jpg', img_gt * 255., encode_param)
+    img_lq = np.float32(cv2.imdecode(encimg, 1))
+    cv2.imwrite('cv2_JPEG_20.png', img_lq)
+
+    # -------------- DiffJPEG -------------- #
+    jpeger = DiffJPEG(differentiable=False).cuda()
+    img_gt = img2tensor(img_gt)
+    img_gt = torch.stack([img_gt, img_gt]).cuda()
+    quality = img_gt.new_tensor([20, 40])
+    out = jpeger(img_gt, quality=quality)
+
+    cv2.imwrite('pt_JPEG_20.png', tensor2img(out[0]))
+    cv2.imwrite('pt_JPEG_40.png', tensor2img(out[1]))
diff --git a/basicsr/utils/dist_util.py b/basicsr/utils/dist_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..0fab887b2cb1ce8533d2e8fdee72ae0c24f68fd0
--- /dev/null
+++ b/basicsr/utils/dist_util.py
@@ -0,0 +1,82 @@
+# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py  # noqa: E501
+import functools
+import os
+import subprocess
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+
+
+def init_dist(launcher, backend='nccl', **kwargs):
+    if mp.get_start_method(allow_none=True) is None:
+        mp.set_start_method('spawn')
+    if launcher == 'pytorch':
+        _init_dist_pytorch(backend, **kwargs)
+    elif launcher == 'slurm':
+        _init_dist_slurm(backend, **kwargs)
+    else:
+        raise ValueError(f'Invalid launcher type: {launcher}')
+
+
+def _init_dist_pytorch(backend, **kwargs):
+    rank = int(os.environ['RANK'])
+    num_gpus = torch.cuda.device_count()
+    torch.cuda.set_device(rank % num_gpus)
+    dist.init_process_group(backend=backend, **kwargs)
+
+
+def _init_dist_slurm(backend, port=None):
+    """Initialize slurm distributed training environment.
+
+    If argument ``port`` is not specified, then the master port will be system
+    environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
+    environment variable, then a default port ``29500`` will be used.
+
+    Args:
+        backend (str): Backend of torch.distributed.
+        port (int, optional): Master port. Defaults to None.
+    """
+    proc_id = int(os.environ['SLURM_PROCID'])
+    ntasks = int(os.environ['SLURM_NTASKS'])
+    node_list = os.environ['SLURM_NODELIST']
+    num_gpus = torch.cuda.device_count()
+    torch.cuda.set_device(proc_id % num_gpus)
+    addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1')
+    # specify master port
+    if port is not None:
+        os.environ['MASTER_PORT'] = str(port)
+    elif 'MASTER_PORT' in os.environ:
+        pass  # use MASTER_PORT in the environment variable
+    else:
+        # 29500 is torch.distributed default port
+        os.environ['MASTER_PORT'] = '29500'
+    os.environ['MASTER_ADDR'] = addr
+    os.environ['WORLD_SIZE'] = str(ntasks)
+    os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
+    os.environ['RANK'] = str(proc_id)
+    dist.init_process_group(backend=backend)
+
+
+def get_dist_info():
+    if dist.is_available():
+        initialized = dist.is_initialized()
+    else:
+        initialized = False
+    if initialized:
+        rank = dist.get_rank()
+        world_size = dist.get_world_size()
+    else:
+        rank = 0
+        world_size = 1
+    return rank, world_size
+
+
+def master_only(func):
+
+    @functools.wraps(func)
+    def wrapper(*args, **kwargs):
+        rank, _ = get_dist_info()
+        if rank == 0:
+            return func(*args, **kwargs)
+
+    return wrapper
diff --git a/basicsr/utils/download_util.py b/basicsr/utils/download_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..f73abd0e1831b8cab6277d780331a5103785b9ec
--- /dev/null
+++ b/basicsr/utils/download_util.py
@@ -0,0 +1,98 @@
+import math
+import os
+import requests
+from torch.hub import download_url_to_file, get_dir
+from tqdm import tqdm
+from urllib.parse import urlparse
+
+from .misc import sizeof_fmt
+
+
+def download_file_from_google_drive(file_id, save_path):
+    """Download files from google drive.
+
+    Reference: https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive
+
+    Args:
+        file_id (str): File id.
+        save_path (str): Save path.
+    """
+
+    session = requests.Session()
+    URL = 'https://docs.google.com/uc?export=download'
+    params = {'id': file_id}
+
+    response = session.get(URL, params=params, stream=True)
+    token = get_confirm_token(response)
+    if token:
+        params['confirm'] = token
+        response = session.get(URL, params=params, stream=True)
+
+    # get file size
+    response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
+    if 'Content-Range' in response_file_size.headers:
+        file_size = int(response_file_size.headers['Content-Range'].split('/')[1])
+    else:
+        file_size = None
+
+    save_response_content(response, save_path, file_size)
+
+
+def get_confirm_token(response):
+    for key, value in response.cookies.items():
+        if key.startswith('download_warning'):
+            return value
+    return None
+
+
+def save_response_content(response, destination, file_size=None, chunk_size=32768):
+    if file_size is not None:
+        pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
+
+        readable_file_size = sizeof_fmt(file_size)
+    else:
+        pbar = None
+
+    with open(destination, 'wb') as f:
+        downloaded_size = 0
+        for chunk in response.iter_content(chunk_size):
+            downloaded_size += chunk_size
+            if pbar is not None:
+                pbar.update(1)
+                pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}')
+            if chunk:  # filter out keep-alive new chunks
+                f.write(chunk)
+        if pbar is not None:
+            pbar.close()
+
+
+def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
+    """Load file form http url, will download models if necessary.
+
+    Reference: https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
+
+    Args:
+        url (str): URL to be downloaded.
+        model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
+            Default: None.
+        progress (bool): Whether to show the download progress. Default: True.
+        file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
+
+    Returns:
+        str: The path to the downloaded file.
+    """
+    if model_dir is None:  # use the pytorch hub_dir
+        hub_dir = get_dir()
+        model_dir = os.path.join(hub_dir, 'checkpoints')
+
+    os.makedirs(model_dir, exist_ok=True)
+
+    parts = urlparse(url)
+    filename = os.path.basename(parts.path)
+    if file_name is not None:
+        filename = file_name
+    cached_file = os.path.abspath(os.path.join(model_dir, filename))
+    if not os.path.exists(cached_file):
+        print(f'Downloading: "{url}" to {cached_file}\n')
+        download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
+    return cached_file
diff --git a/basicsr/utils/file_client.py b/basicsr/utils/file_client.py
new file mode 100644
index 0000000000000000000000000000000000000000..89d83ab9e0d4314f8cdf2393908a561c6d1dca92
--- /dev/null
+++ b/basicsr/utils/file_client.py
@@ -0,0 +1,167 @@
+# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py  # noqa: E501
+from abc import ABCMeta, abstractmethod
+
+
+class BaseStorageBackend(metaclass=ABCMeta):
+    """Abstract class of storage backends.
+
+    All backends need to implement two apis: ``get()`` and ``get_text()``.
+    ``get()`` reads the file as a byte stream and ``get_text()`` reads the file
+    as texts.
+    """
+
+    @abstractmethod
+    def get(self, filepath):
+        pass
+
+    @abstractmethod
+    def get_text(self, filepath):
+        pass
+
+
+class MemcachedBackend(BaseStorageBackend):
+    """Memcached storage backend.
+
+    Attributes:
+        server_list_cfg (str): Config file for memcached server list.
+        client_cfg (str): Config file for memcached client.
+        sys_path (str | None): Additional path to be appended to `sys.path`.
+            Default: None.
+    """
+
+    def __init__(self, server_list_cfg, client_cfg, sys_path=None):
+        if sys_path is not None:
+            import sys
+            sys.path.append(sys_path)
+        try:
+            import mc
+        except ImportError:
+            raise ImportError('Please install memcached to enable MemcachedBackend.')
+
+        self.server_list_cfg = server_list_cfg
+        self.client_cfg = client_cfg
+        self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg)
+        # mc.pyvector servers as a point which points to a memory cache
+        self._mc_buffer = mc.pyvector()
+
+    def get(self, filepath):
+        filepath = str(filepath)
+        import mc
+        self._client.Get(filepath, self._mc_buffer)
+        value_buf = mc.ConvertBuffer(self._mc_buffer)
+        return value_buf
+
+    def get_text(self, filepath):
+        raise NotImplementedError
+
+
+class HardDiskBackend(BaseStorageBackend):
+    """Raw hard disks storage backend."""
+
+    def get(self, filepath):
+        filepath = str(filepath)
+        with open(filepath, 'rb') as f:
+            value_buf = f.read()
+        return value_buf
+
+    def get_text(self, filepath):
+        filepath = str(filepath)
+        with open(filepath, 'r') as f:
+            value_buf = f.read()
+        return value_buf
+
+
+class LmdbBackend(BaseStorageBackend):
+    """Lmdb storage backend.
+
+    Args:
+        db_paths (str | list[str]): Lmdb database paths.
+        client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
+        readonly (bool, optional): Lmdb environment parameter. If True,
+            disallow any write operations. Default: True.
+        lock (bool, optional): Lmdb environment parameter. If False, when
+            concurrent access occurs, do not lock the database. Default: False.
+        readahead (bool, optional): Lmdb environment parameter. If False,
+            disable the OS filesystem readahead mechanism, which may improve
+            random read performance when a database is larger than RAM.
+            Default: False.
+
+    Attributes:
+        db_paths (list): Lmdb database path.
+        _client (list): A list of several lmdb envs.
+    """
+
+    def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs):
+        try:
+            import lmdb
+        except ImportError:
+            raise ImportError('Please install lmdb to enable LmdbBackend.')
+
+        if isinstance(client_keys, str):
+            client_keys = [client_keys]
+
+        if isinstance(db_paths, list):
+            self.db_paths = [str(v) for v in db_paths]
+        elif isinstance(db_paths, str):
+            self.db_paths = [str(db_paths)]
+        assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, '
+                                                        f'but received {len(client_keys)} and {len(self.db_paths)}.')
+
+        self._client = {}
+        for client, path in zip(client_keys, self.db_paths):
+            self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs)
+
+    def get(self, filepath, client_key):
+        """Get values according to the filepath from one lmdb named client_key.
+
+        Args:
+            filepath (str | obj:`Path`): Here, filepath is the lmdb key.
+            client_key (str): Used for distinguishing different lmdb envs.
+        """
+        filepath = str(filepath)
+        assert client_key in self._client, (f'client_key {client_key} is not in lmdb clients.')
+        client = self._client[client_key]
+        with client.begin(write=False) as txn:
+            value_buf = txn.get(filepath.encode('ascii'))
+        return value_buf
+
+    def get_text(self, filepath):
+        raise NotImplementedError
+
+
+class FileClient(object):
+    """A general file client to access files in different backend.
+
+    The client loads a file or text in a specified backend from its path
+    and return it as a binary file. it can also register other backend
+    accessor with a given name and backend class.
+
+    Attributes:
+        backend (str): The storage backend type. Options are "disk",
+            "memcached" and "lmdb".
+        client (:obj:`BaseStorageBackend`): The backend object.
+    """
+
+    _backends = {
+        'disk': HardDiskBackend,
+        'memcached': MemcachedBackend,
+        'lmdb': LmdbBackend,
+    }
+
+    def __init__(self, backend='disk', **kwargs):
+        if backend not in self._backends:
+            raise ValueError(f'Backend {backend} is not supported. Currently supported ones'
+                             f' are {list(self._backends.keys())}')
+        self.backend = backend
+        self.client = self._backends[backend](**kwargs)
+
+    def get(self, filepath, client_key='default'):
+        # client_key is used only for lmdb, where different fileclients have
+        # different lmdb environments.
+        if self.backend == 'lmdb':
+            return self.client.get(filepath, client_key)
+        else:
+            return self.client.get(filepath)
+
+    def get_text(self, filepath):
+        return self.client.get_text(filepath)
diff --git a/basicsr/utils/flow_util.py b/basicsr/utils/flow_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d7180b4e9b5c8f2eb36a9a0e4ff6affdaae84b8
--- /dev/null
+++ b/basicsr/utils/flow_util.py
@@ -0,0 +1,170 @@
+# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py  # noqa: E501
+import cv2
+import numpy as np
+import os
+
+
+def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs):
+    """Read an optical flow map.
+
+    Args:
+        flow_path (ndarray or str): Flow path.
+        quantize (bool): whether to read quantized pair, if set to True,
+            remaining args will be passed to :func:`dequantize_flow`.
+        concat_axis (int): The axis that dx and dy are concatenated,
+            can be either 0 or 1. Ignored if quantize is False.
+
+    Returns:
+        ndarray: Optical flow represented as a (h, w, 2) numpy array
+    """
+    if quantize:
+        assert concat_axis in [0, 1]
+        cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED)
+        if cat_flow.ndim != 2:
+            raise IOError(f'{flow_path} is not a valid quantized flow file, its dimension is {cat_flow.ndim}.')
+        assert cat_flow.shape[concat_axis] % 2 == 0
+        dx, dy = np.split(cat_flow, 2, axis=concat_axis)
+        flow = dequantize_flow(dx, dy, *args, **kwargs)
+    else:
+        with open(flow_path, 'rb') as f:
+            try:
+                header = f.read(4).decode('utf-8')
+            except Exception:
+                raise IOError(f'Invalid flow file: {flow_path}')
+            else:
+                if header != 'PIEH':
+                    raise IOError(f'Invalid flow file: {flow_path}, header does not contain PIEH')
+
+            w = np.fromfile(f, np.int32, 1).squeeze()
+            h = np.fromfile(f, np.int32, 1).squeeze()
+            flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2))
+
+    return flow.astype(np.float32)
+
+
+def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs):
+    """Write optical flow to file.
+
+    If the flow is not quantized, it will be saved as a .flo file losslessly,
+    otherwise a jpeg image which is lossy but of much smaller size. (dx and dy
+    will be concatenated horizontally into a single image if quantize is True.)
+
+    Args:
+        flow (ndarray): (h, w, 2) array of optical flow.
+        filename (str): Output filepath.
+        quantize (bool): Whether to quantize the flow and save it to 2 jpeg
+            images. If set to True, remaining args will be passed to
+            :func:`quantize_flow`.
+        concat_axis (int): The axis that dx and dy are concatenated,
+            can be either 0 or 1. Ignored if quantize is False.
+    """
+    if not quantize:
+        with open(filename, 'wb') as f:
+            f.write('PIEH'.encode('utf-8'))
+            np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
+            flow = flow.astype(np.float32)
+            flow.tofile(f)
+            f.flush()
+    else:
+        assert concat_axis in [0, 1]
+        dx, dy = quantize_flow(flow, *args, **kwargs)
+        dxdy = np.concatenate((dx, dy), axis=concat_axis)
+        os.makedirs(os.path.dirname(filename), exist_ok=True)
+        cv2.imwrite(filename, dxdy)
+
+
+def quantize_flow(flow, max_val=0.02, norm=True):
+    """Quantize flow to [0, 255].
+
+    After this step, the size of flow will be much smaller, and can be
+    dumped as jpeg images.
+
+    Args:
+        flow (ndarray): (h, w, 2) array of optical flow.
+        max_val (float): Maximum value of flow, values beyond
+                        [-max_val, max_val] will be truncated.
+        norm (bool): Whether to divide flow values by image width/height.
+
+    Returns:
+        tuple[ndarray]: Quantized dx and dy.
+    """
+    h, w, _ = flow.shape
+    dx = flow[..., 0]
+    dy = flow[..., 1]
+    if norm:
+        dx = dx / w  # avoid inplace operations
+        dy = dy / h
+    # use 255 levels instead of 256 to make sure 0 is 0 after dequantization.
+    flow_comps = [quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]]
+    return tuple(flow_comps)
+
+
+def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
+    """Recover from quantized flow.
+
+    Args:
+        dx (ndarray): Quantized dx.
+        dy (ndarray): Quantized dy.
+        max_val (float): Maximum value used when quantizing.
+        denorm (bool): Whether to multiply flow values with width/height.
+
+    Returns:
+        ndarray: Dequantized flow.
+    """
+    assert dx.shape == dy.shape
+    assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1)
+
+    dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]]
+
+    if denorm:
+        dx *= dx.shape[1]
+        dy *= dx.shape[0]
+    flow = np.dstack((dx, dy))
+    return flow
+
+
+def quantize(arr, min_val, max_val, levels, dtype=np.int64):
+    """Quantize an array of (-inf, inf) to [0, levels-1].
+
+    Args:
+        arr (ndarray): Input array.
+        min_val (scalar): Minimum value to be clipped.
+        max_val (scalar): Maximum value to be clipped.
+        levels (int): Quantization levels.
+        dtype (np.type): The type of the quantized array.
+
+    Returns:
+        tuple: Quantized array.
+    """
+    if not (isinstance(levels, int) and levels > 1):
+        raise ValueError(f'levels must be a positive integer, but got {levels}')
+    if min_val >= max_val:
+        raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})')
+
+    arr = np.clip(arr, min_val, max_val) - min_val
+    quantized_arr = np.minimum(np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1)
+
+    return quantized_arr
+
+
+def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
+    """Dequantize an array.
+
+    Args:
+        arr (ndarray): Input array.
+        min_val (scalar): Minimum value to be clipped.
+        max_val (scalar): Maximum value to be clipped.
+        levels (int): Quantization levels.
+        dtype (np.type): The type of the dequantized array.
+
+    Returns:
+        tuple: Dequantized array.
+    """
+    if not (isinstance(levels, int) and levels > 1):
+        raise ValueError(f'levels must be a positive integer, but got {levels}')
+    if min_val >= max_val:
+        raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})')
+
+    dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - min_val) / levels + min_val
+
+    return dequantized_arr
diff --git a/basicsr/utils/img_process_util.py b/basicsr/utils/img_process_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..52e02f09930dbf13bcd12bbe16b76e4fce52578e
--- /dev/null
+++ b/basicsr/utils/img_process_util.py
@@ -0,0 +1,83 @@
+import cv2
+import numpy as np
+import torch
+from torch.nn import functional as F
+
+
+def filter2D(img, kernel):
+    """PyTorch version of cv2.filter2D
+
+    Args:
+        img (Tensor): (b, c, h, w)
+        kernel (Tensor): (b, k, k)
+    """
+    k = kernel.size(-1)
+    b, c, h, w = img.size()
+    if k % 2 == 1:
+        img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect')
+    else:
+        raise ValueError('Wrong kernel size')
+
+    ph, pw = img.size()[-2:]
+
+    if kernel.size(0) == 1:
+        # apply the same kernel to all batch images
+        img = img.view(b * c, 1, ph, pw)
+        kernel = kernel.view(1, 1, k, k)
+        return F.conv2d(img, kernel, padding=0).view(b, c, h, w)
+    else:
+        img = img.view(1, b * c, ph, pw)
+        kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k)
+        return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w)
+
+
+def usm_sharp(img, weight=0.5, radius=50, threshold=10):
+    """USM sharpening.
+
+    Input image: I; Blurry image: B.
+    1. sharp = I + weight * (I - B)
+    2. Mask = 1 if abs(I - B) > threshold, else: 0
+    3. Blur mask:
+    4. Out = Mask * sharp + (1 - Mask) * I
+
+
+    Args:
+        img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
+        weight (float): Sharp weight. Default: 1.
+        radius (float): Kernel size of Gaussian blur. Default: 50.
+        threshold (int):
+    """
+    if radius % 2 == 0:
+        radius += 1
+    blur = cv2.GaussianBlur(img, (radius, radius), 0)
+    residual = img - blur
+    mask = np.abs(residual) * 255 > threshold
+    mask = mask.astype('float32')
+    soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
+
+    sharp = img + weight * residual
+    sharp = np.clip(sharp, 0, 1)
+    return soft_mask * sharp + (1 - soft_mask) * img
+
+
+class USMSharp(torch.nn.Module):
+
+    def __init__(self, radius=50, sigma=0):
+        super(USMSharp, self).__init__()
+        if radius % 2 == 0:
+            radius += 1
+        self.radius = radius
+        kernel = cv2.getGaussianKernel(radius, sigma)
+        kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0)
+        self.register_buffer('kernel', kernel)
+
+    def forward(self, img, weight=0.5, threshold=10):
+        blur = filter2D(img, self.kernel)
+        residual = img - blur
+
+        mask = torch.abs(residual) * 255 > threshold
+        mask = mask.float()
+        soft_mask = filter2D(mask, self.kernel)
+        sharp = img + weight * residual
+        sharp = torch.clip(sharp, 0, 1)
+        return soft_mask * sharp + (1 - soft_mask) * img
diff --git a/basicsr/utils/img_util.py b/basicsr/utils/img_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbce5dba5b01deb78f2453edc801a76e6a126998
--- /dev/null
+++ b/basicsr/utils/img_util.py
@@ -0,0 +1,172 @@
+import cv2
+import math
+import numpy as np
+import os
+import torch
+from torchvision.utils import make_grid
+
+
+def img2tensor(imgs, bgr2rgb=True, float32=True):
+    """Numpy array to tensor.
+
+    Args:
+        imgs (list[ndarray] | ndarray): Input images.
+        bgr2rgb (bool): Whether to change bgr to rgb.
+        float32 (bool): Whether to change to float32.
+
+    Returns:
+        list[tensor] | tensor: Tensor images. If returned results only have
+            one element, just return tensor.
+    """
+
+    def _totensor(img, bgr2rgb, float32):
+        if img.shape[2] == 3 and bgr2rgb:
+            if img.dtype == 'float64':
+                img = img.astype('float32')
+            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+        img = torch.from_numpy(img.transpose(2, 0, 1))
+        if float32:
+            img = img.float()
+        return img
+
+    if isinstance(imgs, list):
+        return [_totensor(img, bgr2rgb, float32) for img in imgs]
+    else:
+        return _totensor(imgs, bgr2rgb, float32)
+
+
+def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
+    """Convert torch Tensors into image numpy arrays.
+
+    After clamping to [min, max], values will be normalized to [0, 1].
+
+    Args:
+        tensor (Tensor or list[Tensor]): Accept shapes:
+            1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
+            2) 3D Tensor of shape (3/1 x H x W);
+            3) 2D Tensor of shape (H x W).
+            Tensor channel should be in RGB order.
+        rgb2bgr (bool): Whether to change rgb to bgr.
+        out_type (numpy type): output types. If ``np.uint8``, transform outputs
+            to uint8 type with range [0, 255]; otherwise, float type with
+            range [0, 1]. Default: ``np.uint8``.
+        min_max (tuple[int]): min and max values for clamp.
+
+    Returns:
+        (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
+        shape (H x W). The channel order is BGR.
+    """
+    if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
+        raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
+
+    if torch.is_tensor(tensor):
+        tensor = [tensor]
+    result = []
+    for _tensor in tensor:
+        _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
+        _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
+
+        n_dim = _tensor.dim()
+        if n_dim == 4:
+            img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
+            img_np = img_np.transpose(1, 2, 0)
+            if rgb2bgr:
+                img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
+        elif n_dim == 3:
+            img_np = _tensor.numpy()
+            img_np = img_np.transpose(1, 2, 0)
+            if img_np.shape[2] == 1:  # gray image
+                img_np = np.squeeze(img_np, axis=2)
+            else:
+                if rgb2bgr:
+                    img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
+        elif n_dim == 2:
+            img_np = _tensor.numpy()
+        else:
+            raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
+        if out_type == np.uint8:
+            # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
+            img_np = (img_np * 255.0).round()
+        img_np = img_np.astype(out_type)
+        result.append(img_np)
+    if len(result) == 1 and torch.is_tensor(tensor):
+        result = result[0]
+    return result
+
+
+def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)):
+    """This implementation is slightly faster than tensor2img.
+    It now only supports torch tensor with shape (1, c, h, w).
+
+    Args:
+        tensor (Tensor): Now only support torch tensor with (1, c, h, w).
+        rgb2bgr (bool): Whether to change rgb to bgr. Default: True.
+        min_max (tuple[int]): min and max values for clamp.
+    """
+    output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0)
+    output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255
+    output = output.type(torch.uint8).cpu().numpy()
+    if rgb2bgr:
+        output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
+    return output
+
+
+def imfrombytes(content, flag='color', float32=False):
+    """Read an image from bytes.
+
+    Args:
+        content (bytes): Image bytes got from files or other streams.
+        flag (str): Flags specifying the color type of a loaded image,
+            candidates are `color`, `grayscale` and `unchanged`.
+        float32 (bool): Whether to change to float32., If True, will also norm
+            to [0, 1]. Default: False.
+
+    Returns:
+        ndarray: Loaded image array.
+    """
+    img_np = np.frombuffer(content, np.uint8)
+    imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
+    img = cv2.imdecode(img_np, imread_flags[flag])
+    if float32:
+        img = img.astype(np.float32) / 255.
+    return img
+
+
+def imwrite(img, file_path, params=None, auto_mkdir=True):
+    """Write image to file.
+
+    Args:
+        img (ndarray): Image array to be written.
+        file_path (str): Image file path.
+        params (None or list): Same as opencv's :func:`imwrite` interface.
+        auto_mkdir (bool): If the parent folder of `file_path` does not exist,
+            whether to create it automatically.
+
+    Returns:
+        bool: Successful or not.
+    """
+    if auto_mkdir:
+        dir_name = os.path.abspath(os.path.dirname(file_path))
+        os.makedirs(dir_name, exist_ok=True)
+    ok = cv2.imwrite(file_path, img, params)
+    if not ok:
+        raise IOError('Failed in writing images.')
+
+
+def crop_border(imgs, crop_border):
+    """Crop borders of images.
+
+    Args:
+        imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
+        crop_border (int): Crop border for each end of height and weight.
+
+    Returns:
+        list[ndarray]: Cropped images.
+    """
+    if crop_border == 0:
+        return imgs
+    else:
+        if isinstance(imgs, list):
+            return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs]
+        else:
+            return imgs[crop_border:-crop_border, crop_border:-crop_border, ...]
diff --git a/basicsr/utils/lmdb_util.py b/basicsr/utils/lmdb_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2b45ce01d5e32ddbf8354d71fd1c8678bede822
--- /dev/null
+++ b/basicsr/utils/lmdb_util.py
@@ -0,0 +1,199 @@
+import cv2
+import lmdb
+import sys
+from multiprocessing import Pool
+from os import path as osp
+from tqdm import tqdm
+
+
+def make_lmdb_from_imgs(data_path,
+                        lmdb_path,
+                        img_path_list,
+                        keys,
+                        batch=5000,
+                        compress_level=1,
+                        multiprocessing_read=False,
+                        n_thread=40,
+                        map_size=None):
+    """Make lmdb from images.
+
+    Contents of lmdb. The file structure is:
+
+    ::
+
+        example.lmdb
+        ├── data.mdb
+        ├── lock.mdb
+        ├── meta_info.txt
+
+    The data.mdb and lock.mdb are standard lmdb files and you can refer to
+    https://lmdb.readthedocs.io/en/release/ for more details.
+
+    The meta_info.txt is a specified txt file to record the meta information
+    of our datasets. It will be automatically created when preparing
+    datasets by our provided dataset tools.
+    Each line in the txt file records 1)image name (with extension),
+    2)image shape, and 3)compression level, separated by a white space.
+
+    For example, the meta information could be:
+    `000_00000000.png (720,1280,3) 1`, which means:
+    1) image name (with extension): 000_00000000.png;
+    2) image shape: (720,1280,3);
+    3) compression level: 1
+
+    We use the image name without extension as the lmdb key.
+
+    If `multiprocessing_read` is True, it will read all the images to memory
+    using multiprocessing. Thus, your server needs to have enough memory.
+
+    Args:
+        data_path (str): Data path for reading images.
+        lmdb_path (str): Lmdb save path.
+        img_path_list (str): Image path list.
+        keys (str): Used for lmdb keys.
+        batch (int): After processing batch images, lmdb commits.
+            Default: 5000.
+        compress_level (int): Compress level when encoding images. Default: 1.
+        multiprocessing_read (bool): Whether use multiprocessing to read all
+            the images to memory. Default: False.
+        n_thread (int): For multiprocessing.
+        map_size (int | None): Map size for lmdb env. If None, use the
+            estimated size from images. Default: None
+    """
+
+    assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, '
+                                             f'but got {len(img_path_list)} and {len(keys)}')
+    print(f'Create lmdb for {data_path}, save to {lmdb_path}...')
+    print(f'Totoal images: {len(img_path_list)}')
+    if not lmdb_path.endswith('.lmdb'):
+        raise ValueError("lmdb_path must end with '.lmdb'.")
+    if osp.exists(lmdb_path):
+        print(f'Folder {lmdb_path} already exists. Exit.')
+        sys.exit(1)
+
+    if multiprocessing_read:
+        # read all the images to memory (multiprocessing)
+        dataset = {}  # use dict to keep the order for multiprocessing
+        shapes = {}
+        print(f'Read images with multiprocessing, #thread: {n_thread} ...')
+        pbar = tqdm(total=len(img_path_list), unit='image')
+
+        def callback(arg):
+            """get the image data and update pbar."""
+            key, dataset[key], shapes[key] = arg
+            pbar.update(1)
+            pbar.set_description(f'Read {key}')
+
+        pool = Pool(n_thread)
+        for path, key in zip(img_path_list, keys):
+            pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback)
+        pool.close()
+        pool.join()
+        pbar.close()
+        print(f'Finish reading {len(img_path_list)} images.')
+
+    # create lmdb environment
+    if map_size is None:
+        # obtain data size for one image
+        img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED)
+        _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
+        data_size_per_img = img_byte.nbytes
+        print('Data size per image is: ', data_size_per_img)
+        data_size = data_size_per_img * len(img_path_list)
+        map_size = data_size * 10
+
+    env = lmdb.open(lmdb_path, map_size=map_size)
+
+    # write data to lmdb
+    pbar = tqdm(total=len(img_path_list), unit='chunk')
+    txn = env.begin(write=True)
+    txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
+    for idx, (path, key) in enumerate(zip(img_path_list, keys)):
+        pbar.update(1)
+        pbar.set_description(f'Write {key}')
+        key_byte = key.encode('ascii')
+        if multiprocessing_read:
+            img_byte = dataset[key]
+            h, w, c = shapes[key]
+        else:
+            _, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level)
+            h, w, c = img_shape
+
+        txn.put(key_byte, img_byte)
+        # write meta information
+        txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n')
+        if idx % batch == 0:
+            txn.commit()
+            txn = env.begin(write=True)
+    pbar.close()
+    txn.commit()
+    env.close()
+    txt_file.close()
+    print('\nFinish writing lmdb.')
+
+
+def read_img_worker(path, key, compress_level):
+    """Read image worker.
+
+    Args:
+        path (str): Image path.
+        key (str): Image key.
+        compress_level (int): Compress level when encoding images.
+
+    Returns:
+        str: Image key.
+        byte: Image byte.
+        tuple[int]: Image shape.
+    """
+
+    img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
+    if img.ndim == 2:
+        h, w = img.shape
+        c = 1
+    else:
+        h, w, c = img.shape
+    _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
+    return (key, img_byte, (h, w, c))
+
+
+class LmdbMaker():
+    """LMDB Maker.
+
+    Args:
+        lmdb_path (str): Lmdb save path.
+        map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB.
+        batch (int): After processing batch images, lmdb commits.
+            Default: 5000.
+        compress_level (int): Compress level when encoding images. Default: 1.
+    """
+
+    def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1):
+        if not lmdb_path.endswith('.lmdb'):
+            raise ValueError("lmdb_path must end with '.lmdb'.")
+        if osp.exists(lmdb_path):
+            print(f'Folder {lmdb_path} already exists. Exit.')
+            sys.exit(1)
+
+        self.lmdb_path = lmdb_path
+        self.batch = batch
+        self.compress_level = compress_level
+        self.env = lmdb.open(lmdb_path, map_size=map_size)
+        self.txn = self.env.begin(write=True)
+        self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
+        self.counter = 0
+
+    def put(self, img_byte, key, img_shape):
+        self.counter += 1
+        key_byte = key.encode('ascii')
+        self.txn.put(key_byte, img_byte)
+        # write meta information
+        h, w, c = img_shape
+        self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n')
+        if self.counter % self.batch == 0:
+            self.txn.commit()
+            self.txn = self.env.begin(write=True)
+
+    def close(self):
+        self.txn.commit()
+        self.env.close()
+        self.txt_file.close()
diff --git a/basicsr/utils/logger.py b/basicsr/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..73553dc664781a061737e94880ea1c6788c09043
--- /dev/null
+++ b/basicsr/utils/logger.py
@@ -0,0 +1,213 @@
+import datetime
+import logging
+import time
+
+from .dist_util import get_dist_info, master_only
+
+initialized_logger = {}
+
+
+class AvgTimer():
+
+    def __init__(self, window=200):
+        self.window = window  # average window
+        self.current_time = 0
+        self.total_time = 0
+        self.count = 0
+        self.avg_time = 0
+        self.start()
+
+    def start(self):
+        self.start_time = self.tic = time.time()
+
+    def record(self):
+        self.count += 1
+        self.toc = time.time()
+        self.current_time = self.toc - self.tic
+        self.total_time += self.current_time
+        # calculate average time
+        self.avg_time = self.total_time / self.count
+
+        # reset
+        if self.count > self.window:
+            self.count = 0
+            self.total_time = 0
+
+        self.tic = time.time()
+
+    def get_current_time(self):
+        return self.current_time
+
+    def get_avg_time(self):
+        return self.avg_time
+
+
+class MessageLogger():
+    """Message logger for printing.
+
+    Args:
+        opt (dict): Config. It contains the following keys:
+            name (str): Exp name.
+            logger (dict): Contains 'print_freq' (str) for logger interval.
+            train (dict): Contains 'total_iter' (int) for total iters.
+            use_tb_logger (bool): Use tensorboard logger.
+        start_iter (int): Start iter. Default: 1.
+        tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
+    """
+
+    def __init__(self, opt, start_iter=1, tb_logger=None):
+        self.exp_name = opt['name']
+        self.interval = opt['logger']['print_freq']
+        self.start_iter = start_iter
+        self.max_iters = opt['train']['total_iter']
+        self.use_tb_logger = opt['logger']['use_tb_logger']
+        self.tb_logger = tb_logger
+        self.start_time = time.time()
+        self.logger = get_root_logger()
+
+    def reset_start_time(self):
+        self.start_time = time.time()
+
+    @master_only
+    def __call__(self, log_vars):
+        """Format logging message.
+
+        Args:
+            log_vars (dict): It contains the following keys:
+                epoch (int): Epoch number.
+                iter (int): Current iter.
+                lrs (list): List for learning rates.
+
+                time (float): Iter time.
+                data_time (float): Data time for each iter.
+        """
+        # epoch, iter, learning rates
+        epoch = log_vars.pop('epoch')
+        current_iter = log_vars.pop('iter')
+        lrs = log_vars.pop('lrs')
+
+        message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, iter:{current_iter:8,d}, lr:(')
+        for v in lrs:
+            message += f'{v:.3e},'
+        message += ')] '
+
+        # time and estimated time
+        if 'time' in log_vars.keys():
+            iter_time = log_vars.pop('time')
+            data_time = log_vars.pop('data_time')
+
+            total_time = time.time() - self.start_time
+            time_sec_avg = total_time / (current_iter - self.start_iter + 1)
+            eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
+            eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
+            message += f'[eta: {eta_str}, '
+            message += f'time (data): {iter_time:.3f} ({data_time:.3f})] '
+
+        # other items, especially losses
+        for k, v in log_vars.items():
+            message += f'{k}: {v:.4e} '
+            # tensorboard logger
+            if self.use_tb_logger and 'debug' not in self.exp_name:
+                if k.startswith('l_'):
+                    self.tb_logger.add_scalar(f'losses/{k}', v, current_iter)
+                else:
+                    self.tb_logger.add_scalar(k, v, current_iter)
+        self.logger.info(message)
+
+
+@master_only
+def init_tb_logger(log_dir):
+    from torch.utils.tensorboard import SummaryWriter
+    tb_logger = SummaryWriter(log_dir=log_dir)
+    return tb_logger
+
+
+@master_only
+def init_wandb_logger(opt):
+    """We now only use wandb to sync tensorboard log."""
+    import wandb
+    logger = get_root_logger()
+
+    project = opt['logger']['wandb']['project']
+    resume_id = opt['logger']['wandb'].get('resume_id')
+    if resume_id:
+        wandb_id = resume_id
+        resume = 'allow'
+        logger.warning(f'Resume wandb logger with id={wandb_id}.')
+    else:
+        wandb_id = wandb.util.generate_id()
+        resume = 'never'
+
+    wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True)
+
+    logger.info(f'Use wandb logger with id={wandb_id}; project={project}.')
+
+
+def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
+    """Get the root logger.
+
+    The logger will be initialized if it has not been initialized. By default a
+    StreamHandler will be added. If `log_file` is specified, a FileHandler will
+    also be added.
+
+    Args:
+        logger_name (str): root logger name. Default: 'basicsr'.
+        log_file (str | None): The log filename. If specified, a FileHandler
+            will be added to the root logger.
+        log_level (int): The root logger level. Note that only the process of
+            rank 0 is affected, while other processes will set the level to
+            "Error" and be silent most of the time.
+
+    Returns:
+        logging.Logger: The root logger.
+    """
+    logger = logging.getLogger(logger_name)
+    # if the logger has been initialized, just return it
+    if logger_name in initialized_logger:
+        return logger
+
+    format_str = '%(asctime)s %(levelname)s: %(message)s'
+    stream_handler = logging.StreamHandler()
+    stream_handler.setFormatter(logging.Formatter(format_str))
+    logger.addHandler(stream_handler)
+    logger.propagate = False
+    rank, _ = get_dist_info()
+    if rank != 0:
+        logger.setLevel('ERROR')
+    elif log_file is not None:
+        logger.setLevel(log_level)
+        # add file handler
+        file_handler = logging.FileHandler(log_file, 'w')
+        file_handler.setFormatter(logging.Formatter(format_str))
+        file_handler.setLevel(log_level)
+        logger.addHandler(file_handler)
+    initialized_logger[logger_name] = True
+    return logger
+
+
+def get_env_info():
+    """Get environment information.
+
+    Currently, only log the software version.
+    """
+    import torch
+    import torchvision
+
+    from basicsr.version import __version__
+    msg = r"""
+                ____                _       _____  ____
+               / __ ) ____ _ _____ (_)_____/ ___/ / __ \
+              / __  |/ __ `// ___// // ___/\__ \ / /_/ /
+             / /_/ // /_/ /(__  )/ // /__ ___/ // _, _/
+            /_____/ \__,_//____//_/ \___//____//_/ |_|
+     ______                   __   __                 __      __
+    / ____/____   ____   ____/ /  / /   __  __ _____ / /__   / /
+   / / __ / __ \ / __ \ / __  /  / /   / / / // ___// //_/  / /
+  / /_/ // /_/ // /_/ // /_/ /  / /___/ /_/ // /__ / /<    /_/
+  \____/ \____/ \____/ \____/  /_____/\____/ \___//_/|_|  (_)
+    """
+    msg += ('\nVersion Information: '
+            f'\n\tBasicSR: {__version__}'
+            f'\n\tPyTorch: {torch.__version__}'
+            f'\n\tTorchVision: {torchvision.__version__}')
+    return msg
diff --git a/basicsr/utils/matlab_functions.py b/basicsr/utils/matlab_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..a201f79aaf030cdba710dd97c28af1b29a93ed2a
--- /dev/null
+++ b/basicsr/utils/matlab_functions.py
@@ -0,0 +1,178 @@
+import math
+import numpy as np
+import torch
+
+
+def cubic(x):
+    """cubic function used for calculate_weights_indices."""
+    absx = torch.abs(x)
+    absx2 = absx**2
+    absx3 = absx**3
+    return (1.5 * absx3 - 2.5 * absx2 + 1) * (
+        (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) *
+                                                                                     (absx <= 2)).type_as(absx))
+
+
+def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
+    """Calculate weights and indices, used for imresize function.
+
+    Args:
+        in_length (int): Input length.
+        out_length (int): Output length.
+        scale (float): Scale factor.
+        kernel_width (int): Kernel width.
+        antialisaing (bool): Whether to apply anti-aliasing when downsampling.
+    """
+
+    if (scale < 1) and antialiasing:
+        # Use a modified kernel (larger kernel width) to simultaneously
+        # interpolate and antialias
+        kernel_width = kernel_width / scale
+
+    # Output-space coordinates
+    x = torch.linspace(1, out_length, out_length)
+
+    # Input-space coordinates. Calculate the inverse mapping such that 0.5
+    # in output space maps to 0.5 in input space, and 0.5 + scale in output
+    # space maps to 1.5 in input space.
+    u = x / scale + 0.5 * (1 - 1 / scale)
+
+    # What is the left-most pixel that can be involved in the computation?
+    left = torch.floor(u - kernel_width / 2)
+
+    # What is the maximum number of pixels that can be involved in the
+    # computation?  Note: it's OK to use an extra pixel here; if the
+    # corresponding weights are all zero, it will be eliminated at the end
+    # of this function.
+    p = math.ceil(kernel_width) + 2
+
+    # The indices of the input pixels involved in computing the k-th output
+    # pixel are in row k of the indices matrix.
+    indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand(
+        out_length, p)
+
+    # The weights used to compute the k-th output pixel are in row k of the
+    # weights matrix.
+    distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices
+
+    # apply cubic kernel
+    if (scale < 1) and antialiasing:
+        weights = scale * cubic(distance_to_center * scale)
+    else:
+        weights = cubic(distance_to_center)
+
+    # Normalize the weights matrix so that each row sums to 1.
+    weights_sum = torch.sum(weights, 1).view(out_length, 1)
+    weights = weights / weights_sum.expand(out_length, p)
+
+    # If a column in weights is all zero, get rid of it. only consider the
+    # first and last column.
+    weights_zero_tmp = torch.sum((weights == 0), 0)
+    if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
+        indices = indices.narrow(1, 1, p - 2)
+        weights = weights.narrow(1, 1, p - 2)
+    if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
+        indices = indices.narrow(1, 0, p - 2)
+        weights = weights.narrow(1, 0, p - 2)
+    weights = weights.contiguous()
+    indices = indices.contiguous()
+    sym_len_s = -indices.min() + 1
+    sym_len_e = indices.max() - in_length
+    indices = indices + sym_len_s - 1
+    return weights, indices, int(sym_len_s), int(sym_len_e)
+
+
+@torch.no_grad()
+def imresize(img, scale, antialiasing=True):
+    """imresize function same as MATLAB.
+
+    It now only supports bicubic.
+    The same scale applies for both height and width.
+
+    Args:
+        img (Tensor | Numpy array):
+            Tensor: Input image with shape (c, h, w), [0, 1] range.
+            Numpy: Input image with shape (h, w, c), [0, 1] range.
+        scale (float): Scale factor. The same scale applies for both height
+            and width.
+        antialisaing (bool): Whether to apply anti-aliasing when downsampling.
+            Default: True.
+
+    Returns:
+        Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round.
+    """
+    squeeze_flag = False
+    if type(img).__module__ == np.__name__:  # numpy type
+        numpy_type = True
+        if img.ndim == 2:
+            img = img[:, :, None]
+            squeeze_flag = True
+        img = torch.from_numpy(img.transpose(2, 0, 1)).float()
+    else:
+        numpy_type = False
+        if img.ndim == 2:
+            img = img.unsqueeze(0)
+            squeeze_flag = True
+
+    in_c, in_h, in_w = img.size()
+    out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale)
+    kernel_width = 4
+    kernel = 'cubic'
+
+    # get weights and indices
+    weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width,
+                                                                             antialiasing)
+    weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width,
+                                                                             antialiasing)
+    # process H dimension
+    # symmetric copying
+    img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
+    img_aug.narrow(1, sym_len_hs, in_h).copy_(img)
+
+    sym_patch = img[:, :sym_len_hs, :]
+    inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+    sym_patch_inv = sym_patch.index_select(1, inv_idx)
+    img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv)
+
+    sym_patch = img[:, -sym_len_he:, :]
+    inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+    sym_patch_inv = sym_patch.index_select(1, inv_idx)
+    img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv)
+
+    out_1 = torch.FloatTensor(in_c, out_h, in_w)
+    kernel_width = weights_h.size(1)
+    for i in range(out_h):
+        idx = int(indices_h[i][0])
+        for j in range(in_c):
+            out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i])
+
+    # process W dimension
+    # symmetric copying
+    out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we)
+    out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1)
+
+    sym_patch = out_1[:, :, :sym_len_ws]
+    inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+    sym_patch_inv = sym_patch.index_select(2, inv_idx)
+    out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv)
+
+    sym_patch = out_1[:, :, -sym_len_we:]
+    inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+    sym_patch_inv = sym_patch.index_select(2, inv_idx)
+    out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv)
+
+    out_2 = torch.FloatTensor(in_c, out_h, out_w)
+    kernel_width = weights_w.size(1)
+    for i in range(out_w):
+        idx = int(indices_w[i][0])
+        for j in range(in_c):
+            out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i])
+
+    if squeeze_flag:
+        out_2 = out_2.squeeze(0)
+    if numpy_type:
+        out_2 = out_2.numpy()
+        if not squeeze_flag:
+            out_2 = out_2.transpose(1, 2, 0)
+
+    return out_2
diff --git a/basicsr/utils/misc.py b/basicsr/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8d4a1403509672e85e74ac476e028cefb6dbb62
--- /dev/null
+++ b/basicsr/utils/misc.py
@@ -0,0 +1,141 @@
+import numpy as np
+import os
+import random
+import time
+import torch
+from os import path as osp
+
+from .dist_util import master_only
+
+
+def set_random_seed(seed):
+    """Set random seeds."""
+    random.seed(seed)
+    np.random.seed(seed)
+    torch.manual_seed(seed)
+    torch.cuda.manual_seed(seed)
+    torch.cuda.manual_seed_all(seed)
+
+
+def get_time_str():
+    return time.strftime('%Y%m%d_%H%M%S', time.localtime())
+
+
+def mkdir_and_rename(path):
+    """mkdirs. If path exists, rename it with timestamp and create a new one.
+
+    Args:
+        path (str): Folder path.
+    """
+    if osp.exists(path):
+        new_name = path + '_archived_' + get_time_str()
+        print(f'Path already exists. Rename it to {new_name}', flush=True)
+        os.rename(path, new_name)
+    os.makedirs(path, exist_ok=True)
+
+
+@master_only
+def make_exp_dirs(opt):
+    """Make dirs for experiments."""
+    path_opt = opt['path'].copy()
+    if opt['is_train']:
+        mkdir_and_rename(path_opt.pop('experiments_root'))
+    else:
+        mkdir_and_rename(path_opt.pop('results_root'))
+    for key, path in path_opt.items():
+        if ('strict_load' in key) or ('pretrain_network' in key) or ('resume' in key) or ('param_key' in key):
+            continue
+        else:
+            os.makedirs(path, exist_ok=True)
+
+
+def scandir(dir_path, suffix=None, recursive=False, full_path=False):
+    """Scan a directory to find the interested files.
+
+    Args:
+        dir_path (str): Path of the directory.
+        suffix (str | tuple(str), optional): File suffix that we are
+            interested in. Default: None.
+        recursive (bool, optional): If set to True, recursively scan the
+            directory. Default: False.
+        full_path (bool, optional): If set to True, include the dir_path.
+            Default: False.
+
+    Returns:
+        A generator for all the interested files with relative paths.
+    """
+
+    if (suffix is not None) and not isinstance(suffix, (str, tuple)):
+        raise TypeError('"suffix" must be a string or tuple of strings')
+
+    root = dir_path
+
+    def _scandir(dir_path, suffix, recursive):
+        for entry in os.scandir(dir_path):
+            if not entry.name.startswith('.') and entry.is_file():
+                if full_path:
+                    return_path = entry.path
+                else:
+                    return_path = osp.relpath(entry.path, root)
+
+                if suffix is None:
+                    yield return_path
+                elif return_path.endswith(suffix):
+                    yield return_path
+            else:
+                if recursive:
+                    yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
+                else:
+                    continue
+
+    return _scandir(dir_path, suffix=suffix, recursive=recursive)
+
+
+def check_resume(opt, resume_iter):
+    """Check resume states and pretrain_network paths.
+
+    Args:
+        opt (dict): Options.
+        resume_iter (int): Resume iteration.
+    """
+    if opt['path']['resume_state']:
+        # get all the networks
+        networks = [key for key in opt.keys() if key.startswith('network_')]
+        flag_pretrain = False
+        for network in networks:
+            if opt['path'].get(f'pretrain_{network}') is not None:
+                flag_pretrain = True
+        if flag_pretrain:
+            print('pretrain_network path will be ignored during resuming.')
+        # set pretrained model paths
+        for network in networks:
+            name = f'pretrain_{network}'
+            basename = network.replace('network_', '')
+            if opt['path'].get('ignore_resume_networks') is None or (network
+                                                                     not in opt['path']['ignore_resume_networks']):
+                opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth')
+                print(f"Set {name} to {opt['path'][name]}")
+
+        # change param_key to params in resume
+        param_keys = [key for key in opt['path'].keys() if key.startswith('param_key')]
+        for param_key in param_keys:
+            if opt['path'][param_key] == 'params_ema':
+                opt['path'][param_key] = 'params'
+                print(f'Set {param_key} to params')
+
+
+def sizeof_fmt(size, suffix='B'):
+    """Get human readable file size.
+
+    Args:
+        size (int): File size.
+        suffix (str): Suffix. Default: 'B'.
+
+    Return:
+        str: Formatted file size.
+    """
+    for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
+        if abs(size) < 1024.0:
+            return f'{size:3.1f} {unit}{suffix}'
+        size /= 1024.0
+    return f'{size:3.1f} Y{suffix}'
diff --git a/basicsr/utils/options.py b/basicsr/utils/options.py
new file mode 100644
index 0000000000000000000000000000000000000000..3afd79c4f3e73f44f36503288c3959125ac3df34
--- /dev/null
+++ b/basicsr/utils/options.py
@@ -0,0 +1,210 @@
+import argparse
+import os
+import random
+import torch
+import yaml
+from collections import OrderedDict
+from os import path as osp
+
+from basicsr.utils import set_random_seed
+from basicsr.utils.dist_util import get_dist_info, init_dist, master_only
+
+
+def ordered_yaml():
+    """Support OrderedDict for yaml.
+
+    Returns:
+        tuple: yaml Loader and Dumper.
+    """
+    try:
+        from yaml import CDumper as Dumper
+        from yaml import CLoader as Loader
+    except ImportError:
+        from yaml import Dumper, Loader
+
+    _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
+
+    def dict_representer(dumper, data):
+        return dumper.represent_dict(data.items())
+
+    def dict_constructor(loader, node):
+        return OrderedDict(loader.construct_pairs(node))
+
+    Dumper.add_representer(OrderedDict, dict_representer)
+    Loader.add_constructor(_mapping_tag, dict_constructor)
+    return Loader, Dumper
+
+
+def yaml_load(f):
+    """Load yaml file or string.
+
+    Args:
+        f (str): File path or a python string.
+
+    Returns:
+        dict: Loaded dict.
+    """
+    if os.path.isfile(f):
+        with open(f, 'r') as f:
+            return yaml.load(f, Loader=ordered_yaml()[0])
+    else:
+        return yaml.load(f, Loader=ordered_yaml()[0])
+
+
+def dict2str(opt, indent_level=1):
+    """dict to string for printing options.
+
+    Args:
+        opt (dict): Option dict.
+        indent_level (int): Indent level. Default: 1.
+
+    Return:
+        (str): Option string for printing.
+    """
+    msg = '\n'
+    for k, v in opt.items():
+        if isinstance(v, dict):
+            msg += ' ' * (indent_level * 2) + k + ':['
+            msg += dict2str(v, indent_level + 1)
+            msg += ' ' * (indent_level * 2) + ']\n'
+        else:
+            msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
+    return msg
+
+
+def _postprocess_yml_value(value):
+    # None
+    if value == '~' or value.lower() == 'none':
+        return None
+    # bool
+    if value.lower() == 'true':
+        return True
+    elif value.lower() == 'false':
+        return False
+    # !!float number
+    if value.startswith('!!float'):
+        return float(value.replace('!!float', ''))
+    # number
+    if value.isdigit():
+        return int(value)
+    elif value.replace('.', '', 1).isdigit() and value.count('.') < 2:
+        return float(value)
+    # list
+    if value.startswith('['):
+        return eval(value)
+    # str
+    return value
+
+
+def parse_options(root_path, is_train=True):
+    parser = argparse.ArgumentParser()
+    parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.')
+    parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher')
+    parser.add_argument('--auto_resume', action='store_true')
+    parser.add_argument('--debug', action='store_true')
+    parser.add_argument('--local_rank', type=int, default=0)
+    parser.add_argument(
+        '--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999')
+    args = parser.parse_args()
+
+    # parse yml to dict
+    opt = yaml_load(args.opt)
+
+    # distributed settings
+    if args.launcher == 'none':
+        opt['dist'] = False
+        print('Disable distributed.', flush=True)
+    else:
+        opt['dist'] = True
+        if args.launcher == 'slurm' and 'dist_params' in opt:
+            init_dist(args.launcher, **opt['dist_params'])
+        else:
+            init_dist(args.launcher)
+    opt['rank'], opt['world_size'] = get_dist_info()
+
+    # random seed
+    seed = opt.get('manual_seed')
+    if seed is None:
+        seed = random.randint(1, 10000)
+        opt['manual_seed'] = seed
+    set_random_seed(seed + opt['rank'])
+
+    # force to update yml options
+    if args.force_yml is not None:
+        for entry in args.force_yml:
+            # now do not support creating new keys
+            keys, value = entry.split('=')
+            keys, value = keys.strip(), value.strip()
+            value = _postprocess_yml_value(value)
+            eval_str = 'opt'
+            for key in keys.split(':'):
+                eval_str += f'["{key}"]'
+            eval_str += '=value'
+            # using exec function
+            exec(eval_str)
+
+    opt['auto_resume'] = args.auto_resume
+    opt['is_train'] = is_train
+
+    # debug setting
+    if args.debug and not opt['name'].startswith('debug'):
+        opt['name'] = 'debug_' + opt['name']
+
+    if opt['num_gpu'] == 'auto':
+        opt['num_gpu'] = torch.cuda.device_count()
+
+    # datasets
+    for phase, dataset in opt['datasets'].items():
+        # for multiple datasets, e.g., val_1, val_2; test_1, test_2
+        phase = phase.split('_')[0]
+        dataset['phase'] = phase
+        if 'scale' in opt:
+            dataset['scale'] = opt['scale']
+        if dataset.get('dataroot_gt') is not None:
+            dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt'])
+        if dataset.get('dataroot_lq') is not None:
+            dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq'])
+
+    # paths
+    for key, val in opt['path'].items():
+        if (val is not None) and ('resume_state' in key or 'pretrain_network' in key):
+            opt['path'][key] = osp.expanduser(val)
+
+    if is_train:
+        experiments_root = osp.join(root_path, 'experiments', opt['name'])
+        opt['path']['experiments_root'] = experiments_root
+        opt['path']['models'] = osp.join(experiments_root, 'models')
+        opt['path']['training_states'] = osp.join(experiments_root, 'training_states')
+        opt['path']['log'] = experiments_root
+        opt['path']['visualization'] = osp.join(experiments_root, 'visualization')
+
+        # change some options for debug mode
+        if 'debug' in opt['name']:
+            if 'val' in opt:
+                opt['val']['val_freq'] = 8
+            opt['logger']['print_freq'] = 1
+            opt['logger']['save_checkpoint_freq'] = 8
+    else:  # test
+        results_root = osp.join(root_path, 'results', opt['name'])
+        opt['path']['results_root'] = results_root
+        opt['path']['log'] = results_root
+        opt['path']['visualization'] = osp.join(results_root, 'visualization')
+
+    return opt, args
+
+
+@master_only
+def copy_opt_file(opt_file, experiments_root):
+    # copy the yml file to the experiment root
+    import sys
+    import time
+    from shutil import copyfile
+    cmd = ' '.join(sys.argv)
+    filename = osp.join(experiments_root, osp.basename(opt_file))
+    copyfile(opt_file, filename)
+
+    with open(filename, 'r+') as f:
+        lines = f.readlines()
+        lines.insert(0, f'# GENERATE TIME: {time.asctime()}\n# CMD:\n# {cmd}\n\n')
+        f.seek(0)
+        f.writelines(lines)
diff --git a/basicsr/utils/plot_util.py b/basicsr/utils/plot_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e6da5bc29e706da87ab83af6d5367176fe78763
--- /dev/null
+++ b/basicsr/utils/plot_util.py
@@ -0,0 +1,83 @@
+import re
+
+
+def read_data_from_tensorboard(log_path, tag):
+    """Get raw data (steps and values) from tensorboard events.
+
+    Args:
+        log_path (str): Path to the tensorboard log.
+        tag (str): tag to be read.
+    """
+    from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
+
+    # tensorboard event
+    event_acc = EventAccumulator(log_path)
+    event_acc.Reload()
+    scalar_list = event_acc.Tags()['scalars']
+    print('tag list: ', scalar_list)
+    steps = [int(s.step) for s in event_acc.Scalars(tag)]
+    values = [s.value for s in event_acc.Scalars(tag)]
+    return steps, values
+
+
+def read_data_from_txt_2v(path, pattern, step_one=False):
+    """Read data from txt with 2 returned values (usually [step, value]).
+
+    Args:
+        path (str): path to the txt file.
+        pattern (str): re (regular expression) pattern.
+        step_one (bool): add 1 to steps. Default: False.
+    """
+    with open(path) as f:
+        lines = f.readlines()
+    lines = [line.strip() for line in lines]
+    steps = []
+    values = []
+
+    pattern = re.compile(pattern)
+    for line in lines:
+        match = pattern.match(line)
+        if match:
+            steps.append(int(match.group(1)))
+            values.append(float(match.group(2)))
+    if step_one:
+        steps = [v + 1 for v in steps]
+    return steps, values
+
+
+def read_data_from_txt_1v(path, pattern):
+    """Read data from txt with 1 returned values.
+
+    Args:
+        path (str): path to the txt file.
+        pattern (str): re (regular expression) pattern.
+    """
+    with open(path) as f:
+        lines = f.readlines()
+    lines = [line.strip() for line in lines]
+    data = []
+
+    pattern = re.compile(pattern)
+    for line in lines:
+        match = pattern.match(line)
+        if match:
+            data.append(float(match.group(1)))
+    return data
+
+
+def smooth_data(values, smooth_weight):
+    """ Smooth data using 1st-order IIR low-pass filter (what tensorflow does).
+
+    Reference: https://github.com/tensorflow/tensorboard/blob/f801ebf1f9fbfe2baee1ddd65714d0bccc640fb1/tensorboard/plugins/scalar/vz_line_chart/vz-line-chart.ts#L704  # noqa: E501
+
+    Args:
+        values (list): A list of values to be smoothed.
+        smooth_weight (float): Smooth weight.
+    """
+    values_sm = []
+    last_sm_value = values[0]
+    for value in values:
+        value_sm = last_sm_value * smooth_weight + (1 - smooth_weight) * value
+        values_sm.append(value_sm)
+        last_sm_value = value_sm
+    return values_sm
diff --git a/basicsr/utils/realesrgan_utils.py b/basicsr/utils/realesrgan_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff934e5150b4aa568a51ab9614a2057b011a6014
--- /dev/null
+++ b/basicsr/utils/realesrgan_utils.py
@@ -0,0 +1,293 @@
+import cv2
+import math
+import numpy as np
+import os
+import queue
+import threading
+import torch
+from basicsr.utils.download_util import load_file_from_url
+from torch.nn import functional as F
+
+# ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+
+
+class RealESRGANer():
+    """A helper class for upsampling images with RealESRGAN.
+
+    Args:
+        scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
+        model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
+        model (nn.Module): The defined network. Default: None.
+        tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
+            input images into tiles, and then process each of them. Finally, they will be merged into one image.
+            0 denotes for do not use tile. Default: 0.
+        tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
+        pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
+        half (float): Whether to use half precision during inference. Default: False.
+    """
+
+    def __init__(self,
+                 scale,
+                 model_path,
+                 model=None,
+                 tile=0,
+                 tile_pad=10,
+                 pre_pad=10,
+                 half=False,
+                 device=None,
+                 gpu_id=None):
+        self.scale = scale
+        self.tile_size = tile
+        self.tile_pad = tile_pad
+        self.pre_pad = pre_pad
+        self.mod_scale = None
+        self.half = half
+
+        # initialize model
+        if gpu_id:
+            self.device = torch.device(
+                f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
+        else:
+            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
+        # if the model_path starts with https, it will first download models to the folder: realesrgan/weights
+        if model_path.startswith('https://'):
+            model_path = load_file_from_url(
+                url=model_path, model_dir=os.path.join('weights/realesrgan'), progress=True, file_name=None)
+        loadnet = torch.load(model_path, map_location=torch.device('cpu'))
+        # prefer to use params_ema
+        if 'params_ema' in loadnet:
+            keyname = 'params_ema'
+        else:
+            keyname = 'params'
+        model.load_state_dict(loadnet[keyname], strict=True)
+        model.eval()
+        self.model = model.to(self.device)
+        if self.half:
+            self.model = self.model.half()
+
+    def pre_process(self, img):
+        """Pre-process, such as pre-pad and mod pad, so that the images can be divisible
+        """
+        img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
+        self.img = img.unsqueeze(0).to(self.device)
+        if self.half:
+            self.img = self.img.half()
+
+        # pre_pad
+        if self.pre_pad != 0:
+            self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
+        # mod pad for divisible borders
+        if self.scale == 2:
+            self.mod_scale = 2
+        elif self.scale == 1:
+            self.mod_scale = 4
+        if self.mod_scale is not None:
+            self.mod_pad_h, self.mod_pad_w = 0, 0
+            _, _, h, w = self.img.size()
+            if (h % self.mod_scale != 0):
+                self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
+            if (w % self.mod_scale != 0):
+                self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
+            self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
+
+    def process(self):
+        # model inference
+        self.output = self.model(self.img)
+
+    def tile_process(self):
+        """It will first crop input images to tiles, and then process each tile.
+        Finally, all the processed tiles are merged into one images.
+
+        Modified from: https://github.com/ata4/esrgan-launcher
+        """
+        batch, channel, height, width = self.img.shape
+        output_height = height * self.scale
+        output_width = width * self.scale
+        output_shape = (batch, channel, output_height, output_width)
+
+        # start with black image
+        self.output = self.img.new_zeros(output_shape)
+        tiles_x = math.ceil(width / self.tile_size)
+        tiles_y = math.ceil(height / self.tile_size)
+
+        # loop over all tiles
+        for y in range(tiles_y):
+            for x in range(tiles_x):
+                # extract tile from input image
+                ofs_x = x * self.tile_size
+                ofs_y = y * self.tile_size
+                # input tile area on total image
+                input_start_x = ofs_x
+                input_end_x = min(ofs_x + self.tile_size, width)
+                input_start_y = ofs_y
+                input_end_y = min(ofs_y + self.tile_size, height)
+
+                # input tile area on total image with padding
+                input_start_x_pad = max(input_start_x - self.tile_pad, 0)
+                input_end_x_pad = min(input_end_x + self.tile_pad, width)
+                input_start_y_pad = max(input_start_y - self.tile_pad, 0)
+                input_end_y_pad = min(input_end_y + self.tile_pad, height)
+
+                # input tile dimensions
+                input_tile_width = input_end_x - input_start_x
+                input_tile_height = input_end_y - input_start_y
+                tile_idx = y * tiles_x + x + 1
+                input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
+
+                # upscale tile
+                try:
+                    with torch.no_grad():
+                        output_tile = self.model(input_tile)
+                except RuntimeError as error:
+                    print('Error', error)
+                # print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
+
+                # output tile area on total image
+                output_start_x = input_start_x * self.scale
+                output_end_x = input_end_x * self.scale
+                output_start_y = input_start_y * self.scale
+                output_end_y = input_end_y * self.scale
+
+                # output tile area without padding
+                output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
+                output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
+                output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
+                output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
+
+                # put tile into output image
+                self.output[:, :, output_start_y:output_end_y,
+                            output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
+                                                                       output_start_x_tile:output_end_x_tile]
+
+    def post_process(self):
+        # remove extra pad
+        if self.mod_scale is not None:
+            _, _, h, w = self.output.size()
+            self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
+        # remove prepad
+        if self.pre_pad != 0:
+            _, _, h, w = self.output.size()
+            self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
+        return self.output
+
+    @torch.no_grad()
+    def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
+        h_input, w_input = img.shape[0:2]
+        # img: numpy
+        img = img.astype(np.float32)
+        if np.max(img) > 256:  # 16-bit image
+            max_range = 65535
+            print('\tInput is a 16-bit image')
+        else:
+            max_range = 255
+        img = img / max_range
+        if len(img.shape) == 2:  # gray image
+            img_mode = 'L'
+            img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
+        elif img.shape[2] == 4:  # RGBA image with alpha channel
+            img_mode = 'RGBA'
+            alpha = img[:, :, 3]
+            img = img[:, :, 0:3]
+            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+            if alpha_upsampler == 'realesrgan':
+                alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
+        else:
+            img_mode = 'RGB'
+            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+
+        # ------------------- process image (without the alpha channel) ------------------- #
+        self.pre_process(img)
+        if self.tile_size > 0:
+            self.tile_process()
+        else:
+            self.process()
+        output_img = self.post_process()
+        output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
+        output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
+        if img_mode == 'L':
+            output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
+
+        # ------------------- process the alpha channel if necessary ------------------- #
+        if img_mode == 'RGBA':
+            if alpha_upsampler == 'realesrgan':
+                self.pre_process(alpha)
+                if self.tile_size > 0:
+                    self.tile_process()
+                else:
+                    self.process()
+                output_alpha = self.post_process()
+                output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
+                output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
+                output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
+            else:  # use the cv2 resize for alpha channel
+                h, w = alpha.shape[0:2]
+                output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
+
+            # merge the alpha channel
+            output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
+            output_img[:, :, 3] = output_alpha
+
+        # ------------------------------ return ------------------------------ #
+        if max_range == 65535:  # 16-bit image
+            output = (output_img * 65535.0).round().astype(np.uint16)
+        else:
+            output = (output_img * 255.0).round().astype(np.uint8)
+
+        if outscale is not None and outscale != float(self.scale):
+            output = cv2.resize(
+                output, (
+                    int(w_input * outscale),
+                    int(h_input * outscale),
+                ), interpolation=cv2.INTER_LANCZOS4)
+
+        return output, img_mode
+
+
+class PrefetchReader(threading.Thread):
+    """Prefetch images.
+
+    Args:
+        img_list (list[str]): A image list of image paths to be read.
+        num_prefetch_queue (int): Number of prefetch queue.
+    """
+
+    def __init__(self, img_list, num_prefetch_queue):
+        super().__init__()
+        self.que = queue.Queue(num_prefetch_queue)
+        self.img_list = img_list
+
+    def run(self):
+        for img_path in self.img_list:
+            img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
+            self.que.put(img)
+
+        self.que.put(None)
+
+    def __next__(self):
+        next_item = self.que.get()
+        if next_item is None:
+            raise StopIteration
+        return next_item
+
+    def __iter__(self):
+        return self
+
+
+class IOConsumer(threading.Thread):
+
+    def __init__(self, opt, que, qid):
+        super().__init__()
+        self._queue = que
+        self.qid = qid
+        self.opt = opt
+
+    def run(self):
+        while True:
+            msg = self._queue.get()
+            if isinstance(msg, str) and msg == 'quit':
+                break
+
+            output = msg['output']
+            save_path = msg['save_path']
+            cv2.imwrite(save_path, output)
+        print(f'IO worker {self.qid} is done.')
diff --git a/basicsr/utils/registry.py b/basicsr/utils/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e72ef7ff21b94f50e6caa8948f69ca0b04bc968
--- /dev/null
+++ b/basicsr/utils/registry.py
@@ -0,0 +1,88 @@
+# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py  # noqa: E501
+
+
+class Registry():
+    """
+    The registry that provides name -> object mapping, to support third-party
+    users' custom modules.
+
+    To create a registry (e.g. a backbone registry):
+
+    .. code-block:: python
+
+        BACKBONE_REGISTRY = Registry('BACKBONE')
+
+    To register an object:
+
+    .. code-block:: python
+
+        @BACKBONE_REGISTRY.register()
+        class MyBackbone():
+            ...
+
+    Or:
+
+    .. code-block:: python
+
+        BACKBONE_REGISTRY.register(MyBackbone)
+    """
+
+    def __init__(self, name):
+        """
+        Args:
+            name (str): the name of this registry
+        """
+        self._name = name
+        self._obj_map = {}
+
+    def _do_register(self, name, obj, suffix=None):
+        if isinstance(suffix, str):
+            name = name + '_' + suffix
+
+        assert (name not in self._obj_map), (f"An object named '{name}' was already registered "
+                                             f"in '{self._name}' registry!")
+        self._obj_map[name] = obj
+
+    def register(self, obj=None, suffix=None):
+        """
+        Register the given object under the the name `obj.__name__`.
+        Can be used as either a decorator or not.
+        See docstring of this class for usage.
+        """
+        if obj is None:
+            # used as a decorator
+            def deco(func_or_class):
+                name = func_or_class.__name__
+                self._do_register(name, func_or_class, suffix)
+                return func_or_class
+
+            return deco
+
+        # used as a function call
+        name = obj.__name__
+        self._do_register(name, obj, suffix)
+
+    def get(self, name, suffix='basicsr'):
+        ret = self._obj_map.get(name)
+        if ret is None:
+            ret = self._obj_map.get(name + '_' + suffix)
+            print(f'Name {name} is not found, use name: {name}_{suffix}!')
+        if ret is None:
+            raise KeyError(f"No object named '{name}' found in '{self._name}' registry!")
+        return ret
+
+    def __contains__(self, name):
+        return name in self._obj_map
+
+    def __iter__(self):
+        return iter(self._obj_map.items())
+
+    def keys(self):
+        return self._obj_map.keys()
+
+
+DATASET_REGISTRY = Registry('dataset')
+ARCH_REGISTRY = Registry('arch')
+MODEL_REGISTRY = Registry('model')
+LOSS_REGISTRY = Registry('loss')
+METRIC_REGISTRY = Registry('metric')
diff --git a/configs/sr.yaml b/configs/sr.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9ba43a29fde4690dc9483f6fb536ca5dbe8f6e18
--- /dev/null
+++ b/configs/sr.yaml
@@ -0,0 +1,110 @@
+sf: 4
+degradation:
+  # the first degradation process
+  resize_prob: [0.2, 0.7, 0.1]  # up, down, keep
+  resize_range: [0.3, 1.5]
+  gaussian_noise_prob: 0.5
+  noise_range: [1, 15]
+  poisson_scale_range: [0.05, 2.0]
+  gray_noise_prob: 0.4
+  jpeg_range: [60, 95]
+
+  # the second degradation process
+  second_blur_prob: 0.5
+  resize_prob2: [0.3, 0.4, 0.3]  # up, down, keep
+  resize_range2: [0.6, 1.2]
+  gaussian_noise_prob2: 0.5
+  noise_range2: [1, 12]
+  poisson_scale_range2: [0.05, 1.0]
+  gray_noise_prob2: 0.4
+  jpeg_range2: [60, 100]
+
+  gt_size: 512
+  no_degradation_prob: 0.01
+
+train:
+  queue_size: 180
+  gt_path: ['dataset_path/LSDIR/']
+  face_gt_path: 'dataset_path/FFHQ/'
+  num_face: 10000
+  crop_size: 512
+  io_backend:
+    type: disk
+
+  blur_kernel_size: 21
+  kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
+  kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
+  sinc_prob: 0.1
+  blur_sigma: [0.2, 1.5]
+  betag_range: [0.5, 2.0]
+  betap_range: [1, 1.5]
+
+  blur_kernel_size2: 11
+  kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
+  kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
+  sinc_prob2: 0.1
+  blur_sigma2: [0.2, 1.0]
+  betag_range2: [0.5, 2.0]
+  betap_range2: [1, 1.5]
+
+  final_sinc_prob: 0.8
+
+  gt_size: 512
+  use_hflip: True
+  use_rot: False
+
+validation:
+  gt_path: dataset_path/DIV2K_valid_HR/
+  crop_size: 512
+  io_backend:
+    type: disk
+
+  blur_kernel_size: 21
+  kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
+  kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
+  sinc_prob: 0.1
+  blur_sigma: [0.2, 1.5]
+  betag_range: [0.5, 2.0]
+  betap_range: [1, 1.5]
+
+  blur_kernel_size2: 11
+  kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
+  kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
+  sinc_prob2: 0.1
+  blur_sigma2: [0.2, 1.0]
+  betag_range2: [0.5, 2.0]
+  betap_range2: [1, 1.5]
+
+  final_sinc_prob: 0.8
+
+  gt_size: 512
+  use_hflip: True
+  use_rot: False
+
+test:
+  gt_path: dataset_path/DIV2K_valid_HR/
+  crop_size: 512
+  io_backend:
+    type: disk
+
+  blur_kernel_size: 21
+  kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
+  kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
+  sinc_prob: 0.1
+  blur_sigma: [0.2, 1.5]
+  betag_range: [0.5, 2.0]
+  betap_range: [1, 1.5]
+
+  blur_kernel_size2: 11
+  kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
+  kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
+  sinc_prob2: 0.1
+  blur_sigma2: [0.2, 1.0]
+  betag_range2: [0.5, 2.0]
+  betap_range2: [1, 1.5]
+
+  final_sinc_prob: 0.8
+
+  gt_size: 512
+  use_hflip: True
+  use_rot: False
\ No newline at end of file
diff --git a/configs/sr_test.yaml b/configs/sr_test.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1ae403a0275e35a3637264946db822b4bc6e4855
--- /dev/null
+++ b/configs/sr_test.yaml
@@ -0,0 +1,6 @@
+sf: 4
+
+validation:
+  lr_path: path_to_LR_image_folder
+  io_backend:
+    type: disk
\ No newline at end of file
diff --git a/environment.yaml b/environment.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0730f5e21cd6c4b75b7659e989a6fc51c61e40f6
--- /dev/null
+++ b/environment.yaml
@@ -0,0 +1,35 @@
+name: s3diff
+channels:
+  - pytorch
+  - defaults
+dependencies:
+  - python=3.10
+  - pip:
+      - einops>=0.6.1
+      - numpy>=1.24.4
+      - open-clip-torch>=2.20.0
+      - opencv-python==4.6.0.66
+      - pillow>=9.5.0
+      - scipy==1.11.1
+      - timm>=0.9.2
+      - tokenizers
+      - torch>=2.0.1
+
+      - torchaudio>=2.0.2
+      - torchdata==0.6.1
+      - torchmetrics>=1.0.1
+      - torchvision>=0.15.2
+
+      - tqdm>=4.65.0
+      - transformers==4.35.2
+      - triton==2.0.0
+      - urllib3<1.27,>=1.25.4
+      - xformers>=0.0.20
+      - streamlit-keyup==0.2.0
+      - lpips
+      - peft
+      - pyiqa
+      - omegaconf
+      - dominate
+      - diffusers==0.25.1
+      - gradio==3.43.1
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..177c3c3edd583cead7360e95b5d4774034218662
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,28 @@
+einops>=0.6.1
+numpy>=1.24.4
+open-clip-torch>=2.20.0
+opencv-python==4.6.0.66
+pillow>=9.5.0
+scipy==1.11.1
+timm>=0.9.2
+tokenizers
+
+torch>=2.0.1
+torchaudio>=2.0.2
+torchdata==0.6.1
+torchmetrics>=1.0.1
+torchvision>=0.15.2
+
+tqdm>=4.65.0
+transformers==4.35.2
+triton==2.0.0
+urllib3<1.27,>=1.25.4
+xformers>=0.0.20
+streamlit-keyup==0.2.0
+lpips
+peft
+pyiqa
+omegaconf
+dominate
+diffusers==0.25.1
+gradio==3.43.1
diff --git a/run_evaluate.sh b/run_evaluate.sh
new file mode 100644
index 0000000000000000000000000000000000000000..0a47eff30dee5e0297655ea88a1531b7f3e80749
--- /dev/null
+++ b/run_evaluate.sh
@@ -0,0 +1 @@
+python src/evaluate_img.py -i "path_to_generated_HR" -r "path_to_ground_truth"
diff --git a/run_inference.sh b/run_inference.sh
new file mode 100644
index 0000000000000000000000000000000000000000..57cab9d9fbdc349059325b8693dfd877aef02943
--- /dev/null
+++ b/run_inference.sh
@@ -0,0 +1,7 @@
+accelerate launch --num_processes=1 --gpu_ids="0," --main_process_port 29300 src/inference_s3diff.py \
+    --sd_path="path_to_checkpoints/sd-turbo" \
+    --de_net_path="assets/mm-realsr/de_net.pth" \
+    --pretrained_path="path_to_checkpoints_folder/model_30001.pkl" \
+    --output_dir="./output" \
+    --ref_path="path_to_ground_truth_folder" \
+    --align_method="wavelet"
diff --git a/run_training.sh b/run_training.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c1ab504323fd203cae2445a7dea10181207a3f26
--- /dev/null
+++ b/run_training.sh
@@ -0,0 +1,8 @@
+accelerate launch --num_processes=4 --gpu_ids="0,1,2,3" --main_process_port 29300 src/train_s3diff.py \
+    --sd_path="path_to_checkpoints/sd-turbo" \
+    --de_net_path="assets/mm-realsr/de_net.pth" \
+    --output_dir="./output" \
+    --resolution=512 \
+    --train_batch_size=4 \
+    --enable_xformers_memory_efficient_attention \
+    --viz_freq 25
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b74bb6c0195ef838585c37e9c95c45db9e98a16
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,13 @@
+from setuptools import setup, find_packages
+
+setup(
+    name='S3Diff',
+    version='0.0.1',
+    description='',
+    packages=find_packages(),
+    install_requires=[
+        'torch',
+        'numpy',
+        'tqdm',
+    ],
+)
diff --git a/src/de_net.py b/src/de_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..0465d9d6c043ed8171c3ceb178f25e05ca5fc558
--- /dev/null
+++ b/src/de_net.py
@@ -0,0 +1,127 @@
+import torch
+import copy
+from torch import nn as nn
+from basicsr.archs.arch_util import ResidualBlockNoBN, default_init_weights
+
+class DEResNet(nn.Module):
+    """Degradation Estimator with ResNetNoBN arch. v2.1, no vector anymore
+
+    As shown in paper 'Towards Flexible Blind JPEG Artifacts Removal',
+    resnet arch works for image quality estimation.
+
+    Args:
+        num_in_ch (int): channel number of inputs. Default: 3.
+        num_degradation (int): num of degradation the DE should estimate. Default: 2(blur+noise).
+        degradation_embed_size (int): embedding size of each degradation vector.
+        degradation_degree_actv (int): activation function for degradation degree scalar. Default: sigmoid.
+        num_feats (list): channel number of each stage.
+        num_blocks (list): residual block of each stage.
+        downscales (list): downscales of each stage.
+    """
+
+    def __init__(self,
+                 num_in_ch=3,
+                 num_degradation=2,
+                 degradation_degree_actv='sigmoid',
+                 num_feats=[64, 64, 64, 128],
+                 num_blocks=[2, 2, 2, 2],
+                 downscales=[1, 1, 2, 1]):
+        super(DEResNet, self).__init__()
+
+        assert isinstance(num_feats, list)
+        assert isinstance(num_blocks, list)
+        assert isinstance(downscales, list)
+        assert len(num_feats) == len(num_blocks) and len(num_feats) == len(downscales)
+
+        num_stage = len(num_feats)
+
+        self.conv_first = nn.ModuleList()
+        for _ in range(num_degradation):
+            self.conv_first.append(nn.Conv2d(num_in_ch, num_feats[0], 3, 1, 1))
+        self.body = nn.ModuleList()
+        for _ in range(num_degradation):
+            body = list()
+            for stage in range(num_stage):
+                for _ in range(num_blocks[stage]):
+                    body.append(ResidualBlockNoBN(num_feats[stage]))
+                if downscales[stage] == 1:
+                    if stage < num_stage - 1 and num_feats[stage] != num_feats[stage + 1]:
+                        body.append(nn.Conv2d(num_feats[stage], num_feats[stage + 1], 3, 1, 1))
+                    continue
+                elif downscales[stage] == 2:
+                    body.append(nn.Conv2d(num_feats[stage], num_feats[min(stage + 1, num_stage - 1)], 3, 2, 1))
+                else:
+                    raise NotImplementedError
+            self.body.append(nn.Sequential(*body))
+
+        self.num_degradation = num_degradation
+        self.fc_degree = nn.ModuleList()
+        if degradation_degree_actv == 'sigmoid':
+            actv = nn.Sigmoid
+        elif degradation_degree_actv == 'tanh':
+            actv = nn.Tanh
+        else:
+            raise NotImplementedError(f'only sigmoid and tanh are supported for degradation_degree_actv, '
+                                      f'{degradation_degree_actv} is not supported yet.')
+        for _ in range(num_degradation):
+            self.fc_degree.append(
+                nn.Sequential(
+                    nn.Linear(num_feats[-1], 512),
+                    nn.ReLU(inplace=True),
+                    nn.Linear(512, 1),
+                    actv(),
+                ))
+
+        self.avg_pool = nn.AdaptiveAvgPool2d(1)
+
+        default_init_weights([self.conv_first, self.body, self.fc_degree], 0.1)
+
+    def clone_module(self, module):
+        new_module = copy.deepcopy(module)
+        return new_module
+
+    def average_parameters(self, modules):
+        avg_module = self.clone_module(modules[0])
+        for name, param in avg_module.named_parameters():
+            avg_param = sum([mod.state_dict()[name].data for mod in modules]) / len(modules)
+            param.data.copy_(avg_param)
+        return avg_module
+
+    def expand_degradation_modules(self, new_num_degradation):
+        if new_num_degradation <= self.num_degradation:
+            return
+        initial_modules = [self.conv_first, self.body, self.fc_degree]
+
+        for modules in initial_modules:
+            avg_module = self.average_parameters(modules[:2])
+            while len(modules) < new_num_degradation:
+                modules.append(self.clone_module(avg_module))
+
+    def load_and_expand_model(self, path, num_degradation):
+        state_dict = torch.load(path, map_location=torch.device('cpu'))
+        self.load_state_dict(state_dict, strict=True)
+        
+        self.expand_degradation_modules(num_degradation)
+        self.num_degradation = num_degradation
+
+    def load_model(self, path):
+        state_dict = torch.load(path, map_location=torch.device('cpu'))
+        self.load_state_dict(state_dict, strict=True)
+
+    def set_train(self):
+        self.conv_first.requires_grad_(True)
+        self.fc_degree.requires_grad_(True)
+        for n, _p in self.body.named_parameters():
+            if "lora" in n:
+                _p.requires_grad = True
+
+    def forward(self, x):
+        degrees = []
+        for i in range(self.num_degradation):
+            x_out = self.conv_first[i](x)
+            feat = self.body[i](x_out)
+            feat = self.avg_pool(feat)
+            feat = feat.squeeze(-1).squeeze(-1)
+            # for i in range(self.num_degradation):
+            degrees.append(self.fc_degree[i](feat).squeeze(-1))
+        return torch.stack(degrees, dim=1)
\ No newline at end of file
diff --git a/src/evaluate_img.py b/src/evaluate_img.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ac421da6f30d0b57d611d60478562ae8219616b
--- /dev/null
+++ b/src/evaluate_img.py
@@ -0,0 +1,72 @@
+import pyiqa
+import os
+import argparse
+from pathlib import Path
+import torch
+from utils import util_image
+import tqdm
+
+device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+
+print(pyiqa.list_models())
+def evaluate(in_path, ref_path, ntest):
+    metric_dict = {}
+    metric_dict["clipiqa"] = pyiqa.create_metric('clipiqa').to(device)
+    metric_dict["musiq"] = pyiqa.create_metric('musiq').to(device)
+    metric_dict["niqe"] = pyiqa.create_metric('niqe').to(device)
+    metric_dict["maniqa"] = pyiqa.create_metric('maniqa').to(device)
+    metric_paired_dict = {}
+    
+    in_path = Path(in_path) if not isinstance(in_path, Path) else in_path
+    assert in_path.is_dir()
+    
+    ref_path_list = None
+    if ref_path is not None:
+        ref_path = Path(ref_path) if not isinstance(ref_path, Path) else ref_path
+        ref_path_list = sorted([x for x in ref_path.glob("*.[jpJP][pnPN]*[gG]")])
+        if ntest is not None: ref_path_list = ref_path_list[:ntest]
+        
+        metric_paired_dict["psnr"]=pyiqa.create_metric('psnr', test_y_channel=True, color_space='ycbcr').to(device)
+        metric_paired_dict["lpips"]=pyiqa.create_metric('lpips').to(device)
+        metric_paired_dict["dists"]=pyiqa.create_metric('dists').to(device)
+        metric_paired_dict["ssim"]=pyiqa.create_metric('ssim', test_y_channel=True, color_space='ycbcr' ).to(device)
+        
+    lr_path_list = sorted([x for x in in_path.glob("*.[jpJP][pnPN]*[gG]")])
+    if ntest is not None: lr_path_list = lr_path_list[:ntest]
+    
+    print(f'Find {len(lr_path_list)} images in {in_path}')
+    result = {}
+    for i in tqdm.tqdm(range(len(lr_path_list))):
+        _in_path = lr_path_list[i]
+        _ref_path = ref_path_list[i] if ref_path_list is not None else None
+        
+        im_in = util_image.imread(_in_path, chn='rgb', dtype='float32')  # h x w x c
+        im_in_tensor = util_image.img2tensor(im_in).cuda()              # 1 x c x h x w
+        for key, metric in metric_dict.items():
+            with torch.cuda.amp.autocast():
+                result[key] = result.get(key, 0) + metric(im_in_tensor).item()
+        
+        if ref_path is not None:
+            im_ref = util_image.imread(_ref_path, chn='rgb', dtype='float32')  # h x w x c
+            im_ref_tensor = util_image.img2tensor(im_ref).cuda()    
+            for key, metric in metric_paired_dict.items():
+                result[key] = result.get(key, 0) + metric(im_in_tensor, im_ref_tensor).item()
+    
+    if ref_path is not None:
+        fid_metric = pyiqa.create_metric('fid')
+        result['fid'] = fid_metric(in_path, ref_path)
+
+    for key, res in result.items():
+        if key == 'fid':
+            print(f"{key}: {res:.2f}")
+        else:
+            print(f"{key}: {res/len(lr_path_list):.5f}")
+        
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('-i',"--in_path", type=str, required=True)
+    parser.add_argument("-r", "--ref_path", type=str, default=None)
+    parser.add_argument("--ntest", type=int, default=None)
+    args = parser.parse_args()
+    evaluate(args.in_path, args.ref_path, args.ntest)
+    
\ No newline at end of file
diff --git a/src/gradio_s3diff.py b/src/gradio_s3diff.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbc1a6d18ac2f92692a51829559bc86d9550ca64
--- /dev/null
+++ b/src/gradio_s3diff.py
@@ -0,0 +1,157 @@
+import gradio as gr
+import os
+import sys
+import math
+from typing import List
+
+import numpy as np
+from PIL import Image
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from diffusers.utils.import_utils import is_xformers_available
+
+from my_utils.testing_utils import parse_args_paired_testing
+from de_net import DEResNet
+from s3diff_tile import S3Diff
+from torchvision import transforms
+from utils.wavelet_color import wavelet_color_fix, adain_color_fix
+
+tensor_transforms = transforms.Compose([
+                transforms.ToTensor(),
+            ])
+
+args = parse_args_paired_testing()
+
+# Load scheduler, tokenizer and models.
+pretrained_model_path = 'checkpoint-path/s3diff.pkl'
+t2i_path = 'sd-turbo-path'
+de_net_path = 'assets/mm-realsr/de_net.pth'
+
+# initialize net_sr
+net_sr = S3Diff(lora_rank_unet=args.lora_rank_unet, lora_rank_vae=args.lora_rank_vae, sd_path=t2i_path, pretrained_path=pretrained_model_path, args=args)
+net_sr.set_eval()
+
+# initalize degradation estimation network
+net_de = DEResNet(num_in_ch=3, num_degradation=2)
+net_de.load_model(de_net_path)
+net_de = net_de.cuda()
+net_de.eval()
+
+if args.enable_xformers_memory_efficient_attention:
+    if is_xformers_available():
+        net_sr.unet.enable_xformers_memory_efficient_attention()
+    else:
+        raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+if args.gradient_checkpointing:
+    net_sr.unet.enable_gradient_checkpointing()
+
+weight_dtype = torch.float32
+device = "cuda"
+
+# Move text_encode and vae to gpu and cast to weight_dtype
+net_sr.to(device, dtype=weight_dtype)
+net_de.to(device, dtype=weight_dtype)
+
+@torch.no_grad()
+def process(
+    input_image: Image.Image,
+    scale_factor: float,
+    cfg_scale: float,
+    latent_tiled_size: int,
+    latent_tiled_overlap: int,
+    align_method: str,
+    ) -> List[np.ndarray]:
+
+    # positive_prompt = ""
+    # negative_prompt = ""
+
+    net_sr._set_latent_tile(latent_tiled_size = latent_tiled_size, latent_tiled_overlap = latent_tiled_overlap)
+
+    im_lr = tensor_transforms(input_image).unsqueeze(0).to(device)
+    ori_h, ori_w = im_lr.shape[2:]
+    im_lr_resize = F.interpolate(
+        im_lr,
+        size=(int(ori_h * scale_factor),
+              int(ori_w * scale_factor)),
+        mode='bicubic',
+        )
+    im_lr_resize = im_lr_resize.contiguous() 
+    im_lr_resize_norm = im_lr_resize * 2 - 1.0
+    im_lr_resize_norm = torch.clamp(im_lr_resize_norm, -1.0, 1.0)
+    resize_h, resize_w = im_lr_resize_norm.shape[2:]
+
+    pad_h = (math.ceil(resize_h / 64)) * 64 - resize_h
+    pad_w = (math.ceil(resize_w / 64)) * 64 - resize_w
+    im_lr_resize_norm = F.pad(im_lr_resize_norm, pad=(0, pad_w, 0, pad_h), mode='reflect')
+      
+    try:
+        with torch.autocast("cuda"):
+            deg_score = net_de(im_lr)
+
+            pos_tag_prompt = [args.pos_prompt]
+            neg_tag_prompt = [args.neg_prompt]
+
+            x_tgt_pred = net_sr(im_lr_resize_norm, deg_score, pos_prompt=pos_tag_prompt, neg_prompt=neg_tag_prompt)
+            x_tgt_pred = x_tgt_pred[:, :, :resize_h, :resize_w]
+            out_img = (x_tgt_pred * 0.5 + 0.5).cpu().detach()
+
+        output_pil = transforms.ToPILImage()(out_img[0])
+
+        if align_method == 'no fix':
+            image = output_pil
+        else:
+            im_lr_resize = transforms.ToPILImage()(im_lr_resize[0])
+            if align_method == 'wavelet':
+                image = wavelet_color_fix(output_pil, im_lr_resize)
+            elif align_method == 'adain':
+                image = adain_color_fix(output_pil, im_lr_resize)
+
+    except Exception as e:
+        print(e)
+        image = Image.new(mode="RGB", size=(512, 512))
+
+    return image
+
+
+#
+MARKDOWN = \
+"""
+## Degradation-Guided One-Step Image Super-Resolution with Diffusion Priors
+
+[GitHub](https://github.com/ArcticHare105/S3Diff) | [Paper](https://arxiv.org/abs/2409.17058)
+
+If S3Diff is helpful for you, please help star the GitHub Repo. Thanks!
+"""
+
+block = gr.Blocks().queue()
+with block:
+    with gr.Row():
+        gr.Markdown(MARKDOWN)
+    with gr.Row():
+        with gr.Column():
+            input_image = gr.Image(source="upload", type="pil")
+            run_button = gr.Button(label="Run")
+            with gr.Accordion("Options", open=True):
+                cfg_scale = gr.Slider(label="Classifier Free Guidance Scale (Set a value larger than 1 to enable it!)", minimum=1.0, maximum=1.1, value=1.07, step=0.01)
+                scale_factor = gr.Number(label="SR Scale", value=4)
+                latent_tiled_size = gr.Slider(label="Tile Size", minimum=64, maximum=160, value=96, step=1)
+                latent_tiled_overlap = gr.Slider(label="Tile Overlap", minimum=16, maximum=48, value=32, step=1)
+                align_method = gr.Dropdown(label="Color Correction", choices=["wavelet", "adain", "no fix"], value="wavelet")
+        with gr.Column():
+            result_image = gr.Image(label="Output", show_label=False, elem_id="result_image", source="canvas", width="100%", height="auto")
+
+    inputs = [
+        input_image,
+        scale_factor,
+        cfg_scale,
+        latent_tiled_size,
+        latent_tiled_overlap,
+        align_method
+    ]
+    run_button.click(fn=process, inputs=inputs, outputs=[result_image])
+
+block.launch()
+
diff --git a/src/inference_s3diff.py b/src/inference_s3diff.py
new file mode 100644
index 0000000000000000000000000000000000000000..1240141c523b7d33769068b713634924f3c670c1
--- /dev/null
+++ b/src/inference_s3diff.py
@@ -0,0 +1,218 @@
+import os
+import gc
+import tqdm
+import math
+import lpips
+import pyiqa
+import argparse
+import clip
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+
+from omegaconf import OmegaConf
+from accelerate import Accelerator
+from accelerate.utils import set_seed
+from PIL import Image
+from torchvision import transforms
+# from tqdm.auto import tqdm
+
+import diffusers
+import utils.misc as misc
+
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.optimization import get_scheduler
+
+from de_net import DEResNet
+from s3diff_tile import S3Diff
+from my_utils.testing_utils import parse_args_paired_testing, PlainDataset, lr_proc
+from utils.util_image import ImageSpliterTh
+from my_utils.utils import instantiate_from_config
+from pathlib import Path
+from utils import util_image
+from utils.wavelet_color import wavelet_color_fix, adain_color_fix
+
+def evaluate(in_path, ref_path, ntest):
+
+    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+    metric_dict = {}
+    metric_dict["clipiqa"] = pyiqa.create_metric('clipiqa').to(device)
+    metric_dict["musiq"] = pyiqa.create_metric('musiq').to(device)
+    metric_dict["niqe"] = pyiqa.create_metric('niqe').to(device)
+    metric_dict["maniqa"] = pyiqa.create_metric('maniqa').to(device)
+    metric_paired_dict = {}
+    
+    in_path = Path(in_path) if not isinstance(in_path, Path) else in_path
+    assert in_path.is_dir()
+    
+    ref_path_list = None
+    if ref_path is not None:
+        ref_path = Path(ref_path) if not isinstance(ref_path, Path) else ref_path
+        ref_path_list = sorted([x for x in ref_path.glob("*.[jpJP][pnPN]*[gG]")])
+        if ntest is not None: ref_path_list = ref_path_list[:ntest]
+        
+        metric_paired_dict["psnr"]=pyiqa.create_metric('psnr', test_y_channel=True, color_space='ycbcr').to(device)
+        metric_paired_dict["lpips"]=pyiqa.create_metric('lpips').to(device)
+        metric_paired_dict["dists"]=pyiqa.create_metric('dists').to(device)
+        metric_paired_dict["ssim"]=pyiqa.create_metric('ssim', test_y_channel=True, color_space='ycbcr' ).to(device)
+        
+    lr_path_list = sorted([x for x in in_path.glob("*.[jpJP][pnPN]*[gG]")])
+    if ntest is not None: lr_path_list = lr_path_list[:ntest]
+    
+    print(f'Find {len(lr_path_list)} images in {in_path}')
+    result = {}
+    for i in tqdm.tqdm(range(len(lr_path_list))):
+        _in_path = lr_path_list[i]
+        _ref_path = ref_path_list[i] if ref_path_list is not None else None
+        
+        im_in = util_image.imread(_in_path, chn='rgb', dtype='float32')  # h x w x c
+        im_in_tensor = util_image.img2tensor(im_in).cuda()              # 1 x c x h x w
+        for key, metric in metric_dict.items():
+            with torch.cuda.amp.autocast():
+                result[key] = result.get(key, 0) + metric(im_in_tensor).item()
+        
+        if ref_path is not None:
+            im_ref = util_image.imread(_ref_path, chn='rgb', dtype='float32')  # h x w x c
+            im_ref_tensor = util_image.img2tensor(im_ref).cuda()    
+            for key, metric in metric_paired_dict.items():
+                result[key] = result.get(key, 0) + metric(im_in_tensor, im_ref_tensor).item()
+    
+    if ref_path is not None:
+        fid_metric = pyiqa.create_metric('fid')
+        result['fid'] = fid_metric(in_path, ref_path)
+
+    print_results = []
+    for key, res in result.items():
+        if key == 'fid':
+            print(f"{key}: {res:.2f}")
+            print_results.append(f"{key}: {res:.2f}")
+        else:
+            print(f"{key}: {res/len(lr_path_list):.5f}")
+            print_results.append(f"{key}: {res/len(lr_path_list):.5f}")
+    return print_results
+
+
+def main(args):
+    config = OmegaConf.load(args.base_config)
+
+    accelerator = Accelerator(
+        gradient_accumulation_steps=args.gradient_accumulation_steps,
+        mixed_precision=args.mixed_precision,
+        log_with=args.report_to,
+    )
+
+    if accelerator.is_local_main_process:
+        transformers.utils.logging.set_verbosity_warning()
+        diffusers.utils.logging.set_verbosity_info()
+    else:
+        transformers.utils.logging.set_verbosity_error()
+        diffusers.utils.logging.set_verbosity_error()
+
+    if args.seed is not None:
+        set_seed(args.seed)
+
+    if accelerator.is_main_process:
+        os.makedirs(os.path.join(args.output_dir, "checkpoints"), exist_ok=True)
+        os.makedirs(os.path.join(args.output_dir, "eval"), exist_ok=True)
+
+    # initialize net_sr
+    net_sr = S3Diff(lora_rank_unet=args.lora_rank_unet, lora_rank_vae=args.lora_rank_vae, sd_path=args.sd_path, pretrained_path=args.pretrained_path, args=args)
+    net_sr.set_eval()
+
+    net_de = DEResNet(num_in_ch=3, num_degradation=2)
+    net_de.load_model(args.de_net_path)
+    net_de = net_de.cuda()
+    net_de.eval()
+
+    if args.enable_xformers_memory_efficient_attention:
+        if is_xformers_available():
+            net_sr.unet.enable_xformers_memory_efficient_attention()
+        else:
+            raise ValueError("xformers is not available, please install it by running `pip install xformers`")
+
+    if args.gradient_checkpointing:
+        net_sr.unet.enable_gradient_checkpointing()
+
+    if args.allow_tf32:
+        torch.backends.cuda.matmul.allow_tf32 = True
+
+    dataset_val = PlainDataset(config.validation)
+    dl_val = torch.utils.data.DataLoader(dataset_val, batch_size=1, shuffle=False, num_workers=0)
+
+    # Prepare everything with our `accelerator`.
+    net_sr, net_de = accelerator.prepare(net_sr, net_de)
+
+    weight_dtype = torch.float32
+    if accelerator.mixed_precision == "fp16":
+        weight_dtype = torch.float16
+    elif accelerator.mixed_precision == "bf16":
+        weight_dtype = torch.bfloat16
+
+    # Move al networksr to device and cast to weight_dtype
+    net_sr.to(accelerator.device, dtype=weight_dtype)
+    net_de.to(accelerator.device, dtype=weight_dtype)
+      
+    offset = args.padding_offset
+    for step, batch_val in enumerate(dl_val):
+        lr_path = batch_val['lr_path'][0]
+        (path, name) = os.path.split(lr_path)
+
+        im_lr = batch_val['lr'].cuda()
+        im_lr = im_lr.to(memory_format=torch.contiguous_format).float()    
+
+        ori_h, ori_w = im_lr.shape[2:]
+        im_lr_resize = F.interpolate(
+            im_lr,
+            size=(ori_h * config.sf,
+                  ori_w * config.sf),
+            mode='bicubic',
+            )
+
+        im_lr_resize = im_lr_resize.contiguous() 
+        im_lr_resize_norm = im_lr_resize * 2 - 1.0
+        im_lr_resize_norm = torch.clamp(im_lr_resize_norm, -1.0, 1.0)
+        resize_h, resize_w = im_lr_resize_norm.shape[2:]
+
+        pad_h = (math.ceil(resize_h / 64)) * 64 - resize_h
+        pad_w = (math.ceil(resize_w / 64)) * 64 - resize_w
+        im_lr_resize_norm = F.pad(im_lr_resize_norm, pad=(0, pad_w, 0, pad_h), mode='reflect')
+        
+        B = im_lr_resize.size(0)
+        with torch.no_grad():
+            # forward pass
+            deg_score = net_de(im_lr)
+            pos_tag_prompt = [args.pos_prompt for _ in range(B)]
+            neg_tag_prompt = [args.neg_prompt for _ in range(B)]
+            x_tgt_pred = accelerator.unwrap_model(net_sr)(im_lr_resize_norm, deg_score, pos_prompt=pos_tag_prompt, neg_prompt=neg_tag_prompt)
+            x_tgt_pred = x_tgt_pred[:, :, :resize_h, :resize_w]
+            out_img = (x_tgt_pred * 0.5 + 0.5).cpu().detach()
+
+        output_pil = transforms.ToPILImage()(out_img[0])
+
+        if args.align_method == 'nofix':
+            output_pil = output_pil
+        else:
+            im_lr_resize = transforms.ToPILImage()(im_lr_resize[0].cpu().detach())
+            if args.align_method == 'wavelet':
+                output_pil = wavelet_color_fix(output_pil, im_lr_resize)
+            elif args.align_method == 'adain':
+                output_pil = adain_color_fix(output_pil, im_lr_resize)
+
+        fname, ext = os.path.splitext(name)
+        outf = os.path.join(args.output_dir, fname+'.png')
+        output_pil.save(outf)
+
+    print_results = evaluate(args.output_dir, args.ref_path, None)
+    out_t = os.path.join(args.output_dir, 'results.txt')
+    with open(out_t, 'w', encoding='utf-8') as f:
+        for item in print_results:
+            f.write(f"{item}\n")
+
+    gc.collect()
+    torch.cuda.empty_cache()
+
+if __name__ == "__main__":
+    args = parse_args_paired_testing()
+    main(args)
diff --git a/src/model.py b/src/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..2eb96328092875072b0b0e4447d3539e5371c312
--- /dev/null
+++ b/src/model.py
@@ -0,0 +1,80 @@
+import torch
+import os
+import requests
+from tqdm import tqdm
+from diffusers import DDPMScheduler, EulerDiscreteScheduler
+from typing import Any, Optional, Union
+
+# def make_1step_sched(pretrained_path, step=4):
+#     noise_scheduler_1step = EulerDiscreteScheduler.from_pretrained(pretrained_path, subfolder="scheduler")
+#     noise_scheduler_1step.set_timesteps(step, device="cuda")
+#     noise_scheduler_1step.alphas_cumprod = noise_scheduler_1step.alphas_cumprod.cuda()
+    # return noise_scheduler_1step
+
+
+def make_1step_sched(pretrained_path):
+    noise_scheduler_1step = DDPMScheduler.from_pretrained(pretrained_path, subfolder="scheduler")
+    noise_scheduler_1step.set_timesteps(1, device="cuda")
+    noise_scheduler_1step.alphas_cumprod = noise_scheduler_1step.alphas_cumprod.cuda()
+    return noise_scheduler_1step
+
+
+def my_lora_fwd(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
+    self._check_forward_args(x, *args, **kwargs)
+    adapter_names = kwargs.pop("adapter_names", None)
+
+    if self.disable_adapters:
+        if self.merged:
+            self.unmerge()
+        result = self.base_layer(x, *args, **kwargs)
+    elif adapter_names is not None:
+        result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
+    elif self.merged:
+        result = self.base_layer(x, *args, **kwargs)
+    else:
+        result = self.base_layer(x, *args, **kwargs)
+        torch_result_dtype = result.dtype
+        for active_adapter in self.active_adapters:
+            if active_adapter not in self.lora_A.keys():
+                continue
+            lora_A = self.lora_A[active_adapter]
+            lora_B = self.lora_B[active_adapter]
+            dropout = self.lora_dropout[active_adapter]
+            scaling = self.scaling[active_adapter]
+            x = x.to(lora_A.weight.dtype)
+
+            if not self.use_dora[active_adapter]:
+                _tmp = lora_A(dropout(x))
+                if isinstance(lora_A, torch.nn.Conv2d):
+                    _tmp = torch.einsum('...khw,...kr->...rhw', _tmp, self.de_mod)
+                elif isinstance(lora_A, torch.nn.Linear):
+                    _tmp = torch.einsum('...lk,...kr->...lr', _tmp, self.de_mod)
+                else:
+                    raise NotImplementedError('only conv and linear are supported yet.')
+
+                result = result + lora_B(_tmp) * scaling
+            else:
+                x = dropout(x)
+                result = result + self._apply_dora(x, lora_A, lora_B, scaling, active_adapter)
+
+        result = result.to(torch_result_dtype)
+
+    return result
+
+def download_url(url, outf):
+    if not os.path.exists(outf):
+        print(f"Downloading checkpoint to {outf}")
+        response = requests.get(url, stream=True)
+        total_size_in_bytes = int(response.headers.get('content-length', 0))
+        block_size = 1024  # 1 Kibibyte
+        progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
+        with open(outf, 'wb') as file:
+            for data in response.iter_content(block_size):
+                progress_bar.update(len(data))
+                file.write(data)
+        progress_bar.close()
+        if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
+            print("ERROR, something went wrong")
+        print(f"Downloaded successfully to {outf}")
+    else:
+        print(f"Skipping download, {outf} already exists")
diff --git a/src/my_utils/devices.py b/src/my_utils/devices.py
new file mode 100644
index 0000000000000000000000000000000000000000..7313838c3c1f153816031dc70c8beb765751ed9e
--- /dev/null
+++ b/src/my_utils/devices.py
@@ -0,0 +1,138 @@
+import sys
+import contextlib
+from functools import lru_cache
+
+import torch
+#from modules import errors
+
+if sys.platform == "darwin":
+    from modules import mac_specific
+
+
+def has_mps() -> bool:
+    if sys.platform != "darwin":
+        return False
+    else:
+        return mac_specific.has_mps
+
+
+def get_cuda_device_string():
+    return "cuda"
+
+
+def get_optimal_device_name():
+    if torch.cuda.is_available():
+        return get_cuda_device_string()
+
+    if has_mps():
+        return "mps"
+
+    return "cpu"
+
+
+def get_optimal_device():
+    return torch.device(get_optimal_device_name())
+
+
+def get_device_for(task):
+    return get_optimal_device()
+
+
+def torch_gc():
+
+    if torch.cuda.is_available():
+        with torch.cuda.device(get_cuda_device_string()):
+            torch.cuda.empty_cache()
+            torch.cuda.ipc_collect()
+
+    if has_mps():
+        mac_specific.torch_mps_gc()
+
+
+def enable_tf32():
+    if torch.cuda.is_available():
+
+        # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
+        # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
+        if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())):
+            torch.backends.cudnn.benchmark = True
+
+        torch.backends.cuda.matmul.allow_tf32 = True
+        torch.backends.cudnn.allow_tf32 = True
+
+
+enable_tf32()
+#errors.run(enable_tf32, "Enabling TF32")
+
+cpu = torch.device("cpu")
+device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = torch.device("cuda")
+dtype = torch.float16
+dtype_vae = torch.float16
+dtype_unet = torch.float16
+unet_needs_upcast = False
+
+
+def cond_cast_unet(input):
+    return input.to(dtype_unet) if unet_needs_upcast else input
+
+
+def cond_cast_float(input):
+    return input.float() if unet_needs_upcast else input
+
+
+def randn(seed, shape):
+    torch.manual_seed(seed)
+    return torch.randn(shape, device=device)
+
+
+def randn_without_seed(shape):
+    return torch.randn(shape, device=device)
+
+
+def autocast(disable=False):
+    if disable:
+        return contextlib.nullcontext()
+
+    return torch.autocast("cuda")
+
+
+def without_autocast(disable=False):
+    return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()
+
+
+class NansException(Exception):
+    pass
+
+
+def test_for_nans(x, where):
+    if not torch.all(torch.isnan(x)).item():
+        return
+
+    if where == "unet":
+        message = "A tensor with all NaNs was produced in Unet."
+
+    elif where == "vae":
+        message = "A tensor with all NaNs was produced in VAE."
+
+    else:
+        message = "A tensor with all NaNs was produced."
+
+    message += " Use --disable-nan-check commandline argument to disable this check."
+
+    raise NansException(message)
+
+
+@lru_cache
+def first_time_calculation():
+    """
+    just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and
+    spends about 2.7 seconds doing that, at least wih NVidia.
+    """
+
+    x = torch.zeros((1, 1)).to(device, dtype)
+    linear = torch.nn.Linear(1, 1).to(device, dtype)
+    linear(x)
+
+    x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
+    conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
+    conv2d(x)
diff --git a/src/my_utils/dino_struct.py b/src/my_utils/dino_struct.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2721c9b61b5fbef650e5c9e2133c93a6b6a4ea4
--- /dev/null
+++ b/src/my_utils/dino_struct.py
@@ -0,0 +1,185 @@
+import torch
+import torchvision
+import torch.nn.functional as F
+
+
+def attn_cosine_sim(x, eps=1e-08):
+    x = x[0]  # TEMP: getting rid of redundant dimension, TBF
+    norm1 = x.norm(dim=2, keepdim=True)
+    factor = torch.clamp(norm1 @ norm1.permute(0, 2, 1), min=eps)
+    sim_matrix = (x @ x.permute(0, 2, 1)) / factor
+    return sim_matrix
+
+
+class VitExtractor:
+    BLOCK_KEY = 'block'
+    ATTN_KEY = 'attn'
+    PATCH_IMD_KEY = 'patch_imd'
+    QKV_KEY = 'qkv'
+    KEY_LIST = [BLOCK_KEY, ATTN_KEY, PATCH_IMD_KEY, QKV_KEY]
+
+    def __init__(self, model_name, device):
+        # pdb.set_trace()
+        self.model = torch.hub.load('facebookresearch/dino:main', model_name).to(device)
+        self.model.eval()
+        self.model_name = model_name
+        self.hook_handlers = []
+        self.layers_dict = {}
+        self.outputs_dict = {}
+        for key in VitExtractor.KEY_LIST:
+            self.layers_dict[key] = []
+            self.outputs_dict[key] = []
+        self._init_hooks_data()
+
+    def _init_hooks_data(self):
+        self.layers_dict[VitExtractor.BLOCK_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
+        self.layers_dict[VitExtractor.ATTN_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
+        self.layers_dict[VitExtractor.QKV_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
+        self.layers_dict[VitExtractor.PATCH_IMD_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
+        for key in VitExtractor.KEY_LIST:
+            # self.layers_dict[key] = kwargs[key] if key in kwargs.keys() else []
+            self.outputs_dict[key] = []
+
+    def _register_hooks(self, **kwargs):
+        for block_idx, block in enumerate(self.model.blocks):
+            if block_idx in self.layers_dict[VitExtractor.BLOCK_KEY]:
+                self.hook_handlers.append(block.register_forward_hook(self._get_block_hook()))
+            if block_idx in self.layers_dict[VitExtractor.ATTN_KEY]:
+                self.hook_handlers.append(block.attn.attn_drop.register_forward_hook(self._get_attn_hook()))
+            if block_idx in self.layers_dict[VitExtractor.QKV_KEY]:
+                self.hook_handlers.append(block.attn.qkv.register_forward_hook(self._get_qkv_hook()))
+            if block_idx in self.layers_dict[VitExtractor.PATCH_IMD_KEY]:
+                self.hook_handlers.append(block.attn.register_forward_hook(self._get_patch_imd_hook()))
+
+    def _clear_hooks(self):
+        for handler in self.hook_handlers:
+            handler.remove()
+        self.hook_handlers = []
+
+    def _get_block_hook(self):
+        def _get_block_output(model, input, output):
+            self.outputs_dict[VitExtractor.BLOCK_KEY].append(output)
+
+        return _get_block_output
+
+    def _get_attn_hook(self):
+        def _get_attn_output(model, inp, output):
+            self.outputs_dict[VitExtractor.ATTN_KEY].append(output)
+
+        return _get_attn_output
+
+    def _get_qkv_hook(self):
+        def _get_qkv_output(model, inp, output):
+            self.outputs_dict[VitExtractor.QKV_KEY].append(output)
+
+        return _get_qkv_output
+
+    # TODO: CHECK ATTN OUTPUT TUPLE
+    def _get_patch_imd_hook(self):
+        def _get_attn_output(model, inp, output):
+            self.outputs_dict[VitExtractor.PATCH_IMD_KEY].append(output[0])
+
+        return _get_attn_output
+
+    def get_feature_from_input(self, input_img):  # List([B, N, D])
+        self._register_hooks()
+        self.model(input_img)
+        feature = self.outputs_dict[VitExtractor.BLOCK_KEY]
+        self._clear_hooks()
+        self._init_hooks_data()
+        return feature
+
+    def get_qkv_feature_from_input(self, input_img):
+        self._register_hooks()
+        self.model(input_img)
+        feature = self.outputs_dict[VitExtractor.QKV_KEY]
+        self._clear_hooks()
+        self._init_hooks_data()
+        return feature
+
+    def get_attn_feature_from_input(self, input_img):
+        self._register_hooks()
+        self.model(input_img)
+        feature = self.outputs_dict[VitExtractor.ATTN_KEY]
+        self._clear_hooks()
+        self._init_hooks_data()
+        return feature
+
+    def get_patch_size(self):
+        return 8 if "8" in self.model_name else 16
+
+    def get_width_patch_num(self, input_img_shape):
+        b, c, h, w = input_img_shape
+        patch_size = self.get_patch_size()
+        return w // patch_size
+
+    def get_height_patch_num(self, input_img_shape):
+        b, c, h, w = input_img_shape
+        patch_size = self.get_patch_size()
+        return h // patch_size
+
+    def get_patch_num(self, input_img_shape):
+        patch_num = 1 + (self.get_height_patch_num(input_img_shape) * self.get_width_patch_num(input_img_shape))
+        return patch_num
+
+    def get_head_num(self):
+        if "dino" in self.model_name:
+            return 6 if "s" in self.model_name else 12
+        return 6 if "small" in self.model_name else 12
+
+    def get_embedding_dim(self):
+        if "dino" in self.model_name:
+            return 384 if "s" in self.model_name else 768
+        return 384 if "small" in self.model_name else 768
+
+    def get_queries_from_qkv(self, qkv, input_img_shape):
+        patch_num = self.get_patch_num(input_img_shape)
+        head_num = self.get_head_num()
+        embedding_dim = self.get_embedding_dim()
+        q = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[0]
+        return q
+
+    def get_keys_from_qkv(self, qkv, input_img_shape):
+        patch_num = self.get_patch_num(input_img_shape)
+        head_num = self.get_head_num()
+        embedding_dim = self.get_embedding_dim()
+        k = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[1]
+        return k
+
+    def get_values_from_qkv(self, qkv, input_img_shape):
+        patch_num = self.get_patch_num(input_img_shape)
+        head_num = self.get_head_num()
+        embedding_dim = self.get_embedding_dim()
+        v = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[2]
+        return v
+
+    def get_keys_from_input(self, input_img, layer_num):
+        qkv_features = self.get_qkv_feature_from_input(input_img)[layer_num]
+        keys = self.get_keys_from_qkv(qkv_features, input_img.shape)
+        return keys
+
+    def get_keys_self_sim_from_input(self, input_img, layer_num):
+        keys = self.get_keys_from_input(input_img, layer_num=layer_num)
+        h, t, d = keys.shape
+        concatenated_keys = keys.transpose(0, 1).reshape(t, h * d)
+        ssim_map = attn_cosine_sim(concatenated_keys[None, None, ...])
+        return ssim_map
+
+
+class DinoStructureLoss:
+    def __init__(self, ):
+        self.extractor = VitExtractor(model_name="dino_vitb8", device="cuda")
+        self.preprocess = torchvision.transforms.Compose([
+            torchvision.transforms.Resize(224),
+            torchvision.transforms.ToTensor(),
+            torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
+        ])
+
+    def calculate_global_ssim_loss(self, outputs, inputs):
+        loss = 0.0
+        for a, b in zip(inputs, outputs):  # avoid memory limitations
+            with torch.no_grad():
+                target_keys_self_sim = self.extractor.get_keys_self_sim_from_input(a.unsqueeze(0), layer_num=11)
+            keys_ssim = self.extractor.get_keys_self_sim_from_input(b.unsqueeze(0), layer_num=11)
+            loss += F.mse_loss(keys_ssim, target_keys_self_sim)
+        return loss
diff --git a/src/my_utils/testing_utils.py b/src/my_utils/testing_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca4667c5f82cbcf83da5fd4f8fc29e7dbb88337a
--- /dev/null
+++ b/src/my_utils/testing_utils.py
@@ -0,0 +1,210 @@
+import argparse
+import json
+from PIL import Image
+from torchvision import transforms
+import torch.nn.functional as F
+from glob import glob
+
+import cv2
+import math
+import numpy as np
+import os
+import os.path as osp
+import random
+import time
+import torch
+from pathlib import Path
+from torch.utils import data as data
+
+from basicsr.utils import DiffJPEG, USMSharp
+from basicsr.utils.img_process_util import filter2D
+from basicsr.data.transforms import paired_random_crop, triplet_random_crop
+from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt, random_add_speckle_noise_pt, random_add_saltpepper_noise_pt, bivariate_Gaussian
+
+from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
+from basicsr.data.transforms import augment
+from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
+from basicsr.utils.registry import DATASET_REGISTRY
+
+
+def parse_args_paired_testing(input_args=None):
+    """
+    Parses command-line arguments used for configuring an paired session (pix2pix-Turbo).
+    This function sets up an argument parser to handle various training options.
+
+    Returns:
+    argparse.Namespace: The parsed command-line arguments.
+   """
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--ref_path", type=str, default=None,)
+    parser.add_argument("--base_config", default="./configs/sr_test.yaml", type=str)
+    parser.add_argument("--tracker_project_name", type=str, default="train_pix2pix_turbo", help="The name of the wandb project to log to.")
+
+    # details about the model architecture
+    parser.add_argument("--sd_path")
+    parser.add_argument("--de_net_path")
+    parser.add_argument("--pretrained_path", type=str, default=None,)
+    parser.add_argument("--revision", type=str, default=None,)
+    parser.add_argument("--variant", type=str, default=None,)
+    parser.add_argument("--tokenizer_name", type=str, default=None)
+    parser.add_argument("--lora_rank_unet", default=32, type=int)
+    parser.add_argument("--lora_rank_vae", default=16, type=int)
+
+    parser.add_argument("--scale", type=int, default=4, help="Scale factor for SR.")
+    parser.add_argument("--chop_size", type=int, default=128, choices=[512, 256, 128], help="Chopping forward.")
+    parser.add_argument("--chop_stride", type=int, default=96, help="Chopping stride.")
+    parser.add_argument("--padding_offset", type=int, default=32, help="padding offset.")
+
+    parser.add_argument("--vae_decoder_tiled_size", type=int, default=224)
+    parser.add_argument("--vae_encoder_tiled_size", type=int, default=1024) 
+    parser.add_argument("--latent_tiled_size", type=int, default=96) 
+    parser.add_argument("--latent_tiled_overlap", type=int, default=32)
+
+    parser.add_argument("--align_method", type=str, default="wavelet")
+    
+    parser.add_argument("--pos_prompt", type=str, default="A high-resolution, 8K, ultra-realistic image with sharp focus, vibrant colors, and natural lighting.")
+    parser.add_argument("--neg_prompt", type=str, default="oil painting, cartoon, blur, dirty, messy, low quality, deformation, low resolution, oversmooth")
+
+    # training details
+    parser.add_argument("--output_dir", required=True)
+    parser.add_argument("--cache_dir", default=None,)
+    parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+    parser.add_argument("--resolution", type=int, default=512,)
+    parser.add_argument("--checkpointing_steps", type=int, default=500,)
+    parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.",)
+    parser.add_argument("--gradient_checkpointing", action="store_true",)
+
+    parser.add_argument("--dataloader_num_workers", type=int, default=0,)
+    parser.add_argument("--allow_tf32", action="store_true",
+        help=(
+            "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+            " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+        ),
+    )
+    parser.add_argument("--report_to", type=str, default="wandb",
+        help=(
+            'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+            ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+        ),
+    )
+    parser.add_argument("--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"],)
+    parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.")
+    parser.add_argument("--set_grads_to_none", action="store_true",)
+
+    parser.add_argument('--world_size', default=1, type=int,
+                        help='number of distributed processes')
+    parser.add_argument('--local_rank', default=-1, type=int)
+    parser.add_argument('--dist_url', default='env://',
+                        help='url used to set up distributed training')
+
+    if input_args is not None:
+        args = parser.parse_args(input_args)
+    else:
+        args = parser.parse_args()
+
+    return args
+
+
+class PlainDataset(data.Dataset):
+    """Modified dataset based on the dataset used for Real-ESRGAN model:
+    Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
+
+    It loads gt (Ground-Truth) images, and augments them.
+    It also generates blur kernels and sinc kernels for generating low-quality images.
+    Note that the low-quality images are processed in tensors on GPUS for faster processing.
+
+    Args:
+        opt (dict): Config for train datasets. It contains the following keys:
+            dataroot_gt (str): Data root path for gt.
+            meta_info (str): Path for meta information file.
+            io_backend (dict): IO backend type and other kwarg.
+            use_hflip (bool): Use horizontal flips.
+            use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
+            Please see more options in the codes.
+    """
+
+    def __init__(self, opt):
+        super(PlainDataset, self).__init__()
+        self.opt = opt
+        self.file_client = None
+        self.io_backend_opt = opt['io_backend']
+
+        if 'image_type' not in opt:
+            opt['image_type'] = 'png'
+
+        # support multiple type of data: file path and meta data, remove support of lmdb
+        self.lr_paths = []
+        if 'lr_path' in opt:
+            if isinstance(opt['lr_path'], str):
+                self.lr_paths.extend(sorted(
+                    [str(x) for x in Path(opt['lr_path']).glob('*.png')] +
+                    [str(x) for x in Path(opt['lr_path']).glob('*.jpg')] +
+                    [str(x) for x in Path(opt['lr_path']).glob('*.jpeg')]
+                ))
+            else:
+                self.lr_paths.extend(sorted([str(x) for x in Path(opt['lr_path'][0]).glob('*.'+opt['image_type'])]))
+                if len(opt['lr_path']) > 1:
+                    for i in range(len(opt['lr_path'])-1):
+                        self.lr_paths.extend(sorted([str(x) for x in Path(opt['lr_path'][i+1]).glob('*.'+opt['image_type'])]))
+
+    def __getitem__(self, index):
+        if self.file_client is None:
+            self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+        # -------------------------------- Load gt images -------------------------------- #
+        # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
+        lr_path = self.lr_paths[index]
+
+        # avoid errors caused by high latency in reading files
+        retry = 3
+        while retry > 0:
+            try:
+                lr_img_bytes = self.file_client.get(lr_path, 'gt')
+            except (IOError, OSError) as e:
+                # logger = get_root_logger()
+                # logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}')
+                # change another file to read
+                index = random.randint(0, self.__len__()-1)
+                lr_path = self.lr_paths[index]
+                time.sleep(1)  # sleep 1s for occasional server congestion
+            else:
+                break
+            finally:
+                retry -= 1
+
+        img_lr = imfrombytes(lr_img_bytes, float32=True)
+        
+        # BGR to RGB, HWC to CHW, numpy to tensor
+        img_lr = img2tensor([img_lr], bgr2rgb=True, float32=True)[0]
+
+        return_d = {'lr': img_lr, 'lr_path': lr_path}
+        return return_d
+
+    def __len__(self):
+        return len(self.lr_paths)
+
+
+def lr_proc(config, batch, device):
+    im_lr = batch['lr'].cuda()
+    im_lr = im_lr.to(memory_format=torch.contiguous_format).float()    
+
+    ori_lr = im_lr
+
+    im_lr = F.interpolate(
+            im_lr,
+            size=(im_lr.size(-2) * config.sf,
+                  im_lr.size(-1) * config.sf),
+            mode='bicubic',
+            )
+
+    im_lr = im_lr.contiguous() 
+    im_lr = im_lr * 2 - 1.0
+    im_lr = torch.clamp(im_lr, -1.0, 1.0)
+
+    ori_h, ori_w = im_lr.size(-2), im_lr.size(-1)
+
+    pad_h = (math.ceil(ori_h / 64)) * 64 - ori_h
+    pad_w = (math.ceil(ori_w / 64)) * 64 - ori_w
+    im_lr = F.pad(im_lr, pad=(0, pad_w, 0, pad_h), mode='reflect')
+
+    return im_lr.to(device), ori_lr.to(device), (ori_h, ori_w)
diff --git a/src/my_utils/training_utils.py b/src/my_utils/training_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f3f7d92f17f8b0f54f88bdefdb610744c115b37
--- /dev/null
+++ b/src/my_utils/training_utils.py
@@ -0,0 +1,532 @@
+import argparse
+import json
+from PIL import Image
+from torchvision import transforms
+import torch.nn.functional as F
+from glob import glob
+
+import cv2
+import math
+import numpy as np
+import os
+import os.path as osp
+import random
+import time
+import torch
+from pathlib import Path
+from torch.utils import data as data
+
+from basicsr.utils import DiffJPEG, USMSharp
+from basicsr.utils.img_process_util import filter2D
+from basicsr.data.transforms import paired_random_crop, triplet_random_crop
+from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt, random_add_speckle_noise_pt, random_add_saltpepper_noise_pt, bivariate_Gaussian
+
+from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
+from basicsr.data.transforms import augment
+from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
+from basicsr.utils.registry import DATASET_REGISTRY
+
+def parse_args_paired_training(input_args=None):
+    """
+    Parses command-line arguments used for configuring an paired session (pix2pix-Turbo).
+    This function sets up an argument parser to handle various training options.
+
+    Returns:
+    argparse.Namespace: The parsed command-line arguments.
+   """
+    parser = argparse.ArgumentParser()
+    # args for the loss function
+    parser.add_argument("--gan_disc_type", default="vagan")
+    parser.add_argument("--gan_loss_type", default="multilevel_sigmoid_s")
+    parser.add_argument("--lambda_gan", default=0.5, type=float)
+    parser.add_argument("--lambda_lpips", default=5.0, type=float)
+    parser.add_argument("--lambda_l2", default=2.0, type=float)
+    parser.add_argument("--base_config", default="./configs/sr.yaml", type=str)
+
+    # validation eval args
+    parser.add_argument("--eval_freq", default=100, type=int)
+    parser.add_argument("--save_val", default=True, action="store_false")
+    parser.add_argument("--num_samples_eval", type=int, default=100, help="Number of samples to use for all evaluation")
+
+    parser.add_argument("--viz_freq", type=int, default=100, help="Frequency of visualizing the outputs.")
+
+    # details about the model architecture
+    parser.add_argument("--sd_path")
+    parser.add_argument("--pretrained_path", type=str, default=None,)
+    parser.add_argument("--de_net_path")
+    parser.add_argument("--revision", type=str, default=None,)
+    parser.add_argument("--variant", type=str, default=None,)
+    parser.add_argument("--tokenizer_name", type=str, default=None)
+    parser.add_argument("--lora_rank_unet", default=32, type=int)
+    parser.add_argument("--lora_rank_vae", default=16, type=int)
+    parser.add_argument("--neg_prob", default=0.05, type=float)
+    parser.add_argument("--pos_prompt", type=str, default="A high-resolution, 8K, ultra-realistic image with sharp focus, vibrant colors, and natural lighting.")
+    parser.add_argument("--neg_prompt", type=str, default="oil painting, cartoon, blur, dirty, messy, low quality, deformation, low resolution, oversmooth")
+
+    # training details
+    parser.add_argument("--output_dir", required=True)
+    parser.add_argument("--cache_dir", default=None,)
+    parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+    parser.add_argument("--resolution", type=int, default=512,)
+    parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader.")
+    parser.add_argument("--num_training_epochs", type=int, default=50)
+    parser.add_argument("--max_train_steps", type=int, default=50000,)
+    parser.add_argument("--checkpointing_steps", type=int, default=500,)
+    parser.add_argument("--gradient_accumulation_steps", type=int, default=4, help="Number of updates steps to accumulate before performing a backward/update pass.",)
+    parser.add_argument("--gradient_checkpointing", action="store_true",)
+    parser.add_argument("--learning_rate", type=float, default=2e-5)
+    parser.add_argument("--lr_scheduler", type=str, default="constant",
+        help=(
+            'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+            ' "constant", "piecewise_constant", "constant_with_warmup"]'
+        ),
+    )
+    parser.add_argument("--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler.")
+    parser.add_argument("--lr_num_cycles", type=int, default=1,
+        help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+    )
+    parser.add_argument("--lr_power", type=float, default=0.1, help="Power factor of the polynomial scheduler.")
+
+    parser.add_argument("--dataloader_num_workers", type=int, default=0,)
+    parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+    parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+    parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+    parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+    parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+    parser.add_argument("--allow_tf32", action="store_true",
+        help=(
+            "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+            " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+        ),
+    )
+    parser.add_argument("--report_to", type=str, default="wandb",
+        help=(
+            'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+            ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+        ),
+    )
+    parser.add_argument("--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"],)
+    parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.")
+    parser.add_argument("--set_grads_to_none", action="store_true",)
+
+    if input_args is not None:
+        args = parser.parse_args(input_args)
+    else:
+        args = parser.parse_args()
+
+    return args
+
+
+# @DATASET_REGISTRY.register(suffix='basicsr')
+class PairedDataset(data.Dataset):
+    """Modified dataset based on the dataset used for Real-ESRGAN model:
+    Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
+
+    It loads gt (Ground-Truth) images, and augments them.
+    It also generates blur kernels and sinc kernels for generating low-quality images.
+    Note that the low-quality images are processed in tensors on GPUS for faster processing.
+
+    Args:
+        opt (dict): Config for train datasets. It contains the following keys:
+            dataroot_gt (str): Data root path for gt.
+            meta_info (str): Path for meta information file.
+            io_backend (dict): IO backend type and other kwarg.
+            use_hflip (bool): Use horizontal flips.
+            use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
+            Please see more options in the codes.
+    """
+
+    def __init__(self, opt):
+        super(PairedDataset, self).__init__()
+        self.opt = opt
+        self.file_client = None
+        self.io_backend_opt = opt['io_backend']
+        if 'crop_size' in opt:
+            self.crop_size = opt['crop_size']
+        else:
+            self.crop_size = 512
+        if 'image_type' not in opt:
+            opt['image_type'] = 'png'
+
+        # support multiple type of data: file path and meta data, remove support of lmdb
+        self.paths = []
+        if 'meta_info' in opt:
+            with open(self.opt['meta_info']) as fin:
+                    paths = [line.strip().split(' ')[0] for line in fin]
+                    self.paths = [v for v in paths]
+            if 'meta_num' in opt:
+                self.paths = sorted(self.paths)[:opt['meta_num']]
+        if 'gt_path' in opt:
+            if isinstance(opt['gt_path'], str):
+                # Use rglob to recursively search for images
+                self.paths.extend(sorted([str(x) for x in Path(opt['gt_path']).rglob('*.' + opt['image_type'])]))
+            else:
+                for path in opt['gt_path']:
+                    self.paths.extend(sorted([str(x) for x in Path(path).rglob('*.' + opt['image_type'])]))
+                
+        # if 'gt_path' in opt:
+        #     if isinstance(opt['gt_path'], str):
+        #         self.paths.extend(sorted([str(x) for x in Path(opt['gt_path']).glob('*.'+opt['image_type'])]))
+        #     else:
+        #         self.paths.extend(sorted([str(x) for x in Path(opt['gt_path'][0]).glob('*.'+opt['image_type'])]))
+        #         if len(opt['gt_path']) > 1:
+        #             for i in range(len(opt['gt_path'])-1):
+        #                 self.paths.extend(sorted([str(x) for x in Path(opt['gt_path'][i+1]).glob('*.'+opt['image_type'])]))
+        if 'imagenet_path' in opt:
+            class_list = os.listdir(opt['imagenet_path'])
+            for class_file in class_list:
+                self.paths.extend(sorted([str(x) for x in Path(os.path.join(opt['imagenet_path'], class_file)).glob('*.'+'JPEG')]))
+        if 'face_gt_path' in opt:
+            if isinstance(opt['face_gt_path'], str):
+                face_list = sorted([str(x) for x in Path(opt['face_gt_path']).glob('*.'+opt['image_type'])])
+                self.paths.extend(face_list[:opt['num_face']])
+            else:
+                face_list = sorted([str(x) for x in Path(opt['face_gt_path'][0]).glob('*.'+opt['image_type'])])
+                self.paths.extend(face_list[:opt['num_face']])
+                if len(opt['face_gt_path']) > 1:
+                    for i in range(len(opt['face_gt_path'])-1):
+                        self.paths.extend(sorted([str(x) for x in Path(opt['face_gt_path'][0]).glob('*.'+opt['image_type'])])[:opt['num_face']])
+
+        # limit number of pictures for test
+        if 'num_pic' in opt:
+            if 'val' or 'test' in opt:
+                random.shuffle(self.paths)
+                self.paths = self.paths[:opt['num_pic']]
+            else:
+                self.paths = self.paths[:opt['num_pic']]
+
+        if 'mul_num' in opt:
+            self.paths = self.paths * opt['mul_num']
+            # print('>>>>>>>>>>>>>>>>>>>>>')
+            # print(self.paths)
+
+        # blur settings for the first degradation
+        self.blur_kernel_size = opt['blur_kernel_size']
+        self.kernel_list = opt['kernel_list']
+        self.kernel_prob = opt['kernel_prob']  # a list for each kernel probability
+        self.blur_sigma = opt['blur_sigma']
+        self.betag_range = opt['betag_range']  # betag used in generalized Gaussian blur kernels
+        self.betap_range = opt['betap_range']  # betap used in plateau blur kernels
+        self.sinc_prob = opt['sinc_prob']  # the probability for sinc filters
+
+        # blur settings for the second degradation
+        self.blur_kernel_size2 = opt['blur_kernel_size2']
+        self.kernel_list2 = opt['kernel_list2']
+        self.kernel_prob2 = opt['kernel_prob2']
+        self.blur_sigma2 = opt['blur_sigma2']
+        self.betag_range2 = opt['betag_range2']
+        self.betap_range2 = opt['betap_range2']
+        self.sinc_prob2 = opt['sinc_prob2']
+
+        # a final sinc filter
+        self.final_sinc_prob = opt['final_sinc_prob']
+
+        self.kernel_range = [2 * v + 1 for v in range(3, 11)]  # kernel size ranges from 7 to 21
+        # TODO: kernel range is now hard-coded, should be in the configure file
+        self.pulse_tensor = torch.zeros(21, 21).float()  # convolving with pulse tensor brings no blurry effect
+        self.pulse_tensor[10, 10] = 1
+
+    def __getitem__(self, index):
+        if self.file_client is None:
+            self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+        # -------------------------------- Load gt images -------------------------------- #
+        # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
+        gt_path = self.paths[index]
+        # avoid errors caused by high latency in reading files
+        retry = 3
+        while retry > 0:
+            try:
+                img_bytes = self.file_client.get(gt_path, 'gt')
+            except (IOError, OSError) as e:
+                # logger = get_root_logger()
+                # logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}')
+                # change another file to read
+                index = random.randint(0, self.__len__()-1)
+                gt_path = self.paths[index]
+                time.sleep(1)  # sleep 1s for occasional server congestion
+            else:
+                break
+            finally:
+                retry -= 1
+        img_gt = imfrombytes(img_bytes, float32=True)
+        # filter the dataset and remove images with too low quality
+        img_size = os.path.getsize(gt_path)
+        img_size = img_size / 1024
+
+        while img_gt.shape[0] * img_gt.shape[1] < 384*384 or img_size<100:
+            index = random.randint(0, self.__len__()-1)
+            gt_path = self.paths[index]
+
+            time.sleep(0.1)  # sleep 1s for occasional server congestion
+            img_bytes = self.file_client.get(gt_path, 'gt')
+            img_gt = imfrombytes(img_bytes, float32=True)
+            img_size = os.path.getsize(gt_path)
+            img_size = img_size / 1024
+
+        # -------------------- Do augmentation for training: flip, rotation -------------------- #
+        img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot'])
+
+        # crop or pad to 400
+        # TODO: 400 is hard-coded. You may change it accordingly
+        h, w = img_gt.shape[0:2]
+        crop_pad_size = self.crop_size
+        # pad
+        if h < crop_pad_size or w < crop_pad_size:
+            pad_h = max(0, crop_pad_size - h)
+            pad_w = max(0, crop_pad_size - w)
+            img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)
+        # crop
+        if img_gt.shape[0] > crop_pad_size or img_gt.shape[1] > crop_pad_size:
+            h, w = img_gt.shape[0:2]
+            # randomly choose top and left coordinates
+            top = random.randint(0, h - crop_pad_size)
+            left = random.randint(0, w - crop_pad_size)
+            # top = (h - crop_pad_size) // 2 -1
+            # left = (w - crop_pad_size) // 2 -1
+            img_gt = img_gt[top:top + crop_pad_size, left:left + crop_pad_size, ...]
+
+        # ------------------------ Generate kernels (used in the first degradation) ------------------------ #
+        kernel_size = random.choice(self.kernel_range)
+        if np.random.uniform() < self.opt['sinc_prob']:
+            # this sinc filter setting is for kernels ranging from [7, 21]
+            if kernel_size < 13:
+                omega_c = np.random.uniform(np.pi / 3, np.pi)
+            else:
+                omega_c = np.random.uniform(np.pi / 5, np.pi)
+            kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
+        else:
+            kernel = random_mixed_kernels(
+                self.kernel_list,
+                self.kernel_prob,
+                kernel_size,
+                self.blur_sigma,
+                self.blur_sigma, [-math.pi, math.pi],
+                self.betag_range,
+                self.betap_range,
+                noise_range=None)
+        # pad kernel
+        pad_size = (21 - kernel_size) // 2
+        kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
+
+        # ------------------------ Generate kernels (used in the second degradation) ------------------------ #
+        kernel_size = random.choice(self.kernel_range)
+        if np.random.uniform() < self.opt['sinc_prob2']:
+            if kernel_size < 13:
+                omega_c = np.random.uniform(np.pi / 3, np.pi)
+            else:
+                omega_c = np.random.uniform(np.pi / 5, np.pi)
+            kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
+        else:
+            kernel2 = random_mixed_kernels(
+                self.kernel_list2,
+                self.kernel_prob2,
+                kernel_size,
+                self.blur_sigma2,
+                self.blur_sigma2, [-math.pi, math.pi],
+                self.betag_range2,
+                self.betap_range2,
+                noise_range=None)
+
+        # pad kernel
+        pad_size = (21 - kernel_size) // 2
+        kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
+
+        # ------------------------------------- the final sinc kernel ------------------------------------- #
+        if np.random.uniform() < self.opt['final_sinc_prob']:
+            kernel_size = random.choice(self.kernel_range)
+            omega_c = np.random.uniform(np.pi / 3, np.pi)
+            sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
+            sinc_kernel = torch.FloatTensor(sinc_kernel)
+        else:
+            sinc_kernel = self.pulse_tensor
+
+        # BGR to RGB, HWC to CHW, numpy to tensor
+        img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0]
+        kernel = torch.FloatTensor(kernel)
+        kernel2 = torch.FloatTensor(kernel2)
+
+        return_d = {'gt': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'gt_path': gt_path}
+        return return_d
+
+    def __len__(self):
+        return len(self.paths)
+
+
+def randn_cropinput(lq, gt, base_size=[64, 128, 256, 512]):
+    cur_size_h = random.choice(base_size)
+    cur_size_w = random.choice(base_size)
+    init_h = lq.size(-2)//2
+    init_w = lq.size(-1)//2
+    lq = lq[:, :, init_h-cur_size_h//2:init_h+cur_size_h//2, init_w-cur_size_w//2:init_w+cur_size_w//2]
+    gt = gt[:, :, init_h-cur_size_h//2:init_h+cur_size_h//2, init_w-cur_size_w//2:init_w+cur_size_w//2]
+    assert lq.size(-1)>=64
+    assert lq.size(-2)>=64
+    return [lq, gt]
+
+
+def degradation_proc(configs, batch, device, val=False, use_usm=False, resize_lq=True, random_size=False):
+
+    """Degradation pipeline, modified from Real-ESRGAN:
+    https://github.com/xinntao/Real-ESRGAN
+    """
+
+    jpeger = DiffJPEG(differentiable=False).cuda()  # simulate JPEG compression artifacts
+    usm_sharpener = USMSharp().cuda()  # do usm sharpening
+
+    im_gt = batch['gt'].cuda()
+    if use_usm:
+        im_gt = usm_sharpener(im_gt)
+    im_gt = im_gt.to(memory_format=torch.contiguous_format).float()
+    kernel1 = batch['kernel1'].cuda()
+    kernel2 = batch['kernel2'].cuda()
+    sinc_kernel = batch['sinc_kernel'].cuda()
+
+    ori_h, ori_w = im_gt.size()[2:4]
+
+    # ----------------------- The first degradation process ----------------------- #
+    # blur
+    out = filter2D(im_gt, kernel1)
+    # random resize
+    updown_type = random.choices(
+            ['up', 'down', 'keep'],
+            configs.degradation['resize_prob'],
+            )[0]
+    if updown_type == 'up':
+        scale = random.uniform(1, configs.degradation['resize_range'][1])
+    elif updown_type == 'down':
+        scale = random.uniform(configs.degradation['resize_range'][0], 1)
+    else:
+        scale = 1
+    mode = random.choice(['area', 'bilinear', 'bicubic'])
+    out = F.interpolate(out, scale_factor=scale, mode=mode)
+    # add noise
+    gray_noise_prob = configs.degradation['gray_noise_prob']
+    if random.random() < configs.degradation['gaussian_noise_prob']:
+        out = random_add_gaussian_noise_pt(
+            out,
+            sigma_range=configs.degradation['noise_range'],
+            clip=True,
+            rounds=False,
+            gray_prob=gray_noise_prob,
+            )
+    else:
+        out = random_add_poisson_noise_pt(
+            out,
+            scale_range=configs.degradation['poisson_scale_range'],
+            gray_prob=gray_noise_prob,
+            clip=True,
+            rounds=False)
+    # JPEG compression
+    jpeg_p = out.new_zeros(out.size(0)).uniform_(*configs.degradation['jpeg_range'])
+    out = torch.clamp(out, 0, 1)  # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
+    out = jpeger(out, quality=jpeg_p)
+
+    # ----------------------- The second degradation process ----------------------- #
+    # blur
+    if random.random() < configs.degradation['second_blur_prob']:
+        out = filter2D(out, kernel2)
+    # random resize
+    updown_type = random.choices(
+            ['up', 'down', 'keep'],
+            configs.degradation['resize_prob2'],
+            )[0]
+    if updown_type == 'up':
+        scale = random.uniform(1, configs.degradation['resize_range2'][1])
+    elif updown_type == 'down':
+        scale = random.uniform(configs.degradation['resize_range2'][0], 1)
+    else:
+        scale = 1
+    mode = random.choice(['area', 'bilinear', 'bicubic'])
+    out = F.interpolate(
+            out,
+            size=(int(ori_h / configs.sf * scale),
+                  int(ori_w / configs.sf * scale)),
+            mode=mode,
+            )
+    # add noise
+    gray_noise_prob = configs.degradation['gray_noise_prob2']
+    if random.random() < configs.degradation['gaussian_noise_prob2']:
+        out = random_add_gaussian_noise_pt(
+            out,
+            sigma_range=configs.degradation['noise_range2'],
+            clip=True,
+            rounds=False,
+            gray_prob=gray_noise_prob,
+            )
+    else:
+        out = random_add_poisson_noise_pt(
+            out,
+            scale_range=configs.degradation['poisson_scale_range2'],
+            gray_prob=gray_noise_prob,
+            clip=True,
+            rounds=False,
+            )
+
+    # JPEG compression + the final sinc filter
+    # We also need to resize images to desired sizes. We group [resize back + sinc filter] together
+    # as one operation.
+    # We consider two orders:
+    #   1. [resize back + sinc filter] + JPEG compression
+    #   2. JPEG compression + [resize back + sinc filter]
+    # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
+    if random.random() < 0.5:
+        # resize back + the final sinc filter
+        mode = random.choice(['area', 'bilinear', 'bicubic'])
+        out = F.interpolate(
+                out,
+                size=(ori_h // configs.sf,
+                      ori_w // configs.sf),
+                mode=mode,
+                )
+        out = filter2D(out, sinc_kernel)
+        # JPEG compression
+        jpeg_p = out.new_zeros(out.size(0)).uniform_(*configs.degradation['jpeg_range2'])
+        out = torch.clamp(out, 0, 1)
+        out = jpeger(out, quality=jpeg_p)
+    else:
+        # JPEG compression
+        jpeg_p = out.new_zeros(out.size(0)).uniform_(*configs.degradation['jpeg_range2'])
+        out = torch.clamp(out, 0, 1)
+        out = jpeger(out, quality=jpeg_p)
+        # resize back + the final sinc filter
+        mode = random.choice(['area', 'bilinear', 'bicubic'])
+        out = F.interpolate(
+                out,
+                size=(ori_h // configs.sf,
+                      ori_w // configs.sf),
+                mode=mode,
+                )
+        out = filter2D(out, sinc_kernel)
+
+    # clamp and round
+    im_lq = torch.clamp(out, 0, 1.0)
+
+    # random crop
+    gt_size = configs.degradation['gt_size']
+    im_gt, im_lq = paired_random_crop(im_gt, im_lq, gt_size, configs.sf)
+    lq, gt = im_lq, im_gt
+    ori_lq = im_lq
+
+    if resize_lq:
+        lq = F.interpolate(
+                lq,
+                size=(gt.size(-2),
+                      gt.size(-1)),
+                mode='bicubic',
+                )
+
+    if random.random() < configs.degradation['no_degradation_prob'] or torch.isnan(lq).any():
+        lq = gt
+
+    # sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue
+    lq = lq.contiguous()  # for the warning: grad and param do not obey the gradient layout contract
+    lq = lq * 2 - 1.0 # TODO 0~1?
+    gt = gt * 2 - 1.0
+
+    if random_size:
+        lq, gt = randn_cropinput(lq, gt)
+
+    lq = torch.clamp(lq, -1.0, 1.0)
+
+    return lq.to(device), gt.to(device), ori_lq.to(device)
diff --git a/src/my_utils/utils.py b/src/my_utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c1bdc214ecc7defaa76cf47b592174ab1d580bf
--- /dev/null
+++ b/src/my_utils/utils.py
@@ -0,0 +1,213 @@
+import importlib
+
+import torch
+import numpy as np
+from collections import abc
+from einops import rearrange
+from functools import partial
+
+import multiprocessing as mp
+from threading import Thread
+from queue import Queue
+
+from inspect import isfunction
+from PIL import Image, ImageDraw, ImageFont
+
+
+def log_txt_as_img(wh, xc, size=10):
+    # wh a tuple of (width, height)
+    # xc a list of captions to plot
+    b = len(xc)
+    txts = list()
+    for bi in range(b):
+        txt = Image.new("RGB", wh, color="white")
+        draw = ImageDraw.Draw(txt)
+        font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
+        nc = int(40 * (wh[0] / 256))
+        lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
+
+        try:
+            draw.text((0, 0), lines, fill="black", font=font)
+        except UnicodeEncodeError:
+            print("Cant encode string for logging. Skipping.")
+
+        txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
+        txts.append(txt)
+    txts = np.stack(txts)
+    txts = torch.tensor(txts)
+    return txts
+
+
+def ismap(x):
+    if not isinstance(x, torch.Tensor):
+        return False
+    return (len(x.shape) == 4) and (x.shape[1] > 3)
+
+
+def isimage(x):
+    if not isinstance(x, torch.Tensor):
+        return False
+    return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
+
+
+def exists(x):
+    return x is not None
+
+
+def default(val, d):
+    if exists(val):
+        return val
+    return d() if isfunction(d) else d
+
+
+def mean_flat(tensor):
+    """
+    https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
+    Take the mean over all non-batch dimensions.
+    """
+    return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def count_params(model, verbose=False):
+    total_params = sum(p.numel() for p in model.parameters())
+    if verbose:
+        print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
+    return total_params
+
+
+def instantiate_from_config(config):
+    if not "target" in config:
+        if config == '__is_first_stage__':
+            return None
+        elif config == "__is_unconditional__":
+            return None
+        raise KeyError("Expected key `target` to instantiate.")
+    return get_obj_from_str(config["target"])(**config.get("params", dict()))
+
+
+def instantiate_from_config_sr(config):
+    if not "target" in config:
+        if config == '__is_first_stage__':
+            return None
+        elif config == "__is_unconditional__":
+            return None
+        raise KeyError("Expected key `target` to instantiate.")
+    return get_obj_from_str(config["target"])(config.get("params", dict()))
+
+
+def get_obj_from_str(string, reload=False):
+    module, cls = string.rsplit(".", 1)
+    if reload:
+        module_imp = importlib.import_module(module)
+        importlib.reload(module_imp)
+    return getattr(importlib.import_module(module, package=None), cls)
+
+
+def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
+    # create dummy dataset instance
+
+    # run prefetching
+    if idx_to_fn:
+        res = func(data, worker_id=idx)
+    else:
+        res = func(data)
+    Q.put([idx, res])
+    Q.put("Done")
+
+
+def parallel_data_prefetch(
+        func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
+):
+    # if target_data_type not in ["ndarray", "list"]:
+    #     raise ValueError(
+    #         "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
+    #     )
+    if isinstance(data, np.ndarray) and target_data_type == "list":
+        raise ValueError("list expected but function got ndarray.")
+    elif isinstance(data, abc.Iterable):
+        if isinstance(data, dict):
+            print(
+                f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
+            )
+            data = list(data.values())
+        if target_data_type == "ndarray":
+            data = np.asarray(data)
+        else:
+            data = list(data)
+    else:
+        raise TypeError(
+            f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
+        )
+
+    if cpu_intensive:
+        Q = mp.Queue(1000)
+        proc = mp.Process
+    else:
+        Q = Queue(1000)
+        proc = Thread
+    # spawn processes
+    if target_data_type == "ndarray":
+        arguments = [
+            [func, Q, part, i, use_worker_id]
+            for i, part in enumerate(np.array_split(data, n_proc))
+        ]
+    else:
+        step = (
+            int(len(data) / n_proc + 1)
+            if len(data) % n_proc != 0
+            else int(len(data) / n_proc)
+        )
+        arguments = [
+            [func, Q, part, i, use_worker_id]
+            for i, part in enumerate(
+                [data[i: i + step] for i in range(0, len(data), step)]
+            )
+        ]
+    processes = []
+    for i in range(n_proc):
+        p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
+        processes += [p]
+
+    # start processes
+    print(f"Start prefetching...")
+    import time
+
+    start = time.time()
+    gather_res = [[] for _ in range(n_proc)]
+    try:
+        for p in processes:
+            p.start()
+
+        k = 0
+        while k < n_proc:
+            # get result
+            res = Q.get()
+            if res == "Done":
+                k += 1
+            else:
+                gather_res[res[0]] = res[1]
+
+    except Exception as e:
+        print("Exception: ", e)
+        for p in processes:
+            p.terminate()
+
+        raise e
+    finally:
+        for p in processes:
+            p.join()
+        print(f"Prefetching complete. [{time.time() - start} sec.]")
+
+    if target_data_type == 'ndarray':
+        if not isinstance(gather_res[0], np.ndarray):
+            return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
+
+        # order outputs
+        return np.concatenate(gather_res, axis=0)
+    elif target_data_type == 'list':
+        out = []
+        for r in gather_res:
+            out.extend(r)
+        return out
+    else:
+        return gather_res
diff --git a/src/my_utils/vaehook.py b/src/my_utils/vaehook.py
new file mode 100644
index 0000000000000000000000000000000000000000..2975dea13fb55fca903df6d26501e3c6ab1541c5
--- /dev/null
+++ b/src/my_utils/vaehook.py
@@ -0,0 +1,828 @@
+# ------------------------------------------------------------------------
+#
+#   Ultimate VAE Tile Optimization
+#
+#   Introducing a revolutionary new optimization designed to make
+#   the VAE work with giant images on limited VRAM!
+#   Say goodbye to the frustration of OOM and hello to seamless output!
+#
+# ------------------------------------------------------------------------
+#
+#   This script is a wild hack that splits the image into tiles,
+#   encodes each tile separately, and merges the result back together.
+#
+#   Advantages:
+#   - The VAE can now work with giant images on limited VRAM
+#       (~10 GB for 8K images!)
+#   - The merged output is completely seamless without any post-processing.
+#
+#   Drawbacks:
+#   - Giant RAM needed. To store the intermediate results for a 4096x4096
+#       images, you need 32 GB RAM it consumes ~20GB); for 8192x8192
+#       you need 128 GB RAM machine (it consumes ~100 GB)
+#   - NaNs always appear in for 8k images when you use fp16 (half) VAE
+#       You must use --no-half-vae to disable half VAE for that giant image.
+#   - Slow speed. With default tile size, it takes around 50/200 seconds
+#       to encode/decode a 4096x4096 image; and 200/900 seconds to encode/decode
+#       a 8192x8192 image. (The speed is limited by both the GPU and the CPU.)
+#   - The gradient calculation is not compatible with this hack. It
+#       will break any backward() or torch.autograd.grad() that passes VAE.
+#       (But you can still use the VAE to generate training data.)
+#
+#   How it works:
+#   1) The image is split into tiles.
+#       - To ensure perfect results, each tile is padded with 32 pixels
+#           on each side.
+#       - Then the conv2d/silu/upsample/downsample can produce identical
+#           results to the original image without splitting.
+#   2) The original forward is decomposed into a task queue and a task worker.
+#       - The task queue is a list of functions that will be executed in order.
+#       - The task worker is a loop that executes the tasks in the queue.
+#   3) The task queue is executed for each tile.
+#       - Current tile is sent to GPU.
+#       - local operations are directly executed.
+#       - Group norm calculation is temporarily suspended until the mean
+#           and var of all tiles are calculated.
+#       - The residual is pre-calculated and stored and addded back later.
+#       - When need to go to the next tile, the current tile is send to cpu.
+#   4) After all tiles are processed, tiles are merged on cpu and return.
+#
+#   Enjoy!
+#
+#   @author: LI YI @ Nanyang Technological University - Singapore
+#   @date: 2023-03-02
+#   @license: MIT License
+#
+#   Please give me a star if you like this project!
+#
+# -------------------------------------------------------------------------
+
+import gc
+from time import time
+import math
+from tqdm import tqdm
+
+import torch
+import torch.version
+import torch.nn.functional as F
+from einops import rearrange
+import os
+import sys
+sys.path.append(os.getcwd())
+import my_utils.devices as devices
+
+try:
+    import xformers
+    import xformers.ops
+except ImportError:
+    pass
+
+sd_flag = False
+
+def get_recommend_encoder_tile_size():
+    if torch.cuda.is_available():
+        total_memory = torch.cuda.get_device_properties(
+            devices.device).total_memory // 2**20
+        if total_memory > 16*1000:
+            ENCODER_TILE_SIZE = 3072
+        elif total_memory > 12*1000:
+            ENCODER_TILE_SIZE = 2048
+        elif total_memory > 8*1000:
+            ENCODER_TILE_SIZE = 1536
+        else:
+            ENCODER_TILE_SIZE = 960
+    else:
+        ENCODER_TILE_SIZE = 512
+    return ENCODER_TILE_SIZE
+
+
+def get_recommend_decoder_tile_size():
+    if torch.cuda.is_available():
+        total_memory = torch.cuda.get_device_properties(
+            devices.device).total_memory // 2**20
+        if total_memory > 30*1000:
+            DECODER_TILE_SIZE = 256
+        elif total_memory > 16*1000:
+            DECODER_TILE_SIZE = 192
+        elif total_memory > 12*1000:
+            DECODER_TILE_SIZE = 128
+        elif total_memory > 8*1000:
+            DECODER_TILE_SIZE = 96
+        else:
+            DECODER_TILE_SIZE = 64
+    else:
+        DECODER_TILE_SIZE = 64
+    return DECODER_TILE_SIZE
+
+
+if 'global const':
+    DEFAULT_ENABLED = False
+    DEFAULT_MOVE_TO_GPU = False
+    DEFAULT_FAST_ENCODER = True
+    DEFAULT_FAST_DECODER = True
+    DEFAULT_COLOR_FIX = 0
+    DEFAULT_ENCODER_TILE_SIZE = get_recommend_encoder_tile_size()
+    DEFAULT_DECODER_TILE_SIZE = get_recommend_decoder_tile_size()
+
+
+# inplace version of silu
+def inplace_nonlinearity(x):
+    # Test: fix for Nans
+    return F.silu(x, inplace=True)
+
+# extracted from ldm.modules.diffusionmodules.model
+
+# from diffusers lib
+def attn_forward_new(self, h_):
+    batch_size, channel, height, width = h_.shape
+    hidden_states = h_.view(batch_size, channel, height * width).transpose(1, 2)
+
+    attention_mask = None
+    encoder_hidden_states = None
+    batch_size, sequence_length, _ = hidden_states.shape
+    attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+    query = self.to_q(hidden_states)
+
+    if encoder_hidden_states is None:
+        encoder_hidden_states = hidden_states
+    elif self.norm_cross:
+        encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)
+
+    key = self.to_k(encoder_hidden_states)
+    value = self.to_v(encoder_hidden_states)
+
+    query = self.head_to_batch_dim(query)
+    key = self.head_to_batch_dim(key)
+    value = self.head_to_batch_dim(value)
+
+    attention_probs = self.get_attention_scores(query, key, attention_mask)
+    hidden_states = torch.bmm(attention_probs, value)
+    hidden_states = self.batch_to_head_dim(hidden_states)
+
+    # linear proj
+    hidden_states = self.to_out[0](hidden_states)
+    # dropout
+    hidden_states = self.to_out[1](hidden_states)
+
+    hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+    return hidden_states
+
+def attn_forward(self, h_):
+    q = self.q(h_)
+    k = self.k(h_)
+    v = self.v(h_)
+
+    # compute attention
+    b, c, h, w = q.shape
+    q = q.reshape(b, c, h*w)
+    q = q.permute(0, 2, 1)   # b,hw,c
+    k = k.reshape(b, c, h*w)  # b,c,hw
+    w_ = torch.bmm(q, k)     # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+    w_ = w_ * (int(c)**(-0.5))
+    w_ = torch.nn.functional.softmax(w_, dim=2)
+
+    # attend to values
+    v = v.reshape(b, c, h*w)
+    w_ = w_.permute(0, 2, 1)   # b,hw,hw (first hw of k, second of q)
+    # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+    h_ = torch.bmm(v, w_)
+    h_ = h_.reshape(b, c, h, w)
+
+    h_ = self.proj_out(h_)
+
+    return h_
+
+
+def xformer_attn_forward(self, h_):
+    q = self.q(h_)
+    k = self.k(h_)
+    v = self.v(h_)
+
+    # compute attention
+    B, C, H, W = q.shape
+    q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
+
+    q, k, v = map(
+        lambda t: t.unsqueeze(3)
+        .reshape(B, t.shape[1], 1, C)
+        .permute(0, 2, 1, 3)
+        .reshape(B * 1, t.shape[1], C)
+        .contiguous(),
+        (q, k, v),
+    )
+    out = xformers.ops.memory_efficient_attention(
+        q, k, v, attn_bias=None, op=self.attention_op)
+
+    out = (
+        out.unsqueeze(0)
+        .reshape(B, 1, out.shape[1], C)
+        .permute(0, 2, 1, 3)
+        .reshape(B, out.shape[1], C)
+    )
+    out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
+    out = self.proj_out(out)
+    return out
+
+
+def attn2task(task_queue, net):
+    if False: #isinstance(net, AttnBlock):
+        task_queue.append(('store_res', lambda x: x))
+        task_queue.append(('pre_norm', net.norm))
+        task_queue.append(('attn', lambda x, net=net: attn_forward(net, x)))
+        task_queue.append(['add_res', None])
+    elif False: #isinstance(net, MemoryEfficientAttnBlock):
+        task_queue.append(('store_res', lambda x: x))
+        task_queue.append(('pre_norm', net.norm))
+        task_queue.append(
+            ('attn', lambda x, net=net: xformer_attn_forward(net, x)))
+        task_queue.append(['add_res', None])
+    else:
+        task_queue.append(('store_res', lambda x: x))
+        task_queue.append(('pre_norm', net.group_norm))
+        task_queue.append(('attn', lambda x, net=net: attn_forward_new(net, x)))
+        task_queue.append(['add_res', None])
+
+def resblock2task(queue, block):
+    """
+    Turn a ResNetBlock into a sequence of tasks and append to the task queue
+
+    @param queue: the target task queue
+    @param block: ResNetBlock
+
+    """
+    if block.in_channels != block.out_channels:
+        if sd_flag:
+            if block.use_conv_shortcut:
+                queue.append(('store_res', block.conv_shortcut))
+            else:
+                queue.append(('store_res', block.nin_shortcut))
+        else:
+            if block.use_in_shortcut:
+                queue.append(('store_res', block.conv_shortcut))
+            else:
+                queue.append(('store_res', block.nin_shortcut))
+
+    else:
+        queue.append(('store_res', lambda x: x))
+    queue.append(('pre_norm', block.norm1))
+    queue.append(('silu', inplace_nonlinearity))
+    queue.append(('conv1', block.conv1))
+    queue.append(('pre_norm', block.norm2))
+    queue.append(('silu', inplace_nonlinearity))
+    queue.append(('conv2', block.conv2))
+    queue.append(['add_res', None])
+
+
+
+def build_sampling(task_queue, net, is_decoder):
+    """
+    Build the sampling part of a task queue
+    @param task_queue: the target task queue
+    @param net: the network
+    @param is_decoder: currently building decoder or encoder
+    """
+    if is_decoder:
+        # resblock2task(task_queue, net.mid.block_1)
+        # attn2task(task_queue, net.mid.attn_1)
+        # resblock2task(task_queue, net.mid.block_2)
+        # resolution_iter = reversed(range(net.num_resolutions))
+        # block_ids = net.num_res_blocks + 1
+        # condition = 0
+        # module = net.up
+        # func_name = 'upsample'
+        resblock2task(task_queue, net.mid_block.resnets[0])
+        attn2task(task_queue, net.mid_block.attentions[0])
+        resblock2task(task_queue, net.mid_block.resnets[1])
+        resolution_iter = (range(len(net.up_blocks)))  # range(0,4)
+        block_ids = 2 + 1
+        condition = len(net.up_blocks) - 1
+        module = net.up_blocks
+        func_name = 'upsamplers'
+    else:
+        # resolution_iter = range(net.num_resolutions)
+        # block_ids = net.num_res_blocks
+        # condition = net.num_resolutions - 1
+        # module = net.down
+        # func_name = 'downsample'
+        resolution_iter = (range(len(net.down_blocks)))  # range(0,4)
+        block_ids = 2
+        condition = len(net.down_blocks) - 1
+        module = net.down_blocks
+        func_name = 'downsamplers'
+
+
+    for i_level in resolution_iter:
+        for i_block in range(block_ids):
+            resblock2task(task_queue, module[i_level].resnets[i_block])
+        if i_level != condition:
+            if is_decoder:
+                task_queue.append((func_name, module[i_level].upsamplers[0]))
+            else:
+                task_queue.append((func_name, module[i_level].downsamplers[0]))
+
+    if not is_decoder:
+        resblock2task(task_queue, net.mid_block.resnets[0])
+        attn2task(task_queue, net.mid_block.attentions[0])
+        resblock2task(task_queue, net.mid_block.resnets[1])
+
+
+def build_task_queue(net, is_decoder):
+    """
+    Build a single task queue for the encoder or decoder
+    @param net: the VAE decoder or encoder network
+    @param is_decoder: currently building decoder or encoder
+    @return: the task queue
+    """
+    task_queue = []
+    task_queue.append(('conv_in', net.conv_in))
+
+    # construct the sampling part of the task queue
+    # because encoder and decoder share the same architecture, we extract the sampling part
+    build_sampling(task_queue, net, is_decoder)
+    if is_decoder and not sd_flag:
+        net.give_pre_end = False
+        net.tanh_out = False
+
+    if not is_decoder or not net.give_pre_end:
+        if sd_flag:
+            task_queue.append(('pre_norm', net.norm_out))
+        else:
+            task_queue.append(('pre_norm', net.conv_norm_out))
+        task_queue.append(('silu', inplace_nonlinearity))
+        task_queue.append(('conv_out', net.conv_out))
+        if is_decoder and net.tanh_out:
+            task_queue.append(('tanh', torch.tanh))
+
+    return task_queue
+
+
+def clone_task_queue(task_queue):
+    """
+    Clone a task queue
+    @param task_queue: the task queue to be cloned
+    @return: the cloned task queue
+    """
+    return [[item for item in task] for task in task_queue]
+
+
+def get_var_mean(input, num_groups, eps=1e-6):
+    """
+    Get mean and var for group norm
+    """
+    b, c = input.size(0), input.size(1)
+    channel_in_group = int(c/num_groups)
+    input_reshaped = input.contiguous().view(
+        1, int(b * num_groups), channel_in_group, *input.size()[2:])
+    var, mean = torch.var_mean(
+        input_reshaped, dim=[0, 2, 3, 4], unbiased=False)
+    return var, mean
+
+
+def custom_group_norm(input, num_groups, mean, var, weight=None, bias=None, eps=1e-6):
+    """
+    Custom group norm with fixed mean and var
+
+    @param input: input tensor
+    @param num_groups: number of groups. by default, num_groups = 32
+    @param mean: mean, must be pre-calculated by get_var_mean
+    @param var: var, must be pre-calculated by get_var_mean
+    @param weight: weight, should be fetched from the original group norm
+    @param bias: bias, should be fetched from the original group norm
+    @param eps: epsilon, by default, eps = 1e-6 to match the original group norm
+
+    @return: normalized tensor
+    """
+    b, c = input.size(0), input.size(1)
+    channel_in_group = int(c/num_groups)
+    input_reshaped = input.contiguous().view(
+        1, int(b * num_groups), channel_in_group, *input.size()[2:])
+
+    out = F.batch_norm(input_reshaped, mean, var, weight=None, bias=None,
+                       training=False, momentum=0, eps=eps)
+
+    out = out.view(b, c, *input.size()[2:])
+
+    # post affine transform
+    if weight is not None:
+        out *= weight.view(1, -1, 1, 1)
+    if bias is not None:
+        out += bias.view(1, -1, 1, 1)
+    return out
+
+
+def crop_valid_region(x, input_bbox, target_bbox, is_decoder):
+    """
+    Crop the valid region from the tile
+    @param x: input tile
+    @param input_bbox: original input bounding box
+    @param target_bbox: output bounding box
+    @param scale: scale factor
+    @return: cropped tile
+    """
+    padded_bbox = [i * 8 if is_decoder else i//8 for i in input_bbox]
+    margin = [target_bbox[i] - padded_bbox[i] for i in range(4)]
+    return x[:, :, margin[2]:x.size(2)+margin[3], margin[0]:x.size(3)+margin[1]]
+
+# ↓↓↓ https://github.com/Kahsolt/stable-diffusion-webui-vae-tile-infer ↓↓↓
+
+
+def perfcount(fn):
+    def wrapper(*args, **kwargs):
+        ts = time()
+
+        if torch.cuda.is_available():
+            torch.cuda.reset_peak_memory_stats(devices.device)
+        devices.torch_gc()
+        gc.collect()
+
+        ret = fn(*args, **kwargs)
+
+        devices.torch_gc()
+        gc.collect()
+        if torch.cuda.is_available():
+            vram = torch.cuda.max_memory_allocated(devices.device) / 2**20
+            torch.cuda.reset_peak_memory_stats(devices.device)
+            print(
+                f'[Tiled VAE]: Done in {time() - ts:.3f}s, max VRAM alloc {vram:.3f} MB')
+        else:
+            print(f'[Tiled VAE]: Done in {time() - ts:.3f}s')
+
+        return ret
+    return wrapper
+
+# copy end :)
+
+
+class GroupNormParam:
+    def __init__(self):
+        self.var_list = []
+        self.mean_list = []
+        self.pixel_list = []
+        self.weight = None
+        self.bias = None
+
+    def add_tile(self, tile, layer):
+        var, mean = get_var_mean(tile, 32)
+        # For giant images, the variance can be larger than max float16
+        # In this case we create a copy to float32
+        if var.dtype == torch.float16 and var.isinf().any():
+            fp32_tile = tile.float()
+            var, mean = get_var_mean(fp32_tile, 32)
+        # ============= DEBUG: test for infinite =============
+        # if torch.isinf(var).any():
+        #    print('var: ', var)
+        # ====================================================
+        self.var_list.append(var)
+        self.mean_list.append(mean)
+        self.pixel_list.append(
+            tile.shape[2]*tile.shape[3])
+        if hasattr(layer, 'weight'):
+            self.weight = layer.weight
+            self.bias = layer.bias
+        else:
+            self.weight = None
+            self.bias = None
+
+    def summary(self):
+        """
+        summarize the mean and var and return a function
+        that apply group norm on each tile
+        """
+        if len(self.var_list) == 0:
+            return None
+        var = torch.vstack(self.var_list)
+        mean = torch.vstack(self.mean_list)
+        max_value = max(self.pixel_list)
+        pixels = torch.tensor(
+            self.pixel_list, dtype=torch.float32, device=devices.device) / max_value
+        sum_pixels = torch.sum(pixels)
+        pixels = pixels.unsqueeze(
+            1) / sum_pixels
+        var = torch.sum(
+            var * pixels, dim=0)
+        mean = torch.sum(
+            mean * pixels, dim=0)
+        return lambda x:  custom_group_norm(x, 32, mean, var, self.weight, self.bias)
+
+    @staticmethod
+    def from_tile(tile, norm):
+        """
+        create a function from a single tile without summary
+        """
+        var, mean = get_var_mean(tile, 32)
+        if var.dtype == torch.float16 and var.isinf().any():
+            fp32_tile = tile.float()
+            var, mean = get_var_mean(fp32_tile, 32)
+            # if it is a macbook, we need to convert back to float16
+            if var.device.type == 'mps':
+                # clamp to avoid overflow
+                var = torch.clamp(var, 0, 60000)
+                var = var.half()
+                mean = mean.half()
+        if hasattr(norm, 'weight'):
+            weight = norm.weight
+            bias = norm.bias
+        else:
+            weight = None
+            bias = None
+
+        def group_norm_func(x, mean=mean, var=var, weight=weight, bias=bias):
+            return custom_group_norm(x, 32, mean, var, weight, bias, 1e-6)
+        return group_norm_func
+
+
+class VAEHook:
+    def __init__(self, net, tile_size, is_decoder, fast_decoder, fast_encoder, color_fix, to_gpu=False):
+        self.net = net                  # encoder | decoder
+        self.tile_size = tile_size
+        self.is_decoder = is_decoder
+        self.fast_mode = (fast_encoder and not is_decoder) or (
+            fast_decoder and is_decoder)
+        self.color_fix = color_fix and not is_decoder
+        self.to_gpu = to_gpu
+        self.pad = 11 if is_decoder else 32
+
+    def __call__(self, x):
+        B, C, H, W = x.shape
+        original_device = next(self.net.parameters()).device
+        try:
+            if self.to_gpu:
+                self.net.to(devices.get_optimal_device())
+            if max(H, W) <= self.pad * 2 + self.tile_size:
+                print("[Tiled VAE]: the input size is tiny and unnecessary to tile.")
+                return self.net.original_forward(x)
+            else:
+                return self.vae_tile_forward(x)
+        finally:
+            self.net.to(original_device)
+
+    def get_best_tile_size(self, lowerbound, upperbound):
+        """
+        Get the best tile size for GPU memory
+        """
+        divider = 32
+        while divider >= 2:
+            remainer = lowerbound % divider
+            if remainer == 0:
+                return lowerbound
+            candidate = lowerbound - remainer + divider
+            if candidate <= upperbound:
+                return candidate
+            divider //= 2
+        return lowerbound
+
+    def split_tiles(self, h, w):
+        """
+        Tool function to split the image into tiles
+        @param h: height of the image
+        @param w: width of the image
+        @return: tile_input_bboxes, tile_output_bboxes
+        """
+        tile_input_bboxes, tile_output_bboxes = [], []
+        tile_size = self.tile_size
+        pad = self.pad
+        num_height_tiles = math.ceil((h - 2 * pad) / tile_size)
+        num_width_tiles = math.ceil((w - 2 * pad) / tile_size)
+        # If any of the numbers are 0, we let it be 1
+        # This is to deal with long and thin images
+        num_height_tiles = max(num_height_tiles, 1)
+        num_width_tiles = max(num_width_tiles, 1)
+
+        # Suggestions from https://github.com/Kahsolt: auto shrink the tile size
+        real_tile_height = math.ceil((h - 2 * pad) / num_height_tiles)
+        real_tile_width = math.ceil((w - 2 * pad) / num_width_tiles)
+        real_tile_height = self.get_best_tile_size(real_tile_height, tile_size)
+        real_tile_width = self.get_best_tile_size(real_tile_width, tile_size)
+
+        print(f'[Tiled VAE]: split to {num_height_tiles}x{num_width_tiles} = {num_height_tiles*num_width_tiles} tiles. ' +
+              f'Optimal tile size {real_tile_width}x{real_tile_height}, original tile size {tile_size}x{tile_size}')
+
+        for i in range(num_height_tiles):
+            for j in range(num_width_tiles):
+                # bbox: [x1, x2, y1, y2]
+                # the padding is is unnessary for image borders. So we directly start from (32, 32)
+                input_bbox = [
+                    pad + j * real_tile_width,
+                    min(pad + (j + 1) * real_tile_width, w),
+                    pad + i * real_tile_height,
+                    min(pad + (i + 1) * real_tile_height, h),
+                ]
+
+                # if the output bbox is close to the image boundary, we extend it to the image boundary
+                output_bbox = [
+                    input_bbox[0] if input_bbox[0] > pad else 0,
+                    input_bbox[1] if input_bbox[1] < w - pad else w,
+                    input_bbox[2] if input_bbox[2] > pad else 0,
+                    input_bbox[3] if input_bbox[3] < h - pad else h,
+                ]
+
+                # scale to get the final output bbox
+                output_bbox = [x * 8 if self.is_decoder else x // 8 for x in output_bbox]
+                tile_output_bboxes.append(output_bbox)
+
+                # indistinguishable expand the input bbox by pad pixels
+                tile_input_bboxes.append([
+                    max(0, input_bbox[0] - pad),
+                    min(w, input_bbox[1] + pad),
+                    max(0, input_bbox[2] - pad),
+                    min(h, input_bbox[3] + pad),
+                ])
+
+        return tile_input_bboxes, tile_output_bboxes
+
+    @torch.no_grad()
+    def estimate_group_norm(self, z, task_queue, color_fix):
+        device = z.device
+        tile = z
+        last_id = len(task_queue) - 1
+        while last_id >= 0 and task_queue[last_id][0] != 'pre_norm':
+            last_id -= 1
+        if last_id <= 0 or task_queue[last_id][0] != 'pre_norm':
+            raise ValueError('No group norm found in the task queue')
+        # estimate until the last group norm
+        for i in range(last_id + 1):
+            task = task_queue[i]
+            if task[0] == 'pre_norm':
+                group_norm_func = GroupNormParam.from_tile(tile, task[1])
+                task_queue[i] = ('apply_norm', group_norm_func)
+                if i == last_id:
+                    return True
+                tile = group_norm_func(tile)
+            elif task[0] == 'store_res':
+                task_id = i + 1
+                while task_id < last_id and task_queue[task_id][0] != 'add_res':
+                    task_id += 1
+                if task_id >= last_id:
+                    continue
+                task_queue[task_id][1] = task[1](tile)
+            elif task[0] == 'add_res':
+                tile += task[1].to(device)
+                task[1] = None
+            elif color_fix and task[0] == 'downsample':
+                for j in range(i, last_id + 1):
+                    if task_queue[j][0] == 'store_res':
+                        task_queue[j] = ('store_res_cpu', task_queue[j][1])
+                return True
+            else:
+                tile = task[1](tile)
+            try:
+                devices.test_for_nans(tile, "vae")
+            except:
+                print(f'Nan detected in fast mode estimation. Fast mode disabled.')
+                return False
+
+        raise IndexError('Should not reach here')
+
+    @perfcount
+    @torch.no_grad()
+    def vae_tile_forward(self, z):
+        """
+        Decode a latent vector z into an image in a tiled manner.
+        @param z: latent vector
+        @return: image
+        """
+        device = next(self.net.parameters()).device
+        net = self.net
+        tile_size = self.tile_size
+        is_decoder = self.is_decoder
+
+        z = z.detach() # detach the input to avoid backprop
+
+        N, height, width = z.shape[0], z.shape[2], z.shape[3]
+        net.last_z_shape = z.shape
+
+        # Split the input into tiles and build a task queue for each tile
+        print(f'[Tiled VAE]: input_size: {z.shape}, tile_size: {tile_size}, padding: {self.pad}')
+
+        in_bboxes, out_bboxes = self.split_tiles(height, width)
+
+        # Prepare tiles by split the input latents
+        tiles = []
+        for input_bbox in in_bboxes:
+            tile = z[:, :, input_bbox[2]:input_bbox[3], input_bbox[0]:input_bbox[1]].cpu()
+            tiles.append(tile)
+
+        num_tiles = len(tiles)
+        num_completed = 0
+
+        # Build task queues
+        single_task_queue = build_task_queue(net, is_decoder)
+        #print(single_task_queue)
+        if self.fast_mode:
+            # Fast mode: downsample the input image to the tile size,
+            # then estimate the group norm parameters on the downsampled image
+            scale_factor = tile_size / max(height, width)
+            z = z.to(device)
+            downsampled_z = F.interpolate(z, scale_factor=scale_factor, mode='nearest-exact')
+            # use nearest-exact to keep statictics as close as possible
+            print(f'[Tiled VAE]: Fast mode enabled, estimating group norm parameters on {downsampled_z.shape[3]} x {downsampled_z.shape[2]} image')
+
+            # ======= Special thanks to @Kahsolt for distribution shift issue ======= #
+            # The downsampling will heavily distort its mean and std, so we need to recover it.
+            std_old, mean_old = torch.std_mean(z, dim=[0, 2, 3], keepdim=True)
+            std_new, mean_new = torch.std_mean(downsampled_z, dim=[0, 2, 3], keepdim=True)
+            downsampled_z = (downsampled_z - mean_new) / std_new * std_old + mean_old
+            del std_old, mean_old, std_new, mean_new
+            # occasionally the std_new is too small or too large, which exceeds the range of float16
+            # so we need to clamp it to max z's range.
+            downsampled_z = torch.clamp_(downsampled_z, min=z.min(), max=z.max())
+            estimate_task_queue = clone_task_queue(single_task_queue)
+            if self.estimate_group_norm(downsampled_z, estimate_task_queue, color_fix=self.color_fix):
+                single_task_queue = estimate_task_queue
+            del downsampled_z
+
+        task_queues = [clone_task_queue(single_task_queue) for _ in range(num_tiles)]
+
+        # Dummy result
+        result = None
+        result_approx = None
+        #try:
+        #    with devices.autocast():
+        #        result_approx = torch.cat([F.interpolate(cheap_approximation(x).unsqueeze(0), scale_factor=opt_f, mode='nearest-exact') for x in z], dim=0).cpu()
+        #except: pass
+        # Free memory of input latent tensor
+        del z
+
+        # Task queue execution
+        pbar = tqdm(total=num_tiles * len(task_queues[0]), desc=f"[Tiled VAE]: Executing {'Decoder' if is_decoder else 'Encoder'} Task Queue: ")
+
+        # execute the task back and forth when switch tiles so that we always
+        # keep one tile on the GPU to reduce unnecessary data transfer
+        forward = True
+        interrupted = False
+        #state.interrupted = interrupted
+        while True:
+            #if state.interrupted: interrupted = True ; break
+
+            group_norm_param = GroupNormParam()
+            for i in range(num_tiles) if forward else reversed(range(num_tiles)):
+                #if state.interrupted: interrupted = True ; break
+
+                tile = tiles[i].to(device)
+                input_bbox = in_bboxes[i]
+                task_queue = task_queues[i]
+
+                interrupted = False
+                while len(task_queue) > 0:
+                    #if state.interrupted: interrupted = True ; break
+
+                    # DEBUG: current task
+                    # print('Running task: ', task_queue[0][0], ' on tile ', i, '/', num_tiles, ' with shape ', tile.shape)
+                    task = task_queue.pop(0)
+                    if task[0] == 'pre_norm':
+                        group_norm_param.add_tile(tile, task[1])
+                        break
+                    elif task[0] == 'store_res' or task[0] == 'store_res_cpu':
+                        task_id = 0
+                        res = task[1](tile)
+                        if not self.fast_mode or task[0] == 'store_res_cpu':
+                            res = res.cpu()
+                        while task_queue[task_id][0] != 'add_res':
+                            task_id += 1
+                        task_queue[task_id][1] = res
+                    elif task[0] == 'add_res':
+                        tile += task[1].to(device)
+                        task[1] = None
+                    else:
+                        tile = task[1](tile)
+                    pbar.update(1)
+
+                if interrupted: break
+
+                # check for NaNs in the tile.
+                # If there are NaNs, we abort the process to save user's time
+                #devices.test_for_nans(tile, "vae")
+
+                #print(tiles[i].shape, tile.shape, i, num_tiles)
+                if len(task_queue) == 0:
+                    tiles[i] = None
+                    num_completed += 1
+                    if result is None:      # NOTE: dim C varies from different cases, can only be inited dynamically
+                        result = torch.zeros((N, tile.shape[1], height * 8 if is_decoder else height // 8, width * 8 if is_decoder else width // 8), device=device, requires_grad=False)
+                    result[:, :, out_bboxes[i][2]:out_bboxes[i][3], out_bboxes[i][0]:out_bboxes[i][1]] = crop_valid_region(tile, in_bboxes[i], out_bboxes[i], is_decoder)
+                    del tile
+                elif i == num_tiles - 1 and forward:
+                    forward = False
+                    tiles[i] = tile
+                elif i == 0 and not forward:
+                    forward = True
+                    tiles[i] = tile
+                else:
+                    tiles[i] = tile.cpu()
+                    del tile
+
+            if interrupted: break
+            if num_completed == num_tiles: break
+
+            # insert the group norm task to the head of each task queue
+            group_norm_func = group_norm_param.summary()
+            if group_norm_func is not None:
+                for i in range(num_tiles):
+                    task_queue = task_queues[i]
+                    task_queue.insert(0, ('apply_norm', group_norm_func))
+
+        # Done!
+        pbar.close()
+        return result if result is not None else result_approx.to(device)
\ No newline at end of file
diff --git a/src/s3diff.py b/src/s3diff.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b2f03d14dd07e5d0ab85dfc1e90a342f6d3a1b6
--- /dev/null
+++ b/src/s3diff.py
@@ -0,0 +1,305 @@
+import os
+import re
+import requests
+import sys
+import copy
+import numpy as np
+from tqdm import tqdm
+import torch
+import torch.nn as nn
+from transformers import AutoTokenizer, CLIPTextModel
+from diffusers import AutoencoderKL, UNet2DConditionModel
+from peft import LoraConfig, get_peft_model
+p = "src/"
+sys.path.append(p)
+from model import make_1step_sched, my_lora_fwd
+from basicsr.archs.arch_util import default_init_weights
+
+def get_layer_number(module_name):
+    base_layers = {
+        'down_blocks': 0,
+        'mid_block': 4,
+        'up_blocks': 5
+    }
+
+    if module_name == 'conv_out':
+        return 9
+
+    base_layer = None
+    for key in base_layers:
+        if key in module_name:
+            base_layer = base_layers[key]
+            break
+
+    if base_layer is None:
+        return None
+
+    additional_layers = int(re.findall(r'\.(\d+)', module_name)[0]) #sum(int(num) for num in re.findall(r'\d+', module_name))
+    final_layer = base_layer + additional_layers
+    return final_layer
+
+
+class S3Diff(torch.nn.Module):
+    def __init__(self, sd_path=None, pretrained_path=None, lora_rank_unet=32, lora_rank_vae=16, block_embedding_dim=64):
+        super().__init__()
+        self.tokenizer = AutoTokenizer.from_pretrained(sd_path, subfolder="tokenizer")
+        self.text_encoder = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder").cuda()
+        self.sched = make_1step_sched(sd_path)
+
+        vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae")
+        unet = UNet2DConditionModel.from_pretrained(sd_path, subfolder="unet")
+
+        target_modules_vae = r"^encoder\..*(conv1|conv2|conv_in|conv_shortcut|conv|conv_out|to_k|to_q|to_v|to_out\.0)$"
+        target_modules_unet = [
+            "to_k", "to_q", "to_v", "to_out.0", "conv", "conv1", "conv2", "conv_shortcut", "conv_out",
+            "proj_in", "proj_out", "ff.net.2", "ff.net.0.proj"
+        ]
+
+        num_embeddings = 64
+        self.W = nn.Parameter(torch.randn(num_embeddings), requires_grad=False)
+
+        self.vae_de_mlp = nn.Sequential(
+            nn.Linear(num_embeddings * 4, 256),
+            nn.ReLU(True),
+        )
+
+        self.unet_de_mlp = nn.Sequential(
+            nn.Linear(num_embeddings * 4, 256),
+            nn.ReLU(True),
+        )
+
+        self.vae_block_mlp = nn.Sequential(
+            nn.Linear(block_embedding_dim, 64),
+            nn.ReLU(True),
+        )
+
+        self.unet_block_mlp = nn.Sequential(
+            nn.Linear(block_embedding_dim, 64),
+            nn.ReLU(True),
+        )
+
+        self.vae_fuse_mlp = nn.Linear(256 + 64, lora_rank_vae ** 2)
+        self.unet_fuse_mlp = nn.Linear(256 + 64, lora_rank_unet ** 2)
+
+        default_init_weights([self.vae_de_mlp, self.unet_de_mlp, self.vae_block_mlp, self.unet_block_mlp, \
+            self.vae_fuse_mlp, self.unet_fuse_mlp], 1e-5)
+
+        # vae
+        self.vae_block_embeddings = nn.Embedding(6, block_embedding_dim)
+        self.unet_block_embeddings = nn.Embedding(10, block_embedding_dim)
+
+        if pretrained_path is not None:
+            sd = torch.load(pretrained_path, map_location="cpu")
+            vae_lora_config = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian", target_modules=sd["vae_lora_target_modules"])
+            vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
+            _sd_vae = vae.state_dict()
+            for k in sd["state_dict_vae"]:
+                _sd_vae[k] = sd["state_dict_vae"][k]
+            vae.load_state_dict(_sd_vae)
+
+            unet_lora_config = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_target_modules"])
+            unet.add_adapter(unet_lora_config)
+            _sd_unet = unet.state_dict()
+            for k in sd["state_dict_unet"]:
+                _sd_unet[k] = sd["state_dict_unet"][k]
+            unet.load_state_dict(_sd_unet)
+
+            _vae_de_mlp = self.vae_de_mlp.state_dict()
+            for k in sd["state_dict_vae_de_mlp"]:
+                _vae_de_mlp[k] = sd["state_dict_vae_de_mlp"][k]
+            self.vae_de_mlp.load_state_dict(_vae_de_mlp)
+
+            _unet_de_mlp = self.unet_de_mlp.state_dict()
+            for k in sd["state_dict_unet_de_mlp"]:
+                _unet_de_mlp[k] = sd["state_dict_unet_de_mlp"][k]
+            self.unet_de_mlp.load_state_dict(_unet_de_mlp)
+
+            _vae_block_mlp = self.vae_block_mlp.state_dict()
+            for k in sd["state_dict_vae_block_mlp"]:
+                _vae_block_mlp[k] = sd["state_dict_vae_block_mlp"][k]
+            self.vae_block_mlp.load_state_dict(_vae_block_mlp)
+
+            _unet_block_mlp = self.unet_block_mlp.state_dict()
+            for k in sd["state_dict_unet_block_mlp"]:
+                _unet_block_mlp[k] = sd["state_dict_unet_block_mlp"][k]
+            self.unet_block_mlp.load_state_dict(_unet_block_mlp)
+
+            _vae_fuse_mlp = self.vae_fuse_mlp.state_dict()
+            for k in sd["state_dict_vae_fuse_mlp"]:
+                _vae_fuse_mlp[k] = sd["state_dict_vae_fuse_mlp"][k]
+            self.vae_fuse_mlp.load_state_dict(_vae_fuse_mlp)
+
+            _unet_fuse_mlp = self.unet_fuse_mlp.state_dict()
+            for k in sd["state_dict_unet_fuse_mlp"]:
+                _unet_fuse_mlp[k] = sd["state_dict_unet_fuse_mlp"][k]
+            self.unet_fuse_mlp.load_state_dict(_unet_fuse_mlp)
+
+            self.W = nn.Parameter(sd["w"], requires_grad=False)
+
+            embeddings_state_dict = sd["state_embeddings"]
+            self.vae_block_embeddings.load_state_dict(embeddings_state_dict['state_dict_vae_block'])
+            self.unet_block_embeddings.load_state_dict(embeddings_state_dict['state_dict_unet_block'])
+        else:
+            print("Initializing model with random weights")
+            vae_lora_config = LoraConfig(r=lora_rank_vae, init_lora_weights="gaussian",
+                target_modules=target_modules_vae)
+            vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
+            unet_lora_config = LoraConfig(r=lora_rank_unet, init_lora_weights="gaussian",
+                target_modules=target_modules_unet
+            )
+            unet.add_adapter(unet_lora_config)
+
+        self.lora_rank_unet = lora_rank_unet
+        self.lora_rank_vae = lora_rank_vae
+        self.target_modules_vae = target_modules_vae
+        self.target_modules_unet = target_modules_unet
+
+        self.vae_lora_layers = []
+        for name, module in vae.named_modules():
+            if 'base_layer' in name:
+                self.vae_lora_layers.append(name[:-len(".base_layer")])
+                
+        for name, module in vae.named_modules():
+            if name in self.vae_lora_layers:
+                module.forward = my_lora_fwd.__get__(module, module.__class__)
+
+        self.unet_lora_layers = []
+        for name, module in unet.named_modules():
+            if 'base_layer' in name:
+                self.unet_lora_layers.append(name[:-len(".base_layer")])
+
+        for name, module in unet.named_modules():
+            if name in self.unet_lora_layers:
+                module.forward = my_lora_fwd.__get__(module, module.__class__)
+
+        self.unet_layer_dict = {name: get_layer_number(name) for name in self.unet_lora_layers}
+
+        unet.to("cuda")
+        vae.to("cuda")
+        self.unet, self.vae = unet, vae
+        self.timesteps = torch.tensor([999], device="cuda").long()
+        self.text_encoder.requires_grad_(False)
+
+    def set_eval(self):
+        self.unet.eval()
+        self.vae.eval()
+        self.vae_de_mlp.eval()
+        self.unet_de_mlp.eval()
+        self.vae_block_mlp.eval()
+        self.unet_block_mlp.eval()
+        self.vae_fuse_mlp.eval()
+        self.unet_fuse_mlp.eval()
+
+        self.vae_block_embeddings.requires_grad_(False)
+        self.unet_block_embeddings.requires_grad_(False)
+
+        self.unet.requires_grad_(False)
+        self.vae.requires_grad_(False)
+
+    def set_train(self):
+        self.unet.train()
+        self.vae.train()
+        self.vae_de_mlp.train()
+        self.unet_de_mlp.train()
+        self.vae_block_mlp.train()
+        self.unet_block_mlp.train()
+        self.vae_fuse_mlp.train()
+        self.unet_fuse_mlp.train()    
+
+        self.vae_block_embeddings.requires_grad_(True)
+        self.unet_block_embeddings.requires_grad_(True)
+
+        for n, _p in self.unet.named_parameters():
+            if "lora" in n:
+                _p.requires_grad = True
+
+        self.unet.conv_in.requires_grad_(True)
+
+        for n, _p in self.vae.named_parameters():
+            if "lora" in n:
+                _p.requires_grad = True
+
+    def forward(self, c_t, deg_score, prompt):
+ 
+        if prompt is not None:
+            # encode the text prompt
+            caption_tokens = self.tokenizer(prompt, max_length=self.tokenizer.model_max_length,
+                                            padding="max_length", truncation=True, return_tensors="pt").input_ids.cuda()
+            caption_enc = self.text_encoder(caption_tokens)[0]
+        else:
+            caption_enc = self.text_encoder(prompt_tokens)[0]
+
+        # degradation fourier embedding
+        deg_proj = deg_score[..., None] * self.W[None, None, :] * 2 * np.pi
+        deg_proj = torch.cat([torch.sin(deg_proj), torch.cos(deg_proj)], dim=-1)
+        deg_proj = torch.cat([deg_proj[:, 0], deg_proj[:, 1]], dim=-1)
+
+        # degradation mlp forward
+        vae_de_c_embed = self.vae_de_mlp(deg_proj)
+        unet_de_c_embed = self.unet_de_mlp(deg_proj)
+
+        # block embedding mlp forward
+        vae_block_c_embeds = self.vae_block_mlp(self.vae_block_embeddings.weight)
+        unet_block_c_embeds = self.unet_block_mlp(self.unet_block_embeddings.weight)
+        vae_embeds = self.vae_fuse_mlp(torch.cat([vae_de_c_embed.unsqueeze(1).repeat(1, vae_block_c_embeds.shape[0], 1), \
+            vae_block_c_embeds.unsqueeze(0).repeat(vae_de_c_embed.shape[0],1,1)], -1))
+        unet_embeds = self.unet_fuse_mlp(torch.cat([unet_de_c_embed.unsqueeze(1).repeat(1, unet_block_c_embeds.shape[0], 1), \
+            unet_block_c_embeds.unsqueeze(0).repeat(unet_de_c_embed.shape[0],1,1)], -1))
+
+        for layer_name, module in self.vae.named_modules():
+            if layer_name in self.vae_lora_layers:
+                split_name = layer_name.split(".")
+                if split_name[1] == 'down_blocks':
+                    block_id = int(split_name[2])
+                    vae_embed = vae_embeds[:, block_id]
+                elif split_name[1] == 'mid_block':
+                    vae_embed = vae_embeds[:, -2]
+                else:
+                    vae_embed = vae_embeds[:, -1]
+                module.de_mod = vae_embed.reshape(-1, self.lora_rank_vae, self.lora_rank_vae)
+
+        for layer_name, module in self.unet.named_modules():
+            if layer_name in self.unet_lora_layers:
+                split_name = layer_name.split(".")
+
+                if split_name[0] == 'down_blocks':
+                    block_id = int(split_name[1])
+                    unet_embed = unet_embeds[:, block_id]
+                elif split_name[0] == 'mid_block':
+                    unet_embed = unet_embeds[:, 4]
+                elif split_name[0] == 'up_blocks':
+                    block_id = int(split_name[1]) + 5
+                    unet_embed = unet_embeds[:, block_id]
+                else:
+                    unet_embed = unet_embeds[:, -1]
+                module.de_mod = unet_embed.reshape(-1, self.lora_rank_unet, self.lora_rank_unet)
+
+        encoded_control = self.vae.encode(c_t).latent_dist.sample() * self.vae.config.scaling_factor
+        model_pred = self.unet(encoded_control, self.timesteps, encoder_hidden_states=caption_enc,).sample
+        x_denoised = self.sched.step(model_pred, self.timesteps, encoded_control, return_dict=True).prev_sample
+        output_image = (self.vae.decode(x_denoised / self.vae.config.scaling_factor).sample).clamp(-1, 1)
+
+        return output_image
+
+    def save_model(self, outf):
+        sd = {}
+        sd["unet_lora_target_modules"] = self.target_modules_unet
+        sd["vae_lora_target_modules"] = self.target_modules_vae
+        sd["rank_unet"] = self.lora_rank_unet
+        sd["rank_vae"] = self.lora_rank_vae
+        sd["state_dict_unet"] = {k: v for k, v in self.unet.state_dict().items() if "lora" in k or "conv_in" in k}
+        sd["state_dict_vae"] = {k: v for k, v in self.vae.state_dict().items() if "lora" in k or "skip_conv" in k}
+        sd["state_dict_vae_de_mlp"] = {k: v for k, v in self.vae_de_mlp.state_dict().items()}
+        sd["state_dict_unet_de_mlp"] = {k: v for k, v in self.unet_de_mlp.state_dict().items()}
+        sd["state_dict_vae_block_mlp"] = {k: v for k, v in self.vae_block_mlp.state_dict().items()}
+        sd["state_dict_unet_block_mlp"] = {k: v for k, v in self.unet_block_mlp.state_dict().items()}
+        sd["state_dict_vae_fuse_mlp"] = {k: v for k, v in self.vae_fuse_mlp.state_dict().items()}
+        sd["state_dict_unet_fuse_mlp"] = {k: v for k, v in self.unet_fuse_mlp.state_dict().items()}
+        sd["w"] = self.W
+
+        sd["state_embeddings"] = {
+                    "state_dict_vae_block": self.vae_block_embeddings.state_dict(),
+                    "state_dict_unet_block": self.unet_block_embeddings.state_dict(),
+                }
+
+        torch.save(sd, outf)
diff --git a/src/s3diff_cfg.py b/src/s3diff_cfg.py
new file mode 100644
index 0000000000000000000000000000000000000000..54a7b2353a702fffb01edb1b4a63a1e6d77dc389
--- /dev/null
+++ b/src/s3diff_cfg.py
@@ -0,0 +1,316 @@
+import os
+import re
+import requests
+import sys
+import copy
+import numpy as np
+from tqdm import tqdm
+import torch
+import torch.nn as nn
+from transformers import AutoTokenizer, CLIPTextModel
+from diffusers import AutoencoderKL, UNet2DConditionModel
+from peft import LoraConfig, get_peft_model
+p = "src/"
+sys.path.append(p)
+from model import make_1step_sched, my_lora_fwd
+from basicsr.archs.arch_util import default_init_weights
+
+
+def get_layer_number(module_name):
+    base_layers = {
+        'down_blocks': 0,
+        'mid_block': 4,
+        'up_blocks': 5
+    }
+
+    if module_name == 'conv_out':
+        return 9
+
+    base_layer = None
+    for key in base_layers:
+        if key in module_name:
+            base_layer = base_layers[key]
+            break
+
+    if base_layer is None:
+        return None
+
+    additional_layers = int(re.findall(r'\.(\d+)', module_name)[0]) #sum(int(num) for num in re.findall(r'\d+', module_name))
+    final_layer = base_layer + additional_layers
+    return final_layer
+
+
+class S3Diff(torch.nn.Module):
+    def __init__(self, sd_path=None, pretrained_path=None, lora_rank_unet=8, lora_rank_vae=4, block_embedding_dim=64):
+        super().__init__()
+        self.tokenizer = AutoTokenizer.from_pretrained(sd_path, subfolder="tokenizer")
+        self.text_encoder = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder").cuda()
+        self.sched = make_1step_sched(sd_path)
+        self.guidance_scale = 1.07
+
+        vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae")
+        unet = UNet2DConditionModel.from_pretrained(sd_path, subfolder="unet")
+
+        target_modules_vae = r"^encoder\..*(conv1|conv2|conv_in|conv_shortcut|conv|conv_out|to_k|to_q|to_v|to_out\.0)$"
+        target_modules_unet = [
+            "to_k", "to_q", "to_v", "to_out.0", "conv", "conv1", "conv2", "conv_shortcut", "conv_out",
+            "proj_in", "proj_out", "ff.net.2", "ff.net.0.proj"
+        ]
+
+        num_embeddings = 64
+        self.W = nn.Parameter(torch.randn(num_embeddings), requires_grad=False)
+
+        self.vae_de_mlp = nn.Sequential(
+            nn.Linear(num_embeddings * 4, 256),
+            nn.ReLU(True),
+        )
+
+        self.unet_de_mlp = nn.Sequential(
+            nn.Linear(num_embeddings * 4, 256),
+            nn.ReLU(True),
+        )
+
+        self.vae_block_mlp = nn.Sequential(
+            nn.Linear(block_embedding_dim, 64),
+            nn.ReLU(True),
+        )
+
+        self.unet_block_mlp = nn.Sequential(
+            nn.Linear(block_embedding_dim, 64),
+            nn.ReLU(True),
+        )
+
+        self.vae_fuse_mlp = nn.Linear(256 + 64, lora_rank_vae ** 2)
+        self.unet_fuse_mlp = nn.Linear(256 + 64, lora_rank_unet ** 2)
+
+        default_init_weights([self.vae_de_mlp, self.unet_de_mlp, self.vae_block_mlp, self.unet_block_mlp, \
+            self.vae_fuse_mlp, self.unet_fuse_mlp], 1e-5)
+
+        # vae
+        self.vae_block_embeddings = nn.Embedding(6, block_embedding_dim)
+        self.unet_block_embeddings = nn.Embedding(10, block_embedding_dim)
+
+        if pretrained_path is not None:
+            sd = torch.load(pretrained_path, map_location="cpu")
+            vae_lora_config = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian", target_modules=sd["vae_lora_target_modules"])
+            vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
+            _sd_vae = vae.state_dict()
+            for k in sd["state_dict_vae"]:
+                _sd_vae[k] = sd["state_dict_vae"][k]
+            vae.load_state_dict(_sd_vae)
+
+            unet_lora_config = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_target_modules"])
+            unet.add_adapter(unet_lora_config)
+            _sd_unet = unet.state_dict()
+            for k in sd["state_dict_unet"]:
+                _sd_unet[k] = sd["state_dict_unet"][k]
+            unet.load_state_dict(_sd_unet)
+
+            _vae_de_mlp = self.vae_de_mlp.state_dict()
+            for k in sd["state_dict_vae_de_mlp"]:
+                _vae_de_mlp[k] = sd["state_dict_vae_de_mlp"][k]
+            self.vae_de_mlp.load_state_dict(_vae_de_mlp)
+
+            _unet_de_mlp = self.unet_de_mlp.state_dict()
+            for k in sd["state_dict_unet_de_mlp"]:
+                _unet_de_mlp[k] = sd["state_dict_unet_de_mlp"][k]
+            self.unet_de_mlp.load_state_dict(_unet_de_mlp)
+
+            _vae_block_mlp = self.vae_block_mlp.state_dict()
+            for k in sd["state_dict_vae_block_mlp"]:
+                _vae_block_mlp[k] = sd["state_dict_vae_block_mlp"][k]
+            self.vae_block_mlp.load_state_dict(_vae_block_mlp)
+
+            _unet_block_mlp = self.unet_block_mlp.state_dict()
+            for k in sd["state_dict_unet_block_mlp"]:
+                _unet_block_mlp[k] = sd["state_dict_unet_block_mlp"][k]
+            self.unet_block_mlp.load_state_dict(_unet_block_mlp)
+
+            _vae_fuse_mlp = self.vae_fuse_mlp.state_dict()
+            for k in sd["state_dict_vae_fuse_mlp"]:
+                _vae_fuse_mlp[k] = sd["state_dict_vae_fuse_mlp"][k]
+            self.vae_fuse_mlp.load_state_dict(_vae_fuse_mlp)
+
+            _unet_fuse_mlp = self.unet_fuse_mlp.state_dict()
+            for k in sd["state_dict_unet_fuse_mlp"]:
+                _unet_fuse_mlp[k] = sd["state_dict_unet_fuse_mlp"][k]
+            self.unet_fuse_mlp.load_state_dict(_unet_fuse_mlp)
+
+            self.W = nn.Parameter(sd["w"], requires_grad=False)
+
+            embeddings_state_dict = sd["state_embeddings"]
+            self.vae_block_embeddings.load_state_dict(embeddings_state_dict['state_dict_vae_block'])
+            self.unet_block_embeddings.load_state_dict(embeddings_state_dict['state_dict_unet_block'])
+        else:
+            print("Initializing model with random weights")
+            vae_lora_config = LoraConfig(r=lora_rank_vae, init_lora_weights="gaussian",
+                target_modules=target_modules_vae)
+            vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
+            unet_lora_config = LoraConfig(r=lora_rank_unet, init_lora_weights="gaussian",
+                target_modules=target_modules_unet
+            )
+            unet.add_adapter(unet_lora_config)
+
+        self.lora_rank_unet = lora_rank_unet
+        self.lora_rank_vae = lora_rank_vae
+        self.target_modules_vae = target_modules_vae
+        self.target_modules_unet = target_modules_unet
+
+        self.vae_lora_layers = []
+        for name, module in vae.named_modules():
+            if 'base_layer' in name:
+                self.vae_lora_layers.append(name[:-len(".base_layer")])
+                
+        for name, module in vae.named_modules():
+            if name in self.vae_lora_layers:
+                module.forward = my_lora_fwd.__get__(module, module.__class__)
+
+        self.unet_lora_layers = []
+        for name, module in unet.named_modules():
+            if 'base_layer' in name:
+                self.unet_lora_layers.append(name[:-len(".base_layer")])
+
+        for name, module in unet.named_modules():
+            if name in self.unet_lora_layers:
+                module.forward = my_lora_fwd.__get__(module, module.__class__)
+
+        self.unet_layer_dict = {name: get_layer_number(name) for name in self.unet_lora_layers}
+
+        unet.to("cuda")
+        vae.to("cuda")
+        self.unet, self.vae = unet, vae
+        self.timesteps = torch.tensor([999], device="cuda").long()
+        self.text_encoder.requires_grad_(False)
+
+    def set_eval(self):
+        self.unet.eval()
+        self.vae.eval()
+        self.vae_de_mlp.eval()
+        self.unet_de_mlp.eval()
+        self.vae_block_mlp.eval()
+        self.unet_block_mlp.eval()
+        self.vae_fuse_mlp.eval()
+        self.unet_fuse_mlp.eval()
+
+        self.vae_block_embeddings.requires_grad_(False)
+        self.unet_block_embeddings.requires_grad_(False)
+
+        self.unet.requires_grad_(False)
+        self.vae.requires_grad_(False)
+
+    def set_train(self):
+        self.unet.train()
+        self.vae.train()
+        self.vae_de_mlp.train()
+        self.unet_de_mlp.train()
+        self.vae_block_mlp.train()
+        self.unet_block_mlp.train()
+        self.vae_fuse_mlp.train()
+        self.unet_fuse_mlp.train()    
+
+        self.vae_block_embeddings.requires_grad_(True)
+        self.unet_block_embeddings.requires_grad_(True)
+
+        for n, _p in self.unet.named_parameters():
+            if "lora" in n:
+                _p.requires_grad = True
+        self.unet.conv_in.requires_grad_(True)
+
+        for n, _p in self.vae.named_parameters():
+            if "lora" in n:
+                _p.requires_grad = True
+
+    def forward(self, c_t, deg_score, pos_prompt, neg_prompt):
+ 
+        if pos_prompt is not None:
+            # encode the text prompt
+            pos_caption_tokens = self.tokenizer(pos_prompt, max_length=self.tokenizer.model_max_length,
+                                            padding="max_length", truncation=True, return_tensors="pt").input_ids.cuda()
+            pos_caption_enc = self.text_encoder(pos_caption_tokens)[0]
+        else:
+            pos_caption_enc = self.text_encoder(prompt_tokens)[0]
+
+        if neg_prompt is not None:
+            # encode the text prompt
+            neg_caption_tokens = self.tokenizer(neg_prompt, max_length=self.tokenizer.model_max_length,
+                                            padding="max_length", truncation=True, return_tensors="pt").input_ids.cuda()
+            neg_caption_enc = self.text_encoder(neg_caption_tokens)[0]
+        else:
+            neg_caption_enc = self.text_encoder(neg_prompt_tokens)[0]
+
+        # degradation fourier embedding 
+        deg_proj = deg_score[..., None] * self.W[None, None, :] * 2 * np.pi
+        deg_proj = torch.cat([torch.sin(deg_proj), torch.cos(deg_proj)], dim=-1)
+        deg_proj = torch.cat([deg_proj[:, 0], deg_proj[:, 1]], dim=-1)
+
+        # degradation mlp forward
+        vae_de_c_embed = self.vae_de_mlp(deg_proj)
+        unet_de_c_embed = self.unet_de_mlp(deg_proj)
+
+        # block embedding mlp forward
+        vae_block_c_embeds = self.vae_block_mlp(self.vae_block_embeddings.weight)
+        unet_block_c_embeds = self.unet_block_mlp(self.unet_block_embeddings.weight)
+        vae_embeds = self.vae_fuse_mlp(torch.cat([vae_de_c_embed.unsqueeze(1).repeat(1, vae_block_c_embeds.shape[0], 1), \
+            vae_block_c_embeds.unsqueeze(0).repeat(vae_de_c_embed.shape[0],1,1)], -1))
+        unet_embeds = self.unet_fuse_mlp(torch.cat([unet_de_c_embed.unsqueeze(1).repeat(1, unet_block_c_embeds.shape[0], 1), \
+            unet_block_c_embeds.unsqueeze(0).repeat(unet_de_c_embed.shape[0],1,1)], -1))
+
+        for layer_name, module in self.vae.named_modules():
+            if layer_name in self.vae_lora_layers:
+                split_name = layer_name.split(".")
+                if split_name[1] == 'down_blocks':
+                    block_id = int(split_name[2])
+                    vae_embed = vae_embeds[:, block_id]
+                elif split_name[1] == 'mid_block':
+                    vae_embed = vae_embeds[:, -2]
+                else:
+                    vae_embed = vae_embeds[:, -1]
+                module.de_mod = vae_embed.reshape(-1, self.lora_rank_vae, self.lora_rank_vae)
+
+        for layer_name, module in self.unet.named_modules():
+            if layer_name in self.unet_lora_layers:
+                split_name = layer_name.split(".")
+                if split_name[0] == 'down_blocks':
+                    block_id = int(split_name[1])
+                    unet_embed = unet_embeds[:, block_id]
+                elif split_name[0] == 'mid_block':
+                    unet_embed = unet_embeds[:, 4]
+                elif split_name[0] == 'up_blocks':
+                    block_id = int(split_name[1]) + 5
+                    unet_embed = unet_embeds[:, block_id]
+                else:
+                    unet_embed = unet_embeds[:, -1]
+                module.de_mod = unet_embed.reshape(-1, self.lora_rank_unet, self.lora_rank_unet)
+
+        encoded_control = self.vae.encode(c_t).latent_dist.sample() * self.vae.config.scaling_factor
+        pos_model_pred = self.unet(encoded_control, self.timesteps, encoder_hidden_states=pos_caption_enc).sample
+        neg_model_pred = self.unet(encoded_control, self.timesteps, encoder_hidden_states=neg_caption_enc).sample
+        model_pred = neg_model_pred + self.guidance_scale * (pos_model_pred - neg_model_pred)
+
+        x_denoised = self.sched.step(model_pred, self.timesteps, encoded_control, return_dict=True).prev_sample
+        output_image = (self.vae.decode(x_denoised / self.vae.config.scaling_factor).sample).clamp(-1, 1)
+
+        return output_image
+
+    def save_model(self, outf):
+        sd = {}
+        sd["unet_lora_target_modules"] = self.target_modules_unet
+        sd["vae_lora_target_modules"] = self.target_modules_vae
+        sd["rank_unet"] = self.lora_rank_unet
+        sd["rank_vae"] = self.lora_rank_vae
+        sd["state_dict_unet"] = {k: v for k, v in self.unet.state_dict().items() if "lora" in k or "conv_in" in k}
+        sd["state_dict_vae"] = {k: v for k, v in self.vae.state_dict().items() if "lora" in k or "skip_conv" in k}
+        sd["state_dict_vae_de_mlp"] = {k: v for k, v in self.vae_de_mlp.state_dict().items()}
+        sd["state_dict_unet_de_mlp"] = {k: v for k, v in self.unet_de_mlp.state_dict().items()}
+        sd["state_dict_vae_block_mlp"] = {k: v for k, v in self.vae_block_mlp.state_dict().items()}
+        sd["state_dict_unet_block_mlp"] = {k: v for k, v in self.unet_block_mlp.state_dict().items()}
+        sd["state_dict_vae_fuse_mlp"] = {k: v for k, v in self.vae_fuse_mlp.state_dict().items()}
+        sd["state_dict_unet_fuse_mlp"] = {k: v for k, v in self.unet_fuse_mlp.state_dict().items()}
+        sd["w"] = self.W
+
+        sd["state_embeddings"] = {
+                    "state_dict_vae_block": self.vae_block_embeddings.state_dict(),
+                    "state_dict_unet_block": self.unet_block_embeddings.state_dict(),
+                }
+
+        torch.save(sd, outf)
diff --git a/src/s3diff_tile.py b/src/s3diff_tile.py
new file mode 100644
index 0000000000000000000000000000000000000000..7391d15868856c13860028f7b7a9e9e2b5036971
--- /dev/null
+++ b/src/s3diff_tile.py
@@ -0,0 +1,455 @@
+import os
+import re
+import requests
+import sys
+import copy
+import numpy as np
+from tqdm import tqdm
+import torch
+import torch.nn as nn
+from transformers import AutoTokenizer, CLIPTextModel
+from diffusers import AutoencoderKL, UNet2DConditionModel
+from peft import LoraConfig, get_peft_model
+p = "src/"
+sys.path.append(p)
+from model import make_1step_sched, my_lora_fwd
+from basicsr.archs.arch_util import default_init_weights
+from my_utils.vaehook import VAEHook, perfcount
+
+
+def get_layer_number(module_name):
+    base_layers = {
+        'down_blocks': 0,
+        'mid_block': 4,
+        'up_blocks': 5
+    }
+
+    if module_name == 'conv_out':
+        return 9
+
+    base_layer = None
+    for key in base_layers:
+        if key in module_name:
+            base_layer = base_layers[key]
+            break
+
+    if base_layer is None:
+        return None
+
+    additional_layers = int(re.findall(r'\.(\d+)', module_name)[0]) #sum(int(num) for num in re.findall(r'\d+', module_name))
+    final_layer = base_layer + additional_layers
+    return final_layer
+
+
+class S3Diff(torch.nn.Module):
+    def __init__(self, sd_path=None, pretrained_path=None, lora_rank_unet=32, lora_rank_vae=16, block_embedding_dim=64, args=None):
+        super().__init__()
+        self.args = args
+        self.latent_tiled_size = args.latent_tiled_size
+        self.latent_tiled_overlap = args.latent_tiled_overlap
+
+        self.tokenizer = AutoTokenizer.from_pretrained(sd_path, subfolder="tokenizer")
+        self.text_encoder = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder").cuda()
+        self.sched = make_1step_sched(sd_path)
+        self.guidance_scale = 1.07
+
+        vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae")
+        unet = UNet2DConditionModel.from_pretrained(sd_path, subfolder="unet")
+
+        target_modules_vae = r"^encoder\..*(conv1|conv2|conv_in|conv_shortcut|conv|conv_out|to_k|to_q|to_v|to_out\.0)$"
+        target_modules_unet = [
+            "to_k", "to_q", "to_v", "to_out.0", "conv", "conv1", "conv2", "conv_shortcut", "conv_out",
+            "proj_in", "proj_out", "ff.net.2", "ff.net.0.proj"
+        ]
+
+        num_embeddings = 64
+        self.W = nn.Parameter(torch.randn(num_embeddings), requires_grad=False)
+
+        self.vae_de_mlp = nn.Sequential(
+            nn.Linear(num_embeddings * 4, 256),
+            nn.ReLU(True),
+        )
+
+        self.unet_de_mlp = nn.Sequential(
+            nn.Linear(num_embeddings * 4, 256),
+            nn.ReLU(True),
+        )
+
+        self.vae_block_mlp = nn.Sequential(
+            nn.Linear(block_embedding_dim, 64),
+            nn.ReLU(True),
+        )
+
+        self.unet_block_mlp = nn.Sequential(
+            nn.Linear(block_embedding_dim, 64),
+            nn.ReLU(True),
+        )
+
+        self.vae_fuse_mlp = nn.Linear(256 + 64, lora_rank_vae ** 2)
+        self.unet_fuse_mlp = nn.Linear(256 + 64, lora_rank_unet ** 2)
+
+        default_init_weights([self.vae_de_mlp, self.unet_de_mlp, self.vae_block_mlp, self.unet_block_mlp, \
+            self.vae_fuse_mlp, self.unet_fuse_mlp], 1e-5)
+
+        # vae
+        self.vae_block_embeddings = nn.Embedding(6, block_embedding_dim)
+        self.unet_block_embeddings = nn.Embedding(10, block_embedding_dim)
+
+        if pretrained_path is not None:
+            sd = torch.load(pretrained_path, map_location="cpu")
+            vae_lora_config = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian", target_modules=sd["vae_lora_target_modules"])
+            vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
+            _sd_vae = vae.state_dict()
+            for k in sd["state_dict_vae"]:
+                _sd_vae[k] = sd["state_dict_vae"][k]
+            vae.load_state_dict(_sd_vae)
+
+            unet_lora_config = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_target_modules"])
+            unet.add_adapter(unet_lora_config)
+            _sd_unet = unet.state_dict()
+            for k in sd["state_dict_unet"]:
+                _sd_unet[k] = sd["state_dict_unet"][k]
+            unet.load_state_dict(_sd_unet)
+
+            _vae_de_mlp = self.vae_de_mlp.state_dict()
+            for k in sd["state_dict_vae_de_mlp"]:
+                _vae_de_mlp[k] = sd["state_dict_vae_de_mlp"][k]
+            self.vae_de_mlp.load_state_dict(_vae_de_mlp)
+
+            _unet_de_mlp = self.unet_de_mlp.state_dict()
+            for k in sd["state_dict_unet_de_mlp"]:
+                _unet_de_mlp[k] = sd["state_dict_unet_de_mlp"][k]
+            self.unet_de_mlp.load_state_dict(_unet_de_mlp)
+
+            _vae_block_mlp = self.vae_block_mlp.state_dict()
+            for k in sd["state_dict_vae_block_mlp"]:
+                _vae_block_mlp[k] = sd["state_dict_vae_block_mlp"][k]
+            self.vae_block_mlp.load_state_dict(_vae_block_mlp)
+
+            _unet_block_mlp = self.unet_block_mlp.state_dict()
+            for k in sd["state_dict_unet_block_mlp"]:
+                _unet_block_mlp[k] = sd["state_dict_unet_block_mlp"][k]
+            self.unet_block_mlp.load_state_dict(_unet_block_mlp)
+
+            _vae_fuse_mlp = self.vae_fuse_mlp.state_dict()
+            for k in sd["state_dict_vae_fuse_mlp"]:
+                _vae_fuse_mlp[k] = sd["state_dict_vae_fuse_mlp"][k]
+            self.vae_fuse_mlp.load_state_dict(_vae_fuse_mlp)
+
+            _unet_fuse_mlp = self.unet_fuse_mlp.state_dict()
+            for k in sd["state_dict_unet_fuse_mlp"]:
+                _unet_fuse_mlp[k] = sd["state_dict_unet_fuse_mlp"][k]
+            self.unet_fuse_mlp.load_state_dict(_unet_fuse_mlp)
+
+            self.W = nn.Parameter(sd["w"], requires_grad=False)
+
+            embeddings_state_dict = sd["state_embeddings"]
+            self.vae_block_embeddings.load_state_dict(embeddings_state_dict['state_dict_vae_block'])
+            self.unet_block_embeddings.load_state_dict(embeddings_state_dict['state_dict_unet_block'])
+        else:
+            print("Initializing model with random weights")
+            vae_lora_config = LoraConfig(r=lora_rank_vae, init_lora_weights="gaussian",
+                target_modules=target_modules_vae)
+            vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
+            unet_lora_config = LoraConfig(r=lora_rank_unet, init_lora_weights="gaussian",
+                target_modules=target_modules_unet
+            )
+            unet.add_adapter(unet_lora_config)
+
+        self.lora_rank_unet = lora_rank_unet
+        self.lora_rank_vae = lora_rank_vae
+        self.target_modules_vae = target_modules_vae
+        self.target_modules_unet = target_modules_unet
+
+        self.vae_lora_layers = []
+        for name, module in vae.named_modules():
+            if 'base_layer' in name:
+                self.vae_lora_layers.append(name[:-len(".base_layer")])
+                
+        for name, module in vae.named_modules():
+            if name in self.vae_lora_layers:
+                module.forward = my_lora_fwd.__get__(module, module.__class__)
+
+        self.unet_lora_layers = []
+        for name, module in unet.named_modules():
+            if 'base_layer' in name:
+                self.unet_lora_layers.append(name[:-len(".base_layer")])
+
+        for name, module in unet.named_modules():
+            if name in self.unet_lora_layers:
+                module.forward = my_lora_fwd.__get__(module, module.__class__)
+
+        self.unet_layer_dict = {name: get_layer_number(name) for name in self.unet_lora_layers}
+
+        unet.to("cuda")
+        vae.to("cuda")
+        self.unet, self.vae = unet, vae
+        self.timesteps = torch.tensor([999], device="cuda").long()
+        self.text_encoder.requires_grad_(False)
+
+        # vae tile
+        self._init_tiled_vae(encoder_tile_size=args.vae_encoder_tiled_size, decoder_tile_size=args.vae_decoder_tiled_size)
+
+    def set_eval(self):
+        self.unet.eval()
+        self.vae.eval()
+        self.vae_de_mlp.eval()
+        self.unet_de_mlp.eval()
+        self.vae_block_mlp.eval()
+        self.unet_block_mlp.eval()
+        self.vae_fuse_mlp.eval()
+        self.unet_fuse_mlp.eval()
+
+        self.vae_block_embeddings.requires_grad_(False)
+        self.unet_block_embeddings.requires_grad_(False)
+
+        self.unet.requires_grad_(False)
+        self.vae.requires_grad_(False)
+
+    def set_train(self):
+        self.unet.train()
+        self.vae.train()
+        self.vae_de_mlp.train()
+        self.unet_de_mlp.train()
+        self.vae_block_mlp.train()
+        self.unet_block_mlp.train()
+        self.vae_fuse_mlp.train()
+        self.unet_fuse_mlp.train()    
+
+        self.vae_block_embeddings.requires_grad_(True)
+        self.unet_block_embeddings.requires_grad_(True)
+
+        for n, _p in self.unet.named_parameters():
+            if "lora" in n:
+                _p.requires_grad = True
+        self.unet.conv_in.requires_grad_(True)
+
+        for n, _p in self.vae.named_parameters():
+            if "lora" in n:
+                _p.requires_grad = True
+
+    @perfcount
+    @torch.no_grad()
+    def forward(self, c_t, deg_score, pos_prompt, neg_prompt):
+ 
+        if pos_prompt is not None:
+            # encode the text prompt
+            pos_caption_tokens = self.tokenizer(pos_prompt, max_length=self.tokenizer.model_max_length,
+                                            padding="max_length", truncation=True, return_tensors="pt").input_ids.cuda()
+            pos_caption_enc = self.text_encoder(pos_caption_tokens)[0]
+        else:
+            pos_caption_enc = self.text_encoder(prompt_tokens)[0]
+
+        if neg_prompt is not None:
+            # encode the text prompt
+            neg_caption_tokens = self.tokenizer(neg_prompt, max_length=self.tokenizer.model_max_length,
+                                            padding="max_length", truncation=True, return_tensors="pt").input_ids.cuda()
+            neg_caption_enc = self.text_encoder(neg_caption_tokens)[0]
+        else:
+            neg_caption_enc = self.text_encoder(neg_prompt_tokens)[0]
+
+        # degradation fourier embedding
+        deg_proj = deg_score[..., None] * self.W[None, None, :] * 2 * np.pi
+        deg_proj = torch.cat([torch.sin(deg_proj), torch.cos(deg_proj)], dim=-1)
+        deg_proj = torch.cat([deg_proj[:, 0], deg_proj[:, 1]], dim=-1)
+
+        # degradation mlp forward
+        vae_de_c_embed = self.vae_de_mlp(deg_proj)
+        unet_de_c_embed = self.unet_de_mlp(deg_proj)
+
+        # block embedding mlp forward
+        vae_block_c_embeds = self.vae_block_mlp(self.vae_block_embeddings.weight)
+        unet_block_c_embeds = self.unet_block_mlp(self.unet_block_embeddings.weight)
+
+        vae_embeds = self.vae_fuse_mlp(torch.cat([vae_de_c_embed.unsqueeze(1).repeat(1, vae_block_c_embeds.shape[0], 1), \
+            vae_block_c_embeds.unsqueeze(0).repeat(vae_de_c_embed.shape[0],1,1)], -1))
+        unet_embeds = self.unet_fuse_mlp(torch.cat([unet_de_c_embed.unsqueeze(1).repeat(1, unet_block_c_embeds.shape[0], 1), \
+            unet_block_c_embeds.unsqueeze(0).repeat(unet_de_c_embed.shape[0],1,1)], -1))
+
+        for layer_name, module in self.vae.named_modules():
+            if layer_name in self.vae_lora_layers:
+                split_name = layer_name.split(".")
+                if split_name[1] == 'down_blocks':
+                    block_id = int(split_name[2])
+                    vae_embed = vae_embeds[:, block_id]
+                elif split_name[1] == 'mid_block':
+                    vae_embed = vae_embeds[:, -2]
+                else:
+                    vae_embed = vae_embeds[:, -1]
+                module.de_mod = vae_embed.reshape(-1, self.lora_rank_vae, self.lora_rank_vae)
+
+        for layer_name, module in self.unet.named_modules():
+            if layer_name in self.unet_lora_layers:
+                split_name = layer_name.split(".")
+                if split_name[0] == 'down_blocks':
+                    block_id = int(split_name[1])
+                    unet_embed = unet_embeds[:, block_id]
+                elif split_name[0] == 'mid_block':
+                    unet_embed = unet_embeds[:, 4]
+                elif split_name[0] == 'up_blocks':
+                    block_id = int(split_name[1]) + 5
+                    unet_embed = unet_embeds[:, block_id]
+                else:
+                    unet_embed = unet_embeds[:, -1]
+                module.de_mod = unet_embed.reshape(-1, self.lora_rank_unet, self.lora_rank_unet)
+
+        lq_latent = self.vae.encode(c_t).latent_dist.sample() * self.vae.config.scaling_factor
+
+        ## add tile function
+        _, _, h, w = lq_latent.size()
+        tile_size, tile_overlap = (self.latent_tiled_size, self.latent_tiled_overlap)
+        if h * w <= tile_size * tile_size:
+            print(f"[Tiled Latent]: the input size is tiny and unnecessary to tile.")
+            pos_model_pred = self.unet(lq_latent, self.timesteps, encoder_hidden_states=pos_caption_enc).sample
+            neg_model_pred = self.unet(lq_latent, self.timesteps, encoder_hidden_states=neg_caption_enc).sample
+            model_pred = neg_model_pred + self.guidance_scale * (pos_model_pred - neg_model_pred)
+        else:
+            print(f"[Tiled Latent]: the input size is {c_t.shape[-2]}x{c_t.shape[-1]}, need to tiled")
+            # tile_weights = self._gaussian_weights(tile_size, tile_size, 1).to()
+            tile_size = min(tile_size, min(h, w))
+            tile_weights = self._gaussian_weights(tile_size, tile_size, 1).to(c_t.device)
+
+            grid_rows = 0
+            cur_x = 0
+            while cur_x < lq_latent.size(-1):
+                cur_x = max(grid_rows * tile_size-tile_overlap * grid_rows, 0)+tile_size
+                grid_rows += 1
+
+            grid_cols = 0
+            cur_y = 0
+            while cur_y < lq_latent.size(-2):
+                cur_y = max(grid_cols * tile_size-tile_overlap * grid_cols, 0)+tile_size
+                grid_cols += 1
+
+            input_list = []
+            noise_preds = []
+            for row in range(grid_rows):
+                noise_preds_row = []
+                for col in range(grid_cols):
+                    if col < grid_cols-1 or row < grid_rows-1:
+                        # extract tile from input image
+                        ofs_x = max(row * tile_size-tile_overlap * row, 0)
+                        ofs_y = max(col * tile_size-tile_overlap * col, 0)
+                        # input tile area on total image
+                    if row == grid_rows-1:
+                        ofs_x = w - tile_size
+                    if col == grid_cols-1:
+                        ofs_y = h - tile_size
+
+                    input_start_x = ofs_x
+                    input_end_x = ofs_x + tile_size
+                    input_start_y = ofs_y
+                    input_end_y = ofs_y + tile_size
+
+                    # input tile dimensions
+                    input_tile = lq_latent[:, :, input_start_y:input_end_y, input_start_x:input_end_x]
+                    input_list.append(input_tile)
+
+                    if len(input_list) == 1 or col == grid_cols-1:
+                        input_list_t = torch.cat(input_list, dim=0)
+                        # predict the noise residual
+                        pos_model_pred = self.unet(input_list_t, self.timesteps, encoder_hidden_states=pos_caption_enc).sample
+                        neg_model_pred = self.unet(input_list_t, self.timesteps, encoder_hidden_states=neg_caption_enc).sample
+                        model_out = neg_model_pred + self.guidance_scale * (pos_model_pred - neg_model_pred)
+                        input_list = []
+                    noise_preds.append(model_out)
+
+            # Stitch noise predictions for all tiles
+            noise_pred = torch.zeros(lq_latent.shape, device=lq_latent.device)
+            contributors = torch.zeros(lq_latent.shape, device=lq_latent.device)
+            # Add each tile contribution to overall latents
+            for row in range(grid_rows):
+                for col in range(grid_cols):
+                    if col < grid_cols-1 or row < grid_rows-1:
+                        # extract tile from input image
+                        ofs_x = max(row * tile_size-tile_overlap * row, 0)
+                        ofs_y = max(col * tile_size-tile_overlap * col, 0)
+                        # input tile area on total image
+                    if row == grid_rows-1:
+                        ofs_x = w - tile_size
+                    if col == grid_cols-1:
+                        ofs_y = h - tile_size
+
+                    input_start_x = ofs_x
+                    input_end_x = ofs_x + tile_size
+                    input_start_y = ofs_y
+                    input_end_y = ofs_y + tile_size
+
+                    noise_pred[:, :, input_start_y:input_end_y, input_start_x:input_end_x] += noise_preds[row*grid_cols + col] * tile_weights
+                    contributors[:, :, input_start_y:input_end_y, input_start_x:input_end_x] += tile_weights
+            # Average overlapping areas with more than 1 contributor
+            noise_pred /= contributors
+            model_pred = noise_pred
+
+        x_denoised = self.sched.step(model_pred, self.timesteps, lq_latent, return_dict=True).prev_sample
+        output_image = (self.vae.decode(x_denoised / self.vae.config.scaling_factor).sample).clamp(-1, 1)
+
+        return output_image
+
+    def save_model(self, outf):
+        sd = {}
+        sd["unet_lora_target_modules"] = self.target_modules_unet
+        sd["vae_lora_target_modules"] = self.target_modules_vae
+        sd["rank_unet"] = self.lora_rank_unet
+        sd["rank_vae"] = self.lora_rank_vae
+        sd["state_dict_unet"] = {k: v for k, v in self.unet.state_dict().items() if "lora" in k or "conv_in" in k}
+        sd["state_dict_vae"] = {k: v for k, v in self.vae.state_dict().items() if "lora" in k or "skip_conv" in k}
+        sd["state_dict_vae_de_mlp"] = {k: v for k, v in self.vae_de_mlp.state_dict().items()}
+        sd["state_dict_unet_de_mlp"] = {k: v for k, v in self.unet_de_mlp.state_dict().items()}
+        sd["state_dict_vae_block_mlp"] = {k: v for k, v in self.vae_block_mlp.state_dict().items()}
+        sd["state_dict_unet_block_mlp"] = {k: v for k, v in self.unet_block_mlp.state_dict().items()}
+        sd["state_dict_vae_fuse_mlp"] = {k: v for k, v in self.vae_fuse_mlp.state_dict().items()}
+        sd["state_dict_unet_fuse_mlp"] = {k: v for k, v in self.unet_fuse_mlp.state_dict().items()}
+        sd["w"] = self.W
+
+        sd["state_embeddings"] = {
+                    "state_dict_vae_block": self.vae_block_embeddings.state_dict(),
+                    "state_dict_unet_block": self.unet_block_embeddings.state_dict(),
+                }
+
+        torch.save(sd, outf)
+
+    def _set_latent_tile(self,
+        latent_tiled_size = 96,
+        latent_tiled_overlap = 32):
+        self.latent_tiled_size = latent_tiled_size
+        self.latent_tiled_overlap = latent_tiled_overlap
+    
+    def _init_tiled_vae(self,
+            encoder_tile_size = 256,
+            decoder_tile_size = 256,
+            fast_decoder = False,
+            fast_encoder = False,
+            color_fix = False,
+            vae_to_gpu = True):
+        # save original forward (only once)
+        if not hasattr(self.vae.encoder, 'original_forward'):
+            setattr(self.vae.encoder, 'original_forward', self.vae.encoder.forward)
+        if not hasattr(self.vae.decoder, 'original_forward'):
+            setattr(self.vae.decoder, 'original_forward', self.vae.decoder.forward)
+
+        encoder = self.vae.encoder
+        decoder = self.vae.decoder
+
+        self.vae.encoder.forward = VAEHook(
+            encoder, encoder_tile_size, is_decoder=False, fast_decoder=fast_decoder, fast_encoder=fast_encoder, color_fix=color_fix, to_gpu=vae_to_gpu)
+        self.vae.decoder.forward = VAEHook(
+            decoder, decoder_tile_size, is_decoder=True, fast_decoder=fast_decoder, fast_encoder=fast_encoder, color_fix=color_fix, to_gpu=vae_to_gpu)
+
+    def _gaussian_weights(self, tile_width, tile_height, nbatches):
+        """Generates a gaussian mask of weights for tile contributions"""
+        from numpy import pi, exp, sqrt
+        import numpy as np
+
+        latent_width = tile_width
+        latent_height = tile_height
+
+        var = 0.01
+        midpoint = (latent_width - 1) / 2  # -1 because index goes from 0 to latent_width - 1
+        x_probs = [exp(-(x-midpoint)*(x-midpoint)/(latent_width*latent_width)/(2*var)) / sqrt(2*pi*var) for x in range(latent_width)]
+        midpoint = latent_height / 2
+        y_probs = [exp(-(y-midpoint)*(y-midpoint)/(latent_height*latent_height)/(2*var)) / sqrt(2*pi*var) for y in range(latent_height)]
+
+        weights = np.outer(y_probs, x_probs)
+        return torch.tile(torch.tensor(weights), (nbatches, self.unet.config.in_channels, 1, 1))
+
diff --git a/src/train_s3diff.py b/src/train_s3diff.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee5b2129b1fa5171de529566f78b6d80ff2d1580
--- /dev/null
+++ b/src/train_s3diff.py
@@ -0,0 +1,284 @@
+import os
+os.environ['TORCH_DISTRIBUTED_DEBUG'] = 'INFO'
+
+import gc
+import lpips
+import clip
+import random
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+
+from omegaconf import OmegaConf
+from accelerate import Accelerator
+from accelerate.utils import set_seed
+from PIL import Image
+from torchvision import transforms
+from tqdm.auto import tqdm
+
+import diffusers
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.optimization import get_scheduler
+
+from de_net import DEResNet
+from s3diff import S3Diff
+from my_utils.training_utils import parse_args_paired_training, PairedDataset, degradation_proc
+
+def main(args):
+
+    # init and save configs
+    config = OmegaConf.load(args.base_config)
+
+    accelerator = Accelerator(
+        gradient_accumulation_steps=args.gradient_accumulation_steps,
+        mixed_precision=args.mixed_precision,
+        log_with=args.report_to,
+    )
+
+    if accelerator.is_local_main_process:
+        transformers.utils.logging.set_verbosity_warning()
+        diffusers.utils.logging.set_verbosity_info()
+    else:
+        transformers.utils.logging.set_verbosity_error()
+        diffusers.utils.logging.set_verbosity_error()
+
+    if args.seed is not None:
+        set_seed(args.seed)
+
+    if accelerator.is_main_process:
+        os.makedirs(os.path.join(args.output_dir, "checkpoints"), exist_ok=True)
+        os.makedirs(os.path.join(args.output_dir, "eval"), exist_ok=True)
+
+    # initialize degradation estimation network
+    net_de = DEResNet(num_in_ch=3, num_degradation=2)
+    net_de.load_model(args.de_net_path)
+    net_de = net_de.cuda()
+    net_de.eval()
+
+    # initialize net_sr
+    net_sr = S3Diff(lora_rank_unet=args.lora_rank_unet, lora_rank_vae=args.lora_rank_vae, sd_path=args.sd_path, pretrained_path=args.pretrained_path)
+    net_sr.set_train()
+
+    if args.enable_xformers_memory_efficient_attention:
+        if is_xformers_available():
+            net_sr.unet.enable_xformers_memory_efficient_attention()
+        else:
+            raise ValueError("xformers is not available, please install it by running `pip install xformers`")
+
+    if args.gradient_checkpointing:
+        net_sr.unet.enable_gradient_checkpointing()
+
+    if args.allow_tf32:
+        torch.backends.cuda.matmul.allow_tf32 = True
+
+    if args.gan_disc_type == "vagan":
+        import vision_aided_loss
+        net_disc = vision_aided_loss.Discriminator(cv_type='dino', output_type='conv_multi_level', loss_type=args.gan_loss_type, device="cuda")
+    else:
+        raise NotImplementedError(f"Discriminator type {args.gan_disc_type} not implemented")
+
+    net_disc = net_disc.cuda()
+    net_disc.requires_grad_(True)
+    net_disc.cv_ensemble.requires_grad_(False)
+    net_disc.train()
+
+    net_lpips = lpips.LPIPS(net='vgg').cuda()
+    net_lpips.requires_grad_(False)
+
+    # make the optimizer
+    layers_to_opt = []
+    layers_to_opt = layers_to_opt + list(net_sr.vae_block_embeddings.parameters()) + list(net_sr.unet_block_embeddings.parameters())
+    layers_to_opt = layers_to_opt + list(net_sr.vae_de_mlp.parameters()) + list(net_sr.unet_de_mlp.parameters()) + \
+        list(net_sr.vae_block_mlp.parameters()) + list(net_sr.unet_block_mlp.parameters()) + \
+        list(net_sr.vae_fuse_mlp.parameters()) + list(net_sr.unet_fuse_mlp.parameters())
+
+    for n, _p in net_sr.unet.named_parameters():
+        if "lora" in n:
+            assert _p.requires_grad
+            layers_to_opt.append(_p)
+    layers_to_opt += list(net_sr.unet.conv_in.parameters())
+
+    for n, _p in net_sr.vae.named_parameters():
+        if "lora" in n:
+            assert _p.requires_grad
+            layers_to_opt.append(_p)
+
+    dataset_train = PairedDataset(config.train)
+    dl_train = torch.utils.data.DataLoader(dataset_train, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers)
+    dataset_val = PairedDataset(config.validation)
+    dl_val = torch.utils.data.DataLoader(dataset_val, batch_size=1, shuffle=False, num_workers=0)
+
+
+    optimizer = torch.optim.AdamW(layers_to_opt, lr=args.learning_rate,
+        betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
+        eps=args.adam_epsilon,)
+    lr_scheduler = get_scheduler(args.lr_scheduler, optimizer=optimizer,
+        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+        num_training_steps=args.max_train_steps * accelerator.num_processes,
+        num_cycles=args.lr_num_cycles, power=args.lr_power,)
+
+    optimizer_disc = torch.optim.AdamW(net_disc.parameters(), lr=args.learning_rate,
+        betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
+        eps=args.adam_epsilon,)
+    lr_scheduler_disc = get_scheduler(args.lr_scheduler, optimizer=optimizer_disc,
+            num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+            num_training_steps=args.max_train_steps * accelerator.num_processes,
+            num_cycles=args.lr_num_cycles, power=args.lr_power)
+
+    # Prepare everything with our `accelerator`.
+    net_sr, net_disc, optimizer, optimizer_disc, dl_train, lr_scheduler, lr_scheduler_disc = accelerator.prepare(
+        net_sr, net_disc, optimizer, optimizer_disc, dl_train, lr_scheduler, lr_scheduler_disc
+    )
+    net_de, net_lpips = accelerator.prepare(net_de, net_lpips)
+    # # renorm with image net statistics
+    weight_dtype = torch.float32
+    if accelerator.mixed_precision == "fp16":
+        weight_dtype = torch.float16
+    elif accelerator.mixed_precision == "bf16":
+        weight_dtype = torch.bfloat16
+
+    # Move al networksr to device and cast to weight_dtype
+    net_sr.to(accelerator.device, dtype=weight_dtype)
+    net_de.to(accelerator.device, dtype=weight_dtype)
+    net_disc.to(accelerator.device, dtype=weight_dtype)
+    net_lpips.to(accelerator.device, dtype=weight_dtype)
+
+    progress_bar = tqdm(range(0, args.max_train_steps), initial=0, desc="Steps",
+        disable=not accelerator.is_local_main_process,)
+
+    for name, module in net_disc.named_modules():
+        if "attn" in name:
+            module.fused_attn = False
+
+    # start the training loop
+    global_step = 0
+    for epoch in range(0, args.num_training_epochs):
+        for step, batch in enumerate(dl_train):
+            l_acc = [net_sr, net_disc]
+            with accelerator.accumulate(*l_acc):
+                x_src, x_tgt, x_ori_size_src = degradation_proc(config, batch, accelerator.device)
+                B, C, H, W = x_src.shape
+                with torch.no_grad():
+                    deg_score = net_de(x_ori_size_src.detach()).detach()
+
+                pos_tag_prompt = [args.pos_prompt for _ in range(B)]                
+                neg_tag_prompt = [args.neg_prompt for _ in range(B)]
+
+                neg_probs = torch.rand(B).to(accelerator.device)
+                
+                # build mixed prompt and target
+                mixed_tag_prompt = [_neg_tag if p_i < args.neg_prob else _pos_tag for _neg_tag, _pos_tag, p_i in zip(neg_tag_prompt, pos_tag_prompt, neg_probs)]
+                neg_probs = neg_probs.reshape(B, 1, 1, 1)
+                mixed_tgt = torch.where(neg_probs < args.neg_prob, x_src, x_tgt)
+
+                x_tgt_pred = net_sr(x_src.detach(), deg_score, mixed_tag_prompt)
+                loss_l2 = F.mse_loss(x_tgt_pred.float(), mixed_tgt.detach().float(), reduction="mean") * args.lambda_l2
+                loss_lpips = net_lpips(x_tgt_pred.float(), mixed_tgt.detach().float()).mean() * args.lambda_lpips
+
+                loss = loss_l2 + loss_lpips
+
+                accelerator.backward(loss, retain_graph=False)
+                if accelerator.sync_gradients:
+                    accelerator.clip_grad_norm_(layers_to_opt, args.max_grad_norm)
+                optimizer.step()
+                lr_scheduler.step()
+                optimizer.zero_grad(set_to_none=args.set_grads_to_none)
+
+                """
+                Generator loss: fool the discriminator
+                """
+                x_tgt_pred = net_sr(x_src.detach(), deg_score, pos_tag_prompt)
+                lossG = net_disc(x_tgt_pred, for_G=True).mean() * args.lambda_gan
+                accelerator.backward(lossG)
+                if accelerator.sync_gradients:
+                    accelerator.clip_grad_norm_(layers_to_opt, args.max_grad_norm)
+                optimizer.step()
+                lr_scheduler.step()
+                optimizer.zero_grad(set_to_none=args.set_grads_to_none)
+
+                """
+                Discriminator loss: fake image vs real image
+                """
+                # real image
+                lossD_real = net_disc(x_tgt.detach(), for_real=True).mean() * args.lambda_gan
+                accelerator.backward(lossD_real.mean())
+                if accelerator.sync_gradients:
+                    accelerator.clip_grad_norm_(net_disc.parameters(), args.max_grad_norm)
+                optimizer_disc.step()
+                lr_scheduler_disc.step()
+                optimizer_disc.zero_grad(set_to_none=args.set_grads_to_none)
+                # fake image
+                lossD_fake = net_disc(x_tgt_pred.detach(), for_real=False).mean() * args.lambda_gan
+                accelerator.backward(lossD_fake.mean())
+                if accelerator.sync_gradients:
+                    accelerator.clip_grad_norm_(net_disc.parameters(), args.max_grad_norm)
+                optimizer_disc.step()
+                optimizer_disc.zero_grad(set_to_none=args.set_grads_to_none)
+                lossD = lossD_real + lossD_fake
+
+            # Checks if the accelerator has performed an optimization step behind the scenes
+            if accelerator.sync_gradients:
+                progress_bar.update(1)
+                global_step += 1
+
+                if accelerator.is_main_process:
+                    logs = {}
+                    logs["lossG"] = lossG.detach().item()
+                    logs["lossD"] = lossD.detach().item()
+                    logs["loss_l2"] = loss_l2.detach().item()
+                    logs["loss_lpips"] = loss_lpips.detach().item()
+                    progress_bar.set_postfix(**logs)
+
+                    # checkpoint the model
+                    if global_step % args.checkpointing_steps == 1:
+                        outf = os.path.join(args.output_dir, "checkpoints", f"model_{global_step}.pkl")
+                        accelerator.unwrap_model(net_sr).save_model(outf)
+
+                    # compute validation set FID, L2, LPIPS, CLIP-SIM
+                    if global_step % args.eval_freq == 1:
+                        l_l2, l_lpips = [], []
+        
+                        val_count = 0
+                        for step, batch_val in enumerate(dl_val):
+                            if step >= args.num_samples_eval:
+                                break
+                            x_src, x_tgt, x_ori_size_src = degradation_proc(config, batch_val, accelerator.device)
+                            B, C, H, W = x_src.shape
+                            assert B == 1, "Use batch size 1 for eval."
+                            with torch.no_grad():
+                                # forward pass
+                                with torch.no_grad():
+                                    deg_score = net_de(x_ori_size_src.detach())
+
+                                pos_tag_prompt = [args.pos_prompt for _ in range(B)]
+                                x_tgt_pred = accelerator.unwrap_model(net_sr)(x_src.detach(), deg_score, pos_tag_prompt)
+                                # compute the reconstruction losses
+                                loss_l2 = F.mse_loss(x_tgt_pred.float(), x_tgt.detach().float(), reduction="mean")
+                                loss_lpips = net_lpips(x_tgt_pred.float(), x_tgt.detach().float()).mean()
+
+                                l_l2.append(loss_l2.item())
+                                l_lpips.append(loss_lpips.item())
+
+                            if args.save_val and val_count < 5:
+                                x_src = x_src.cpu().detach() * 0.5 + 0.5
+                                x_tgt = x_tgt.cpu().detach() * 0.5 + 0.5
+                                x_tgt_pred = x_tgt_pred.cpu().detach() * 0.5 + 0.5
+
+                                combined = torch.cat([x_src, x_tgt_pred, x_tgt], dim=3)
+                                output_pil = transforms.ToPILImage()(combined[0])
+                                outf = os.path.join(args.output_dir, f"val_{step}.png")
+                                output_pil.save(outf)
+                                val_count += 1
+
+                        logs["val/l2"] = np.mean(l_l2)
+                        logs["val/lpips"] = np.mean(l_lpips)
+                        gc.collect()
+                        torch.cuda.empty_cache()
+                    accelerator.log(logs, step=global_step)
+
+
+if __name__ == "__main__":
+    args = parse_args_paired_training()
+    main(args)
diff --git a/utils/misc.py b/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..999d2f5205b22dda9ac8acc3e738fa0e1099a426
--- /dev/null
+++ b/utils/misc.py
@@ -0,0 +1,528 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# DeiT: https://github.com/facebookresearch/deit
+# BEiT: https://github.com/microsoft/unilm/tree/master/beit
+# --------------------------------------------------------
+
+import builtins
+import datetime
+import os
+import time
+from collections import defaultdict, deque
+from pathlib import Path
+import json
+import subprocess
+
+import torch
+import torch.distributed as dist
+
+from typing import List, Dict, Tuple, Optional
+from torch import Tensor
+
+class SmoothedValue(object):
+    """Track a series of values and provide access to smoothed values over a
+    window or the global series average.
+    """
+
+    def __init__(self, window_size=20, fmt=None):
+        if fmt is None:
+            fmt = "{median:.4f} ({global_avg:.4f})"
+        self.deque = deque(maxlen=window_size)
+        self.total = 0.0
+        self.count = 0
+        self.fmt = fmt
+
+    def update(self, value, n=1):
+        self.deque.append(value)
+        self.count += n
+        self.total += value * n
+
+    def synchronize_between_processes(self):
+        """
+        Warning: does not synchronize the deque!
+        """
+        if not is_dist_avail_and_initialized():
+            return
+        t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
+        dist.barrier()
+        dist.all_reduce(t)
+        t = t.tolist()
+        self.count = int(t[0])
+        self.total = t[1]
+
+    @property
+    def median(self):
+        d = torch.tensor(list(self.deque))
+        return d.median().item()
+
+    @property
+    def avg(self):
+        d = torch.tensor(list(self.deque), dtype=torch.float32)
+        return d.mean().item()
+
+    @property
+    def global_avg(self):
+        return self.total / self.count
+
+    @property
+    def max(self):
+        return max(self.deque)
+
+    @property
+    def value(self):
+        return self.deque[-1]
+
+    def __str__(self):
+        return self.fmt.format(
+            median=self.median,
+            avg=self.avg,
+            global_avg=self.global_avg,
+            max=self.max,
+            value=self.value)
+
+
+class MetricLogger(object):
+    def __init__(self, delimiter="\t"):
+        self.meters = defaultdict(SmoothedValue)
+        self.delimiter = delimiter
+
+    def update(self, **kwargs):
+        for k, v in kwargs.items():
+            if v is None:
+                continue
+            if isinstance(v, torch.Tensor):
+                v = v.item()
+            assert isinstance(v, (float, int))
+            self.meters[k].update(v)
+
+    def __getattr__(self, attr):
+        if attr in self.meters:
+            return self.meters[attr]
+        if attr in self.__dict__:
+            return self.__dict__[attr]
+        raise AttributeError("'{}' object has no attribute '{}'".format(
+            type(self).__name__, attr))
+
+    def __str__(self):
+        loss_str = []
+        for name, meter in self.meters.items():
+            loss_str.append(
+                "{}: {}".format(name, str(meter))
+            )
+        return self.delimiter.join(loss_str)
+
+    def synchronize_between_processes(self):
+        for meter in self.meters.values():
+            meter.synchronize_between_processes()
+
+    def add_meter(self, name, meter):
+        self.meters[name] = meter
+
+    def log_every(self, iterable, print_freq, header=None):
+        i = 0
+        if not header:
+            header = ''
+        start_time = time.time()
+        end = time.time()
+        iter_time = SmoothedValue(fmt='{avg:.4f}')
+        data_time = SmoothedValue(fmt='{avg:.4f}')
+        space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
+        log_msg = [
+            header,
+            '[{0' + space_fmt + '}/{1}]',
+            'eta: {eta}',
+            '{meters}',
+            'time: {time}',
+            'data: {data}'
+        ]
+        if torch.cuda.is_available():
+            log_msg.append('max mem: {memory:.0f}')
+        log_msg = self.delimiter.join(log_msg)
+        MB = 1024.0 * 1024.0
+        for obj in iterable:
+            data_time.update(time.time() - end)
+            yield obj
+            iter_time.update(time.time() - end)
+            if i % print_freq == 0 or i == len(iterable) - 1:
+                eta_seconds = iter_time.global_avg * (len(iterable) - i)
+                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+                if torch.cuda.is_available():
+                    print(log_msg.format(
+                        i, len(iterable), eta=eta_string,
+                        meters=str(self),
+                        time=str(iter_time), data=str(data_time),
+                        memory=torch.cuda.max_memory_allocated() / MB))
+                else:
+                    print(log_msg.format(
+                        i, len(iterable), eta=eta_string,
+                        meters=str(self),
+                        time=str(iter_time), data=str(data_time)))
+            i += 1
+            end = time.time()
+        total_time = time.time() - start_time
+        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+        print('{} Total time: {} ({:.4f} s / it)'.format(
+            header, total_time_str, total_time / len(iterable)))
+
+
+def setup_for_distributed(is_master):
+    """
+    This function disables printing when not in master process
+    """
+    builtin_print = builtins.print
+
+    def print(*args, **kwargs):
+        force = kwargs.pop('force', False)
+        force = force or (get_world_size() > 8)
+        if is_master or force:
+            now = datetime.datetime.now().time()
+            builtin_print('[{}] '.format(now), end='')  # print with time stamp
+            builtin_print(*args, **kwargs)
+
+    builtins.print = print
+
+
+def is_dist_avail_and_initialized():
+    if not dist.is_available():
+        return False
+    if not dist.is_initialized():
+        return False
+    return True
+
+
+def get_world_size():
+    if not is_dist_avail_and_initialized():
+        return 1
+    return dist.get_world_size()
+
+
+def get_rank():
+    if not is_dist_avail_and_initialized():
+        return 0
+    return dist.get_rank()
+
+
+def is_main_process():
+    return get_rank() == 0
+
+
+def save_on_master(*args, **kwargs):
+    if is_main_process():
+        torch.save(*args, **kwargs)
+
+
+def init_distributed_mode(args):
+    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
+        args.rank = int(os.environ["RANK"])
+        args.world_size = int(os.environ['WORLD_SIZE'])
+        args.gpu = int(os.environ['LOCAL_RANK'])
+        args.dist_url = 'env://'
+        os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count())
+    elif 'SLURM_PROCID' in os.environ:
+        proc_id = int(os.environ['SLURM_PROCID'])
+        ntasks = int(os.environ['SLURM_NTASKS'])
+        node_list = os.environ['SLURM_NODELIST']
+        num_gpus = torch.cuda.device_count()
+        addr = subprocess.getoutput(
+            'scontrol show hostname {} | head -n1'.format(node_list))
+        os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29200')
+        os.environ['MASTER_ADDR'] = addr
+        os.environ['WORLD_SIZE'] = str(ntasks)
+        os.environ['RANK'] = str(proc_id)
+        os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
+        os.environ['LOCAL_SIZE'] = str(num_gpus)
+        args.dist_url = 'env://'
+        args.world_size = ntasks
+        args.rank = proc_id
+        args.gpu = proc_id % num_gpus
+    else:
+        print('Not using distributed mode')
+        args.distributed = False
+        return
+
+    args.distributed = True
+
+    torch.cuda.set_device(args.gpu)
+    args.dist_backend = 'nccl'
+    print('| distributed init (rank {}): {}'.format(
+        args.rank, args.dist_url), flush=True)
+    torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
+                                         world_size=args.world_size, rank=args.rank)
+    torch.distributed.barrier()
+    setup_for_distributed(args.rank == 0)
+
+def clip_grad_norm_(
+        parameters, max_norm: float, norm_type: float = 2.0,
+        error_if_nonfinite: bool = False, foreach: Optional[bool] = None) -> torch.Tensor:
+    r"""Clips gradient norm of an iterable of parameters.
+
+    The norm is computed over all gradients together, as if they were
+    concatenated into a single vector. Gradients are modified in-place.
+
+    Args:
+        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
+            single Tensor that will have gradients normalized
+        max_norm (float): max norm of the gradients
+        norm_type (float): type of the used p-norm. Can be ``'inf'`` for
+            infinity norm.
+        error_if_nonfinite (bool): if True, an error is thrown if the total
+            norm of the gradients from :attr:`parameters` is ``nan``,
+            ``inf``, or ``-inf``. Default: False (will switch to True in the future)
+        foreach (bool): use the faster foreach-based implementation.
+            If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
+            fall back to the slow implementation for other device types.
+            Default: ``None``
+
+    Returns:
+        Total norm of the parameter gradients (viewed as a single vector).
+    """
+    if isinstance(parameters, torch.Tensor):
+        parameters = [parameters]
+    grads = [p.grad for p in parameters if p.grad is not None]
+    
+    max_norm = float(max_norm)
+    norm_type = float(norm_type)
+    if len(grads) == 0:
+        return torch.tensor(0.)
+    first_device = grads[0].device
+    grouped_grads: Dict[Tuple[torch.device, torch.dtype], List[List[Tensor]]] \
+        = {(first_device, grads[0].dtype): [[g.detach() for g in grads]]}
+    
+    norms = [torch.norm(g) for g in grads]
+    total_norm = torch.norm(torch.stack(norms))
+
+    clip_coef = max_norm / (total_norm + 1e-6)
+    # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
+    # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
+    # when the gradients do not reside in CPU memory.
+    clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
+    for ((device, _), [grads]) in grouped_grads.items():
+        if (foreach is None or foreach):
+            torch._foreach_mul_(grads, clip_coef_clamped.to(device))  # type: ignore[call-overload]
+        elif foreach:
+            raise RuntimeError(f'foreach=True was passed, but can\'t use the foreach API on {device.type} tensors')
+        else:
+            clip_coef_clamped_device = clip_coef_clamped.to(device)
+            for g in grads:
+                g.detach().mul_(clip_coef_clamped_device)
+
+    return total_norm
+
+
+class NativeScalerWithGradNormCount:
+    state_dict_key = "amp_scaler"
+
+    def __init__(self):
+        self._scaler = torch.cuda.amp.GradScaler()
+
+    def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
+
+        self._scaler.scale(loss).backward(create_graph=create_graph)
+        if update_grad:
+            if clip_grad is not None:
+                assert parameters is not None
+                self._scaler.unscale_(optimizer)  # unscale the gradients of optimizer's assigned params in-place
+                norm = clip_grad_norm_(parameters, clip_grad)
+            else:
+                self._scaler.unscale_(optimizer)
+                norm = get_grad_norm_(parameters)
+            self._scaler.step(optimizer)
+            self._scaler.update()
+        else:
+            norm = None
+        return norm
+
+    def state_dict(self):
+        return self._scaler.state_dict()
+
+    def load_state_dict(self, state_dict):
+        self._scaler.load_state_dict(state_dict)
+
+
+def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
+    if isinstance(parameters, torch.Tensor):
+        parameters = [parameters]
+    parameters = [p for p in parameters if p.grad is not None]
+    norm_type = float(norm_type)
+    if len(parameters) == 0:
+        return torch.tensor(0.)
+    device = parameters[0].grad.device
+    if norm_type == inf:
+        total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
+    else:
+        total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
+    return total_norm
+
+
+def save_model(args, epoch, model, model_without_ddp, optimizer):
+    output_dir = Path(args.output_dir)
+    epoch_name = str(epoch)
+  
+    # checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
+    checkpoint_paths = [output_dir / 'checkpoint.pth']
+    for checkpoint_path in checkpoint_paths:
+        to_save = {
+            'model': model_without_ddp.state_dict(),
+            'optimizer': optimizer.state_dict(),
+            'epoch': epoch,
+            'args': args,
+        }
+
+        save_on_master(to_save, checkpoint_path)
+
+def load_model(args, model_without_ddp, optimizer):
+    if args.resume:
+        if args.resume.startswith('https'):
+            checkpoint = torch.hub.load_state_dict_from_url(
+                args.resume, map_location='cpu', check_hash=True)
+        else:
+            checkpoint = torch.load(args.resume, map_location='cpu')
+        model_without_ddp.load_state_dict(checkpoint['model'])
+        print("Resume checkpoint %s" % args.resume)
+        if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
+            optimizer.load_state_dict(checkpoint['optimizer'])
+            args.start_epoch = checkpoint['epoch'] + 1
+            print("With optim & sched!")
+
+def auto_load_model(args, model, model_without_ddp, optimizer):
+    output_dir = Path(args.output_dir)
+
+    # torch.amp
+    if args.auto_resume and len(args.resume) == 0:
+        import glob
+        all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth'))
+        latest_ckpt = -1
+        for ckpt in all_checkpoints:
+            t = ckpt.split('-')[-1].split('.')[0]
+            if t.isdigit():
+                latest_ckpt = max(int(t), latest_ckpt)
+        if latest_ckpt >= 0:
+            args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt)
+        print("Auto resume checkpoint: %s" % args.resume)
+
+    if args.resume:
+        if args.resume.startswith('https'):
+            checkpoint = torch.hub.load_state_dict_from_url(
+                args.resume, map_location='cpu', check_hash=True)
+        else:
+            checkpoint = torch.load(args.resume, map_location='cpu')
+        model_without_ddp.load_state_dict(checkpoint['model'])
+        print("Resume checkpoint %s" % args.resume)
+        if 'optimizer' in checkpoint and 'epoch' in checkpoint:
+            optimizer.load_state_dict(checkpoint['optimizer'])
+            args.start_epoch = checkpoint['epoch'] + 1
+            print("With optim & sched!")
+ 
+
+def all_reduce_mean(x):
+    world_size = get_world_size()
+    if world_size > 1:
+        x_reduce = torch.tensor(x).cuda()
+        dist.all_reduce(x_reduce)
+        x_reduce /= world_size
+        return x_reduce.item()
+    else:
+        return x
+
+
+def create_ds_config(args):
+    args.deepspeed_config = os.path.join(args.output_dir, "deepspeed_config.json")
+    with open(args.deepspeed_config, mode="w") as writer:
+        ds_config = {
+            "train_batch_size": args.batch_size * args.accum_iter * get_world_size(),
+            "train_micro_batch_size_per_gpu": args.batch_size,
+            "steps_per_print": 1000,
+            "optimizer": {
+                "type": "Adam",
+                "adam_w_mode": True,
+                "params": {
+                    "lr": args.lr,
+                    "weight_decay": args.weight_decay,
+                    "bias_correction": True,
+                    "betas": [
+                        args.opt_betas[0],
+                        args.opt_betas[1]
+                    ],
+                    "eps": args.opt_eps
+                }
+            },
+            "fp16": {
+                "enabled": True,
+                "loss_scale": 0,
+                "initial_scale_power": 16,
+                "loss_scale_window": 1000,
+                "hysteresis": 2,
+                "min_loss_scale": 1
+            },
+            # "bf16": {
+            #     "enabled": True
+            # },
+            "amp": {
+                "enabled": False,
+                "opt_level": "O2"
+            },
+            "flops_profiler": {
+                "enabled": True,
+                "profile_step": -1,
+                "module_depth": -1,
+                "top_modules": 1,
+                "detailed": True,
+            },
+        }
+
+        if args.clip_grad is not None:
+            ds_config.update({'gradient_clipping': args.clip_grad})
+
+        if args.zero_stage == 1:
+            ds_config.update({"zero_optimization": {"stage": args.zero_stage, "reduce_bucket_size": 5e8}})
+        elif args.zero_stage > 1:
+            raise NotImplementedError()
+
+        writer.write(json.dumps(ds_config, indent=2))
+
+def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None):
+    parameter_group_names = {}
+    parameter_group_vars = {}
+
+    for name, param in model.named_parameters():
+        if not param.requires_grad:
+            continue  # frozen weights
+        if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
+            group_name = "no_decay"
+            this_weight_decay = 0.
+        else:
+            group_name = "decay"
+            this_weight_decay = weight_decay
+        if get_num_layer is not None:
+            layer_id = get_num_layer(name)
+            group_name = "layer_%d_%s" % (layer_id, group_name)
+        else:
+            layer_id = None
+
+        if group_name not in parameter_group_names:
+            if get_layer_scale is not None:
+                scale = get_layer_scale(layer_id)
+            else:
+                scale = 1.
+
+            parameter_group_names[group_name] = {
+                "weight_decay": this_weight_decay,
+                "params": [],
+                "lr_scale": scale
+            }
+            parameter_group_vars[group_name] = {
+                "weight_decay": this_weight_decay,
+                "params": [],
+                "lr_scale": scale
+            }
+
+        parameter_group_vars[group_name]["params"].append(param)
+        parameter_group_names[group_name]["params"].append(name)
+    print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
+    return list(parameter_group_vars.values())
+
diff --git a/utils/util_image.py b/utils/util_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..88d6f307b208ead04389c870d1b24a82c5c0960e
--- /dev/null
+++ b/utils/util_image.py
@@ -0,0 +1,935 @@
+#!/usr/bin/env python
+# -*- coding:utf-8 -*-
+# Power by Zongsheng Yue 2021-11-24 16:54:19
+
+import sys
+import cv2
+import math
+import torch
+import random
+import numpy as np
+from scipy import fft
+from pathlib import Path
+from einops import rearrange
+from skimage import img_as_ubyte, img_as_float32
+
+# --------------------------Metrics----------------------------
+def ssim(img1, img2):
+    C1 = (0.01 * 255)**2
+    C2 = (0.03 * 255)**2
+
+    img1 = img1.astype(np.float64)
+    img2 = img2.astype(np.float64)
+    kernel = cv2.getGaussianKernel(11, 1.5)
+    window = np.outer(kernel, kernel.transpose())
+
+    mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]  # valid
+    mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
+    mu1_sq = mu1**2
+    mu2_sq = mu2**2
+    mu1_mu2 = mu1 * mu2
+    sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
+    sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
+    sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
+
+    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
+                                                            (sigma1_sq + sigma2_sq + C2))
+    return ssim_map.mean()
+
+def calculate_ssim(im1, im2, border=0, ycbcr=False):
+    '''
+    SSIM the same outputs as MATLAB's
+    im1, im2: h x w x , [0, 255], uint8
+    '''
+    if not im1.shape == im2.shape:
+        raise ValueError('Input images must have the same dimensions.')
+
+    if ycbcr:
+        im1 = rgb2ycbcr(im1, True)
+        im2 = rgb2ycbcr(im2, True)
+
+    h, w = im1.shape[:2]
+    im1 = im1[border:h-border, border:w-border]
+    im2 = im2[border:h-border, border:w-border]
+
+    if im1.ndim == 2:
+        return ssim(im1, im2)
+    elif im1.ndim == 3:
+        if im1.shape[2] == 3:
+            ssims = []
+            for i in range(3):
+                ssims.append(ssim(im1[:,:,i], im2[:,:,i]))
+            return np.array(ssims).mean()
+        elif im1.shape[2] == 1:
+            return ssim(np.squeeze(im1), np.squeeze(im2))
+    else:
+        raise ValueError('Wrong input image dimensions.')
+
+def calculate_psnr(im1, im2, border=0, ycbcr=False):
+    '''
+    PSNR metric.
+    im1, im2: h x w x , [0, 255], uint8
+    '''
+    if not im1.shape == im2.shape:
+        raise ValueError('Input images must have the same dimensions.')
+
+    if ycbcr:
+        im1 = rgb2ycbcr(im1, True)
+        im2 = rgb2ycbcr(im2, True)
+
+    h, w = im1.shape[:2]
+    im1 = im1[border:h-border, border:w-border]
+    im2 = im2[border:h-border, border:w-border]
+
+    im1 = im1.astype(np.float64)
+    im2 = im2.astype(np.float64)
+    mse = np.mean((im1 - im2)**2)
+    if mse == 0:
+        return float('inf')
+    return 20 * math.log10(255.0 / math.sqrt(mse))
+
+def batch_PSNR(img, imclean, border=0, ycbcr=False):
+    if ycbcr:
+        img = rgb2ycbcrTorch(img, True)
+        imclean = rgb2ycbcrTorch(imclean, True)
+    Img = img.data.cpu().numpy()
+    Iclean = imclean.data.cpu().numpy()
+    Img = img_as_ubyte(Img)
+    Iclean = img_as_ubyte(Iclean)
+    PSNR = 0
+    h, w = Iclean.shape[2:]
+    for i in range(Img.shape[0]):
+        PSNR += calculate_psnr(Iclean[i,:,].transpose((1,2,0)), Img[i,:,].transpose((1,2,0)), border)
+    return PSNR
+
+def batch_SSIM(img, imclean, border=0, ycbcr=False):
+    if ycbcr:
+        img = rgb2ycbcrTorch(img, True)
+        imclean = rgb2ycbcrTorch(imclean, True)
+    Img = img.data.cpu().numpy()
+    Iclean = imclean.data.cpu().numpy()
+    Img = img_as_ubyte(Img)
+    Iclean = img_as_ubyte(Iclean)
+    SSIM = 0
+    for i in range(Img.shape[0]):
+        SSIM += calculate_ssim(Iclean[i,:,].transpose((1,2,0)), Img[i,:,].transpose((1,2,0)), border)
+    return SSIM
+
+def normalize_np(im, mean=0.5, std=0.5, reverse=False):
+    '''
+    Input:
+        im: h x w x c, numpy array
+        Normalize: (im - mean) / std
+        Reverse: im * std + mean
+
+    '''
+    if not isinstance(mean, (list, tuple)):
+        mean = [mean, ] * im.shape[2]
+    mean = np.array(mean).reshape([1, 1, im.shape[2]])
+
+    if not isinstance(std, (list, tuple)):
+        std = [std, ] * im.shape[2]
+    std = np.array(std).reshape([1, 1, im.shape[2]])
+
+    if not reverse:
+        out = (im.astype(np.float32) - mean) / std
+    else:
+        out = im.astype(np.float32) * std + mean
+    return out
+
+def normalize_th(im, mean=0.5, std=0.5, reverse=False):
+    '''
+    Input:
+        im: b x c x h x w, torch tensor
+        Normalize: (im - mean) / std
+        Reverse: im * std + mean
+
+    '''
+    if not isinstance(mean, (list, tuple)):
+        mean = [mean, ] * im.shape[1]
+    mean = torch.tensor(mean, device=im.device).view([1, im.shape[1], 1, 1])
+
+    if not isinstance(std, (list, tuple)):
+        std = [std, ] * im.shape[1]
+    std = torch.tensor(std, device=im.device).view([1, im.shape[1], 1, 1])
+
+    if not reverse:
+        out = (im - mean) / std
+    else:
+        out = im * std + mean
+    return out
+
+# ------------------------Image format--------------------------
+def rgb2ycbcr(im, only_y=True):
+    '''
+    same as matlab rgb2ycbcr
+    Input:
+        im: uint8 [0,255] or float [0,1]
+        only_y: only return Y channel
+    '''
+    # transform to float64 data type, range [0, 255]
+    if im.dtype == np.uint8:
+        im_temp = im.astype(np.float64)
+    else:
+        im_temp = (im * 255).astype(np.float64)
+
+    # convert
+    if only_y:
+        rlt = np.dot(im_temp, np.array([65.481, 128.553, 24.966])/ 255.0) + 16.0
+    else:
+        rlt = np.matmul(im_temp, np.array([[65.481,  -37.797, 112.0  ],
+                                           [128.553, -74.203, -93.786],
+                                           [24.966,  112.0,   -18.214]])/255.0) + [16, 128, 128]
+    if im.dtype == np.uint8:
+        rlt = rlt.round()
+    else:
+        rlt /= 255.
+    return rlt.astype(im.dtype)
+
+def rgb2ycbcrTorch(im, only_y=True):
+    '''
+    same as matlab rgb2ycbcr
+    Input:
+        im: float [0,1], N x 3 x H x W
+        only_y: only return Y channel
+    '''
+    # transform to range [0,255.0]
+    im_temp = im.permute([0,2,3,1]) * 255.0  # N x H x W x C --> N x H x W x C
+    # convert
+    if only_y:
+        rlt = torch.matmul(im_temp, torch.tensor([65.481, 128.553, 24.966],
+                                        device=im.device, dtype=im.dtype).view([3,1])/ 255.0) + 16.0
+    else:
+        rlt = torch.matmul(im_temp, torch.tensor([[65.481,  -37.797, 112.0  ],
+                                                  [128.553, -74.203, -93.786],
+                                                  [24.966,  112.0,   -18.214]],
+                                                  device=im.device, dtype=im.dtype)/255.0) + \
+                                                    torch.tensor([16, 128, 128]).view([-1, 1, 1, 3])
+    rlt /= 255.0
+    rlt.clamp_(0.0, 1.0)
+    return rlt.permute([0, 3, 1, 2])
+
+def bgr2rgb(im): return cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
+
+def rgb2bgr(im): return cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
+
+def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
+    """Convert torch Tensors into image numpy arrays.
+
+    After clamping to [min, max], values will be normalized to [0, 1].
+
+    Args:
+        tensor (Tensor or list[Tensor]): Accept shapes:
+            1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
+            2) 3D Tensor of shape (3/1 x H x W);
+            3) 2D Tensor of shape (H x W).
+            Tensor channel should be in RGB order.
+        rgb2bgr (bool): Whether to change rgb to bgr.
+        out_type (numpy type): output types. If ``np.uint8``, transform outputs
+            to uint8 type with range [0, 255]; otherwise, float type with
+            range [0, 1]. Default: ``np.uint8``.
+        min_max (tuple[int]): min and max values for clamp.
+
+    Returns:
+        (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
+        shape (H x W). The channel order is BGR.
+    """
+    if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
+        raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
+
+    flag_tensor = torch.is_tensor(tensor)
+    if flag_tensor:
+        tensor = [tensor]
+    result = []
+    for _tensor in tensor:
+        _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
+        _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
+
+        n_dim = _tensor.dim()
+        if n_dim == 4:
+            img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
+            img_np = img_np.transpose(1, 2, 0)
+            if rgb2bgr:
+                img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
+        elif n_dim == 3:
+            img_np = _tensor.numpy()
+            img_np = img_np.transpose(1, 2, 0)
+            if img_np.shape[2] == 1:  # gray image
+                img_np = np.squeeze(img_np, axis=2)
+            else:
+                if rgb2bgr:
+                    img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
+        elif n_dim == 2:
+            img_np = _tensor.numpy()
+        else:
+            raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
+        if out_type == np.uint8:
+            # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
+            img_np = (img_np * 255.0).round()
+        img_np = img_np.astype(out_type)
+        result.append(img_np)
+    if len(result) == 1 and flag_tensor:
+        result = result[0]
+    return result
+
+def img2tensor(imgs, out_type=torch.float32):
+    """Convert image numpy arrays into torch tensor.
+    Args:
+        imgs (Array or list[array]): Accept shapes:
+            3) list of numpy arrays
+            1) 3D numpy array of shape (H x W x 3/1);
+            2) 2D Tensor of shape (H x W).
+            Tensor channel should be in RGB order.
+
+    Returns:
+        (array or list): 4D ndarray of shape (1 x C x H x W)
+    """
+
+    def _img2tensor(img):
+        if img.ndim == 2:
+            tensor = torch.from_numpy(img[None, None,]).type(out_type)
+        elif img.ndim == 3:
+            tensor = torch.from_numpy(rearrange(img, 'h w c -> c h w')).type(out_type).unsqueeze(0)
+        else:
+            raise TypeError(f'2D or 3D numpy array expected, got{img.ndim}D array')
+        return tensor
+
+    if not (isinstance(imgs, np.ndarray) or (isinstance(imgs, list) and all(isinstance(t, np.ndarray) for t in imgs))):
+        raise TypeError(f'Numpy array or list of numpy array expected, got {type(imgs)}')
+
+    flag_numpy = isinstance(imgs, np.ndarray)
+    if flag_numpy:
+        imgs = [imgs,]
+    result = []
+    for _img in imgs:
+        result.append(_img2tensor(_img))
+
+    if len(result) == 1 and flag_numpy:
+        result = result[0]
+    return result
+
+# ------------------------Image resize-----------------------------
+def imresize_np(img, scale, antialiasing=True):
+    # Now the scale should be the same for H and W
+    # input: img: Numpy, HWC or HW [0,1]
+    # output: HWC or HW [0,1] w/o round
+    img = torch.from_numpy(img)
+    need_squeeze = True if img.dim() == 2 else False
+    if need_squeeze:
+        img.unsqueeze_(2)
+
+    in_H, in_W, in_C = img.size()
+    out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
+    kernel_width = 4
+    kernel = 'cubic'
+
+    # Return the desired dimension order for performing the resize.  The
+    # strategy is to perform the resize first along the dimension with the
+    # smallest scale factor.
+    # Now we do not support this.
+
+    # get weights and indices
+    weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
+        in_H, out_H, scale, kernel, kernel_width, antialiasing)
+    weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
+        in_W, out_W, scale, kernel, kernel_width, antialiasing)
+    # process H dimension
+    # symmetric copying
+    img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
+    img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
+
+    sym_patch = img[:sym_len_Hs, :, :]
+    inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
+    sym_patch_inv = sym_patch.index_select(0, inv_idx)
+    img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
+
+    sym_patch = img[-sym_len_He:, :, :]
+    inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
+    sym_patch_inv = sym_patch.index_select(0, inv_idx)
+    img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
+
+    out_1 = torch.FloatTensor(out_H, in_W, in_C)
+    kernel_width = weights_H.size(1)
+    for i in range(out_H):
+        idx = int(indices_H[i][0])
+        for j in range(out_C):
+            out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
+
+    # process W dimension
+    # symmetric copying
+    out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
+    out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
+
+    sym_patch = out_1[:, :sym_len_Ws, :]
+    inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+    sym_patch_inv = sym_patch.index_select(1, inv_idx)
+    out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
+
+    sym_patch = out_1[:, -sym_len_We:, :]
+    inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+    sym_patch_inv = sym_patch.index_select(1, inv_idx)
+    out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
+
+    out_2 = torch.FloatTensor(out_H, out_W, in_C)
+    kernel_width = weights_W.size(1)
+    for i in range(out_W):
+        idx = int(indices_W[i][0])
+        for j in range(out_C):
+            out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])
+    if need_squeeze:
+        out_2.squeeze_()
+
+    return out_2.numpy()
+
+def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
+    if (scale < 1) and (antialiasing):
+        # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
+        kernel_width = kernel_width / scale
+
+    # Output-space coordinates
+    x = torch.linspace(1, out_length, out_length)
+
+    # Input-space coordinates. Calculate the inverse mapping such that 0.5
+    # in output space maps to 0.5 in input space, and 0.5+scale in output
+    # space maps to 1.5 in input space.
+    u = x / scale + 0.5 * (1 - 1 / scale)
+
+    # What is the left-most pixel that can be involved in the computation?
+    left = torch.floor(u - kernel_width / 2)
+
+    # What is the maximum number of pixels that can be involved in the
+    # computation?  Note: it's OK to use an extra pixel here; if the
+    # corresponding weights are all zero, it will be eliminated at the end
+    # of this function.
+    P = math.ceil(kernel_width) + 2
+
+    # The indices of the input pixels involved in computing the k-th output
+    # pixel are in row k of the indices matrix.
+    indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
+        1, P).expand(out_length, P)
+
+    # The weights used to compute the k-th output pixel are in row k of the
+    # weights matrix.
+    distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
+    # apply cubic kernel
+    if (scale < 1) and (antialiasing):
+        weights = scale * cubic(distance_to_center * scale)
+    else:
+        weights = cubic(distance_to_center)
+    # Normalize the weights matrix so that each row sums to 1.
+    weights_sum = torch.sum(weights, 1).view(out_length, 1)
+    weights = weights / weights_sum.expand(out_length, P)
+
+    # If a column in weights is all zero, get rid of it. only consider the first and last column.
+    weights_zero_tmp = torch.sum((weights == 0), 0)
+    if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
+        indices = indices.narrow(1, 1, P - 2)
+        weights = weights.narrow(1, 1, P - 2)
+    if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
+        indices = indices.narrow(1, 0, P - 2)
+        weights = weights.narrow(1, 0, P - 2)
+    weights = weights.contiguous()
+    indices = indices.contiguous()
+    sym_len_s = -indices.min() + 1
+    sym_len_e = indices.max() - in_length
+    indices = indices + sym_len_s - 1
+    return weights, indices, int(sym_len_s), int(sym_len_e)
+
+# matlab 'imresize' function, now only support 'bicubic'
+def cubic(x):
+    absx = torch.abs(x)
+    absx2 = absx**2
+    absx3 = absx**3
+    return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
+        (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
+
+# ------------------------Image I/O-----------------------------
+def imread(path, chn='rgb', dtype='float32'):
+    '''
+    Read image.
+    chn: 'rgb', 'bgr' or 'gray'
+    out:
+        im: h x w x c, numpy tensor
+    '''
+    im = cv2.imread(str(path), cv2.IMREAD_UNCHANGED)  # BGR, uint8
+    try:
+        if chn.lower() == 'rgb':
+            if im.ndim == 3:
+                im = bgr2rgb(im)
+            else:
+                im = np.stack((im, im, im), axis=2)
+        elif chn.lower() == 'gray':
+            assert im.ndim == 2
+    except:
+        print(str(path))
+
+    if dtype == 'float32':
+        im = im.astype(np.float32) / 255.
+    elif dtype ==  'float64':
+        im = im.astype(np.float64) / 255.
+    elif dtype == 'uint8':
+        pass
+    else:
+        sys.exit('Please input corrected dtype: float32, float64 or uint8!')
+
+    return im
+
+def imwrite(im_in, path, chn='rgb', dtype_in='float32', qf=None):
+    '''
+    Save image.
+    Input:
+        im: h x w x c, numpy tensor
+        path: the saving path
+        chn: the channel order of the im,
+    '''
+    im = im_in.copy()
+    if isinstance(path, str):
+        path = Path(path)
+    if dtype_in != 'uint8':
+        im = img_as_ubyte(im)
+
+    if chn.lower() == 'rgb' and im.ndim == 3:
+        im = rgb2bgr(im)
+
+    if qf is not None and path.suffix.lower() in ['.jpg', '.jpeg']:
+        flag = cv2.imwrite(str(path), im, [int(cv2.IMWRITE_JPEG_QUALITY), int(qf)])
+    else:
+        flag = cv2.imwrite(str(path), im)
+
+    return flag
+
+def jpeg_compress(im, qf, chn_in='rgb'):
+    '''
+    Input:
+        im: h x w x 3 array
+        qf: compress factor, (0, 100]
+        chn_in: 'rgb' or 'bgr'
+    Return:
+        Compressed Image with channel order: chn_in
+    '''
+    # transform to BGR channle and uint8 data type
+    im_bgr = rgb2bgr(im) if chn_in.lower() == 'rgb' else im
+    if im.dtype != np.dtype('uint8'): im_bgr = img_as_ubyte(im_bgr)
+
+    # JPEG compress
+    flag, encimg = cv2.imencode('.jpg', im_bgr, [int(cv2.IMWRITE_JPEG_QUALITY), qf])
+    assert flag
+    im_jpg_bgr = cv2.imdecode(encimg, 1)    # uint8, BGR
+
+    # transform back to original channel and the original data type
+    im_out = bgr2rgb(im_jpg_bgr) if chn_in.lower() == 'rgb' else im_jpg_bgr
+    if im.dtype != np.dtype('uint8'): im_out = img_as_float32(im_out).astype(im.dtype)
+    return im_out
+
+# ------------------------Augmentation-----------------------------
+def data_aug_np(image, mode):
+    '''
+    Performs data augmentation of the input image
+    Input:
+        image: a cv2 (OpenCV) image
+        mode: int. Choice of transformation to apply to the image
+                0 - no transformation
+                1 - flip up and down
+                2 - rotate counterwise 90 degree
+                3 - rotate 90 degree and flip up and down
+                4 - rotate 180 degree
+                5 - rotate 180 degree and flip
+                6 - rotate 270 degree
+                7 - rotate 270 degree and flip
+    '''
+    if mode == 0:
+        # original
+        out = image
+    elif mode == 1:
+        # flip up and down
+        out = np.flipud(image)
+    elif mode == 2:
+        # rotate counterwise 90 degree
+        out = np.rot90(image)
+    elif mode == 3:
+        # rotate 90 degree and flip up and down
+        out = np.rot90(image)
+        out = np.flipud(out)
+    elif mode == 4:
+        # rotate 180 degree
+        out = np.rot90(image, k=2)
+    elif mode == 5:
+        # rotate 180 degree and flip
+        out = np.rot90(image, k=2)
+        out = np.flipud(out)
+    elif mode == 6:
+        # rotate 270 degree
+        out = np.rot90(image, k=3)
+    elif mode == 7:
+        # rotate 270 degree and flip
+        out = np.rot90(image, k=3)
+        out = np.flipud(out)
+    else:
+        raise Exception('Invalid choice of image transformation')
+
+    return out.copy()
+
+def inverse_data_aug_np(image, mode):
+    '''
+    Performs inverse data augmentation of the input image
+    '''
+    if mode == 0:
+        # original
+        out = image
+    elif mode == 1:
+        out = np.flipud(image)
+    elif mode == 2:
+        out = np.rot90(image, axes=(1,0))
+    elif mode == 3:
+        out = np.flipud(image)
+        out = np.rot90(out, axes=(1,0))
+    elif mode == 4:
+        out = np.rot90(image, k=2, axes=(1,0))
+    elif mode == 5:
+        out = np.flipud(image)
+        out = np.rot90(out, k=2, axes=(1,0))
+    elif mode == 6:
+        out = np.rot90(image, k=3, axes=(1,0))
+    elif mode == 7:
+        # rotate 270 degree and flip
+        out = np.flipud(image)
+        out = np.rot90(out, k=3, axes=(1,0))
+    else:
+        raise Exception('Invalid choice of image transformation')
+
+    return out
+
+class SpatialAug:
+    def __init__(self):
+        pass
+
+    def __call__(self, im, flag=None):
+        if flag is None:
+            flag = random.randint(0, 7)
+
+        out = data_aug_np(im, flag)
+        return out
+
+# ----------------------Visualization----------------------------
+def imshow(x, title=None, cbar=False):
+    import matplotlib.pyplot as plt
+    plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')
+    if title:
+        plt.title(title)
+    if cbar:
+        plt.colorbar()
+    plt.show()
+
+# -----------------------Covolution------------------------------
+def imgrad(im, pading_mode='mirror'):
+    '''
+    Calculate image gradient.
+    Input:
+        im: h x w x c numpy array
+    '''
+    from scipy.ndimage import correlate  # lazy import
+    wx = np.array([[0, 0, 0],
+                   [-1, 1, 0],
+                   [0, 0, 0]], dtype=np.float32)
+    wy = np.array([[0, -1, 0],
+                   [0, 1, 0],
+                   [0, 0, 0]], dtype=np.float32)
+    if im.ndim == 3:
+        gradx = np.stack(
+                [correlate(im[:,:,c], wx, mode=pading_mode) for c in range(im.shape[2])],
+                axis=2
+                )
+        grady = np.stack(
+                [correlate(im[:,:,c], wy, mode=pading_mode) for c in range(im.shape[2])],
+                axis=2
+                )
+        grad = np.concatenate((gradx, grady), axis=2)
+    else:
+        gradx = correlate(im, wx, mode=pading_mode)
+        grady = correlate(im, wy, mode=pading_mode)
+        grad = np.stack((gradx, grady), axis=2)
+
+    return {'gradx': gradx, 'grady': grady, 'grad':grad}
+
+def imgrad_fft(im):
+    '''
+    Calculate image gradient.
+    Input:
+        im: h x w x c numpy array
+    '''
+    wx = np.rot90(np.array([[0, 0, 0],
+                            [-1, 1, 0],
+                            [0, 0, 0]], dtype=np.float32), k=2)
+    gradx = convfft(im, wx)
+    wy = np.rot90(np.array([[0, -1, 0],
+                            [0, 1, 0],
+                            [0, 0, 0]], dtype=np.float32), k=2)
+    grady = convfft(im, wy)
+    grad = np.concatenate((gradx, grady), axis=2)
+
+    return {'gradx': gradx, 'grady': grady, 'grad':grad}
+
+def convfft(im, weight):
+    '''
+    Convolution with FFT
+    Input:
+        im: h1 x w1 x c numpy array
+        weight: h2 x w2 numpy array
+    Output:
+        out: h1 x w1 x c numpy array
+    '''
+    axes = (0,1)
+    otf = psf2otf(weight, im.shape[:2])
+    if im.ndim == 3:
+        otf = np.tile(otf[:, :, None], (1,1,im.shape[2]))
+    out = fft.ifft2(fft.fft2(im, axes=axes) * otf, axes=axes).real
+    return out
+
+def psf2otf(psf, shape):
+    """
+    MATLAB psf2otf function.
+    Borrowed from https://github.com/aboucaud/pypher/blob/master/pypher/pypher.py.
+    Input:
+        psf : h x w numpy array
+        shape : list or tuple, output shape of the OTF array
+    Output:
+        otf : OTF array with the desirable shape
+    """
+    if np.all(psf == 0):
+        return np.zeros_like(psf)
+
+    inshape = psf.shape
+    # Pad the PSF to outsize
+    psf = zero_pad(psf, shape, position='corner')
+
+    # Circularly shift OTF so that the 'center' of the PSF is [0,0] element of the array
+    for axis, axis_size in enumerate(inshape):
+        psf = np.roll(psf, -int(axis_size / 2), axis=axis)
+
+    # Compute the OTF
+    otf = fft.fft2(psf)
+
+    # Estimate the rough number of operations involved in the FFT
+    # and discard the PSF imaginary part if within roundoff error
+    # roundoff error  = machine epsilon = sys.float_info.epsilon
+    # or np.finfo().eps
+    n_ops = np.sum(psf.size * np.log2(psf.shape))
+    otf = np.real_if_close(otf, tol=n_ops)
+
+    return otf
+
+# ----------------------Patch Cropping----------------------------
+def random_crop(im, pch_size):
+    '''
+    Randomly crop a patch from the give image.
+    '''
+    h, w = im.shape[:2]
+    if h == pch_size and w == pch_size:
+        im_pch = im
+    else:
+        assert h >= pch_size or w >= pch_size
+        ind_h = random.randint(0, h-pch_size)
+        ind_w = random.randint(0, w-pch_size)
+        im_pch = im[ind_h:ind_h+pch_size, ind_w:ind_w+pch_size,]
+
+    return im_pch
+
+class RandomCrop:
+    def __init__(self, pch_size):
+        self.pch_size = pch_size
+
+    def __call__(self, im):
+        return random_crop(im, self.pch_size)
+
+class ImageSpliterNp:
+    def __init__(self, im, pch_size, stride, sf=1):
+        '''
+        Input:
+            im: h x w x c, numpy array, [0, 1], low-resolution image in SR
+            pch_size, stride: patch setting
+            sf: scale factor in image super-resolution
+        '''
+        assert stride <= pch_size
+        self.stride = stride
+        self.pch_size = pch_size
+        self.sf = sf
+
+        if im.ndim == 2:
+            im = im[:, :, None]
+
+        height, width, chn = im.shape
+        self.height_starts_list = self.extract_starts(height)
+        self.width_starts_list = self.extract_starts(width)
+        self.length = self.__len__()
+        self.num_pchs = 0
+
+        self.im_ori = im
+        self.im_res = np.zeros([height*sf, width*sf, chn], dtype=im.dtype)
+        self.pixel_count = np.zeros([height*sf, width*sf, chn], dtype=im.dtype)
+
+    def extract_starts(self, length):
+        starts = list(range(0, length, self.stride))
+        if starts[-1] + self.pch_size > length:
+            starts[-1] = length - self.pch_size
+        return starts
+
+    def __len__(self):
+        return len(self.height_starts_list) * len(self.width_starts_list)
+
+    def __iter__(self):
+        return self
+
+    def __next__(self):
+        if self.num_pchs < self.length:
+            w_start_idx = self.num_pchs // len(self.height_starts_list)
+            w_start = self.width_starts_list[w_start_idx] * self.sf
+            w_end = w_start + self.pch_size * self.sf
+
+            h_start_idx = self.num_pchs % len(self.height_starts_list)
+            h_start = self.height_starts_list[h_start_idx] * self.sf
+            h_end = h_start + self.pch_size * self.sf
+
+            pch = self.im_ori[h_start:h_end, w_start:w_end,]
+            self.w_start, self.w_end = w_start, w_end
+            self.h_start, self.h_end = h_start, h_end
+
+            self.num_pchs += 1
+        else:
+            raise StopIteration(0)
+
+        return pch, (h_start, h_end, w_start, w_end)
+
+    def update(self, pch_res, index_infos):
+        '''
+        Input:
+            pch_res: pch_size x pch_size x 3, [0,1]
+            index_infos: (h_start, h_end, w_start, w_end)
+        '''
+        if index_infos is None:
+            w_start, w_end = self.w_start, self.w_end
+            h_start, h_end = self.h_start, self.h_end
+        else:
+            h_start, h_end, w_start, w_end = index_infos
+
+        self.im_res[h_start:h_end, w_start:w_end] += pch_res
+        self.pixel_count[h_start:h_end, w_start:w_end] += 1
+
+    def gather(self):
+        assert np.all(self.pixel_count != 0)
+        return self.im_res / self.pixel_count
+
+class ImageSpliterTh:
+    def __init__(self, im, pch_size, stride, sf=1, extra_bs=1):
+        '''
+        Input:
+            im: n x c x h x w, torch tensor, float, low-resolution image in SR
+            pch_size, stride: patch setting
+            sf: scale factor in image super-resolution
+            pch_bs: aggregate pchs to processing, only used when inputing single image
+        '''
+        assert stride <= pch_size
+        self.stride = stride
+        self.pch_size = pch_size
+        self.sf = sf
+        self.extra_bs = extra_bs
+
+        bs, chn, height, width= im.shape
+        self.true_bs = bs
+
+        self.height_starts_list = self.extract_starts(height)
+        self.width_starts_list = self.extract_starts(width)
+        self.starts_list = []
+        for ii in self.height_starts_list:
+            for jj in self.width_starts_list:
+                self.starts_list.append([ii, jj])
+
+        self.length = self.__len__()
+        self.count_pchs = 0
+
+        self.im_ori = im
+        self.im_res = torch.zeros([bs, chn, height*sf, width*sf], dtype=im.dtype, device=im.device)
+        self.pixel_count = torch.zeros([bs, chn, height*sf, width*sf], dtype=im.dtype, device=im.device)
+
+    def extract_starts(self, length):
+        if length <= self.pch_size:
+            starts = [0,]
+        else:
+            starts = list(range(0, length, self.stride))
+            for ii in range(len(starts)):
+                if starts[ii] + self.pch_size > length:
+                    starts[ii] = length - self.pch_size
+            starts = sorted(set(starts), key=starts.index)
+        return starts
+
+    def __len__(self):
+        return len(self.height_starts_list) * len(self.width_starts_list)
+
+    def __iter__(self):
+        return self
+
+    def __next__(self):
+        if self.count_pchs < self.length:
+            index_infos = []
+            current_starts_list = self.starts_list[self.count_pchs:self.count_pchs+self.extra_bs]
+            for ii, (h_start, w_start) in enumerate(current_starts_list):
+                w_end = w_start + self.pch_size
+                h_end = h_start + self.pch_size
+                current_pch = self.im_ori[:, :, h_start:h_end, w_start:w_end]
+                if ii == 0:
+                    pch =  current_pch
+                else:
+                    pch = torch.cat([pch, current_pch], dim=0)
+
+                h_start *= self.sf
+                h_end *= self.sf
+                w_start *= self.sf
+                w_end *= self.sf
+                index_infos.append([h_start, h_end, w_start, w_end])
+
+            self.count_pchs += len(current_starts_list)
+        else:
+            raise StopIteration()
+
+        return pch, index_infos
+
+    def update(self, pch_res, index_infos):
+        '''
+        Input:
+            pch_res: (n*extra_bs) x c x pch_size x pch_size, float
+            index_infos: [(h_start, h_end, w_start, w_end),]
+        '''
+        assert pch_res.shape[0] % self.true_bs == 0
+        pch_list = torch.split(pch_res, self.true_bs, dim=0)
+        assert len(pch_list) == len(index_infos)
+        for ii, (h_start, h_end, w_start, w_end) in enumerate(index_infos):
+            current_pch = pch_list[ii]
+            self.im_res[:, :, h_start:h_end, w_start:w_end] +=  current_pch
+            self.pixel_count[:, :, h_start:h_end, w_start:w_end] += 1
+
+    def gather(self):
+        assert torch.all(self.pixel_count != 0)
+        return self.im_res.div(self.pixel_count)
+
+# ----------------------Patch Cropping----------------------------
+class Clamper:
+    def __init__(self, min_max=(-1, 1)):
+        self.min_bound, self.max_bound = min_max[0], min_max[1]
+
+    def __call__(self, im):
+        if isinstance(im, np.ndarray):
+            return np.clip(im, a_min=self.min_bound, a_max=self.max_bound)
+        elif isinstance(im, torch.Tensor):
+            return torch.clamp(im, min=self.min_bound, max=self.max_bound)
+        else:
+            raise TypeError(f'ndarray or Tensor expected, got {type(im)}')
+
+if __name__ == '__main__':
+    im = np.random.randn(64, 64, 3).astype(np.float32)
+
+    grad1 = imgrad(im)['grad']
+    grad2 = imgrad_fft(im)['grad']
+
+    error = np.abs(grad1 -grad2).max()
+    mean_error = np.abs(grad1 -grad2).mean()
+    print('The largest error is {:.2e}'.format(error))
+    print('The mean error is {:.2e}'.format(mean_error))
\ No newline at end of file
diff --git a/utils/wavelet_color.py b/utils/wavelet_color.py
new file mode 100644
index 0000000000000000000000000000000000000000..0947a8621ecea7ef3b96d8b56acbdaecd3821a3d
--- /dev/null
+++ b/utils/wavelet_color.py
@@ -0,0 +1,119 @@
+'''
+# --------------------------------------------------------------------------------
+#   Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py)
+# --------------------------------------------------------------------------------
+'''
+
+import torch
+from PIL import Image
+from torch import Tensor
+from torch.nn import functional as F
+
+from torchvision.transforms import ToTensor, ToPILImage
+
+def adain_color_fix(target: Image, source: Image):
+    # Convert images to tensors
+    to_tensor = ToTensor()
+    target_tensor = to_tensor(target).unsqueeze(0)
+    source_tensor = to_tensor(source).unsqueeze(0)
+
+    # Apply adaptive instance normalization
+    result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)
+
+    # Convert tensor back to image
+    to_image = ToPILImage()
+    result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
+
+    return result_image
+
+def wavelet_color_fix(target: Image, source: Image):
+    # Convert images to tensors
+    to_tensor = ToTensor()
+    target_tensor = to_tensor(target).unsqueeze(0)
+    source_tensor = to_tensor(source).unsqueeze(0)
+
+    # Apply wavelet reconstruction
+    result_tensor = wavelet_reconstruction(target_tensor, source_tensor)
+
+    # Convert tensor back to image
+    to_image = ToPILImage()
+    result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
+
+    return result_image
+
+def calc_mean_std(feat: Tensor, eps=1e-5):
+    """Calculate mean and std for adaptive_instance_normalization.
+    Args:
+        feat (Tensor): 4D tensor.
+        eps (float): A small value added to the variance to avoid
+            divide-by-zero. Default: 1e-5.
+    """
+    size = feat.size()
+    assert len(size) == 4, 'The input feature should be 4D tensor.'
+    b, c = size[:2]
+    feat_var = feat.reshape(b, c, -1).var(dim=2) + eps
+    feat_std = feat_var.sqrt().reshape(b, c, 1, 1)
+    feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1)
+    return feat_mean, feat_std
+
+def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):
+    """Adaptive instance normalization.
+    Adjust the reference features to have the similar color and illuminations
+    as those in the degradate features.
+    Args:
+        content_feat (Tensor): The reference feature.
+        style_feat (Tensor): The degradate features.
+    """
+    size = content_feat.size()
+    style_mean, style_std = calc_mean_std(style_feat)
+    content_mean, content_std = calc_mean_std(content_feat)
+    normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
+    return normalized_feat * style_std.expand(size) + style_mean.expand(size)
+
+def wavelet_blur(image: Tensor, radius: int):
+    """
+    Apply wavelet blur to the input tensor.
+    """
+    # input shape: (1, 3, H, W)
+    # convolution kernel
+    kernel_vals = [
+        [0.0625, 0.125, 0.0625],
+        [0.125, 0.25, 0.125],
+        [0.0625, 0.125, 0.0625],
+    ]
+    kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
+    # add channel dimensions to the kernel to make it a 4D tensor
+    kernel = kernel[None, None]
+    # repeat the kernel across all input channels
+    kernel = kernel.repeat(3, 1, 1, 1)
+    image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
+    # apply convolution
+    output = F.conv2d(image, kernel, groups=3, dilation=radius)
+    return output
+
+def wavelet_decomposition(image: Tensor, levels=5):
+    """
+    Apply wavelet decomposition to the input tensor.
+    This function only returns the low frequency & the high frequency.
+    """
+    high_freq = torch.zeros_like(image)
+    for i in range(levels):
+        radius = 2 ** i
+        low_freq = wavelet_blur(image, radius)
+        high_freq += (image - low_freq)
+        image = low_freq
+
+    return high_freq, low_freq
+
+def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
+    """
+    Apply wavelet decomposition, so that the content will have the same color as the style.
+    """
+    # calculate the wavelet decomposition of the content feature
+    content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
+    del content_low_freq
+    # calculate the wavelet decomposition of the style feature
+    style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
+    del style_high_freq
+    # reconstruct the content feature with the style's high frequency
+    return content_high_freq + style_low_freq
\ No newline at end of file