Spaces:
Running
on
L40S
Running
on
L40S
from dataclasses import dataclass, field | |
import torch | |
import torch.nn as nn | |
from seva.modules.layers import ( | |
Downsample, | |
GroupNorm32, | |
ResBlock, | |
TimestepEmbedSequential, | |
Upsample, | |
timestep_embedding, | |
) | |
from seva.modules.transformer import MultiviewTransformer | |
class SevaParams(object): | |
in_channels: int = 11 | |
model_channels: int = 320 | |
out_channels: int = 4 | |
num_frames: int = 21 | |
num_res_blocks: int = 2 | |
attention_resolutions: list[int] = field(default_factory=lambda: [4, 2, 1]) | |
channel_mult: list[int] = field(default_factory=lambda: [1, 2, 4, 4]) | |
num_head_channels: int = 64 | |
transformer_depth: list[int] = field(default_factory=lambda: [1, 1, 1, 1]) | |
context_dim: int = 1024 | |
dense_in_channels: int = 6 | |
dropout: float = 0.0 | |
unflatten_names: list[str] = field( | |
default_factory=lambda: ["middle_ds8", "output_ds4", "output_ds2"] | |
) | |
def __post_init__(self): | |
assert len(self.channel_mult) == len(self.transformer_depth) | |
class Seva(nn.Module): | |
def __init__(self, params: SevaParams) -> None: | |
super().__init__() | |
self.params = params | |
self.model_channels = params.model_channels | |
self.out_channels = params.out_channels | |
self.num_head_channels = params.num_head_channels | |
time_embed_dim = params.model_channels * 4 | |
self.time_embed = nn.Sequential( | |
nn.Linear(params.model_channels, time_embed_dim), | |
nn.SiLU(), | |
nn.Linear(time_embed_dim, time_embed_dim), | |
) | |
self.input_blocks = nn.ModuleList( | |
[ | |
TimestepEmbedSequential( | |
nn.Conv2d(params.in_channels, params.model_channels, 3, padding=1) | |
) | |
] | |
) | |
self._feature_size = params.model_channels | |
input_block_chans = [params.model_channels] | |
ch = params.model_channels | |
ds = 1 | |
for level, mult in enumerate(params.channel_mult): | |
for _ in range(params.num_res_blocks): | |
input_layers: list[ResBlock | MultiviewTransformer | Downsample] = [ | |
ResBlock( | |
channels=ch, | |
emb_channels=time_embed_dim, | |
out_channels=mult * params.model_channels, | |
dense_in_channels=params.dense_in_channels, | |
dropout=params.dropout, | |
) | |
] | |
ch = mult * params.model_channels | |
if ds in params.attention_resolutions: | |
num_heads = ch // params.num_head_channels | |
dim_head = params.num_head_channels | |
input_layers.append( | |
MultiviewTransformer( | |
ch, | |
num_heads, | |
dim_head, | |
name=f"input_ds{ds}", | |
depth=params.transformer_depth[level], | |
context_dim=params.context_dim, | |
unflatten_names=params.unflatten_names, | |
) | |
) | |
self.input_blocks.append(TimestepEmbedSequential(*input_layers)) | |
self._feature_size += ch | |
input_block_chans.append(ch) | |
if level != len(params.channel_mult) - 1: | |
ds *= 2 | |
out_ch = ch | |
self.input_blocks.append( | |
TimestepEmbedSequential(Downsample(ch, out_channels=out_ch)) | |
) | |
ch = out_ch | |
input_block_chans.append(ch) | |
self._feature_size += ch | |
num_heads = ch // params.num_head_channels | |
dim_head = params.num_head_channels | |
self.middle_block = TimestepEmbedSequential( | |
ResBlock( | |
channels=ch, | |
emb_channels=time_embed_dim, | |
out_channels=None, | |
dense_in_channels=params.dense_in_channels, | |
dropout=params.dropout, | |
), | |
MultiviewTransformer( | |
ch, | |
num_heads, | |
dim_head, | |
name=f"middle_ds{ds}", | |
depth=params.transformer_depth[-1], | |
context_dim=params.context_dim, | |
unflatten_names=params.unflatten_names, | |
), | |
ResBlock( | |
channels=ch, | |
emb_channels=time_embed_dim, | |
out_channels=None, | |
dense_in_channels=params.dense_in_channels, | |
dropout=params.dropout, | |
), | |
) | |
self._feature_size += ch | |
self.output_blocks = nn.ModuleList([]) | |
for level, mult in list(enumerate(params.channel_mult))[::-1]: | |
for i in range(params.num_res_blocks + 1): | |
ich = input_block_chans.pop() | |
output_layers: list[ResBlock | MultiviewTransformer | Upsample] = [ | |
ResBlock( | |
channels=ch + ich, | |
emb_channels=time_embed_dim, | |
out_channels=params.model_channels * mult, | |
dense_in_channels=params.dense_in_channels, | |
dropout=params.dropout, | |
) | |
] | |
ch = params.model_channels * mult | |
if ds in params.attention_resolutions: | |
num_heads = ch // params.num_head_channels | |
dim_head = params.num_head_channels | |
output_layers.append( | |
MultiviewTransformer( | |
ch, | |
num_heads, | |
dim_head, | |
name=f"output_ds{ds}", | |
depth=params.transformer_depth[level], | |
context_dim=params.context_dim, | |
unflatten_names=params.unflatten_names, | |
) | |
) | |
if level and i == params.num_res_blocks: | |
out_ch = ch | |
ds //= 2 | |
output_layers.append(Upsample(ch, out_ch)) | |
self.output_blocks.append(TimestepEmbedSequential(*output_layers)) | |
self._feature_size += ch | |
self.out = nn.Sequential( | |
GroupNorm32(32, ch), | |
nn.SiLU(), | |
nn.Conv2d(self.model_channels, params.out_channels, 3, padding=1), | |
) | |
def forward( | |
self, | |
x: torch.Tensor, | |
t: torch.Tensor, | |
y: torch.Tensor, | |
dense_y: torch.Tensor, | |
num_frames: int | None = None, | |
) -> torch.Tensor: | |
num_frames = num_frames or self.params.num_frames | |
t_emb = timestep_embedding(t, self.model_channels) | |
t_emb = self.time_embed(t_emb) | |
hs = [] | |
h = x | |
for module in self.input_blocks: | |
h = module( | |
h, | |
emb=t_emb, | |
context=y, | |
dense_emb=dense_y, | |
num_frames=num_frames, | |
) | |
hs.append(h) | |
h = self.middle_block( | |
h, | |
emb=t_emb, | |
context=y, | |
dense_emb=dense_y, | |
num_frames=num_frames, | |
) | |
for module in self.output_blocks: | |
h = torch.cat([h, hs.pop()], dim=1) | |
h = module( | |
h, | |
emb=t_emb, | |
context=y, | |
dense_emb=dense_y, | |
num_frames=num_frames, | |
) | |
h = h.type(x.dtype) | |
return self.out(h) | |
class SGMWrapper(nn.Module): | |
def __init__(self, module: Seva): | |
super().__init__() | |
self.module = module | |
def forward( | |
self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs | |
) -> torch.Tensor: | |
x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1) | |
return self.module( | |
x, | |
t=t, | |
y=c["crossattn"], | |
dense_y=c["dense_vector"], | |
**kwargs, | |
) | |