Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Commit 
							
							·
						
						9d593b2
	
1
								Parent(s):
							
							b36b21a
								
add files
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- README.md +6 -7
- app.py +65 -0
- chatterbox/src/chatterbox/__init__.py +2 -0
- chatterbox/src/chatterbox/__pycache__/__init__.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/__pycache__/tts.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/__pycache__/utils.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/__pycache__/vc.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/__init__.py +2 -0
- chatterbox/src/chatterbox/models/s3gen/__pycache__/__init__.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/__pycache__/const.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/__pycache__/decoder.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/__pycache__/f0_predictor.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/__pycache__/flow.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/__pycache__/flow_matching.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/__pycache__/hifigan.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/__pycache__/s3gen.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/__pycache__/xvector.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/const.py +1 -0
- chatterbox/src/chatterbox/models/s3gen/decoder.py +317 -0
- chatterbox/src/chatterbox/models/s3gen/f0_predictor.py +55 -0
- chatterbox/src/chatterbox/models/s3gen/flow.py +242 -0
- chatterbox/src/chatterbox/models/s3gen/flow_matching.py +228 -0
- chatterbox/src/chatterbox/models/s3gen/hifigan.py +474 -0
- chatterbox/src/chatterbox/models/s3gen/matcha/__pycache__/decoder.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/matcha/__pycache__/flow_matching.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/matcha/__pycache__/transformer.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/matcha/decoder.py +443 -0
- chatterbox/src/chatterbox/models/s3gen/matcha/flow_matching.py +129 -0
- chatterbox/src/chatterbox/models/s3gen/matcha/text_encoder.py +413 -0
- chatterbox/src/chatterbox/models/s3gen/matcha/transformer.py +316 -0
- chatterbox/src/chatterbox/models/s3gen/s3gen.py +305 -0
- chatterbox/src/chatterbox/models/s3gen/transformer/__init__.py +0 -0
- chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/__init__.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/activation.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/attention.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/convolution.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/embedding.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/encoder_layer.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/positionwise_feed_forward.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/subsampling.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/upsample_encoder.cpython-311.pyc +0 -0
- chatterbox/src/chatterbox/models/s3gen/transformer/activation.py +84 -0
- chatterbox/src/chatterbox/models/s3gen/transformer/attention.py +330 -0
- chatterbox/src/chatterbox/models/s3gen/transformer/convolution.py +145 -0
- chatterbox/src/chatterbox/models/s3gen/transformer/embedding.py +294 -0
- chatterbox/src/chatterbox/models/s3gen/transformer/encoder_layer.py +236 -0
- chatterbox/src/chatterbox/models/s3gen/transformer/positionwise_feed_forward.py +115 -0
- chatterbox/src/chatterbox/models/s3gen/transformer/subsampling.py +383 -0
- chatterbox/src/chatterbox/models/s3gen/transformer/upsample_encoder.py +318 -0
- chatterbox/src/chatterbox/models/s3gen/utils/__pycache__/class_utils.cpython-311.pyc +0 -0
    	
        README.md
    CHANGED
    
    | @@ -1,14 +1,13 @@ | |
| 1 | 
             
            ---
         | 
| 2 | 
            -
            title: Chatterbox  | 
| 3 | 
            -
            emoji:  | 
| 4 | 
            -
            colorFrom:  | 
| 5 | 
             
            colorTo: blue
         | 
| 6 | 
             
            sdk: gradio
         | 
| 7 | 
            -
            sdk_version: 5. | 
| 8 | 
             
            app_file: app.py
         | 
| 9 | 
             
            pinned: false
         | 
| 10 | 
            -
             | 
| 11 | 
            -
            short_description: SoTA Open-source TTS Model.
         | 
| 12 | 
             
            ---
         | 
| 13 |  | 
| 14 | 
            -
            Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
         | 
|  | |
| 1 | 
             
            ---
         | 
| 2 | 
            +
            title: Chatterbox TTS
         | 
| 3 | 
            +
            emoji: 🍿
         | 
| 4 | 
            +
            colorFrom: indigo
         | 
| 5 | 
             
            colorTo: blue
         | 
| 6 | 
             
            sdk: gradio
         | 
| 7 | 
            +
            sdk_version: 5.29.0
         | 
| 8 | 
             
            app_file: app.py
         | 
| 9 | 
             
            pinned: false
         | 
| 10 | 
            +
            short_description: Expressive Zeroshot TTS
         | 
|  | |
| 11 | 
             
            ---
         | 
| 12 |  | 
| 13 | 
            +
            Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
         | 
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,65 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import random
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            from chatterbox.src.chatterbox.tts import ChatterboxTTS
         | 
| 5 | 
            +
            import gradio as gr
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            def set_seed(seed: int):
         | 
| 11 | 
            +
                torch.manual_seed(seed)
         | 
| 12 | 
            +
                torch.cuda.manual_seed(seed)
         | 
| 13 | 
            +
                torch.cuda.manual_seed_all(seed)
         | 
| 14 | 
            +
                random.seed(seed)
         | 
| 15 | 
            +
                np.random.seed(seed)
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            model = ChatterboxTTS.from_pretrained(DEVICE)
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            def generate(text, audio_prompt_path, exaggeration, pace, temperature, seed_num):
         | 
| 21 | 
            +
                if seed_num != 0:
         | 
| 22 | 
            +
                    set_seed(int(seed_num))
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                wav = model.generate(
         | 
| 25 | 
            +
                    text,
         | 
| 26 | 
            +
                    audio_prompt_path=audio_prompt_path,
         | 
| 27 | 
            +
                    exaggeration=exaggeration,
         | 
| 28 | 
            +
                    pace=pace,
         | 
| 29 | 
            +
                    temperature=temperature,
         | 
| 30 | 
            +
                )
         | 
| 31 | 
            +
                return model.sr, wav.squeeze(0).numpy()
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            with gr.Blocks() as demo:
         | 
| 35 | 
            +
                with gr.Row():
         | 
| 36 | 
            +
                    with gr.Column():
         | 
| 37 | 
            +
                        text = gr.Textbox(value="What does the fox say?", label="Text to synthesize")
         | 
| 38 | 
            +
                        ref_wav = gr.Audio(sources="upload", type="filepath", label="Reference Audio File", value=None)
         | 
| 39 | 
            +
                        exaggeration = gr.Slider(0.25, 2, step=.05, label="exaggeration", value=.7)
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                        with gr.Accordion("More options", open=False):
         | 
| 42 | 
            +
                            seed_num = gr.Number(value=0, label="Random seed (0 for random)")
         | 
| 43 | 
            +
                            temp = gr.Slider(0.05, 5, step=.05, label="temperature", value=.8)
         | 
| 44 | 
            +
                            pace = gr.Slider(0.8, 1.2, step=.01, label="pace", value=1)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                        run_btn = gr.Button("Generate", variant="primary")
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                    with gr.Column():
         | 
| 49 | 
            +
                        audio_output = gr.Audio(label="Output Audio")
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                run_btn.click(
         | 
| 52 | 
            +
                    fn=generate,
         | 
| 53 | 
            +
                    inputs=[
         | 
| 54 | 
            +
                        text,
         | 
| 55 | 
            +
                        ref_wav,
         | 
| 56 | 
            +
                        exaggeration,
         | 
| 57 | 
            +
                        pace,
         | 
| 58 | 
            +
                        temp,
         | 
| 59 | 
            +
                        seed_num,
         | 
| 60 | 
            +
                    ],
         | 
| 61 | 
            +
                    outputs=audio_output,
         | 
| 62 | 
            +
                )
         | 
| 63 | 
            +
             | 
| 64 | 
            +
            if __name__ == "__main__":
         | 
| 65 | 
            +
                demo.launch()
         | 
    	
        chatterbox/src/chatterbox/__init__.py
    ADDED
    
    | @@ -0,0 +1,2 @@ | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .tts import ChatterboxTTS
         | 
| 2 | 
            +
            from .vc import ChatterboxVC
         | 
    	
        chatterbox/src/chatterbox/__pycache__/__init__.cpython-311.pyc
    ADDED
    
    | Binary file (275 Bytes). View file | 
|  | 
    	
        chatterbox/src/chatterbox/__pycache__/tts.cpython-311.pyc
    ADDED
    
    | Binary file (13.8 kB). View file | 
|  | 
    	
        chatterbox/src/chatterbox/__pycache__/utils.cpython-311.pyc
    ADDED
    
    | Binary file (858 Bytes). View file | 
|  | 
    	
        chatterbox/src/chatterbox/__pycache__/vc.cpython-311.pyc
    ADDED
    
    | Binary file (5.44 kB). View file | 
|  | 
    	
        chatterbox/src/chatterbox/models/s3gen/__init__.py
    ADDED
    
    | @@ -0,0 +1,2 @@ | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .s3gen import S3Token2Wav as S3Gen
         | 
| 2 | 
            +
            from .const import S3GEN_SR
         | 
    	
        chatterbox/src/chatterbox/models/s3gen/__pycache__/__init__.cpython-311.pyc
    ADDED
    
    | Binary file (294 Bytes). View file | 
|  | 
    	
        chatterbox/src/chatterbox/models/s3gen/__pycache__/const.cpython-311.pyc
    ADDED
    
    | Binary file (190 Bytes). View file | 
|  | 
    	
        chatterbox/src/chatterbox/models/s3gen/__pycache__/decoder.cpython-311.pyc
    ADDED
    
    | Binary file (16.9 kB). View file | 
|  | 
    	
        chatterbox/src/chatterbox/models/s3gen/__pycache__/f0_predictor.cpython-311.pyc
    ADDED
    
    | Binary file (2.7 kB). View file | 
|  | 
    	
        chatterbox/src/chatterbox/models/s3gen/__pycache__/flow.cpython-311.pyc
    ADDED
    
    | Binary file (13.7 kB). View file | 
|  | 
    	
        chatterbox/src/chatterbox/models/s3gen/__pycache__/flow_matching.cpython-311.pyc
    ADDED
    
    | Binary file (13.3 kB). View file | 
|  | 
    	
        chatterbox/src/chatterbox/models/s3gen/__pycache__/hifigan.cpython-311.pyc
    ADDED
    
    | Binary file (26.3 kB). View file | 
|  | 
    	
        chatterbox/src/chatterbox/models/s3gen/__pycache__/s3gen.cpython-311.pyc
    ADDED
    
    | Binary file (13.7 kB). View file | 
|  | 
    	
        chatterbox/src/chatterbox/models/s3gen/__pycache__/xvector.cpython-311.pyc
    ADDED
    
    | Binary file (24 kB). View file | 
|  | 
    	
        chatterbox/src/chatterbox/models/s3gen/const.py
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            S3GEN_SR = 24000
         | 
    	
        chatterbox/src/chatterbox/models/s3gen/decoder.py
    ADDED
    
    | @@ -0,0 +1,317 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
            import torch
         | 
| 15 | 
            +
            import torch.nn as nn
         | 
| 16 | 
            +
            import torch.nn.functional as F
         | 
| 17 | 
            +
            from einops import pack, rearrange, repeat
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            from .utils.mask import add_optional_chunk_mask
         | 
| 20 | 
            +
            from .matcha.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, \
         | 
| 21 | 
            +
                TimestepEmbedding, Upsample1D
         | 
| 22 | 
            +
            from .matcha.transformer import BasicTransformerBlock
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
         | 
| 26 | 
            +
                assert mask.dtype == torch.bool
         | 
| 27 | 
            +
                assert dtype in [torch.float32, torch.bfloat16, torch.float16]
         | 
| 28 | 
            +
                mask = mask.to(dtype)
         | 
| 29 | 
            +
                # attention mask bias
         | 
| 30 | 
            +
                # NOTE(Mddct): torch.finfo jit issues
         | 
| 31 | 
            +
                #     chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
         | 
| 32 | 
            +
                mask = (1.0 - mask) * -1.0e+10
         | 
| 33 | 
            +
                return mask
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            class Transpose(torch.nn.Module):
         | 
| 38 | 
            +
                def __init__(self, dim0: int, dim1: int):
         | 
| 39 | 
            +
                    super().__init__()
         | 
| 40 | 
            +
                    self.dim0 = dim0
         | 
| 41 | 
            +
                    self.dim1 = dim1
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                def forward(self, x: torch.Tensor):
         | 
| 44 | 
            +
                    x = torch.transpose(x, self.dim0, self.dim1)
         | 
| 45 | 
            +
                    return x
         | 
| 46 | 
            +
             | 
| 47 | 
            +
             | 
| 48 | 
            +
            class CausalBlock1D(Block1D):
         | 
| 49 | 
            +
                def __init__(self, dim: int, dim_out: int):
         | 
| 50 | 
            +
                    super(CausalBlock1D, self).__init__(dim, dim_out)
         | 
| 51 | 
            +
                    self.block = torch.nn.Sequential(
         | 
| 52 | 
            +
                        CausalConv1d(dim, dim_out, 3),
         | 
| 53 | 
            +
                        Transpose(1, 2),
         | 
| 54 | 
            +
                        nn.LayerNorm(dim_out),
         | 
| 55 | 
            +
                        Transpose(1, 2),
         | 
| 56 | 
            +
                        nn.Mish(),
         | 
| 57 | 
            +
                    )
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                def forward(self, x: torch.Tensor, mask: torch.Tensor):
         | 
| 60 | 
            +
                    output = self.block(x * mask)
         | 
| 61 | 
            +
                    return output * mask
         | 
| 62 | 
            +
             | 
| 63 | 
            +
             | 
| 64 | 
            +
            class CausalResnetBlock1D(ResnetBlock1D):
         | 
| 65 | 
            +
                def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
         | 
| 66 | 
            +
                    super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
         | 
| 67 | 
            +
                    self.block1 = CausalBlock1D(dim, dim_out)
         | 
| 68 | 
            +
                    self.block2 = CausalBlock1D(dim_out, dim_out)
         | 
| 69 | 
            +
             | 
| 70 | 
            +
             | 
| 71 | 
            +
            class CausalConv1d(torch.nn.Conv1d):
         | 
| 72 | 
            +
                def __init__(
         | 
| 73 | 
            +
                    self,
         | 
| 74 | 
            +
                    in_channels: int,
         | 
| 75 | 
            +
                    out_channels: int,
         | 
| 76 | 
            +
                    kernel_size: int,
         | 
| 77 | 
            +
                    stride: int = 1,
         | 
| 78 | 
            +
                    dilation: int = 1,
         | 
| 79 | 
            +
                    groups: int = 1,
         | 
| 80 | 
            +
                    bias: bool = True,
         | 
| 81 | 
            +
                    padding_mode: str = 'zeros',
         | 
| 82 | 
            +
                    device=None,
         | 
| 83 | 
            +
                    dtype=None
         | 
| 84 | 
            +
                ) -> None:
         | 
| 85 | 
            +
                    super(CausalConv1d, self).__init__(in_channels, out_channels,
         | 
| 86 | 
            +
                                                       kernel_size, stride,
         | 
| 87 | 
            +
                                                       padding=0, dilation=dilation,
         | 
| 88 | 
            +
                                                       groups=groups, bias=bias,
         | 
| 89 | 
            +
                                                       padding_mode=padding_mode,
         | 
| 90 | 
            +
                                                       device=device, dtype=dtype)
         | 
| 91 | 
            +
                    assert stride == 1
         | 
| 92 | 
            +
                    self.causal_padding = (kernel_size - 1, 0)
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                def forward(self, x: torch.Tensor):
         | 
| 95 | 
            +
                    x = F.pad(x, self.causal_padding)
         | 
| 96 | 
            +
                    x = super(CausalConv1d, self).forward(x)
         | 
| 97 | 
            +
                    return x
         | 
| 98 | 
            +
             | 
| 99 | 
            +
             | 
| 100 | 
            +
            class ConditionalDecoder(nn.Module):
         | 
| 101 | 
            +
                def __init__(
         | 
| 102 | 
            +
                    self,
         | 
| 103 | 
            +
                    in_channels=320,
         | 
| 104 | 
            +
                    out_channels=80,
         | 
| 105 | 
            +
                    causal=True,
         | 
| 106 | 
            +
                    channels=[256],
         | 
| 107 | 
            +
                    dropout=0.0,
         | 
| 108 | 
            +
                    attention_head_dim=64,
         | 
| 109 | 
            +
                    n_blocks=4,
         | 
| 110 | 
            +
                    num_mid_blocks=12,
         | 
| 111 | 
            +
                    num_heads=8,
         | 
| 112 | 
            +
                    act_fn="gelu",
         | 
| 113 | 
            +
                ):
         | 
| 114 | 
            +
                    """
         | 
| 115 | 
            +
                    This decoder requires an input with the same shape of the target. So, if your text content
         | 
| 116 | 
            +
                    is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
         | 
| 117 | 
            +
                    """
         | 
| 118 | 
            +
                    super().__init__()
         | 
| 119 | 
            +
                    channels = tuple(channels)
         | 
| 120 | 
            +
                    self.in_channels = in_channels
         | 
| 121 | 
            +
                    self.out_channels = out_channels
         | 
| 122 | 
            +
                    self.causal = causal
         | 
| 123 | 
            +
                    self.time_embeddings = SinusoidalPosEmb(in_channels)
         | 
| 124 | 
            +
                    time_embed_dim = channels[0] * 4
         | 
| 125 | 
            +
                    self.time_mlp = TimestepEmbedding(
         | 
| 126 | 
            +
                        in_channels=in_channels,
         | 
| 127 | 
            +
                        time_embed_dim=time_embed_dim,
         | 
| 128 | 
            +
                        act_fn="silu",
         | 
| 129 | 
            +
                    )
         | 
| 130 | 
            +
                    self.down_blocks = nn.ModuleList([])
         | 
| 131 | 
            +
                    self.mid_blocks = nn.ModuleList([])
         | 
| 132 | 
            +
                    self.up_blocks = nn.ModuleList([])
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                    # NOTE jrm: `static_chunk_size` is missing?
         | 
| 135 | 
            +
                    self.static_chunk_size = 0
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    output_channel = in_channels
         | 
| 138 | 
            +
                    for i in range(len(channels)):  # pylint: disable=consider-using-enumerate
         | 
| 139 | 
            +
                        input_channel = output_channel
         | 
| 140 | 
            +
                        output_channel = channels[i]
         | 
| 141 | 
            +
                        is_last = i == len(channels) - 1
         | 
| 142 | 
            +
                        resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
         | 
| 143 | 
            +
                            ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
         | 
| 144 | 
            +
                        transformer_blocks = nn.ModuleList(
         | 
| 145 | 
            +
                            [
         | 
| 146 | 
            +
                                BasicTransformerBlock(
         | 
| 147 | 
            +
                                    dim=output_channel,
         | 
| 148 | 
            +
                                    num_attention_heads=num_heads,
         | 
| 149 | 
            +
                                    attention_head_dim=attention_head_dim,
         | 
| 150 | 
            +
                                    dropout=dropout,
         | 
| 151 | 
            +
                                    activation_fn=act_fn,
         | 
| 152 | 
            +
                                )
         | 
| 153 | 
            +
                                for _ in range(n_blocks)
         | 
| 154 | 
            +
                            ]
         | 
| 155 | 
            +
                        )
         | 
| 156 | 
            +
                        downsample = (
         | 
| 157 | 
            +
                            Downsample1D(output_channel) if not is_last else
         | 
| 158 | 
            +
                            CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
         | 
| 159 | 
            +
                        )
         | 
| 160 | 
            +
                        self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                    for _ in range(num_mid_blocks):
         | 
| 163 | 
            +
                        input_channel = channels[-1]
         | 
| 164 | 
            +
                        out_channels = channels[-1]
         | 
| 165 | 
            +
                        resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
         | 
| 166 | 
            +
                            ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                        transformer_blocks = nn.ModuleList(
         | 
| 169 | 
            +
                            [
         | 
| 170 | 
            +
                                BasicTransformerBlock(
         | 
| 171 | 
            +
                                    dim=output_channel,
         | 
| 172 | 
            +
                                    num_attention_heads=num_heads,
         | 
| 173 | 
            +
                                    attention_head_dim=attention_head_dim,
         | 
| 174 | 
            +
                                    dropout=dropout,
         | 
| 175 | 
            +
                                    activation_fn=act_fn,
         | 
| 176 | 
            +
                                )
         | 
| 177 | 
            +
                                for _ in range(n_blocks)
         | 
| 178 | 
            +
                            ]
         | 
| 179 | 
            +
                        )
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                        self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                    channels = channels[::-1] + (channels[0],)
         | 
| 184 | 
            +
                    for i in range(len(channels) - 1):
         | 
| 185 | 
            +
                        input_channel = channels[i] * 2
         | 
| 186 | 
            +
                        output_channel = channels[i + 1]
         | 
| 187 | 
            +
                        is_last = i == len(channels) - 2
         | 
| 188 | 
            +
                        resnet = CausalResnetBlock1D(
         | 
| 189 | 
            +
                            dim=input_channel,
         | 
| 190 | 
            +
                            dim_out=output_channel,
         | 
| 191 | 
            +
                            time_emb_dim=time_embed_dim,
         | 
| 192 | 
            +
                        ) if self.causal else ResnetBlock1D(
         | 
| 193 | 
            +
                            dim=input_channel,
         | 
| 194 | 
            +
                            dim_out=output_channel,
         | 
| 195 | 
            +
                            time_emb_dim=time_embed_dim,
         | 
| 196 | 
            +
                        )
         | 
| 197 | 
            +
                        transformer_blocks = nn.ModuleList(
         | 
| 198 | 
            +
                            [
         | 
| 199 | 
            +
                                BasicTransformerBlock(
         | 
| 200 | 
            +
                                    dim=output_channel,
         | 
| 201 | 
            +
                                    num_attention_heads=num_heads,
         | 
| 202 | 
            +
                                    attention_head_dim=attention_head_dim,
         | 
| 203 | 
            +
                                    dropout=dropout,
         | 
| 204 | 
            +
                                    activation_fn=act_fn,
         | 
| 205 | 
            +
                                )
         | 
| 206 | 
            +
                                for _ in range(n_blocks)
         | 
| 207 | 
            +
                            ]
         | 
| 208 | 
            +
                        )
         | 
| 209 | 
            +
                        upsample = (
         | 
| 210 | 
            +
                            Upsample1D(output_channel, use_conv_transpose=True)
         | 
| 211 | 
            +
                            if not is_last
         | 
| 212 | 
            +
                            else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
         | 
| 213 | 
            +
                        )
         | 
| 214 | 
            +
                        self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
         | 
| 215 | 
            +
                    self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
         | 
| 216 | 
            +
                    self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
         | 
| 217 | 
            +
                    self.initialize_weights()
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                def initialize_weights(self):
         | 
| 220 | 
            +
                    for m in self.modules():
         | 
| 221 | 
            +
                        if isinstance(m, nn.Conv1d):
         | 
| 222 | 
            +
                            nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
         | 
| 223 | 
            +
                            if m.bias is not None:
         | 
| 224 | 
            +
                                nn.init.constant_(m.bias, 0)
         | 
| 225 | 
            +
                        elif isinstance(m, nn.GroupNorm):
         | 
| 226 | 
            +
                            nn.init.constant_(m.weight, 1)
         | 
| 227 | 
            +
                            nn.init.constant_(m.bias, 0)
         | 
| 228 | 
            +
                        elif isinstance(m, nn.Linear):
         | 
| 229 | 
            +
                            nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
         | 
| 230 | 
            +
                            if m.bias is not None:
         | 
| 231 | 
            +
                                nn.init.constant_(m.bias, 0)
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                def forward(self, x, mask, mu, t, spks=None, cond=None):
         | 
| 234 | 
            +
                    """Forward pass of the UNet1DConditional model.
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                    Args:
         | 
| 237 | 
            +
                        x (torch.Tensor): shape (batch_size, in_channels, time)
         | 
| 238 | 
            +
                        mask (_type_): shape (batch_size, 1, time)
         | 
| 239 | 
            +
                        t (_type_): shape (batch_size)
         | 
| 240 | 
            +
                        spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
         | 
| 241 | 
            +
                        cond (_type_, optional): placeholder for future use. Defaults to None.
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                    Raises:
         | 
| 244 | 
            +
                        ValueError: _description_
         | 
| 245 | 
            +
                        ValueError: _description_
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                    Returns:
         | 
| 248 | 
            +
                        _type_: _description_
         | 
| 249 | 
            +
                    """
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                    t = self.time_embeddings(t).to(t.dtype)
         | 
| 252 | 
            +
                    t = self.time_mlp(t)
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                    x = pack([x, mu], "b * t")[0]
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                    if spks is not None:
         | 
| 257 | 
            +
                        spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
         | 
| 258 | 
            +
                        x = pack([x, spks], "b * t")[0]
         | 
| 259 | 
            +
                    if cond is not None:
         | 
| 260 | 
            +
                        x = pack([x, cond], "b * t")[0]
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                    hiddens = []
         | 
| 263 | 
            +
                    masks = [mask]
         | 
| 264 | 
            +
                    for resnet, transformer_blocks, downsample in self.down_blocks:
         | 
| 265 | 
            +
                        mask_down = masks[-1]
         | 
| 266 | 
            +
                        x = resnet(x, mask_down, t)
         | 
| 267 | 
            +
                        x = rearrange(x, "b c t -> b t c").contiguous()
         | 
| 268 | 
            +
                        # attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
         | 
| 269 | 
            +
                        attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
         | 
| 270 | 
            +
                        attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
         | 
| 271 | 
            +
                        for transformer_block in transformer_blocks:
         | 
| 272 | 
            +
                            x = transformer_block(
         | 
| 273 | 
            +
                                hidden_states=x,
         | 
| 274 | 
            +
                                attention_mask=attn_mask,
         | 
| 275 | 
            +
                                timestep=t,
         | 
| 276 | 
            +
                            )
         | 
| 277 | 
            +
                        x = rearrange(x, "b t c -> b c t").contiguous()
         | 
| 278 | 
            +
                        hiddens.append(x)  # Save hidden states for skip connections
         | 
| 279 | 
            +
                        x = downsample(x * mask_down)
         | 
| 280 | 
            +
                        masks.append(mask_down[:, :, ::2])
         | 
| 281 | 
            +
                    masks = masks[:-1]
         | 
| 282 | 
            +
                    mask_mid = masks[-1]
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                    for resnet, transformer_blocks in self.mid_blocks:
         | 
| 285 | 
            +
                        x = resnet(x, mask_mid, t)
         | 
| 286 | 
            +
                        x = rearrange(x, "b c t -> b t c").contiguous()
         | 
| 287 | 
            +
                        # attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
         | 
| 288 | 
            +
                        attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
         | 
| 289 | 
            +
                        attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
         | 
| 290 | 
            +
                        for transformer_block in transformer_blocks:
         | 
| 291 | 
            +
                            x = transformer_block(
         | 
| 292 | 
            +
                                hidden_states=x,
         | 
| 293 | 
            +
                                attention_mask=attn_mask,
         | 
| 294 | 
            +
                                timestep=t,
         | 
| 295 | 
            +
                            )
         | 
| 296 | 
            +
                        x = rearrange(x, "b t c -> b c t").contiguous()
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                    for resnet, transformer_blocks, upsample in self.up_blocks:
         | 
| 299 | 
            +
                        mask_up = masks.pop()
         | 
| 300 | 
            +
                        skip = hiddens.pop()
         | 
| 301 | 
            +
                        x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
         | 
| 302 | 
            +
                        x = resnet(x, mask_up, t)
         | 
| 303 | 
            +
                        x = rearrange(x, "b c t -> b t c").contiguous()
         | 
| 304 | 
            +
                        # attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
         | 
| 305 | 
            +
                        attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
         | 
| 306 | 
            +
                        attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
         | 
| 307 | 
            +
                        for transformer_block in transformer_blocks:
         | 
| 308 | 
            +
                            x = transformer_block(
         | 
| 309 | 
            +
                                hidden_states=x,
         | 
| 310 | 
            +
                                attention_mask=attn_mask,
         | 
| 311 | 
            +
                                timestep=t,
         | 
| 312 | 
            +
                            )
         | 
| 313 | 
            +
                        x = rearrange(x, "b t c -> b c t").contiguous()
         | 
| 314 | 
            +
                        x = upsample(x * mask_up)
         | 
| 315 | 
            +
                    x = self.final_block(x, mask_up)
         | 
| 316 | 
            +
                    output = self.final_proj(x * mask_up)
         | 
| 317 | 
            +
                    return output * mask
         | 
    	
        chatterbox/src/chatterbox/models/s3gen/f0_predictor.py
    ADDED
    
    | @@ -0,0 +1,55 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
            import torch
         | 
| 15 | 
            +
            import torch.nn as nn
         | 
| 16 | 
            +
            from torch.nn.utils.parametrizations import weight_norm
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            class ConvRNNF0Predictor(nn.Module):
         | 
| 20 | 
            +
                def __init__(self,
         | 
| 21 | 
            +
                             num_class: int = 1,
         | 
| 22 | 
            +
                             in_channels: int = 80,
         | 
| 23 | 
            +
                             cond_channels: int = 512
         | 
| 24 | 
            +
                             ):
         | 
| 25 | 
            +
                    super().__init__()
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                    self.num_class = num_class
         | 
| 28 | 
            +
                    self.condnet = nn.Sequential(
         | 
| 29 | 
            +
                        weight_norm(
         | 
| 30 | 
            +
                            nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
         | 
| 31 | 
            +
                        ),
         | 
| 32 | 
            +
                        nn.ELU(),
         | 
| 33 | 
            +
                        weight_norm(
         | 
| 34 | 
            +
                            nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
         | 
| 35 | 
            +
                        ),
         | 
| 36 | 
            +
                        nn.ELU(),
         | 
| 37 | 
            +
                        weight_norm(
         | 
| 38 | 
            +
                            nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
         | 
| 39 | 
            +
                        ),
         | 
| 40 | 
            +
                        nn.ELU(),
         | 
| 41 | 
            +
                        weight_norm(
         | 
| 42 | 
            +
                            nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
         | 
| 43 | 
            +
                        ),
         | 
| 44 | 
            +
                        nn.ELU(),
         | 
| 45 | 
            +
                        weight_norm(
         | 
| 46 | 
            +
                            nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
         | 
| 47 | 
            +
                        ),
         | 
| 48 | 
            +
                        nn.ELU(),
         | 
| 49 | 
            +
                    )
         | 
| 50 | 
            +
                    self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 53 | 
            +
                    x = self.condnet(x)
         | 
| 54 | 
            +
                    x = x.transpose(1, 2)
         | 
| 55 | 
            +
                    return torch.abs(self.classifier(x).squeeze(-1))
         | 
    	
        chatterbox/src/chatterbox/models/s3gen/flow.py
    ADDED
    
    | @@ -0,0 +1,242 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
            import logging
         | 
| 15 | 
            +
            import random
         | 
| 16 | 
            +
            from typing import Dict, Optional
         | 
| 17 | 
            +
            import torch
         | 
| 18 | 
            +
            import torch.nn as nn
         | 
| 19 | 
            +
            from torch.nn import functional as F
         | 
| 20 | 
            +
            from omegaconf import DictConfig
         | 
| 21 | 
            +
            from .utils.mask import make_pad_mask
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            class MaskedDiffWithXvec(torch.nn.Module):
         | 
| 25 | 
            +
                def __init__(self,
         | 
| 26 | 
            +
                             input_size: int = 512,
         | 
| 27 | 
            +
                             output_size: int = 80,
         | 
| 28 | 
            +
                             spk_embed_dim: int = 192,
         | 
| 29 | 
            +
                             output_type: str = "mel",
         | 
| 30 | 
            +
                             vocab_size: int = 4096,
         | 
| 31 | 
            +
                             input_frame_rate: int = 50,
         | 
| 32 | 
            +
                             only_mask_loss: bool = True,
         | 
| 33 | 
            +
                             encoder: torch.nn.Module = None,
         | 
| 34 | 
            +
                             length_regulator: torch.nn.Module = None,
         | 
| 35 | 
            +
                             decoder: torch.nn.Module = None,
         | 
| 36 | 
            +
                             decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
         | 
| 37 | 
            +
                                                   'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
         | 
| 38 | 
            +
                                                                             'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
         | 
| 39 | 
            +
                                                   'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
         | 
| 40 | 
            +
                                                                      'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
         | 
| 41 | 
            +
                             mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
         | 
| 42 | 
            +
                                                    'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
         | 
| 43 | 
            +
                    super().__init__()
         | 
| 44 | 
            +
                    self.input_size = input_size
         | 
| 45 | 
            +
                    self.output_size = output_size
         | 
| 46 | 
            +
                    self.decoder_conf = decoder_conf
         | 
| 47 | 
            +
                    self.mel_feat_conf = mel_feat_conf
         | 
| 48 | 
            +
                    self.vocab_size = vocab_size
         | 
| 49 | 
            +
                    self.output_type = output_type
         | 
| 50 | 
            +
                    self.input_frame_rate = input_frame_rate
         | 
| 51 | 
            +
                    logging.info(f"input frame rate={self.input_frame_rate}")
         | 
| 52 | 
            +
                    self.input_embedding = nn.Embedding(vocab_size, input_size)
         | 
| 53 | 
            +
                    self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
         | 
| 54 | 
            +
                    self.encoder = encoder
         | 
| 55 | 
            +
                    self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
         | 
| 56 | 
            +
                    self.decoder = decoder
         | 
| 57 | 
            +
                    self.length_regulator = length_regulator
         | 
| 58 | 
            +
                    self.only_mask_loss = only_mask_loss
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                def forward(
         | 
| 61 | 
            +
                        self,
         | 
| 62 | 
            +
                        batch: dict,
         | 
| 63 | 
            +
                        device: torch.device,
         | 
| 64 | 
            +
                ) -> Dict[str, Optional[torch.Tensor]]:
         | 
| 65 | 
            +
                    token = batch['speech_token'].to(device)
         | 
| 66 | 
            +
                    token_len = batch['speech_token_len'].to(device)
         | 
| 67 | 
            +
                    feat = batch['speech_feat'].to(device)
         | 
| 68 | 
            +
                    feat_len = batch['speech_feat_len'].to(device)
         | 
| 69 | 
            +
                    embedding = batch['embedding'].to(device)
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    # xvec projection
         | 
| 72 | 
            +
                    embedding = F.normalize(embedding, dim=1)
         | 
| 73 | 
            +
                    embedding = self.spk_embed_affine_layer(embedding)
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    # concat text and prompt_text
         | 
| 76 | 
            +
                    mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
         | 
| 77 | 
            +
                    token = self.input_embedding(torch.clamp(token, min=0)) * mask
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    # text encode
         | 
| 80 | 
            +
                    h, h_lengths = self.encoder(token, token_len)
         | 
| 81 | 
            +
                    h = self.encoder_proj(h)
         | 
| 82 | 
            +
                    h, h_lengths = self.length_regulator(h, feat_len)
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    # get conditions
         | 
| 85 | 
            +
                    conds = torch.zeros(feat.shape, device=token.device)
         | 
| 86 | 
            +
                    for i, j in enumerate(feat_len):
         | 
| 87 | 
            +
                        if random.random() < 0.5:
         | 
| 88 | 
            +
                            continue
         | 
| 89 | 
            +
                        index = random.randint(0, int(0.3 * j))
         | 
| 90 | 
            +
                        conds[i, :index] = feat[i, :index]
         | 
| 91 | 
            +
                    conds = conds.transpose(1, 2)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                    mask = (~make_pad_mask(feat_len)).to(h)
         | 
| 94 | 
            +
                    feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
         | 
| 95 | 
            +
                    loss, _ = self.decoder.compute_loss(
         | 
| 96 | 
            +
                        feat.transpose(1, 2).contiguous(),
         | 
| 97 | 
            +
                        mask.unsqueeze(1),
         | 
| 98 | 
            +
                        h.transpose(1, 2).contiguous(),
         | 
| 99 | 
            +
                        embedding,
         | 
| 100 | 
            +
                        cond=conds
         | 
| 101 | 
            +
                    )
         | 
| 102 | 
            +
                    return {'loss': loss}
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                @torch.inference_mode()
         | 
| 105 | 
            +
                def inference(self,
         | 
| 106 | 
            +
                              token,
         | 
| 107 | 
            +
                              token_len,
         | 
| 108 | 
            +
                              prompt_token,
         | 
| 109 | 
            +
                              prompt_token_len,
         | 
| 110 | 
            +
                              prompt_feat,
         | 
| 111 | 
            +
                              prompt_feat_len,
         | 
| 112 | 
            +
                              embedding,
         | 
| 113 | 
            +
                              flow_cache):
         | 
| 114 | 
            +
                    if self.fp16 is True:
         | 
| 115 | 
            +
                        prompt_feat = prompt_feat.half()
         | 
| 116 | 
            +
                        embedding = embedding.half()
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                    assert token.shape[0] == 1
         | 
| 119 | 
            +
                    # xvec projection
         | 
| 120 | 
            +
                    embedding = F.normalize(embedding, dim=1)
         | 
| 121 | 
            +
                    embedding = self.spk_embed_affine_layer(embedding)
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    # concat text and prompt_text
         | 
| 124 | 
            +
                    token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
         | 
| 125 | 
            +
                    token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
         | 
| 126 | 
            +
                    mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
         | 
| 127 | 
            +
                    token = self.input_embedding(torch.clamp(token, min=0)) * mask
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    # text encode
         | 
| 130 | 
            +
                    h, h_lengths = self.encoder(token, token_len)
         | 
| 131 | 
            +
                    h = self.encoder_proj(h)
         | 
| 132 | 
            +
                    mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
         | 
| 133 | 
            +
                    h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    # get conditions
         | 
| 136 | 
            +
                    conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
         | 
| 137 | 
            +
                    conds[:, :mel_len1] = prompt_feat
         | 
| 138 | 
            +
                    conds = conds.transpose(1, 2)
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
         | 
| 141 | 
            +
                    feat, flow_cache = self.decoder(
         | 
| 142 | 
            +
                        mu=h.transpose(1, 2).contiguous(),
         | 
| 143 | 
            +
                        mask=mask.unsqueeze(1),
         | 
| 144 | 
            +
                        spks=embedding,
         | 
| 145 | 
            +
                        cond=conds,
         | 
| 146 | 
            +
                        n_timesteps=10,
         | 
| 147 | 
            +
                        prompt_len=mel_len1,
         | 
| 148 | 
            +
                        flow_cache=flow_cache
         | 
| 149 | 
            +
                    )
         | 
| 150 | 
            +
                    feat = feat[:, :, mel_len1:]
         | 
| 151 | 
            +
                    assert feat.shape[2] == mel_len2
         | 
| 152 | 
            +
                    return feat.float(), flow_cache
         | 
| 153 | 
            +
             | 
| 154 | 
            +
             | 
| 155 | 
            +
            class CausalMaskedDiffWithXvec(torch.nn.Module):
         | 
| 156 | 
            +
                def __init__(self,
         | 
| 157 | 
            +
                             input_size: int = 512,
         | 
| 158 | 
            +
                             output_size: int = 80,
         | 
| 159 | 
            +
                             spk_embed_dim: int = 192,
         | 
| 160 | 
            +
                             output_type: str = "mel",
         | 
| 161 | 
            +
                             vocab_size: int = 6561,
         | 
| 162 | 
            +
                             input_frame_rate: int = 25,
         | 
| 163 | 
            +
                             only_mask_loss: bool = True,
         | 
| 164 | 
            +
                             token_mel_ratio: int = 2,
         | 
| 165 | 
            +
                             pre_lookahead_len: int = 3,
         | 
| 166 | 
            +
                             encoder: torch.nn.Module = None,
         | 
| 167 | 
            +
                             decoder: torch.nn.Module = None,
         | 
| 168 | 
            +
                             decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
         | 
| 169 | 
            +
                                                   'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
         | 
| 170 | 
            +
                                                                             'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
         | 
| 171 | 
            +
                                                   'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
         | 
| 172 | 
            +
                                                                      'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
         | 
| 173 | 
            +
                             mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
         | 
| 174 | 
            +
                                                    'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
         | 
| 175 | 
            +
                    super().__init__()
         | 
| 176 | 
            +
                    self.input_size = input_size
         | 
| 177 | 
            +
                    self.output_size = output_size
         | 
| 178 | 
            +
                    self.decoder_conf = decoder_conf
         | 
| 179 | 
            +
                    self.mel_feat_conf = mel_feat_conf
         | 
| 180 | 
            +
                    self.vocab_size = vocab_size
         | 
| 181 | 
            +
                    self.output_type = output_type
         | 
| 182 | 
            +
                    self.input_frame_rate = input_frame_rate
         | 
| 183 | 
            +
                    logging.info(f"input frame rate={self.input_frame_rate}")
         | 
| 184 | 
            +
                    self.input_embedding = nn.Embedding(vocab_size, input_size)
         | 
| 185 | 
            +
                    self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
         | 
| 186 | 
            +
                    self.encoder = encoder
         | 
| 187 | 
            +
                    self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
         | 
| 188 | 
            +
                    self.decoder = decoder
         | 
| 189 | 
            +
                    self.only_mask_loss = only_mask_loss
         | 
| 190 | 
            +
                    self.token_mel_ratio = token_mel_ratio
         | 
| 191 | 
            +
                    self.pre_lookahead_len = pre_lookahead_len
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                    # FIXME: this was missing - just putting it in as false
         | 
| 194 | 
            +
                    self.fp16 = False
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                @torch.inference_mode()
         | 
| 197 | 
            +
                def inference(self,
         | 
| 198 | 
            +
                              token,
         | 
| 199 | 
            +
                              token_len,
         | 
| 200 | 
            +
                              prompt_token,
         | 
| 201 | 
            +
                              prompt_token_len,
         | 
| 202 | 
            +
                              prompt_feat,
         | 
| 203 | 
            +
                              prompt_feat_len,
         | 
| 204 | 
            +
                              embedding,
         | 
| 205 | 
            +
                              finalize):
         | 
| 206 | 
            +
                    if self.fp16 is True:
         | 
| 207 | 
            +
                        prompt_feat = prompt_feat.half()
         | 
| 208 | 
            +
                        embedding = embedding.half()
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                    assert token.shape[0] == 1
         | 
| 211 | 
            +
                    # xvec projection
         | 
| 212 | 
            +
                    embedding = F.normalize(embedding, dim=1)
         | 
| 213 | 
            +
                    embedding = self.spk_embed_affine_layer(embedding)
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                    # concat text and prompt_text
         | 
| 216 | 
            +
                    token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
         | 
| 217 | 
            +
                    mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
         | 
| 218 | 
            +
                    token = self.input_embedding(torch.clamp(token, min=0)) * mask
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                    # text encode
         | 
| 221 | 
            +
                    h, h_lengths = self.encoder(token, token_len)
         | 
| 222 | 
            +
                    if finalize is False:
         | 
| 223 | 
            +
                        h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio]
         | 
| 224 | 
            +
                    mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
         | 
| 225 | 
            +
                    h = self.encoder_proj(h)
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                    # get conditions
         | 
| 228 | 
            +
                    conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
         | 
| 229 | 
            +
                    conds[:, :mel_len1] = prompt_feat
         | 
| 230 | 
            +
                    conds = conds.transpose(1, 2)
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                    mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
         | 
| 233 | 
            +
                    feat, _ = self.decoder(
         | 
| 234 | 
            +
                        mu=h.transpose(1, 2).contiguous(),
         | 
| 235 | 
            +
                        mask=mask.unsqueeze(1),
         | 
| 236 | 
            +
                        spks=embedding,
         | 
| 237 | 
            +
                        cond=conds,
         | 
| 238 | 
            +
                        n_timesteps=10
         | 
| 239 | 
            +
                    )
         | 
| 240 | 
            +
                    feat = feat[:, :, mel_len1:]
         | 
| 241 | 
            +
                    assert feat.shape[2] == mel_len2
         | 
| 242 | 
            +
                    return feat.float(), None  # NOTE jrm: why are they returning None here?
         | 
    	
        chatterbox/src/chatterbox/models/s3gen/flow_matching.py
    ADDED
    
    | @@ -0,0 +1,228 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
            import threading
         | 
| 15 | 
            +
            import torch
         | 
| 16 | 
            +
            import torch.nn.functional as F
         | 
| 17 | 
            +
            from .matcha.flow_matching import BASECFM
         | 
| 18 | 
            +
            from omegaconf import OmegaConf
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            CFM_PARAMS = OmegaConf.create({
         | 
| 22 | 
            +
                "sigma_min": 1e-06,
         | 
| 23 | 
            +
                "solver": "euler",
         | 
| 24 | 
            +
                "t_scheduler": "cosine",
         | 
| 25 | 
            +
                "training_cfg_rate": 0.2,
         | 
| 26 | 
            +
                "inference_cfg_rate": 0.7,
         | 
| 27 | 
            +
                "reg_loss_type": "l1"
         | 
| 28 | 
            +
            })
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            class ConditionalCFM(BASECFM):
         | 
| 32 | 
            +
                def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
         | 
| 33 | 
            +
                    super().__init__(
         | 
| 34 | 
            +
                        n_feats=in_channels,
         | 
| 35 | 
            +
                        cfm_params=cfm_params,
         | 
| 36 | 
            +
                        n_spks=n_spks,
         | 
| 37 | 
            +
                        spk_emb_dim=spk_emb_dim,
         | 
| 38 | 
            +
                    )
         | 
| 39 | 
            +
                    self.t_scheduler = cfm_params.t_scheduler
         | 
| 40 | 
            +
                    self.training_cfg_rate = cfm_params.training_cfg_rate
         | 
| 41 | 
            +
                    self.inference_cfg_rate = cfm_params.inference_cfg_rate
         | 
| 42 | 
            +
                    in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
         | 
| 43 | 
            +
                    # Just change the architecture of the estimator here
         | 
| 44 | 
            +
                    self.estimator = estimator
         | 
| 45 | 
            +
                    self.lock = threading.Lock()
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                @torch.inference_mode()
         | 
| 48 | 
            +
                def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
         | 
| 49 | 
            +
                    """Forward diffusion
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    Args:
         | 
| 52 | 
            +
                        mu (torch.Tensor): output of encoder
         | 
| 53 | 
            +
                            shape: (batch_size, n_feats, mel_timesteps)
         | 
| 54 | 
            +
                        mask (torch.Tensor): output_mask
         | 
| 55 | 
            +
                            shape: (batch_size, 1, mel_timesteps)
         | 
| 56 | 
            +
                        n_timesteps (int): number of diffusion steps
         | 
| 57 | 
            +
                        temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
         | 
| 58 | 
            +
                        spks (torch.Tensor, optional): speaker ids. Defaults to None.
         | 
| 59 | 
            +
                            shape: (batch_size, spk_emb_dim)
         | 
| 60 | 
            +
                        cond: Not used but kept for future purposes
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    Returns:
         | 
| 63 | 
            +
                        sample: generated mel-spectrogram
         | 
| 64 | 
            +
                            shape: (batch_size, n_feats, mel_timesteps)
         | 
| 65 | 
            +
                    """
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                    z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
         | 
| 68 | 
            +
                    cache_size = flow_cache.shape[2]
         | 
| 69 | 
            +
                    # fix prompt and overlap part mu and z
         | 
| 70 | 
            +
                    if cache_size != 0:
         | 
| 71 | 
            +
                        z[:, :, :cache_size] = flow_cache[:, :, :, 0]
         | 
| 72 | 
            +
                        mu[:, :, :cache_size] = flow_cache[:, :, :, 1]
         | 
| 73 | 
            +
                    z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
         | 
| 74 | 
            +
                    mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
         | 
| 75 | 
            +
                    flow_cache = torch.stack([z_cache, mu_cache], dim=-1)
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                    t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
         | 
| 78 | 
            +
                    if self.t_scheduler == 'cosine':
         | 
| 79 | 
            +
                        t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
         | 
| 80 | 
            +
                    return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                def solve_euler(self, x, t_span, mu, mask, spks, cond):
         | 
| 83 | 
            +
                    """
         | 
| 84 | 
            +
                    Fixed euler solver for ODEs.
         | 
| 85 | 
            +
                    Args:
         | 
| 86 | 
            +
                        x (torch.Tensor): random noise
         | 
| 87 | 
            +
                        t_span (torch.Tensor): n_timesteps interpolated
         | 
| 88 | 
            +
                            shape: (n_timesteps + 1,)
         | 
| 89 | 
            +
                        mu (torch.Tensor): output of encoder
         | 
| 90 | 
            +
                            shape: (batch_size, n_feats, mel_timesteps)
         | 
| 91 | 
            +
                        mask (torch.Tensor): output_mask
         | 
| 92 | 
            +
                            shape: (batch_size, 1, mel_timesteps)
         | 
| 93 | 
            +
                        spks (torch.Tensor, optional): speaker ids. Defaults to None.
         | 
| 94 | 
            +
                            shape: (batch_size, spk_emb_dim)
         | 
| 95 | 
            +
                        cond: Not used but kept for future purposes
         | 
| 96 | 
            +
                    """
         | 
| 97 | 
            +
                    t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
         | 
| 98 | 
            +
                    t = t.unsqueeze(dim=0)
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                    # I am storing this because I can later plot it by putting a debugger here and saving it to a file
         | 
| 101 | 
            +
                    # Or in future might add like a return_all_steps flag
         | 
| 102 | 
            +
                    sol = []
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    # Do not use concat, it may cause memory format changed and trt infer with wrong results!
         | 
| 105 | 
            +
                    x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
         | 
| 106 | 
            +
                    mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
         | 
| 107 | 
            +
                    mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
         | 
| 108 | 
            +
                    t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
         | 
| 109 | 
            +
                    spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
         | 
| 110 | 
            +
                    cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
         | 
| 111 | 
            +
                    for step in range(1, len(t_span)):
         | 
| 112 | 
            +
                        # Classifier-Free Guidance inference introduced in VoiceBox
         | 
| 113 | 
            +
                        x_in[:] = x
         | 
| 114 | 
            +
                        mask_in[:] = mask
         | 
| 115 | 
            +
                        mu_in[0] = mu
         | 
| 116 | 
            +
                        t_in[:] = t.unsqueeze(0)
         | 
| 117 | 
            +
                        spks_in[0] = spks
         | 
| 118 | 
            +
                        cond_in[0] = cond
         | 
| 119 | 
            +
                        dphi_dt = self.forward_estimator(
         | 
| 120 | 
            +
                            x_in, mask_in,
         | 
| 121 | 
            +
                            mu_in, t_in,
         | 
| 122 | 
            +
                            spks_in,
         | 
| 123 | 
            +
                            cond_in
         | 
| 124 | 
            +
                        )
         | 
| 125 | 
            +
                        dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
         | 
| 126 | 
            +
                        dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
         | 
| 127 | 
            +
                        x = x + dt * dphi_dt
         | 
| 128 | 
            +
                        t = t + dt
         | 
| 129 | 
            +
                        sol.append(x)
         | 
| 130 | 
            +
                        if step < len(t_span) - 1:
         | 
| 131 | 
            +
                            dt = t_span[step + 1] - t
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    return sol[-1].float()
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                def forward_estimator(self, x, mask, mu, t, spks, cond):
         | 
| 136 | 
            +
                    if isinstance(self.estimator, torch.nn.Module):
         | 
| 137 | 
            +
                        return self.estimator.forward(x, mask, mu, t, spks, cond)
         | 
| 138 | 
            +
                    else:
         | 
| 139 | 
            +
                        with self.lock:
         | 
| 140 | 
            +
                            self.estimator.set_input_shape('x', (2, 80, x.size(2)))
         | 
| 141 | 
            +
                            self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
         | 
| 142 | 
            +
                            self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
         | 
| 143 | 
            +
                            self.estimator.set_input_shape('t', (2,))
         | 
| 144 | 
            +
                            self.estimator.set_input_shape('spks', (2, 80))
         | 
| 145 | 
            +
                            self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
         | 
| 146 | 
            +
                            # run trt engine
         | 
| 147 | 
            +
                            self.estimator.execute_v2([x.contiguous().data_ptr(),
         | 
| 148 | 
            +
                                                       mask.contiguous().data_ptr(),
         | 
| 149 | 
            +
                                                       mu.contiguous().data_ptr(),
         | 
| 150 | 
            +
                                                       t.contiguous().data_ptr(),
         | 
| 151 | 
            +
                                                       spks.contiguous().data_ptr(),
         | 
| 152 | 
            +
                                                       cond.contiguous().data_ptr(),
         | 
| 153 | 
            +
                                                       x.data_ptr()])
         | 
| 154 | 
            +
                        return x
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                def compute_loss(self, x1, mask, mu, spks=None, cond=None):
         | 
| 157 | 
            +
                    """Computes diffusion loss
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                    Args:
         | 
| 160 | 
            +
                        x1 (torch.Tensor): Target
         | 
| 161 | 
            +
                            shape: (batch_size, n_feats, mel_timesteps)
         | 
| 162 | 
            +
                        mask (torch.Tensor): target mask
         | 
| 163 | 
            +
                            shape: (batch_size, 1, mel_timesteps)
         | 
| 164 | 
            +
                        mu (torch.Tensor): output of encoder
         | 
| 165 | 
            +
                            shape: (batch_size, n_feats, mel_timesteps)
         | 
| 166 | 
            +
                        spks (torch.Tensor, optional): speaker embedding. Defaults to None.
         | 
| 167 | 
            +
                            shape: (batch_size, spk_emb_dim)
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                    Returns:
         | 
| 170 | 
            +
                        loss: conditional flow matching loss
         | 
| 171 | 
            +
                        y: conditional flow
         | 
| 172 | 
            +
                            shape: (batch_size, n_feats, mel_timesteps)
         | 
| 173 | 
            +
                    """
         | 
| 174 | 
            +
                    b, _, t = mu.shape
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                    # random timestep
         | 
| 177 | 
            +
                    t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
         | 
| 178 | 
            +
                    if self.t_scheduler == 'cosine':
         | 
| 179 | 
            +
                        t = 1 - torch.cos(t * 0.5 * torch.pi)
         | 
| 180 | 
            +
                    # sample noise p(x_0)
         | 
| 181 | 
            +
                    z = torch.randn_like(x1)
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                    y = (1 - (1 - self.sigma_min) * t) * z + t * x1
         | 
| 184 | 
            +
                    u = x1 - (1 - self.sigma_min) * z
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                    # during training, we randomly drop condition to trade off mode coverage and sample fidelity
         | 
| 187 | 
            +
                    if self.training_cfg_rate > 0:
         | 
| 188 | 
            +
                        cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
         | 
| 189 | 
            +
                        mu = mu * cfg_mask.view(-1, 1, 1)
         | 
| 190 | 
            +
                        spks = spks * cfg_mask.view(-1, 1)
         | 
| 191 | 
            +
                        cond = cond * cfg_mask.view(-1, 1, 1)
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                    pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
         | 
| 194 | 
            +
                    loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
         | 
| 195 | 
            +
                    return loss, y
         | 
| 196 | 
            +
             | 
| 197 | 
            +
             | 
| 198 | 
            +
            class CausalConditionalCFM(ConditionalCFM):
         | 
| 199 | 
            +
                def __init__(self, in_channels=240, cfm_params=CFM_PARAMS, n_spks=1, spk_emb_dim=80, estimator=None):
         | 
| 200 | 
            +
                    super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
         | 
| 201 | 
            +
                    self.rand_noise = torch.randn([1, 80, 50 * 300])
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                @torch.inference_mode()
         | 
| 204 | 
            +
                def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
         | 
| 205 | 
            +
                    """Forward diffusion
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                    Args:
         | 
| 208 | 
            +
                        mu (torch.Tensor): output of encoder
         | 
| 209 | 
            +
                            shape: (batch_size, n_feats, mel_timesteps)
         | 
| 210 | 
            +
                        mask (torch.Tensor): output_mask
         | 
| 211 | 
            +
                            shape: (batch_size, 1, mel_timesteps)
         | 
| 212 | 
            +
                        n_timesteps (int): number of diffusion steps
         | 
| 213 | 
            +
                        temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
         | 
| 214 | 
            +
                        spks (torch.Tensor, optional): speaker ids. Defaults to None.
         | 
| 215 | 
            +
                            shape: (batch_size, spk_emb_dim)
         | 
| 216 | 
            +
                        cond: Not used but kept for future purposes
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                    Returns:
         | 
| 219 | 
            +
                        sample: generated mel-spectrogram
         | 
| 220 | 
            +
                            shape: (batch_size, n_feats, mel_timesteps)
         | 
| 221 | 
            +
                    """
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                    z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature
         | 
| 224 | 
            +
                    # fix prompt and overlap part mu and z
         | 
| 225 | 
            +
                    t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
         | 
| 226 | 
            +
                    if self.t_scheduler == 'cosine':
         | 
| 227 | 
            +
                        t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
         | 
| 228 | 
            +
                    return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None
         | 
    	
        chatterbox/src/chatterbox/models/s3gen/hifigan.py
    ADDED
    
    | @@ -0,0 +1,474 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # jrm: adapted from CosyVoice/cosyvoice/hifigan/generator.py
         | 
| 2 | 
            +
            #      most modules should be reusable, but I found their SineGen changed a git.
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
         | 
| 5 | 
            +
            #
         | 
| 6 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 7 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 8 | 
            +
            # You may obtain a copy of the License at
         | 
| 9 | 
            +
            #
         | 
| 10 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 11 | 
            +
            #
         | 
| 12 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 13 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 14 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 15 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 16 | 
            +
            # limitations under the License.
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            """HIFI-GAN"""
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            from typing import Dict, Optional, List
         | 
| 21 | 
            +
            import numpy as np
         | 
| 22 | 
            +
            from scipy.signal import get_window
         | 
| 23 | 
            +
            import torch
         | 
| 24 | 
            +
            import torch.nn.functional as F
         | 
| 25 | 
            +
            from torch.nn import Conv1d
         | 
| 26 | 
            +
            from torch.nn import ConvTranspose1d
         | 
| 27 | 
            +
            from torch.nn.utils import remove_weight_norm
         | 
| 28 | 
            +
            from torch.nn.utils.parametrizations import weight_norm
         | 
| 29 | 
            +
            from torch.distributions.uniform import Uniform
         | 
| 30 | 
            +
            from torch import nn, sin, pow
         | 
| 31 | 
            +
            from torch.nn import Parameter
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            class Snake(nn.Module):
         | 
| 35 | 
            +
                '''
         | 
| 36 | 
            +
                Implementation of a sine-based periodic activation function
         | 
| 37 | 
            +
                Shape:
         | 
| 38 | 
            +
                    - Input: (B, C, T)
         | 
| 39 | 
            +
                    - Output: (B, C, T), same shape as the input
         | 
| 40 | 
            +
                Parameters:
         | 
| 41 | 
            +
                    - alpha - trainable parameter
         | 
| 42 | 
            +
                References:
         | 
| 43 | 
            +
                    - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
         | 
| 44 | 
            +
                    https://arxiv.org/abs/2006.08195
         | 
| 45 | 
            +
                Examples:
         | 
| 46 | 
            +
                    >>> a1 = snake(256)
         | 
| 47 | 
            +
                    >>> x = torch.randn(256)
         | 
| 48 | 
            +
                    >>> x = a1(x)
         | 
| 49 | 
            +
                '''
         | 
| 50 | 
            +
                def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
         | 
| 51 | 
            +
                    '''
         | 
| 52 | 
            +
                    Initialization.
         | 
| 53 | 
            +
                    INPUT:
         | 
| 54 | 
            +
                        - in_features: shape of the input
         | 
| 55 | 
            +
                        - alpha: trainable parameter
         | 
| 56 | 
            +
                        alpha is initialized to 1 by default, higher values = higher-frequency.
         | 
| 57 | 
            +
                        alpha will be trained along with the rest of your model.
         | 
| 58 | 
            +
                    '''
         | 
| 59 | 
            +
                    super(Snake, self).__init__()
         | 
| 60 | 
            +
                    self.in_features = in_features
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    # initialize alpha
         | 
| 63 | 
            +
                    self.alpha_logscale = alpha_logscale
         | 
| 64 | 
            +
                    if self.alpha_logscale: # log scale alphas initialized to zeros
         | 
| 65 | 
            +
                        self.alpha = Parameter(torch.zeros(in_features) * alpha)
         | 
| 66 | 
            +
                    else: # linear scale alphas initialized to ones
         | 
| 67 | 
            +
                        self.alpha = Parameter(torch.ones(in_features) * alpha)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    self.alpha.requires_grad = alpha_trainable
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    self.no_div_by_zero = 0.000000001
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                def forward(self, x):
         | 
| 74 | 
            +
                    '''
         | 
| 75 | 
            +
                    Forward pass of the function.
         | 
| 76 | 
            +
                    Applies the function to the input elementwise.
         | 
| 77 | 
            +
                    Snake ∶= x + 1/a * sin^2 (xa)
         | 
| 78 | 
            +
                    '''
         | 
| 79 | 
            +
                    alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
         | 
| 80 | 
            +
                    if self.alpha_logscale:
         | 
| 81 | 
            +
                        alpha = torch.exp(alpha)
         | 
| 82 | 
            +
                    x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    return x
         | 
| 85 | 
            +
             | 
| 86 | 
            +
             | 
| 87 | 
            +
             | 
| 88 | 
            +
            def get_padding(kernel_size, dilation=1):
         | 
| 89 | 
            +
                return int((kernel_size * dilation - dilation) / 2)
         | 
| 90 | 
            +
             | 
| 91 | 
            +
            def init_weights(m, mean=0.0, std=0.01):
         | 
| 92 | 
            +
                classname = m.__class__.__name__
         | 
| 93 | 
            +
                if classname.find("Conv") != -1:
         | 
| 94 | 
            +
                    m.weight.data.normal_(mean, std)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
             | 
| 97 | 
            +
            """hifigan based generator implementation.
         | 
| 98 | 
            +
             | 
| 99 | 
            +
            This code is modified from https://github.com/jik876/hifi-gan
         | 
| 100 | 
            +
             ,https://github.com/kan-bayashi/ParallelWaveGAN and
         | 
| 101 | 
            +
             https://github.com/NVIDIA/BigVGAN
         | 
| 102 | 
            +
             | 
| 103 | 
            +
            """
         | 
| 104 | 
            +
             | 
| 105 | 
            +
             | 
| 106 | 
            +
            class ResBlock(torch.nn.Module):
         | 
| 107 | 
            +
                """Residual block module in HiFiGAN/BigVGAN."""
         | 
| 108 | 
            +
                def __init__(
         | 
| 109 | 
            +
                    self,
         | 
| 110 | 
            +
                    channels: int = 512,
         | 
| 111 | 
            +
                    kernel_size: int = 3,
         | 
| 112 | 
            +
                    dilations: List[int] = [1, 3, 5],
         | 
| 113 | 
            +
                ):
         | 
| 114 | 
            +
                    super(ResBlock, self).__init__()
         | 
| 115 | 
            +
                    self.convs1 = nn.ModuleList()
         | 
| 116 | 
            +
                    self.convs2 = nn.ModuleList()
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                    for dilation in dilations:
         | 
| 119 | 
            +
                        self.convs1.append(
         | 
| 120 | 
            +
                            weight_norm(
         | 
| 121 | 
            +
                                Conv1d(
         | 
| 122 | 
            +
                                    channels,
         | 
| 123 | 
            +
                                    channels,
         | 
| 124 | 
            +
                                    kernel_size,
         | 
| 125 | 
            +
                                    1,
         | 
| 126 | 
            +
                                    dilation=dilation,
         | 
| 127 | 
            +
                                    padding=get_padding(kernel_size, dilation)
         | 
| 128 | 
            +
                                )
         | 
| 129 | 
            +
                            )
         | 
| 130 | 
            +
                        )
         | 
| 131 | 
            +
                        self.convs2.append(
         | 
| 132 | 
            +
                            weight_norm(
         | 
| 133 | 
            +
                                Conv1d(
         | 
| 134 | 
            +
                                    channels,
         | 
| 135 | 
            +
                                    channels,
         | 
| 136 | 
            +
                                    kernel_size,
         | 
| 137 | 
            +
                                    1,
         | 
| 138 | 
            +
                                    dilation=1,
         | 
| 139 | 
            +
                                    padding=get_padding(kernel_size, 1)
         | 
| 140 | 
            +
                                )
         | 
| 141 | 
            +
                            )
         | 
| 142 | 
            +
                        )
         | 
| 143 | 
            +
                    self.convs1.apply(init_weights)
         | 
| 144 | 
            +
                    self.convs2.apply(init_weights)
         | 
| 145 | 
            +
                    self.activations1 = nn.ModuleList([
         | 
| 146 | 
            +
                        Snake(channels, alpha_logscale=False)
         | 
| 147 | 
            +
                        for _ in range(len(self.convs1))
         | 
| 148 | 
            +
                    ])
         | 
| 149 | 
            +
                    self.activations2 = nn.ModuleList([
         | 
| 150 | 
            +
                        Snake(channels, alpha_logscale=False)
         | 
| 151 | 
            +
                        for _ in range(len(self.convs2))
         | 
| 152 | 
            +
                    ])
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 155 | 
            +
                    for idx in range(len(self.convs1)):
         | 
| 156 | 
            +
                        xt = self.activations1[idx](x)
         | 
| 157 | 
            +
                        xt = self.convs1[idx](xt)
         | 
| 158 | 
            +
                        xt = self.activations2[idx](xt)
         | 
| 159 | 
            +
                        xt = self.convs2[idx](xt)
         | 
| 160 | 
            +
                        x = xt + x
         | 
| 161 | 
            +
                    return x
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                def remove_weight_norm(self):
         | 
| 164 | 
            +
                    for idx in range(len(self.convs1)):
         | 
| 165 | 
            +
                        remove_weight_norm(self.convs1[idx])
         | 
| 166 | 
            +
                        remove_weight_norm(self.convs2[idx])
         | 
| 167 | 
            +
             | 
| 168 | 
            +
             | 
| 169 | 
            +
            class SineGen(torch.nn.Module):
         | 
| 170 | 
            +
                """ Definition of sine generator
         | 
| 171 | 
            +
                SineGen(samp_rate, harmonic_num = 0,
         | 
| 172 | 
            +
                        sine_amp = 0.1, noise_std = 0.003,
         | 
| 173 | 
            +
                        voiced_threshold = 0,
         | 
| 174 | 
            +
                        flag_for_pulse=False)
         | 
| 175 | 
            +
                samp_rate: sampling rate in Hz
         | 
| 176 | 
            +
                harmonic_num: number of harmonic overtones (default 0)
         | 
| 177 | 
            +
                sine_amp: amplitude of sine-wavefrom (default 0.1)
         | 
| 178 | 
            +
                noise_std: std of Gaussian noise (default 0.003)
         | 
| 179 | 
            +
                voiced_thoreshold: F0 threshold for U/V classification (default 0)
         | 
| 180 | 
            +
                flag_for_pulse: this SinGen is used inside PulseGen (default False)
         | 
| 181 | 
            +
                Note: when flag_for_pulse is True, the first time step of a voiced
         | 
| 182 | 
            +
                    segment is always sin(np.pi) or cos(0)
         | 
| 183 | 
            +
                """
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                def __init__(self, samp_rate, harmonic_num=0,
         | 
| 186 | 
            +
                             sine_amp=0.1, noise_std=0.003,
         | 
| 187 | 
            +
                             voiced_threshold=0):
         | 
| 188 | 
            +
                    super(SineGen, self).__init__()
         | 
| 189 | 
            +
                    self.sine_amp = sine_amp
         | 
| 190 | 
            +
                    self.noise_std = noise_std
         | 
| 191 | 
            +
                    self.harmonic_num = harmonic_num
         | 
| 192 | 
            +
                    self.sampling_rate = samp_rate
         | 
| 193 | 
            +
                    self.voiced_threshold = voiced_threshold
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                def _f02uv(self, f0):
         | 
| 196 | 
            +
                    # generate uv signal
         | 
| 197 | 
            +
                    uv = (f0 > self.voiced_threshold).type(torch.float32)
         | 
| 198 | 
            +
                    return uv
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                @torch.no_grad()
         | 
| 201 | 
            +
                def forward(self, f0):
         | 
| 202 | 
            +
                    """
         | 
| 203 | 
            +
                    :param f0: [B, 1, sample_len], Hz
         | 
| 204 | 
            +
                    :return: [B, 1, sample_len]
         | 
| 205 | 
            +
                    """
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                    F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
         | 
| 208 | 
            +
                    for i in range(self.harmonic_num + 1):
         | 
| 209 | 
            +
                        F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                    theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
         | 
| 212 | 
            +
                    u_dist = Uniform(low=-np.pi, high=np.pi)
         | 
| 213 | 
            +
                    phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
         | 
| 214 | 
            +
                    phase_vec[:, 0, :] = 0
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                    # generate sine waveforms
         | 
| 217 | 
            +
                    sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                    # generate uv signal
         | 
| 220 | 
            +
                    uv = self._f02uv(f0)
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                    # noise: for unvoiced should be similar to sine_amp
         | 
| 223 | 
            +
                    #        std = self.sine_amp/3 -> max value ~ self.sine_amp
         | 
| 224 | 
            +
                    # .       for voiced regions is self.noise_std
         | 
| 225 | 
            +
                    noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
         | 
| 226 | 
            +
                    noise = noise_amp * torch.randn_like(sine_waves)
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                    # first: set the unvoiced part to 0 by uv
         | 
| 229 | 
            +
                    # then: additive noise
         | 
| 230 | 
            +
                    sine_waves = sine_waves * uv + noise
         | 
| 231 | 
            +
                    return sine_waves, uv, noise
         | 
| 232 | 
            +
             | 
| 233 | 
            +
             | 
| 234 | 
            +
            class SourceModuleHnNSF(torch.nn.Module):
         | 
| 235 | 
            +
                """ SourceModule for hn-nsf
         | 
| 236 | 
            +
                SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
         | 
| 237 | 
            +
                             add_noise_std=0.003, voiced_threshod=0)
         | 
| 238 | 
            +
                sampling_rate: sampling_rate in Hz
         | 
| 239 | 
            +
                harmonic_num: number of harmonic above F0 (default: 0)
         | 
| 240 | 
            +
                sine_amp: amplitude of sine source signal (default: 0.1)
         | 
| 241 | 
            +
                add_noise_std: std of additive Gaussian noise (default: 0.003)
         | 
| 242 | 
            +
                    note that amplitude of noise in unvoiced is decided
         | 
| 243 | 
            +
                    by sine_amp
         | 
| 244 | 
            +
                voiced_threshold: threhold to set U/V given F0 (default: 0)
         | 
| 245 | 
            +
                Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
         | 
| 246 | 
            +
                F0_sampled (batchsize, length, 1)
         | 
| 247 | 
            +
                Sine_source (batchsize, length, 1)
         | 
| 248 | 
            +
                noise_source (batchsize, length 1)
         | 
| 249 | 
            +
                uv (batchsize, length, 1)
         | 
| 250 | 
            +
                """
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
         | 
| 253 | 
            +
                             add_noise_std=0.003, voiced_threshod=0):
         | 
| 254 | 
            +
                    super(SourceModuleHnNSF, self).__init__()
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                    self.sine_amp = sine_amp
         | 
| 257 | 
            +
                    self.noise_std = add_noise_std
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                    # to produce sine waveforms
         | 
| 260 | 
            +
                    self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
         | 
| 261 | 
            +
                                             sine_amp, add_noise_std, voiced_threshod)
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                    # to merge source harmonics into a single excitation
         | 
| 264 | 
            +
                    self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
         | 
| 265 | 
            +
                    self.l_tanh = torch.nn.Tanh()
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                def forward(self, x):
         | 
| 268 | 
            +
                    """
         | 
| 269 | 
            +
                    Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
         | 
| 270 | 
            +
                    F0_sampled (batchsize, length, 1)
         | 
| 271 | 
            +
                    Sine_source (batchsize, length, 1)
         | 
| 272 | 
            +
                    noise_source (batchsize, length 1)
         | 
| 273 | 
            +
                    """
         | 
| 274 | 
            +
                    # source for harmonic branch
         | 
| 275 | 
            +
                    with torch.no_grad():
         | 
| 276 | 
            +
                        sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
         | 
| 277 | 
            +
                        sine_wavs = sine_wavs.transpose(1, 2)
         | 
| 278 | 
            +
                        uv = uv.transpose(1, 2)
         | 
| 279 | 
            +
                    sine_merge = self.l_tanh(self.l_linear(sine_wavs))
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                    # source for noise branch, in the same shape as uv
         | 
| 282 | 
            +
                    noise = torch.randn_like(uv) * self.sine_amp / 3
         | 
| 283 | 
            +
                    return sine_merge, noise, uv
         | 
| 284 | 
            +
             | 
| 285 | 
            +
             | 
| 286 | 
            +
            class HiFTGenerator(nn.Module):
         | 
| 287 | 
            +
                """
         | 
| 288 | 
            +
                HiFTNet Generator: Neural Source Filter + ISTFTNet
         | 
| 289 | 
            +
                https://arxiv.org/abs/2309.09493
         | 
| 290 | 
            +
                """
         | 
| 291 | 
            +
                def __init__(
         | 
| 292 | 
            +
                        self,
         | 
| 293 | 
            +
                        in_channels: int = 80,
         | 
| 294 | 
            +
                        base_channels: int = 512,
         | 
| 295 | 
            +
                        nb_harmonics: int = 8,
         | 
| 296 | 
            +
                        sampling_rate: int = 22050,
         | 
| 297 | 
            +
                        nsf_alpha: float = 0.1,
         | 
| 298 | 
            +
                        nsf_sigma: float = 0.003,
         | 
| 299 | 
            +
                        nsf_voiced_threshold: float = 10,
         | 
| 300 | 
            +
                        upsample_rates: List[int] = [8, 8],
         | 
| 301 | 
            +
                        upsample_kernel_sizes: List[int] = [16, 16],
         | 
| 302 | 
            +
                        istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
         | 
| 303 | 
            +
                        resblock_kernel_sizes: List[int] = [3, 7, 11],
         | 
| 304 | 
            +
                        resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
         | 
| 305 | 
            +
                        source_resblock_kernel_sizes: List[int] = [7, 11],
         | 
| 306 | 
            +
                        source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
         | 
| 307 | 
            +
                        lrelu_slope: float = 0.1,
         | 
| 308 | 
            +
                        audio_limit: float = 0.99,
         | 
| 309 | 
            +
                        f0_predictor: torch.nn.Module = None,
         | 
| 310 | 
            +
                ):
         | 
| 311 | 
            +
                    super(HiFTGenerator, self).__init__()
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                    self.out_channels = 1
         | 
| 314 | 
            +
                    self.nb_harmonics = nb_harmonics
         | 
| 315 | 
            +
                    self.sampling_rate = sampling_rate
         | 
| 316 | 
            +
                    self.istft_params = istft_params
         | 
| 317 | 
            +
                    self.lrelu_slope = lrelu_slope
         | 
| 318 | 
            +
                    self.audio_limit = audio_limit
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                    self.num_kernels = len(resblock_kernel_sizes)
         | 
| 321 | 
            +
                    self.num_upsamples = len(upsample_rates)
         | 
| 322 | 
            +
                    self.m_source = SourceModuleHnNSF(
         | 
| 323 | 
            +
                        sampling_rate=sampling_rate,
         | 
| 324 | 
            +
                        upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
         | 
| 325 | 
            +
                        harmonic_num=nb_harmonics,
         | 
| 326 | 
            +
                        sine_amp=nsf_alpha,
         | 
| 327 | 
            +
                        add_noise_std=nsf_sigma,
         | 
| 328 | 
            +
                        voiced_threshod=nsf_voiced_threshold)
         | 
| 329 | 
            +
                    self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                    self.conv_pre = weight_norm(
         | 
| 332 | 
            +
                        Conv1d(in_channels, base_channels, 7, 1, padding=3)
         | 
| 333 | 
            +
                    )
         | 
| 334 | 
            +
             | 
| 335 | 
            +
                    # Up
         | 
| 336 | 
            +
                    self.ups = nn.ModuleList()
         | 
| 337 | 
            +
                    for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
         | 
| 338 | 
            +
                        self.ups.append(
         | 
| 339 | 
            +
                            weight_norm(
         | 
| 340 | 
            +
                                ConvTranspose1d(
         | 
| 341 | 
            +
                                    base_channels // (2**i),
         | 
| 342 | 
            +
                                    base_channels // (2**(i + 1)),
         | 
| 343 | 
            +
                                    k,
         | 
| 344 | 
            +
                                    u,
         | 
| 345 | 
            +
                                    padding=(k - u) // 2,
         | 
| 346 | 
            +
                                )
         | 
| 347 | 
            +
                            )
         | 
| 348 | 
            +
                        )
         | 
| 349 | 
            +
             | 
| 350 | 
            +
                    # Down
         | 
| 351 | 
            +
                    self.source_downs = nn.ModuleList()
         | 
| 352 | 
            +
                    self.source_resblocks = nn.ModuleList()
         | 
| 353 | 
            +
                    downsample_rates = [1] + upsample_rates[::-1][:-1]
         | 
| 354 | 
            +
                    downsample_cum_rates = np.cumprod(downsample_rates)
         | 
| 355 | 
            +
                    for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
         | 
| 356 | 
            +
                        if u == 1:
         | 
| 357 | 
            +
                            self.source_downs.append(
         | 
| 358 | 
            +
                                Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
         | 
| 359 | 
            +
                            )
         | 
| 360 | 
            +
                        else:
         | 
| 361 | 
            +
                            self.source_downs.append(
         | 
| 362 | 
            +
                                Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
         | 
| 363 | 
            +
                            )
         | 
| 364 | 
            +
             | 
| 365 | 
            +
                        self.source_resblocks.append(
         | 
| 366 | 
            +
                            ResBlock(base_channels // (2 ** (i + 1)), k, d)
         | 
| 367 | 
            +
                        )
         | 
| 368 | 
            +
             | 
| 369 | 
            +
                    self.resblocks = nn.ModuleList()
         | 
| 370 | 
            +
                    for i in range(len(self.ups)):
         | 
| 371 | 
            +
                        ch = base_channels // (2**(i + 1))
         | 
| 372 | 
            +
                        for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
         | 
| 373 | 
            +
                            self.resblocks.append(ResBlock(ch, k, d))
         | 
| 374 | 
            +
             | 
| 375 | 
            +
                    self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
         | 
| 376 | 
            +
                    self.ups.apply(init_weights)
         | 
| 377 | 
            +
                    self.conv_post.apply(init_weights)
         | 
| 378 | 
            +
                    self.reflection_pad = nn.ReflectionPad1d((1, 0))
         | 
| 379 | 
            +
                    self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
         | 
| 380 | 
            +
                    self.f0_predictor = f0_predictor
         | 
| 381 | 
            +
             | 
| 382 | 
            +
                def remove_weight_norm(self):
         | 
| 383 | 
            +
                    print('Removing weight norm...')
         | 
| 384 | 
            +
                    for l in self.ups:
         | 
| 385 | 
            +
                        remove_weight_norm(l)
         | 
| 386 | 
            +
                    for l in self.resblocks:
         | 
| 387 | 
            +
                        l.remove_weight_norm()
         | 
| 388 | 
            +
                    remove_weight_norm(self.conv_pre)
         | 
| 389 | 
            +
                    remove_weight_norm(self.conv_post)
         | 
| 390 | 
            +
                    self.m_source.remove_weight_norm()
         | 
| 391 | 
            +
                    for l in self.source_downs:
         | 
| 392 | 
            +
                        remove_weight_norm(l)
         | 
| 393 | 
            +
                    for l in self.source_resblocks:
         | 
| 394 | 
            +
                        l.remove_weight_norm()
         | 
| 395 | 
            +
             | 
| 396 | 
            +
                def _stft(self, x):
         | 
| 397 | 
            +
                    spec = torch.stft(
         | 
| 398 | 
            +
                        x,
         | 
| 399 | 
            +
                        self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
         | 
| 400 | 
            +
                        return_complex=True)
         | 
| 401 | 
            +
                    spec = torch.view_as_real(spec)  # [B, F, TT, 2]
         | 
| 402 | 
            +
                    return spec[..., 0], spec[..., 1]
         | 
| 403 | 
            +
             | 
| 404 | 
            +
                def _istft(self, magnitude, phase):
         | 
| 405 | 
            +
                    magnitude = torch.clip(magnitude, max=1e2)
         | 
| 406 | 
            +
                    real = magnitude * torch.cos(phase)
         | 
| 407 | 
            +
                    img = magnitude * torch.sin(phase)
         | 
| 408 | 
            +
                    inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
         | 
| 409 | 
            +
                                                    self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
         | 
| 410 | 
            +
                    return inverse_transform
         | 
| 411 | 
            +
             | 
| 412 | 
            +
                def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
         | 
| 413 | 
            +
                    s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
         | 
| 414 | 
            +
                    s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
         | 
| 415 | 
            +
             | 
| 416 | 
            +
                    x = self.conv_pre(x)
         | 
| 417 | 
            +
                    for i in range(self.num_upsamples):
         | 
| 418 | 
            +
                        x = F.leaky_relu(x, self.lrelu_slope)
         | 
| 419 | 
            +
                        x = self.ups[i](x)
         | 
| 420 | 
            +
             | 
| 421 | 
            +
                        if i == self.num_upsamples - 1:
         | 
| 422 | 
            +
                            x = self.reflection_pad(x)
         | 
| 423 | 
            +
             | 
| 424 | 
            +
                        # fusion
         | 
| 425 | 
            +
                        si = self.source_downs[i](s_stft)
         | 
| 426 | 
            +
                        si = self.source_resblocks[i](si)
         | 
| 427 | 
            +
                        x = x + si
         | 
| 428 | 
            +
             | 
| 429 | 
            +
                        xs = None
         | 
| 430 | 
            +
                        for j in range(self.num_kernels):
         | 
| 431 | 
            +
                            if xs is None:
         | 
| 432 | 
            +
                                xs = self.resblocks[i * self.num_kernels + j](x)
         | 
| 433 | 
            +
                            else:
         | 
| 434 | 
            +
                                xs += self.resblocks[i * self.num_kernels + j](x)
         | 
| 435 | 
            +
                        x = xs / self.num_kernels
         | 
| 436 | 
            +
             | 
| 437 | 
            +
                    x = F.leaky_relu(x)
         | 
| 438 | 
            +
                    x = self.conv_post(x)
         | 
| 439 | 
            +
                    magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
         | 
| 440 | 
            +
                    phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :])  # actually, sin is redundancy
         | 
| 441 | 
            +
             | 
| 442 | 
            +
                    x = self._istft(magnitude, phase)
         | 
| 443 | 
            +
                    x = torch.clamp(x, -self.audio_limit, self.audio_limit)
         | 
| 444 | 
            +
                    return x
         | 
| 445 | 
            +
             | 
| 446 | 
            +
                def forward(
         | 
| 447 | 
            +
                        self,
         | 
| 448 | 
            +
                        batch: dict,
         | 
| 449 | 
            +
                        device: torch.device,
         | 
| 450 | 
            +
                ) -> Dict[str, Optional[torch.Tensor]]:
         | 
| 451 | 
            +
                    speech_feat = batch['speech_feat'].transpose(1, 2).to(device)
         | 
| 452 | 
            +
                    # mel->f0
         | 
| 453 | 
            +
                    f0 = self.f0_predictor(speech_feat)
         | 
| 454 | 
            +
                    # f0->source
         | 
| 455 | 
            +
                    s = self.f0_upsamp(f0[:, None]).transpose(1, 2)  # bs,n,t
         | 
| 456 | 
            +
                    s, _, _ = self.m_source(s)
         | 
| 457 | 
            +
                    s = s.transpose(1, 2)
         | 
| 458 | 
            +
                    # mel+source->speech
         | 
| 459 | 
            +
                    generated_speech = self.decode(x=speech_feat, s=s)
         | 
| 460 | 
            +
                    return generated_speech, f0
         | 
| 461 | 
            +
             | 
| 462 | 
            +
                @torch.inference_mode()
         | 
| 463 | 
            +
                def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
         | 
| 464 | 
            +
                    # mel->f0
         | 
| 465 | 
            +
                    f0 = self.f0_predictor(speech_feat)
         | 
| 466 | 
            +
                    # f0->source
         | 
| 467 | 
            +
                    s = self.f0_upsamp(f0[:, None]).transpose(1, 2)  # bs,n,t
         | 
| 468 | 
            +
                    s, _, _ = self.m_source(s)
         | 
| 469 | 
            +
                    s = s.transpose(1, 2)
         | 
| 470 | 
            +
                    # use cache_source to avoid glitch
         | 
| 471 | 
            +
                    if cache_source.shape[2] != 0:
         | 
| 472 | 
            +
                        s[:, :, :cache_source.shape[2]] = cache_source
         | 
| 473 | 
            +
                    generated_speech = self.decode(x=speech_feat, s=s)
         | 
| 474 | 
            +
                    return generated_speech, s
         | 
    	
        chatterbox/src/chatterbox/models/s3gen/matcha/__pycache__/decoder.cpython-311.pyc
    ADDED
    
    | Binary file (21.3 kB). View file | 
|  | 
    	
        chatterbox/src/chatterbox/models/s3gen/matcha/__pycache__/flow_matching.cpython-311.pyc
    ADDED
    
    | Binary file (6.46 kB). View file | 
|  | 
    	
        chatterbox/src/chatterbox/models/s3gen/matcha/__pycache__/transformer.cpython-311.pyc
    ADDED
    
    | Binary file (14.7 kB). View file | 
|  | 
    	
        chatterbox/src/chatterbox/models/s3gen/matcha/decoder.py
    ADDED
    
    | @@ -0,0 +1,443 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
            from typing import Optional
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import torch.nn as nn
         | 
| 6 | 
            +
            import torch.nn.functional as F
         | 
| 7 | 
            +
            from conformer import ConformerBlock
         | 
| 8 | 
            +
            from diffusers.models.activations import get_activation
         | 
| 9 | 
            +
            from einops import pack, rearrange, repeat
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from .transformer import BasicTransformerBlock
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            class SinusoidalPosEmb(torch.nn.Module):
         | 
| 15 | 
            +
                def __init__(self, dim):
         | 
| 16 | 
            +
                    super().__init__()
         | 
| 17 | 
            +
                    self.dim = dim
         | 
| 18 | 
            +
                    assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                def forward(self, x, scale=1000):
         | 
| 21 | 
            +
                    if x.ndim < 1:
         | 
| 22 | 
            +
                        x = x.unsqueeze(0)
         | 
| 23 | 
            +
                    device = x.device
         | 
| 24 | 
            +
                    half_dim = self.dim // 2
         | 
| 25 | 
            +
                    emb = math.log(10000) / (half_dim - 1)
         | 
| 26 | 
            +
                    emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
         | 
| 27 | 
            +
                    emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
         | 
| 28 | 
            +
                    emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
         | 
| 29 | 
            +
                    return emb
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
            +
            class Block1D(torch.nn.Module):
         | 
| 33 | 
            +
                def __init__(self, dim, dim_out, groups=8):
         | 
| 34 | 
            +
                    super().__init__()
         | 
| 35 | 
            +
                    self.block = torch.nn.Sequential(
         | 
| 36 | 
            +
                        torch.nn.Conv1d(dim, dim_out, 3, padding=1),
         | 
| 37 | 
            +
                        torch.nn.GroupNorm(groups, dim_out),
         | 
| 38 | 
            +
                        nn.Mish(),
         | 
| 39 | 
            +
                    )
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                def forward(self, x, mask):
         | 
| 42 | 
            +
                    output = self.block(x * mask)
         | 
| 43 | 
            +
                    return output * mask
         | 
| 44 | 
            +
             | 
| 45 | 
            +
             | 
| 46 | 
            +
            class ResnetBlock1D(torch.nn.Module):
         | 
| 47 | 
            +
                def __init__(self, dim, dim_out, time_emb_dim, groups=8):
         | 
| 48 | 
            +
                    super().__init__()
         | 
| 49 | 
            +
                    self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out))
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    self.block1 = Block1D(dim, dim_out, groups=groups)
         | 
| 52 | 
            +
                    self.block2 = Block1D(dim_out, dim_out, groups=groups)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                    self.res_conv = torch.nn.Conv1d(dim, dim_out, 1)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                def forward(self, x, mask, time_emb):
         | 
| 57 | 
            +
                    h = self.block1(x, mask)
         | 
| 58 | 
            +
                    h += self.mlp(time_emb).unsqueeze(-1)
         | 
| 59 | 
            +
                    h = self.block2(h, mask)
         | 
| 60 | 
            +
                    output = h + self.res_conv(x * mask)
         | 
| 61 | 
            +
                    return output
         | 
| 62 | 
            +
             | 
| 63 | 
            +
             | 
| 64 | 
            +
            class Downsample1D(nn.Module):
         | 
| 65 | 
            +
                def __init__(self, dim):
         | 
| 66 | 
            +
                    super().__init__()
         | 
| 67 | 
            +
                    self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                def forward(self, x):
         | 
| 70 | 
            +
                    return self.conv(x)
         | 
| 71 | 
            +
             | 
| 72 | 
            +
             | 
| 73 | 
            +
            class TimestepEmbedding(nn.Module):
         | 
| 74 | 
            +
                def __init__(
         | 
| 75 | 
            +
                    self,
         | 
| 76 | 
            +
                    in_channels: int,
         | 
| 77 | 
            +
                    time_embed_dim: int,
         | 
| 78 | 
            +
                    act_fn: str = "silu",
         | 
| 79 | 
            +
                    out_dim: int = None,
         | 
| 80 | 
            +
                    post_act_fn: Optional[str] = None,
         | 
| 81 | 
            +
                    cond_proj_dim=None,
         | 
| 82 | 
            +
                ):
         | 
| 83 | 
            +
                    super().__init__()
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    self.linear_1 = nn.Linear(in_channels, time_embed_dim)
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    if cond_proj_dim is not None:
         | 
| 88 | 
            +
                        self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
         | 
| 89 | 
            +
                    else:
         | 
| 90 | 
            +
                        self.cond_proj = None
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                    self.act = get_activation(act_fn)
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    if out_dim is not None:
         | 
| 95 | 
            +
                        time_embed_dim_out = out_dim
         | 
| 96 | 
            +
                    else:
         | 
| 97 | 
            +
                        time_embed_dim_out = time_embed_dim
         | 
| 98 | 
            +
                    self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                    if post_act_fn is None:
         | 
| 101 | 
            +
                        self.post_act = None
         | 
| 102 | 
            +
                    else:
         | 
| 103 | 
            +
                        self.post_act = get_activation(post_act_fn)
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                def forward(self, sample, condition=None):
         | 
| 106 | 
            +
                    if condition is not None:
         | 
| 107 | 
            +
                        sample = sample + self.cond_proj(condition)
         | 
| 108 | 
            +
                    sample = self.linear_1(sample)
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                    if self.act is not None:
         | 
| 111 | 
            +
                        sample = self.act(sample)
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    sample = self.linear_2(sample)
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    if self.post_act is not None:
         | 
| 116 | 
            +
                        sample = self.post_act(sample)
         | 
| 117 | 
            +
                    return sample
         | 
| 118 | 
            +
             | 
| 119 | 
            +
             | 
| 120 | 
            +
            class Upsample1D(nn.Module):
         | 
| 121 | 
            +
                """A 1D upsampling layer with an optional convolution.
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                Parameters:
         | 
| 124 | 
            +
                    channels (`int`):
         | 
| 125 | 
            +
                        number of channels in the inputs and outputs.
         | 
| 126 | 
            +
                    use_conv (`bool`, default `False`):
         | 
| 127 | 
            +
                        option to use a convolution.
         | 
| 128 | 
            +
                    use_conv_transpose (`bool`, default `False`):
         | 
| 129 | 
            +
                        option to use a convolution transpose.
         | 
| 130 | 
            +
                    out_channels (`int`, optional):
         | 
| 131 | 
            +
                        number of output channels. Defaults to `channels`.
         | 
| 132 | 
            +
                """
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                def __init__(self, channels, use_conv=False, use_conv_transpose=True, out_channels=None, name="conv"):
         | 
| 135 | 
            +
                    super().__init__()
         | 
| 136 | 
            +
                    self.channels = channels
         | 
| 137 | 
            +
                    self.out_channels = out_channels or channels
         | 
| 138 | 
            +
                    self.use_conv = use_conv
         | 
| 139 | 
            +
                    self.use_conv_transpose = use_conv_transpose
         | 
| 140 | 
            +
                    self.name = name
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                    self.conv = None
         | 
| 143 | 
            +
                    if use_conv_transpose:
         | 
| 144 | 
            +
                        self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
         | 
| 145 | 
            +
                    elif use_conv:
         | 
| 146 | 
            +
                        self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                def forward(self, inputs):
         | 
| 149 | 
            +
                    assert inputs.shape[1] == self.channels
         | 
| 150 | 
            +
                    if self.use_conv_transpose:
         | 
| 151 | 
            +
                        return self.conv(inputs)
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                    outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                    if self.use_conv:
         | 
| 156 | 
            +
                        outputs = self.conv(outputs)
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                    return outputs
         | 
| 159 | 
            +
             | 
| 160 | 
            +
             | 
| 161 | 
            +
            class ConformerWrapper(ConformerBlock):
         | 
| 162 | 
            +
                def __init__(  # pylint: disable=useless-super-delegation
         | 
| 163 | 
            +
                    self,
         | 
| 164 | 
            +
                    *,
         | 
| 165 | 
            +
                    dim,
         | 
| 166 | 
            +
                    dim_head=64,
         | 
| 167 | 
            +
                    heads=8,
         | 
| 168 | 
            +
                    ff_mult=4,
         | 
| 169 | 
            +
                    conv_expansion_factor=2,
         | 
| 170 | 
            +
                    conv_kernel_size=31,
         | 
| 171 | 
            +
                    attn_dropout=0,
         | 
| 172 | 
            +
                    ff_dropout=0,
         | 
| 173 | 
            +
                    conv_dropout=0,
         | 
| 174 | 
            +
                    conv_causal=False,
         | 
| 175 | 
            +
                ):
         | 
| 176 | 
            +
                    super().__init__(
         | 
| 177 | 
            +
                        dim=dim,
         | 
| 178 | 
            +
                        dim_head=dim_head,
         | 
| 179 | 
            +
                        heads=heads,
         | 
| 180 | 
            +
                        ff_mult=ff_mult,
         | 
| 181 | 
            +
                        conv_expansion_factor=conv_expansion_factor,
         | 
| 182 | 
            +
                        conv_kernel_size=conv_kernel_size,
         | 
| 183 | 
            +
                        attn_dropout=attn_dropout,
         | 
| 184 | 
            +
                        ff_dropout=ff_dropout,
         | 
| 185 | 
            +
                        conv_dropout=conv_dropout,
         | 
| 186 | 
            +
                        conv_causal=conv_causal,
         | 
| 187 | 
            +
                    )
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                def forward(
         | 
| 190 | 
            +
                    self,
         | 
| 191 | 
            +
                    hidden_states,
         | 
| 192 | 
            +
                    attention_mask,
         | 
| 193 | 
            +
                    encoder_hidden_states=None,
         | 
| 194 | 
            +
                    encoder_attention_mask=None,
         | 
| 195 | 
            +
                    timestep=None,
         | 
| 196 | 
            +
                ):
         | 
| 197 | 
            +
                    return super().forward(x=hidden_states, mask=attention_mask.bool())
         | 
| 198 | 
            +
             | 
| 199 | 
            +
             | 
| 200 | 
            +
            class Decoder(nn.Module):
         | 
| 201 | 
            +
                def __init__(
         | 
| 202 | 
            +
                    self,
         | 
| 203 | 
            +
                    in_channels,
         | 
| 204 | 
            +
                    out_channels,
         | 
| 205 | 
            +
                    channels=(256, 256),
         | 
| 206 | 
            +
                    dropout=0.05,
         | 
| 207 | 
            +
                    attention_head_dim=64,
         | 
| 208 | 
            +
                    n_blocks=1,
         | 
| 209 | 
            +
                    num_mid_blocks=2,
         | 
| 210 | 
            +
                    num_heads=4,
         | 
| 211 | 
            +
                    act_fn="snake",
         | 
| 212 | 
            +
                    down_block_type="transformer",
         | 
| 213 | 
            +
                    mid_block_type="transformer",
         | 
| 214 | 
            +
                    up_block_type="transformer",
         | 
| 215 | 
            +
                ):
         | 
| 216 | 
            +
                    super().__init__()
         | 
| 217 | 
            +
                    channels = tuple(channels)
         | 
| 218 | 
            +
                    self.in_channels = in_channels
         | 
| 219 | 
            +
                    self.out_channels = out_channels
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                    self.time_embeddings = SinusoidalPosEmb(in_channels)
         | 
| 222 | 
            +
                    time_embed_dim = channels[0] * 4
         | 
| 223 | 
            +
                    self.time_mlp = TimestepEmbedding(
         | 
| 224 | 
            +
                        in_channels=in_channels,
         | 
| 225 | 
            +
                        time_embed_dim=time_embed_dim,
         | 
| 226 | 
            +
                        act_fn="silu",
         | 
| 227 | 
            +
                    )
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                    self.down_blocks = nn.ModuleList([])
         | 
| 230 | 
            +
                    self.mid_blocks = nn.ModuleList([])
         | 
| 231 | 
            +
                    self.up_blocks = nn.ModuleList([])
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                    output_channel = in_channels
         | 
| 234 | 
            +
                    for i in range(len(channels)):  # pylint: disable=consider-using-enumerate
         | 
| 235 | 
            +
                        input_channel = output_channel
         | 
| 236 | 
            +
                        output_channel = channels[i]
         | 
| 237 | 
            +
                        is_last = i == len(channels) - 1
         | 
| 238 | 
            +
                        resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
         | 
| 239 | 
            +
                        transformer_blocks = nn.ModuleList(
         | 
| 240 | 
            +
                            [
         | 
| 241 | 
            +
                                self.get_block(
         | 
| 242 | 
            +
                                    down_block_type,
         | 
| 243 | 
            +
                                    output_channel,
         | 
| 244 | 
            +
                                    attention_head_dim,
         | 
| 245 | 
            +
                                    num_heads,
         | 
| 246 | 
            +
                                    dropout,
         | 
| 247 | 
            +
                                    act_fn,
         | 
| 248 | 
            +
                                )
         | 
| 249 | 
            +
                                for _ in range(n_blocks)
         | 
| 250 | 
            +
                            ]
         | 
| 251 | 
            +
                        )
         | 
| 252 | 
            +
                        downsample = (
         | 
| 253 | 
            +
                            Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
         | 
| 254 | 
            +
                        )
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                        self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                    for i in range(num_mid_blocks):
         | 
| 259 | 
            +
                        input_channel = channels[-1]
         | 
| 260 | 
            +
                        out_channels = channels[-1]
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                        resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                        transformer_blocks = nn.ModuleList(
         | 
| 265 | 
            +
                            [
         | 
| 266 | 
            +
                                self.get_block(
         | 
| 267 | 
            +
                                    mid_block_type,
         | 
| 268 | 
            +
                                    output_channel,
         | 
| 269 | 
            +
                                    attention_head_dim,
         | 
| 270 | 
            +
                                    num_heads,
         | 
| 271 | 
            +
                                    dropout,
         | 
| 272 | 
            +
                                    act_fn,
         | 
| 273 | 
            +
                                )
         | 
| 274 | 
            +
                                for _ in range(n_blocks)
         | 
| 275 | 
            +
                            ]
         | 
| 276 | 
            +
                        )
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                        self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                    channels = channels[::-1] + (channels[0],)
         | 
| 281 | 
            +
                    for i in range(len(channels) - 1):
         | 
| 282 | 
            +
                        input_channel = channels[i]
         | 
| 283 | 
            +
                        output_channel = channels[i + 1]
         | 
| 284 | 
            +
                        is_last = i == len(channels) - 2
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                        resnet = ResnetBlock1D(
         | 
| 287 | 
            +
                            dim=2 * input_channel,
         | 
| 288 | 
            +
                            dim_out=output_channel,
         | 
| 289 | 
            +
                            time_emb_dim=time_embed_dim,
         | 
| 290 | 
            +
                        )
         | 
| 291 | 
            +
                        transformer_blocks = nn.ModuleList(
         | 
| 292 | 
            +
                            [
         | 
| 293 | 
            +
                                self.get_block(
         | 
| 294 | 
            +
                                    up_block_type,
         | 
| 295 | 
            +
                                    output_channel,
         | 
| 296 | 
            +
                                    attention_head_dim,
         | 
| 297 | 
            +
                                    num_heads,
         | 
| 298 | 
            +
                                    dropout,
         | 
| 299 | 
            +
                                    act_fn,
         | 
| 300 | 
            +
                                )
         | 
| 301 | 
            +
                                for _ in range(n_blocks)
         | 
| 302 | 
            +
                            ]
         | 
| 303 | 
            +
                        )
         | 
| 304 | 
            +
                        upsample = (
         | 
| 305 | 
            +
                            Upsample1D(output_channel, use_conv_transpose=True)
         | 
| 306 | 
            +
                            if not is_last
         | 
| 307 | 
            +
                            else nn.Conv1d(output_channel, output_channel, 3, padding=1)
         | 
| 308 | 
            +
                        )
         | 
| 309 | 
            +
             | 
| 310 | 
            +
                        self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
         | 
| 311 | 
            +
             | 
| 312 | 
            +
                    self.final_block = Block1D(channels[-1], channels[-1])
         | 
| 313 | 
            +
                    self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                    self.initialize_weights()
         | 
| 316 | 
            +
                    # nn.init.normal_(self.final_proj.weight)
         | 
| 317 | 
            +
             | 
| 318 | 
            +
                @staticmethod
         | 
| 319 | 
            +
                def get_block(block_type, dim, attention_head_dim, num_heads, dropout, act_fn):
         | 
| 320 | 
            +
                    if block_type == "conformer":
         | 
| 321 | 
            +
                        block = ConformerWrapper(
         | 
| 322 | 
            +
                            dim=dim,
         | 
| 323 | 
            +
                            dim_head=attention_head_dim,
         | 
| 324 | 
            +
                            heads=num_heads,
         | 
| 325 | 
            +
                            ff_mult=1,
         | 
| 326 | 
            +
                            conv_expansion_factor=2,
         | 
| 327 | 
            +
                            ff_dropout=dropout,
         | 
| 328 | 
            +
                            attn_dropout=dropout,
         | 
| 329 | 
            +
                            conv_dropout=dropout,
         | 
| 330 | 
            +
                            conv_kernel_size=31,
         | 
| 331 | 
            +
                        )
         | 
| 332 | 
            +
                    elif block_type == "transformer":
         | 
| 333 | 
            +
                        block = BasicTransformerBlock(
         | 
| 334 | 
            +
                            dim=dim,
         | 
| 335 | 
            +
                            num_attention_heads=num_heads,
         | 
| 336 | 
            +
                            attention_head_dim=attention_head_dim,
         | 
| 337 | 
            +
                            dropout=dropout,
         | 
| 338 | 
            +
                            activation_fn=act_fn,
         | 
| 339 | 
            +
                        )
         | 
| 340 | 
            +
                    else:
         | 
| 341 | 
            +
                        raise ValueError(f"Unknown block type {block_type}")
         | 
| 342 | 
            +
             | 
| 343 | 
            +
                    return block
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                def initialize_weights(self):
         | 
| 346 | 
            +
                    for m in self.modules():
         | 
| 347 | 
            +
                        if isinstance(m, nn.Conv1d):
         | 
| 348 | 
            +
                            nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
         | 
| 349 | 
            +
             | 
| 350 | 
            +
                            if m.bias is not None:
         | 
| 351 | 
            +
                                nn.init.constant_(m.bias, 0)
         | 
| 352 | 
            +
             | 
| 353 | 
            +
                        elif isinstance(m, nn.GroupNorm):
         | 
| 354 | 
            +
                            nn.init.constant_(m.weight, 1)
         | 
| 355 | 
            +
                            nn.init.constant_(m.bias, 0)
         | 
| 356 | 
            +
             | 
| 357 | 
            +
                        elif isinstance(m, nn.Linear):
         | 
| 358 | 
            +
                            nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
         | 
| 359 | 
            +
             | 
| 360 | 
            +
                            if m.bias is not None:
         | 
| 361 | 
            +
                                nn.init.constant_(m.bias, 0)
         | 
| 362 | 
            +
             | 
| 363 | 
            +
                def forward(self, x, mask, mu, t, spks=None, cond=None):
         | 
| 364 | 
            +
                    """Forward pass of the UNet1DConditional model.
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                    Args:
         | 
| 367 | 
            +
                        x (torch.Tensor): shape (batch_size, in_channels, time)
         | 
| 368 | 
            +
                        mask (_type_): shape (batch_size, 1, time)
         | 
| 369 | 
            +
                        t (_type_): shape (batch_size)
         | 
| 370 | 
            +
                        spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
         | 
| 371 | 
            +
                        cond (_type_, optional): placeholder for future use. Defaults to None.
         | 
| 372 | 
            +
             | 
| 373 | 
            +
                    Raises:
         | 
| 374 | 
            +
                        ValueError: _description_
         | 
| 375 | 
            +
                        ValueError: _description_
         | 
| 376 | 
            +
             | 
| 377 | 
            +
                    Returns:
         | 
| 378 | 
            +
                        _type_: _description_
         | 
| 379 | 
            +
                    """
         | 
| 380 | 
            +
             | 
| 381 | 
            +
                    t = self.time_embeddings(t)
         | 
| 382 | 
            +
                    t = self.time_mlp(t)
         | 
| 383 | 
            +
             | 
| 384 | 
            +
                    x = pack([x, mu], "b * t")[0]
         | 
| 385 | 
            +
             | 
| 386 | 
            +
                    if spks is not None:
         | 
| 387 | 
            +
                        spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
         | 
| 388 | 
            +
                        x = pack([x, spks], "b * t")[0]
         | 
| 389 | 
            +
             | 
| 390 | 
            +
                    hiddens = []
         | 
| 391 | 
            +
                    masks = [mask]
         | 
| 392 | 
            +
                    for resnet, transformer_blocks, downsample in self.down_blocks:
         | 
| 393 | 
            +
                        mask_down = masks[-1]
         | 
| 394 | 
            +
                        x = resnet(x, mask_down, t)
         | 
| 395 | 
            +
                        x = rearrange(x, "b c t -> b t c")
         | 
| 396 | 
            +
                        mask_down = rearrange(mask_down, "b 1 t -> b t")
         | 
| 397 | 
            +
                        for transformer_block in transformer_blocks:
         | 
| 398 | 
            +
                            x = transformer_block(
         | 
| 399 | 
            +
                                hidden_states=x,
         | 
| 400 | 
            +
                                attention_mask=mask_down,
         | 
| 401 | 
            +
                                timestep=t,
         | 
| 402 | 
            +
                            )
         | 
| 403 | 
            +
                        x = rearrange(x, "b t c -> b c t")
         | 
| 404 | 
            +
                        mask_down = rearrange(mask_down, "b t -> b 1 t")
         | 
| 405 | 
            +
                        hiddens.append(x)  # Save hidden states for skip connections
         | 
| 406 | 
            +
                        x = downsample(x * mask_down)
         | 
| 407 | 
            +
                        masks.append(mask_down[:, :, ::2])
         | 
| 408 | 
            +
             | 
| 409 | 
            +
                    masks = masks[:-1]
         | 
| 410 | 
            +
                    mask_mid = masks[-1]
         | 
| 411 | 
            +
             | 
| 412 | 
            +
                    for resnet, transformer_blocks in self.mid_blocks:
         | 
| 413 | 
            +
                        x = resnet(x, mask_mid, t)
         | 
| 414 | 
            +
                        x = rearrange(x, "b c t -> b t c")
         | 
| 415 | 
            +
                        mask_mid = rearrange(mask_mid, "b 1 t -> b t")
         | 
| 416 | 
            +
                        for transformer_block in transformer_blocks:
         | 
| 417 | 
            +
                            x = transformer_block(
         | 
| 418 | 
            +
                                hidden_states=x,
         | 
| 419 | 
            +
                                attention_mask=mask_mid,
         | 
| 420 | 
            +
                                timestep=t,
         | 
| 421 | 
            +
                            )
         | 
| 422 | 
            +
                        x = rearrange(x, "b t c -> b c t")
         | 
| 423 | 
            +
                        mask_mid = rearrange(mask_mid, "b t -> b 1 t")
         | 
| 424 | 
            +
             | 
| 425 | 
            +
                    for resnet, transformer_blocks, upsample in self.up_blocks:
         | 
| 426 | 
            +
                        mask_up = masks.pop()
         | 
| 427 | 
            +
                        x = resnet(pack([x, hiddens.pop()], "b * t")[0], mask_up, t)
         | 
| 428 | 
            +
                        x = rearrange(x, "b c t -> b t c")
         | 
| 429 | 
            +
                        mask_up = rearrange(mask_up, "b 1 t -> b t")
         | 
| 430 | 
            +
                        for transformer_block in transformer_blocks:
         | 
| 431 | 
            +
                            x = transformer_block(
         | 
| 432 | 
            +
                                hidden_states=x,
         | 
| 433 | 
            +
                                attention_mask=mask_up,
         | 
| 434 | 
            +
                                timestep=t,
         | 
| 435 | 
            +
                            )
         | 
| 436 | 
            +
                        x = rearrange(x, "b t c -> b c t")
         | 
| 437 | 
            +
                        mask_up = rearrange(mask_up, "b t -> b 1 t")
         | 
| 438 | 
            +
                        x = upsample(x * mask_up)
         | 
| 439 | 
            +
             | 
| 440 | 
            +
                    x = self.final_block(x, mask_up)
         | 
| 441 | 
            +
                    output = self.final_proj(x * mask_up)
         | 
| 442 | 
            +
             | 
| 443 | 
            +
                    return output * mask
         | 
    	
        chatterbox/src/chatterbox/models/s3gen/matcha/flow_matching.py
    ADDED
    
    | @@ -0,0 +1,129 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from abc import ABC
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torch.nn.functional as F
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from .decoder import Decoder
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            class BASECFM(torch.nn.Module, ABC):
         | 
| 10 | 
            +
                def __init__(
         | 
| 11 | 
            +
                    self,
         | 
| 12 | 
            +
                    n_feats,
         | 
| 13 | 
            +
                    cfm_params,
         | 
| 14 | 
            +
                    n_spks=1,
         | 
| 15 | 
            +
                    spk_emb_dim=128,
         | 
| 16 | 
            +
                ):
         | 
| 17 | 
            +
                    super().__init__()
         | 
| 18 | 
            +
                    self.n_feats = n_feats
         | 
| 19 | 
            +
                    self.n_spks = n_spks
         | 
| 20 | 
            +
                    self.spk_emb_dim = spk_emb_dim
         | 
| 21 | 
            +
                    self.solver = cfm_params.solver
         | 
| 22 | 
            +
                    if hasattr(cfm_params, "sigma_min"):
         | 
| 23 | 
            +
                        self.sigma_min = cfm_params.sigma_min
         | 
| 24 | 
            +
                    else:
         | 
| 25 | 
            +
                        self.sigma_min = 1e-4
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                    self.estimator = None
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                @torch.inference_mode()
         | 
| 30 | 
            +
                def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
         | 
| 31 | 
            +
                    """Forward diffusion
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    Args:
         | 
| 34 | 
            +
                        mu (torch.Tensor): output of encoder
         | 
| 35 | 
            +
                            shape: (batch_size, n_feats, mel_timesteps)
         | 
| 36 | 
            +
                        mask (torch.Tensor): output_mask
         | 
| 37 | 
            +
                            shape: (batch_size, 1, mel_timesteps)
         | 
| 38 | 
            +
                        n_timesteps (int): number of diffusion steps
         | 
| 39 | 
            +
                        temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
         | 
| 40 | 
            +
                        spks (torch.Tensor, optional): speaker ids. Defaults to None.
         | 
| 41 | 
            +
                            shape: (batch_size, spk_emb_dim)
         | 
| 42 | 
            +
                        cond: Not used but kept for future purposes
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    Returns:
         | 
| 45 | 
            +
                        sample: generated mel-spectrogram
         | 
| 46 | 
            +
                            shape: (batch_size, n_feats, mel_timesteps)
         | 
| 47 | 
            +
                    """
         | 
| 48 | 
            +
                    z = torch.randn_like(mu) * temperature
         | 
| 49 | 
            +
                    t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
         | 
| 50 | 
            +
                    return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                def solve_euler(self, x, t_span, mu, mask, spks, cond):
         | 
| 53 | 
            +
                    """
         | 
| 54 | 
            +
                    Fixed euler solver for ODEs.
         | 
| 55 | 
            +
                    Args:
         | 
| 56 | 
            +
                        x (torch.Tensor): random noise
         | 
| 57 | 
            +
                        t_span (torch.Tensor): n_timesteps interpolated
         | 
| 58 | 
            +
                            shape: (n_timesteps + 1,)
         | 
| 59 | 
            +
                        mu (torch.Tensor): output of encoder
         | 
| 60 | 
            +
                            shape: (batch_size, n_feats, mel_timesteps)
         | 
| 61 | 
            +
                        mask (torch.Tensor): output_mask
         | 
| 62 | 
            +
                            shape: (batch_size, 1, mel_timesteps)
         | 
| 63 | 
            +
                        spks (torch.Tensor, optional): speaker ids. Defaults to None.
         | 
| 64 | 
            +
                            shape: (batch_size, spk_emb_dim)
         | 
| 65 | 
            +
                        cond: Not used but kept for future purposes
         | 
| 66 | 
            +
                    """
         | 
| 67 | 
            +
                    t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    # I am storing this because I can later plot it by putting a debugger here and saving it to a file
         | 
| 70 | 
            +
                    # Or in future might add like a return_all_steps flag
         | 
| 71 | 
            +
                    sol = []
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    for step in range(1, len(t_span)):
         | 
| 74 | 
            +
                        dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                        x = x + dt * dphi_dt
         | 
| 77 | 
            +
                        t = t + dt
         | 
| 78 | 
            +
                        sol.append(x)
         | 
| 79 | 
            +
                        if step < len(t_span) - 1:
         | 
| 80 | 
            +
                            dt = t_span[step + 1] - t
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    return sol[-1]
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                def compute_loss(self, x1, mask, mu, spks=None, cond=None):
         | 
| 85 | 
            +
                    """Computes diffusion loss
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    Args:
         | 
| 88 | 
            +
                        x1 (torch.Tensor): Target
         | 
| 89 | 
            +
                            shape: (batch_size, n_feats, mel_timesteps)
         | 
| 90 | 
            +
                        mask (torch.Tensor): target mask
         | 
| 91 | 
            +
                            shape: (batch_size, 1, mel_timesteps)
         | 
| 92 | 
            +
                        mu (torch.Tensor): output of encoder
         | 
| 93 | 
            +
                            shape: (batch_size, n_feats, mel_timesteps)
         | 
| 94 | 
            +
                        spks (torch.Tensor, optional): speaker embedding. Defaults to None.
         | 
| 95 | 
            +
                            shape: (batch_size, spk_emb_dim)
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                    Returns:
         | 
| 98 | 
            +
                        loss: conditional flow matching loss
         | 
| 99 | 
            +
                        y: conditional flow
         | 
| 100 | 
            +
                            shape: (batch_size, n_feats, mel_timesteps)
         | 
| 101 | 
            +
                    """
         | 
| 102 | 
            +
                    b, _, t = mu.shape
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    # random timestep
         | 
| 105 | 
            +
                    t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
         | 
| 106 | 
            +
                    # sample noise p(x_0)
         | 
| 107 | 
            +
                    z = torch.randn_like(x1)
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    y = (1 - (1 - self.sigma_min) * t) * z + t * x1
         | 
| 110 | 
            +
                    u = x1 - (1 - self.sigma_min) * z
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / (
         | 
| 113 | 
            +
                        torch.sum(mask) * u.shape[1]
         | 
| 114 | 
            +
                    )
         | 
| 115 | 
            +
                    return loss, y
         | 
| 116 | 
            +
             | 
| 117 | 
            +
             | 
| 118 | 
            +
            class CFM(BASECFM):
         | 
| 119 | 
            +
                def __init__(self, in_channels, out_channel, cfm_params, decoder_params, n_spks=1, spk_emb_dim=64):
         | 
| 120 | 
            +
                    super().__init__(
         | 
| 121 | 
            +
                        n_feats=in_channels,
         | 
| 122 | 
            +
                        cfm_params=cfm_params,
         | 
| 123 | 
            +
                        n_spks=n_spks,
         | 
| 124 | 
            +
                        spk_emb_dim=spk_emb_dim,
         | 
| 125 | 
            +
                    )
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                    in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0)
         | 
| 128 | 
            +
                    # Just change the architecture of the estimator here
         | 
| 129 | 
            +
                    self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params)
         | 
    	
        chatterbox/src/chatterbox/models/s3gen/matcha/text_encoder.py
    ADDED
    
    | @@ -0,0 +1,413 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """ from https://github.com/jaywalnut310/glow-tts """
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import math
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            import torch.nn as nn
         | 
| 7 | 
            +
            from einops import rearrange
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            def sequence_mask(length, max_length=None):
         | 
| 11 | 
            +
                if max_length is None:
         | 
| 12 | 
            +
                    max_length = length.max()
         | 
| 13 | 
            +
                x = torch.arange(max_length, dtype=length.dtype, device=length.device)
         | 
| 14 | 
            +
                return x.unsqueeze(0) < length.unsqueeze(1)
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            class LayerNorm(nn.Module):
         | 
| 19 | 
            +
                def __init__(self, channels, eps=1e-4):
         | 
| 20 | 
            +
                    super().__init__()
         | 
| 21 | 
            +
                    self.channels = channels
         | 
| 22 | 
            +
                    self.eps = eps
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                    self.gamma = torch.nn.Parameter(torch.ones(channels))
         | 
| 25 | 
            +
                    self.beta = torch.nn.Parameter(torch.zeros(channels))
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                def forward(self, x):
         | 
| 28 | 
            +
                    n_dims = len(x.shape)
         | 
| 29 | 
            +
                    mean = torch.mean(x, 1, keepdim=True)
         | 
| 30 | 
            +
                    variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                    x = (x - mean) * torch.rsqrt(variance + self.eps)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    shape = [1, -1] + [1] * (n_dims - 2)
         | 
| 35 | 
            +
                    x = x * self.gamma.view(*shape) + self.beta.view(*shape)
         | 
| 36 | 
            +
                    return x
         | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
| 39 | 
            +
            class ConvReluNorm(nn.Module):
         | 
| 40 | 
            +
                def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
         | 
| 41 | 
            +
                    super().__init__()
         | 
| 42 | 
            +
                    self.in_channels = in_channels
         | 
| 43 | 
            +
                    self.hidden_channels = hidden_channels
         | 
| 44 | 
            +
                    self.out_channels = out_channels
         | 
| 45 | 
            +
                    self.kernel_size = kernel_size
         | 
| 46 | 
            +
                    self.n_layers = n_layers
         | 
| 47 | 
            +
                    self.p_dropout = p_dropout
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                    self.conv_layers = torch.nn.ModuleList()
         | 
| 50 | 
            +
                    self.norm_layers = torch.nn.ModuleList()
         | 
| 51 | 
            +
                    self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
         | 
| 52 | 
            +
                    self.norm_layers.append(LayerNorm(hidden_channels))
         | 
| 53 | 
            +
                    self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout))
         | 
| 54 | 
            +
                    for _ in range(n_layers - 1):
         | 
| 55 | 
            +
                        self.conv_layers.append(
         | 
| 56 | 
            +
                            torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)
         | 
| 57 | 
            +
                        )
         | 
| 58 | 
            +
                        self.norm_layers.append(LayerNorm(hidden_channels))
         | 
| 59 | 
            +
                    self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1)
         | 
| 60 | 
            +
                    self.proj.weight.data.zero_()
         | 
| 61 | 
            +
                    self.proj.bias.data.zero_()
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                def forward(self, x, x_mask):
         | 
| 64 | 
            +
                    x_org = x
         | 
| 65 | 
            +
                    for i in range(self.n_layers):
         | 
| 66 | 
            +
                        x = self.conv_layers[i](x * x_mask)
         | 
| 67 | 
            +
                        x = self.norm_layers[i](x)
         | 
| 68 | 
            +
                        x = self.relu_drop(x)
         | 
| 69 | 
            +
                    x = x_org + self.proj(x)
         | 
| 70 | 
            +
                    return x * x_mask
         | 
| 71 | 
            +
             | 
| 72 | 
            +
             | 
| 73 | 
            +
            class DurationPredictor(nn.Module):
         | 
| 74 | 
            +
                def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
         | 
| 75 | 
            +
                    super().__init__()
         | 
| 76 | 
            +
                    self.in_channels = in_channels
         | 
| 77 | 
            +
                    self.filter_channels = filter_channels
         | 
| 78 | 
            +
                    self.p_dropout = p_dropout
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    self.drop = torch.nn.Dropout(p_dropout)
         | 
| 81 | 
            +
                    self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
         | 
| 82 | 
            +
                    self.norm_1 = LayerNorm(filter_channels)
         | 
| 83 | 
            +
                    self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
         | 
| 84 | 
            +
                    self.norm_2 = LayerNorm(filter_channels)
         | 
| 85 | 
            +
                    self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                def forward(self, x, x_mask):
         | 
| 88 | 
            +
                    x = self.conv_1(x * x_mask)
         | 
| 89 | 
            +
                    x = torch.relu(x)
         | 
| 90 | 
            +
                    x = self.norm_1(x)
         | 
| 91 | 
            +
                    x = self.drop(x)
         | 
| 92 | 
            +
                    x = self.conv_2(x * x_mask)
         | 
| 93 | 
            +
                    x = torch.relu(x)
         | 
| 94 | 
            +
                    x = self.norm_2(x)
         | 
| 95 | 
            +
                    x = self.drop(x)
         | 
| 96 | 
            +
                    x = self.proj(x * x_mask)
         | 
| 97 | 
            +
                    return x * x_mask
         | 
| 98 | 
            +
             | 
| 99 | 
            +
             | 
| 100 | 
            +
            class RotaryPositionalEmbeddings(nn.Module):
         | 
| 101 | 
            +
                """
         | 
| 102 | 
            +
                ## RoPE module
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                Rotary encoding transforms pairs of features by rotating in the 2D plane.
         | 
| 105 | 
            +
                That is, it organizes the $d$ features as $\frac{d}{2}$ pairs.
         | 
| 106 | 
            +
                Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it
         | 
| 107 | 
            +
                by an angle depending on the position of the token.
         | 
| 108 | 
            +
                """
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                def __init__(self, d: int, base: int = 10_000):
         | 
| 111 | 
            +
                    r"""
         | 
| 112 | 
            +
                    * `d` is the number of features $d$
         | 
| 113 | 
            +
                    * `base` is the constant used for calculating $\Theta$
         | 
| 114 | 
            +
                    """
         | 
| 115 | 
            +
                    super().__init__()
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                    self.base = base
         | 
| 118 | 
            +
                    self.d = int(d)
         | 
| 119 | 
            +
                    self.cos_cached = None
         | 
| 120 | 
            +
                    self.sin_cached = None
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                def _build_cache(self, x: torch.Tensor):
         | 
| 123 | 
            +
                    r"""
         | 
| 124 | 
            +
                    Cache $\cos$ and $\sin$ values
         | 
| 125 | 
            +
                    """
         | 
| 126 | 
            +
                    # Return if cache is already built
         | 
| 127 | 
            +
                    if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
         | 
| 128 | 
            +
                        return
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    # Get sequence length
         | 
| 131 | 
            +
                    seq_len = x.shape[0]
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    # $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
         | 
| 134 | 
            +
                    theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                    # Create position indexes `[0, 1, ..., seq_len - 1]`
         | 
| 137 | 
            +
                    seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    # Calculate the product of position index and $\theta_i$
         | 
| 140 | 
            +
                    idx_theta = torch.einsum("n,d->nd", seq_idx, theta)
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                    # Concatenate so that for row $m$ we have
         | 
| 143 | 
            +
                    # $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$
         | 
| 144 | 
            +
                    idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    # Cache them
         | 
| 147 | 
            +
                    self.cos_cached = idx_theta2.cos()[:, None, None, :]
         | 
| 148 | 
            +
                    self.sin_cached = idx_theta2.sin()[:, None, None, :]
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                def _neg_half(self, x: torch.Tensor):
         | 
| 151 | 
            +
                    # $\frac{d}{2}$
         | 
| 152 | 
            +
                    d_2 = self.d // 2
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                    # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
         | 
| 155 | 
            +
                    return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                def forward(self, x: torch.Tensor):
         | 
| 158 | 
            +
                    """
         | 
| 159 | 
            +
                    * `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]`
         | 
| 160 | 
            +
                    """
         | 
| 161 | 
            +
                    # Cache $\cos$ and $\sin$ values
         | 
| 162 | 
            +
                    x = rearrange(x, "b h t d -> t b h d")
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    self._build_cache(x)
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                    # Split the features, we can choose to apply rotary embeddings only to a partial set of features.
         | 
| 167 | 
            +
                    x_rope, x_pass = x[..., : self.d], x[..., self.d :]
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                    # Calculate
         | 
| 170 | 
            +
                    # $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
         | 
| 171 | 
            +
                    neg_half_x = self._neg_half(x_rope)
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                    x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]])
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    return rearrange(torch.cat((x_rope, x_pass), dim=-1), "t b h d -> b h t d")
         | 
| 176 | 
            +
             | 
| 177 | 
            +
             | 
| 178 | 
            +
            class MultiHeadAttention(nn.Module):
         | 
| 179 | 
            +
                def __init__(
         | 
| 180 | 
            +
                    self,
         | 
| 181 | 
            +
                    channels,
         | 
| 182 | 
            +
                    out_channels,
         | 
| 183 | 
            +
                    n_heads,
         | 
| 184 | 
            +
                    heads_share=True,
         | 
| 185 | 
            +
                    p_dropout=0.0,
         | 
| 186 | 
            +
                    proximal_bias=False,
         | 
| 187 | 
            +
                    proximal_init=False,
         | 
| 188 | 
            +
                ):
         | 
| 189 | 
            +
                    super().__init__()
         | 
| 190 | 
            +
                    assert channels % n_heads == 0
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                    self.channels = channels
         | 
| 193 | 
            +
                    self.out_channels = out_channels
         | 
| 194 | 
            +
                    self.n_heads = n_heads
         | 
| 195 | 
            +
                    self.heads_share = heads_share
         | 
| 196 | 
            +
                    self.proximal_bias = proximal_bias
         | 
| 197 | 
            +
                    self.p_dropout = p_dropout
         | 
| 198 | 
            +
                    self.attn = None
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                    self.k_channels = channels // n_heads
         | 
| 201 | 
            +
                    self.conv_q = torch.nn.Conv1d(channels, channels, 1)
         | 
| 202 | 
            +
                    self.conv_k = torch.nn.Conv1d(channels, channels, 1)
         | 
| 203 | 
            +
                    self.conv_v = torch.nn.Conv1d(channels, channels, 1)
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                    # from https://nn.labml.ai/transformers/rope/index.html
         | 
| 206 | 
            +
                    self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
         | 
| 207 | 
            +
                    self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                    self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
         | 
| 210 | 
            +
                    self.drop = torch.nn.Dropout(p_dropout)
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                    torch.nn.init.xavier_uniform_(self.conv_q.weight)
         | 
| 213 | 
            +
                    torch.nn.init.xavier_uniform_(self.conv_k.weight)
         | 
| 214 | 
            +
                    if proximal_init:
         | 
| 215 | 
            +
                        self.conv_k.weight.data.copy_(self.conv_q.weight.data)
         | 
| 216 | 
            +
                        self.conv_k.bias.data.copy_(self.conv_q.bias.data)
         | 
| 217 | 
            +
                    torch.nn.init.xavier_uniform_(self.conv_v.weight)
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                def forward(self, x, c, attn_mask=None):
         | 
| 220 | 
            +
                    q = self.conv_q(x)
         | 
| 221 | 
            +
                    k = self.conv_k(c)
         | 
| 222 | 
            +
                    v = self.conv_v(c)
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                    x, self.attn = self.attention(q, k, v, mask=attn_mask)
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                    x = self.conv_o(x)
         | 
| 227 | 
            +
                    return x
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                def attention(self, query, key, value, mask=None):
         | 
| 230 | 
            +
                    b, d, t_s, t_t = (*key.size(), query.size(2))
         | 
| 231 | 
            +
                    query = rearrange(query, "b (h c) t-> b h t c", h=self.n_heads)
         | 
| 232 | 
            +
                    key = rearrange(key, "b (h c) t-> b h t c", h=self.n_heads)
         | 
| 233 | 
            +
                    value = rearrange(value, "b (h c) t-> b h t c", h=self.n_heads)
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                    query = self.query_rotary_pe(query)
         | 
| 236 | 
            +
                    key = self.key_rotary_pe(key)
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                    if self.proximal_bias:
         | 
| 241 | 
            +
                        assert t_s == t_t, "Proximal bias is only available for self-attention."
         | 
| 242 | 
            +
                        scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
         | 
| 243 | 
            +
                    if mask is not None:
         | 
| 244 | 
            +
                        scores = scores.masked_fill(mask == 0, -1e4)
         | 
| 245 | 
            +
                    p_attn = torch.nn.functional.softmax(scores, dim=-1)
         | 
| 246 | 
            +
                    p_attn = self.drop(p_attn)
         | 
| 247 | 
            +
                    output = torch.matmul(p_attn, value)
         | 
| 248 | 
            +
                    output = output.transpose(2, 3).contiguous().view(b, d, t_t)
         | 
| 249 | 
            +
                    return output, p_attn
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                @staticmethod
         | 
| 252 | 
            +
                def _attention_bias_proximal(length):
         | 
| 253 | 
            +
                    r = torch.arange(length, dtype=torch.float32)
         | 
| 254 | 
            +
                    diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
         | 
| 255 | 
            +
                    return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
         | 
| 256 | 
            +
             | 
| 257 | 
            +
             | 
| 258 | 
            +
            class FFN(nn.Module):
         | 
| 259 | 
            +
                def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0):
         | 
| 260 | 
            +
                    super().__init__()
         | 
| 261 | 
            +
                    self.in_channels = in_channels
         | 
| 262 | 
            +
                    self.out_channels = out_channels
         | 
| 263 | 
            +
                    self.filter_channels = filter_channels
         | 
| 264 | 
            +
                    self.kernel_size = kernel_size
         | 
| 265 | 
            +
                    self.p_dropout = p_dropout
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                    self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
         | 
| 268 | 
            +
                    self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2)
         | 
| 269 | 
            +
                    self.drop = torch.nn.Dropout(p_dropout)
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                def forward(self, x, x_mask):
         | 
| 272 | 
            +
                    x = self.conv_1(x * x_mask)
         | 
| 273 | 
            +
                    x = torch.relu(x)
         | 
| 274 | 
            +
                    x = self.drop(x)
         | 
| 275 | 
            +
                    x = self.conv_2(x * x_mask)
         | 
| 276 | 
            +
                    return x * x_mask
         | 
| 277 | 
            +
             | 
| 278 | 
            +
             | 
| 279 | 
            +
            class Encoder(nn.Module):
         | 
| 280 | 
            +
                def __init__(
         | 
| 281 | 
            +
                    self,
         | 
| 282 | 
            +
                    hidden_channels,
         | 
| 283 | 
            +
                    filter_channels,
         | 
| 284 | 
            +
                    n_heads,
         | 
| 285 | 
            +
                    n_layers,
         | 
| 286 | 
            +
                    kernel_size=1,
         | 
| 287 | 
            +
                    p_dropout=0.0,
         | 
| 288 | 
            +
                    **kwargs,
         | 
| 289 | 
            +
                ):
         | 
| 290 | 
            +
                    super().__init__()
         | 
| 291 | 
            +
                    self.hidden_channels = hidden_channels
         | 
| 292 | 
            +
                    self.filter_channels = filter_channels
         | 
| 293 | 
            +
                    self.n_heads = n_heads
         | 
| 294 | 
            +
                    self.n_layers = n_layers
         | 
| 295 | 
            +
                    self.kernel_size = kernel_size
         | 
| 296 | 
            +
                    self.p_dropout = p_dropout
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                    self.drop = torch.nn.Dropout(p_dropout)
         | 
| 299 | 
            +
                    self.attn_layers = torch.nn.ModuleList()
         | 
| 300 | 
            +
                    self.norm_layers_1 = torch.nn.ModuleList()
         | 
| 301 | 
            +
                    self.ffn_layers = torch.nn.ModuleList()
         | 
| 302 | 
            +
                    self.norm_layers_2 = torch.nn.ModuleList()
         | 
| 303 | 
            +
                    for _ in range(self.n_layers):
         | 
| 304 | 
            +
                        self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout))
         | 
| 305 | 
            +
                        self.norm_layers_1.append(LayerNorm(hidden_channels))
         | 
| 306 | 
            +
                        self.ffn_layers.append(
         | 
| 307 | 
            +
                            FFN(
         | 
| 308 | 
            +
                                hidden_channels,
         | 
| 309 | 
            +
                                hidden_channels,
         | 
| 310 | 
            +
                                filter_channels,
         | 
| 311 | 
            +
                                kernel_size,
         | 
| 312 | 
            +
                                p_dropout=p_dropout,
         | 
| 313 | 
            +
                            )
         | 
| 314 | 
            +
                        )
         | 
| 315 | 
            +
                        self.norm_layers_2.append(LayerNorm(hidden_channels))
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                def forward(self, x, x_mask):
         | 
| 318 | 
            +
                    attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
         | 
| 319 | 
            +
                    for i in range(self.n_layers):
         | 
| 320 | 
            +
                        x = x * x_mask
         | 
| 321 | 
            +
                        y = self.attn_layers[i](x, x, attn_mask)
         | 
| 322 | 
            +
                        y = self.drop(y)
         | 
| 323 | 
            +
                        x = self.norm_layers_1[i](x + y)
         | 
| 324 | 
            +
                        y = self.ffn_layers[i](x, x_mask)
         | 
| 325 | 
            +
                        y = self.drop(y)
         | 
| 326 | 
            +
                        x = self.norm_layers_2[i](x + y)
         | 
| 327 | 
            +
                    x = x * x_mask
         | 
| 328 | 
            +
                    return x
         | 
| 329 | 
            +
             | 
| 330 | 
            +
             | 
| 331 | 
            +
            class TextEncoder(nn.Module):
         | 
| 332 | 
            +
                def __init__(
         | 
| 333 | 
            +
                    self,
         | 
| 334 | 
            +
                    encoder_type,
         | 
| 335 | 
            +
                    encoder_params,
         | 
| 336 | 
            +
                    duration_predictor_params,
         | 
| 337 | 
            +
                    n_vocab,
         | 
| 338 | 
            +
                    n_spks=1,
         | 
| 339 | 
            +
                    spk_emb_dim=128,
         | 
| 340 | 
            +
                ):
         | 
| 341 | 
            +
                    super().__init__()
         | 
| 342 | 
            +
                    self.encoder_type = encoder_type
         | 
| 343 | 
            +
                    self.n_vocab = n_vocab
         | 
| 344 | 
            +
                    self.n_feats = encoder_params.n_feats
         | 
| 345 | 
            +
                    self.n_channels = encoder_params.n_channels
         | 
| 346 | 
            +
                    self.spk_emb_dim = spk_emb_dim
         | 
| 347 | 
            +
                    self.n_spks = n_spks
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                    self.emb = torch.nn.Embedding(n_vocab, self.n_channels)
         | 
| 350 | 
            +
                    torch.nn.init.normal_(self.emb.weight, 0.0, self.n_channels**-0.5)
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                    if encoder_params.prenet:
         | 
| 353 | 
            +
                        self.prenet = ConvReluNorm(
         | 
| 354 | 
            +
                            self.n_channels,
         | 
| 355 | 
            +
                            self.n_channels,
         | 
| 356 | 
            +
                            self.n_channels,
         | 
| 357 | 
            +
                            kernel_size=5,
         | 
| 358 | 
            +
                            n_layers=3,
         | 
| 359 | 
            +
                            p_dropout=0.5,
         | 
| 360 | 
            +
                        )
         | 
| 361 | 
            +
                    else:
         | 
| 362 | 
            +
                        self.prenet = lambda x, x_mask: x
         | 
| 363 | 
            +
             | 
| 364 | 
            +
                    self.encoder = Encoder(
         | 
| 365 | 
            +
                        encoder_params.n_channels + (spk_emb_dim if n_spks > 1 else 0),
         | 
| 366 | 
            +
                        encoder_params.filter_channels,
         | 
| 367 | 
            +
                        encoder_params.n_heads,
         | 
| 368 | 
            +
                        encoder_params.n_layers,
         | 
| 369 | 
            +
                        encoder_params.kernel_size,
         | 
| 370 | 
            +
                        encoder_params.p_dropout,
         | 
| 371 | 
            +
                    )
         | 
| 372 | 
            +
             | 
| 373 | 
            +
                    self.proj_m = torch.nn.Conv1d(self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1)
         | 
| 374 | 
            +
                    self.proj_w = DurationPredictor(
         | 
| 375 | 
            +
                        self.n_channels + (spk_emb_dim if n_spks > 1 else 0),
         | 
| 376 | 
            +
                        duration_predictor_params.filter_channels_dp,
         | 
| 377 | 
            +
                        duration_predictor_params.kernel_size,
         | 
| 378 | 
            +
                        duration_predictor_params.p_dropout,
         | 
| 379 | 
            +
                    )
         | 
| 380 | 
            +
             | 
| 381 | 
            +
                def forward(self, x, x_lengths, spks=None):
         | 
| 382 | 
            +
                    """Run forward pass to the transformer based encoder and duration predictor
         | 
| 383 | 
            +
             | 
| 384 | 
            +
                    Args:
         | 
| 385 | 
            +
                        x (torch.Tensor): text input
         | 
| 386 | 
            +
                            shape: (batch_size, max_text_length)
         | 
| 387 | 
            +
                        x_lengths (torch.Tensor): text input lengths
         | 
| 388 | 
            +
                            shape: (batch_size,)
         | 
| 389 | 
            +
                        spks (torch.Tensor, optional): speaker ids. Defaults to None.
         | 
| 390 | 
            +
                            shape: (batch_size,)
         | 
| 391 | 
            +
             | 
| 392 | 
            +
                    Returns:
         | 
| 393 | 
            +
                        mu (torch.Tensor): average output of the encoder
         | 
| 394 | 
            +
                            shape: (batch_size, n_feats, max_text_length)
         | 
| 395 | 
            +
                        logw (torch.Tensor): log duration predicted by the duration predictor
         | 
| 396 | 
            +
                            shape: (batch_size, 1, max_text_length)
         | 
| 397 | 
            +
                        x_mask (torch.Tensor): mask for the text input
         | 
| 398 | 
            +
                            shape: (batch_size, 1, max_text_length)
         | 
| 399 | 
            +
                    """
         | 
| 400 | 
            +
                    x = self.emb(x) * math.sqrt(self.n_channels)
         | 
| 401 | 
            +
                    x = torch.transpose(x, 1, -1)
         | 
| 402 | 
            +
                    x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
         | 
| 403 | 
            +
             | 
| 404 | 
            +
                    x = self.prenet(x, x_mask)
         | 
| 405 | 
            +
                    if self.n_spks > 1:
         | 
| 406 | 
            +
                        x = torch.cat([x, spks.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)
         | 
| 407 | 
            +
                    x = self.encoder(x, x_mask)
         | 
| 408 | 
            +
                    mu = self.proj_m(x) * x_mask
         | 
| 409 | 
            +
             | 
| 410 | 
            +
                    x_dp = torch.detach(x)
         | 
| 411 | 
            +
                    logw = self.proj_w(x_dp, x_mask)
         | 
| 412 | 
            +
             | 
| 413 | 
            +
                    return mu, logw, x_mask
         | 
    	
        chatterbox/src/chatterbox/models/s3gen/matcha/transformer.py
    ADDED
    
    | @@ -0,0 +1,316 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Any, Dict, Optional
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torch.nn as nn
         | 
| 5 | 
            +
            from diffusers.models.attention import (
         | 
| 6 | 
            +
                GEGLU,
         | 
| 7 | 
            +
                GELU,
         | 
| 8 | 
            +
                AdaLayerNorm,
         | 
| 9 | 
            +
                AdaLayerNormZero,
         | 
| 10 | 
            +
                ApproximateGELU,
         | 
| 11 | 
            +
            )
         | 
| 12 | 
            +
            from diffusers.models.attention_processor import Attention
         | 
| 13 | 
            +
            from diffusers.models.lora import LoRACompatibleLinear
         | 
| 14 | 
            +
            from diffusers.utils.torch_utils import maybe_allow_in_graph
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            class SnakeBeta(nn.Module):
         | 
| 18 | 
            +
                """
         | 
| 19 | 
            +
                A modified Snake function which uses separate parameters for the magnitude of the periodic components
         | 
| 20 | 
            +
                Shape:
         | 
| 21 | 
            +
                    - Input: (B, C, T)
         | 
| 22 | 
            +
                    - Output: (B, C, T), same shape as the input
         | 
| 23 | 
            +
                Parameters:
         | 
| 24 | 
            +
                    - alpha - trainable parameter that controls frequency
         | 
| 25 | 
            +
                    - beta - trainable parameter that controls magnitude
         | 
| 26 | 
            +
                References:
         | 
| 27 | 
            +
                    - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
         | 
| 28 | 
            +
                    https://arxiv.org/abs/2006.08195
         | 
| 29 | 
            +
                Examples:
         | 
| 30 | 
            +
                    >>> a1 = snakebeta(256)
         | 
| 31 | 
            +
                    >>> x = torch.randn(256)
         | 
| 32 | 
            +
                    >>> x = a1(x)
         | 
| 33 | 
            +
                """
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                def __init__(self, in_features, out_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
         | 
| 36 | 
            +
                    """
         | 
| 37 | 
            +
                    Initialization.
         | 
| 38 | 
            +
                    INPUT:
         | 
| 39 | 
            +
                        - in_features: shape of the input
         | 
| 40 | 
            +
                        - alpha - trainable parameter that controls frequency
         | 
| 41 | 
            +
                        - beta - trainable parameter that controls magnitude
         | 
| 42 | 
            +
                        alpha is initialized to 1 by default, higher values = higher-frequency.
         | 
| 43 | 
            +
                        beta is initialized to 1 by default, higher values = higher-magnitude.
         | 
| 44 | 
            +
                        alpha will be trained along with the rest of your model.
         | 
| 45 | 
            +
                    """
         | 
| 46 | 
            +
                    super().__init__()
         | 
| 47 | 
            +
                    self.in_features = out_features if isinstance(out_features, list) else [out_features]
         | 
| 48 | 
            +
                    self.proj = LoRACompatibleLinear(in_features, out_features)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                    # initialize alpha
         | 
| 51 | 
            +
                    self.alpha_logscale = alpha_logscale
         | 
| 52 | 
            +
                    if self.alpha_logscale:  # log scale alphas initialized to zeros
         | 
| 53 | 
            +
                        self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha)
         | 
| 54 | 
            +
                        self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha)
         | 
| 55 | 
            +
                    else:  # linear scale alphas initialized to ones
         | 
| 56 | 
            +
                        self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha)
         | 
| 57 | 
            +
                        self.beta = nn.Parameter(torch.ones(self.in_features) * alpha)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    self.alpha.requires_grad = alpha_trainable
         | 
| 60 | 
            +
                    self.beta.requires_grad = alpha_trainable
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    self.no_div_by_zero = 0.000000001
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                def forward(self, x):
         | 
| 65 | 
            +
                    """
         | 
| 66 | 
            +
                    Forward pass of the function.
         | 
| 67 | 
            +
                    Applies the function to the input elementwise.
         | 
| 68 | 
            +
                    SnakeBeta ∶= x + 1/b * sin^2 (xa)
         | 
| 69 | 
            +
                    """
         | 
| 70 | 
            +
                    x = self.proj(x)
         | 
| 71 | 
            +
                    if self.alpha_logscale:
         | 
| 72 | 
            +
                        alpha = torch.exp(self.alpha)
         | 
| 73 | 
            +
                        beta = torch.exp(self.beta)
         | 
| 74 | 
            +
                    else:
         | 
| 75 | 
            +
                        alpha = self.alpha
         | 
| 76 | 
            +
                        beta = self.beta
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                    x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2)
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    return x
         | 
| 81 | 
            +
             | 
| 82 | 
            +
             | 
| 83 | 
            +
            class FeedForward(nn.Module):
         | 
| 84 | 
            +
                r"""
         | 
| 85 | 
            +
                A feed-forward layer.
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                Parameters:
         | 
| 88 | 
            +
                    dim (`int`): The number of channels in the input.
         | 
| 89 | 
            +
                    dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
         | 
| 90 | 
            +
                    mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
         | 
| 91 | 
            +
                    dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
         | 
| 92 | 
            +
                    activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
         | 
| 93 | 
            +
                    final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
         | 
| 94 | 
            +
                """
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                def __init__(
         | 
| 97 | 
            +
                    self,
         | 
| 98 | 
            +
                    dim: int,
         | 
| 99 | 
            +
                    dim_out: Optional[int] = None,
         | 
| 100 | 
            +
                    mult: int = 4,
         | 
| 101 | 
            +
                    dropout: float = 0.0,
         | 
| 102 | 
            +
                    activation_fn: str = "geglu",
         | 
| 103 | 
            +
                    final_dropout: bool = False,
         | 
| 104 | 
            +
                ):
         | 
| 105 | 
            +
                    super().__init__()
         | 
| 106 | 
            +
                    inner_dim = int(dim * mult)
         | 
| 107 | 
            +
                    dim_out = dim_out if dim_out is not None else dim
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    if activation_fn == "gelu":
         | 
| 110 | 
            +
                        act_fn = GELU(dim, inner_dim)
         | 
| 111 | 
            +
                    if activation_fn == "gelu-approximate":
         | 
| 112 | 
            +
                        act_fn = GELU(dim, inner_dim, approximate="tanh")
         | 
| 113 | 
            +
                    elif activation_fn == "geglu":
         | 
| 114 | 
            +
                        act_fn = GEGLU(dim, inner_dim)
         | 
| 115 | 
            +
                    elif activation_fn == "geglu-approximate":
         | 
| 116 | 
            +
                        act_fn = ApproximateGELU(dim, inner_dim)
         | 
| 117 | 
            +
                    elif activation_fn == "snakebeta":
         | 
| 118 | 
            +
                        act_fn = SnakeBeta(dim, inner_dim)
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    self.net = nn.ModuleList([])
         | 
| 121 | 
            +
                    # project in
         | 
| 122 | 
            +
                    self.net.append(act_fn)
         | 
| 123 | 
            +
                    # project dropout
         | 
| 124 | 
            +
                    self.net.append(nn.Dropout(dropout))
         | 
| 125 | 
            +
                    # project out
         | 
| 126 | 
            +
                    self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
         | 
| 127 | 
            +
                    # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
         | 
| 128 | 
            +
                    if final_dropout:
         | 
| 129 | 
            +
                        self.net.append(nn.Dropout(dropout))
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                def forward(self, hidden_states):
         | 
| 132 | 
            +
                    for module in self.net:
         | 
| 133 | 
            +
                        hidden_states = module(hidden_states)
         | 
| 134 | 
            +
                    return hidden_states
         | 
| 135 | 
            +
             | 
| 136 | 
            +
             | 
| 137 | 
            +
            @maybe_allow_in_graph
         | 
| 138 | 
            +
            class BasicTransformerBlock(nn.Module):
         | 
| 139 | 
            +
                r"""
         | 
| 140 | 
            +
                A basic Transformer block.
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                Parameters:
         | 
| 143 | 
            +
                    dim (`int`): The number of channels in the input and output.
         | 
| 144 | 
            +
                    num_attention_heads (`int`): The number of heads to use for multi-head attention.
         | 
| 145 | 
            +
                    attention_head_dim (`int`): The number of channels in each head.
         | 
| 146 | 
            +
                    dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
         | 
| 147 | 
            +
                    cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
         | 
| 148 | 
            +
                    only_cross_attention (`bool`, *optional*):
         | 
| 149 | 
            +
                        Whether to use only cross-attention layers. In this case two cross attention layers are used.
         | 
| 150 | 
            +
                    double_self_attention (`bool`, *optional*):
         | 
| 151 | 
            +
                        Whether to use two self-attention layers. In this case no cross attention layers are used.
         | 
| 152 | 
            +
                    activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
         | 
| 153 | 
            +
                    num_embeds_ada_norm (:
         | 
| 154 | 
            +
                        obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
         | 
| 155 | 
            +
                    attention_bias (:
         | 
| 156 | 
            +
                        obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
         | 
| 157 | 
            +
                """
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                def __init__(
         | 
| 160 | 
            +
                    self,
         | 
| 161 | 
            +
                    dim: int,
         | 
| 162 | 
            +
                    num_attention_heads: int,
         | 
| 163 | 
            +
                    attention_head_dim: int,
         | 
| 164 | 
            +
                    dropout=0.0,
         | 
| 165 | 
            +
                    cross_attention_dim: Optional[int] = None,
         | 
| 166 | 
            +
                    activation_fn: str = "geglu",
         | 
| 167 | 
            +
                    num_embeds_ada_norm: Optional[int] = None,
         | 
| 168 | 
            +
                    attention_bias: bool = False,
         | 
| 169 | 
            +
                    only_cross_attention: bool = False,
         | 
| 170 | 
            +
                    double_self_attention: bool = False,
         | 
| 171 | 
            +
                    upcast_attention: bool = False,
         | 
| 172 | 
            +
                    norm_elementwise_affine: bool = True,
         | 
| 173 | 
            +
                    norm_type: str = "layer_norm",
         | 
| 174 | 
            +
                    final_dropout: bool = False,
         | 
| 175 | 
            +
                ):
         | 
| 176 | 
            +
                    super().__init__()
         | 
| 177 | 
            +
                    self.only_cross_attention = only_cross_attention
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                    self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
         | 
| 180 | 
            +
                    self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                    if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
         | 
| 183 | 
            +
                        raise ValueError(
         | 
| 184 | 
            +
                            f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
         | 
| 185 | 
            +
                            f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
         | 
| 186 | 
            +
                        )
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                    # Define 3 blocks. Each block has its own normalization layer.
         | 
| 189 | 
            +
                    # 1. Self-Attn
         | 
| 190 | 
            +
                    if self.use_ada_layer_norm:
         | 
| 191 | 
            +
                        self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
         | 
| 192 | 
            +
                    elif self.use_ada_layer_norm_zero:
         | 
| 193 | 
            +
                        self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
         | 
| 194 | 
            +
                    else:
         | 
| 195 | 
            +
                        self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
         | 
| 196 | 
            +
                    self.attn1 = Attention(
         | 
| 197 | 
            +
                        query_dim=dim,
         | 
| 198 | 
            +
                        heads=num_attention_heads,
         | 
| 199 | 
            +
                        dim_head=attention_head_dim,
         | 
| 200 | 
            +
                        dropout=dropout,
         | 
| 201 | 
            +
                        bias=attention_bias,
         | 
| 202 | 
            +
                        cross_attention_dim=cross_attention_dim if only_cross_attention else None,
         | 
| 203 | 
            +
                        upcast_attention=upcast_attention,
         | 
| 204 | 
            +
                    )
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                    # 2. Cross-Attn
         | 
| 207 | 
            +
                    if cross_attention_dim is not None or double_self_attention:
         | 
| 208 | 
            +
                        # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
         | 
| 209 | 
            +
                        # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
         | 
| 210 | 
            +
                        # the second cross attention block.
         | 
| 211 | 
            +
                        self.norm2 = (
         | 
| 212 | 
            +
                            AdaLayerNorm(dim, num_embeds_ada_norm)
         | 
| 213 | 
            +
                            if self.use_ada_layer_norm
         | 
| 214 | 
            +
                            else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
         | 
| 215 | 
            +
                        )
         | 
| 216 | 
            +
                        self.attn2 = Attention(
         | 
| 217 | 
            +
                            query_dim=dim,
         | 
| 218 | 
            +
                            cross_attention_dim=cross_attention_dim if not double_self_attention else None,
         | 
| 219 | 
            +
                            heads=num_attention_heads,
         | 
| 220 | 
            +
                            dim_head=attention_head_dim,
         | 
| 221 | 
            +
                            dropout=dropout,
         | 
| 222 | 
            +
                            bias=attention_bias,
         | 
| 223 | 
            +
                            upcast_attention=upcast_attention,
         | 
| 224 | 
            +
                            # scale_qk=False, # uncomment this to not to use flash attention
         | 
| 225 | 
            +
                        )  # is self-attn if encoder_hidden_states is none
         | 
| 226 | 
            +
                    else:
         | 
| 227 | 
            +
                        self.norm2 = None
         | 
| 228 | 
            +
                        self.attn2 = None
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                    # 3. Feed-forward
         | 
| 231 | 
            +
                    self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
         | 
| 232 | 
            +
                    self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                    # let chunk size default to None
         | 
| 235 | 
            +
                    self._chunk_size = None
         | 
| 236 | 
            +
                    self._chunk_dim = 0
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
         | 
| 239 | 
            +
                    # Sets chunk feed-forward
         | 
| 240 | 
            +
                    self._chunk_size = chunk_size
         | 
| 241 | 
            +
                    self._chunk_dim = dim
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                def forward(
         | 
| 244 | 
            +
                    self,
         | 
| 245 | 
            +
                    hidden_states: torch.FloatTensor,
         | 
| 246 | 
            +
                    attention_mask: Optional[torch.FloatTensor] = None,
         | 
| 247 | 
            +
                    encoder_hidden_states: Optional[torch.FloatTensor] = None,
         | 
| 248 | 
            +
                    encoder_attention_mask: Optional[torch.FloatTensor] = None,
         | 
| 249 | 
            +
                    timestep: Optional[torch.LongTensor] = None,
         | 
| 250 | 
            +
                    cross_attention_kwargs: Dict[str, Any] = None,
         | 
| 251 | 
            +
                    class_labels: Optional[torch.LongTensor] = None,
         | 
| 252 | 
            +
                ):
         | 
| 253 | 
            +
                    # Notice that normalization is always applied before the real computation in the following blocks.
         | 
| 254 | 
            +
                    # 1. Self-Attention
         | 
| 255 | 
            +
                    if self.use_ada_layer_norm:
         | 
| 256 | 
            +
                        norm_hidden_states = self.norm1(hidden_states, timestep)
         | 
| 257 | 
            +
                    elif self.use_ada_layer_norm_zero:
         | 
| 258 | 
            +
                        norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
         | 
| 259 | 
            +
                            hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
         | 
| 260 | 
            +
                        )
         | 
| 261 | 
            +
                    else:
         | 
| 262 | 
            +
                        norm_hidden_states = self.norm1(hidden_states)
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                    cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
         | 
| 265 | 
            +
             | 
| 266 | 
            +
                    attn_output = self.attn1(
         | 
| 267 | 
            +
                        norm_hidden_states,
         | 
| 268 | 
            +
                        encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
         | 
| 269 | 
            +
                        attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask,
         | 
| 270 | 
            +
                        **cross_attention_kwargs,
         | 
| 271 | 
            +
                    )
         | 
| 272 | 
            +
                    if self.use_ada_layer_norm_zero:
         | 
| 273 | 
            +
                        attn_output = gate_msa.unsqueeze(1) * attn_output
         | 
| 274 | 
            +
                    hidden_states = attn_output + hidden_states
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                    # 2. Cross-Attention
         | 
| 277 | 
            +
                    if self.attn2 is not None:
         | 
| 278 | 
            +
                        norm_hidden_states = (
         | 
| 279 | 
            +
                            self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
         | 
| 280 | 
            +
                        )
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                        attn_output = self.attn2(
         | 
| 283 | 
            +
                            norm_hidden_states,
         | 
| 284 | 
            +
                            encoder_hidden_states=encoder_hidden_states,
         | 
| 285 | 
            +
                            attention_mask=encoder_attention_mask,
         | 
| 286 | 
            +
                            **cross_attention_kwargs,
         | 
| 287 | 
            +
                        )
         | 
| 288 | 
            +
                        hidden_states = attn_output + hidden_states
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                    # 3. Feed-forward
         | 
| 291 | 
            +
                    norm_hidden_states = self.norm3(hidden_states)
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                    if self.use_ada_layer_norm_zero:
         | 
| 294 | 
            +
                        norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
         | 
| 295 | 
            +
             | 
| 296 | 
            +
                    if self._chunk_size is not None:
         | 
| 297 | 
            +
                        # "feed_forward_chunk_size" can be used to save memory
         | 
| 298 | 
            +
                        if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
         | 
| 299 | 
            +
                            raise ValueError(
         | 
| 300 | 
            +
                                f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
         | 
| 301 | 
            +
                            )
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                        num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
         | 
| 304 | 
            +
                        ff_output = torch.cat(
         | 
| 305 | 
            +
                            [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
         | 
| 306 | 
            +
                            dim=self._chunk_dim,
         | 
| 307 | 
            +
                        )
         | 
| 308 | 
            +
                    else:
         | 
| 309 | 
            +
                        ff_output = self.ff(norm_hidden_states)
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                    if self.use_ada_layer_norm_zero:
         | 
| 312 | 
            +
                        ff_output = gate_mlp.unsqueeze(1) * ff_output
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                    hidden_states = ff_output + hidden_states
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                    return hidden_states
         | 
    	
        chatterbox/src/chatterbox/models/s3gen/s3gen.py
    ADDED
    
    | @@ -0,0 +1,305 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Modified from CosyVoice https://github.com/FunAudioLLM/CosyVoice
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            import logging
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            import numpy as np
         | 
| 18 | 
            +
            import torch
         | 
| 19 | 
            +
            import torchaudio as ta
         | 
| 20 | 
            +
            from functools import lru_cache
         | 
| 21 | 
            +
            from typing import Optional
         | 
| 22 | 
            +
            from omegaconf import DictConfig
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            from ..s3tokenizer import S3_SR, SPEECH_VOCAB_SIZE, S3Tokenizer
         | 
| 25 | 
            +
            from .const import S3GEN_SR
         | 
| 26 | 
            +
            from .flow import CausalMaskedDiffWithXvec
         | 
| 27 | 
            +
            from .xvector import CAMPPlus
         | 
| 28 | 
            +
            from .utils.mel import mel_spectrogram
         | 
| 29 | 
            +
            from .f0_predictor import ConvRNNF0Predictor
         | 
| 30 | 
            +
            from .hifigan import HiFTGenerator
         | 
| 31 | 
            +
            from .transformer.upsample_encoder import UpsampleConformerEncoder
         | 
| 32 | 
            +
            from .flow_matching import CausalConditionalCFM
         | 
| 33 | 
            +
            from .decoder import ConditionalDecoder
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            def drop_invalid_tokens(x):
         | 
| 37 | 
            +
                assert len(x.shape) <= 2 and x.shape[0] == 1, "only batch size of one allowed for now"
         | 
| 38 | 
            +
                return x[x < SPEECH_VOCAB_SIZE]
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            # TODO: global resampler cache
         | 
| 42 | 
            +
            @lru_cache(100)
         | 
| 43 | 
            +
            def get_resampler(src_sr, dst_sr, device):
         | 
| 44 | 
            +
                return ta.transforms.Resample(src_sr, dst_sr).to(device)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            class S3Token2Mel(torch.nn.Module):
         | 
| 48 | 
            +
                """
         | 
| 49 | 
            +
                CosyVoice2's CFM decoder maps S3 speech tokens to mel-spectrograms.
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                TODO: make these modules configurable?
         | 
| 52 | 
            +
                """
         | 
| 53 | 
            +
                def __init__(self):
         | 
| 54 | 
            +
                    super().__init__()
         | 
| 55 | 
            +
                    self.tokenizer = S3Tokenizer("speech_tokenizer_v2_25hz")
         | 
| 56 | 
            +
                    self.mel_extractor = mel_spectrogram # TODO: make it a torch module?
         | 
| 57 | 
            +
                    self.speaker_encoder = CAMPPlus()  # use default args
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    encoder = UpsampleConformerEncoder(
         | 
| 60 | 
            +
                        output_size=512,
         | 
| 61 | 
            +
                        attention_heads=8,
         | 
| 62 | 
            +
                        linear_units=2048,
         | 
| 63 | 
            +
                        num_blocks=6,
         | 
| 64 | 
            +
                        dropout_rate=0.1,
         | 
| 65 | 
            +
                        positional_dropout_rate=0.1,
         | 
| 66 | 
            +
                        attention_dropout_rate=0.1,
         | 
| 67 | 
            +
                        normalize_before=True,
         | 
| 68 | 
            +
                        input_layer='linear',
         | 
| 69 | 
            +
                        pos_enc_layer_type='rel_pos_espnet',
         | 
| 70 | 
            +
                        selfattention_layer_type='rel_selfattn',
         | 
| 71 | 
            +
                        input_size=512,
         | 
| 72 | 
            +
                        use_cnn_module=False,
         | 
| 73 | 
            +
                        macaron_style=False,
         | 
| 74 | 
            +
                    )
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    estimator = ConditionalDecoder(
         | 
| 77 | 
            +
                        in_channels=320,
         | 
| 78 | 
            +
                        out_channels=80,
         | 
| 79 | 
            +
                        causal=True,
         | 
| 80 | 
            +
                        channels=[256],
         | 
| 81 | 
            +
                        dropout=0.0,
         | 
| 82 | 
            +
                        attention_head_dim=64,
         | 
| 83 | 
            +
                        n_blocks=4,
         | 
| 84 | 
            +
                        num_mid_blocks=12,
         | 
| 85 | 
            +
                        num_heads=8,
         | 
| 86 | 
            +
                        act_fn='gelu',
         | 
| 87 | 
            +
                    )
         | 
| 88 | 
            +
                    cfm_params = DictConfig({
         | 
| 89 | 
            +
                        "sigma_min": 1e-06,
         | 
| 90 | 
            +
                        "solver": 'euler',
         | 
| 91 | 
            +
                        "t_scheduler": 'cosine',
         | 
| 92 | 
            +
                        "training_cfg_rate": 0.2,
         | 
| 93 | 
            +
                        "inference_cfg_rate": 0.7,
         | 
| 94 | 
            +
                        "reg_loss_type": 'l1',
         | 
| 95 | 
            +
                    })
         | 
| 96 | 
            +
                    decoder = CausalConditionalCFM(
         | 
| 97 | 
            +
                        spk_emb_dim=80,
         | 
| 98 | 
            +
                        cfm_params=cfm_params,
         | 
| 99 | 
            +
                        estimator=estimator,
         | 
| 100 | 
            +
                    )
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    self.flow = CausalMaskedDiffWithXvec(
         | 
| 103 | 
            +
                        encoder=encoder,
         | 
| 104 | 
            +
                        decoder=decoder
         | 
| 105 | 
            +
                    )
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    self.resamplers = {}
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                @property
         | 
| 110 | 
            +
                def device(self):
         | 
| 111 | 
            +
                    params = self.tokenizer.parameters()
         | 
| 112 | 
            +
                    return next(params).device
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                def embed_ref(
         | 
| 115 | 
            +
                    self,
         | 
| 116 | 
            +
                    ref_wav: torch.Tensor,
         | 
| 117 | 
            +
                    ref_sr: int,
         | 
| 118 | 
            +
                    device="auto",
         | 
| 119 | 
            +
                    ref_fade_out=True,
         | 
| 120 | 
            +
                ):
         | 
| 121 | 
            +
                    device = self.device if device == "auto" else device
         | 
| 122 | 
            +
                    if isinstance(ref_wav, np.ndarray):
         | 
| 123 | 
            +
                        ref_wav = torch.from_numpy(ref_wav).float()
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                    if ref_wav.device != device:
         | 
| 126 | 
            +
                        ref_wav = ref_wav.to(device)
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                    if len(ref_wav.shape) == 1:
         | 
| 129 | 
            +
                        ref_wav = ref_wav.unsqueeze(0)  # (B, L)
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                    if ref_wav.size(1) > 10 * ref_sr:
         | 
| 132 | 
            +
                        print("WARNING: cosydec received ref longer than 10s")
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                    ref_wav_24 = ref_wav
         | 
| 135 | 
            +
                    if ref_sr != S3GEN_SR:
         | 
| 136 | 
            +
                        ref_wav_24 = get_resampler(ref_sr, S3GEN_SR, device)(ref_wav)
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                    ref_mels_24 = self.mel_extractor(ref_wav_24).transpose(1, 2).to(device)
         | 
| 139 | 
            +
                    ref_mels_24_len = None
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                    # Resample to 16kHz
         | 
| 142 | 
            +
                    ref_wav_16 = get_resampler(ref_sr, S3_SR, device)(ref_wav).to(device)
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                    # Speaker embedding
         | 
| 145 | 
            +
                    ref_x_vector = self.speaker_encoder.inference(ref_wav_16)
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                    # Tokenize 16khz reference
         | 
| 148 | 
            +
                    ref_speech_tokens, ref_speech_token_lens = self.tokenizer(ref_wav_16)
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    # Make sure mel_len = 2 * stoken_len (happens when the input is not padded to multiple of 40ms)
         | 
| 151 | 
            +
                    if ref_mels_24.shape[1] != 2 * ref_speech_tokens.shape[1]:
         | 
| 152 | 
            +
                        logging.warning(
         | 
| 153 | 
            +
                            "Reference mel length is not equal to 2 * reference token length.\n"
         | 
| 154 | 
            +
                        )
         | 
| 155 | 
            +
                        ref_speech_tokens = ref_speech_tokens[:, :ref_mels_24.shape[1] // 2]
         | 
| 156 | 
            +
                        ref_speech_token_lens[0] = ref_speech_tokens.shape[1]
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                    return dict(
         | 
| 159 | 
            +
                        prompt_token=ref_speech_tokens.to(device),
         | 
| 160 | 
            +
                        prompt_token_len=ref_speech_token_lens,
         | 
| 161 | 
            +
                        prompt_feat=ref_mels_24,
         | 
| 162 | 
            +
                        prompt_feat_len=ref_mels_24_len,
         | 
| 163 | 
            +
                        embedding=ref_x_vector,
         | 
| 164 | 
            +
                    )
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                def forward(
         | 
| 167 | 
            +
                    self,
         | 
| 168 | 
            +
                    speech_tokens: torch.LongTensor,
         | 
| 169 | 
            +
                    # locally-computed ref embedding (mutex with ref_dict)
         | 
| 170 | 
            +
                    ref_wav: Optional[torch.Tensor],
         | 
| 171 | 
            +
                    ref_sr: Optional[int],
         | 
| 172 | 
            +
                    # pre-computed ref embedding (prod API)
         | 
| 173 | 
            +
                    ref_dict: Optional[dict] = None,
         | 
| 174 | 
            +
                    finalize: bool = False,
         | 
| 175 | 
            +
                ):
         | 
| 176 | 
            +
                    """
         | 
| 177 | 
            +
                    Generate waveforms from S3 speech tokens and a reference waveform, which the speaker timbre is inferred from.
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                    NOTE:
         | 
| 180 | 
            +
                    - The speaker encoder accepts 16 kHz waveform.
         | 
| 181 | 
            +
                    - S3TokenizerV2 accepts 16 kHz waveform.
         | 
| 182 | 
            +
                    - The mel-spectrogram for the reference assumes 24 kHz input signal.
         | 
| 183 | 
            +
                    - This function is designed for batch_size=1 only.
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                    Args
         | 
| 186 | 
            +
                    ----
         | 
| 187 | 
            +
                    - `speech_tokens`: S3 speech tokens [B=1, T]
         | 
| 188 | 
            +
                    - `ref_wav`: reference waveform (`torch.Tensor` with shape=[B=1, T])
         | 
| 189 | 
            +
                    - `ref_sr`: reference sample rate
         | 
| 190 | 
            +
                    - `finalize`: whether streaming is finished or not. Note that if False, the last 3 tokens will be ignored.
         | 
| 191 | 
            +
                    """
         | 
| 192 | 
            +
                    assert (ref_wav is None) ^ (ref_dict is None), f"Must provide exactly one of ref_wav or ref_dict (got {ref_wav} and {ref_dict})"
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                    if ref_dict is None:
         | 
| 195 | 
            +
                        ref_dict = self.embed_ref(ref_wav, ref_sr)
         | 
| 196 | 
            +
                    else:
         | 
| 197 | 
            +
                        # type/device casting (all values will be numpy if it's from a prod API call)
         | 
| 198 | 
            +
                        for rk in list(ref_dict):
         | 
| 199 | 
            +
                            if isinstance(ref_dict[rk], np.ndarray):
         | 
| 200 | 
            +
                                ref_dict[rk] = torch.from_numpy(ref_dict[rk])
         | 
| 201 | 
            +
                            if torch.is_tensor(ref_dict[rk]):
         | 
| 202 | 
            +
                                ref_dict[rk] = ref_dict[rk].to(self.device)
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                    if len(speech_tokens.shape) == 1:
         | 
| 205 | 
            +
                        speech_tokens = speech_tokens.unsqueeze(0)
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                    # assert speech_tokens.shape[0] == 1, "only batch size of one allowed for now"
         | 
| 208 | 
            +
                    speech_token_lens = torch.LongTensor([speech_tokens.size(1)]).to(self.device)
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                    output_mels, _ = self.flow.inference(
         | 
| 211 | 
            +
                        token=speech_tokens,
         | 
| 212 | 
            +
                        token_len=speech_token_lens,
         | 
| 213 | 
            +
                        finalize=finalize,
         | 
| 214 | 
            +
                        **ref_dict,
         | 
| 215 | 
            +
                    )
         | 
| 216 | 
            +
                    return output_mels
         | 
| 217 | 
            +
             | 
| 218 | 
            +
             | 
| 219 | 
            +
            class S3Token2Wav(S3Token2Mel):
         | 
| 220 | 
            +
                """
         | 
| 221 | 
            +
                The decoder of CosyVoice2 is a concat of token-to-mel (CFM) and a mel-to-waveform (HiFiGAN) modules.
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                TODO: make these modules configurable?
         | 
| 224 | 
            +
                """
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                def __init__(self):
         | 
| 227 | 
            +
                    super().__init__()
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                    f0_predictor = ConvRNNF0Predictor()
         | 
| 230 | 
            +
                    self.mel2wav = HiFTGenerator(
         | 
| 231 | 
            +
                        sampling_rate=S3GEN_SR,
         | 
| 232 | 
            +
                        upsample_rates=[8, 5, 3],
         | 
| 233 | 
            +
                        upsample_kernel_sizes=[16, 11, 7],
         | 
| 234 | 
            +
                        source_resblock_kernel_sizes=[7, 7, 11],
         | 
| 235 | 
            +
                        source_resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
         | 
| 236 | 
            +
                        f0_predictor=f0_predictor,
         | 
| 237 | 
            +
                    )
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                    # silence out a few ms and fade audio in to reduce artifacts
         | 
| 240 | 
            +
                    n_trim = S3GEN_SR // 50  # 20ms = half of a frame
         | 
| 241 | 
            +
                    trim_fade = torch.zeros(2 * n_trim)
         | 
| 242 | 
            +
                    trim_fade[n_trim:] = (torch.cos(torch.linspace(torch.pi, 0, n_trim)) + 1) / 2
         | 
| 243 | 
            +
                    self.register_buffer("trim_fade", trim_fade, persistent=False) # (buffers get automatic device casting)
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                def forward(
         | 
| 246 | 
            +
                    self,
         | 
| 247 | 
            +
                    speech_tokens,
         | 
| 248 | 
            +
                    # locally-computed ref embedding (mutex with ref_dict)
         | 
| 249 | 
            +
                    ref_wav: Optional[torch.Tensor],
         | 
| 250 | 
            +
                    ref_sr: Optional[int],
         | 
| 251 | 
            +
                    # pre-computed ref embedding (prod API)
         | 
| 252 | 
            +
                    ref_dict: Optional[dict] = None,
         | 
| 253 | 
            +
                    finalize: bool = False
         | 
| 254 | 
            +
                ):
         | 
| 255 | 
            +
                    output_mels = super().forward(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize)
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                    # TODO jrm: ignoring the speed control (mel interpolation) and the HiFTGAN caching mechanisms for now.
         | 
| 258 | 
            +
                    hift_cache_source = torch.zeros(1, 1, 0).to(self.device)
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                    output_wavs, *_ = self.mel2wav.inference(speech_feat=output_mels, cache_source=hift_cache_source)
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                    if not self.training:
         | 
| 263 | 
            +
                        # NOTE: ad-hoc method to reduce "spillover" from the reference clip.
         | 
| 264 | 
            +
                        output_wavs[:, :len(self.trim_fade)] *= self.trim_fade
         | 
| 265 | 
            +
             | 
| 266 | 
            +
                    return output_wavs
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                @torch.inference_mode()
         | 
| 269 | 
            +
                def flow_inference(
         | 
| 270 | 
            +
                    self,
         | 
| 271 | 
            +
                    speech_tokens,
         | 
| 272 | 
            +
                    # locally-computed ref embedding (mutex with ref_dict)
         | 
| 273 | 
            +
                    ref_wav: Optional[torch.Tensor] = None,
         | 
| 274 | 
            +
                    ref_sr: Optional[int] = None,
         | 
| 275 | 
            +
                    # pre-computed ref embedding (prod API)
         | 
| 276 | 
            +
                    ref_dict: Optional[dict] = None,
         | 
| 277 | 
            +
                    finalize: bool = False,
         | 
| 278 | 
            +
                ):
         | 
| 279 | 
            +
                    return super().forward(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize)
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                @torch.inference_mode()
         | 
| 282 | 
            +
                def hift_inference(self, speech_feat, cache_source: torch.Tensor = None):
         | 
| 283 | 
            +
                    if cache_source is None:
         | 
| 284 | 
            +
                        cache_source = torch.zeros(1, 1, 0).to(self.device)
         | 
| 285 | 
            +
                    return self.mel2wav.inference(speech_feat=speech_feat, cache_source=cache_source)
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                @torch.inference_mode()
         | 
| 288 | 
            +
                def inference(
         | 
| 289 | 
            +
                    self,
         | 
| 290 | 
            +
                    speech_tokens,
         | 
| 291 | 
            +
                    # locally-computed ref embedding (mutex with ref_dict)
         | 
| 292 | 
            +
                    ref_wav: Optional[torch.Tensor] = None,
         | 
| 293 | 
            +
                    ref_sr: Optional[int] = None,
         | 
| 294 | 
            +
                    # pre-computed ref embedding (prod API)
         | 
| 295 | 
            +
                    ref_dict: Optional[dict] = None,
         | 
| 296 | 
            +
                    cache_source: torch.Tensor = None, # NOTE: this arg is for streaming, it can probably be removed here
         | 
| 297 | 
            +
                    finalize: bool = True,
         | 
| 298 | 
            +
                ):
         | 
| 299 | 
            +
                    output_mels = self.flow_inference(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize)
         | 
| 300 | 
            +
                    output_wavs, output_sources = self.hift_inference(output_mels, cache_source)
         | 
| 301 | 
            +
             | 
| 302 | 
            +
                    # NOTE: ad-hoc method to reduce "spillover" from the reference clip.
         | 
| 303 | 
            +
                    output_wavs[:, :len(self.trim_fade)] *= self.trim_fade
         | 
| 304 | 
            +
             | 
| 305 | 
            +
                    return output_wavs, output_sources
         | 
    	
        chatterbox/src/chatterbox/models/s3gen/transformer/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/__init__.cpython-311.pyc
    ADDED
    
    | Binary file (190 Bytes). View file | 
|  | 
    	
        chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/activation.cpython-311.pyc
    ADDED
    
    | Binary file (3.58 kB). View file | 
|  | 
    	
        chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/attention.cpython-311.pyc
    ADDED
    
    | Binary file (15.7 kB). View file | 
|  | 
    	
        chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/convolution.cpython-311.pyc
    ADDED
    
    | Binary file (5.54 kB). View file | 
|  | 
    	
        chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/embedding.cpython-311.pyc
    ADDED
    
    | Binary file (17.3 kB). View file | 
|  | 
    	
        chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/encoder_layer.cpython-311.pyc
    ADDED
    
    | Binary file (11.2 kB). View file | 
|  | 
    	
        chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/positionwise_feed_forward.cpython-311.pyc
    ADDED
    
    | Binary file (6.24 kB). View file | 
|  | 
    	
        chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/subsampling.cpython-311.pyc
    ADDED
    
    | Binary file (18.9 kB). View file | 
|  | 
    	
        chatterbox/src/chatterbox/models/s3gen/transformer/__pycache__/upsample_encoder.cpython-311.pyc
    ADDED
    
    | Binary file (15.6 kB). View file | 
|  | 
    	
        chatterbox/src/chatterbox/models/s3gen/transformer/activation.py
    ADDED
    
    | @@ -0,0 +1,84 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe)
         | 
| 2 | 
            +
            #               2020 Northwestern Polytechnical University (Pengcheng Guo)
         | 
| 3 | 
            +
            #               2020 Mobvoi Inc (Binbin Zhang)
         | 
| 4 | 
            +
            #               2024 Alibaba Inc (Xiang Lyu)
         | 
| 5 | 
            +
            #
         | 
| 6 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 7 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 8 | 
            +
            # You may obtain a copy of the License at
         | 
| 9 | 
            +
            #
         | 
| 10 | 
            +
            #   http://www.apache.org/licenses/LICENSE-2.0
         | 
| 11 | 
            +
            #
         | 
| 12 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 13 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 14 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 15 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 16 | 
            +
            # limitations under the License.
         | 
| 17 | 
            +
            """Swish() activation function for Conformer."""
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            import torch
         | 
| 20 | 
            +
            from torch import nn, sin, pow
         | 
| 21 | 
            +
            from torch.nn import Parameter
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            class Swish(torch.nn.Module):
         | 
| 25 | 
            +
                """Construct an Swish object."""
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 28 | 
            +
                    """Return Swish activation function."""
         | 
| 29 | 
            +
                    return x * torch.sigmoid(x)
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
            +
            # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
         | 
| 33 | 
            +
            #   LICENSE is in incl_licenses directory.
         | 
| 34 | 
            +
            class Snake(nn.Module):
         | 
| 35 | 
            +
                '''
         | 
| 36 | 
            +
                Implementation of a sine-based periodic activation function
         | 
| 37 | 
            +
                Shape:
         | 
| 38 | 
            +
                    - Input: (B, C, T)
         | 
| 39 | 
            +
                    - Output: (B, C, T), same shape as the input
         | 
| 40 | 
            +
                Parameters:
         | 
| 41 | 
            +
                    - alpha - trainable parameter
         | 
| 42 | 
            +
                References:
         | 
| 43 | 
            +
                    - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
         | 
| 44 | 
            +
                    https://arxiv.org/abs/2006.08195
         | 
| 45 | 
            +
                Examples:
         | 
| 46 | 
            +
                    >>> a1 = snake(256)
         | 
| 47 | 
            +
                    >>> x = torch.randn(256)
         | 
| 48 | 
            +
                    >>> x = a1(x)
         | 
| 49 | 
            +
                '''
         | 
| 50 | 
            +
                def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
         | 
| 51 | 
            +
                    '''
         | 
| 52 | 
            +
                    Initialization.
         | 
| 53 | 
            +
                    INPUT:
         | 
| 54 | 
            +
                        - in_features: shape of the input
         | 
| 55 | 
            +
                        - alpha: trainable parameter
         | 
| 56 | 
            +
                        alpha is initialized to 1 by default, higher values = higher-frequency.
         | 
| 57 | 
            +
                        alpha will be trained along with the rest of your model.
         | 
| 58 | 
            +
                    '''
         | 
| 59 | 
            +
                    super(Snake, self).__init__()
         | 
| 60 | 
            +
                    self.in_features = in_features
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    # initialize alpha
         | 
| 63 | 
            +
                    self.alpha_logscale = alpha_logscale
         | 
| 64 | 
            +
                    if self.alpha_logscale:  # log scale alphas initialized to zeros
         | 
| 65 | 
            +
                        self.alpha = Parameter(torch.zeros(in_features) * alpha)
         | 
| 66 | 
            +
                    else:  # linear scale alphas initialized to ones
         | 
| 67 | 
            +
                        self.alpha = Parameter(torch.ones(in_features) * alpha)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    self.alpha.requires_grad = alpha_trainable
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    self.no_div_by_zero = 0.000000001
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                def forward(self, x):
         | 
| 74 | 
            +
                    '''
         | 
| 75 | 
            +
                    Forward pass of the function.
         | 
| 76 | 
            +
                    Applies the function to the input elementwise.
         | 
| 77 | 
            +
                    Snake ∶= x + 1/a * sin^2 (xa)
         | 
| 78 | 
            +
                    '''
         | 
| 79 | 
            +
                    alpha = self.alpha.unsqueeze(0).unsqueeze(-1)  # line up with x to [B, C, T]
         | 
| 80 | 
            +
                    if self.alpha_logscale:
         | 
| 81 | 
            +
                        alpha = torch.exp(alpha)
         | 
| 82 | 
            +
                    x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    return x
         | 
    	
        chatterbox/src/chatterbox/models/s3gen/transformer/attention.py
    ADDED
    
    | @@ -0,0 +1,330 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2019 Shigeki Karita
         | 
| 2 | 
            +
            #               2020 Mobvoi Inc (Binbin Zhang)
         | 
| 3 | 
            +
            #               2022 Xingchen Song ([email protected])
         | 
| 4 | 
            +
            #               2024 Alibaba Inc (Xiang Lyu)
         | 
| 5 | 
            +
            #
         | 
| 6 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 7 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 8 | 
            +
            # You may obtain a copy of the License at
         | 
| 9 | 
            +
            #
         | 
| 10 | 
            +
            #   http://www.apache.org/licenses/LICENSE-2.0
         | 
| 11 | 
            +
            #
         | 
| 12 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 13 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 14 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 15 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 16 | 
            +
            # limitations under the License.
         | 
| 17 | 
            +
            """Multi-Head Attention layer definition."""
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            import math
         | 
| 20 | 
            +
            from typing import Tuple
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            import torch
         | 
| 23 | 
            +
            from torch import nn
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            class MultiHeadedAttention(nn.Module):
         | 
| 27 | 
            +
                """Multi-Head Attention layer.
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                Args:
         | 
| 30 | 
            +
                    n_head (int): The number of heads.
         | 
| 31 | 
            +
                    n_feat (int): The number of features.
         | 
| 32 | 
            +
                    dropout_rate (float): Dropout rate.
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                """
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                def __init__(self,
         | 
| 37 | 
            +
                             n_head: int,
         | 
| 38 | 
            +
                             n_feat: int,
         | 
| 39 | 
            +
                             dropout_rate: float,
         | 
| 40 | 
            +
                             key_bias: bool = True):
         | 
| 41 | 
            +
                    """Construct an MultiHeadedAttention object."""
         | 
| 42 | 
            +
                    super().__init__()
         | 
| 43 | 
            +
                    assert n_feat % n_head == 0
         | 
| 44 | 
            +
                    # We assume d_v always equals d_k
         | 
| 45 | 
            +
                    self.d_k = n_feat // n_head
         | 
| 46 | 
            +
                    self.h = n_head
         | 
| 47 | 
            +
                    self.linear_q = nn.Linear(n_feat, n_feat)
         | 
| 48 | 
            +
                    self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
         | 
| 49 | 
            +
                    self.linear_v = nn.Linear(n_feat, n_feat)
         | 
| 50 | 
            +
                    self.linear_out = nn.Linear(n_feat, n_feat)
         | 
| 51 | 
            +
                    self.dropout = nn.Dropout(p=dropout_rate)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                def forward_qkv(
         | 
| 54 | 
            +
                    self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
         | 
| 55 | 
            +
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
         | 
| 56 | 
            +
                    """Transform query, key and value.
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    Args:
         | 
| 59 | 
            +
                        query (torch.Tensor): Query tensor (#batch, time1, size).
         | 
| 60 | 
            +
                        key (torch.Tensor): Key tensor (#batch, time2, size).
         | 
| 61 | 
            +
                        value (torch.Tensor): Value tensor (#batch, time2, size).
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    Returns:
         | 
| 64 | 
            +
                        torch.Tensor: Transformed query tensor, size
         | 
| 65 | 
            +
                            (#batch, n_head, time1, d_k).
         | 
| 66 | 
            +
                        torch.Tensor: Transformed key tensor, size
         | 
| 67 | 
            +
                            (#batch, n_head, time2, d_k).
         | 
| 68 | 
            +
                        torch.Tensor: Transformed value tensor, size
         | 
| 69 | 
            +
                            (#batch, n_head, time2, d_k).
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    """
         | 
| 72 | 
            +
                    n_batch = query.size(0)
         | 
| 73 | 
            +
                    q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
         | 
| 74 | 
            +
                    k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
         | 
| 75 | 
            +
                    v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
         | 
| 76 | 
            +
                    q = q.transpose(1, 2)  # (batch, head, time1, d_k)
         | 
| 77 | 
            +
                    k = k.transpose(1, 2)  # (batch, head, time2, d_k)
         | 
| 78 | 
            +
                    v = v.transpose(1, 2)  # (batch, head, time2, d_k)
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    return q, k, v
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                def forward_attention(
         | 
| 83 | 
            +
                    self,
         | 
| 84 | 
            +
                    value: torch.Tensor,
         | 
| 85 | 
            +
                    scores: torch.Tensor,
         | 
| 86 | 
            +
                    mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
         | 
| 87 | 
            +
                ) -> torch.Tensor:
         | 
| 88 | 
            +
                    """Compute attention context vector.
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    Args:
         | 
| 91 | 
            +
                        value (torch.Tensor): Transformed value, size
         | 
| 92 | 
            +
                            (#batch, n_head, time2, d_k).
         | 
| 93 | 
            +
                        scores (torch.Tensor): Attention score, size
         | 
| 94 | 
            +
                            (#batch, n_head, time1, time2).
         | 
| 95 | 
            +
                        mask (torch.Tensor): Mask, size (#batch, 1, time2) or
         | 
| 96 | 
            +
                            (#batch, time1, time2), (0, 0, 0) means fake mask.
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    Returns:
         | 
| 99 | 
            +
                        torch.Tensor: Transformed value (#batch, time1, d_model)
         | 
| 100 | 
            +
                            weighted by the attention score (#batch, time1, time2).
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    """
         | 
| 103 | 
            +
                    n_batch = value.size(0)
         | 
| 104 | 
            +
                    # NOTE(xcsong): When will `if mask.size(2) > 0` be True?
         | 
| 105 | 
            +
                    #   1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
         | 
| 106 | 
            +
                    #           1st chunk to ease the onnx export.]
         | 
| 107 | 
            +
                    #   2. pytorch training
         | 
| 108 | 
            +
                    if mask.size(2) > 0:  # time2 > 0
         | 
| 109 | 
            +
                        mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
         | 
| 110 | 
            +
                        # For last chunk, time2 might be larger than scores.size(-1)
         | 
| 111 | 
            +
                        mask = mask[:, :, :, :scores.size(-1)]  # (batch, 1, *, time2)
         | 
| 112 | 
            +
                        scores = scores.masked_fill(mask, -float('inf'))
         | 
| 113 | 
            +
                        attn = torch.softmax(scores, dim=-1).masked_fill(
         | 
| 114 | 
            +
                            mask, 0.0)  # (batch, head, time1, time2)
         | 
| 115 | 
            +
                    # NOTE(xcsong): When will `if mask.size(2) > 0` be False?
         | 
| 116 | 
            +
                    #   1. onnx(16/-1, -1/-1, 16/0)
         | 
| 117 | 
            +
                    #   2. jit (16/-1, -1/-1, 16/0, 16/4)
         | 
| 118 | 
            +
                    else:
         | 
| 119 | 
            +
                        attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                    p_attn = self.dropout(attn)
         | 
| 122 | 
            +
                    x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
         | 
| 123 | 
            +
                    x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
         | 
| 124 | 
            +
                                                             self.h * self.d_k)
         | 
| 125 | 
            +
                         )  # (batch, time1, d_model)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                    return self.linear_out(x)  # (batch, time1, d_model)
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                def forward(
         | 
| 130 | 
            +
                    self,
         | 
| 131 | 
            +
                    query: torch.Tensor,
         | 
| 132 | 
            +
                    key: torch.Tensor,
         | 
| 133 | 
            +
                    value: torch.Tensor,
         | 
| 134 | 
            +
                    mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
         | 
| 135 | 
            +
                    pos_emb: torch.Tensor = torch.empty(0),
         | 
| 136 | 
            +
                    cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
         | 
| 137 | 
            +
                ) -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 138 | 
            +
                    """Compute scaled dot product attention.
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    Args:
         | 
| 141 | 
            +
                        query (torch.Tensor): Query tensor (#batch, time1, size).
         | 
| 142 | 
            +
                        key (torch.Tensor): Key tensor (#batch, time2, size).
         | 
| 143 | 
            +
                        value (torch.Tensor): Value tensor (#batch, time2, size).
         | 
| 144 | 
            +
                        mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
         | 
| 145 | 
            +
                            (#batch, time1, time2).
         | 
| 146 | 
            +
                            1.When applying cross attention between decoder and encoder,
         | 
| 147 | 
            +
                            the batch padding mask for input is in (#batch, 1, T) shape.
         | 
| 148 | 
            +
                            2.When applying self attention of encoder,
         | 
| 149 | 
            +
                            the mask is in (#batch, T, T)  shape.
         | 
| 150 | 
            +
                            3.When applying self attention of decoder,
         | 
| 151 | 
            +
                            the mask is in (#batch, L, L)  shape.
         | 
| 152 | 
            +
                            4.If the different position in decoder see different block
         | 
| 153 | 
            +
                            of the encoder, such as Mocha, the passed in mask could be
         | 
| 154 | 
            +
                            in (#batch, L, T) shape. But there is no such case in current
         | 
| 155 | 
            +
                            CosyVoice.
         | 
| 156 | 
            +
                        cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
         | 
| 157 | 
            +
                            where `cache_t == chunk_size * num_decoding_left_chunks`
         | 
| 158 | 
            +
                            and `head * d_k == size`
         | 
| 159 | 
            +
             | 
| 160 | 
            +
             | 
| 161 | 
            +
                    Returns:
         | 
| 162 | 
            +
                        torch.Tensor: Output tensor (#batch, time1, d_model).
         | 
| 163 | 
            +
                        torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
         | 
| 164 | 
            +
                            where `cache_t == chunk_size * num_decoding_left_chunks`
         | 
| 165 | 
            +
                            and `head * d_k == size`
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                    """
         | 
| 168 | 
            +
                    q, k, v = self.forward_qkv(query, key, value)
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                    # NOTE(xcsong):
         | 
| 171 | 
            +
                    #   when export onnx model, for 1st chunk, we feed
         | 
| 172 | 
            +
                    #       cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
         | 
| 173 | 
            +
                    #       or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
         | 
| 174 | 
            +
                    #       In all modes, `if cache.size(0) > 0` will alwayse be `True`
         | 
| 175 | 
            +
                    #       and we will always do splitting and
         | 
| 176 | 
            +
                    #       concatnation(this will simplify onnx export). Note that
         | 
| 177 | 
            +
                    #       it's OK to concat & split zero-shaped tensors(see code below).
         | 
| 178 | 
            +
                    #   when export jit  model, for 1st chunk, we always feed
         | 
| 179 | 
            +
                    #       cache(0, 0, 0, 0) since jit supports dynamic if-branch.
         | 
| 180 | 
            +
                    # >>> a = torch.ones((1, 2, 0, 4))
         | 
| 181 | 
            +
                    # >>> b = torch.ones((1, 2, 3, 4))
         | 
| 182 | 
            +
                    # >>> c = torch.cat((a, b), dim=2)
         | 
| 183 | 
            +
                    # >>> torch.equal(b, c)        # True
         | 
| 184 | 
            +
                    # >>> d = torch.split(a, 2, dim=-1)
         | 
| 185 | 
            +
                    # >>> torch.equal(d[0], d[1])  # True
         | 
| 186 | 
            +
                    if cache.size(0) > 0:
         | 
| 187 | 
            +
                        key_cache, value_cache = torch.split(cache,
         | 
| 188 | 
            +
                                                             cache.size(-1) // 2,
         | 
| 189 | 
            +
                                                             dim=-1)
         | 
| 190 | 
            +
                        k = torch.cat([key_cache, k], dim=2)
         | 
| 191 | 
            +
                        v = torch.cat([value_cache, v], dim=2)
         | 
| 192 | 
            +
                    # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
         | 
| 193 | 
            +
                    #   non-trivial to calculate `next_cache_start` here.
         | 
| 194 | 
            +
                    new_cache = torch.cat((k, v), dim=-1)
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
         | 
| 197 | 
            +
                    return self.forward_attention(v, scores, mask), new_cache
         | 
| 198 | 
            +
             | 
| 199 | 
            +
             | 
| 200 | 
            +
            class RelPositionMultiHeadedAttention(MultiHeadedAttention):
         | 
| 201 | 
            +
                """Multi-Head Attention layer with relative position encoding.
         | 
| 202 | 
            +
                Paper: https://arxiv.org/abs/1901.02860
         | 
| 203 | 
            +
                Args:
         | 
| 204 | 
            +
                    n_head (int): The number of heads.
         | 
| 205 | 
            +
                    n_feat (int): The number of features.
         | 
| 206 | 
            +
                    dropout_rate (float): Dropout rate.
         | 
| 207 | 
            +
                """
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                def __init__(self,
         | 
| 210 | 
            +
                             n_head: int,
         | 
| 211 | 
            +
                             n_feat: int,
         | 
| 212 | 
            +
                             dropout_rate: float,
         | 
| 213 | 
            +
                             key_bias: bool = True):
         | 
| 214 | 
            +
                    """Construct an RelPositionMultiHeadedAttention object."""
         | 
| 215 | 
            +
                    super().__init__(n_head, n_feat, dropout_rate, key_bias)
         | 
| 216 | 
            +
                    # linear transformation for positional encoding
         | 
| 217 | 
            +
                    self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
         | 
| 218 | 
            +
                    # these two learnable bias are used in matrix c and matrix d
         | 
| 219 | 
            +
                    # as described in https://arxiv.org/abs/1901.02860 Section 3.3
         | 
| 220 | 
            +
                    self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
         | 
| 221 | 
            +
                    self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
         | 
| 222 | 
            +
                    torch.nn.init.xavier_uniform_(self.pos_bias_u)
         | 
| 223 | 
            +
                    torch.nn.init.xavier_uniform_(self.pos_bias_v)
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 226 | 
            +
                    """Compute relative positional encoding.
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                    Args:
         | 
| 229 | 
            +
                        x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
         | 
| 230 | 
            +
                        time1 means the length of query vector.
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                    Returns:
         | 
| 233 | 
            +
                        torch.Tensor: Output tensor.
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                    """
         | 
| 236 | 
            +
                    zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
         | 
| 237 | 
            +
                                           device=x.device,
         | 
| 238 | 
            +
                                           dtype=x.dtype)
         | 
| 239 | 
            +
                    x_padded = torch.cat([zero_pad, x], dim=-1)
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                    x_padded = x_padded.view(x.size()[0],
         | 
| 242 | 
            +
                                             x.size()[1],
         | 
| 243 | 
            +
                                             x.size(3) + 1, x.size(2))
         | 
| 244 | 
            +
                    x = x_padded[:, :, 1:].view_as(x)[
         | 
| 245 | 
            +
                        :, :, :, : x.size(-1) // 2 + 1
         | 
| 246 | 
            +
                    ]  # only keep the positions from 0 to time2
         | 
| 247 | 
            +
                    return x
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                def forward(
         | 
| 250 | 
            +
                    self,
         | 
| 251 | 
            +
                    query: torch.Tensor,
         | 
| 252 | 
            +
                    key: torch.Tensor,
         | 
| 253 | 
            +
                    value: torch.Tensor,
         | 
| 254 | 
            +
                    mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
         | 
| 255 | 
            +
                    pos_emb: torch.Tensor = torch.empty(0),
         | 
| 256 | 
            +
                    cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
         | 
| 257 | 
            +
                ) -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 258 | 
            +
                    """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
         | 
| 259 | 
            +
                    Args:
         | 
| 260 | 
            +
                        query (torch.Tensor): Query tensor (#batch, time1, size).
         | 
| 261 | 
            +
                        key (torch.Tensor): Key tensor (#batch, time2, size).
         | 
| 262 | 
            +
                        value (torch.Tensor): Value tensor (#batch, time2, size).
         | 
| 263 | 
            +
                        mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
         | 
| 264 | 
            +
                            (#batch, time1, time2), (0, 0, 0) means fake mask.
         | 
| 265 | 
            +
                        pos_emb (torch.Tensor): Positional embedding tensor
         | 
| 266 | 
            +
                            (#batch, time2, size).
         | 
| 267 | 
            +
                        cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
         | 
| 268 | 
            +
                            where `cache_t == chunk_size * num_decoding_left_chunks`
         | 
| 269 | 
            +
                            and `head * d_k == size`
         | 
| 270 | 
            +
                    Returns:
         | 
| 271 | 
            +
                        torch.Tensor: Output tensor (#batch, time1, d_model).
         | 
| 272 | 
            +
                        torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
         | 
| 273 | 
            +
                            where `cache_t == chunk_size * num_decoding_left_chunks`
         | 
| 274 | 
            +
                            and `head * d_k == size`
         | 
| 275 | 
            +
                    """
         | 
| 276 | 
            +
                    q, k, v = self.forward_qkv(query, key, value)
         | 
| 277 | 
            +
                    q = q.transpose(1, 2)  # (batch, time1, head, d_k)
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                    # NOTE(xcsong):
         | 
| 280 | 
            +
                    #   when export onnx model, for 1st chunk, we feed
         | 
| 281 | 
            +
                    #       cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
         | 
| 282 | 
            +
                    #       or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
         | 
| 283 | 
            +
                    #       In all modes, `if cache.size(0) > 0` will alwayse be `True`
         | 
| 284 | 
            +
                    #       and we will always do splitting and
         | 
| 285 | 
            +
                    #       concatnation(this will simplify onnx export). Note that
         | 
| 286 | 
            +
                    #       it's OK to concat & split zero-shaped tensors(see code below).
         | 
| 287 | 
            +
                    #   when export jit  model, for 1st chunk, we always feed
         | 
| 288 | 
            +
                    #       cache(0, 0, 0, 0) since jit supports dynamic if-branch.
         | 
| 289 | 
            +
                    # >>> a = torch.ones((1, 2, 0, 4))
         | 
| 290 | 
            +
                    # >>> b = torch.ones((1, 2, 3, 4))
         | 
| 291 | 
            +
                    # >>> c = torch.cat((a, b), dim=2)
         | 
| 292 | 
            +
                    # >>> torch.equal(b, c)        # True
         | 
| 293 | 
            +
                    # >>> d = torch.split(a, 2, dim=-1)
         | 
| 294 | 
            +
                    # >>> torch.equal(d[0], d[1])  # True
         | 
| 295 | 
            +
                    if cache.size(0) > 0:
         | 
| 296 | 
            +
                        key_cache, value_cache = torch.split(cache,
         | 
| 297 | 
            +
                                                             cache.size(-1) // 2,
         | 
| 298 | 
            +
                                                             dim=-1)
         | 
| 299 | 
            +
                        k = torch.cat([key_cache, k], dim=2)
         | 
| 300 | 
            +
                        v = torch.cat([value_cache, v], dim=2)
         | 
| 301 | 
            +
                    # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
         | 
| 302 | 
            +
                    #   non-trivial to calculate `next_cache_start` here.
         | 
| 303 | 
            +
                    new_cache = torch.cat((k, v), dim=-1)
         | 
| 304 | 
            +
             | 
| 305 | 
            +
                    n_batch_pos = pos_emb.size(0)
         | 
| 306 | 
            +
                    p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
         | 
| 307 | 
            +
                    p = p.transpose(1, 2)  # (batch, head, time1, d_k)
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                    # (batch, head, time1, d_k)
         | 
| 310 | 
            +
                    q_with_bias_u = (q + self.pos_bias_u.to(q.device)).transpose(1, 2)
         | 
| 311 | 
            +
                    # (batch, head, time1, d_k)
         | 
| 312 | 
            +
                    q_with_bias_v = (q + self.pos_bias_v.to(q.device)).transpose(1, 2)
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                    # compute attention score
         | 
| 315 | 
            +
                    # first compute matrix a and matrix c
         | 
| 316 | 
            +
                    # as described in https://arxiv.org/abs/1901.02860 Section 3.3
         | 
| 317 | 
            +
                    # (batch, head, time1, time2)
         | 
| 318 | 
            +
                    matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                    # compute matrix b and matrix d
         | 
| 321 | 
            +
                    # (batch, head, time1, time2)
         | 
| 322 | 
            +
                    matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
         | 
| 323 | 
            +
                    # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
         | 
| 324 | 
            +
                    if matrix_ac.shape != matrix_bd.shape:
         | 
| 325 | 
            +
                        matrix_bd = self.rel_shift(matrix_bd)
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                    scores = (matrix_ac + matrix_bd) / math.sqrt(
         | 
| 328 | 
            +
                        self.d_k)  # (batch, head, time1, time2)
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                    return self.forward_attention(v, scores, mask), new_cache
         | 
    	
        chatterbox/src/chatterbox/models/s3gen/transformer/convolution.py
    ADDED
    
    | @@ -0,0 +1,145 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
         | 
| 2 | 
            +
            #               2024 Alibaba Inc (Xiang Lyu)
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 5 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 6 | 
            +
            # You may obtain a copy of the License at
         | 
| 7 | 
            +
            #
         | 
| 8 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 9 | 
            +
            #
         | 
| 10 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 11 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 12 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 13 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 14 | 
            +
            # limitations under the License.
         | 
| 15 | 
            +
            # Modified from ESPnet(https://github.com/espnet/espnet)
         | 
| 16 | 
            +
            """ConvolutionModule definition."""
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from typing import Tuple
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            import torch
         | 
| 21 | 
            +
            from torch import nn
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            class ConvolutionModule(nn.Module):
         | 
| 25 | 
            +
                """ConvolutionModule in Conformer model."""
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                def __init__(self,
         | 
| 28 | 
            +
                             channels: int,
         | 
| 29 | 
            +
                             kernel_size: int = 15,
         | 
| 30 | 
            +
                             activation: nn.Module = nn.ReLU(),
         | 
| 31 | 
            +
                             norm: str = "batch_norm",
         | 
| 32 | 
            +
                             causal: bool = False,
         | 
| 33 | 
            +
                             bias: bool = True):
         | 
| 34 | 
            +
                    """Construct an ConvolutionModule object.
         | 
| 35 | 
            +
                    Args:
         | 
| 36 | 
            +
                        channels (int): The number of channels of conv layers.
         | 
| 37 | 
            +
                        kernel_size (int): Kernel size of conv layers.
         | 
| 38 | 
            +
                        causal (int): Whether use causal convolution or not
         | 
| 39 | 
            +
                    """
         | 
| 40 | 
            +
                    super().__init__()
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    self.pointwise_conv1 = nn.Conv1d(
         | 
| 43 | 
            +
                        channels,
         | 
| 44 | 
            +
                        2 * channels,
         | 
| 45 | 
            +
                        kernel_size=1,
         | 
| 46 | 
            +
                        stride=1,
         | 
| 47 | 
            +
                        padding=0,
         | 
| 48 | 
            +
                        bias=bias,
         | 
| 49 | 
            +
                    )
         | 
| 50 | 
            +
                    # self.lorder is used to distinguish if it's a causal convolution,
         | 
| 51 | 
            +
                    # if self.lorder > 0: it's a causal convolution, the input will be
         | 
| 52 | 
            +
                    #    padded with self.lorder frames on the left in forward.
         | 
| 53 | 
            +
                    # else: it's a symmetrical convolution
         | 
| 54 | 
            +
                    if causal:
         | 
| 55 | 
            +
                        padding = 0
         | 
| 56 | 
            +
                        self.lorder = kernel_size - 1
         | 
| 57 | 
            +
                    else:
         | 
| 58 | 
            +
                        # kernel_size should be an odd number for none causal convolution
         | 
| 59 | 
            +
                        assert (kernel_size - 1) % 2 == 0
         | 
| 60 | 
            +
                        padding = (kernel_size - 1) // 2
         | 
| 61 | 
            +
                        self.lorder = 0
         | 
| 62 | 
            +
                    self.depthwise_conv = nn.Conv1d(
         | 
| 63 | 
            +
                        channels,
         | 
| 64 | 
            +
                        channels,
         | 
| 65 | 
            +
                        kernel_size,
         | 
| 66 | 
            +
                        stride=1,
         | 
| 67 | 
            +
                        padding=padding,
         | 
| 68 | 
            +
                        groups=channels,
         | 
| 69 | 
            +
                        bias=bias,
         | 
| 70 | 
            +
                    )
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    assert norm in ['batch_norm', 'layer_norm']
         | 
| 73 | 
            +
                    if norm == "batch_norm":
         | 
| 74 | 
            +
                        self.use_layer_norm = False
         | 
| 75 | 
            +
                        self.norm = nn.BatchNorm1d(channels)
         | 
| 76 | 
            +
                    else:
         | 
| 77 | 
            +
                        self.use_layer_norm = True
         | 
| 78 | 
            +
                        self.norm = nn.LayerNorm(channels)
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    self.pointwise_conv2 = nn.Conv1d(
         | 
| 81 | 
            +
                        channels,
         | 
| 82 | 
            +
                        channels,
         | 
| 83 | 
            +
                        kernel_size=1,
         | 
| 84 | 
            +
                        stride=1,
         | 
| 85 | 
            +
                        padding=0,
         | 
| 86 | 
            +
                        bias=bias,
         | 
| 87 | 
            +
                    )
         | 
| 88 | 
            +
                    self.activation = activation
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                def forward(
         | 
| 91 | 
            +
                    self,
         | 
| 92 | 
            +
                    x: torch.Tensor,
         | 
| 93 | 
            +
                    mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
         | 
| 94 | 
            +
                    cache: torch.Tensor = torch.zeros((0, 0, 0)),
         | 
| 95 | 
            +
                ) -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 96 | 
            +
                    """Compute convolution module.
         | 
| 97 | 
            +
                    Args:
         | 
| 98 | 
            +
                        x (torch.Tensor): Input tensor (#batch, time, channels).
         | 
| 99 | 
            +
                        mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
         | 
| 100 | 
            +
                            (0, 0, 0) means fake mask.
         | 
| 101 | 
            +
                        cache (torch.Tensor): left context cache, it is only
         | 
| 102 | 
            +
                            used in causal convolution (#batch, channels, cache_t),
         | 
| 103 | 
            +
                            (0, 0, 0) meas fake cache.
         | 
| 104 | 
            +
                    Returns:
         | 
| 105 | 
            +
                        torch.Tensor: Output tensor (#batch, time, channels).
         | 
| 106 | 
            +
                    """
         | 
| 107 | 
            +
                    # exchange the temporal dimension and the feature dimension
         | 
| 108 | 
            +
                    x = x.transpose(1, 2)  # (#batch, channels, time)
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                    # mask batch padding
         | 
| 111 | 
            +
                    if mask_pad.size(2) > 0:  # time > 0
         | 
| 112 | 
            +
                        x.masked_fill_(~mask_pad, 0.0)
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                    if self.lorder > 0:
         | 
| 115 | 
            +
                        if cache.size(2) == 0:  # cache_t == 0
         | 
| 116 | 
            +
                            x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
         | 
| 117 | 
            +
                        else:
         | 
| 118 | 
            +
                            assert cache.size(0) == x.size(0)  # equal batch
         | 
| 119 | 
            +
                            assert cache.size(1) == x.size(1)  # equal channel
         | 
| 120 | 
            +
                            x = torch.cat((cache, x), dim=2)
         | 
| 121 | 
            +
                        assert (x.size(2) > self.lorder)
         | 
| 122 | 
            +
                        new_cache = x[:, :, -self.lorder:]
         | 
| 123 | 
            +
                    else:
         | 
| 124 | 
            +
                        # It's better we just return None if no cache is required,
         | 
| 125 | 
            +
                        # However, for JIT export, here we just fake one tensor instead of
         | 
| 126 | 
            +
                        # None.
         | 
| 127 | 
            +
                        new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    # GLU mechanism
         | 
| 130 | 
            +
                    x = self.pointwise_conv1(x)  # (batch, 2*channel, dim)
         | 
| 131 | 
            +
                    x = nn.functional.glu(x, dim=1)  # (batch, channel, dim)
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    # 1D Depthwise Conv
         | 
| 134 | 
            +
                    x = self.depthwise_conv(x)
         | 
| 135 | 
            +
                    if self.use_layer_norm:
         | 
| 136 | 
            +
                        x = x.transpose(1, 2)
         | 
| 137 | 
            +
                    x = self.activation(self.norm(x))
         | 
| 138 | 
            +
                    if self.use_layer_norm:
         | 
| 139 | 
            +
                        x = x.transpose(1, 2)
         | 
| 140 | 
            +
                    x = self.pointwise_conv2(x)
         | 
| 141 | 
            +
                    # mask batch padding
         | 
| 142 | 
            +
                    if mask_pad.size(2) > 0:  # time > 0
         | 
| 143 | 
            +
                        x.masked_fill_(~mask_pad, 0.0)
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    return x.transpose(1, 2), new_cache
         | 
    	
        chatterbox/src/chatterbox/models/s3gen/transformer/embedding.py
    ADDED
    
    | @@ -0,0 +1,294 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
         | 
| 2 | 
            +
            #               2024 Alibaba Inc (Xiang Lyu)
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 5 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 6 | 
            +
            # You may obtain a copy of the License at
         | 
| 7 | 
            +
            #
         | 
| 8 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 9 | 
            +
            #
         | 
| 10 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 11 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 12 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 13 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 14 | 
            +
            # limitations under the License.
         | 
| 15 | 
            +
            # Modified from ESPnet(https://github.com/espnet/espnet)
         | 
| 16 | 
            +
            """Positonal Encoding Module."""
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            import math
         | 
| 19 | 
            +
            from typing import Tuple, Union
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            import torch
         | 
| 22 | 
            +
            import torch.nn.functional as F
         | 
| 23 | 
            +
            import numpy as np
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            class PositionalEncoding(torch.nn.Module):
         | 
| 27 | 
            +
                """Positional encoding.
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                :param int d_model: embedding dim
         | 
| 30 | 
            +
                :param float dropout_rate: dropout rate
         | 
| 31 | 
            +
                :param int max_len: maximum input length
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                PE(pos, 2i)   = sin(pos/(10000^(2i/dmodel)))
         | 
| 34 | 
            +
                PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
         | 
| 35 | 
            +
                """
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                def __init__(self,
         | 
| 38 | 
            +
                             d_model: int,
         | 
| 39 | 
            +
                             dropout_rate: float,
         | 
| 40 | 
            +
                             max_len: int = 5000,
         | 
| 41 | 
            +
                             reverse: bool = False):
         | 
| 42 | 
            +
                    """Construct an PositionalEncoding object."""
         | 
| 43 | 
            +
                    super().__init__()
         | 
| 44 | 
            +
                    self.d_model = d_model
         | 
| 45 | 
            +
                    self.xscale = math.sqrt(self.d_model)
         | 
| 46 | 
            +
                    self.dropout = torch.nn.Dropout(p=dropout_rate)
         | 
| 47 | 
            +
                    self.max_len = max_len
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                    self.pe = torch.zeros(self.max_len, self.d_model)
         | 
| 50 | 
            +
                    position = torch.arange(0, self.max_len,
         | 
| 51 | 
            +
                                            dtype=torch.float32).unsqueeze(1)
         | 
| 52 | 
            +
                    div_term = torch.exp(
         | 
| 53 | 
            +
                        torch.arange(0, self.d_model, 2, dtype=torch.float32) *
         | 
| 54 | 
            +
                        -(math.log(10000.0) / self.d_model))
         | 
| 55 | 
            +
                    self.pe[:, 0::2] = torch.sin(position * div_term)
         | 
| 56 | 
            +
                    self.pe[:, 1::2] = torch.cos(position * div_term)
         | 
| 57 | 
            +
                    self.pe = self.pe.unsqueeze(0)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                def forward(self,
         | 
| 60 | 
            +
                            x: torch.Tensor,
         | 
| 61 | 
            +
                            offset: Union[int, torch.Tensor] = 0) \
         | 
| 62 | 
            +
                        -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 63 | 
            +
                    """Add positional encoding.
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    Args:
         | 
| 66 | 
            +
                        x (torch.Tensor): Input. Its shape is (batch, time, ...)
         | 
| 67 | 
            +
                        offset (int, torch.tensor): position offset
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    Returns:
         | 
| 70 | 
            +
                        torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
         | 
| 71 | 
            +
                        torch.Tensor: for compatibility to RelPositionalEncoding
         | 
| 72 | 
            +
                    """
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    self.pe = self.pe.to(x.device)
         | 
| 75 | 
            +
                    pos_emb = self.position_encoding(offset, x.size(1), False)
         | 
| 76 | 
            +
                    x = x * self.xscale + pos_emb
         | 
| 77 | 
            +
                    return self.dropout(x), self.dropout(pos_emb)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                def position_encoding(self,
         | 
| 80 | 
            +
                                      offset: Union[int, torch.Tensor],
         | 
| 81 | 
            +
                                      size: int,
         | 
| 82 | 
            +
                                      apply_dropout: bool = True) -> torch.Tensor:
         | 
| 83 | 
            +
                    """ For getting encoding in a streaming fashion
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    Attention!!!!!
         | 
| 86 | 
            +
                    we apply dropout only once at the whole utterance level in a none
         | 
| 87 | 
            +
                    streaming way, but will call this function several times with
         | 
| 88 | 
            +
                    increasing input size in a streaming scenario, so the dropout will
         | 
| 89 | 
            +
                    be applied several times.
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                    Args:
         | 
| 92 | 
            +
                        offset (int or torch.tensor): start offset
         | 
| 93 | 
            +
                        size (int): required size of position encoding
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                    Returns:
         | 
| 96 | 
            +
                        torch.Tensor: Corresponding encoding
         | 
| 97 | 
            +
                    """
         | 
| 98 | 
            +
                    # How to subscript a Union type:
         | 
| 99 | 
            +
                    #   https://github.com/pytorch/pytorch/issues/69434
         | 
| 100 | 
            +
                    if isinstance(offset, int):
         | 
| 101 | 
            +
                        assert offset + size <= self.max_len
         | 
| 102 | 
            +
                        pos_emb = self.pe[:, offset:offset + size]
         | 
| 103 | 
            +
                    elif isinstance(offset, torch.Tensor) and offset.dim() == 0:  # scalar
         | 
| 104 | 
            +
                        assert offset + size <= self.max_len
         | 
| 105 | 
            +
                        pos_emb = self.pe[:, offset:offset + size]
         | 
| 106 | 
            +
                    else:  # for batched streaming decoding on GPU
         | 
| 107 | 
            +
                        assert torch.max(offset) + size <= self.max_len
         | 
| 108 | 
            +
                        index = offset.unsqueeze(1) + \
         | 
| 109 | 
            +
                            torch.arange(0, size).to(offset.device)  # B X T
         | 
| 110 | 
            +
                        flag = index > 0
         | 
| 111 | 
            +
                        # remove negative offset
         | 
| 112 | 
            +
                        index = index * flag
         | 
| 113 | 
            +
                        pos_emb = F.embedding(index, self.pe[0])  # B X T X d_model
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    if apply_dropout:
         | 
| 116 | 
            +
                        pos_emb = self.dropout(pos_emb)
         | 
| 117 | 
            +
                    return pos_emb
         | 
| 118 | 
            +
             | 
| 119 | 
            +
             | 
| 120 | 
            +
            class RelPositionalEncoding(PositionalEncoding):
         | 
| 121 | 
            +
                """Relative positional encoding module.
         | 
| 122 | 
            +
                See : Appendix B in https://arxiv.org/abs/1901.02860
         | 
| 123 | 
            +
                Args:
         | 
| 124 | 
            +
                    d_model (int): Embedding dimension.
         | 
| 125 | 
            +
                    dropout_rate (float): Dropout rate.
         | 
| 126 | 
            +
                    max_len (int): Maximum input length.
         | 
| 127 | 
            +
                """
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
         | 
| 130 | 
            +
                    """Initialize class."""
         | 
| 131 | 
            +
                    super().__init__(d_model, dropout_rate, max_len, reverse=True)
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                def forward(self,
         | 
| 134 | 
            +
                            x: torch.Tensor,
         | 
| 135 | 
            +
                            offset: Union[int, torch.Tensor] = 0) \
         | 
| 136 | 
            +
                        -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 137 | 
            +
                    """Compute positional encoding.
         | 
| 138 | 
            +
                    Args:
         | 
| 139 | 
            +
                        x (torch.Tensor): Input tensor (batch, time, `*`).
         | 
| 140 | 
            +
                    Returns:
         | 
| 141 | 
            +
                        torch.Tensor: Encoded tensor (batch, time, `*`).
         | 
| 142 | 
            +
                        torch.Tensor: Positional embedding tensor (1, time, `*`).
         | 
| 143 | 
            +
                    """
         | 
| 144 | 
            +
                    self.pe = self.pe.to(x.device)
         | 
| 145 | 
            +
                    x = x * self.xscale
         | 
| 146 | 
            +
                    pos_emb = self.position_encoding(offset, x.size(1), False)
         | 
| 147 | 
            +
                    return self.dropout(x), self.dropout(pos_emb)
         | 
| 148 | 
            +
             | 
| 149 | 
            +
             | 
| 150 | 
            +
            class WhisperPositionalEncoding(PositionalEncoding):
         | 
| 151 | 
            +
                """ Sinusoids position encoding used in openai-whisper.encoder
         | 
| 152 | 
            +
                """
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                def __init__(self, d_model: int, dropout_rate: float, max_len: int = 1500):
         | 
| 155 | 
            +
                    super().__init__(d_model, dropout_rate, max_len)
         | 
| 156 | 
            +
                    self.xscale = 1.0
         | 
| 157 | 
            +
                    log_timescale_increment = np.log(10000) / (d_model // 2 - 1)
         | 
| 158 | 
            +
                    inv_timescales = torch.exp(-log_timescale_increment *
         | 
| 159 | 
            +
                                               torch.arange(d_model // 2))
         | 
| 160 | 
            +
                    scaled_time = torch.arange(max_len)[:, np.newaxis] * \
         | 
| 161 | 
            +
                        inv_timescales[np.newaxis, :]
         | 
| 162 | 
            +
                    pe = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
         | 
| 163 | 
            +
                    delattr(self, "pe")
         | 
| 164 | 
            +
                    self.register_buffer("pe", pe.unsqueeze(0))
         | 
| 165 | 
            +
             | 
| 166 | 
            +
             | 
| 167 | 
            +
            class LearnablePositionalEncoding(PositionalEncoding):
         | 
| 168 | 
            +
                """ Learnable position encoding used in openai-whisper.decoder
         | 
| 169 | 
            +
                """
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                def __init__(self, d_model: int, dropout_rate: float, max_len: int = 448):
         | 
| 172 | 
            +
                    super().__init__(d_model, dropout_rate, max_len)
         | 
| 173 | 
            +
                    # NOTE(xcsong): overwrite self.pe & self.xscale
         | 
| 174 | 
            +
                    self.pe = torch.nn.Parameter(torch.empty(1, max_len, d_model))
         | 
| 175 | 
            +
                    self.xscale = 1.0
         | 
| 176 | 
            +
             | 
| 177 | 
            +
             | 
| 178 | 
            +
            class NoPositionalEncoding(torch.nn.Module):
         | 
| 179 | 
            +
                """ No position encoding
         | 
| 180 | 
            +
                """
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                def __init__(self, d_model: int, dropout_rate: float):
         | 
| 183 | 
            +
                    super().__init__()
         | 
| 184 | 
            +
                    self.d_model = d_model
         | 
| 185 | 
            +
                    self.dropout = torch.nn.Dropout(p=dropout_rate)
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                def forward(self,
         | 
| 188 | 
            +
                            x: torch.Tensor,
         | 
| 189 | 
            +
                            offset: Union[int, torch.Tensor] = 0) \
         | 
| 190 | 
            +
                        -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 191 | 
            +
                    """ Just return zero vector for interface compatibility
         | 
| 192 | 
            +
                    """
         | 
| 193 | 
            +
                    pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device)
         | 
| 194 | 
            +
                    return self.dropout(x), pos_emb
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                def position_encoding(self, offset: Union[int, torch.Tensor],
         | 
| 197 | 
            +
                                      size: int) -> torch.Tensor:
         | 
| 198 | 
            +
                    return torch.zeros(1, size, self.d_model)
         | 
| 199 | 
            +
             | 
| 200 | 
            +
             | 
| 201 | 
            +
            class EspnetRelPositionalEncoding(torch.nn.Module):
         | 
| 202 | 
            +
                """Relative positional encoding module (new implementation).
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                Details can be found in https://github.com/espnet/espnet/pull/2816.
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                See : Appendix B in https://arxiv.org/abs/1901.02860
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                Args:
         | 
| 209 | 
            +
                    d_model (int): Embedding dimension.
         | 
| 210 | 
            +
                    dropout_rate (float): Dropout rate.
         | 
| 211 | 
            +
                    max_len (int): Maximum input length.
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                """
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
         | 
| 216 | 
            +
                    """Construct an PositionalEncoding object."""
         | 
| 217 | 
            +
                    super(EspnetRelPositionalEncoding, self).__init__()
         | 
| 218 | 
            +
                    self.d_model = d_model
         | 
| 219 | 
            +
                    self.xscale = math.sqrt(self.d_model)
         | 
| 220 | 
            +
                    self.dropout = torch.nn.Dropout(p=dropout_rate)
         | 
| 221 | 
            +
                    self.pe = None
         | 
| 222 | 
            +
                    self.extend_pe(torch.tensor(0.0).expand(1, max_len))
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                def extend_pe(self, x: torch.Tensor):
         | 
| 225 | 
            +
                    """Reset the positional encodings."""
         | 
| 226 | 
            +
                    if self.pe is not None:
         | 
| 227 | 
            +
                        # self.pe contains both positive and negative parts
         | 
| 228 | 
            +
                        # the length of self.pe is 2 * input_len - 1
         | 
| 229 | 
            +
                        if self.pe.size(1) >= x.size(1) * 2 - 1:
         | 
| 230 | 
            +
                            if self.pe.dtype != x.dtype or self.pe.device != x.device:
         | 
| 231 | 
            +
                                self.pe = self.pe.to(dtype=x.dtype, device=x.device)
         | 
| 232 | 
            +
                            return
         | 
| 233 | 
            +
                    # Suppose `i` means to the position of query vecotr and `j` means the
         | 
| 234 | 
            +
                    # position of key vector. We use position relative positions when keys
         | 
| 235 | 
            +
                    # are to the left (i>j) and negative relative positions otherwise (i<j).
         | 
| 236 | 
            +
                    pe_positive = torch.zeros(x.size(1), self.d_model)
         | 
| 237 | 
            +
                    pe_negative = torch.zeros(x.size(1), self.d_model)
         | 
| 238 | 
            +
                    position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
         | 
| 239 | 
            +
                    div_term = torch.exp(
         | 
| 240 | 
            +
                        torch.arange(0, self.d_model, 2, dtype=torch.float32)
         | 
| 241 | 
            +
                        * -(math.log(10000.0) / self.d_model)
         | 
| 242 | 
            +
                    )
         | 
| 243 | 
            +
                    pe_positive[:, 0::2] = torch.sin(position * div_term)
         | 
| 244 | 
            +
                    pe_positive[:, 1::2] = torch.cos(position * div_term)
         | 
| 245 | 
            +
                    pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
         | 
| 246 | 
            +
                    pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                    # Reserve the order of positive indices and concat both positive and
         | 
| 249 | 
            +
                    # negative indices. This is used to support the shifting trick
         | 
| 250 | 
            +
                    # as in https://arxiv.org/abs/1901.02860
         | 
| 251 | 
            +
                    pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
         | 
| 252 | 
            +
                    pe_negative = pe_negative[1:].unsqueeze(0)
         | 
| 253 | 
            +
                    pe = torch.cat([pe_positive, pe_negative], dim=1)
         | 
| 254 | 
            +
                    self.pe = pe.to(device=x.device, dtype=x.dtype)
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \
         | 
| 257 | 
            +
                        -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 258 | 
            +
                    """Add positional encoding.
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                    Args:
         | 
| 261 | 
            +
                        x (torch.Tensor): Input tensor (batch, time, `*`).
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                    Returns:
         | 
| 264 | 
            +
                        torch.Tensor: Encoded tensor (batch, time, `*`).
         | 
| 265 | 
            +
             | 
| 266 | 
            +
                    """
         | 
| 267 | 
            +
                    self.extend_pe(x)
         | 
| 268 | 
            +
                    x = x * self.xscale
         | 
| 269 | 
            +
                    pos_emb = self.position_encoding(size=x.size(1), offset=offset)
         | 
| 270 | 
            +
                    return self.dropout(x), self.dropout(pos_emb)
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                def position_encoding(self,
         | 
| 273 | 
            +
                                      offset: Union[int, torch.Tensor],
         | 
| 274 | 
            +
                                      size: int) -> torch.Tensor:
         | 
| 275 | 
            +
                    """ For getting encoding in a streaming fashion
         | 
| 276 | 
            +
             | 
| 277 | 
            +
                    Attention!!!!!
         | 
| 278 | 
            +
                    we apply dropout only once at the whole utterance level in a none
         | 
| 279 | 
            +
                    streaming way, but will call this function several times with
         | 
| 280 | 
            +
                    increasing input size in a streaming scenario, so the dropout will
         | 
| 281 | 
            +
                    be applied several times.
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                    Args:
         | 
| 284 | 
            +
                        offset (int or torch.tensor): start offset
         | 
| 285 | 
            +
                        size (int): required size of position encoding
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                    Returns:
         | 
| 288 | 
            +
                        torch.Tensor: Corresponding encoding
         | 
| 289 | 
            +
                    """
         | 
| 290 | 
            +
                    pos_emb = self.pe[
         | 
| 291 | 
            +
                        :,
         | 
| 292 | 
            +
                        self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size,
         | 
| 293 | 
            +
                    ]
         | 
| 294 | 
            +
                    return pos_emb
         | 
    	
        chatterbox/src/chatterbox/models/s3gen/transformer/encoder_layer.py
    ADDED
    
    | @@ -0,0 +1,236 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
         | 
| 2 | 
            +
            #               2022 Xingchen Song ([email protected])
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 5 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 6 | 
            +
            # You may obtain a copy of the License at
         | 
| 7 | 
            +
            #
         | 
| 8 | 
            +
            #   http://www.apache.org/licenses/LICENSE-2.0
         | 
| 9 | 
            +
            #
         | 
| 10 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 11 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 12 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 13 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 14 | 
            +
            # limitations under the License.
         | 
| 15 | 
            +
            # Modified from ESPnet(https://github.com/espnet/espnet)
         | 
| 16 | 
            +
            """Encoder self-attention layer definition."""
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from typing import Optional, Tuple
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            import torch
         | 
| 21 | 
            +
            from torch import nn
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            class TransformerEncoderLayer(nn.Module):
         | 
| 25 | 
            +
                """Encoder layer module.
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                Args:
         | 
| 28 | 
            +
                    size (int): Input dimension.
         | 
| 29 | 
            +
                    self_attn (torch.nn.Module): Self-attention module instance.
         | 
| 30 | 
            +
                        `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
         | 
| 31 | 
            +
                        instance can be used as the argument.
         | 
| 32 | 
            +
                    feed_forward (torch.nn.Module): Feed-forward module instance.
         | 
| 33 | 
            +
                        `PositionwiseFeedForward`, instance can be used as the argument.
         | 
| 34 | 
            +
                    dropout_rate (float): Dropout rate.
         | 
| 35 | 
            +
                    normalize_before (bool):
         | 
| 36 | 
            +
                        True: use layer_norm before each sub-block.
         | 
| 37 | 
            +
                        False: to use layer_norm after each sub-block.
         | 
| 38 | 
            +
                """
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                def __init__(
         | 
| 41 | 
            +
                    self,
         | 
| 42 | 
            +
                    size: int,
         | 
| 43 | 
            +
                    self_attn: torch.nn.Module,
         | 
| 44 | 
            +
                    feed_forward: torch.nn.Module,
         | 
| 45 | 
            +
                    dropout_rate: float,
         | 
| 46 | 
            +
                    normalize_before: bool = True,
         | 
| 47 | 
            +
                ):
         | 
| 48 | 
            +
                    """Construct an EncoderLayer object."""
         | 
| 49 | 
            +
                    super().__init__()
         | 
| 50 | 
            +
                    self.self_attn = self_attn
         | 
| 51 | 
            +
                    self.feed_forward = feed_forward
         | 
| 52 | 
            +
                    self.norm1 = nn.LayerNorm(size, eps=1e-12)
         | 
| 53 | 
            +
                    self.norm2 = nn.LayerNorm(size, eps=1e-12)
         | 
| 54 | 
            +
                    self.dropout = nn.Dropout(dropout_rate)
         | 
| 55 | 
            +
                    self.size = size
         | 
| 56 | 
            +
                    self.normalize_before = normalize_before
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                def forward(
         | 
| 59 | 
            +
                    self,
         | 
| 60 | 
            +
                    x: torch.Tensor,
         | 
| 61 | 
            +
                    mask: torch.Tensor,
         | 
| 62 | 
            +
                    pos_emb: torch.Tensor,
         | 
| 63 | 
            +
                    mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
         | 
| 64 | 
            +
                    att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
         | 
| 65 | 
            +
                    cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
         | 
| 66 | 
            +
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
         | 
| 67 | 
            +
                    """Compute encoded features.
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    Args:
         | 
| 70 | 
            +
                        x (torch.Tensor): (#batch, time, size)
         | 
| 71 | 
            +
                        mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
         | 
| 72 | 
            +
                            (0, 0, 0) means fake mask.
         | 
| 73 | 
            +
                        pos_emb (torch.Tensor): just for interface compatibility
         | 
| 74 | 
            +
                            to ConformerEncoderLayer
         | 
| 75 | 
            +
                        mask_pad (torch.Tensor): does not used in transformer layer,
         | 
| 76 | 
            +
                            just for unified api with conformer.
         | 
| 77 | 
            +
                        att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
         | 
| 78 | 
            +
                            (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
         | 
| 79 | 
            +
                        cnn_cache (torch.Tensor): Convolution cache in conformer layer
         | 
| 80 | 
            +
                            (#batch=1, size, cache_t2), not used here, it's for interface
         | 
| 81 | 
            +
                            compatibility to ConformerEncoderLayer.
         | 
| 82 | 
            +
                    Returns:
         | 
| 83 | 
            +
                        torch.Tensor: Output tensor (#batch, time, size).
         | 
| 84 | 
            +
                        torch.Tensor: Mask tensor (#batch, time, time).
         | 
| 85 | 
            +
                        torch.Tensor: att_cache tensor,
         | 
| 86 | 
            +
                            (#batch=1, head, cache_t1 + time, d_k * 2).
         | 
| 87 | 
            +
                        torch.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2).
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    """
         | 
| 90 | 
            +
                    residual = x
         | 
| 91 | 
            +
                    if self.normalize_before:
         | 
| 92 | 
            +
                        x = self.norm1(x)
         | 
| 93 | 
            +
                    x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb=pos_emb, cache=att_cache)
         | 
| 94 | 
            +
                    x = residual + self.dropout(x_att)
         | 
| 95 | 
            +
                    if not self.normalize_before:
         | 
| 96 | 
            +
                        x = self.norm1(x)
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    residual = x
         | 
| 99 | 
            +
                    if self.normalize_before:
         | 
| 100 | 
            +
                        x = self.norm2(x)
         | 
| 101 | 
            +
                    x = residual + self.dropout(self.feed_forward(x))
         | 
| 102 | 
            +
                    if not self.normalize_before:
         | 
| 103 | 
            +
                        x = self.norm2(x)
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                    fake_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
         | 
| 106 | 
            +
                    return x, mask, new_att_cache, fake_cnn_cache
         | 
| 107 | 
            +
             | 
| 108 | 
            +
             | 
| 109 | 
            +
            class ConformerEncoderLayer(nn.Module):
         | 
| 110 | 
            +
                """Encoder layer module.
         | 
| 111 | 
            +
                Args:
         | 
| 112 | 
            +
                    size (int): Input dimension.
         | 
| 113 | 
            +
                    self_attn (torch.nn.Module): Self-attention module instance.
         | 
| 114 | 
            +
                        `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
         | 
| 115 | 
            +
                        instance can be used as the argument.
         | 
| 116 | 
            +
                    feed_forward (torch.nn.Module): Feed-forward module instance.
         | 
| 117 | 
            +
                        `PositionwiseFeedForward` instance can be used as the argument.
         | 
| 118 | 
            +
                    feed_forward_macaron (torch.nn.Module): Additional feed-forward module
         | 
| 119 | 
            +
                         instance.
         | 
| 120 | 
            +
                        `PositionwiseFeedForward` instance can be used as the argument.
         | 
| 121 | 
            +
                    conv_module (torch.nn.Module): Convolution module instance.
         | 
| 122 | 
            +
                        `ConvlutionModule` instance can be used as the argument.
         | 
| 123 | 
            +
                    dropout_rate (float): Dropout rate.
         | 
| 124 | 
            +
                    normalize_before (bool):
         | 
| 125 | 
            +
                        True: use layer_norm before each sub-block.
         | 
| 126 | 
            +
                        False: use layer_norm after each sub-block.
         | 
| 127 | 
            +
                """
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                def __init__(
         | 
| 130 | 
            +
                    self,
         | 
| 131 | 
            +
                    size: int,
         | 
| 132 | 
            +
                    self_attn: torch.nn.Module,
         | 
| 133 | 
            +
                    feed_forward: Optional[nn.Module] = None,
         | 
| 134 | 
            +
                    feed_forward_macaron: Optional[nn.Module] = None,
         | 
| 135 | 
            +
                    conv_module: Optional[nn.Module] = None,
         | 
| 136 | 
            +
                    dropout_rate: float = 0.1,
         | 
| 137 | 
            +
                    normalize_before: bool = True,
         | 
| 138 | 
            +
                ):
         | 
| 139 | 
            +
                    """Construct an EncoderLayer object."""
         | 
| 140 | 
            +
                    super().__init__()
         | 
| 141 | 
            +
                    self.self_attn = self_attn
         | 
| 142 | 
            +
                    self.feed_forward = feed_forward
         | 
| 143 | 
            +
                    self.feed_forward_macaron = feed_forward_macaron
         | 
| 144 | 
            +
                    self.conv_module = conv_module
         | 
| 145 | 
            +
                    self.norm_ff = nn.LayerNorm(size, eps=1e-12)  # for the FNN module
         | 
| 146 | 
            +
                    self.norm_mha = nn.LayerNorm(size, eps=1e-12)  # for the MHA module
         | 
| 147 | 
            +
                    if feed_forward_macaron is not None:
         | 
| 148 | 
            +
                        self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12)
         | 
| 149 | 
            +
                        self.ff_scale = 0.5
         | 
| 150 | 
            +
                    else:
         | 
| 151 | 
            +
                        self.ff_scale = 1.0
         | 
| 152 | 
            +
                    if self.conv_module is not None:
         | 
| 153 | 
            +
                        self.norm_conv = nn.LayerNorm(size, eps=1e-12)  # for the CNN module
         | 
| 154 | 
            +
                        self.norm_final = nn.LayerNorm(
         | 
| 155 | 
            +
                            size, eps=1e-12)  # for the final output of the block
         | 
| 156 | 
            +
                    self.dropout = nn.Dropout(dropout_rate)
         | 
| 157 | 
            +
                    self.size = size
         | 
| 158 | 
            +
                    self.normalize_before = normalize_before
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                def forward(
         | 
| 161 | 
            +
                    self,
         | 
| 162 | 
            +
                    x: torch.Tensor,
         | 
| 163 | 
            +
                    mask: torch.Tensor,
         | 
| 164 | 
            +
                    pos_emb: torch.Tensor,
         | 
| 165 | 
            +
                    mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
         | 
| 166 | 
            +
                    att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
         | 
| 167 | 
            +
                    cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
         | 
| 168 | 
            +
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
         | 
| 169 | 
            +
                    """Compute encoded features.
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                    Args:
         | 
| 172 | 
            +
                        x (torch.Tensor): (#batch, time, size)
         | 
| 173 | 
            +
                        mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
         | 
| 174 | 
            +
                            (0, 0, 0) means fake mask.
         | 
| 175 | 
            +
                        pos_emb (torch.Tensor): positional encoding, must not be None
         | 
| 176 | 
            +
                            for ConformerEncoderLayer.
         | 
| 177 | 
            +
                        mask_pad (torch.Tensor): batch padding mask used for conv module.
         | 
| 178 | 
            +
                            (#batch, 1,time), (0, 0, 0) means fake mask.
         | 
| 179 | 
            +
                        att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
         | 
| 180 | 
            +
                            (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
         | 
| 181 | 
            +
                        cnn_cache (torch.Tensor): Convolution cache in conformer layer
         | 
| 182 | 
            +
                            (#batch=1, size, cache_t2)
         | 
| 183 | 
            +
                    Returns:
         | 
| 184 | 
            +
                        torch.Tensor: Output tensor (#batch, time, size).
         | 
| 185 | 
            +
                        torch.Tensor: Mask tensor (#batch, time, time).
         | 
| 186 | 
            +
                        torch.Tensor: att_cache tensor,
         | 
| 187 | 
            +
                            (#batch=1, head, cache_t1 + time, d_k * 2).
         | 
| 188 | 
            +
                        torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
         | 
| 189 | 
            +
                    """
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                    # whether to use macaron style
         | 
| 192 | 
            +
                    if self.feed_forward_macaron is not None:
         | 
| 193 | 
            +
                        residual = x
         | 
| 194 | 
            +
                        if self.normalize_before:
         | 
| 195 | 
            +
                            x = self.norm_ff_macaron(x)
         | 
| 196 | 
            +
                        x = residual + self.ff_scale * self.dropout(
         | 
| 197 | 
            +
                            self.feed_forward_macaron(x))
         | 
| 198 | 
            +
                        if not self.normalize_before:
         | 
| 199 | 
            +
                            x = self.norm_ff_macaron(x)
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                    # multi-headed self-attention module
         | 
| 202 | 
            +
                    residual = x
         | 
| 203 | 
            +
                    if self.normalize_before:
         | 
| 204 | 
            +
                        x = self.norm_mha(x)
         | 
| 205 | 
            +
                    x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
         | 
| 206 | 
            +
                                                          att_cache)
         | 
| 207 | 
            +
                    x = residual + self.dropout(x_att)
         | 
| 208 | 
            +
                    if not self.normalize_before:
         | 
| 209 | 
            +
                        x = self.norm_mha(x)
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                    # convolution module
         | 
| 212 | 
            +
                    # Fake new cnn cache here, and then change it in conv_module
         | 
| 213 | 
            +
                    new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
         | 
| 214 | 
            +
                    if self.conv_module is not None:
         | 
| 215 | 
            +
                        residual = x
         | 
| 216 | 
            +
                        if self.normalize_before:
         | 
| 217 | 
            +
                            x = self.norm_conv(x)
         | 
| 218 | 
            +
                        x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
         | 
| 219 | 
            +
                        x = residual + self.dropout(x)
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                        if not self.normalize_before:
         | 
| 222 | 
            +
                            x = self.norm_conv(x)
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                    # feed forward module
         | 
| 225 | 
            +
                    residual = x
         | 
| 226 | 
            +
                    if self.normalize_before:
         | 
| 227 | 
            +
                        x = self.norm_ff(x)
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                    x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
         | 
| 230 | 
            +
                    if not self.normalize_before:
         | 
| 231 | 
            +
                        x = self.norm_ff(x)
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                    if self.conv_module is not None:
         | 
| 234 | 
            +
                        x = self.norm_final(x)
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                    return x, mask, new_att_cache, new_cnn_cache
         | 
    	
        chatterbox/src/chatterbox/models/s3gen/transformer/positionwise_feed_forward.py
    ADDED
    
    | @@ -0,0 +1,115 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2019 Shigeki Karita
         | 
| 2 | 
            +
            #               2020 Mobvoi Inc (Binbin Zhang)
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 5 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 6 | 
            +
            # You may obtain a copy of the License at
         | 
| 7 | 
            +
            #
         | 
| 8 | 
            +
            #   http://www.apache.org/licenses/LICENSE-2.0
         | 
| 9 | 
            +
            #
         | 
| 10 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 11 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 12 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 13 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 14 | 
            +
            # limitations under the License.
         | 
| 15 | 
            +
            """Positionwise feed forward layer definition."""
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            import torch
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            class PositionwiseFeedForward(torch.nn.Module):
         | 
| 21 | 
            +
                """Positionwise feed forward layer.
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                FeedForward are appied on each position of the sequence.
         | 
| 24 | 
            +
                The output dim is same with the input dim.
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                Args:
         | 
| 27 | 
            +
                    idim (int): Input dimenstion.
         | 
| 28 | 
            +
                    hidden_units (int): The number of hidden units.
         | 
| 29 | 
            +
                    dropout_rate (float): Dropout rate.
         | 
| 30 | 
            +
                    activation (torch.nn.Module): Activation function
         | 
| 31 | 
            +
                """
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                def __init__(
         | 
| 34 | 
            +
                        self,
         | 
| 35 | 
            +
                        idim: int,
         | 
| 36 | 
            +
                        hidden_units: int,
         | 
| 37 | 
            +
                        dropout_rate: float,
         | 
| 38 | 
            +
                        activation: torch.nn.Module = torch.nn.ReLU(),
         | 
| 39 | 
            +
                ):
         | 
| 40 | 
            +
                    """Construct a PositionwiseFeedForward object."""
         | 
| 41 | 
            +
                    super(PositionwiseFeedForward, self).__init__()
         | 
| 42 | 
            +
                    self.w_1 = torch.nn.Linear(idim, hidden_units)
         | 
| 43 | 
            +
                    self.activation = activation
         | 
| 44 | 
            +
                    self.dropout = torch.nn.Dropout(dropout_rate)
         | 
| 45 | 
            +
                    self.w_2 = torch.nn.Linear(hidden_units, idim)
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                def forward(self, xs: torch.Tensor) -> torch.Tensor:
         | 
| 48 | 
            +
                    """Forward function.
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                    Args:
         | 
| 51 | 
            +
                        xs: input tensor (B, L, D)
         | 
| 52 | 
            +
                    Returns:
         | 
| 53 | 
            +
                        output tensor, (B, L, D)
         | 
| 54 | 
            +
                    """
         | 
| 55 | 
            +
                    return self.w_2(self.dropout(self.activation(self.w_1(xs))))
         | 
| 56 | 
            +
             | 
| 57 | 
            +
             | 
| 58 | 
            +
            class MoEFFNLayer(torch.nn.Module):
         | 
| 59 | 
            +
                """
         | 
| 60 | 
            +
                Mixture of expert with Positionwise feed forward layer
         | 
| 61 | 
            +
                See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf
         | 
| 62 | 
            +
                The output dim is same with the input dim.
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                Modified from https://github.com/Lightning-AI/lit-gpt/pull/823
         | 
| 65 | 
            +
                              https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219
         | 
| 66 | 
            +
                Args:
         | 
| 67 | 
            +
                    n_expert: number of expert.
         | 
| 68 | 
            +
                    n_expert_per_token: The actual number of experts used for each frame
         | 
| 69 | 
            +
                    idim (int): Input dimenstion.
         | 
| 70 | 
            +
                    hidden_units (int): The number of hidden units.
         | 
| 71 | 
            +
                    dropout_rate (float): Dropout rate.
         | 
| 72 | 
            +
                    activation (torch.nn.Module): Activation function
         | 
| 73 | 
            +
                """
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                def __init__(
         | 
| 76 | 
            +
                        self,
         | 
| 77 | 
            +
                        n_expert: int,
         | 
| 78 | 
            +
                        n_expert_per_token: int,
         | 
| 79 | 
            +
                        idim: int,
         | 
| 80 | 
            +
                        hidden_units: int,
         | 
| 81 | 
            +
                        dropout_rate: float,
         | 
| 82 | 
            +
                        activation: torch.nn.Module = torch.nn.ReLU(),
         | 
| 83 | 
            +
                ):
         | 
| 84 | 
            +
                    super(MoEFFNLayer, self).__init__()
         | 
| 85 | 
            +
                    self.gate = torch.nn.Linear(idim, n_expert, bias=False)
         | 
| 86 | 
            +
                    self.experts = torch.nn.ModuleList(
         | 
| 87 | 
            +
                        PositionwiseFeedForward(idim, hidden_units, dropout_rate,
         | 
| 88 | 
            +
                                                activation) for _ in range(n_expert))
         | 
| 89 | 
            +
                    self.n_expert_per_token = n_expert_per_token
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                def forward(self, xs: torch.Tensor) -> torch.Tensor:
         | 
| 92 | 
            +
                    """Foward function.
         | 
| 93 | 
            +
                    Args:
         | 
| 94 | 
            +
                        xs: input tensor (B, L, D)
         | 
| 95 | 
            +
                    Returns:
         | 
| 96 | 
            +
                        output tensor, (B, L, D)
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    """
         | 
| 99 | 
            +
                    B, L, D = xs.size(
         | 
| 100 | 
            +
                    )  # batch size, sequence length, embedding dimension (idim)
         | 
| 101 | 
            +
                    xs = xs.view(-1, D)  # (B*L, D)
         | 
| 102 | 
            +
                    router = self.gate(xs)  # (B*L, n_expert)
         | 
| 103 | 
            +
                    logits, indices = torch.topk(
         | 
| 104 | 
            +
                        router, self.n_expert_per_token
         | 
| 105 | 
            +
                    )  # probs:(B*L, n_expert), indices: (B*L, n_expert)
         | 
| 106 | 
            +
                    weights = torch.nn.functional.softmax(
         | 
| 107 | 
            +
                        logits, dim=1,
         | 
| 108 | 
            +
                        dtype=torch.float).to(dtype=xs.dtype)  # (B*L, n_expert_per_token)
         | 
| 109 | 
            +
                    output = torch.zeros_like(xs)  # (B*L, D)
         | 
| 110 | 
            +
                    for i, expert in enumerate(self.experts):
         | 
| 111 | 
            +
                        mask = indices == i
         | 
| 112 | 
            +
                        batch_idx, ith_expert = torch.where(mask)
         | 
| 113 | 
            +
                        output[batch_idx] += weights[batch_idx, ith_expert, None] * expert(
         | 
| 114 | 
            +
                            xs[batch_idx])
         | 
| 115 | 
            +
                    return output.view(B, L, D)
         | 
    	
        chatterbox/src/chatterbox/models/s3gen/transformer/subsampling.py
    ADDED
    
    | @@ -0,0 +1,383 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
         | 
| 2 | 
            +
            #               2024 Alibaba Inc (Xiang Lyu)
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 5 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 6 | 
            +
            # You may obtain a copy of the License at
         | 
| 7 | 
            +
            #
         | 
| 8 | 
            +
            #   http://www.apache.org/licenses/LICENSE-2.0
         | 
| 9 | 
            +
            #
         | 
| 10 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 11 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 12 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 13 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 14 | 
            +
            # limitations under the License.
         | 
| 15 | 
            +
            # Modified from ESPnet(https://github.com/espnet/espnet)
         | 
| 16 | 
            +
            """Subsampling layer definition."""
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from typing import Tuple, Union
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            import torch
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            class BaseSubsampling(torch.nn.Module):
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                def __init__(self):
         | 
| 26 | 
            +
                    super().__init__()
         | 
| 27 | 
            +
                    self.right_context = 0
         | 
| 28 | 
            +
                    self.subsampling_rate = 1
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                def position_encoding(self, offset: Union[int, torch.Tensor],
         | 
| 31 | 
            +
                                      size: int) -> torch.Tensor:
         | 
| 32 | 
            +
                    return self.pos_enc.position_encoding(offset, size)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
            class EmbedinigNoSubsampling(BaseSubsampling):
         | 
| 36 | 
            +
                """Embedding input without subsampling
         | 
| 37 | 
            +
                """
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                def __init__(self, idim: int, odim: int, dropout_rate: float,
         | 
| 40 | 
            +
                             pos_enc_class: torch.nn.Module):
         | 
| 41 | 
            +
                    super().__init__()
         | 
| 42 | 
            +
                    self.embed = torch.nn.Embedding(idim, odim)
         | 
| 43 | 
            +
                    self.pos_enc = pos_enc_class
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                def forward(
         | 
| 46 | 
            +
                    self,
         | 
| 47 | 
            +
                    x: torch.Tensor,
         | 
| 48 | 
            +
                    x_mask: torch.Tensor,
         | 
| 49 | 
            +
                    offset: Union[int, torch.Tensor] = 0
         | 
| 50 | 
            +
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
         | 
| 51 | 
            +
                    """Input x.
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                    Args:
         | 
| 54 | 
            +
                        x (torch.Tensor): Input tensor (#batch, time, idim).
         | 
| 55 | 
            +
                        x_mask (torch.Tensor): Input mask (#batch, 1, time).
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                    Returns:
         | 
| 58 | 
            +
                        torch.Tensor: linear input tensor (#batch, time', odim),
         | 
| 59 | 
            +
                            where time' = time .
         | 
| 60 | 
            +
                        torch.Tensor: linear input mask (#batch, 1, time'),
         | 
| 61 | 
            +
                            where time' = time .
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    """
         | 
| 64 | 
            +
                    x = self.embed(x)
         | 
| 65 | 
            +
                    x, pos_emb = self.pos_enc(x, offset)
         | 
| 66 | 
            +
                    return x, pos_emb, x_mask
         | 
| 67 | 
            +
             | 
| 68 | 
            +
             | 
| 69 | 
            +
            class LinearNoSubsampling(BaseSubsampling):
         | 
| 70 | 
            +
                """Linear transform the input without subsampling
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                Args:
         | 
| 73 | 
            +
                    idim (int): Input dimension.
         | 
| 74 | 
            +
                    odim (int): Output dimension.
         | 
| 75 | 
            +
                    dropout_rate (float): Dropout rate.
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                """
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                def __init__(self, idim: int, odim: int, dropout_rate: float,
         | 
| 80 | 
            +
                             pos_enc_class: torch.nn.Module):
         | 
| 81 | 
            +
                    """Construct an linear object."""
         | 
| 82 | 
            +
                    super().__init__()
         | 
| 83 | 
            +
                    self.out = torch.nn.Sequential(
         | 
| 84 | 
            +
                        torch.nn.Linear(idim, odim),
         | 
| 85 | 
            +
                        torch.nn.LayerNorm(odim, eps=1e-5),
         | 
| 86 | 
            +
                        torch.nn.Dropout(dropout_rate),
         | 
| 87 | 
            +
                    )
         | 
| 88 | 
            +
                    self.pos_enc = pos_enc_class
         | 
| 89 | 
            +
                    self.right_context = 0
         | 
| 90 | 
            +
                    self.subsampling_rate = 1
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                def forward(
         | 
| 93 | 
            +
                    self,
         | 
| 94 | 
            +
                    x: torch.Tensor,
         | 
| 95 | 
            +
                    x_mask: torch.Tensor,
         | 
| 96 | 
            +
                    offset: Union[int, torch.Tensor] = 0
         | 
| 97 | 
            +
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
         | 
| 98 | 
            +
                    """Input x.
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                    Args:
         | 
| 101 | 
            +
                        x (torch.Tensor): Input tensor (#batch, time, idim).
         | 
| 102 | 
            +
                        x_mask (torch.Tensor): Input mask (#batch, 1, time).
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    Returns:
         | 
| 105 | 
            +
                        torch.Tensor: linear input tensor (#batch, time', odim),
         | 
| 106 | 
            +
                            where time' = time .
         | 
| 107 | 
            +
                        torch.Tensor: linear input mask (#batch, 1, time'),
         | 
| 108 | 
            +
                            where time' = time .
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                    """
         | 
| 111 | 
            +
                    x = self.out(x)
         | 
| 112 | 
            +
                    x, pos_emb = self.pos_enc(x, offset)
         | 
| 113 | 
            +
                    return x, pos_emb, x_mask
         | 
| 114 | 
            +
             | 
| 115 | 
            +
             | 
| 116 | 
            +
            class Conv1dSubsampling2(BaseSubsampling):
         | 
| 117 | 
            +
                """Convolutional 1D subsampling (to 1/2 length).
         | 
| 118 | 
            +
                   It is designed for Whisper, ref:
         | 
| 119 | 
            +
                   https://github.com/openai/whisper/blob/main/whisper/model.py
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                Args:
         | 
| 122 | 
            +
                    idim (int): Input dimension.
         | 
| 123 | 
            +
                    odim (int): Output dimension.
         | 
| 124 | 
            +
                    dropout_rate (float): Dropout rate.
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                """
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                def __init__(self, idim: int, odim: int, dropout_rate: float,
         | 
| 129 | 
            +
                             pos_enc_class: torch.nn.Module):
         | 
| 130 | 
            +
                    """Construct an Conv1dSubsampling2 object."""
         | 
| 131 | 
            +
                    super().__init__()
         | 
| 132 | 
            +
                    self.conv = torch.nn.Sequential(
         | 
| 133 | 
            +
                        torch.nn.Conv1d(idim, odim, kernel_size=3, padding=1),
         | 
| 134 | 
            +
                        torch.nn.GELU(),
         | 
| 135 | 
            +
                        torch.nn.Conv1d(odim, odim, kernel_size=3, stride=2, padding=1),
         | 
| 136 | 
            +
                        torch.nn.GELU(),
         | 
| 137 | 
            +
                    )
         | 
| 138 | 
            +
                    self.pos_enc = pos_enc_class
         | 
| 139 | 
            +
                    # The right context for every conv layer is computed by:
         | 
| 140 | 
            +
                    # (kernel_size - 1) * frame_rate_of_this_layer
         | 
| 141 | 
            +
                    self.subsampling_rate = 2
         | 
| 142 | 
            +
                    # 4 = (3 - 1) * 1 + (3 - 1) * 1
         | 
| 143 | 
            +
                    self.right_context = 4
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                def forward(
         | 
| 146 | 
            +
                    self,
         | 
| 147 | 
            +
                    x: torch.Tensor,
         | 
| 148 | 
            +
                    x_mask: torch.Tensor,
         | 
| 149 | 
            +
                    offset: Union[int, torch.Tensor] = 0
         | 
| 150 | 
            +
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
         | 
| 151 | 
            +
                    """Subsample x.
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                    Args:
         | 
| 154 | 
            +
                        x (torch.Tensor): Input tensor (#batch, time, idim).
         | 
| 155 | 
            +
                        x_mask (torch.Tensor): Input mask (#batch, 1, time).
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                    Returns:
         | 
| 158 | 
            +
                        torch.Tensor: Subsampled tensor (#batch, time', odim),
         | 
| 159 | 
            +
                            where time' = time // 2.
         | 
| 160 | 
            +
                        torch.Tensor: Subsampled mask (#batch, 1, time'),
         | 
| 161 | 
            +
                            where time' = time // 2.
         | 
| 162 | 
            +
                        torch.Tensor: positional encoding
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    """
         | 
| 165 | 
            +
                    time = x.size(1)
         | 
| 166 | 
            +
                    x = x.transpose(1, 2)  # (b, f, t)
         | 
| 167 | 
            +
                    x = self.conv(x)
         | 
| 168 | 
            +
                    x = x.transpose(1, 2)  # (b, t, f)
         | 
| 169 | 
            +
                    x, pos_emb = self.pos_enc(x, offset)
         | 
| 170 | 
            +
                    return x, pos_emb, x_mask[:, :, (time + 1) % 2::2]
         | 
| 171 | 
            +
             | 
| 172 | 
            +
             | 
| 173 | 
            +
            class Conv2dSubsampling4(BaseSubsampling):
         | 
| 174 | 
            +
                """Convolutional 2D subsampling (to 1/4 length).
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                Args:
         | 
| 177 | 
            +
                    idim (int): Input dimension.
         | 
| 178 | 
            +
                    odim (int): Output dimension.
         | 
| 179 | 
            +
                    dropout_rate (float): Dropout rate.
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                """
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                def __init__(self, idim: int, odim: int, dropout_rate: float,
         | 
| 184 | 
            +
                             pos_enc_class: torch.nn.Module):
         | 
| 185 | 
            +
                    """Construct an Conv2dSubsampling4 object."""
         | 
| 186 | 
            +
                    super().__init__()
         | 
| 187 | 
            +
                    self.conv = torch.nn.Sequential(
         | 
| 188 | 
            +
                        torch.nn.Conv2d(1, odim, 3, 2),
         | 
| 189 | 
            +
                        torch.nn.ReLU(),
         | 
| 190 | 
            +
                        torch.nn.Conv2d(odim, odim, 3, 2),
         | 
| 191 | 
            +
                        torch.nn.ReLU(),
         | 
| 192 | 
            +
                    )
         | 
| 193 | 
            +
                    self.out = torch.nn.Sequential(
         | 
| 194 | 
            +
                        torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim))
         | 
| 195 | 
            +
                    self.pos_enc = pos_enc_class
         | 
| 196 | 
            +
                    # The right context for every conv layer is computed by:
         | 
| 197 | 
            +
                    # (kernel_size - 1) * frame_rate_of_this_layer
         | 
| 198 | 
            +
                    self.subsampling_rate = 4
         | 
| 199 | 
            +
                    # 6 = (3 - 1) * 1 + (3 - 1) * 2
         | 
| 200 | 
            +
                    self.right_context = 6
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                def forward(
         | 
| 203 | 
            +
                    self,
         | 
| 204 | 
            +
                    x: torch.Tensor,
         | 
| 205 | 
            +
                    x_mask: torch.Tensor,
         | 
| 206 | 
            +
                    offset: Union[int, torch.Tensor] = 0
         | 
| 207 | 
            +
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
         | 
| 208 | 
            +
                    """Subsample x.
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                    Args:
         | 
| 211 | 
            +
                        x (torch.Tensor): Input tensor (#batch, time, idim).
         | 
| 212 | 
            +
                        x_mask (torch.Tensor): Input mask (#batch, 1, time).
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                    Returns:
         | 
| 215 | 
            +
                        torch.Tensor: Subsampled tensor (#batch, time', odim),
         | 
| 216 | 
            +
                            where time' = time // 4.
         | 
| 217 | 
            +
                        torch.Tensor: Subsampled mask (#batch, 1, time'),
         | 
| 218 | 
            +
                            where time' = time // 4.
         | 
| 219 | 
            +
                        torch.Tensor: positional encoding
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                    """
         | 
| 222 | 
            +
                    x = x.unsqueeze(1)  # (b, c=1, t, f)
         | 
| 223 | 
            +
                    x = self.conv(x)
         | 
| 224 | 
            +
                    b, c, t, f = x.size()
         | 
| 225 | 
            +
                    x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
         | 
| 226 | 
            +
                    x, pos_emb = self.pos_enc(x, offset)
         | 
| 227 | 
            +
                    return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2]
         | 
| 228 | 
            +
             | 
| 229 | 
            +
             | 
| 230 | 
            +
            class Conv2dSubsampling6(BaseSubsampling):
         | 
| 231 | 
            +
                """Convolutional 2D subsampling (to 1/6 length).
         | 
| 232 | 
            +
                Args:
         | 
| 233 | 
            +
                    idim (int): Input dimension.
         | 
| 234 | 
            +
                    odim (int): Output dimension.
         | 
| 235 | 
            +
                    dropout_rate (float): Dropout rate.
         | 
| 236 | 
            +
                    pos_enc (torch.nn.Module): Custom position encoding layer.
         | 
| 237 | 
            +
                """
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                def __init__(self, idim: int, odim: int, dropout_rate: float,
         | 
| 240 | 
            +
                             pos_enc_class: torch.nn.Module):
         | 
| 241 | 
            +
                    """Construct an Conv2dSubsampling6 object."""
         | 
| 242 | 
            +
                    super().__init__()
         | 
| 243 | 
            +
                    self.conv = torch.nn.Sequential(
         | 
| 244 | 
            +
                        torch.nn.Conv2d(1, odim, 3, 2),
         | 
| 245 | 
            +
                        torch.nn.ReLU(),
         | 
| 246 | 
            +
                        torch.nn.Conv2d(odim, odim, 5, 3),
         | 
| 247 | 
            +
                        torch.nn.ReLU(),
         | 
| 248 | 
            +
                    )
         | 
| 249 | 
            +
                    self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3),
         | 
| 250 | 
            +
                                                  odim)
         | 
| 251 | 
            +
                    self.pos_enc = pos_enc_class
         | 
| 252 | 
            +
                    # 10 = (3 - 1) * 1 + (5 - 1) * 2
         | 
| 253 | 
            +
                    self.subsampling_rate = 6
         | 
| 254 | 
            +
                    self.right_context = 10
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                def forward(
         | 
| 257 | 
            +
                    self,
         | 
| 258 | 
            +
                    x: torch.Tensor,
         | 
| 259 | 
            +
                    x_mask: torch.Tensor,
         | 
| 260 | 
            +
                    offset: Union[int, torch.Tensor] = 0
         | 
| 261 | 
            +
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
         | 
| 262 | 
            +
                    """Subsample x.
         | 
| 263 | 
            +
                    Args:
         | 
| 264 | 
            +
                        x (torch.Tensor): Input tensor (#batch, time, idim).
         | 
| 265 | 
            +
                        x_mask (torch.Tensor): Input mask (#batch, 1, time).
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                    Returns:
         | 
| 268 | 
            +
                        torch.Tensor: Subsampled tensor (#batch, time', odim),
         | 
| 269 | 
            +
                            where time' = time // 6.
         | 
| 270 | 
            +
                        torch.Tensor: Subsampled mask (#batch, 1, time'),
         | 
| 271 | 
            +
                            where time' = time // 6.
         | 
| 272 | 
            +
                        torch.Tensor: positional encoding
         | 
| 273 | 
            +
                    """
         | 
| 274 | 
            +
                    x = x.unsqueeze(1)  # (b, c, t, f)
         | 
| 275 | 
            +
                    x = self.conv(x)
         | 
| 276 | 
            +
                    b, c, t, f = x.size()
         | 
| 277 | 
            +
                    x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
         | 
| 278 | 
            +
                    x, pos_emb = self.pos_enc(x, offset)
         | 
| 279 | 
            +
                    return x, pos_emb, x_mask[:, :, 2::2][:, :, 4::3]
         | 
| 280 | 
            +
             | 
| 281 | 
            +
             | 
| 282 | 
            +
            class Conv2dSubsampling8(BaseSubsampling):
         | 
| 283 | 
            +
                """Convolutional 2D subsampling (to 1/8 length).
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                Args:
         | 
| 286 | 
            +
                    idim (int): Input dimension.
         | 
| 287 | 
            +
                    odim (int): Output dimension.
         | 
| 288 | 
            +
                    dropout_rate (float): Dropout rate.
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                """
         | 
| 291 | 
            +
             | 
| 292 | 
            +
                def __init__(self, idim: int, odim: int, dropout_rate: float,
         | 
| 293 | 
            +
                             pos_enc_class: torch.nn.Module):
         | 
| 294 | 
            +
                    """Construct an Conv2dSubsampling8 object."""
         | 
| 295 | 
            +
                    super().__init__()
         | 
| 296 | 
            +
                    self.conv = torch.nn.Sequential(
         | 
| 297 | 
            +
                        torch.nn.Conv2d(1, odim, 3, 2),
         | 
| 298 | 
            +
                        torch.nn.ReLU(),
         | 
| 299 | 
            +
                        torch.nn.Conv2d(odim, odim, 3, 2),
         | 
| 300 | 
            +
                        torch.nn.ReLU(),
         | 
| 301 | 
            +
                        torch.nn.Conv2d(odim, odim, 3, 2),
         | 
| 302 | 
            +
                        torch.nn.ReLU(),
         | 
| 303 | 
            +
                    )
         | 
| 304 | 
            +
                    self.linear = torch.nn.Linear(
         | 
| 305 | 
            +
                        odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim)
         | 
| 306 | 
            +
                    self.pos_enc = pos_enc_class
         | 
| 307 | 
            +
                    self.subsampling_rate = 8
         | 
| 308 | 
            +
                    # 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4
         | 
| 309 | 
            +
                    self.right_context = 14
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                def forward(
         | 
| 312 | 
            +
                    self,
         | 
| 313 | 
            +
                    x: torch.Tensor,
         | 
| 314 | 
            +
                    x_mask: torch.Tensor,
         | 
| 315 | 
            +
                    offset: Union[int, torch.Tensor] = 0
         | 
| 316 | 
            +
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
         | 
| 317 | 
            +
                    """Subsample x.
         | 
| 318 | 
            +
             | 
| 319 | 
            +
                    Args:
         | 
| 320 | 
            +
                        x (torch.Tensor): Input tensor (#batch, time, idim).
         | 
| 321 | 
            +
                        x_mask (torch.Tensor): Input mask (#batch, 1, time).
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                    Returns:
         | 
| 324 | 
            +
                        torch.Tensor: Subsampled tensor (#batch, time', odim),
         | 
| 325 | 
            +
                            where time' = time // 8.
         | 
| 326 | 
            +
                        torch.Tensor: Subsampled mask (#batch, 1, time'),
         | 
| 327 | 
            +
                            where time' = time // 8.
         | 
| 328 | 
            +
                        torch.Tensor: positional encoding
         | 
| 329 | 
            +
                    """
         | 
| 330 | 
            +
                    x = x.unsqueeze(1)  # (b, c, t, f)
         | 
| 331 | 
            +
                    x = self.conv(x)
         | 
| 332 | 
            +
                    b, c, t, f = x.size()
         | 
| 333 | 
            +
                    x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
         | 
| 334 | 
            +
                    x, pos_emb = self.pos_enc(x, offset)
         | 
| 335 | 
            +
                    return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2]
         | 
| 336 | 
            +
             | 
| 337 | 
            +
             | 
| 338 | 
            +
            class LegacyLinearNoSubsampling(BaseSubsampling):
         | 
| 339 | 
            +
                """Linear transform the input without subsampling
         | 
| 340 | 
            +
             | 
| 341 | 
            +
                Args:
         | 
| 342 | 
            +
                    idim (int): Input dimension.
         | 
| 343 | 
            +
                    odim (int): Output dimension.
         | 
| 344 | 
            +
                    dropout_rate (float): Dropout rate.
         | 
| 345 | 
            +
             | 
| 346 | 
            +
                """
         | 
| 347 | 
            +
             | 
| 348 | 
            +
                def __init__(self, idim: int, odim: int, dropout_rate: float,
         | 
| 349 | 
            +
                             pos_enc_class: torch.nn.Module):
         | 
| 350 | 
            +
                    """Construct an linear object."""
         | 
| 351 | 
            +
                    super().__init__()
         | 
| 352 | 
            +
                    self.out = torch.nn.Sequential(
         | 
| 353 | 
            +
                        torch.nn.Linear(idim, odim),
         | 
| 354 | 
            +
                        torch.nn.LayerNorm(odim, eps=1e-5),
         | 
| 355 | 
            +
                        torch.nn.Dropout(dropout_rate),
         | 
| 356 | 
            +
                        torch.nn.ReLU(),
         | 
| 357 | 
            +
                    )
         | 
| 358 | 
            +
                    self.pos_enc = pos_enc_class
         | 
| 359 | 
            +
                    self.right_context = 0
         | 
| 360 | 
            +
                    self.subsampling_rate = 1
         | 
| 361 | 
            +
             | 
| 362 | 
            +
                def forward(
         | 
| 363 | 
            +
                    self,
         | 
| 364 | 
            +
                    x: torch.Tensor,
         | 
| 365 | 
            +
                    x_mask: torch.Tensor,
         | 
| 366 | 
            +
                    offset: Union[int, torch.Tensor] = 0
         | 
| 367 | 
            +
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
         | 
| 368 | 
            +
                    """Input x.
         | 
| 369 | 
            +
             | 
| 370 | 
            +
                    Args:
         | 
| 371 | 
            +
                        x (torch.Tensor): Input tensor (#batch, time, idim).
         | 
| 372 | 
            +
                        x_mask (torch.Tensor): Input mask (#batch, 1, time).
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                    Returns:
         | 
| 375 | 
            +
                        torch.Tensor: linear input tensor (#batch, time', odim),
         | 
| 376 | 
            +
                            where time' = time .
         | 
| 377 | 
            +
                        torch.Tensor: linear input mask (#batch, 1, time'),
         | 
| 378 | 
            +
                            where time' = time .
         | 
| 379 | 
            +
             | 
| 380 | 
            +
                    """
         | 
| 381 | 
            +
                    x = self.out(x)
         | 
| 382 | 
            +
                    x, pos_emb = self.pos_enc(x, offset)
         | 
| 383 | 
            +
                    return x, pos_emb, x_mask
         | 
    	
        chatterbox/src/chatterbox/models/s3gen/transformer/upsample_encoder.py
    ADDED
    
    | @@ -0,0 +1,318 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
         | 
| 2 | 
            +
            #               2022 Xingchen Song ([email protected])
         | 
| 3 | 
            +
            #               2024 Alibaba Inc (Xiang Lyu)
         | 
| 4 | 
            +
            #
         | 
| 5 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 6 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 7 | 
            +
            # You may obtain a copy of the License at
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            #   http://www.apache.org/licenses/LICENSE-2.0
         | 
| 10 | 
            +
            #
         | 
| 11 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 12 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 13 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 14 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 15 | 
            +
            # limitations under the License.
         | 
| 16 | 
            +
            # Modified from ESPnet(https://github.com/espnet/espnet)
         | 
| 17 | 
            +
            """Encoder definition."""
         | 
| 18 | 
            +
            from typing import Tuple
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            import torch
         | 
| 21 | 
            +
            from torch import nn
         | 
| 22 | 
            +
            from torch.nn import functional as F
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            from .convolution import ConvolutionModule
         | 
| 25 | 
            +
            from .encoder_layer import ConformerEncoderLayer
         | 
| 26 | 
            +
            from .positionwise_feed_forward import PositionwiseFeedForward
         | 
| 27 | 
            +
            from ..utils.class_utils import (
         | 
| 28 | 
            +
                COSYVOICE_EMB_CLASSES,
         | 
| 29 | 
            +
                COSYVOICE_SUBSAMPLE_CLASSES,
         | 
| 30 | 
            +
                COSYVOICE_ATTENTION_CLASSES,
         | 
| 31 | 
            +
                COSYVOICE_ACTIVATION_CLASSES,
         | 
| 32 | 
            +
            )
         | 
| 33 | 
            +
            from ..utils.mask import make_pad_mask
         | 
| 34 | 
            +
            from ..utils.mask import add_optional_chunk_mask
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            class Upsample1D(nn.Module):
         | 
| 38 | 
            +
                """A 1D upsampling layer with an optional convolution.
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                Parameters:
         | 
| 41 | 
            +
                    channels (`int`):
         | 
| 42 | 
            +
                        number of channels in the inputs and outputs.
         | 
| 43 | 
            +
                    use_conv (`bool`, default `False`):
         | 
| 44 | 
            +
                        option to use a convolution.
         | 
| 45 | 
            +
                    use_conv_transpose (`bool`, default `False`):
         | 
| 46 | 
            +
                        option to use a convolution transpose.
         | 
| 47 | 
            +
                    out_channels (`int`, optional):
         | 
| 48 | 
            +
                        number of output channels. Defaults to `channels`.
         | 
| 49 | 
            +
                """
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                def __init__(self, channels: int, out_channels: int, stride: int = 2):
         | 
| 52 | 
            +
                    super().__init__()
         | 
| 53 | 
            +
                    self.channels = channels
         | 
| 54 | 
            +
                    self.out_channels = out_channels
         | 
| 55 | 
            +
                    self.stride = stride
         | 
| 56 | 
            +
                    # In this mode, first repeat interpolate, than conv with stride=1
         | 
| 57 | 
            +
                    self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor):
         | 
| 60 | 
            +
                    outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest")
         | 
| 61 | 
            +
                    outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
         | 
| 62 | 
            +
                    outputs = self.conv(outputs)
         | 
| 63 | 
            +
                    return outputs, input_lengths * self.stride
         | 
| 64 | 
            +
             | 
| 65 | 
            +
             | 
| 66 | 
            +
            class PreLookaheadLayer(nn.Module):
         | 
| 67 | 
            +
                def __init__(self, channels: int, pre_lookahead_len: int = 1):
         | 
| 68 | 
            +
                    super().__init__()
         | 
| 69 | 
            +
                    self.channels = channels
         | 
| 70 | 
            +
                    self.pre_lookahead_len = pre_lookahead_len
         | 
| 71 | 
            +
                    self.conv1 = nn.Conv1d(
         | 
| 72 | 
            +
                        channels, channels,
         | 
| 73 | 
            +
                        kernel_size=pre_lookahead_len + 1,
         | 
| 74 | 
            +
                        stride=1, padding=0,
         | 
| 75 | 
            +
                    )
         | 
| 76 | 
            +
                    self.conv2 = nn.Conv1d(
         | 
| 77 | 
            +
                        channels, channels,
         | 
| 78 | 
            +
                        kernel_size=3, stride=1, padding=0,
         | 
| 79 | 
            +
                    )
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                def forward(self, inputs: torch.Tensor) -> torch.Tensor:
         | 
| 82 | 
            +
                    """
         | 
| 83 | 
            +
                    inputs: (batch_size, seq_len, channels)
         | 
| 84 | 
            +
                    """
         | 
| 85 | 
            +
                    outputs = inputs.transpose(1, 2).contiguous()
         | 
| 86 | 
            +
                    # look ahead
         | 
| 87 | 
            +
                    outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
         | 
| 88 | 
            +
                    outputs = F.leaky_relu(self.conv1(outputs))
         | 
| 89 | 
            +
                    # outputs
         | 
| 90 | 
            +
                    outputs = F.pad(outputs, (2, 0), mode='constant', value=0.0)
         | 
| 91 | 
            +
                    outputs = self.conv2(outputs)
         | 
| 92 | 
            +
                    outputs = outputs.transpose(1, 2).contiguous()
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    # residual connection
         | 
| 95 | 
            +
                    outputs = outputs + inputs
         | 
| 96 | 
            +
                    return outputs
         | 
| 97 | 
            +
             | 
| 98 | 
            +
             | 
| 99 | 
            +
            class UpsampleConformerEncoder(torch.nn.Module):
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                def __init__(
         | 
| 102 | 
            +
                    self,
         | 
| 103 | 
            +
                    input_size: int = 512,
         | 
| 104 | 
            +
                    output_size: int = 512,
         | 
| 105 | 
            +
                    attention_heads: int = 8,
         | 
| 106 | 
            +
                    linear_units: int = 2048,
         | 
| 107 | 
            +
                    num_blocks: int = 6,
         | 
| 108 | 
            +
                    dropout_rate: float = 0.1,
         | 
| 109 | 
            +
                    positional_dropout_rate: float = 0.1,
         | 
| 110 | 
            +
                    attention_dropout_rate: float = 0.1,
         | 
| 111 | 
            +
                    input_layer: str = "linear",
         | 
| 112 | 
            +
                    pos_enc_layer_type: str = "rel_pos_espnet",
         | 
| 113 | 
            +
                    normalize_before: bool = True,
         | 
| 114 | 
            +
                    static_chunk_size: int = 0,
         | 
| 115 | 
            +
                    use_dynamic_chunk: bool = False,
         | 
| 116 | 
            +
                    global_cmvn: torch.nn.Module = None,
         | 
| 117 | 
            +
                    use_dynamic_left_chunk: bool = False,
         | 
| 118 | 
            +
                    positionwise_conv_kernel_size: int = 1,
         | 
| 119 | 
            +
                    macaron_style: bool = False,
         | 
| 120 | 
            +
                    selfattention_layer_type: str = "rel_selfattn",
         | 
| 121 | 
            +
                    activation_type: str = "swish",
         | 
| 122 | 
            +
                    use_cnn_module: bool = False,
         | 
| 123 | 
            +
                    cnn_module_kernel: int = 15,
         | 
| 124 | 
            +
                    causal: bool = False,
         | 
| 125 | 
            +
                    cnn_module_norm: str = "batch_norm",
         | 
| 126 | 
            +
                    key_bias: bool = True,
         | 
| 127 | 
            +
                    gradient_checkpointing: bool = False,
         | 
| 128 | 
            +
                ):
         | 
| 129 | 
            +
                    """
         | 
| 130 | 
            +
                    Args:
         | 
| 131 | 
            +
                        input_size (int): input dim
         | 
| 132 | 
            +
                        output_size (int): dimension of attention
         | 
| 133 | 
            +
                        attention_heads (int): the number of heads of multi head attention
         | 
| 134 | 
            +
                        linear_units (int): the hidden units number of position-wise feed
         | 
| 135 | 
            +
                            forward
         | 
| 136 | 
            +
                        num_blocks (int): the number of decoder blocks
         | 
| 137 | 
            +
                        dropout_rate (float): dropout rate
         | 
| 138 | 
            +
                        attention_dropout_rate (float): dropout rate in attention
         | 
| 139 | 
            +
                        positional_dropout_rate (float): dropout rate after adding
         | 
| 140 | 
            +
                            positional encoding
         | 
| 141 | 
            +
                        input_layer (str): input layer type.
         | 
| 142 | 
            +
                            optional [linear, conv2d, conv2d6, conv2d8]
         | 
| 143 | 
            +
                        pos_enc_layer_type (str): Encoder positional encoding layer type.
         | 
| 144 | 
            +
                            opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
         | 
| 145 | 
            +
                        normalize_before (bool):
         | 
| 146 | 
            +
                            True: use layer_norm before each sub-block of a layer.
         | 
| 147 | 
            +
                            False: use layer_norm after each sub-block of a layer.
         | 
| 148 | 
            +
                        static_chunk_size (int): chunk size for static chunk training and
         | 
| 149 | 
            +
                            decoding
         | 
| 150 | 
            +
                        use_dynamic_chunk (bool): whether use dynamic chunk size for
         | 
| 151 | 
            +
                            training or not, You can only use fixed chunk(chunk_size > 0)
         | 
| 152 | 
            +
                            or dyanmic chunk size(use_dynamic_chunk = True)
         | 
| 153 | 
            +
                        global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
         | 
| 154 | 
            +
                        use_dynamic_left_chunk (bool): whether use dynamic left chunk in
         | 
| 155 | 
            +
                            dynamic chunk training
         | 
| 156 | 
            +
                        key_bias: whether use bias in attention.linear_k, False for whisper models.
         | 
| 157 | 
            +
                        gradient_checkpointing: rerunning a forward-pass segment for each
         | 
| 158 | 
            +
                            checkpointed segment during backward.
         | 
| 159 | 
            +
                    """
         | 
| 160 | 
            +
                    super().__init__()
         | 
| 161 | 
            +
                    self._output_size = output_size
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                    self.global_cmvn = global_cmvn
         | 
| 164 | 
            +
                    self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
         | 
| 165 | 
            +
                        input_size,
         | 
| 166 | 
            +
                        output_size,
         | 
| 167 | 
            +
                        dropout_rate,
         | 
| 168 | 
            +
                        COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
         | 
| 169 | 
            +
                                                                  positional_dropout_rate),
         | 
| 170 | 
            +
                    )
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                    self.normalize_before = normalize_before
         | 
| 173 | 
            +
                    self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
         | 
| 174 | 
            +
                    self.static_chunk_size = static_chunk_size
         | 
| 175 | 
            +
                    self.use_dynamic_chunk = use_dynamic_chunk
         | 
| 176 | 
            +
                    self.use_dynamic_left_chunk = use_dynamic_left_chunk
         | 
| 177 | 
            +
                    self.gradient_checkpointing = gradient_checkpointing
         | 
| 178 | 
            +
                    activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
         | 
| 179 | 
            +
                    # self-attention module definition
         | 
| 180 | 
            +
                    encoder_selfattn_layer_args = (
         | 
| 181 | 
            +
                        attention_heads,
         | 
| 182 | 
            +
                        output_size,
         | 
| 183 | 
            +
                        attention_dropout_rate,
         | 
| 184 | 
            +
                        key_bias,
         | 
| 185 | 
            +
                    )
         | 
| 186 | 
            +
                    # feed-forward module definition
         | 
| 187 | 
            +
                    positionwise_layer_args = (
         | 
| 188 | 
            +
                        output_size,
         | 
| 189 | 
            +
                        linear_units,
         | 
| 190 | 
            +
                        dropout_rate,
         | 
| 191 | 
            +
                        activation,
         | 
| 192 | 
            +
                    )
         | 
| 193 | 
            +
                    # convolution module definition
         | 
| 194 | 
            +
                    convolution_layer_args = (output_size, cnn_module_kernel, activation,
         | 
| 195 | 
            +
                                              cnn_module_norm, causal)
         | 
| 196 | 
            +
                    self.pre_lookahead_layer = PreLookaheadLayer(channels=512, pre_lookahead_len=3)
         | 
| 197 | 
            +
                    self.encoders = torch.nn.ModuleList([
         | 
| 198 | 
            +
                        ConformerEncoderLayer(
         | 
| 199 | 
            +
                            output_size,
         | 
| 200 | 
            +
                            COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
         | 
| 201 | 
            +
                                *encoder_selfattn_layer_args),
         | 
| 202 | 
            +
                            PositionwiseFeedForward(*positionwise_layer_args),
         | 
| 203 | 
            +
                            PositionwiseFeedForward(
         | 
| 204 | 
            +
                                *positionwise_layer_args) if macaron_style else None,
         | 
| 205 | 
            +
                            ConvolutionModule(
         | 
| 206 | 
            +
                                *convolution_layer_args) if use_cnn_module else None,
         | 
| 207 | 
            +
                            dropout_rate,
         | 
| 208 | 
            +
                            normalize_before,
         | 
| 209 | 
            +
                        ) for _ in range(num_blocks)
         | 
| 210 | 
            +
                    ])
         | 
| 211 | 
            +
                    self.up_layer = Upsample1D(channels=512, out_channels=512, stride=2)
         | 
| 212 | 
            +
                    self.up_embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
         | 
| 213 | 
            +
                        input_size,
         | 
| 214 | 
            +
                        output_size,
         | 
| 215 | 
            +
                        dropout_rate,
         | 
| 216 | 
            +
                        COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
         | 
| 217 | 
            +
                                                                  positional_dropout_rate),
         | 
| 218 | 
            +
                    )
         | 
| 219 | 
            +
                    self.up_encoders = torch.nn.ModuleList([
         | 
| 220 | 
            +
                        ConformerEncoderLayer(
         | 
| 221 | 
            +
                            output_size,
         | 
| 222 | 
            +
                            COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
         | 
| 223 | 
            +
                                *encoder_selfattn_layer_args),
         | 
| 224 | 
            +
                            PositionwiseFeedForward(*positionwise_layer_args),
         | 
| 225 | 
            +
                            PositionwiseFeedForward(
         | 
| 226 | 
            +
                                *positionwise_layer_args) if macaron_style else None,
         | 
| 227 | 
            +
                            ConvolutionModule(
         | 
| 228 | 
            +
                                *convolution_layer_args) if use_cnn_module else None,
         | 
| 229 | 
            +
                            dropout_rate,
         | 
| 230 | 
            +
                            normalize_before,
         | 
| 231 | 
            +
                        ) for _ in range(4)
         | 
| 232 | 
            +
                    ])
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                def output_size(self) -> int:
         | 
| 235 | 
            +
                    return self._output_size
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                def forward(
         | 
| 238 | 
            +
                    self,
         | 
| 239 | 
            +
                    xs: torch.Tensor,
         | 
| 240 | 
            +
                    xs_lens: torch.Tensor,
         | 
| 241 | 
            +
                    decoding_chunk_size: int = 0,
         | 
| 242 | 
            +
                    num_decoding_left_chunks: int = -1,
         | 
| 243 | 
            +
                ) -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 244 | 
            +
                    """Embed positions in tensor.
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                    Args:
         | 
| 247 | 
            +
                        xs: padded input tensor (B, T, D)
         | 
| 248 | 
            +
                        xs_lens: input length (B)
         | 
| 249 | 
            +
                        decoding_chunk_size: decoding chunk size for dynamic chunk
         | 
| 250 | 
            +
                            0: default for training, use random dynamic chunk.
         | 
| 251 | 
            +
                            <0: for decoding, use full chunk.
         | 
| 252 | 
            +
                            >0: for decoding, use fixed chunk size as set.
         | 
| 253 | 
            +
                        num_decoding_left_chunks: number of left chunks, this is for decoding,
         | 
| 254 | 
            +
                        the chunk size is decoding_chunk_size.
         | 
| 255 | 
            +
                            >=0: use num_decoding_left_chunks
         | 
| 256 | 
            +
                            <0: use all left chunks
         | 
| 257 | 
            +
                    Returns:
         | 
| 258 | 
            +
                        encoder output tensor xs, and subsampled masks
         | 
| 259 | 
            +
                        xs: padded output tensor (B, T' ~= T/subsample_rate, D)
         | 
| 260 | 
            +
                        masks: torch.Tensor batch padding mask after subsample
         | 
| 261 | 
            +
                            (B, 1, T' ~= T/subsample_rate)
         | 
| 262 | 
            +
                    NOTE(xcsong):
         | 
| 263 | 
            +
                        We pass the `__call__` method of the modules instead of `forward` to the
         | 
| 264 | 
            +
                        checkpointing API because `__call__` attaches all the hooks of the module.
         | 
| 265 | 
            +
                        https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
         | 
| 266 | 
            +
                    """
         | 
| 267 | 
            +
                    T = xs.size(1)
         | 
| 268 | 
            +
                    masks = ~make_pad_mask(xs_lens, T).unsqueeze(1)  # (B, 1, T)
         | 
| 269 | 
            +
                    if self.global_cmvn is not None:
         | 
| 270 | 
            +
                        xs = self.global_cmvn(xs)
         | 
| 271 | 
            +
                    xs, pos_emb, masks = self.embed(xs, masks)
         | 
| 272 | 
            +
                    mask_pad = masks  # (B, 1, T/subsample_rate)
         | 
| 273 | 
            +
                    chunk_masks = add_optional_chunk_mask(xs, masks,
         | 
| 274 | 
            +
                                                          self.use_dynamic_chunk,
         | 
| 275 | 
            +
                                                          self.use_dynamic_left_chunk,
         | 
| 276 | 
            +
                                                          decoding_chunk_size,
         | 
| 277 | 
            +
                                                          self.static_chunk_size,
         | 
| 278 | 
            +
                                                          num_decoding_left_chunks)
         | 
| 279 | 
            +
                    # lookahead + conformer encoder
         | 
| 280 | 
            +
                    xs = self.pre_lookahead_layer(xs)
         | 
| 281 | 
            +
                    xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                    # upsample + conformer encoder
         | 
| 284 | 
            +
                    xs = xs.transpose(1, 2).contiguous()
         | 
| 285 | 
            +
                    xs, xs_lens = self.up_layer(xs, xs_lens)
         | 
| 286 | 
            +
                    xs = xs.transpose(1, 2).contiguous()
         | 
| 287 | 
            +
                    T = xs.size(1)
         | 
| 288 | 
            +
                    masks = ~make_pad_mask(xs_lens, T).unsqueeze(1)  # (B, 1, T)
         | 
| 289 | 
            +
                    xs, pos_emb, masks = self.up_embed(xs, masks)
         | 
| 290 | 
            +
                    mask_pad = masks  # (B, 1, T/subsample_rate)
         | 
| 291 | 
            +
                    chunk_masks = add_optional_chunk_mask(xs, masks,
         | 
| 292 | 
            +
                                                          self.use_dynamic_chunk,
         | 
| 293 | 
            +
                                                          self.use_dynamic_left_chunk,
         | 
| 294 | 
            +
                                                          decoding_chunk_size,
         | 
| 295 | 
            +
                                                          self.static_chunk_size * self.up_layer.stride,
         | 
| 296 | 
            +
                                                          num_decoding_left_chunks)
         | 
| 297 | 
            +
                    xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                    if self.normalize_before:
         | 
| 300 | 
            +
                        xs = self.after_norm(xs)
         | 
| 301 | 
            +
                    # Here we assume the mask is not changed in encoder layers, so just
         | 
| 302 | 
            +
                    # return the masks before encoder layers, and the masks will be used
         | 
| 303 | 
            +
                    # for cross attention with decoder later
         | 
| 304 | 
            +
                    return xs, masks
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
         | 
| 307 | 
            +
                                   pos_emb: torch.Tensor,
         | 
| 308 | 
            +
                                   mask_pad: torch.Tensor) -> torch.Tensor:
         | 
| 309 | 
            +
                    for layer in self.encoders:
         | 
| 310 | 
            +
                        xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
         | 
| 311 | 
            +
                    return xs
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                def forward_up_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
         | 
| 314 | 
            +
                                      pos_emb: torch.Tensor,
         | 
| 315 | 
            +
                                      mask_pad: torch.Tensor) -> torch.Tensor:
         | 
| 316 | 
            +
                    for layer in self.up_encoders:
         | 
| 317 | 
            +
                        xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
         | 
| 318 | 
            +
                    return xs
         | 
    	
        chatterbox/src/chatterbox/models/s3gen/utils/__pycache__/class_utils.cpython-311.pyc
    ADDED
    
    | Binary file (1.93 kB). View file | 
|  | 
