hangg-sai's picture
Initial commit
a342aa8
from dataclasses import dataclass, field
import torch
import torch.nn as nn
from seva.modules.layers import (
Downsample,
GroupNorm32,
ResBlock,
TimestepEmbedSequential,
Upsample,
timestep_embedding,
)
from seva.modules.transformer import MultiviewTransformer
@dataclass
class SevaParams(object):
in_channels: int = 11
model_channels: int = 320
out_channels: int = 4
num_frames: int = 21
num_res_blocks: int = 2
attention_resolutions: list[int] = field(default_factory=lambda: [4, 2, 1])
channel_mult: list[int] = field(default_factory=lambda: [1, 2, 4, 4])
num_head_channels: int = 64
transformer_depth: list[int] = field(default_factory=lambda: [1, 1, 1, 1])
context_dim: int = 1024
dense_in_channels: int = 6
dropout: float = 0.0
unflatten_names: list[str] = field(
default_factory=lambda: ["middle_ds8", "output_ds4", "output_ds2"]
)
def __post_init__(self):
assert len(self.channel_mult) == len(self.transformer_depth)
class Seva(nn.Module):
def __init__(self, params: SevaParams) -> None:
super().__init__()
self.params = params
self.model_channels = params.model_channels
self.out_channels = params.out_channels
self.num_head_channels = params.num_head_channels
time_embed_dim = params.model_channels * 4
self.time_embed = nn.Sequential(
nn.Linear(params.model_channels, time_embed_dim),
nn.SiLU(),
nn.Linear(time_embed_dim, time_embed_dim),
)
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
nn.Conv2d(params.in_channels, params.model_channels, 3, padding=1)
)
]
)
self._feature_size = params.model_channels
input_block_chans = [params.model_channels]
ch = params.model_channels
ds = 1
for level, mult in enumerate(params.channel_mult):
for _ in range(params.num_res_blocks):
input_layers: list[ResBlock | MultiviewTransformer | Downsample] = [
ResBlock(
channels=ch,
emb_channels=time_embed_dim,
out_channels=mult * params.model_channels,
dense_in_channels=params.dense_in_channels,
dropout=params.dropout,
)
]
ch = mult * params.model_channels
if ds in params.attention_resolutions:
num_heads = ch // params.num_head_channels
dim_head = params.num_head_channels
input_layers.append(
MultiviewTransformer(
ch,
num_heads,
dim_head,
name=f"input_ds{ds}",
depth=params.transformer_depth[level],
context_dim=params.context_dim,
unflatten_names=params.unflatten_names,
)
)
self.input_blocks.append(TimestepEmbedSequential(*input_layers))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(params.channel_mult) - 1:
ds *= 2
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(Downsample(ch, out_channels=out_ch))
)
ch = out_ch
input_block_chans.append(ch)
self._feature_size += ch
num_heads = ch // params.num_head_channels
dim_head = params.num_head_channels
self.middle_block = TimestepEmbedSequential(
ResBlock(
channels=ch,
emb_channels=time_embed_dim,
out_channels=None,
dense_in_channels=params.dense_in_channels,
dropout=params.dropout,
),
MultiviewTransformer(
ch,
num_heads,
dim_head,
name=f"middle_ds{ds}",
depth=params.transformer_depth[-1],
context_dim=params.context_dim,
unflatten_names=params.unflatten_names,
),
ResBlock(
channels=ch,
emb_channels=time_embed_dim,
out_channels=None,
dense_in_channels=params.dense_in_channels,
dropout=params.dropout,
),
)
self._feature_size += ch
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(params.channel_mult))[::-1]:
for i in range(params.num_res_blocks + 1):
ich = input_block_chans.pop()
output_layers: list[ResBlock | MultiviewTransformer | Upsample] = [
ResBlock(
channels=ch + ich,
emb_channels=time_embed_dim,
out_channels=params.model_channels * mult,
dense_in_channels=params.dense_in_channels,
dropout=params.dropout,
)
]
ch = params.model_channels * mult
if ds in params.attention_resolutions:
num_heads = ch // params.num_head_channels
dim_head = params.num_head_channels
output_layers.append(
MultiviewTransformer(
ch,
num_heads,
dim_head,
name=f"output_ds{ds}",
depth=params.transformer_depth[level],
context_dim=params.context_dim,
unflatten_names=params.unflatten_names,
)
)
if level and i == params.num_res_blocks:
out_ch = ch
ds //= 2
output_layers.append(Upsample(ch, out_ch))
self.output_blocks.append(TimestepEmbedSequential(*output_layers))
self._feature_size += ch
self.out = nn.Sequential(
GroupNorm32(32, ch),
nn.SiLU(),
nn.Conv2d(self.model_channels, params.out_channels, 3, padding=1),
)
def forward(
self,
x: torch.Tensor,
t: torch.Tensor,
y: torch.Tensor,
dense_y: torch.Tensor,
num_frames: int | None = None,
) -> torch.Tensor:
num_frames = num_frames or self.params.num_frames
t_emb = timestep_embedding(t, self.model_channels)
t_emb = self.time_embed(t_emb)
hs = []
h = x
for module in self.input_blocks:
h = module(
h,
emb=t_emb,
context=y,
dense_emb=dense_y,
num_frames=num_frames,
)
hs.append(h)
h = self.middle_block(
h,
emb=t_emb,
context=y,
dense_emb=dense_y,
num_frames=num_frames,
)
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
h = module(
h,
emb=t_emb,
context=y,
dense_emb=dense_y,
num_frames=num_frames,
)
h = h.type(x.dtype)
return self.out(h)
class SGMWrapper(nn.Module):
def __init__(self, module: Seva):
super().__init__()
self.module = module
def forward(
self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs
) -> torch.Tensor:
x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
return self.module(
x,
t=t,
y=c["crossattn"],
dense_y=c["dense_vector"],
**kwargs,
)