hangg-sai's picture
Initial commit
a342aa8
import math
import torch
import torch.nn.functional as F
from einops import repeat
from torch import nn
from .transformer import MultiviewTransformer
def timestep_embedding(
timesteps: torch.Tensor,
dim: int,
max_period: int = 10000,
repeat_only: bool = False,
) -> torch.Tensor:
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
else:
embedding = repeat(timesteps, "b -> b d", d=dim)
return embedding
class Upsample(nn.Module):
def __init__(self, channels: int, out_channels: int | None = None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.conv = nn.Conv2d(self.channels, self.out_channels, 3, 1, 1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
assert x.shape[1] == self.channels
x = F.interpolate(x, scale_factor=2, mode="nearest")
x = self.conv(x)
return x
class Downsample(nn.Module):
def __init__(self, channels: int, out_channels: int | None = None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.op = nn.Conv2d(self.channels, self.out_channels, 3, 2, 1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
assert x.shape[1] == self.channels
return self.op(x)
class GroupNorm32(nn.GroupNorm):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return super().forward(input.float()).type(input.dtype)
class TimestepEmbedSequential(nn.Sequential):
def forward( # type: ignore[override]
self,
x: torch.Tensor,
emb: torch.Tensor,
context: torch.Tensor,
dense_emb: torch.Tensor,
num_frames: int,
) -> torch.Tensor:
for layer in self:
if isinstance(layer, MultiviewTransformer):
assert num_frames is not None
x = layer(x, context, num_frames)
elif isinstance(layer, ResBlock):
x = layer(x, emb, dense_emb)
else:
x = layer(x)
return x
class ResBlock(nn.Module):
def __init__(
self,
channels: int,
emb_channels: int,
out_channels: int | None,
dense_in_channels: int,
dropout: float,
):
super().__init__()
out_channels = out_channels or channels
self.in_layers = nn.Sequential(
GroupNorm32(32, channels),
nn.SiLU(),
nn.Conv2d(channels, out_channels, 3, 1, 1),
)
self.emb_layers = nn.Sequential(
nn.SiLU(), nn.Linear(emb_channels, out_channels)
)
self.dense_emb_layers = nn.Sequential(
nn.Conv2d(dense_in_channels, 2 * channels, 1, 1, 0)
)
self.out_layers = nn.Sequential(
GroupNorm32(32, out_channels),
nn.SiLU(),
nn.Dropout(dropout),
nn.Conv2d(out_channels, out_channels, 3, 1, 1),
)
if out_channels == channels:
self.skip_connection = nn.Identity()
else:
self.skip_connection = nn.Conv2d(channels, out_channels, 1, 1, 0)
def forward(
self, x: torch.Tensor, emb: torch.Tensor, dense_emb: torch.Tensor
) -> torch.Tensor:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
dense = self.dense_emb_layers(
F.interpolate(
dense_emb, size=h.shape[2:], mode="bilinear", align_corners=True
)
).type(h.dtype)
dense_scale, dense_shift = torch.chunk(dense, 2, dim=1)
h = h * (1 + dense_scale) + dense_shift
h = in_conv(h)
emb_out = self.emb_layers(emb).type(h.dtype)
# TODO(hangg): Optimize this?
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
h = h + emb_out
h = self.out_layers(h)
h = self.skip_connection(x) + h
return h