from dataclasses import dataclass, field import torch import torch.nn as nn from seva.modules.layers import ( Downsample, GroupNorm32, ResBlock, TimestepEmbedSequential, Upsample, timestep_embedding, ) from seva.modules.transformer import MultiviewTransformer @dataclass class SevaParams(object): in_channels: int = 11 model_channels: int = 320 out_channels: int = 4 num_frames: int = 21 num_res_blocks: int = 2 attention_resolutions: list[int] = field(default_factory=lambda: [4, 2, 1]) channel_mult: list[int] = field(default_factory=lambda: [1, 2, 4, 4]) num_head_channels: int = 64 transformer_depth: list[int] = field(default_factory=lambda: [1, 1, 1, 1]) context_dim: int = 1024 dense_in_channels: int = 6 dropout: float = 0.0 unflatten_names: list[str] = field( default_factory=lambda: ["middle_ds8", "output_ds4", "output_ds2"] ) def __post_init__(self): assert len(self.channel_mult) == len(self.transformer_depth) class Seva(nn.Module): def __init__(self, params: SevaParams) -> None: super().__init__() self.params = params self.model_channels = params.model_channels self.out_channels = params.out_channels self.num_head_channels = params.num_head_channels time_embed_dim = params.model_channels * 4 self.time_embed = nn.Sequential( nn.Linear(params.model_channels, time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim), ) self.input_blocks = nn.ModuleList( [ TimestepEmbedSequential( nn.Conv2d(params.in_channels, params.model_channels, 3, padding=1) ) ] ) self._feature_size = params.model_channels input_block_chans = [params.model_channels] ch = params.model_channels ds = 1 for level, mult in enumerate(params.channel_mult): for _ in range(params.num_res_blocks): input_layers: list[ResBlock | MultiviewTransformer | Downsample] = [ ResBlock( channels=ch, emb_channels=time_embed_dim, out_channels=mult * params.model_channels, dense_in_channels=params.dense_in_channels, dropout=params.dropout, ) ] ch = mult * params.model_channels if ds in params.attention_resolutions: num_heads = ch // params.num_head_channels dim_head = params.num_head_channels input_layers.append( MultiviewTransformer( ch, num_heads, dim_head, name=f"input_ds{ds}", depth=params.transformer_depth[level], context_dim=params.context_dim, unflatten_names=params.unflatten_names, ) ) self.input_blocks.append(TimestepEmbedSequential(*input_layers)) self._feature_size += ch input_block_chans.append(ch) if level != len(params.channel_mult) - 1: ds *= 2 out_ch = ch self.input_blocks.append( TimestepEmbedSequential(Downsample(ch, out_channels=out_ch)) ) ch = out_ch input_block_chans.append(ch) self._feature_size += ch num_heads = ch // params.num_head_channels dim_head = params.num_head_channels self.middle_block = TimestepEmbedSequential( ResBlock( channels=ch, emb_channels=time_embed_dim, out_channels=None, dense_in_channels=params.dense_in_channels, dropout=params.dropout, ), MultiviewTransformer( ch, num_heads, dim_head, name=f"middle_ds{ds}", depth=params.transformer_depth[-1], context_dim=params.context_dim, unflatten_names=params.unflatten_names, ), ResBlock( channels=ch, emb_channels=time_embed_dim, out_channels=None, dense_in_channels=params.dense_in_channels, dropout=params.dropout, ), ) self._feature_size += ch self.output_blocks = nn.ModuleList([]) for level, mult in list(enumerate(params.channel_mult))[::-1]: for i in range(params.num_res_blocks + 1): ich = input_block_chans.pop() output_layers: list[ResBlock | MultiviewTransformer | Upsample] = [ ResBlock( channels=ch + ich, emb_channels=time_embed_dim, out_channels=params.model_channels * mult, dense_in_channels=params.dense_in_channels, dropout=params.dropout, ) ] ch = params.model_channels * mult if ds in params.attention_resolutions: num_heads = ch // params.num_head_channels dim_head = params.num_head_channels output_layers.append( MultiviewTransformer( ch, num_heads, dim_head, name=f"output_ds{ds}", depth=params.transformer_depth[level], context_dim=params.context_dim, unflatten_names=params.unflatten_names, ) ) if level and i == params.num_res_blocks: out_ch = ch ds //= 2 output_layers.append(Upsample(ch, out_ch)) self.output_blocks.append(TimestepEmbedSequential(*output_layers)) self._feature_size += ch self.out = nn.Sequential( GroupNorm32(32, ch), nn.SiLU(), nn.Conv2d(self.model_channels, params.out_channels, 3, padding=1), ) def forward( self, x: torch.Tensor, t: torch.Tensor, y: torch.Tensor, dense_y: torch.Tensor, num_frames: int | None = None, ) -> torch.Tensor: num_frames = num_frames or self.params.num_frames t_emb = timestep_embedding(t, self.model_channels) t_emb = self.time_embed(t_emb) hs = [] h = x for module in self.input_blocks: h = module( h, emb=t_emb, context=y, dense_emb=dense_y, num_frames=num_frames, ) hs.append(h) h = self.middle_block( h, emb=t_emb, context=y, dense_emb=dense_y, num_frames=num_frames, ) for module in self.output_blocks: h = torch.cat([h, hs.pop()], dim=1) h = module( h, emb=t_emb, context=y, dense_emb=dense_y, num_frames=num_frames, ) h = h.type(x.dtype) return self.out(h) class SGMWrapper(nn.Module): def __init__(self, module: Seva): super().__init__() self.module = module def forward( self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs ) -> torch.Tensor: x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1) return self.module( x, t=t, y=c["crossattn"], dense_y=c["dense_vector"], **kwargs, )