PLPQ / wavelet.py
TheTrueJard's picture
Upload folder using huggingface_hub
748c921 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
class WaveletTransform(nn.Module):
def __init__(self, patch_size: int, inverse: bool = False):
'''
`patchwise` in forward/invert makes *no difference*; the result
is numerically identical either way. It's still enabled by default
in case we pass in a non-square image, which may not be equivalent.
`reshape` is pretty much useless.
TODO: Clean up these options.
'''
super().__init__()
self.patch_size = patch_size
self.inverse = inverse
# From https://github.com/NVIDIA/Cosmos-Tokenizer/blob/3584ae752ce8ebdbe06a420bf60d7513c0e878cc/cosmos_tokenizer/modules/patching.py#L33
self.haar = torch.tensor([0.7071067811865476, 0.7071067811865476])
self.arange = torch.arange(len(self.haar))
self.steps = int(math.log2(self.patch_size))
def num_transformed_channels(self, in_channels: int = 3) -> int:
'''
Returns the number of channels to expect in the transformed image
given the channels in the input image.
'''
return in_channels * (4 ** self.steps)
def forward(self, x: torch.Tensor, patchwise: bool = True, reshape: bool = False) -> torch.Tensor:
if self.inverse:
return self.invert(x, patchwise=patchwise, from_reshaped=reshape)
else:
return self.transform(x, patchwise=patchwise, reshape=reshape)
def transform(self, x: torch.Tensor, patchwise: bool = True, reshape: bool = False) -> torch.Tensor:
'''
### Parameters:
`x`: ImageNet-normalized images with shape (B C H W)
`patchwise`: Whether to compute independently on patches
`reshape`: Reshape the results to match the input HxW
### Returns:
If `reshape`, returns (B C H W)
otherwise, returns (B C*patch_size**2 H/patch_size W/patch_size)
'''
p = self.patch_size
if patchwise:
# Place patches into batch dimension
# (B C H W) -> (B*L C H/root(L), W/root(L))
b, c, h, w = x.shape
init_b = b
# (B C H W) -> (B C LH LW P P)
x = x.reshape(b, c, h//p, p, w//p, p).moveaxis(4,3)
# (B C LH LW P P) -> (B' C P P)
x = x.moveaxis(1,3).reshape(-1, c, p, p)
for _ in range(self.steps):
x = self.dwt(x)
if patchwise:
# Extract patches from batch dimension
# (B' C' 1 1) -> (B LH LW C') -> (B C' LH LW)
x = x.reshape(init_b, h//p, w//p, -1).moveaxis(3,1)
if reshape:
# (B C*patch_size**2 H/patch_size W/patch_size) -> (B C H W)
b, cp2, hdp, wdp = x.shape
c, h, w = cp2//(p**2), hdp*p, wdp*p
x = x.reshape(b, p, p, c, hdp, wdp)
x = x.moveaxis(3,1).moveaxis(3,4).reshape(b, c, h, w).contiguous()
return x
def invert(self, x: torch.Tensor, patchwise: bool = True, from_reshaped: bool = False) -> torch.Tensor:
'''
### Parameters:
`x`: Wavelet-space input of either (B C H W) (when `from_reshaped=True`) or
(B C*patch_size**2 H/patch_size W/patch_size)
`patchwise`: Whether to compute independently on patches
`from_reshaped`: Determines the shape of `x`; should match the value of `reshape`
used when calling `forward`
'''
p = self.patch_size
if from_reshaped:
# (B C H W) -> (B C*patch_size**2 H/patch_size W/patch_size)
b, c, h, w = x.shape
cp2, hdp, wdp = c*self.patch_size**2, h//self.patch_size, w//self.patch_size
x = x.reshape(b, c, self.patch_size, hdp, self.patch_size, wdp)
x = x.moveaxis(4,3).moveaxis(1,3).reshape(b, cp2, hdp, wdp)
if patchwise:
# Put patches into batch dimension
# (B C' LH LW) -> (B LH LW C') -> (B' C' 1 1)
init_b, lh, lw = x.shape[0], x.shape[2], x.shape[3]
x = x.moveaxis(1,3).reshape(-1, x.shape[1], 1, 1)
for _ in range(self.steps):
x = self.idwt(x)
if patchwise:
# Extract patches from batch dimension and expand
# (B' C P P) -> (B C LH LW P P)
x = x.reshape(init_b, lh, lw, *x.shape[1:]).moveaxis(3,1)
# (B C LH LW P P) -> (B C H W)
x = x.moveaxis(3,4).reshape(*x.shape[:2], lh*p, lw*p)
return x
def dwt(self, x: torch.Tensor):
dtype = x.dtype
h = self.haar
n = h.shape[0]
g = x.shape[1]
hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1)
hh = (h * ((-1) ** self.arange)).reshape(1, 1, -1).repeat(g, 1, 1)
hh = hh.to(device=x.device, dtype=dtype)
hl = hl.to(device=x.device, dtype=dtype)
x = F.pad(x, pad=(n - 2, n - 1, n - 2, n - 1), mode='reflect').to(dtype)
xl = F.conv2d(x, hl.unsqueeze(2), groups=g, stride=(1, 2))
xh = F.conv2d(x, hh.unsqueeze(2), groups=g, stride=(1, 2))
xll = F.conv2d(xl, hl.unsqueeze(3), groups=g, stride=(2, 1))
xlh = F.conv2d(xl, hh.unsqueeze(3), groups=g, stride=(2, 1))
xhl = F.conv2d(xh, hl.unsqueeze(3), groups=g, stride=(2, 1))
xhh = F.conv2d(xh, hh.unsqueeze(3), groups=g, stride=(2, 1))
return 0.5 * torch.cat([xll, xlh, xhl, xhh], dim=1)
def idwt(self, x: torch.Tensor):
dtype = x.dtype
h = self.haar
n = h.shape[0]
g = x.shape[1] // 4
hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1])
hh = (h * ((-1) ** self.arange)).reshape(1, 1, -1).repeat(g, 1, 1)
hh = hh.to(device=x.device, dtype=dtype)
hl = hl.to(device=x.device, dtype=dtype)
xll, xlh, xhl, xhh = torch.chunk(x.to(dtype), 4, dim=1)
# Inverse transform.
yl = torch.nn.functional.conv_transpose2d(
xll, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
)
yl += torch.nn.functional.conv_transpose2d(
xlh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
)
yh = torch.nn.functional.conv_transpose2d(
xhl, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
)
yh += torch.nn.functional.conv_transpose2d(
xhh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
)
y = torch.nn.functional.conv_transpose2d(
yl, hl.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2)
)
y += torch.nn.functional.conv_transpose2d(
yh, hh.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2)
)
return 2.0 * y