Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| class ContextualAlphaMask(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int = 768, | |
| ): | |
| super(ContextualAlphaMask, self).__init__() | |
| self.dim = dim | |
| half_dim = dim // 2 | |
| quarter_dim = dim // 4 | |
| self.fc1 = nn.Linear(self.dim, self.dim) | |
| self.fc2 = nn.Linear(self.dim, half_dim) | |
| self.norm1 = nn.LayerNorm(half_dim) | |
| self.fc3 = nn.Linear(half_dim, half_dim) | |
| self.fc4 = nn.Linear(half_dim, quarter_dim) | |
| self.norm2 = nn.LayerNorm(quarter_dim) | |
| self.fc5 = nn.Linear(quarter_dim, quarter_dim) | |
| self.fc6 = nn.Linear(quarter_dim, 1) | |
| # set fc6 weights to near zero | |
| self.fc6.weight.data.normal_(mean=0.0, std=0.0001) | |
| self.act_fn = nn.GELU() | |
| def forward(self, x): | |
| # x = (batch_size, 77, 768) | |
| x = self.fc1(x) | |
| x = self.act_fn(x) | |
| x = self.fc2(x) | |
| x = self.norm1(x) | |
| x = self.act_fn(x) | |
| x = self.fc3(x) | |
| x = self.act_fn(x) | |
| x = self.fc4(x) | |
| x = self.norm2(x) | |
| x = self.act_fn(x) | |
| x = self.fc5(x) | |
| x = self.act_fn(x) | |
| x = self.fc6(x) | |
| x = torch.sigmoid(x) | |
| return x | |
| class ZipperModule(nn.Module): | |
| def __init__( | |
| self, | |
| in_size, | |
| in_tokens, | |
| out_size, | |
| out_tokens, | |
| hidden_size, | |
| hidden_tokens, | |
| use_residual=False, | |
| ): | |
| super().__init__() | |
| self.in_size = in_size | |
| self.in_tokens = in_tokens | |
| self.out_size = out_size | |
| self.out_tokens = out_tokens | |
| self.hidden_size = hidden_size | |
| self.hidden_tokens = hidden_tokens | |
| self.use_residual = use_residual | |
| self.act_fn = nn.GELU() | |
| self.layernorm = nn.LayerNorm(self.in_size) | |
| self.conv1 = nn.Conv1d(self.in_tokens, self.hidden_tokens, 1) | |
| # act | |
| self.fc1 = nn.Linear(self.in_size, self.hidden_size) | |
| # act | |
| self.conv2 = nn.Conv1d(self.hidden_tokens, self.out_tokens, 1) | |
| # act | |
| self.fc2 = nn.Linear(self.hidden_size, self.out_size) | |
| def forward(self, x): | |
| residual = x | |
| x = self.layernorm(x) | |
| x = self.conv1(x) | |
| x = self.act_fn(x) | |
| x = self.fc1(x) | |
| x = self.act_fn(x) | |
| x = self.conv2(x) | |
| x = self.act_fn(x) | |
| x = self.fc2(x) | |
| if self.use_residual: | |
| x = x + residual | |
| return x | |
| class ZipperResampler(nn.Module): | |
| def __init__( | |
| self, | |
| in_size, | |
| in_tokens, | |
| out_size, | |
| out_tokens, | |
| hidden_size, | |
| hidden_tokens, | |
| num_blocks=1, | |
| is_conv_input=False, | |
| ): | |
| super().__init__() | |
| self.is_conv_input = is_conv_input | |
| module_list = [] | |
| for i in range(num_blocks): | |
| this_in_size = in_size | |
| this_in_tokens = in_tokens | |
| this_out_size = out_size | |
| this_out_tokens = out_tokens | |
| this_hidden_size = hidden_size | |
| this_hidden_tokens = hidden_tokens | |
| use_residual = False | |
| # maintain middle sizes as hidden_size | |
| if i == 0: # first block | |
| this_in_size = in_size | |
| this_in_tokens = in_tokens | |
| if num_blocks == 1: | |
| this_out_size = out_size | |
| this_out_tokens = out_tokens | |
| else: | |
| this_out_size = hidden_size | |
| this_out_tokens = hidden_tokens | |
| elif i == num_blocks - 1: # last block | |
| this_out_size = out_size | |
| this_out_tokens = out_tokens | |
| if num_blocks == 1: | |
| this_in_size = in_size | |
| this_in_tokens = in_tokens | |
| else: | |
| this_in_size = hidden_size | |
| this_in_tokens = hidden_tokens | |
| else: # middle blocks | |
| this_out_size = hidden_size | |
| this_out_tokens = hidden_tokens | |
| this_in_size = hidden_size | |
| this_in_tokens = hidden_tokens | |
| use_residual = True | |
| module_list.append(ZipperModule( | |
| in_size=this_in_size, | |
| in_tokens=this_in_tokens, | |
| out_size=this_out_size, | |
| out_tokens=this_out_tokens, | |
| hidden_size=this_hidden_size, | |
| hidden_tokens=this_hidden_tokens, | |
| use_residual=use_residual | |
| )) | |
| self.blocks = nn.ModuleList(module_list) | |
| self.ctx_alpha = ContextualAlphaMask( | |
| dim=out_size, | |
| ) | |
| def forward(self, x): | |
| if self.is_conv_input: | |
| # flatten | |
| x = x.view(x.size(0), x.size(1), -1) | |
| # rearrange to (batch, tokens, size) | |
| x = x.permute(0, 2, 1) | |
| for block in self.blocks: | |
| x = block(x) | |
| alpha = self.ctx_alpha(x) | |
| return x * alpha | |