|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import numpy as np |
|
import math |
|
|
|
|
|
class WaveletTransform(nn.Module): |
|
|
|
def __init__(self, patch_size: int, inverse: bool = False): |
|
''' |
|
`patchwise` in forward/invert makes *no difference*; the result |
|
is numerically identical either way. It's still enabled by default |
|
in case we pass in a non-square image, which may not be equivalent. |
|
`reshape` is pretty much useless. |
|
TODO: Clean up these options. |
|
''' |
|
super().__init__() |
|
self.patch_size = patch_size |
|
self.inverse = inverse |
|
|
|
self.haar = torch.tensor([0.7071067811865476, 0.7071067811865476]) |
|
self.arange = torch.arange(len(self.haar)) |
|
self.steps = int(math.log2(self.patch_size)) |
|
|
|
def num_transformed_channels(self, in_channels: int = 3) -> int: |
|
''' |
|
Returns the number of channels to expect in the transformed image |
|
given the channels in the input image. |
|
''' |
|
return in_channels * (4 ** self.steps) |
|
|
|
|
|
def forward(self, x: torch.Tensor, patchwise: bool = True, reshape: bool = False) -> torch.Tensor: |
|
if self.inverse: |
|
return self.invert(x, patchwise=patchwise, from_reshaped=reshape) |
|
else: |
|
return self.transform(x, patchwise=patchwise, reshape=reshape) |
|
|
|
|
|
def transform(self, x: torch.Tensor, patchwise: bool = True, reshape: bool = False) -> torch.Tensor: |
|
''' |
|
### Parameters: |
|
`x`: ImageNet-normalized images with shape (B C H W) |
|
`patchwise`: Whether to compute independently on patches |
|
`reshape`: Reshape the results to match the input HxW |
|
### Returns: |
|
If `reshape`, returns (B C H W) |
|
otherwise, returns (B C*patch_size**2 H/patch_size W/patch_size) |
|
''' |
|
p = self.patch_size |
|
if patchwise: |
|
|
|
|
|
b, c, h, w = x.shape |
|
init_b = b |
|
|
|
x = x.reshape(b, c, h//p, p, w//p, p).moveaxis(4,3) |
|
|
|
x = x.moveaxis(1,3).reshape(-1, c, p, p) |
|
|
|
for _ in range(self.steps): |
|
x = self.dwt(x) |
|
|
|
if patchwise: |
|
|
|
|
|
x = x.reshape(init_b, h//p, w//p, -1).moveaxis(3,1) |
|
if reshape: |
|
|
|
b, cp2, hdp, wdp = x.shape |
|
c, h, w = cp2//(p**2), hdp*p, wdp*p |
|
x = x.reshape(b, p, p, c, hdp, wdp) |
|
x = x.moveaxis(3,1).moveaxis(3,4).reshape(b, c, h, w).contiguous() |
|
return x |
|
|
|
def invert(self, x: torch.Tensor, patchwise: bool = True, from_reshaped: bool = False) -> torch.Tensor: |
|
''' |
|
### Parameters: |
|
`x`: Wavelet-space input of either (B C H W) (when `from_reshaped=True`) or |
|
(B C*patch_size**2 H/patch_size W/patch_size) |
|
`patchwise`: Whether to compute independently on patches |
|
`from_reshaped`: Determines the shape of `x`; should match the value of `reshape` |
|
used when calling `forward` |
|
''' |
|
p = self.patch_size |
|
if from_reshaped: |
|
|
|
b, c, h, w = x.shape |
|
cp2, hdp, wdp = c*self.patch_size**2, h//self.patch_size, w//self.patch_size |
|
x = x.reshape(b, c, self.patch_size, hdp, self.patch_size, wdp) |
|
x = x.moveaxis(4,3).moveaxis(1,3).reshape(b, cp2, hdp, wdp) |
|
if patchwise: |
|
|
|
|
|
init_b, lh, lw = x.shape[0], x.shape[2], x.shape[3] |
|
x = x.moveaxis(1,3).reshape(-1, x.shape[1], 1, 1) |
|
|
|
for _ in range(self.steps): |
|
x = self.idwt(x) |
|
|
|
if patchwise: |
|
|
|
|
|
x = x.reshape(init_b, lh, lw, *x.shape[1:]).moveaxis(3,1) |
|
|
|
x = x.moveaxis(3,4).reshape(*x.shape[:2], lh*p, lw*p) |
|
return x |
|
|
|
|
|
def dwt(self, x: torch.Tensor): |
|
dtype = x.dtype |
|
h = self.haar |
|
|
|
n = h.shape[0] |
|
g = x.shape[1] |
|
hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1) |
|
hh = (h * ((-1) ** self.arange)).reshape(1, 1, -1).repeat(g, 1, 1) |
|
hh = hh.to(device=x.device, dtype=dtype) |
|
hl = hl.to(device=x.device, dtype=dtype) |
|
|
|
x = F.pad(x, pad=(n - 2, n - 1, n - 2, n - 1), mode='reflect').to(dtype) |
|
xl = F.conv2d(x, hl.unsqueeze(2), groups=g, stride=(1, 2)) |
|
xh = F.conv2d(x, hh.unsqueeze(2), groups=g, stride=(1, 2)) |
|
xll = F.conv2d(xl, hl.unsqueeze(3), groups=g, stride=(2, 1)) |
|
xlh = F.conv2d(xl, hh.unsqueeze(3), groups=g, stride=(2, 1)) |
|
xhl = F.conv2d(xh, hl.unsqueeze(3), groups=g, stride=(2, 1)) |
|
xhh = F.conv2d(xh, hh.unsqueeze(3), groups=g, stride=(2, 1)) |
|
|
|
return 0.5 * torch.cat([xll, xlh, xhl, xhh], dim=1) |
|
|
|
|
|
def idwt(self, x: torch.Tensor): |
|
dtype = x.dtype |
|
h = self.haar |
|
n = h.shape[0] |
|
|
|
g = x.shape[1] // 4 |
|
hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1]) |
|
hh = (h * ((-1) ** self.arange)).reshape(1, 1, -1).repeat(g, 1, 1) |
|
hh = hh.to(device=x.device, dtype=dtype) |
|
hl = hl.to(device=x.device, dtype=dtype) |
|
|
|
xll, xlh, xhl, xhh = torch.chunk(x.to(dtype), 4, dim=1) |
|
|
|
|
|
yl = torch.nn.functional.conv_transpose2d( |
|
xll, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0) |
|
) |
|
yl += torch.nn.functional.conv_transpose2d( |
|
xlh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0) |
|
) |
|
yh = torch.nn.functional.conv_transpose2d( |
|
xhl, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0) |
|
) |
|
yh += torch.nn.functional.conv_transpose2d( |
|
xhh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0) |
|
) |
|
y = torch.nn.functional.conv_transpose2d( |
|
yl, hl.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2) |
|
) |
|
y += torch.nn.functional.conv_transpose2d( |
|
yh, hh.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2) |
|
) |
|
|
|
return 2.0 * y |
|
|