# Modified from CosyVoice https://github.com/FunAudioLLM/CosyVoice/tree/main """ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ from abc import ABC import torch import torch.nn.functional as F from typing import Dict, Optional import torch.nn as nn from einops import pack, rearrange, repeat from .matcha_components import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D from .matcha_transformer import BasicTransformerBlock from omegaconf import DictConfig def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: assert mask.dtype == torch.bool assert dtype in [torch.float32, torch.bfloat16, torch.float16] mask = mask.to(dtype) # attention mask bias # NOTE(Mddct): torch.finfo jit issues # chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min mask = (1.0 - mask) * torch.finfo(dtype).min return mask def subsequent_chunk_mask( size: int, chunk_size: int, num_left_chunks: int = -1, device: torch.device = torch.device("cpu"), ) -> torch.Tensor: """Create mask for subsequent steps (size, size) with chunk size, this is for streaming encoder Args: size (int): size of mask chunk_size (int): size of chunk num_left_chunks (int): number of left chunks <0: use full chunk >=0: use num_left_chunks device (torch.device): "cpu" or "cuda" or torch.Tensor.device Returns: torch.Tensor: mask Examples: >>> subsequent_chunk_mask(4, 2) [[1, 1, 0, 0], [1, 1, 0, 0], [1, 1, 1, 1], [1, 1, 1, 1]] """ # NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks # actually this is not needed after we have inference cache implemented, will remove it later pos_idx = torch.arange(size, device=device) block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1) return ret def subsequent_mask( size: int, device: torch.device = torch.device("cpu"), ) -> torch.Tensor: """Create mask for subsequent steps (size, size). This mask is used only in decoder which works in an auto-regressive mode. This means the current step could only do attention with its left steps. In encoder, fully attention is used when streaming is not necessary and the sequence is not long. In this case, no attention mask is needed. When streaming is need, chunk-based attention is used in encoder. See subsequent_chunk_mask for the chunk-based attention mask. Args: size (int): size of mask str device (str): "cpu" or "cuda" or torch.Tensor.device dtype (torch.device): result dtype Returns: torch.Tensor: mask Examples: >>> subsequent_mask(3) [[1, 0, 0], [1, 1, 0], [1, 1, 1]] """ arange = torch.arange(size, device=device) mask = arange.expand(size, size) arange = arange.unsqueeze(-1) mask = mask <= arange return mask def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor, use_dynamic_chunk: bool, use_dynamic_left_chunk: bool, decoding_chunk_size: int, static_chunk_size: int, num_decoding_left_chunks: int, enable_full_context: bool = True): """ Apply optional mask for encoder. Args: xs (torch.Tensor): padded input, (B, L, D), L for max length mask (torch.Tensor): mask for xs, (B, 1, L) use_dynamic_chunk (bool): whether to use dynamic chunk or not use_dynamic_left_chunk (bool): whether to use dynamic left chunk for training. decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's 0: default for training, use random dynamic chunk. <0: for decoding, use full chunk. >0: for decoding, use fixed chunk size as set. static_chunk_size (int): chunk size for static chunk training/decoding if it's greater than 0, if use_dynamic_chunk is true, this parameter will be ignored num_decoding_left_chunks: number of left chunks, this is for decoding, the chunk size is decoding_chunk_size. >=0: use num_decoding_left_chunks <0: use all left chunks enable_full_context (bool): True: chunk size is either [1, 25] or full context(max_len) False: chunk size ~ U[1, 25] Returns: torch.Tensor: chunk mask of the input xs. """ # Whether to use chunk mask or not if use_dynamic_chunk: max_len = xs.size(1) if decoding_chunk_size < 0: chunk_size = max_len num_left_chunks = -1 elif decoding_chunk_size > 0: chunk_size = decoding_chunk_size num_left_chunks = num_decoding_left_chunks else: # chunk size is either [1, 25] or full context(max_len). # Since we use 4 times subsampling and allow up to 1s(100 frames) # delay, the maximum frame is 100 / 4 = 25. chunk_size = torch.randint(1, max_len, (1, )).item() num_left_chunks = -1 if chunk_size > max_len // 2 and enable_full_context: chunk_size = max_len else: chunk_size = chunk_size % 25 + 1 if use_dynamic_left_chunk: max_left_chunks = (max_len - 1) // chunk_size num_left_chunks = torch.randint(0, max_left_chunks, (1, )).item() chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size, num_left_chunks, xs.device) # (L, L) chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) chunk_masks = masks & chunk_masks # (B, L, L) elif static_chunk_size > 0: num_left_chunks = num_decoding_left_chunks chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size, num_left_chunks, xs.device) # (L, L) chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) chunk_masks = masks & chunk_masks # (B, L, L) else: chunk_masks = masks return chunk_masks def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: """Make mask tensor containing indices of padded part. See description of make_non_pad_mask. Args: lengths (torch.Tensor): Batch of lengths (B,). Returns: torch.Tensor: Mask tensor containing indices of padded part. Examples: >>> lengths = [5, 3, 2] >>> make_pad_mask(lengths) masks = [[0, 0, 0, 0 ,0], [0, 0, 0, 1, 1], [0, 0, 1, 1, 1]] """ batch_size = lengths.size(0) max_len = max_len if max_len > 0 else lengths.max().item() seq_range = torch.arange(0, max_len, dtype=torch.int64, device=lengths.device) seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) seq_length_expand = lengths.unsqueeze(-1) mask = seq_range_expand >= seq_length_expand return mask # Causal class Transpose(torch.nn.Module): def __init__(self, dim0: int, dim1: int): super().__init__() self.dim0 = dim0 self.dim1 = dim1 def forward(self, x: torch.Tensor): x = torch.transpose(x, self.dim0, self.dim1) return x class CausalBlock1D(Block1D): def __init__(self, dim: int, dim_out: int): super(CausalBlock1D, self).__init__(dim, dim_out) self.block = torch.nn.Sequential( CausalConv1d(dim, dim_out, 3), Transpose(1, 2), nn.LayerNorm(dim_out), Transpose(1, 2), nn.Mish(), ) def forward(self, x: torch.Tensor, mask: torch.Tensor): output = self.block(x * mask) return output * mask class CausalResnetBlock1D(ResnetBlock1D): def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8): super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups) self.block1 = CausalBlock1D(dim, dim_out) self.block2 = CausalBlock1D(dim_out, dim_out) class CausalConv1d(torch.nn.Conv1d): def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, dilation: int = 1, groups: int = 1, bias: bool = True, padding_mode: str = 'zeros', device=None, dtype=None ) -> None: super(CausalConv1d, self).__init__(in_channels, out_channels, kernel_size, stride, padding=0, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode, device=device, dtype=dtype) assert stride == 1 self.causal_padding = (kernel_size - 1, 0) def forward(self, x: torch.Tensor): x = F.pad(x, self.causal_padding) x = super(CausalConv1d, self).forward(x) return x class BASECFM(torch.nn.Module, ABC): def __init__( self, n_feats, cfm_params, n_spks=1, spk_emb_dim=128, ): super().__init__() self.n_feats = n_feats self.n_spks = n_spks self.spk_emb_dim = spk_emb_dim self.solver = cfm_params.solver if hasattr(cfm_params, "sigma_min"): self.sigma_min = cfm_params.sigma_min else: self.sigma_min = 1e-4 self.estimator = None @torch.inference_mode() def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): """Forward diffusion Args: mu (torch.Tensor): output of encoder shape: (batch_size, n_feats, mel_timesteps) mask (torch.Tensor): output_mask shape: (batch_size, 1, mel_timesteps) n_timesteps (int): number of diffusion steps temperature (float, optional): temperature for scaling noise. Defaults to 1.0. spks (torch.Tensor, optional): speaker ids. Defaults to None. shape: (batch_size, spk_emb_dim) cond: Not used but kept for future purposes Returns: sample: generated mel-spectrogram shape: (batch_size, n_feats, mel_timesteps) """ z = torch.randn_like(mu) * temperature t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond) def solve_euler(self, x, t_span, mu, mask, spks, cond): """ Fixed euler solver for ODEs. Args: x (torch.Tensor): random noise t_span (torch.Tensor): n_timesteps interpolated shape: (n_timesteps + 1,) mu (torch.Tensor): output of encoder shape: (batch_size, n_feats, mel_timesteps) mask (torch.Tensor): output_mask shape: (batch_size, 1, mel_timesteps) spks (torch.Tensor, optional): speaker ids. Defaults to None. shape: (batch_size, spk_emb_dim) cond: Not used but kept for future purposes """ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] # I am storing this because I can later plot it by putting a debugger here and saving it to a file # Or in future might add like a return_all_steps flag sol = [] for step in range(1, len(t_span)): dphi_dt = self.estimator(x, mask, mu, t, spks, cond) x = x + dt * dphi_dt t = t + dt sol.append(x) if step < len(t_span) - 1: dt = t_span[step + 1] - t return sol[-1] def compute_loss(self, x1, mask, mu, spks=None, cond=None): """Computes diffusion loss Args: x1 (torch.Tensor): Target shape: (batch_size, n_feats, mel_timesteps) mask (torch.Tensor): target mask shape: (batch_size, 1, mel_timesteps) mu (torch.Tensor): output of encoder shape: (batch_size, n_feats, mel_timesteps) spks (torch.Tensor, optional): speaker embedding. Defaults to None. shape: (batch_size, spk_emb_dim) Returns: loss: conditional flow matching loss y: conditional flow shape: (batch_size, n_feats, mel_timesteps) """ b, _, t = mu.shape # random timestep t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) # sample noise p(x_0) z = torch.randn_like(x1) y = (1 - (1 - self.sigma_min) * t) * z + t * x1 u = x1 - (1 - self.sigma_min) * z loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / ( torch.sum(mask) * u.shape[1] ) return loss, y def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: """Make mask tensor containing indices of padded part. See description of make_non_pad_mask. Args: lengths (torch.Tensor): Batch of lengths (B,). Returns: torch.Tensor: Mask tensor containing indices of padded part. Examples: >>> lengths = [5, 3, 2] >>> make_pad_mask(lengths) masks = [[0, 0, 0, 0 ,0], [0, 0, 0, 1, 1], [0, 0, 1, 1, 1]] """ batch_size = lengths.size(0) max_len = max_len if max_len > 0 else lengths.max().item() seq_range = torch.arange(0, max_len, dtype=torch.int64, device=lengths.device) seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) seq_length_expand = lengths.unsqueeze(-1) mask = seq_range_expand >= seq_length_expand return mask class ConditionalDecoder(nn.Module): def __init__( self, in_channels, out_channels, causal=False, channels=(256, 256), dropout=0.05, attention_head_dim=64, n_blocks=1, num_mid_blocks=2, num_heads=4, act_fn="snake", gradient_checkpointing=True, ): """ This decoder requires an input with the same shape of the target. So, if your text content is shorter or longer than the outputs, please re-sampling it before feeding to the decoder. """ super().__init__() channels = tuple(channels) self.in_channels = in_channels self.out_channels = out_channels self.causal = causal self.static_chunk_size = 2 * 25 * 2 # 2*input_frame_rate*token_mel_ratio self.gradient_checkpointing = gradient_checkpointing self.time_embeddings = SinusoidalPosEmb(in_channels) time_embed_dim = channels[0] * 4 self.time_mlp = TimestepEmbedding( in_channels=in_channels, time_embed_dim=time_embed_dim, act_fn="silu", ) self.down_blocks = nn.ModuleList([]) self.mid_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) output_channel = in_channels for i in range(len(channels)): # pylint: disable=consider-using-enumerate input_channel = output_channel output_channel = channels[i] is_last = i == len(channels) - 1 resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \ ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( dim=output_channel, num_attention_heads=num_heads, attention_head_dim=attention_head_dim, dropout=dropout, activation_fn=act_fn, ) for _ in range(n_blocks) ] ) downsample = ( Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1) ) self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample])) for _ in range(num_mid_blocks): input_channel = channels[-1] out_channels = channels[-1] resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \ ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( dim=output_channel, num_attention_heads=num_heads, attention_head_dim=attention_head_dim, dropout=dropout, activation_fn=act_fn, ) for _ in range(n_blocks) ] ) self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks])) channels = channels[::-1] + (channels[0],) for i in range(len(channels) - 1): input_channel = channels[i] * 2 output_channel = channels[i + 1] is_last = i == len(channels) - 2 resnet = CausalResnetBlock1D( dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim, ) if self.causal else ResnetBlock1D( dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim, ) transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( dim=output_channel, num_attention_heads=num_heads, attention_head_dim=attention_head_dim, dropout=dropout, activation_fn=act_fn, ) for _ in range(n_blocks) ] ) upsample = ( Upsample1D(output_channel, use_conv_transpose=True) if not is_last else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1) ) self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample])) self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1]) self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1) self.initialize_weights() def initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv1d): nn.init.kaiming_normal_(m.weight, nonlinearity="relu") if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.GroupNorm): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.kaiming_normal_(m.weight, nonlinearity="relu") if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x, mask, mu, t, spks=None, cond=None): """Forward pass of the UNet1DConditional model. Args: x (torch.Tensor): shape (batch_size, in_channels, time) mask (_type_): shape (batch_size, 1, time) t (_type_): shape (batch_size) spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None. cond (_type_, optional): placeholder for future use. Defaults to None. Raises: ValueError: _description_ ValueError: _description_ Returns: _type_: _description_ """ t = self.time_embeddings(t) t = t.to(x.dtype) t = self.time_mlp(t) x = pack([x, mu], "b * t")[0] mask = mask.to(x.dtype) if spks is not None: spks = repeat(spks, "b c -> b c t", t=x.shape[-1]) x = pack([x, spks], "b * t")[0] if cond is not None: x = pack([x, cond], "b * t")[0] hiddens = [] masks = [mask] for resnet, transformer_blocks, downsample in self.down_blocks: mask_down = masks[-1] x = resnet(x, mask_down, t) x = rearrange(x, "b c t -> b t c").contiguous() # attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down) attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1) attn_mask = mask_to_bias(attn_mask == 1, x.dtype) for transformer_block in transformer_blocks: if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward x = torch.utils.checkpoint.checkpoint( create_custom_forward(transformer_block), x, attn_mask, t, ) else: x = transformer_block( hidden_states=x, attention_mask=attn_mask, timestep=t, ) x = rearrange(x, "b t c -> b c t").contiguous() hiddens.append(x) # Save hidden states for skip connections x = downsample(x * mask_down) masks.append(mask_down[:, :, ::2]) masks = masks[:-1] mask_mid = masks[-1] for resnet, transformer_blocks in self.mid_blocks: x = resnet(x, mask_mid, t) x = rearrange(x, "b c t -> b t c").contiguous() # attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid) attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1) attn_mask = mask_to_bias(attn_mask == 1, x.dtype) for transformer_block in transformer_blocks: if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward x = torch.utils.checkpoint.checkpoint( create_custom_forward(transformer_block), x, attn_mask, t, ) else: x = transformer_block( hidden_states=x, attention_mask=attn_mask, timestep=t, ) x = rearrange(x, "b t c -> b c t").contiguous() for resnet, transformer_blocks, upsample in self.up_blocks: mask_up = masks.pop() skip = hiddens.pop() x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0] x = resnet(x, mask_up, t) x = rearrange(x, "b c t -> b t c").contiguous() # attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up) attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1) attn_mask = mask_to_bias(attn_mask == 1, x.dtype) for transformer_block in transformer_blocks: if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward x = torch.utils.checkpoint.checkpoint( create_custom_forward(transformer_block), x, attn_mask, t, ) else: x = transformer_block( hidden_states=x, attention_mask=attn_mask, timestep=t, ) x = rearrange(x, "b t c -> b c t").contiguous() x = upsample(x * mask_up) x = self.final_block(x, mask_up) output = self.final_proj(x * mask_up) return output * mask class ConditionalCFM(BASECFM): def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64): super().__init__( n_feats=in_channels, cfm_params=cfm_params, n_spks=n_spks, spk_emb_dim=spk_emb_dim, ) self.t_scheduler = cfm_params.t_scheduler self.training_cfg_rate = cfm_params.training_cfg_rate self.inference_cfg_rate = cfm_params.inference_cfg_rate @torch.inference_mode() def forward(self, estimator, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): """Forward diffusion Args: mu (torch.Tensor): output of encoder shape: (batch_size, n_feats, mel_timesteps) mask (torch.Tensor): output_mask shape: (batch_size, 1, mel_timesteps) n_timesteps (int): number of diffusion steps temperature (float, optional): temperature for scaling noise. Defaults to 1.0. spks (torch.Tensor, optional): speaker ids. Defaults to None. shape: (batch_size, spk_emb_dim) cond: Not used but kept for future purposes Returns: sample: generated mel-spectrogram shape: (batch_size, n_feats, mel_timesteps) """ z = torch.randn_like(mu) * temperature t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) if self.t_scheduler == 'cosine': t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) return self.solve_euler(estimator, z, t_span=t_span.to(mu.dtype), mu=mu, mask=mask, spks=spks, cond=cond) def solve_euler(self, estimator, x, t_span, mu, mask, spks, cond): """ Fixed euler solver for ODEs. Args: x (torch.Tensor): random noise t_span (torch.Tensor): n_timesteps interpolated shape: (n_timesteps + 1,) mu (torch.Tensor): output of encoder shape: (batch_size, n_feats, mel_timesteps) mask (torch.Tensor): output_mask shape: (batch_size, 1, mel_timesteps) spks (torch.Tensor, optional): speaker ids. Defaults to None. shape: (batch_size, spk_emb_dim) cond: Not used but kept for future purposes """ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] # I am storing this because I can later plot it by putting a debugger here and saving it to a file # Or in future might add like a return_all_steps flag sol = [] for step in range(1, len(t_span)): dphi_dt = estimator(x, mask, mu, t, spks, cond) # Classifier-Free Guidance inference introduced in VoiceBox if self.inference_cfg_rate > 0: cfg_dphi_dt = estimator( x, mask, torch.zeros_like(mu), t, torch.zeros_like(spks) if spks is not None else None, cond=cond ) dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt) x = x + dt * dphi_dt t = t + dt sol.append(x) if step < len(t_span) - 1: dt = t_span[step + 1] - t return sol[-1] def compute_loss(self, estimator, x1, mask, mu, spks=None, cond=None): """Computes diffusion loss Args: x1 (torch.Tensor): Target shape: (batch_size, n_feats, mel_timesteps) mask (torch.Tensor): target mask shape: (batch_size, 1, mel_timesteps) mu (torch.Tensor): output of encoder shape: (batch_size, n_feats, mel_timesteps) spks (torch.Tensor, optional): speaker embedding. Defaults to None. shape: (batch_size, spk_emb_dim) Returns: loss: conditional flow matching loss y: conditional flow shape: (batch_size, n_feats, mel_timesteps) """ org_dtype = x1.dtype b, _, t = mu.shape # random timestep t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) if self.t_scheduler == 'cosine': t = 1 - torch.cos(t * 0.5 * torch.pi) # sample noise p(x_0) z = torch.randn_like(x1) y = (1 - (1 - self.sigma_min) * t) * z + t * x1 u = x1 - (1 - self.sigma_min) * z # during training, we randomly drop condition to trade off mode coverage and sample fidelity if self.training_cfg_rate > 0: cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate mu = mu * cfg_mask.view(-1, 1, 1) if spks is not None: spks = spks * cfg_mask.view(-1, 1) if cond is not None: cond = cond * cfg_mask.view(-1, 1, 1) pred = estimator(y, mask, mu, t.squeeze(), spks, cond) pred = pred.float() u = u.float() loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1]) loss = loss.to(org_dtype) return loss, y