import torch import torch.nn as nn from torch.nn import functional as nnf from torch.utils.data import Dataset, DataLoader from enum import Enum from transformers import GPT2LMHeadModel from typing import Tuple, Optional, Union def get_clapcap(name: str): if name == "ClapCaption": return ClapCaptionModel else: raise Exception('The ClapCap model {} is incorrect or not supported'.format(name)) class MappingType(Enum): MLP = 'mlp' Transformer = 'transformer' class MLP(nn.Module): def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh): super(MLP, self).__init__() layers = [] for i in range(len(sizes) - 1): layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias)) if i < len(sizes) - 2: layers.append(act()) self.model = nn.Sequential(*layers) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.model(x) class MlpTransformer(nn.Module): def __init__(self, in_dim, h_dim, out_d: Optional[int] = None, act=nnf.relu, dropout=0.): super().__init__() out_d = out_d if out_d is not None else in_dim self.fc1 = nn.Linear(in_dim, h_dim) self.act = act self.fc2 = nn.Linear(h_dim, out_d) self.dropout = nn.Dropout(dropout) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.dropout(x) x = self.fc2(x) x = self.dropout(x) return x class MultiHeadAttention(nn.Module): def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.): super().__init__() self.num_heads = num_heads head_dim = dim_self // num_heads self.scale = head_dim ** -0.5 self.to_queries = nn.Linear(dim_self, dim_self, bias=bias) self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias) self.project = nn.Linear(dim_self, dim_self) self.dropout = nn.Dropout(dropout) def forward(self, x, y=None, mask=None): y = y if y is not None else x b, n, c = x.shape _, m, d = y.shape # b n h dh queries = self.to_queries(x).reshape(b, n, self.num_heads, c // self.num_heads) # b m 2 h dh keys_values = self.to_keys_values(y).reshape(b, m, 2, self.num_heads, c // self.num_heads) keys, values = keys_values[:, :, 0], keys_values[:, :, 1] attention = torch.einsum('bnhd,bmhd->bnmh', queries, keys) * self.scale if mask is not None: if mask.dim() == 2: mask = mask.unsqueeze(1) attention = attention.masked_fill(mask.unsqueeze(3), float("-inf")) attention = attention.softmax(dim=2) out = torch.einsum('bnmh,bmhd->bnhd', attention, values).reshape(b, n, c) out = self.project(out) return out, attention class TransformerLayer(nn.Module): def forward_with_attention(self, x, y=None, mask=None): x_, attention = self.attn(self.norm1(x), y, mask) x = x + x_ x = x + self.mlp(self.norm2(x)) return x, attention def forward(self, x, y=None, mask=None): x = x + self.attn(self.norm1(x), y, mask)[0] x = x + self.mlp(self.norm2(x)) return x def __init__(self, dim_self, dim_ref, num_heads, mlp_ratio=4., bias=False, dropout=0., act=nnf.relu, norm_layer: nn.Module = nn.LayerNorm): super().__init__() self.norm1 = norm_layer(dim_self) self.attn = MultiHeadAttention(dim_self, dim_ref, num_heads, bias=bias, dropout=dropout) self.norm2 = norm_layer(dim_self) self.mlp = MlpTransformer(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout) class Transformer(nn.Module): def __init__(self, dim_self: int, num_heads: int, num_layers: int, dim_ref: Optional[int] = None, mlp_ratio: float = 2., act=nnf.relu, norm_layer: nn.Module = nn.LayerNorm, enc_dec: bool = False): super(Transformer, self).__init__() dim_ref = dim_ref if dim_ref is not None else dim_self self.enc_dec = enc_dec if enc_dec: num_layers = num_layers * 2 layers = [] for i in range(num_layers): if i % 2 == 0 and enc_dec: # cross layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer)) elif enc_dec: # self layers.append(TransformerLayer(dim_self, dim_self, num_heads, mlp_ratio, act=act, norm_layer=norm_layer)) else: # self or cross layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer)) self.layers = nn.ModuleList(layers) def forward_with_attention(self, x, y=None, mask=None): attentions = [] for layer in self.layers: x, att = layer.forward_with_attention(x, y, mask) attentions.append(att) return x, attentions def forward(self, x, y=None, mask=None): for i, layer in enumerate(self.layers): if i % 2 == 0 and self.enc_dec: # cross x = layer(x, y) elif self.enc_dec: # self x = layer(x, x, mask) else: # self or cross x = layer(x, y, mask) return x class TransformerMapper(nn.Module): def __init__(self, dim_clip: int, dim_embedding: int, prefix_length: int, clip_length: int, num_layers: int = 8): super(TransformerMapper, self).__init__() self.clip_length = clip_length self.transformer = Transformer(dim_embedding, 8, num_layers) self.linear = nn.Linear(dim_clip, clip_length * dim_embedding) self.prefix_const = nn.Parameter(torch.randn(prefix_length, dim_embedding), requires_grad=True) def forward(self, x): x = self.linear(x).view(x.shape[0], self.clip_length, -1) prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape) prefix = torch.cat((x, prefix), dim=1) out = self.transformer(prefix)[:, self.clip_length:] return out class ClapCaptionModel(nn.Module): def __init__(self, clap, text_decoder: str, prefix_length: int, clip_length: Optional[int] = None, prefix_size: int = 512, num_layers: int = 8, normalize_prefix: bool = True, mapping_type: str = None,\ freeze_audio_encoder_weights: bool = True, freeze_gpt_weights: bool = True): super(ClapCaptionModel, self).__init__() self.clap = clap.audio_encoder self.prefix_length = prefix_length self.normalize_prefix = normalize_prefix self.gpt = GPT2LMHeadModel.from_pretrained(text_decoder) self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1] if mapping_type == 'mlp': self.clap_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2, self.gpt_embedding_size * prefix_length)) else: self.clap_project = TransformerMapper(prefix_size, self.gpt_embedding_size, prefix_length, clip_length, num_layers) # Freeze all CLAP parameters if freeze_audio_encoder_weights: for p in self.clap.parameters(): p.requires_grad = False if freeze_gpt_weights: for p in self.gpt.parameters(): p.requires_grad = False def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor: return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device) def forward(self, audios: torch.Tensor, tokens: torch.Tensor, mask: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None): # get audio embeddings prefix, _ = self.clap(audios) # normalize prefix (audio embedding) if self.normalize_prefix: prefix = prefix / prefix.norm(2, -1).reshape(-1,1) embedding_text = self.gpt.transformer.wte(tokens['input_ids']) prefix_projections = self.clap_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size) embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1) if labels is not None: dummy_token = self.get_dummy_token(tokens['input_ids'].shape[0], tokens['input_ids'].device) labels = torch.cat((dummy_token, tokens), dim=1) out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask) return out