# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu) # 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn) # 2024 Alibaba Inc (Xiang Lyu) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Modified from ESPnet(https://github.com/espnet/espnet) """Encoder definition.""" from typing import Tuple, List import torch from torch import nn from torch.nn import functional as F from cosyvoice2.transformer.encoder_layer import ConformerEncoderLayer from cosyvoice2.transformer.positionwise_feed_forward import PositionwiseFeedForward from cosyvoice2.utils.class_utils import ( COSYVOICE_EMB_CLASSES, COSYVOICE_SUBSAMPLE_CLASSES, COSYVOICE_ATTENTION_CLASSES, COSYVOICE_ACTIVATION_CLASSES, ) from cosyvoice2.utils.mask import ( make_pad_mask, ) import torch._dynamo torch._dynamo.config.suppress_errors = True torch._dynamo.config.cache_size_limit = 128 class Upsample1D(nn.Module): """A 1D upsampling layer with an optional convolution. Parameters: channels (`int`): number of channels in the inputs and outputs. use_conv (`bool`, default `False`): option to use a convolution. use_conv_transpose (`bool`, default `False`): option to use a convolution transpose. out_channels (`int`, optional): number of output channels. Defaults to `channels`. """ def __init__(self, channels: int, out_channels: int, stride: int = 2, scale_factor: float = None): super().__init__() self.channels = channels self.out_channels = out_channels self.stride = stride # In this mode, first repeat interpolate, than conv with stride=1 self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0) self.scale_factor = float(self.stride) if scale_factor is None else float(scale_factor) def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor): outputs = F.interpolate(inputs, scale_factor=self.scale_factor, mode="nearest") outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0) outputs = self.conv(outputs) return outputs, input_lengths * self.stride def forward_chunk(self, inputs: torch.Tensor, input_lengths: torch.Tensor, cache: torch.Tensor = torch.zeros((0, 0, 0))): """ Args: inputs(torch.Tensor): shape (b, c, t) input_length(torch.Tensor): shape (b), can be None cache(torch.Tensor): shape (b, c, cache_t), where cache_t = stride * 2 """ outputs = F.interpolate(inputs, scale_factor=self.scale_factor, mode="nearest") if cache is None: cache = inputs.new_zeros(inputs.shape[0], inputs.shape[1], self.stride * 2) outputs = torch.cat([cache, outputs], dim=2) new_cache = outputs[..., -self.stride*2:] outputs = self.conv(outputs) if input_lengths is not None: input_lengths = input_lengths * self.stride return outputs, input_lengths, new_cache class PreLookaheadLayer(nn.Module): def __init__(self, channels: int, pre_lookahead_len: int = 1): super().__init__() self.channels = channels self.pre_lookahead_len = pre_lookahead_len self.conv1 = nn.Conv1d( channels, channels, kernel_size=pre_lookahead_len + 1, stride=1, padding=0, ) self.conv2 = nn.Conv1d( channels, channels, kernel_size=3, stride=1, padding=0, ) def forward(self, inputs: torch.Tensor) -> torch.Tensor: """ inputs: (batch_size, seq_len, channels) """ outputs = inputs.transpose(1, 2).contiguous() # look ahead outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0) outputs = F.leaky_relu(self.conv1(outputs)) # outputs outputs = F.pad(outputs, (2, 0), mode='constant', value=0.0) outputs = self.conv2(outputs) outputs = outputs.transpose(1, 2).contiguous() # residual connection outputs = outputs + inputs return outputs def forward_chunk(self, inputs: torch.Tensor, cache: torch.Tensor = None): """ Args: inputs(torch.Tensor): shape (b, t, c) cache(torch.Tensor): shape (b, c, cache_t=2), c = channels """ outputs = inputs.transpose(1, 2).contiguous() outputs = F.leaky_relu(self.conv1(outputs)) # the length of outputs is input length - pre_lookahead_len if cache is None: cache = outputs.new_zeros(outputs.shape[0], outputs.shape[1], 2) # NOTE new_cache = outputs[..., -2:] outputs = torch.cat([cache, outputs], dim=2) outputs = self.conv2(outputs) outputs = outputs.transpose(1, 2).contiguous() # residual connection outputs = outputs + inputs[:, :-self.pre_lookahead_len] return outputs, new_cache """Customize each sample's chunk attention mask """ class UpsampleConformerEncoderV2(torch.nn.Module): def __init__( self, # input & output input_size: int, output_size: int = 256, input_layer: str = "linear", pre_lookahead_len: int = 3, # size num_blocks: int = 6, num_up_blocks: int = 4, # upsampling up_stride: int = 2, up_scale_factor: float = 2, # attention attention_heads: int = 4, pos_enc_layer_type: str = "rel_pos_espnet", selfattention_layer_type: str = "rel_selfattn", key_bias: bool = True, # mlp linear_units: int = 2048, # dropouts dropout_rate: float = 0.1, positional_dropout_rate: float = 0.1, attention_dropout_rate: float = 0.0, # other normalize_before: bool = True, activation_type: str = "swish", **kwargs, ): super().__init__() self._output_size = output_size self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer]( input_size, output_size, dropout_rate, COSYVOICE_EMB_CLASSES[pos_enc_layer_type]( output_size, positional_dropout_rate ), ) self.normalize_before = normalize_before self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5) activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]() # self-attention module definition encoder_selfattn_layer_args = ( attention_heads, output_size, attention_dropout_rate, key_bias, ) # feed-forward module definition positionwise_layer_args = ( output_size, linear_units, dropout_rate, activation, ) self.pre_lookahead_layer = PreLookaheadLayer( channels=output_size, pre_lookahead_len=pre_lookahead_len ) self.encoders = torch.nn.ModuleList([ ConformerEncoderLayer( output_size, COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type]( *encoder_selfattn_layer_args ), PositionwiseFeedForward(*positionwise_layer_args), None, None, dropout_rate, normalize_before, ) for _ in range(num_blocks) ]) self.up_layer = Upsample1D( channels=output_size, out_channels=output_size, stride=up_stride, scale_factor=up_scale_factor ) self.up_embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer]( input_size, output_size, dropout_rate, COSYVOICE_EMB_CLASSES[pos_enc_layer_type]( output_size, positional_dropout_rate ), ) self.up_encoders = torch.nn.ModuleList([ ConformerEncoderLayer( output_size, COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type]( *encoder_selfattn_layer_args ), PositionwiseFeedForward(*positionwise_layer_args), None, None, dropout_rate, normalize_before, ) for _ in range(num_up_blocks) ]) self.enable_cuda_graph = False self.use_cuda_graph = False self.graph_encoder = {} self.graph_up_encoder = {} self.inference_buffers_encoder = {} self.inference_buffers_up_encoder = {} self.max_static_time = 1500 # FIXME(sfy) revert hard-coded bfloat16 # this method is skipped in CausalMaskedDiffWithXvec.scatter_cuda_graph def scatter_cuda_graph(self, enable_cuda_graph: bool): self.enable_cuda_graph = enable_cuda_graph if self.enable_cuda_graph: self._init_cuda_graph() def _init_cuda_graph(self): """初始化 CUDA Graph""" for l in range(100, 1500, 10): static_x = torch.zeros((1, l, 512), dtype=torch.float32, device=torch.device('cuda')) static_mask = torch.ones((1, 1, l), dtype=torch.bool, device=torch.device('cuda')) static_pos_emb = torch.zeros((1, 2*l-1, 512), dtype=torch.float32, device=torch.device('cuda')) static_inputs = [ static_x, static_mask, static_pos_emb, ] self._forward_impl_encoder( static_inputs[0], static_inputs[1], static_inputs[2], ) graph = torch.cuda.CUDAGraph() with torch.no_grad(): with torch.cuda.graph(graph): static_out_x = self._forward_impl_encoder( static_inputs[0], static_inputs[1], static_inputs[2] ) self.graph_encoder[l] = graph static_outputs = [ static_out_x, ] self.inference_buffers_encoder[l] = { 'static_inputs': static_inputs, 'static_outputs': static_outputs } for l in range(100, 1500, 10): static_x = torch.zeros((1, l, 512), dtype=torch.float32, device=torch.device('cuda')) static_mask = torch.ones((1, 1, l), dtype=torch.bool, device=torch.device('cuda')) static_pos_emb = torch.zeros((1, 2*l-1, 512), dtype=torch.float32, device=torch.device('cuda')) static_inputs = [ static_x, static_mask, static_pos_emb, ] self._forward_impl_up_encoder( static_inputs[0], static_inputs[1], static_inputs[2], ) graph = torch.cuda.CUDAGraph() with torch.no_grad(): with torch.cuda.graph(graph): static_out_x = self._forward_impl_up_encoder( static_inputs[0], static_inputs[1], static_inputs[2] ) self.graph_up_encoder[l] = graph static_outputs = [ static_out_x, ] self.inference_buffers_up_encoder[l] = { 'static_inputs': static_inputs, 'static_outputs': static_outputs } self.use_cuda_graph = True print("CUDA Graph initialized successfully for encoder and up_encoder") # @torch.compile(dynamic=True,backend="eager") def _forward_impl_encoder(self, x: torch.Tensor, mask: torch.Tensor, pos_emb: torch.Tensor): for layer in self.encoders: x, _, _, _ = layer(x, mask, pos_emb) return x # @torch.compile(dynamic=True,backend="eager") def _forward_impl_up_encoder(self, x: torch.Tensor, mask: torch.Tensor, pos_emb: torch.Tensor): for layer in self.up_encoders: x, _, _, _ = layer(x, mask, pos_emb) return x def output_size(self) -> int: return self._output_size # @torch.compile(dynamic=True,backend="eager") def forward( self, xs: torch.Tensor, xs_lens: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: # (sfy) chunk training strategy should not be open-sourced T = xs.size(1) masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) xs, pos_emb, masks = self.embed(xs, masks) # lookahead xs = self.pre_lookahead_layer(xs) # conformer block if self.enable_cuda_graph and xs.shape[1] in self.graph_encoder: self.inference_buffers_encoder[xs.shape[1]]['static_inputs'][0].copy_(xs) self.inference_buffers_encoder[xs.shape[1]]['static_inputs'][1].copy_(masks) self.inference_buffers_encoder[xs.shape[1]]['static_inputs'][2].copy_(pos_emb) self.graph_encoder[xs.shape[1]].replay() xs = self.inference_buffers_encoder[xs.shape[1]]['static_outputs'][0] else: xs = self._forward_impl_encoder(xs, masks, pos_emb) # upsample xs = xs.transpose(1, 2).contiguous() xs, xs_lens = self.up_layer(xs, xs_lens) xs = xs.transpose(1, 2).contiguous() # 2nd conformer block T = xs.size(1) masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) xs, pos_emb, masks = self.up_embed(xs, masks) if self.enable_cuda_graph and xs.shape[1] in self.graph_up_encoder: self.inference_buffers_up_encoder[xs.shape[1]]['static_inputs'][0].copy_(xs) self.inference_buffers_up_encoder[xs.shape[1]]['static_inputs'][1].copy_(masks) self.inference_buffers_up_encoder[xs.shape[1]]['static_inputs'][2].copy_(pos_emb) self.graph_up_encoder[xs.shape[1]].replay() xs = self.inference_buffers_up_encoder[xs.shape[1]]['static_outputs'][0] else: xs = self._forward_impl_up_encoder(xs, masks, pos_emb) # post norm if self.normalize_before: xs = self.after_norm(xs) return xs, masks @torch.compile(dynamic=True,backend="eager") def forward_chunk(self, xs: torch.Tensor, last_chunk: bool = False, cnn_cache: torch.Tensor = None, att_cache: torch.Tensor = None, ): """ Args: xs: shape (b, dt, c) last_chunk: bool. If last chunk, will pad input with lookaheads att_cache: shape (depth1+depth2, b, nh, 2*t1, c). cnn_cache: shape (b, c, t1+t2). Where t1=2 (pre_lookahead_layer), t2=4 (up_layer) """ if att_cache is not None: assert att_cache.shape[3] % 2 == 0, att_cache.shape if cnn_cache is not None: assert cnn_cache.shape[2] == 2+self.up_layer.stride*2, cnn_cache.shape # unpack caches offset1 = att_cache.shape[3] // 2 if att_cache is not None else 0 att_cache1 = att_cache[:len(self.encoders), :, :, :offset1] if att_cache is not None else [None] * len(self.encoders) att_cache2 = att_cache[len(self.encoders):] if att_cache is not None else [None] * len(self.encoders) cnn_cache1 = cnn_cache[:, :, :2] if cnn_cache is not None else None cnn_cache2 = cnn_cache[:, :, 2:] if cnn_cache is not None else None xs, _, _ = self.embed(xs, None) if last_chunk: xs = F.pad(xs, (0, 0, 0, self.pre_lookahead_layer.pre_lookahead_len)) # this_cnn_cache: shape (b=1, c=512, t=2) xs, new_cnn_cache1 = self.pre_lookahead_layer.forward_chunk(xs, cache=cnn_cache1) # remake pos_emb, offset param is ignored by position_encoding pos_emb = self.embed.position_encoding(offset=None, size=offset1 + xs.shape[1]) # first conformer chunk_masks = torch.zeros((0, 0, 0)) new_att_cache1 = [] for idx, layer in enumerate(self.encoders): # this_att_cache: shape (b, nh, t, c * 2) xs, _, this_new_att_cache1, _ = layer(xs, chunk_masks, pos_emb, att_cache=att_cache1[idx]) new_att_cache1.append(this_new_att_cache1) new_att_cache1 = torch.stack(new_att_cache1, dim=0) # upsample + conformer encoder, xs: (b, t, c) -> (b, c, t) xs = xs.transpose(1, 2).contiguous() # this_cnn_cache: shape (b=1, c=512, t=2*2) xs, _, new_cnn_cache2 = self.up_layer.forward_chunk(xs, None, cache=cnn_cache2) xs = xs.transpose(1, 2).contiguous() # at this time, xs are doubled in length xs, _, _ = self.up_embed(xs, None) # remake pos_emb pos_emb = self.embed.position_encoding(offset=None, size=offset1 * self.up_layer.stride + xs.shape[1]) # second conformer chunk_masks = torch.zeros((0, 0, 0),dtype=torch.bfloat16) new_att_cache2 = [] for idx, layer in enumerate(self.up_encoders): xs, _, this_new_att_cache2, _ = layer(xs, chunk_masks, pos_emb, att_cache=att_cache2[idx]) new_att_cache2.append(this_new_att_cache2) new_att_cache2 = torch.stack(new_att_cache2, dim=0) if self.normalize_before: xs = self.after_norm(xs) # pack new cache new_att_cache = torch.cat([new_att_cache1.repeat(1, 1, 1, 2, 1), new_att_cache2], dim=0) new_cnn_cache = torch.cat([new_cnn_cache1, new_cnn_cache2], dim=2) return xs, new_cnn_cache, new_att_cache