Spaces:
Running
Running
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from omegaconf import OmegaConf | |
| import numpy as np | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| from torch.nn.utils import weight_norm | |
| from transformers import T5EncoderModel, T5Tokenizer # type: ignore | |
| from einops import rearrange | |
| torch.backends.cuda.enable_mem_efficient_sdp(True) | |
| N_REPEAT = 2 # num (virtual batch_size) clones of audio sounds | |
| def _shift(x): | |
| #print(x.shape, 'BATCH Independent SHIFT\n AudioGen') | |
| for i, _slice in enumerate(x): | |
| n = x.shape[2] | |
| offset = np.random.randint(.24 * n, max(1, .74 * n)) # high should be above >= 0 TBD | |
| print(offset) | |
| x[i, :, :] = torch.roll(_slice, offset, dims=1) # _slice 2D | |
| return x | |
| class AudioGen(torch.nn.Module): | |
| # https://huggingface.co/facebook/audiogen-medium | |
| def __init__(self): | |
| super().__init__() | |
| _file_1 = hf_hub_download( | |
| repo_id='facebook/audiogen-medium', | |
| filename="compression_state_dict.bin", | |
| cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None), | |
| library_name="audiocraft", | |
| library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__) | |
| pkg = torch.load(_file_1, map_location='cpu')# kwargs = OmegaConf.create(pkg['xp.cfg']) | |
| self.compression_model = EncodecModel() | |
| self.compression_model.load_state_dict(pkg['best_state'], strict=False) | |
| self.compression_model.eval() # ckpt has also unused encoder weights | |
| self._chunk_len = 476 | |
| _file_2 = hf_hub_download( | |
| repo_id='facebook/audiogen-medium', | |
| filename="state_dict.bin", | |
| cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None), | |
| library_name="audiocraft", | |
| library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__) | |
| pkg = torch.load(_file_2, map_location='cpu') | |
| cfg = OmegaConf.create(pkg['xp.cfg']) # CFG inside torch bin | |
| _best = pkg['best_state'] | |
| _best['t5.output_proj.weight'] = _best.pop('condition_provider.conditioners.description.output_proj.weight')#.to(torch.float) | |
| _best['t5.output_proj.bias'] = _best.pop('condition_provider.conditioners.description.output_proj.bias')#.to(torch.float) | |
| self.lm = LMModel() | |
| self.lm.load_state_dict(pkg['best_state'], strict=True) | |
| self.lm.eval() | |
| def generate(self, | |
| prompt='dogs mewo', | |
| duration=2.24, # seconds of audio | |
| cache_lim=71, # flush kv cache after cache_lim tok | |
| ): | |
| torch.manual_seed(42) # https://github.com/facebookresearch/audiocraft/issues/111#issuecomment-1614732858 | |
| self.lm.cache_lim = cache_lim | |
| self.lm.n_draw = int(.8 * duration) + 1 # different beam every 0.47 seconds of audio | |
| with torch.autocast(device_type='cpu', dtype=torch.bfloat16): | |
| gen_tokens = self.lm.generate( | |
| text_condition=[prompt] * N_REPEAT + [''] * N_REPEAT,#['dogs', 'dogs...!', '', ''] | |
| max_tokens=int(.04 * duration / N_REPEAT * self.compression_model.frame_rate) + 12) # [bs, 4, 74*self.lm.n_draw] | |
| # OOM if vocode all tokens | |
| x = [] | |
| for i in range(7, gen_tokens.shape[2], self._chunk_len): # min soundscape 2s assures 10 tokens | |
| decoded_chunk = self.compression_model.decode(gen_tokens[:, :, i-7:i+self._chunk_len]) | |
| x.append(decoded_chunk) | |
| x = torch.cat(x, 2) # [bs, 1, 114000] | |
| x = _shift(x) # clone() to have xN | |
| return x.reshape(-1) #x / (x.abs().max() + 1e-7) | |
| class EncodecModel(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.decoder = SEANetDecoder() | |
| self.quantizer = ResidualVectorQuantizer() | |
| self.frame_rate = 50 | |
| def decode(self, codes): | |
| # B,K,T -> B,C,T | |
| emb = self.quantizer.decode(codes) | |
| return self.decoder(emb) | |
| class StreamableLSTM(nn.Module): | |
| def __init__(self, | |
| dimension, | |
| num_layers=2, | |
| skip=True): | |
| super().__init__() | |
| self.skip = skip | |
| self.lstm = nn.LSTM(dimension, dimension, num_layers) | |
| def forward(self, x): | |
| x = x.permute(2, 0, 1) | |
| y, _ = self.lstm(x) | |
| if self.skip: | |
| y = y + x | |
| y = y.permute(1, 2, 0) | |
| return y | |
| class SEANetResnetBlock(nn.Module): | |
| def __init__(self, | |
| dim, | |
| kernel_sizes = [3, 1], | |
| pad_mode = 'reflect', | |
| compress = 2): | |
| super().__init__() | |
| hidden = dim // compress | |
| block = [] | |
| for i, kernel_size in enumerate(kernel_sizes): | |
| in_chs = dim if i == 0 else hidden | |
| out_chs = dim if i == len(kernel_sizes) - 1 else hidden | |
| block += [nn.ELU(), | |
| StreamableConv1d(in_chs, | |
| out_chs, | |
| kernel_size=kernel_size, | |
| pad_mode=pad_mode)] | |
| self.block = nn.Sequential(*block) | |
| def forward(self, x): | |
| return x + self.block(x) | |
| class SEANetDecoder(nn.Module): | |
| # channels=1 dimension=128 n_filters=64 n_residual_layers=1 ratios=[8, 5, 4, 2] | |
| # activation='ELU' activation_params={'alpha': 1.0}, final_activation=None | |
| # final_activation_params=None norm='weight_norm' | |
| # norm_params={} kernel_size=7 last_kernel_size=7 residual_kernel_size=3 dilation_base=2 | |
| # causal=False pad_mode='constant' | |
| # true_skip=True compress=2 lstm=2 disable_norm_outer_blocks=0 trim_right_ratio=1.0 | |
| def __init__(self, | |
| channels = 1, | |
| dimension = 128, | |
| n_filters = 64, | |
| n_residual_layers = 1, | |
| ratios = [8, 5, 4, 2], | |
| kernel_size = 7, | |
| last_kernel_size = 7, | |
| residual_kernel_size = 3, | |
| pad_mode = 'constant', | |
| compress = 2, | |
| lstm = 2): | |
| super().__init__() | |
| mult = int(2 ** len(ratios)) | |
| model = [ | |
| StreamableConv1d(dimension, mult * n_filters, | |
| kernel_size, | |
| pad_mode=pad_mode) | |
| ] | |
| if lstm: | |
| print('\n\n\n\nLSTM IN SEANET\n\n\n\n') | |
| model += [StreamableLSTM(mult * n_filters, | |
| num_layers=lstm)] | |
| # Upsample to raw audio scale | |
| for i, ratio in enumerate(ratios): | |
| model += [ | |
| nn.ELU(), | |
| StreamableConvTranspose1d(mult * n_filters, | |
| mult * n_filters // 2, | |
| kernel_size=ratio * 2, | |
| stride=ratio), | |
| ] | |
| # Add residual layers | |
| for j in range(n_residual_layers): | |
| model += [ | |
| SEANetResnetBlock(mult * n_filters // 2, | |
| kernel_sizes=[residual_kernel_size, 1], | |
| pad_mode=pad_mode, | |
| compress=compress)] | |
| mult //= 2 | |
| # Add final layers | |
| model += [ | |
| nn.ELU(), | |
| StreamableConv1d(n_filters, | |
| channels, | |
| last_kernel_size, | |
| pad_mode=pad_mode)] | |
| self.model=nn.Sequential(*model) | |
| def forward(self, z): | |
| return self.model(z) | |
| def unpad1d(x, paddings): | |
| padding_left, padding_right = paddings | |
| end = x.shape[-1] - padding_right | |
| return x[..., padding_left: end] | |
| class NormConv1d(nn.Module): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__() | |
| self.conv = weight_norm(nn.Conv1d(*args, **kwargs)) # norm = weight_norm | |
| def forward(self, x): | |
| return self.conv(x) | |
| class NormConvTranspose1d(nn.Module): | |
| def __init__(self, *args, causal: bool = False, norm: str = 'none', | |
| norm_kwargs = {}, **kwargs): | |
| super().__init__() | |
| self.convtr = weight_norm(nn.ConvTranspose1d(*args, **kwargs)) | |
| def forward(self, x): | |
| return self.convtr(x) | |
| class StreamableConv1d(nn.Module): | |
| def __init__(self, | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride=1, | |
| groups=1, | |
| bias=True, | |
| pad_mode='reflect'): | |
| super().__init__() | |
| if (stride != 1) or (groups != 1): | |
| raise ValueError | |
| self.conv = NormConv1d(in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride, | |
| groups=groups, | |
| bias=bias) | |
| self.pad_mode = pad_mode | |
| def forward(self, x): | |
| kernel_size = self.conv.conv.kernel_size[0] | |
| kernel_size = (kernel_size - 1) * self.conv.conv.dilation[0] + 1 | |
| padding_total = kernel_size - self.conv.conv.stride[0] | |
| padding_right = padding_total // 2 | |
| padding_left = padding_total - padding_right | |
| # x = pad1d(x, (padding_left, padding_right), mode=self.pad_mode) | |
| x = F.pad(x, (padding_left, padding_right), self.pad_mode) | |
| return self.conv(x) | |
| class StreamableConvTranspose1d(nn.Module): | |
| def __init__(self, in_channels: int, out_channels: int, | |
| kernel_size: int, stride: int = 1, causal: bool = False, | |
| norm: str = 'none', trim_right_ratio: float = 1., | |
| norm_kwargs = {}): | |
| super().__init__() | |
| self.convtr = NormConvTranspose1d(in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride) | |
| def forward(self, x): | |
| padding_total = self.convtr.convtr.kernel_size[0] - self.convtr.convtr.stride[0] | |
| y = self.convtr(x) | |
| # Asymmetric padding required for odd strides | |
| # print('\n \n\n\nn\n\n\nnANTICAUSAL T\n\n\n') | |
| padding_right = padding_total // 2 | |
| padding_left = padding_total - padding_right | |
| y = unpad1d(y, (padding_left, padding_right)) | |
| return y | |
| # VQ | |
| class EuclideanCodebook(nn.Module): | |
| def __init__(self, | |
| dim, | |
| codebook_size): | |
| super().__init__() | |
| self.register_buffer("embed", torch.zeros(codebook_size, dim)) | |
| class VectorQuantization(nn.Module): | |
| def __init__(self, | |
| dim, | |
| codebook_size): | |
| super().__init__() | |
| self._codebook = EuclideanCodebook(dim=dim, | |
| codebook_size=codebook_size) | |
| def decode(self, _ind): | |
| return F.embedding(_ind, self._codebook.embed) | |
| class ResidualVectorQuantization(nn.Module): | |
| def __init__(self, *, num_quantizers, **kwargs): | |
| super().__init__() | |
| self.layers = nn.ModuleList( | |
| [VectorQuantization(**kwargs) for _ in range(num_quantizers)] | |
| ) | |
| def decode(self, _ind): | |
| x = 0.0 | |
| for i, _code in enumerate(_ind): | |
| x = x + self.layers[i].decode(_code) | |
| return x.transpose(1, 2) | |
| class ResidualVectorQuantizer(nn.Module): | |
| # dimension=128 n_q=4 q_dropout=False bins=2048 decay=0.99 kmeans_init=True | |
| # kmeans_iters=50 threshold_ema_dead_code=2 | |
| # orthogonal_reg_weight=0.0 orthogonal_reg_active_codes_only=False | |
| # orthogonal_reg_max_codes=None | |
| def __init__( | |
| self, | |
| dimension = 128, | |
| n_q = 4, | |
| bins = 2048 | |
| ): | |
| super().__init__() | |
| self.vq = ResidualVectorQuantization(dim=dimension, | |
| codebook_size=bins, | |
| num_quantizers=n_q) | |
| def decode(self, codes): | |
| # codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T]. | |
| return self.vq.decode(codes.transpose(0, 1)) | |
| class T5(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.output_proj = nn.Linear(1024, # t5-large | |
| 1536) # lm hidden | |
| self.t5_tokenizer = T5Tokenizer.from_pretrained('t5-large', legacy=True) | |
| t5 = T5EncoderModel.from_pretrained('t5-large').train(mode=False) | |
| # this makes sure that the t5 is not part | |
| # of the saved checkpoint | |
| self.__dict__['t5'] = t5.to('cpu') | |
| def forward(self, prompt): | |
| with torch.set_grad_enabled(False): #, torch.autocast(device_type='cpu', dtype=torch.float32): | |
| bs = len(prompt) // 2 | |
| d = self.t5_tokenizer(prompt, | |
| return_tensors='pt', | |
| padding=True).to(self.output_proj.bias.device) | |
| d['attention_mask'][bs:, :] = 0 # null condition t5 attn_mask should be zero | |
| x = self.t5(input_ids=d['input_ids'], | |
| attention_mask=d['attention_mask']).last_hidden_state # no kv | |
| # Float 16 | |
| # > self.output_proj() is outside of autocast of t5 - however inside the autocast of lm thus computed in torch.float16 | |
| x = self.output_proj(x) # nn.Linear() - produces different result if there is no duplicate txt condition here | |
| x[bs:, :, :] = 0 # venv/../site-packages/audiocraft/modules/conditioners.py -> tokenize() | |
| return x | |
| class LMModel(nn.Module): | |
| def __init__(self, | |
| n_q = 4, | |
| card = 2048, | |
| dim = 1536 | |
| ): | |
| super().__init__() | |
| self.cache_lim = -1 | |
| self.t5 = T5() | |
| self.card = card # 2048 | |
| self.n_draw = 1 # draw > 1 tokens of different CFG scale | |
| # batch size > 1 is slower from n_draw as calls transformer on larger batch | |
| self.emb = nn.ModuleList([nn.Embedding(self.card + 1, dim) for _ in range(n_q)]) # EMBEDDING HAS 2049 | |
| self.transformer = StreamingTransformer() | |
| self.out_norm = nn.LayerNorm(dim, eps=1e-5) | |
| self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=False) for _ in range(n_q)]) # LINEAR DOESNT HAVE 2049 | |
| def forward(self, | |
| sequence, | |
| condition_tensors=None, | |
| cache_position=None | |
| ): | |
| bs, n_q, time_frames = sequence.shape # [bs, 4, time] | |
| input_ = sum([self.emb[k](sequence[:, k]) for k in range(n_q)]) | |
| out = self.transformer(torch.cat([input_, input_], 0), # duplicate null condition (bs x 2) for ClassifierFreeGuidance | |
| cross_attention_src=condition_tensors, | |
| cache_position=cache_position) | |
| out = self.out_norm(out) | |
| logits = torch.stack([self.linears[k](out) for k in range(n_q)], dim=1) # [2*bs, 4, 1, 2048] | |
| logits = 3 * logits[:bs, :, :, :] - self._scale * logits[bs:, :, :, :] # [ bs, 4, n_draw, 2048] | |
| #bs, n_q, n_draw, vocab = logits.shape | |
| tokens = torch.multinomial(torch.softmax(logits.view(bs * self.n_draw * n_q, 2048), dim=1), | |
| num_samples=1) | |
| return tokens.view(bs, n_q, self.n_draw).transpose(1, 2) | |
| def generate(self, | |
| max_tokens=None, | |
| text_condition=None | |
| ): | |
| x = self.t5(text_condition) | |
| bs = x.shape[0] // 2 # has null conditions - bs*2*N_REPEAT applys in builders.py | |
| self._scale = .3 * torch.rand(1, 1, self.n_draw, 1, device=x.device) + 1.94 | |
| cache_position = 0 | |
| out_codes = torch.full((bs, | |
| self.n_draw, | |
| 4, | |
| 4 + 3 + max_tokens), # 4 + max_tokens + 4-1 to have sufficient to index the 1st antidiagonal of 4x4 + 4 xtra tokens | |
| self.card, | |
| dtype=torch.long, | |
| device=x.device) # [bs, n_draw, 4, dur] | |
| # A/R | |
| for offset in range(0, max_tokens + 4 - 1): # max_tokens + n_q - 1 | |
| # extract diagonal via indexing out_codes[ [0, 1, 2, 3], [0, 1, 2, 3] ] | |
| next_token = self.forward(out_codes[:, 0, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset][:, :, None], # index diagonal & exapnd to [bs, n_q, dur=1] | |
| #gen_sequence[:, 0, :, offset-1:offset], # DIAGINDEXING for setting prediction of lm into gen_sequence THE GENSEQUENCE has to be un-delayed in the end [Because it has to be de-delayed for the vocoder then is actually only the lm input that requires to see the delay thus we could just feed by diaggather] so it matches gen_codes -1 a[[0, 1, 2, 3], torch.tensor([0, 1, 2, 3]) + 5] the gen_sequence is indexed by vertical column and fed to lm however the prediction of lm is place diagonally with delay to the gen_sequence | |
| condition_tensors=x, # utilisation of the attention mask of txt condition ? | |
| cache_position=cache_position) # [bs, n_draw, 4] | |
| # Fill of next_token should be also placed on antidiagonal [not column] | |
| # Do Not Overwrite 2048 of TRIU/TRIL = START/END => Do Not Fill them by Predicted Tokens | |
| # 0-th antidiagonal should be full of card = [2048, 2048, 2048, 2048] | |
| # | |
| # [2048, 2048, 2048, 2048, 0, 1, 2, 3, 4, 5, 6, 2048, 2048, 2048], | |
| # [2048, 2048, 2048, 2048, 2048, 0, 1, 2, 3, 4, 5, 6, 2048, 2048], | |
| # [2048, 2048, 2048, 2048, 2048, 2048, 0, 1, 2, 3, 4, 5, 6, 2048], | |
| # [2048, 2048, 2048, 2048, 2048, 2048, 2048, 0, 1, 2, 3, 4, 5, 6]] | |
| # NO OVerWriting | |
| if offset == 0: | |
| next_token[:, :, 1:4] = 2048 # self.card - bottom 3 entries of the antidiagonal should remain 2048 | |
| elif offset == 1: | |
| next_token[:, :, 2:4] = 2048 # bottom 2 entries of the antidiagonal should remain 2048 | |
| elif offset == 2: | |
| next_token[:, :, 3:4] = 2048 | |
| elif offset == max_tokens: | |
| next_token[:, :, 0:1] = 2048 # top 1 entry of the antidiagonal should stay to 2048 | |
| elif offset == (max_tokens + 1): | |
| next_token[:, :, 0:2] = 2048 | |
| elif offset == (max_tokens + 2): | |
| next_token[:, :, 0:3] = 2048 | |
| else: # offset 3,4,5,6,7...... max_tokens-1 # FILL Complete n_q = 4 ANTIDIAGONAL ENTRIES | |
| pass #print('No delete anti-diag') | |
| out_codes[:, :, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset + 1] = next_token | |
| # Sink Attn | |
| if (offset > 0) and (offset % self.cache_lim) == 0: | |
| n_preserve = 4 | |
| self.transformer._flush(n_preserve=n_preserve) | |
| cache_position = n_preserve | |
| else: | |
| cache_position += 1 | |
| # [bs, n_draw, 4, time+xtra] -> [bs, 4, n_draw, time] -> [bs, 4, time * n_draw] | |
| out_codes = out_codes[:, :, :, 4:max_tokens+4].transpose(1, 2).reshape(bs, 4, self.n_draw * max_tokens) | |
| # flush for next API call | |
| self.transformer._flush() | |
| return out_codes # SKIP THE 4 fill 2048 | |
| def create_sin_embedding(positions, | |
| dim, | |
| max_period=10000 | |
| ): | |
| # assert dim % 2 == 0 | |
| half_dim = dim // 2 | |
| positions = positions.to(torch.float) | |
| adim = torch.arange(half_dim, device=positions.device, | |
| dtype=torch.float).view(1, 1, -1) | |
| max_period_tensor = torch.full([], | |
| max_period, | |
| device=positions.device, | |
| dtype=torch.float) # avoid sync point | |
| phase = positions / (max_period_tensor ** (adim / (half_dim - 1))) | |
| # OFFICIAL is torch.float32 HOWEVER self_attn.in_prod_weight = torch.float16 | |
| return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1) | |
| class StreamingMultiheadAttention(nn.Module): | |
| def __init__(self, | |
| embed_dim, | |
| num_heads, | |
| cross_attention=False, | |
| ): | |
| super().__init__() | |
| self.cross_attention = cross_attention | |
| # if not self.cross_attention then it has kvcachingn | |
| self.k_history = None | |
| # cleanup history through LM inside GENERATION - Each 0,..,47 mha has different kv history | |
| self.v_history = None | |
| self.num_heads = num_heads | |
| self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False) | |
| self.register_buffer('in_proj_weight', torch.ones((3 * embed_dim, embed_dim), | |
| dtype=torch.float)) | |
| def forward(self, | |
| query, | |
| key=None, | |
| value=None): | |
| layout = "b h t d" | |
| if self.cross_attention: | |
| # Different queries, keys, values > split in_proj_weight | |
| dim = self.in_proj_weight.shape[0] // 3 | |
| q = nn.functional.linear(query, self.in_proj_weight[:dim]) | |
| k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim]) | |
| v = nn.functional.linear(value, self.in_proj_weight[2 * dim:]) | |
| q, k, v = [ | |
| rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]] | |
| else: | |
| # Here <else> = self_attention for audio with itself (above is cross attention txt) | |
| # HISTORY - DIFFERENT FOR EACH TRANSF LAYER | |
| # here we have different floating values from official | |
| projected = nn.functional.linear(query, self.in_proj_weight, None) | |
| # print(query.sum(), projected.sum() , self.in_proj_weight.sum(), 'Lc') # verified official AudioGen values | |
| bound_layout = "b h p t d" | |
| packed = rearrange( | |
| projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads) | |
| q, k, v = packed.unbind(dim=2) | |
| if self.k_history is not None: | |
| # IF ctrl^c during live_demo the assigning of each of kv is non-atomic k!=v | |
| # thus it will try to continue with incompatible k/v dims! | |
| self.k_history = torch.cat([self.k_history, k], 2) | |
| self.v_history = torch.cat([self.v_history, v], 2) | |
| else: | |
| self.k_history = k | |
| self.v_history = v | |
| # Assign Completed k / v to k / v | |
| k = self.k_history | |
| v = self.v_history | |
| # -> kv CACHE ONLY APPLIES if not self.cross_attention | |
| x = torch.nn.functional.scaled_dot_product_attention( | |
| q, k, v, attn_mask=None, is_causal=False, dropout_p=0.0) | |
| x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads) | |
| x = self.out_proj(x) | |
| return x | |
| class StreamingTransformerLayer(nn.Module): | |
| def __init__(self, | |
| d_model, | |
| num_heads, | |
| dim_feedforward): | |
| super().__init__() | |
| self.self_attn = StreamingMultiheadAttention(embed_dim=d_model, | |
| num_heads=num_heads) | |
| self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False) | |
| self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False) | |
| self.cross_attention = StreamingMultiheadAttention(embed_dim=d_model, | |
| num_heads=num_heads, | |
| cross_attention=True) | |
| self.norm_cross = nn.LayerNorm(d_model, eps=1e-5) | |
| self.norm1 = nn.LayerNorm(d_model, eps=1e-5) | |
| self.norm2 = nn.LayerNorm(d_model, eps=1e-5) | |
| def forward(self, | |
| x, | |
| cross_attention_src=None): | |
| x = x + self.self_attn(self.norm1(x)) | |
| x = x + self.cross_attention(query=self.norm_cross(x), | |
| key=cross_attention_src, | |
| value=cross_attention_src) # txtcondition | |
| x = x + self.linear2(F.gelu(self.linear1(self.norm2(x)))) | |
| return x | |
| class StreamingTransformer(nn.Module): | |
| def __init__(self, | |
| d_model=1536, | |
| num_heads=24, | |
| num_layers=48, | |
| dim_feedforward=6144): | |
| super().__init__() | |
| self.layers = nn.ModuleList( | |
| [ | |
| StreamingTransformerLayer(d_model=d_model, | |
| num_heads=num_heads, | |
| dim_feedforward=dim_feedforward) for _ in range(num_layers) | |
| ] | |
| ) | |
| def forward(self, | |
| x, | |
| cache_position=None, | |
| cross_attention_src=None): | |
| x = x + create_sin_embedding( | |
| torch.zeros(x.shape[0], 1, 1, device=x.device) + cache_position, 1536) | |
| for lay in self.layers: | |
| x = lay(x, | |
| cross_attention_src=cross_attention_src) | |
| return x | |
| def _flush(self, | |
| n_preserve=None): | |
| for lay in self.layers: | |
| if n_preserve is not None: | |
| # cache position is difficult to choose to also preserve kv from end | |
| lay.self_attn.k_history = lay.self_attn.k_history[:, :, :n_preserve, :] | |
| lay.self_attn.v_history = lay.self_attn.v_history[:, :, :n_preserve, :] | |
| else: | |
| lay.self_attn.k_history = None | |
| lay.self_attn.v_history = None | |
| if __name__ == '__main__': | |
| import audiofile | |
| model = AudioGen().to('cpu') | |
| x = model.generate(prompt='swims in lake frogs', duration=6.4).cpu().numpy() | |
| audiofile.write('_sound_.wav', x, 16000) | |