martin commited on
Commit
67c46fd
·
1 Parent(s): 703c9a0
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +7 -35
  2. .gitignore +1 -0
  3. Dockerfile +46 -0
  4. __init__.py +0 -0
  5. app.py +173 -0
  6. assets/assistant.png +3 -0
  7. assets/user.png +3 -0
  8. clone_hf_model.sh +57 -0
  9. cosyvoice/__init__.py +0 -0
  10. cosyvoice/cli/__init__.py +0 -0
  11. cosyvoice/cli/cosyvoice.py +68 -0
  12. cosyvoice/cli/frontend.py +106 -0
  13. cosyvoice/cli/model.py +32 -0
  14. cosyvoice/flow/decoder.py +238 -0
  15. cosyvoice/flow/flow.py +196 -0
  16. cosyvoice/flow/flow_matching.py +315 -0
  17. cosyvoice/flow/length_regulator.py +65 -0
  18. cosyvoice/hifigan/f0_predictor.py +55 -0
  19. cosyvoice/hifigan/generator.py +566 -0
  20. cosyvoice/matcha/audio.py +90 -0
  21. cosyvoice/matcha/decoder.py +511 -0
  22. cosyvoice/matcha/flow_matching.py +141 -0
  23. cosyvoice/matcha/transformer.py +443 -0
  24. cosyvoice/transformer/__init__.py +0 -0
  25. cosyvoice/transformer/activation.py +87 -0
  26. cosyvoice/transformer/attention.py +322 -0
  27. cosyvoice/transformer/convolution.py +147 -0
  28. cosyvoice/transformer/decoder.py +418 -0
  29. cosyvoice/transformer/decoder_layer.py +132 -0
  30. cosyvoice/transformer/embedding.py +293 -0
  31. cosyvoice/transformer/encoder.py +633 -0
  32. cosyvoice/transformer/encoder_layer.py +237 -0
  33. cosyvoice/transformer/label_smoothing_loss.py +98 -0
  34. cosyvoice/transformer/positionwise_feed_forward.py +116 -0
  35. cosyvoice/transformer/subsampling.py +391 -0
  36. cosyvoice/utils/__init__.py +0 -0
  37. cosyvoice/utils/audio.py +90 -0
  38. cosyvoice/utils/class_utils.py +78 -0
  39. cosyvoice/utils/common.py +169 -0
  40. cosyvoice/utils/executor.py +151 -0
  41. cosyvoice/utils/file_utils.py +49 -0
  42. cosyvoice/utils/frontend_utils.py +142 -0
  43. cosyvoice/utils/mask.py +226 -0
  44. cosyvoice/utils/scheduler.py +761 -0
  45. cosyvoice/utils/train_utils.py +350 -0
  46. funasr_detach/__init__.py +38 -0
  47. funasr_detach/auto/__init__.py +0 -0
  48. funasr_detach/auto/auto_frontend.py +90 -0
  49. funasr_detach/auto/auto_model.py +573 -0
  50. funasr_detach/auto/auto_tokenizer.py +7 -0
.gitattributes CHANGED
@@ -1,35 +1,7 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.png filter=lfs diff=lfs merge=lfs -text
2
+ *.wav filter=lfs diff=lfs merge=lfs -text
3
+ assets/user.png filter=lfs diff=lfs merge=lfs -text
4
+ assets/assistant.png filter=lfs diff=lfs merge=lfs -text
5
+ speakers/闫雨婷_prompt.wav filter=lfs diff=lfs merge=lfs -text
6
+ speakers/闫雨婷RAP_prompt.wav filter=lfs diff=lfs merge=lfs -text
7
+ speakers/闫雨婷VOCAL_prompt.wav filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
Dockerfile ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.1.0-base-ubuntu20.04
2
+
3
+ ENV TZ=Asia/Shanghai
4
+ RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime \
5
+ && echo $TZ > /etc/timezone
6
+
7
+ RUN apt-get update \
8
+ && apt-get install -y build-essential \
9
+ && apt-get install -y wget \
10
+ && apt-get install -y software-properties-common curl zip unzip git-lfs awscli libssl-dev openssh-server vim \
11
+ && apt-get install -y net-tools iputils-ping iproute2
12
+
13
+ RUN apt-get install --reinstall ca-certificates && update-ca-certificates
14
+
15
+ RUN add-apt-repository -y 'ppa:deadsnakes/ppa' && apt update
16
+ RUN apt install python3.10 python3.10-dev python3.10-distutils python3.10-venv -y \
17
+ && apt-get clean \
18
+ && rm -rf /var/lib/apt/lists/*
19
+
20
+ RUN wget -qO- https://bootstrap.pypa.io/get-pip.py | python3.10
21
+ RUN ln -s /usr/bin/python3.10 /usr/bin/python
22
+ RUN pip uninstall -y Pillow && pip install pillow
23
+
24
+ # https://huggingface.co/docs/hub/spaces-sdks-docker#permissions
25
+ RUN useradd -m -u 1000 user
26
+ USER user
27
+
28
+ ENV HOME="/home/user" \
29
+ PATH="/home/user/.local/bin:${PATH}"
30
+
31
+ RUN python3.10 -m pip install pipx
32
+ RUN pipx install poetry
33
+
34
+ RUN poetry --version || { echo 'Poetry installation check failed' ; exit 1; }
35
+
36
+ WORKDIR /workspace
37
+
38
+ COPY --chown=user requirements.txt .
39
+ RUN pip install -r requirements.txt
40
+
41
+ COPY --chown=user . .
42
+
43
+ RUN pip install gradio
44
+ RUN chmod +x clone_hf_model.sh
45
+ ENV HF_MODEL_PATH="/tmp/hf_model"
46
+ CMD ["./clone_hf_model.sh", "$HF_MODEL_PATH", "&&", "python", "app.py", "--model", "$HF_MODEL_PATH"]
__init__.py ADDED
File without changes
app.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import time
3
+ from pathlib import Path
4
+ import torchaudio
5
+ from stepaudio import StepAudio
6
+
7
+ from funasr import AutoModel
8
+ from funasr.utils.postprocess_utils import rich_transcription_postprocess
9
+
10
+ CACHE_DIR = "/tmp/gradio/"
11
+ system_promtp = {"role": "system", "content": "适配用户的语言,用简短口语化的文字回答"}
12
+
13
+
14
+ class CustomAsr:
15
+ def __init__(self, model_name="iic/SenseVoiceSmall", device="cuda"):
16
+ self.model = AutoModel(
17
+ model=model_name,
18
+ vad_model="fsmn-vad",
19
+ vad_kwargs={"max_single_segment_time": 30000},
20
+ device=device,
21
+ )
22
+
23
+ def run(self, audio_path):
24
+ res = self.model.generate(
25
+ input=audio_path,
26
+ cache={},
27
+ language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech"
28
+ use_itn=True,
29
+ batch_size_s=60,
30
+ merge_vad=True, #
31
+ merge_length_s=15,
32
+ )
33
+ text = rich_transcription_postprocess(res[0]["text"])
34
+ return text
35
+
36
+
37
+ def add_message(chatbot, history, mic, text, asr_model):
38
+ if not mic and not text:
39
+ return chatbot, history, "Input is empty"
40
+
41
+ if text:
42
+ chatbot.append({"role": "user", "content": text})
43
+ history.append({"role": "user", "content": text})
44
+ elif mic and Path(mic).exists():
45
+ chatbot.append({"role": "user", "content": {"path": mic}})
46
+ # 使用用户语音的 asr 结果为了加速推理
47
+ text = asr_model.run(mic)
48
+ chatbot.append({"role": "user", "content": text})
49
+ history.append({"role": "user", "content": text})
50
+
51
+ print(f"{history=}")
52
+ return chatbot, history, None
53
+
54
+
55
+ def reset_state():
56
+ """Reset the chat history."""
57
+ return [], [system_promtp]
58
+
59
+
60
+ def save_tmp_audio(audio, sr):
61
+ import tempfile
62
+
63
+ with tempfile.NamedTemporaryFile(
64
+ dir=CACHE_DIR, delete=False, suffix=".wav"
65
+ ) as temp_audio:
66
+ temp_audio_path = temp_audio.name
67
+ torchaudio.save(temp_audio_path, audio, sr)
68
+
69
+ return temp_audio.name
70
+
71
+
72
+ def predict(chatbot, history, audio_model):
73
+ """Generate a response from the model."""
74
+ try:
75
+ text, audio, sr = audio_model(history, "闫雨婷")
76
+ print(f"predict {text=}")
77
+ audio_path = save_tmp_audio(audio, sr)
78
+ chatbot.append({"role": "assistant", "content": {"path": audio_path}})
79
+ chatbot.append({"role": "assistant", "content": text})
80
+ history.append({"role": "assistant", "content": text})
81
+ except Exception as e:
82
+ print(e)
83
+ gr.Warning(f"Some error happend, retry submit")
84
+ return chatbot, history
85
+
86
+
87
+ def _launch_demo(args, audio_model, asr_model):
88
+ with gr.Blocks(delete_cache=(86400, 86400)) as demo:
89
+ gr.Markdown("""<center><font size=8>Step Audio Chat</center>""")
90
+ chatbot = gr.Chatbot(
91
+ elem_id="chatbot",
92
+ avatar_images=["assets/user.png", "assets/assistant.png"],
93
+ min_height=800,
94
+ type="messages",
95
+ )
96
+ # 保存 chat 历史,不需要每次再重新拼格式
97
+ history = gr.State([system_promtp])
98
+ mic = gr.Audio(type="filepath")
99
+ text = gr.Textbox(placeholder="Enter message ...")
100
+
101
+ with gr.Row():
102
+ clean_btn = gr.Button("🧹 Clear History (清除历史)")
103
+ regen_btn = gr.Button("🤔️ Regenerate (重试)")
104
+ submit_btn = gr.Button("🚀 Submit")
105
+
106
+ def on_submit(chatbot, history, mic, text):
107
+ chatbot, history, error = add_message(
108
+ chatbot, history, mic, text, asr_model
109
+ )
110
+ if error:
111
+ gr.Warning(error) # 显示警告消息
112
+ return chatbot, history, None, None
113
+ else:
114
+ chatbot, history = predict(chatbot, history, audio_model)
115
+ return chatbot, history, None, None
116
+
117
+ submit_btn.click(
118
+ fn=on_submit,
119
+ inputs=[chatbot, history, mic, text],
120
+ outputs=[chatbot, history, mic, text],
121
+ concurrency_limit=4,
122
+ concurrency_id="gpu_queue",
123
+ )
124
+ clean_btn.click(
125
+ reset_state,
126
+ outputs=[chatbot, history],
127
+ show_progress=True,
128
+ )
129
+
130
+ def regenerate(chatbot, history):
131
+ while chatbot and chatbot[-1]["role"] == "assistant":
132
+ chatbot.pop()
133
+ while history and history[-1]["role"] == "assistant":
134
+ print(f"discard {history[-1]}")
135
+ history.pop()
136
+ return predict(chatbot, history, audio_model)
137
+
138
+ regen_btn.click(
139
+ regenerate,
140
+ [chatbot, history],
141
+ [chatbot, history],
142
+ show_progress=True,
143
+ concurrency_id="gpu_queue",
144
+ )
145
+
146
+ demo.queue().launch(
147
+ share=False,
148
+ server_port=args.server_port,
149
+ server_name=args.server_name,
150
+ )
151
+
152
+
153
+ if __name__ == "__main__":
154
+ from argparse import ArgumentParser
155
+ import os
156
+
157
+ parser = ArgumentParser()
158
+ parser.add_argument("--model-path", type=str, required=True, help="Model path.")
159
+ parser.add_argument(
160
+ "--server-port", type=int, default=7860, help="Demo server port."
161
+ )
162
+ parser.add_argument(
163
+ "--server-name", type=str, default="0.0.0.0", help="Demo server name."
164
+ )
165
+ args = parser.parse_args()
166
+
167
+ audio_model = StepAudio(
168
+ tokenizer_path=os.path.join(args.model_path, "Step-Audio-Tokenizer"),
169
+ tts_path=os.path.join(args.model_path, "Step-Audio-TTS-3B"),
170
+ llm_path=os.path.join(args.model_path, "Step-Audio-Chat"),
171
+ )
172
+ asr_model = CustomAsr()
173
+ _launch_demo(args, audio_model, asr_model)
assets/assistant.png ADDED

Git LFS Details

  • SHA256: 9e86f2162eb1ba508dc3faf4526ce1d124c9a388deb575729b4b4d3aea0fda20
  • Pointer size: 129 Bytes
  • Size of remote file: 2.23 kB
assets/user.png ADDED

Git LFS Details

  • SHA256: 972dbabce7049a264e4d4ad7ead51ce5a42bc3106f19f0eddbddc04419d3575e
  • Pointer size: 129 Bytes
  • Size of remote file: 8.15 kB
clone_hf_model.sh ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ if [ -z "$HF_USER_NAME" ]; then
4
+ echo "错误:环境变量 HF_USER_NAME 未设置!"
5
+ exit 1
6
+ fi
7
+
8
+ if [ -z "$HF_USER_TOKEN" ]; then
9
+ echo "错误:环境变量 HF_USER_TOKEN 未设置!"
10
+ exit 1
11
+ fi
12
+
13
+ # 启用Git LFS支持
14
+ git lfs install --force
15
+
16
+ # 定义需要克隆的仓库列表
17
+ BASE_REPO_URL="https://${HF_USER_NAME}:${HF_USER_TOKEN}@huggingface.co/stepfun-ai"
18
+ REPOSITORIES=(
19
+ "Step-Audio-Tokenizer"
20
+ "Step-Audio-TTS-3B"
21
+ "Step-Audio-Chat"
22
+ )
23
+
24
+ # 定义本地存放仓库的目录,默认为当前目录
25
+ LOCAL_DIR="${1:-$(pwd)}"
26
+
27
+ # 克隆函数(带无限重试机制)
28
+ clone_with_retry() {
29
+ local repo_name=$1
30
+ local repo_url="${BASE_REPO_URL}/${repo_name}"
31
+ local target_dir="${LOCAL_DIR}/${repo_name}"
32
+
33
+ # 检查是否已存在目录
34
+ if [ -d "${target_dir}" ]; then
35
+ echo "目录 ${target_dir} 已存在,跳过克隆。"
36
+ return 0
37
+ fi
38
+
39
+ # 无限重试循环
40
+ while true; do
41
+ echo "正在尝试克隆 ${repo_name} 到 ${target_dir}..."
42
+ if git clone "${repo_url}" "${target_dir}"; then
43
+ echo "成功克隆 ${repo_name} 到 ${target_dir}"
44
+ return 0
45
+ else
46
+ echo "克隆失败, 5秒后重试..."
47
+ sleep 5
48
+ fi
49
+ done
50
+ }
51
+
52
+ # 遍历所有仓库进行克隆
53
+ for repo in "${REPOSITORIES[@]}"; do
54
+ clone_with_retry "${repo}"
55
+ done
56
+
57
+ echo "所有仓库已成功下载!"
cosyvoice/__init__.py ADDED
File without changes
cosyvoice/cli/__init__.py ADDED
File without changes
cosyvoice/cli/cosyvoice.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import uuid
16
+ import time
17
+ from tqdm import tqdm
18
+ import torch
19
+ import torchaudio
20
+ from hyperpyyaml import load_hyperpyyaml
21
+ from cosyvoice.cli.frontend import CosyVoiceFrontEnd
22
+ from cosyvoice.cli.model import CosyVoiceModel
23
+
24
+
25
+ class CosyVoice:
26
+
27
+ def __init__(
28
+ self,
29
+ model_dir,
30
+ ):
31
+ self.model_dir = model_dir
32
+ with open("{}/cosyvoice.yaml".format(model_dir), "r") as f:
33
+ configs = load_hyperpyyaml(f)
34
+ self.frontend = CosyVoiceFrontEnd(
35
+ configs["feat_extractor"],
36
+ "{}/campplus.onnx".format(model_dir),
37
+ "{}/speech_tokenizer_v1.onnx".format(model_dir),
38
+ )
39
+ self.model = CosyVoiceModel(configs["flow"], configs["hift"])
40
+ self.model.load(
41
+ "{}/flow.pt".format(model_dir),
42
+ "{}/hift.pt".format(model_dir),
43
+ )
44
+ self.model.flow = self.model.flow.to(torch.bfloat16)
45
+ del configs
46
+
47
+ def token_to_wav_offline(
48
+ self,
49
+ speech_token,
50
+ speech_feat,
51
+ speech_feat_len,
52
+ prompt_token,
53
+ prompt_token_len,
54
+ embedding,
55
+ ):
56
+ tts_mel = self.model.flow.inference(
57
+ token=speech_token.to(self.model.device),
58
+ token_len=torch.tensor([speech_token.size(1)], dtype=torch.int32).to(
59
+ self.model.device
60
+ ),
61
+ prompt_token=prompt_token.to(self.model.device),
62
+ prompt_token_len=prompt_token_len.to(self.model.device),
63
+ prompt_feat=speech_feat.to(self.model.device),
64
+ prompt_feat_len=speech_feat_len.to(self.model.device),
65
+ embedding=embedding.to(self.model.device),
66
+ )
67
+ tts_speech = self.model.hift.inference(mel=tts_mel.float())[0].cpu()
68
+ return tts_speech
cosyvoice/cli/frontend.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import onnxruntime
15
+ import torch
16
+ import numpy as np
17
+ import whisper
18
+ from typing import Callable
19
+ import torchaudio.compliance.kaldi as kaldi
20
+
21
+
22
+ class CosyVoiceFrontEnd:
23
+
24
+ def __init__(
25
+ self,
26
+ feat_extractor: Callable,
27
+ campplus_model: str,
28
+ speech_tokenizer_model: str,
29
+ ):
30
+ self.feat_extractor = feat_extractor
31
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ option = onnxruntime.SessionOptions()
33
+ option.graph_optimization_level = (
34
+ onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
35
+ )
36
+ option.intra_op_num_threads = 1
37
+ self.campplus_session = onnxruntime.InferenceSession(
38
+ campplus_model, sess_options=option, providers=["CPUExecutionProvider"]
39
+ )
40
+ self.speech_tokenizer_session = onnxruntime.InferenceSession(
41
+ speech_tokenizer_model,
42
+ sess_options=option,
43
+ providers=[
44
+ (
45
+ "CUDAExecutionProvider"
46
+ if torch.cuda.is_available()
47
+ else "CPUExecutionProvider"
48
+ )
49
+ ],
50
+ )
51
+
52
+ def _extract_speech_token(self, speech):
53
+ assert (
54
+ speech.shape[1] / 16000 <= 30
55
+ ), "do not support extract speech token for audio longer than 30s"
56
+ feat = whisper.log_mel_spectrogram(speech, n_mels=128)
57
+ speech_token = (
58
+ self.speech_tokenizer_session.run(
59
+ None,
60
+ {
61
+ self.speech_tokenizer_session.get_inputs()[0]
62
+ .name: feat.detach()
63
+ .cpu()
64
+ .numpy(),
65
+ self.speech_tokenizer_session.get_inputs()[1].name: np.array(
66
+ [feat.shape[2]], dtype=np.int32
67
+ ),
68
+ },
69
+ )[0]
70
+ .flatten()
71
+ .tolist()
72
+ )
73
+ speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
74
+ speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(
75
+ self.device
76
+ )
77
+ return speech_token, speech_token_len
78
+
79
+ def _extract_spk_embedding(self, speech):
80
+ feat = kaldi.fbank(speech, num_mel_bins=80, dither=0, sample_frequency=16000)
81
+ feat = feat - feat.mean(dim=0, keepdim=True)
82
+ embedding = (
83
+ self.campplus_session.run(
84
+ None,
85
+ {
86
+ self.campplus_session.get_inputs()[0]
87
+ .name: feat.unsqueeze(dim=0)
88
+ .cpu()
89
+ .numpy()
90
+ },
91
+ )[0]
92
+ .flatten()
93
+ .tolist()
94
+ )
95
+ embedding = torch.tensor([embedding]).to(self.device)
96
+ return embedding
97
+
98
+ def _extract_speech_feat(self, speech):
99
+ speech_feat = (
100
+ self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
101
+ )
102
+ speech_feat = speech_feat.unsqueeze(dim=0)
103
+ speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(
104
+ self.device
105
+ )
106
+ return speech_feat, speech_feat_len
cosyvoice/cli/model.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+
16
+
17
+ class CosyVoiceModel:
18
+
19
+ def __init__(
20
+ self,
21
+ flow: torch.nn.Module,
22
+ hift: torch.nn.Module,
23
+ ):
24
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+ self.flow = flow
26
+ self.hift = hift
27
+
28
+ def load(self, flow_model, hift_model):
29
+ self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
30
+ self.flow.to(self.device).eval()
31
+ self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
32
+ self.hift.to(self.device).eval()
cosyvoice/flow/decoder.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ import torch.nn as nn
16
+ from einops import pack, rearrange, repeat
17
+ from cosyvoice.matcha.decoder import (
18
+ SinusoidalPosEmb,
19
+ Block1D,
20
+ ResnetBlock1D,
21
+ Downsample1D,
22
+ TimestepEmbedding,
23
+ Upsample1D,
24
+ )
25
+ from cosyvoice.matcha.transformer import BasicTransformerBlock
26
+
27
+
28
+ class ConditionalDecoder(nn.Module):
29
+ def __init__(
30
+ self,
31
+ in_channels,
32
+ out_channels,
33
+ channels=(256, 256),
34
+ dropout=0.05,
35
+ attention_head_dim=64,
36
+ n_blocks=1,
37
+ num_mid_blocks=2,
38
+ num_heads=4,
39
+ act_fn="snake",
40
+ ):
41
+ """
42
+ This decoder requires an input with the same shape of the target. So, if your text content
43
+ is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
44
+ """
45
+ super().__init__()
46
+ channels = tuple(channels)
47
+ self.in_channels = in_channels
48
+ self.out_channels = out_channels
49
+
50
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
51
+ time_embed_dim = channels[0] * 4
52
+ self.time_mlp = TimestepEmbedding(
53
+ in_channels=in_channels,
54
+ time_embed_dim=time_embed_dim,
55
+ act_fn="silu",
56
+ )
57
+ self.down_blocks = nn.ModuleList([])
58
+ self.mid_blocks = nn.ModuleList([])
59
+ self.up_blocks = nn.ModuleList([])
60
+
61
+ output_channel = in_channels
62
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
63
+ input_channel = output_channel
64
+ output_channel = channels[i]
65
+ is_last = i == len(channels) - 1
66
+ resnet = ResnetBlock1D(
67
+ dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim
68
+ )
69
+ transformer_blocks = nn.ModuleList(
70
+ [
71
+ BasicTransformerBlock(
72
+ dim=output_channel,
73
+ num_attention_heads=num_heads,
74
+ attention_head_dim=attention_head_dim,
75
+ dropout=dropout,
76
+ activation_fn=act_fn,
77
+ )
78
+ for _ in range(n_blocks)
79
+ ]
80
+ )
81
+ downsample = (
82
+ Downsample1D(output_channel)
83
+ if not is_last
84
+ else nn.Conv1d(output_channel, output_channel, 3, padding=1)
85
+ )
86
+ self.down_blocks.append(
87
+ nn.ModuleList([resnet, transformer_blocks, downsample])
88
+ )
89
+
90
+ for _ in range(num_mid_blocks):
91
+ input_channel = channels[-1]
92
+ out_channels = channels[-1]
93
+ resnet = ResnetBlock1D(
94
+ dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim
95
+ )
96
+
97
+ transformer_blocks = nn.ModuleList(
98
+ [
99
+ BasicTransformerBlock(
100
+ dim=output_channel,
101
+ num_attention_heads=num_heads,
102
+ attention_head_dim=attention_head_dim,
103
+ dropout=dropout,
104
+ activation_fn=act_fn,
105
+ )
106
+ for _ in range(n_blocks)
107
+ ]
108
+ )
109
+
110
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
111
+
112
+ channels = channels[::-1] + (channels[0],)
113
+ for i in range(len(channels) - 1):
114
+ input_channel = channels[i] * 2
115
+ output_channel = channels[i + 1]
116
+ is_last = i == len(channels) - 2
117
+ resnet = ResnetBlock1D(
118
+ dim=input_channel,
119
+ dim_out=output_channel,
120
+ time_emb_dim=time_embed_dim,
121
+ )
122
+ transformer_blocks = nn.ModuleList(
123
+ [
124
+ BasicTransformerBlock(
125
+ dim=output_channel,
126
+ num_attention_heads=num_heads,
127
+ attention_head_dim=attention_head_dim,
128
+ dropout=dropout,
129
+ activation_fn=act_fn,
130
+ )
131
+ for _ in range(n_blocks)
132
+ ]
133
+ )
134
+ upsample = (
135
+ Upsample1D(output_channel, use_conv_transpose=True)
136
+ if not is_last
137
+ else nn.Conv1d(output_channel, output_channel, 3, padding=1)
138
+ )
139
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
140
+ self.final_block = Block1D(channels[-1], channels[-1])
141
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
142
+ self.initialize_weights()
143
+
144
+ def initialize_weights(self):
145
+ for m in self.modules():
146
+ if isinstance(m, nn.Conv1d):
147
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
148
+ if m.bias is not None:
149
+ nn.init.constant_(m.bias, 0)
150
+ elif isinstance(m, nn.GroupNorm):
151
+ nn.init.constant_(m.weight, 1)
152
+ nn.init.constant_(m.bias, 0)
153
+ elif isinstance(m, nn.Linear):
154
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
155
+ if m.bias is not None:
156
+ nn.init.constant_(m.bias, 0)
157
+
158
+ def forward(self, x, mask, mu, t, spks=None, cond=None):
159
+ """Forward pass of the UNet1DConditional model.
160
+
161
+ Args:
162
+ x (torch.Tensor): shape (batch_size, in_channels, time)
163
+ mask (_type_): shape (batch_size, 1, time)
164
+ t (_type_): shape (batch_size)
165
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
166
+ cond (_type_, optional): placeholder for future use. Defaults to None.
167
+
168
+ Raises:
169
+ ValueError: _description_
170
+ ValueError: _description_
171
+
172
+ Returns:
173
+ _type_: _description_
174
+ """
175
+
176
+ t = self.time_embeddings(t).to(t.dtype)
177
+ t = self.time_mlp(t)
178
+
179
+ x = pack([x, mu], "b * t")[0]
180
+
181
+ if spks is not None:
182
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
183
+ x = pack([x, spks], "b * t")[0]
184
+ if cond is not None:
185
+ x = pack([x, cond], "b * t")[0]
186
+
187
+ hiddens = []
188
+ masks = [mask]
189
+ for resnet, transformer_blocks, downsample in self.down_blocks:
190
+ mask_down = masks[-1]
191
+ x = resnet(
192
+ x.to(torch.bfloat16), mask_down.to(torch.bfloat16), t.to(torch.bfloat16)
193
+ )
194
+ x = rearrange(x, "b c t -> b t c").contiguous()
195
+ # attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
196
+ for transformer_block in transformer_blocks:
197
+ x = transformer_block(
198
+ hidden_states=x,
199
+ # attention_mask=attn_mask,
200
+ timestep=t,
201
+ )
202
+ x = rearrange(x, "b t c -> b c t").contiguous()
203
+ hiddens.append(x) # Save hidden states for skip connections
204
+ x = downsample(x * mask_down)
205
+ masks.append(mask_down[:, :, ::2])
206
+ masks = masks[:-1]
207
+ mask_mid = masks[-1]
208
+
209
+ for resnet, transformer_blocks in self.mid_blocks:
210
+ x = resnet(x, mask_mid, t)
211
+ x = rearrange(x, "b c t -> b t c").contiguous()
212
+ # attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
213
+ for transformer_block in transformer_blocks:
214
+ x = transformer_block(
215
+ hidden_states=x,
216
+ # attention_mask=attn_mask,
217
+ timestep=t,
218
+ )
219
+ x = rearrange(x, "b t c -> b c t").contiguous()
220
+
221
+ for resnet, transformer_blocks, upsample in self.up_blocks:
222
+ mask_up = masks.pop()
223
+ skip = hiddens.pop()
224
+ x = pack([x[:, :, : skip.shape[-1]], skip], "b * t")[0]
225
+ x = resnet(x, mask_up, t)
226
+ x = rearrange(x, "b c t -> b t c").contiguous()
227
+ # attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
228
+ for transformer_block in transformer_blocks:
229
+ x = transformer_block(
230
+ hidden_states=x,
231
+ # attention_mask=attn_mask,
232
+ timestep=t,
233
+ )
234
+ x = rearrange(x, "b t c -> b c t").contiguous()
235
+ x = upsample(x * mask_up)
236
+ x = self.final_block(x, mask_up)
237
+ output = self.final_proj(x * mask_up)
238
+ return output * mask
cosyvoice/flow/flow.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import random
16
+ from typing import Dict, Optional
17
+ import torch
18
+ import torch.nn as nn
19
+ from torch.nn import functional as F
20
+ from omegaconf import DictConfig
21
+ from cosyvoice.utils.mask import make_pad_mask
22
+ import time
23
+
24
+
25
+ class MaskedDiffWithXvec(torch.nn.Module):
26
+ def __init__(
27
+ self,
28
+ input_size: int = 512,
29
+ output_size: int = 80,
30
+ spk_embed_dim: int = 192,
31
+ output_type: str = "mel",
32
+ vocab_size: int = 4096,
33
+ input_frame_rate: int = 50,
34
+ only_mask_loss: bool = True,
35
+ encoder: torch.nn.Module = None,
36
+ length_regulator: torch.nn.Module = None,
37
+ decoder: torch.nn.Module = None,
38
+ decoder_conf: Dict = {
39
+ "in_channels": 240,
40
+ "out_channel": 80,
41
+ "spk_emb_dim": 80,
42
+ "n_spks": 1,
43
+ "cfm_params": DictConfig(
44
+ {
45
+ "sigma_min": 1e-06,
46
+ "solver": "euler",
47
+ "t_scheduler": "cosine",
48
+ "training_cfg_rate": 0.2,
49
+ "inference_cfg_rate": 0.7,
50
+ "reg_loss_type": "l1",
51
+ }
52
+ ),
53
+ "decoder_params": {
54
+ "channels": [256, 256],
55
+ "dropout": 0.0,
56
+ "attention_head_dim": 64,
57
+ "n_blocks": 4,
58
+ "num_mid_blocks": 12,
59
+ "num_heads": 8,
60
+ "act_fn": "gelu",
61
+ },
62
+ },
63
+ mel_feat_conf: Dict = {
64
+ "n_fft": 1024,
65
+ "num_mels": 80,
66
+ "sampling_rate": 22050,
67
+ "hop_size": 256,
68
+ "win_size": 1024,
69
+ "fmin": 0,
70
+ "fmax": 8000,
71
+ },
72
+ ):
73
+ super().__init__()
74
+ self.input_size = input_size
75
+ self.output_size = output_size
76
+ self.decoder_conf = decoder_conf
77
+ self.mel_feat_conf = mel_feat_conf
78
+ self.vocab_size = vocab_size
79
+ self.output_type = output_type
80
+ self.input_frame_rate = input_frame_rate
81
+ logging.info(f"input frame rate={self.input_frame_rate}")
82
+ self.input_embedding = nn.Embedding(vocab_size, input_size)
83
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
84
+ self.encoder = encoder
85
+ self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
86
+ self.decoder = decoder
87
+ self.length_regulator = length_regulator
88
+ self.only_mask_loss = only_mask_loss
89
+
90
+ def forward(
91
+ self,
92
+ batch: dict,
93
+ device: torch.device,
94
+ ) -> Dict[str, Optional[torch.Tensor]]:
95
+ token = batch["speech_token"].to(device)
96
+ token_len = batch["speech_token_len"].to(device)
97
+ feat = batch["speech_feat"].to(device)
98
+ feat_len = batch["speech_feat_len"].to(device)
99
+ embedding = batch["embedding"].to(device)
100
+
101
+ # xvec projection
102
+ embedding = F.normalize(embedding, dim=1)
103
+ embedding = self.spk_embed_affine_layer(embedding)
104
+
105
+ # concat text and prompt_text
106
+ mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
107
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
108
+
109
+ # text encode
110
+ h, h_lengths = self.encoder(token, token_len)
111
+ h = self.encoder_proj(h)
112
+ h, h_lengths = self.length_regulator(h, feat_len)
113
+
114
+ # get conditions
115
+ conds = torch.zeros(feat.shape, device=token.device)
116
+ for i, j in enumerate(feat_len):
117
+ if random.random() < 0.5:
118
+ continue
119
+ index = random.randint(0, int(0.3 * j))
120
+ conds[i, :index] = feat[i, :index]
121
+ conds = conds.transpose(1, 2)
122
+
123
+ mask = (~make_pad_mask(feat_len)).to(h)
124
+ feat = F.interpolate(
125
+ feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest"
126
+ ).squeeze(dim=1)
127
+ loss, _ = self.decoder.compute_loss(
128
+ feat.transpose(1, 2).contiguous(),
129
+ mask.unsqueeze(1),
130
+ h.transpose(1, 2).contiguous(),
131
+ embedding,
132
+ cond=conds,
133
+ )
134
+ return {"loss": loss}
135
+
136
+ @torch.inference_mode()
137
+ def inference(
138
+ self,
139
+ token,
140
+ token_len,
141
+ prompt_token,
142
+ prompt_token_len,
143
+ prompt_feat,
144
+ prompt_feat_len,
145
+ embedding,
146
+ ):
147
+ assert token.shape[0] == 1
148
+ # xvec projection
149
+ embedding = F.normalize(embedding, dim=1)
150
+ embedding = self.spk_embed_affine_layer(embedding)
151
+
152
+ # concat text and prompt_text
153
+ token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
154
+ # text encode
155
+ token, token_len = (
156
+ torch.concat([prompt_token, token], dim=1),
157
+ prompt_token_len + token_len,
158
+ )
159
+ token = self.input_embedding(torch.clamp(token, min=0))
160
+ h, _ = self.encoder.inference(token, token_len)
161
+ h = self.encoder_proj(h)
162
+ mel_len1, mel_len2 = prompt_feat.shape[1], int(
163
+ token_len2
164
+ / self.input_frame_rate
165
+ * self.mel_feat_conf["sampling_rate"]
166
+ / self.mel_feat_conf["hop_size"]
167
+ )
168
+
169
+ h, _ = self.length_regulator.inference(
170
+ h[:, :token_len1],
171
+ h[:, token_len1:],
172
+ mel_len1,
173
+ mel_len2,
174
+ )
175
+
176
+ # get conditions
177
+ conds = torch.zeros(
178
+ [1, mel_len1 + mel_len2, self.output_size], device=token.device
179
+ )
180
+ conds[:, :mel_len1] = prompt_feat
181
+ conds = conds.transpose(1, 2)
182
+
183
+ # mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
184
+ mask = torch.ones(
185
+ [1, mel_len1 + mel_len2], device=h.device, dtype=torch.bfloat16
186
+ )
187
+ feat = self.decoder(
188
+ mu=h.transpose(1, 2).contiguous(),
189
+ mask=mask.unsqueeze(1),
190
+ spks=embedding,
191
+ cond=conds,
192
+ n_timesteps=10,
193
+ )
194
+ feat = feat[:, :, mel_len1:]
195
+ assert feat.shape[2] == mel_len2
196
+ return feat
cosyvoice/flow/flow_matching.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import time
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from cosyvoice.matcha.flow_matching import BASECFM
18
+
19
+
20
+ class ConditionalCFM(BASECFM):
21
+ def __init__(
22
+ self,
23
+ in_channels,
24
+ cfm_params,
25
+ n_spks=1,
26
+ spk_emb_dim=64,
27
+ estimator: torch.nn.Module = None,
28
+ ):
29
+ super().__init__(
30
+ n_feats=in_channels,
31
+ cfm_params=cfm_params,
32
+ n_spks=n_spks,
33
+ spk_emb_dim=spk_emb_dim,
34
+ )
35
+ self.t_scheduler = cfm_params.t_scheduler
36
+ self.training_cfg_rate = cfm_params.training_cfg_rate
37
+ self.inference_cfg_rate = cfm_params.inference_cfg_rate
38
+ in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
39
+ # Just change the architecture of the estimator here
40
+ self.estimator = estimator
41
+ self.inference_graphs = {}
42
+ self.inference_buffers = {}
43
+ # self.capture_inference()
44
+
45
+ @torch.inference_mode()
46
+ def forward(
47
+ self,
48
+ mu,
49
+ mask,
50
+ n_timesteps,
51
+ temperature=1.0,
52
+ spks=None,
53
+ cond=None,
54
+ ):
55
+ """Forward diffusion
56
+
57
+ Args:
58
+ mu (torch.Tensor): output of encoder
59
+ shape: (batch_size, n_feats, mel_timesteps)
60
+ mask (torch.Tensor): output_mask
61
+ shape: (batch_size, 1, mel_timesteps)
62
+ n_timesteps (int): number of diffusion steps
63
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
64
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
65
+ shape: (batch_size, spk_emb_dim)
66
+ cond: Not used but kept for future purposes
67
+
68
+ Returns:
69
+ sample: generated mel-spectrogram
70
+ shape: (batch_size, n_feats, mel_timesteps)
71
+ """
72
+ z = torch.randn_like(mu) * temperature
73
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
74
+ if self.t_scheduler == "cosine":
75
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
76
+ return self.solve_euler(
77
+ z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond
78
+ )
79
+
80
+ @torch.inference_mode()
81
+ def capture_inference(self, seq_len_to_capture=list(range(128, 512, 8))):
82
+ start_time = time.time()
83
+ print(
84
+ f"capture_inference for ConditionalCFM solve euler, seq_len_to_capture: {seq_len_to_capture}"
85
+ )
86
+ for seq_len in seq_len_to_capture:
87
+ static_z = torch.randn(
88
+ 1, 80, seq_len, device=torch.device("cuda"), dtype=torch.bfloat16
89
+ )
90
+ static_t_span = torch.linspace(
91
+ 0, 1, 11, device=torch.device("cuda"), dtype=torch.bfloat16
92
+ ) # only capture at 10 steps
93
+ static_mu = torch.randn(
94
+ 1, 80, seq_len, device=torch.device("cuda"), dtype=torch.bfloat16
95
+ )
96
+ static_mask = torch.ones(
97
+ 1, 1, seq_len, device=torch.device("cuda"), dtype=torch.bfloat16
98
+ )
99
+ static_spks = torch.randn(
100
+ 1, 80, device=torch.device("cuda"), dtype=torch.bfloat16
101
+ )
102
+ static_cond = torch.randn(
103
+ 1, 80, seq_len, device=torch.device("cuda"), dtype=torch.float32
104
+ )
105
+ static_out = torch.randn(
106
+ 1, 80, seq_len, device=torch.device("cuda"), dtype=torch.bfloat16
107
+ )
108
+
109
+ self._solve_euler_impl(
110
+ static_z,
111
+ t_span=static_t_span,
112
+ mu=static_mu,
113
+ mask=static_mask,
114
+ spks=static_spks,
115
+ cond=static_cond,
116
+ )
117
+ torch.cuda.synchronize()
118
+
119
+ g = torch.cuda.CUDAGraph()
120
+ with torch.cuda.graph(g):
121
+ static_out = self._solve_euler_impl(
122
+ static_z,
123
+ t_span=static_t_span,
124
+ mu=static_mu,
125
+ mask=static_mask,
126
+ spks=static_spks,
127
+ cond=static_cond,
128
+ )
129
+
130
+ self.inference_buffers[seq_len] = {
131
+ "z": static_z,
132
+ "t_span": static_t_span,
133
+ "mu": static_mu,
134
+ "mask": static_mask,
135
+ "spks": static_spks,
136
+ "cond": static_cond,
137
+ "out": static_out,
138
+ }
139
+ self.inference_graphs[seq_len] = g
140
+ end_time = time.time()
141
+ print(
142
+ f"capture_inference for ConditionalCFM solve euler, time elapsed: {end_time - start_time}"
143
+ )
144
+
145
+ def solve_euler(self, x, t_span, mu, mask, spks, cond):
146
+ if hasattr(self, "inference_graphs") and len(self.inference_graphs) > 0:
147
+ curr_seq_len = x.shape[2]
148
+
149
+ available_lengths = sorted(list(self.inference_graphs.keys()))
150
+
151
+ if curr_seq_len <= max(available_lengths):
152
+ target_len = min(available_lengths, key=lambda x: abs(x - curr_seq_len))
153
+ if target_len == curr_seq_len:
154
+ padded_x = x
155
+ padded_mu = mu
156
+ padded_mask = mask
157
+ if cond is not None:
158
+ padded_cond = cond
159
+ else:
160
+ padded_x = torch.randn(
161
+ (x.shape[0], x.shape[1], target_len),
162
+ dtype=x.dtype,
163
+ device=x.device,
164
+ )
165
+ padded_x[:, :, :curr_seq_len] = x
166
+
167
+ padded_mu = torch.randn(
168
+ (mu.shape[0], mu.shape[1], target_len),
169
+ dtype=mu.dtype,
170
+ device=mu.device,
171
+ )
172
+ padded_mu[:, :, :curr_seq_len] = mu
173
+
174
+ # FIXME(ys): uses zeros and maskgroupnorm
175
+ padded_mask = torch.ones(
176
+ (mask.shape[0], mask.shape[1], target_len),
177
+ dtype=mask.dtype,
178
+ device=mask.device,
179
+ )
180
+
181
+ if cond is not None:
182
+ padded_cond = torch.randn(
183
+ (cond.shape[0], cond.shape[1], target_len),
184
+ dtype=cond.dtype,
185
+ device=cond.device,
186
+ )
187
+ padded_cond[:, :, :curr_seq_len] = cond
188
+
189
+ buffer = self.inference_buffers[target_len]
190
+ buffer["z"].copy_(padded_x)
191
+ buffer["t_span"].copy_(t_span)
192
+ buffer["mu"].copy_(padded_mu)
193
+ buffer["mask"].copy_(padded_mask)
194
+ buffer["spks"].copy_(spks)
195
+ if cond is not None:
196
+ buffer["cond"].copy_(padded_cond)
197
+
198
+ self.inference_graphs[target_len].replay()
199
+
200
+ output = buffer["out"][:, :, :curr_seq_len]
201
+ return output
202
+
203
+ return self._solve_euler_impl(x, t_span, mu, mask, spks, cond)
204
+
205
+ def _solve_euler_impl(self, x, t_span, mu, mask, spks, cond):
206
+ """
207
+ Fixed euler solver for ODEs.
208
+ Args:
209
+ x (torch.Tensor): random noise
210
+ t_span (torch.Tensor): n_timesteps interpolated
211
+ shape: (n_timesteps + 1,)
212
+ mu (torch.Tensor): output of encoder
213
+ shape: (batch_size, n_feats, mel_timesteps)
214
+ mask (torch.Tensor): output_mask
215
+ shape: (batch_size, 1, mel_timesteps)
216
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
217
+ shape: (batch_size, spk_emb_dim)
218
+ cond: Not used but kept for future purposes
219
+ """
220
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
221
+ t = t.unsqueeze(dim=0)
222
+
223
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
224
+ # Or in future might add like a return_all_steps flag
225
+ sol = []
226
+
227
+ for step in range(1, len(t_span)):
228
+ if self.inference_cfg_rate > 0:
229
+ x_double = torch.cat([x, x], dim=0)
230
+ mask_double = torch.cat([mask, mask], dim=0)
231
+ mu_double = torch.cat([mu, torch.zeros_like(mu)], dim=0)
232
+ t_double = torch.cat([t, t], dim=0)
233
+ spks_double = (
234
+ torch.cat([spks, torch.zeros_like(spks)], dim=0)
235
+ if spks is not None
236
+ else None
237
+ )
238
+ cond_double = torch.cat([cond, torch.zeros_like(cond)], dim=0)
239
+
240
+ dphi_dt_double = self.forward_estimator(
241
+ x_double, mask_double, mu_double, t_double, spks_double, cond_double
242
+ )
243
+
244
+ dphi_dt, cfg_dphi_dt = torch.chunk(dphi_dt_double, 2, dim=0)
245
+ dphi_dt = (
246
+ 1.0 + self.inference_cfg_rate
247
+ ) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt
248
+ else:
249
+ dphi_dt = self.forward_estimator(x, mask, mu, t, spks, cond)
250
+
251
+ x = x + dt * dphi_dt
252
+ t = t + dt
253
+ sol.append(x)
254
+ if step < len(t_span) - 1:
255
+ dt = t_span[step + 1] - t
256
+
257
+ return sol[-1]
258
+
259
+ def forward_estimator(self, x, mask, mu, t, spks, cond):
260
+ if isinstance(self.estimator, torch.nn.Module):
261
+ return self.estimator.forward(x, mask, mu, t, spks, cond)
262
+ else:
263
+ ort_inputs = {
264
+ "x": x.cpu().numpy(),
265
+ "mask": mask.cpu().numpy(),
266
+ "mu": mu.cpu().numpy(),
267
+ "t": t.cpu().numpy(),
268
+ "spks": spks.cpu().numpy(),
269
+ "cond": cond.cpu().numpy(),
270
+ }
271
+ output = self.estimator.run(None, ort_inputs)[0]
272
+ return torch.tensor(output, dtype=x.dtype, device=x.device)
273
+
274
+ def compute_loss(self, x1, mask, mu, spks=None, cond=None):
275
+ """Computes diffusion loss
276
+
277
+ Args:
278
+ x1 (torch.Tensor): Target
279
+ shape: (batch_size, n_feats, mel_timesteps)
280
+ mask (torch.Tensor): target mask
281
+ shape: (batch_size, 1, mel_timesteps)
282
+ mu (torch.Tensor): output of encoder
283
+ shape: (batch_size, n_feats, mel_timesteps)
284
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
285
+ shape: (batch_size, spk_emb_dim)
286
+
287
+ Returns:
288
+ loss: conditional flow matching loss
289
+ y: conditional flow
290
+ shape: (batch_size, n_feats, mel_timesteps)
291
+ """
292
+ b, _, t = mu.shape
293
+
294
+ # random timestep
295
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
296
+ if self.t_scheduler == "cosine":
297
+ t = 1 - torch.cos(t * 0.5 * torch.pi)
298
+ # sample noise p(x_0)
299
+ z = torch.randn_like(x1)
300
+
301
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
302
+ u = x1 - (1 - self.sigma_min) * z
303
+
304
+ # during training, we randomly drop condition to trade off mode coverage and sample fidelity
305
+ if self.training_cfg_rate > 0:
306
+ cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
307
+ mu = mu * cfg_mask.view(-1, 1, 1)
308
+ spks = spks * cfg_mask.view(-1, 1)
309
+ cond = cond * cfg_mask.view(-1, 1, 1)
310
+
311
+ pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
312
+ loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (
313
+ torch.sum(mask) * u.shape[1]
314
+ )
315
+ return loss, y
cosyvoice/flow/length_regulator.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Tuple
15
+ import torch.nn as nn
16
+ import torch
17
+ from torch.nn import functional as F
18
+ from cosyvoice.utils.mask import make_pad_mask
19
+
20
+
21
+ class InterpolateRegulator(nn.Module):
22
+ def __init__(
23
+ self,
24
+ channels: int,
25
+ sampling_ratios: Tuple,
26
+ out_channels: int = None,
27
+ groups: int = 1,
28
+ ):
29
+ super().__init__()
30
+ self.sampling_ratios = sampling_ratios
31
+ out_channels = out_channels or channels
32
+ model = nn.ModuleList([])
33
+ if len(sampling_ratios) > 0:
34
+ for _ in sampling_ratios:
35
+ module = nn.Conv1d(channels, channels, 3, 1, 1)
36
+ norm = nn.GroupNorm(groups, channels)
37
+ act = nn.Mish()
38
+ model.extend([module, norm, act])
39
+ model.append(nn.Conv1d(channels, out_channels, 1, 1))
40
+ self.model = nn.Sequential(*model)
41
+
42
+ def forward(self, x, ylens=None):
43
+ # x in (B, T, D)
44
+ mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
45
+ x = F.interpolate(
46
+ x.transpose(1, 2).contiguous(), size=ylens.max(), mode="linear"
47
+ )
48
+ out = self.model(x).transpose(1, 2).contiguous()
49
+ olens = ylens
50
+ return out * mask, olens
51
+
52
+ def inference(self, x1, x2, mel_len1, mel_len2):
53
+ # x in (B, T, D)
54
+ x2 = F.interpolate(
55
+ x2.transpose(1, 2).contiguous(), size=mel_len2, mode="linear"
56
+ )
57
+ if x1.shape[1] != 0:
58
+ x1 = F.interpolate(
59
+ x1.transpose(1, 2).contiguous(), size=mel_len1, mode="linear"
60
+ )
61
+ x = torch.concat([x1, x2], dim=2)
62
+ else:
63
+ x = x2
64
+ out = self.model(x).transpose(1, 2).contiguous()
65
+ return out, mel_len1 + mel_len2
cosyvoice/hifigan/f0_predictor.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.nn.utils import weight_norm
17
+
18
+
19
+ class ConvRNNF0Predictor(nn.Module):
20
+ def __init__(
21
+ self, num_class: int = 1, in_channels: int = 80, cond_channels: int = 512
22
+ ):
23
+ super().__init__()
24
+
25
+ self.num_class = num_class
26
+ self.condnet = nn.Sequential(
27
+ weight_norm(
28
+ nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
29
+ ),
30
+ nn.ELU(),
31
+ weight_norm(
32
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
33
+ ),
34
+ nn.ELU(),
35
+ weight_norm(
36
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
37
+ ),
38
+ nn.ELU(),
39
+ weight_norm(
40
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
41
+ ),
42
+ nn.ELU(),
43
+ weight_norm(
44
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
45
+ ),
46
+ nn.ELU(),
47
+ )
48
+ self.classifier = nn.Linear(
49
+ in_features=cond_channels, out_features=self.num_class
50
+ )
51
+
52
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
53
+ x = self.condnet(x)
54
+ x = x.transpose(1, 2)
55
+ return torch.abs(self.classifier(x).squeeze(-1))
cosyvoice/hifigan/generator.py ADDED
@@ -0,0 +1,566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """HIFI-GAN"""
16
+
17
+ import typing as tp
18
+ import time
19
+ import numpy as np
20
+ from scipy.signal import get_window
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from torch.nn import Conv1d
25
+ from torch.nn import ConvTranspose1d
26
+ from torch.nn.utils import remove_weight_norm
27
+ from torch.nn.utils import weight_norm
28
+ from torch.distributions.uniform import Uniform
29
+
30
+ from cosyvoice.transformer.activation import Snake
31
+ from cosyvoice.utils.common import get_padding
32
+ from cosyvoice.utils.common import init_weights
33
+
34
+
35
+ """hifigan based generator implementation.
36
+
37
+ This code is modified from https://github.com/jik876/hifi-gan
38
+ ,https://github.com/kan-bayashi/ParallelWaveGAN and
39
+ https://github.com/NVIDIA/BigVGAN
40
+
41
+ """
42
+
43
+
44
+ class ResBlock(torch.nn.Module):
45
+ """Residual block module in HiFiGAN/BigVGAN."""
46
+
47
+ def __init__(
48
+ self,
49
+ channels: int = 512,
50
+ kernel_size: int = 3,
51
+ dilations: tp.List[int] = [1, 3, 5],
52
+ ):
53
+ super(ResBlock, self).__init__()
54
+ self.convs1 = nn.ModuleList()
55
+ self.convs2 = nn.ModuleList()
56
+
57
+ for dilation in dilations:
58
+ self.convs1.append(
59
+ weight_norm(
60
+ Conv1d(
61
+ channels,
62
+ channels,
63
+ kernel_size,
64
+ 1,
65
+ dilation=dilation,
66
+ padding=get_padding(kernel_size, dilation),
67
+ )
68
+ )
69
+ )
70
+ self.convs2.append(
71
+ weight_norm(
72
+ Conv1d(
73
+ channels,
74
+ channels,
75
+ kernel_size,
76
+ 1,
77
+ dilation=1,
78
+ padding=get_padding(kernel_size, 1),
79
+ )
80
+ )
81
+ )
82
+ self.convs1.apply(init_weights)
83
+ self.convs2.apply(init_weights)
84
+ self.activations1 = nn.ModuleList(
85
+ [Snake(channels, alpha_logscale=False) for _ in range(len(self.convs1))]
86
+ )
87
+ self.activations2 = nn.ModuleList(
88
+ [Snake(channels, alpha_logscale=False) for _ in range(len(self.convs2))]
89
+ )
90
+
91
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
92
+ for idx in range(len(self.convs1)):
93
+ xt = self.activations1[idx](x)
94
+ xt = self.convs1[idx](xt)
95
+ xt = self.activations2[idx](xt)
96
+ xt = self.convs2[idx](xt)
97
+ x = xt + x
98
+ return x
99
+
100
+ def remove_weight_norm(self):
101
+ for idx in range(len(self.convs1)):
102
+ remove_weight_norm(self.convs1[idx])
103
+ remove_weight_norm(self.convs2[idx])
104
+
105
+
106
+ class SineGen(torch.nn.Module):
107
+ """Definition of sine generator
108
+ SineGen(samp_rate, harmonic_num = 0,
109
+ sine_amp = 0.1, noise_std = 0.003,
110
+ voiced_threshold = 0,
111
+ flag_for_pulse=False)
112
+ samp_rate: sampling rate in Hz
113
+ harmonic_num: number of harmonic overtones (default 0)
114
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
115
+ noise_std: std of Gaussian noise (default 0.003)
116
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
117
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
118
+ Note: when flag_for_pulse is True, the first time step of a voiced
119
+ segment is always sin(np.pi) or cos(0)
120
+ """
121
+
122
+ def __init__(
123
+ self,
124
+ samp_rate,
125
+ harmonic_num=0,
126
+ sine_amp=0.1,
127
+ noise_std=0.003,
128
+ voiced_threshold=0,
129
+ ):
130
+ super(SineGen, self).__init__()
131
+ self.sine_amp = sine_amp
132
+ self.noise_std = noise_std
133
+ self.harmonic_num = harmonic_num
134
+ self.sampling_rate = samp_rate
135
+ self.voiced_threshold = voiced_threshold
136
+
137
+ def _f02uv(self, f0):
138
+ # generate uv signal
139
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
140
+ return uv
141
+
142
+ @torch.no_grad()
143
+ def forward(self, f0):
144
+ """
145
+ :param f0: [B, 1, sample_len], Hz
146
+ :return: [B, 1, sample_len]
147
+ """
148
+
149
+ F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(
150
+ f0.device
151
+ )
152
+ for i in range(self.harmonic_num + 1):
153
+ F_mat[:, i : i + 1, :] = f0 * (i + 1) / self.sampling_rate
154
+
155
+ theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
156
+ u_dist = Uniform(low=-np.pi, high=np.pi)
157
+ phase_vec = u_dist.sample(
158
+ sample_shape=(f0.size(0), self.harmonic_num + 1, 1)
159
+ ).to(F_mat.device)
160
+ phase_vec[:, 0, :] = 0
161
+
162
+ # generate sine waveforms
163
+ sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
164
+
165
+ # generate uv signal
166
+ uv = self._f02uv(f0)
167
+
168
+ # noise: for unvoiced should be similar to sine_amp
169
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
170
+ # . for voiced regions is self.noise_std
171
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
172
+ noise = noise_amp * torch.randn_like(sine_waves)
173
+
174
+ # first: set the unvoiced part to 0 by uv
175
+ # then: additive noise
176
+ sine_waves = sine_waves * uv + noise
177
+ return sine_waves, uv, noise
178
+
179
+
180
+ class SourceModuleHnNSF(torch.nn.Module):
181
+ """SourceModule for hn-nsf
182
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
183
+ add_noise_std=0.003, voiced_threshod=0)
184
+ sampling_rate: sampling_rate in Hz
185
+ harmonic_num: number of harmonic above F0 (default: 0)
186
+ sine_amp: amplitude of sine source signal (default: 0.1)
187
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
188
+ note that amplitude of noise in unvoiced is decided
189
+ by sine_amp
190
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
191
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
192
+ F0_sampled (batchsize, length, 1)
193
+ Sine_source (batchsize, length, 1)
194
+ noise_source (batchsize, length 1)
195
+ uv (batchsize, length, 1)
196
+ """
197
+
198
+ def __init__(
199
+ self,
200
+ sampling_rate,
201
+ upsample_scale,
202
+ harmonic_num=0,
203
+ sine_amp=0.1,
204
+ add_noise_std=0.003,
205
+ voiced_threshod=0,
206
+ ):
207
+ super(SourceModuleHnNSF, self).__init__()
208
+
209
+ self.sine_amp = sine_amp
210
+ self.noise_std = add_noise_std
211
+
212
+ # to produce sine waveforms
213
+ self.l_sin_gen = SineGen(
214
+ sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod
215
+ )
216
+
217
+ # to merge source harmonics into a single excitation
218
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
219
+ self.l_tanh = torch.nn.Tanh()
220
+
221
+ def forward(self, x):
222
+ """
223
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
224
+ F0_sampled (batchsize, length, 1)
225
+ Sine_source (batchsize, length, 1)
226
+ noise_source (batchsize, length 1)
227
+ """
228
+ # source for harmonic branch
229
+ with torch.no_grad():
230
+ sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
231
+ sine_wavs = sine_wavs.transpose(1, 2)
232
+ uv = uv.transpose(1, 2)
233
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
234
+
235
+ # source for noise branch, in the same shape as uv
236
+ noise = torch.randn_like(uv) * self.sine_amp / 3
237
+ return sine_merge, noise, uv
238
+
239
+
240
+ class HiFTGenerator(nn.Module):
241
+ """
242
+ HiFTNet Generator: Neural Source Filter + ISTFTNet
243
+ https://arxiv.org/abs/2309.09493
244
+ """
245
+
246
+ def __init__(
247
+ self,
248
+ in_channels: int = 80,
249
+ base_channels: int = 512,
250
+ nb_harmonics: int = 8,
251
+ sampling_rate: int = 22050,
252
+ nsf_alpha: float = 0.1,
253
+ nsf_sigma: float = 0.003,
254
+ nsf_voiced_threshold: float = 10,
255
+ upsample_rates: tp.List[int] = [8, 8],
256
+ upsample_kernel_sizes: tp.List[int] = [16, 16],
257
+ istft_params: tp.Dict[str, int] = {"n_fft": 16, "hop_len": 4},
258
+ resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
259
+ resblock_dilation_sizes: tp.List[tp.List[int]] = [
260
+ [1, 3, 5],
261
+ [1, 3, 5],
262
+ [1, 3, 5],
263
+ ],
264
+ source_resblock_kernel_sizes: tp.List[int] = [7, 11],
265
+ source_resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5]],
266
+ lrelu_slope: float = 0.1,
267
+ audio_limit: float = 0.99,
268
+ f0_predictor: torch.nn.Module = None,
269
+ ):
270
+ super(HiFTGenerator, self).__init__()
271
+
272
+ self.out_channels = 1
273
+ self.nb_harmonics = nb_harmonics
274
+ self.sampling_rate = sampling_rate
275
+ self.istft_params = istft_params
276
+ self.lrelu_slope = lrelu_slope
277
+ self.audio_limit = audio_limit
278
+
279
+ self.num_kernels = len(resblock_kernel_sizes)
280
+ self.num_upsamples = len(upsample_rates)
281
+ self.upsample_rates = upsample_rates
282
+ self.m_source = SourceModuleHnNSF(
283
+ sampling_rate=sampling_rate,
284
+ upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
285
+ harmonic_num=nb_harmonics,
286
+ sine_amp=nsf_alpha,
287
+ add_noise_std=nsf_sigma,
288
+ voiced_threshod=nsf_voiced_threshold,
289
+ )
290
+ self.f0_upsamp = torch.nn.Upsample(
291
+ scale_factor=np.prod(upsample_rates) * istft_params["hop_len"]
292
+ )
293
+
294
+ self.conv_pre = weight_norm(Conv1d(in_channels, base_channels, 7, 1, padding=3))
295
+
296
+ # Up
297
+ self.ups = nn.ModuleList()
298
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
299
+ self.ups.append(
300
+ weight_norm(
301
+ ConvTranspose1d(
302
+ base_channels // (2**i),
303
+ base_channels // (2 ** (i + 1)),
304
+ k,
305
+ u,
306
+ padding=(k - u) // 2,
307
+ )
308
+ )
309
+ )
310
+
311
+ # Down
312
+ self.source_downs = nn.ModuleList()
313
+ self.source_resblocks = nn.ModuleList()
314
+ downsample_rates = [1] + upsample_rates[::-1][:-1]
315
+ downsample_cum_rates = np.cumprod(downsample_rates)
316
+ for i, (u, k, d) in enumerate(
317
+ zip(
318
+ downsample_cum_rates[::-1],
319
+ source_resblock_kernel_sizes,
320
+ source_resblock_dilation_sizes,
321
+ )
322
+ ):
323
+ if u == 1:
324
+ self.source_downs.append(
325
+ Conv1d(
326
+ istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1
327
+ )
328
+ )
329
+ else:
330
+ self.source_downs.append(
331
+ Conv1d(
332
+ istft_params["n_fft"] + 2,
333
+ base_channels // (2 ** (i + 1)),
334
+ u * 2,
335
+ u,
336
+ padding=(u // 2),
337
+ )
338
+ )
339
+
340
+ self.source_resblocks.append(
341
+ ResBlock(base_channels // (2 ** (i + 1)), k, d)
342
+ )
343
+
344
+ self.resblocks = nn.ModuleList()
345
+ for i in range(len(self.ups)):
346
+ ch = base_channels // (2 ** (i + 1))
347
+ for _, (k, d) in enumerate(
348
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
349
+ ):
350
+ self.resblocks.append(ResBlock(ch, k, d))
351
+
352
+ self.conv_post = weight_norm(
353
+ Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3)
354
+ )
355
+ self.ups.apply(init_weights)
356
+ self.conv_post.apply(init_weights)
357
+ self.reflection_pad = nn.ReflectionPad1d((1, 0))
358
+ self.stft_window = torch.from_numpy(
359
+ get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32)
360
+ ).cuda()
361
+ self.f0_predictor = f0_predictor
362
+ self.inference_buffers = {}
363
+ self.inference_graphs = {}
364
+
365
+ def _f02source(self, f0: torch.Tensor) -> torch.Tensor:
366
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
367
+
368
+ har_source, _, _ = self.m_source(f0)
369
+ return har_source.transpose(1, 2)
370
+
371
+ def _stft(self, x):
372
+ spec = torch.stft(
373
+ x,
374
+ self.istft_params["n_fft"],
375
+ self.istft_params["hop_len"],
376
+ self.istft_params["n_fft"],
377
+ window=self.stft_window,
378
+ return_complex=True,
379
+ )
380
+ spec = torch.view_as_real(spec) # [B, F, TT, 2]
381
+ return spec[..., 0], spec[..., 1]
382
+
383
+ def _istft(self, magnitude, phase):
384
+ magnitude = torch.clip(magnitude, max=1e2)
385
+ real = magnitude * torch.cos(phase)
386
+ img = magnitude * torch.sin(phase)
387
+ inverse_transform = torch.istft(
388
+ torch.complex(real, img),
389
+ self.istft_params["n_fft"],
390
+ self.istft_params["hop_len"],
391
+ self.istft_params["n_fft"],
392
+ window=self.stft_window,
393
+ )
394
+ return inverse_transform
395
+
396
+ def forward(
397
+ self, x: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)
398
+ ) -> torch.Tensor:
399
+ f0 = self.f0_predictor(x)
400
+ s = self._f02source(f0)
401
+
402
+ # use cache_source to avoid glitch
403
+ if cache_source.shape[2] != 0:
404
+ s[:, :, : cache_source.shape[2]] = cache_source
405
+
406
+ s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
407
+ s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
408
+
409
+ x = self.conv_pre(x)
410
+ for i in range(self.num_upsamples):
411
+ x = F.leaky_relu(x, self.lrelu_slope)
412
+ x = self.ups[i](x)
413
+
414
+ if i == self.num_upsamples - 1:
415
+ x = self.reflection_pad(x)
416
+
417
+ # fusion
418
+ si = self.source_downs[i](s_stft)
419
+ si = self.source_resblocks[i](si)
420
+ x = x + si
421
+
422
+ xs = None
423
+ for j in range(self.num_kernels):
424
+ if xs is None:
425
+ xs = self.resblocks[i * self.num_kernels + j](x)
426
+ else:
427
+ xs += self.resblocks[i * self.num_kernels + j](x)
428
+ x = xs / self.num_kernels
429
+
430
+ x = F.leaky_relu(x)
431
+ x = self.conv_post(x)
432
+ magnitude = torch.exp(x[:, : self.istft_params["n_fft"] // 2 + 1, :])
433
+ phase = torch.sin(
434
+ x[:, self.istft_params["n_fft"] // 2 + 1 :, :]
435
+ ) # actually, sin is redundancy
436
+
437
+ x = self._istft(magnitude, phase)
438
+ x = torch.clamp(x, -self.audio_limit, self.audio_limit)
439
+ return x, s
440
+
441
+ def remove_weight_norm(self):
442
+ print("Removing weight norm...")
443
+ for l in self.ups:
444
+ remove_weight_norm(l)
445
+ for l in self.resblocks:
446
+ l.remove_weight_norm()
447
+ remove_weight_norm(self.conv_pre)
448
+ remove_weight_norm(self.conv_post)
449
+ self.source_module.remove_weight_norm()
450
+ for l in self.source_downs:
451
+ remove_weight_norm(l)
452
+ for l in self.source_resblocks:
453
+ l.remove_weight_norm()
454
+
455
+ @torch.inference_mode()
456
+ def _inference_impl(self, mel: torch.Tensor, s_stft: torch.Tensor) -> torch.Tensor:
457
+ x = self.conv_pre(mel)
458
+ for i in range(self.num_upsamples):
459
+ x = F.leaky_relu(x, self.lrelu_slope)
460
+ x = self.ups[i](x)
461
+
462
+ if i == self.num_upsamples - 1:
463
+ x = self.reflection_pad(x)
464
+
465
+ # fusion
466
+ si = self.source_downs[i](s_stft)
467
+ si = self.source_resblocks[i](si)
468
+ x = x + si
469
+
470
+ xs = None
471
+ for j in range(self.num_kernels):
472
+ if xs is None:
473
+ xs = self.resblocks[i * self.num_kernels + j](x)
474
+ else:
475
+ xs += self.resblocks[i * self.num_kernels + j](x)
476
+ x = xs / self.num_kernels
477
+
478
+ x = F.leaky_relu(x)
479
+ x = self.conv_post(x)
480
+ magnitude = torch.exp(x[:, : self.istft_params["n_fft"] // 2 + 1, :])
481
+ phase = torch.sin(
482
+ x[:, self.istft_params["n_fft"] // 2 + 1 :, :]
483
+ ) # actually, sin is redundancy
484
+ # print(f"mel: {mel.shape}, magnitude: {magnitude.shape}, phase: {phase.shape}")
485
+ return magnitude, phase
486
+
487
+ @torch.inference_mode()
488
+ def inference(
489
+ self, mel: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)
490
+ ) -> torch.Tensor:
491
+ curr_seq_len = mel.shape[2]
492
+ f0 = self.f0_predictor(mel)
493
+ s = self._f02source(f0)
494
+ s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
495
+ s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
496
+
497
+ target_len = None
498
+ for seq_len in sorted(self.inference_buffers.keys()):
499
+ if curr_seq_len <= seq_len:
500
+ target_len = seq_len
501
+ break
502
+
503
+ if target_len is not None:
504
+ buffer = self.inference_buffers[target_len]
505
+
506
+ if curr_seq_len < target_len:
507
+ padded_mel = torch.zeros_like(buffer["mel"])
508
+ padded_mel[:, :, :curr_seq_len] = mel
509
+ buffer["mel"].copy_(padded_mel)
510
+ padded_s_stft = torch.zeros_like(buffer["s_stft"])
511
+ cur_s_stft_len = s_stft.shape[2]
512
+ padded_s_stft[:, :, :cur_s_stft_len] = s_stft
513
+ buffer["s_stft"].copy_(padded_s_stft)
514
+
515
+ else:
516
+ buffer["mel"].copy_(mel)
517
+ buffer["s_stft"].copy_(s_stft)
518
+ cur_s_stft_len = s_stft.shape[2]
519
+
520
+ self.inference_graphs[target_len].replay()
521
+
522
+ magnitude, phase = (
523
+ buffer["magnitude"][:, :, :cur_s_stft_len],
524
+ buffer["phase"][:, :, :cur_s_stft_len],
525
+ )
526
+ else:
527
+ magnitude, phase = self._inference_impl(mel=mel, s_stft=s_stft)
528
+
529
+ x = self._istft(magnitude, phase)
530
+ x = torch.clamp(x, -self.audio_limit, self.audio_limit)
531
+ return x, s
532
+
533
+ @torch.inference_mode()
534
+ def capture_inference(self, seq_len_to_capture=[64, 128, 256, 512, 1024]):
535
+ start_time = time.time()
536
+ print(
537
+ f"capture inference for HiFTGenerator with seq_len_to_capture: {seq_len_to_capture}"
538
+ )
539
+ for seq_len in seq_len_to_capture:
540
+ mel = torch.randn(
541
+ 1, 80, seq_len, device=torch.device("cuda"), dtype=torch.float32
542
+ )
543
+ f0 = self.f0_predictor(mel)
544
+ s = self._f02source(f0)
545
+ s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
546
+ s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
547
+
548
+ magnitude, phase = self._inference_impl(mel=mel, s_stft=s_stft)
549
+ torch.cuda.synchronize()
550
+
551
+ g = torch.cuda.CUDAGraph()
552
+ with torch.cuda.graph(g):
553
+ magnitude, phase = self._inference_impl(mel=mel, s_stft=s_stft)
554
+ inference_buffer = {
555
+ "mel": mel,
556
+ "s_stft": s_stft,
557
+ "magnitude": magnitude,
558
+ "phase": phase,
559
+ }
560
+ self.inference_buffers[seq_len] = inference_buffer
561
+ self.inference_graphs[seq_len] = g
562
+
563
+ end_time = time.time()
564
+ print(
565
+ f"capture inference for HiFTGenerator with seq_len_to_capture: {seq_len_to_capture} takes {end_time - start_time} seconds"
566
+ )
cosyvoice/matcha/audio.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.utils.data
4
+ from librosa.filters import mel as librosa_mel_fn
5
+ from scipy.io.wavfile import read
6
+
7
+ MAX_WAV_VALUE = 32768.0
8
+
9
+
10
+ def load_wav(full_path):
11
+ sampling_rate, data = read(full_path)
12
+ return data, sampling_rate
13
+
14
+
15
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
16
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
17
+
18
+
19
+ def dynamic_range_decompression(x, C=1):
20
+ return np.exp(x) / C
21
+
22
+
23
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
24
+ return torch.log(torch.clamp(x, min=clip_val) * C)
25
+
26
+
27
+ def dynamic_range_decompression_torch(x, C=1):
28
+ return torch.exp(x) / C
29
+
30
+
31
+ def spectral_normalize_torch(magnitudes):
32
+ output = dynamic_range_compression_torch(magnitudes)
33
+ return output
34
+
35
+
36
+ def spectral_de_normalize_torch(magnitudes):
37
+ output = dynamic_range_decompression_torch(magnitudes)
38
+ return output
39
+
40
+
41
+ mel_basis = {}
42
+ hann_window = {}
43
+
44
+
45
+ def mel_spectrogram(
46
+ y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
47
+ ):
48
+ if torch.min(y) < -1.0:
49
+ print("min value is ", torch.min(y))
50
+ if torch.max(y) > 1.0:
51
+ print("max value is ", torch.max(y))
52
+
53
+ global mel_basis, hann_window # pylint: disable=global-statement
54
+ if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
55
+ mel = librosa_mel_fn(
56
+ sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
57
+ )
58
+ mel_basis[str(fmax) + "_" + str(y.device)] = (
59
+ torch.from_numpy(mel).float().to(y.device)
60
+ )
61
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
62
+
63
+ y = torch.nn.functional.pad(
64
+ y.unsqueeze(1),
65
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
66
+ mode="reflect",
67
+ )
68
+ y = y.squeeze(1)
69
+
70
+ spec = torch.view_as_real(
71
+ torch.stft(
72
+ y,
73
+ n_fft,
74
+ hop_length=hop_size,
75
+ win_length=win_size,
76
+ window=hann_window[str(y.device)],
77
+ center=center,
78
+ pad_mode="reflect",
79
+ normalized=False,
80
+ onesided=True,
81
+ return_complex=True,
82
+ )
83
+ )
84
+
85
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
86
+
87
+ spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
88
+ spec = spectral_normalize_torch(spec)
89
+
90
+ return spec
cosyvoice/matcha/decoder.py ADDED
@@ -0,0 +1,511 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from conformer import ConformerBlock
8
+ from diffusers.models.activations import get_activation
9
+ from einops import pack, rearrange, repeat
10
+
11
+ from cosyvoice.matcha.transformer import BasicTransformerBlock
12
+
13
+
14
+ class SinusoidalPosEmb(torch.nn.Module):
15
+ def __init__(self, dim):
16
+ super().__init__()
17
+ self.dim = dim
18
+ assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
19
+
20
+ def forward(self, x, scale=1000):
21
+ if x.ndim < 1:
22
+ x = x.unsqueeze(0)
23
+ device = x.device
24
+ half_dim = self.dim // 2
25
+ emb = math.log(10000) / (half_dim - 1)
26
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
27
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
28
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
29
+ return emb
30
+
31
+
32
+ class MaskedGroupNorm(nn.GroupNorm):
33
+ """
34
+ Masked verstion of the Group normalization.
35
+
36
+ Based on: https://github.com/ptrblck/pytorch_misc/blob/20e8ea93bd458b88f921a87e2d4001a4eb753a02/batch_norm_manual.py
37
+
38
+ Receives a N-dim tensor of sequence lengths per batch element
39
+ along with the regular input for masking.
40
+
41
+ Check pytorch's GroupNorm implementation for argument details.
42
+ """
43
+
44
+ def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
45
+ super(MaskedGroupNorm, self).__init__(num_groups, num_channels, eps, affine)
46
+
47
+ def forward(self, inp, mask=None):
48
+ assert (
49
+ inp.shape[1] % self.num_groups == 0
50
+ ), "Feature size not divisible by groups"
51
+
52
+ # 计算有效长度
53
+ seq_lengths = mask.sum(-1, keepdim=True) # [batch_size, 1]
54
+
55
+ # 将输入reshape为groups
56
+ features_per_group = inp.shape[1] // self.num_groups
57
+ inp_r = inp.reshape(
58
+ inp.shape[0], self.num_groups, features_per_group, inp.shape[-1]
59
+ )
60
+ mask_r = mask.unsqueeze(1) # [batch_size, 1, 1, length]
61
+
62
+ # 计算masked mean和variance
63
+ masked_inp = inp_r * mask_r
64
+ n = seq_lengths * features_per_group # 每组的有效元素数量
65
+ mean = masked_inp.sum([2, 3], keepdim=True) / (n.view(-1, 1, 1, 1) + 1e-5)
66
+ var = ((masked_inp - mean * mask_r) ** 2).sum([2, 3], keepdim=True) / (
67
+ n.view(-1, 1, 1, 1) + 1e-5
68
+ )
69
+
70
+ # 标准化
71
+ inp_r = (inp_r - mean) / (torch.sqrt(var + self.eps))
72
+ out = inp_r.reshape(inp.shape[0], self.num_channels, inp.shape[-1])
73
+
74
+ # 应用仿射变换
75
+ if self.affine:
76
+ out = out * self.weight[None, :, None] + self.bias[None, :, None]
77
+
78
+ return out
79
+
80
+
81
+ class Block1D(torch.nn.Module):
82
+ def __init__(self, dim, dim_out, groups=8):
83
+ super().__init__()
84
+ self.block = torch.nn.Sequential(
85
+ torch.nn.Conv1d(dim, dim_out, 3, padding=1),
86
+ torch.nn.GroupNorm(groups, dim_out),
87
+ # MaskedGroupNorm(groups, dim_out),
88
+ nn.Mish(),
89
+ )
90
+
91
+ def forward(self, x, mask):
92
+ output = self.block(x * mask)
93
+ return output * mask
94
+ return x * mask
95
+
96
+
97
+ class ResnetBlock1D(torch.nn.Module):
98
+ def __init__(self, dim, dim_out, time_emb_dim, groups=8):
99
+ super().__init__()
100
+ self.mlp = torch.nn.Sequential(
101
+ nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out)
102
+ )
103
+
104
+ self.block1 = Block1D(dim, dim_out, groups=groups)
105
+ self.block2 = Block1D(dim_out, dim_out, groups=groups)
106
+
107
+ self.res_conv = torch.nn.Conv1d(dim, dim_out, 1)
108
+
109
+ def forward(self, x, mask, time_emb):
110
+ h = self.block1(x, mask)
111
+ h += self.mlp(time_emb).unsqueeze(-1)
112
+ h = self.block2(h, mask)
113
+ output = h + self.res_conv(x * mask)
114
+ return output
115
+
116
+
117
+ class Downsample1D(nn.Module):
118
+ def __init__(self, dim):
119
+ super().__init__()
120
+ self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1)
121
+
122
+ def forward(self, x):
123
+ return self.conv(x)
124
+
125
+
126
+ class TimestepEmbedding(nn.Module):
127
+ def __init__(
128
+ self,
129
+ in_channels: int,
130
+ time_embed_dim: int,
131
+ act_fn: str = "silu",
132
+ out_dim: int = None,
133
+ post_act_fn: Optional[str] = None,
134
+ cond_proj_dim=None,
135
+ ):
136
+ super().__init__()
137
+
138
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim)
139
+
140
+ if cond_proj_dim is not None:
141
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
142
+ else:
143
+ self.cond_proj = None
144
+
145
+ self.act = get_activation(act_fn)
146
+
147
+ if out_dim is not None:
148
+ time_embed_dim_out = out_dim
149
+ else:
150
+ time_embed_dim_out = time_embed_dim
151
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
152
+
153
+ if post_act_fn is None:
154
+ self.post_act = None
155
+ else:
156
+ self.post_act = get_activation(post_act_fn)
157
+
158
+ def forward(self, sample, condition=None):
159
+ if condition is not None:
160
+ sample = sample + self.cond_proj(condition)
161
+ sample = self.linear_1(sample)
162
+
163
+ if self.act is not None:
164
+ sample = self.act(sample)
165
+
166
+ sample = self.linear_2(sample)
167
+
168
+ if self.post_act is not None:
169
+ sample = self.post_act(sample)
170
+ return sample
171
+
172
+
173
+ class Upsample1D(nn.Module):
174
+ """A 1D upsampling layer with an optional convolution.
175
+
176
+ Parameters:
177
+ channels (`int`):
178
+ number of channels in the inputs and outputs.
179
+ use_conv (`bool`, default `False`):
180
+ option to use a convolution.
181
+ use_conv_transpose (`bool`, default `False`):
182
+ option to use a convolution transpose.
183
+ out_channels (`int`, optional):
184
+ number of output channels. Defaults to `channels`.
185
+ """
186
+
187
+ def __init__(
188
+ self,
189
+ channels,
190
+ use_conv=False,
191
+ use_conv_transpose=True,
192
+ out_channels=None,
193
+ name="conv",
194
+ ):
195
+ super().__init__()
196
+ self.channels = channels
197
+ self.out_channels = out_channels or channels
198
+ self.use_conv = use_conv
199
+ self.use_conv_transpose = use_conv_transpose
200
+ self.name = name
201
+
202
+ self.conv = None
203
+ if use_conv_transpose:
204
+ self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
205
+ elif use_conv:
206
+ self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
207
+
208
+ def forward(self, inputs):
209
+ assert inputs.shape[1] == self.channels
210
+ if self.use_conv_transpose:
211
+ return self.conv(inputs)
212
+
213
+ outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
214
+
215
+ if self.use_conv:
216
+ outputs = self.conv(outputs)
217
+
218
+ return outputs
219
+
220
+
221
+ class ConformerWrapper(ConformerBlock):
222
+ def __init__( # pylint: disable=useless-super-delegation
223
+ self,
224
+ *,
225
+ dim,
226
+ dim_head=64,
227
+ heads=8,
228
+ ff_mult=4,
229
+ conv_expansion_factor=2,
230
+ conv_kernel_size=31,
231
+ attn_dropout=0,
232
+ ff_dropout=0,
233
+ conv_dropout=0,
234
+ conv_causal=False,
235
+ ):
236
+ super().__init__(
237
+ dim=dim,
238
+ dim_head=dim_head,
239
+ heads=heads,
240
+ ff_mult=ff_mult,
241
+ conv_expansion_factor=conv_expansion_factor,
242
+ conv_kernel_size=conv_kernel_size,
243
+ attn_dropout=attn_dropout,
244
+ ff_dropout=ff_dropout,
245
+ conv_dropout=conv_dropout,
246
+ conv_causal=conv_causal,
247
+ )
248
+
249
+ def forward(
250
+ self,
251
+ hidden_states,
252
+ attention_mask,
253
+ encoder_hidden_states=None,
254
+ encoder_attention_mask=None,
255
+ timestep=None,
256
+ ):
257
+ return super().forward(x=hidden_states, mask=attention_mask.bool())
258
+
259
+
260
+ class Decoder(nn.Module):
261
+ def __init__(
262
+ self,
263
+ in_channels,
264
+ out_channels,
265
+ channels=(256, 256),
266
+ dropout=0.05,
267
+ attention_head_dim=64,
268
+ n_blocks=1,
269
+ num_mid_blocks=2,
270
+ num_heads=4,
271
+ act_fn="snake",
272
+ down_block_type="transformer",
273
+ mid_block_type="transformer",
274
+ up_block_type="transformer",
275
+ ):
276
+ super().__init__()
277
+ channels = tuple(channels)
278
+ self.in_channels = in_channels
279
+ self.out_channels = out_channels
280
+
281
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
282
+ time_embed_dim = channels[0] * 4
283
+ self.time_mlp = TimestepEmbedding(
284
+ in_channels=in_channels,
285
+ time_embed_dim=time_embed_dim,
286
+ act_fn="silu",
287
+ )
288
+
289
+ self.down_blocks = nn.ModuleList([])
290
+ self.mid_blocks = nn.ModuleList([])
291
+ self.up_blocks = nn.ModuleList([])
292
+
293
+ output_channel = in_channels
294
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
295
+ input_channel = output_channel
296
+ output_channel = channels[i]
297
+ is_last = i == len(channels) - 1
298
+ resnet = ResnetBlock1D(
299
+ dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim
300
+ )
301
+ transformer_blocks = nn.ModuleList(
302
+ [
303
+ self.get_block(
304
+ down_block_type,
305
+ output_channel,
306
+ attention_head_dim,
307
+ num_heads,
308
+ dropout,
309
+ act_fn,
310
+ )
311
+ for _ in range(n_blocks)
312
+ ]
313
+ )
314
+ downsample = (
315
+ Downsample1D(output_channel)
316
+ if not is_last
317
+ else nn.Conv1d(output_channel, output_channel, 3, padding=1)
318
+ )
319
+
320
+ self.down_blocks.append(
321
+ nn.ModuleList([resnet, transformer_blocks, downsample])
322
+ )
323
+
324
+ for i in range(num_mid_blocks):
325
+ input_channel = channels[-1]
326
+ out_channels = channels[-1]
327
+
328
+ resnet = ResnetBlock1D(
329
+ dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim
330
+ )
331
+
332
+ transformer_blocks = nn.ModuleList(
333
+ [
334
+ self.get_block(
335
+ mid_block_type,
336
+ output_channel,
337
+ attention_head_dim,
338
+ num_heads,
339
+ dropout,
340
+ act_fn,
341
+ )
342
+ for _ in range(n_blocks)
343
+ ]
344
+ )
345
+
346
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
347
+
348
+ channels = channels[::-1] + (channels[0],)
349
+ for i in range(len(channels) - 1):
350
+ input_channel = channels[i]
351
+ output_channel = channels[i + 1]
352
+ is_last = i == len(channels) - 2
353
+
354
+ resnet = ResnetBlock1D(
355
+ dim=2 * input_channel,
356
+ dim_out=output_channel,
357
+ time_emb_dim=time_embed_dim,
358
+ )
359
+ transformer_blocks = nn.ModuleList(
360
+ [
361
+ self.get_block(
362
+ up_block_type,
363
+ output_channel,
364
+ attention_head_dim,
365
+ num_heads,
366
+ dropout,
367
+ act_fn,
368
+ )
369
+ for _ in range(n_blocks)
370
+ ]
371
+ )
372
+ upsample = (
373
+ Upsample1D(output_channel, use_conv_transpose=True)
374
+ if not is_last
375
+ else nn.Conv1d(output_channel, output_channel, 3, padding=1)
376
+ )
377
+
378
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
379
+
380
+ self.final_block = Block1D(channels[-1], channels[-1])
381
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
382
+
383
+ self.initialize_weights()
384
+ # nn.init.normal_(self.final_proj.weight)
385
+
386
+ @staticmethod
387
+ def get_block(block_type, dim, attention_head_dim, num_heads, dropout, act_fn):
388
+ if block_type == "conformer":
389
+ block = ConformerWrapper(
390
+ dim=dim,
391
+ dim_head=attention_head_dim,
392
+ heads=num_heads,
393
+ ff_mult=1,
394
+ conv_expansion_factor=2,
395
+ ff_dropout=dropout,
396
+ attn_dropout=dropout,
397
+ conv_dropout=dropout,
398
+ conv_kernel_size=31,
399
+ )
400
+ elif block_type == "transformer":
401
+ block = BasicTransformerBlock(
402
+ dim=dim,
403
+ num_attention_heads=num_heads,
404
+ attention_head_dim=attention_head_dim,
405
+ dropout=dropout,
406
+ activation_fn=act_fn,
407
+ )
408
+ else:
409
+ raise ValueError(f"Unknown block type {block_type}")
410
+
411
+ return block
412
+
413
+ def initialize_weights(self):
414
+ for m in self.modules():
415
+ if isinstance(m, nn.Conv1d):
416
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
417
+
418
+ if m.bias is not None:
419
+ nn.init.constant_(m.bias, 0)
420
+
421
+ elif isinstance(m, nn.GroupNorm):
422
+ nn.init.constant_(m.weight, 1)
423
+ nn.init.constant_(m.bias, 0)
424
+
425
+ elif isinstance(m, nn.Linear):
426
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
427
+
428
+ if m.bias is not None:
429
+ nn.init.constant_(m.bias, 0)
430
+
431
+ def forward(self, x, mask, mu, t, spks=None, cond=None):
432
+ """Forward pass of the UNet1DConditional model.
433
+
434
+ Args:
435
+ x (torch.Tensor): shape (batch_size, in_channels, time)
436
+ mask (_type_): shape (batch_size, 1, time)
437
+ t (_type_): shape (batch_size)
438
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
439
+ cond (_type_, optional): placeholder for future use. Defaults to None.
440
+
441
+ Raises:
442
+ ValueError: _description_
443
+ ValueError: _description_
444
+
445
+ Returns:
446
+ _type_: _description_
447
+ """
448
+
449
+ t = self.time_embeddings(t)
450
+ t = self.time_mlp(t)
451
+
452
+ x = pack([x, mu], "b * t")[0]
453
+
454
+ if spks is not None:
455
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
456
+ x = pack([x, spks], "b * t")[0]
457
+
458
+ hiddens = []
459
+ masks = [mask]
460
+ for resnet, transformer_blocks, downsample in self.down_blocks:
461
+ mask_down = masks[-1]
462
+ x = resnet(x, mask_down, t)
463
+ x = rearrange(x, "b c t -> b t c")
464
+ mask_down = rearrange(mask_down, "b 1 t -> b t")
465
+ for transformer_block in transformer_blocks:
466
+ x = transformer_block(
467
+ hidden_states=x,
468
+ attention_mask=mask_down,
469
+ timestep=t,
470
+ )
471
+ x = rearrange(x, "b t c -> b c t")
472
+ mask_down = rearrange(mask_down, "b t -> b 1 t")
473
+ hiddens.append(x) # Save hidden states for skip connections
474
+ x = downsample(x * mask_down)
475
+ masks.append(mask_down[:, :, ::2])
476
+
477
+ masks = masks[:-1]
478
+ mask_mid = masks[-1]
479
+
480
+ for resnet, transformer_blocks in self.mid_blocks:
481
+ x = resnet(x, mask_mid, t)
482
+ x = rearrange(x, "b c t -> b t c")
483
+ mask_mid = rearrange(mask_mid, "b 1 t -> b t")
484
+ for transformer_block in transformer_blocks:
485
+ x = transformer_block(
486
+ hidden_states=x,
487
+ attention_mask=mask_mid,
488
+ timestep=t,
489
+ )
490
+ x = rearrange(x, "b t c -> b c t")
491
+ mask_mid = rearrange(mask_mid, "b t -> b 1 t")
492
+
493
+ for resnet, transformer_blocks, upsample in self.up_blocks:
494
+ mask_up = masks.pop()
495
+ x = resnet(pack([x, hiddens.pop()], "b * t")[0], mask_up, t)
496
+ x = rearrange(x, "b c t -> b t c")
497
+ mask_up = rearrange(mask_up, "b 1 t -> b t")
498
+ for transformer_block in transformer_blocks:
499
+ x = transformer_block(
500
+ hidden_states=x,
501
+ attention_mask=mask_up,
502
+ timestep=t,
503
+ )
504
+ x = rearrange(x, "b t c -> b c t")
505
+ mask_up = rearrange(mask_up, "b t -> b 1 t")
506
+ x = upsample(x * mask_up)
507
+
508
+ x = self.final_block(x, mask_up)
509
+ output = self.final_proj(x * mask_up)
510
+
511
+ return output * mask
cosyvoice/matcha/flow_matching.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from cosyvoice.matcha.decoder import Decoder
7
+
8
+
9
+ class BASECFM(torch.nn.Module, ABC):
10
+ def __init__(
11
+ self,
12
+ n_feats,
13
+ cfm_params,
14
+ n_spks=1,
15
+ spk_emb_dim=128,
16
+ ):
17
+ super().__init__()
18
+ self.n_feats = n_feats
19
+ self.n_spks = n_spks
20
+ self.spk_emb_dim = spk_emb_dim
21
+ self.solver = cfm_params.solver
22
+ if hasattr(cfm_params, "sigma_min"):
23
+ self.sigma_min = cfm_params.sigma_min
24
+ else:
25
+ self.sigma_min = 1e-4
26
+
27
+ self.estimator = None
28
+
29
+ @torch.inference_mode()
30
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
31
+ """Forward diffusion
32
+
33
+ Args:
34
+ mu (torch.Tensor): output of encoder
35
+ shape: (batch_size, n_feats, mel_timesteps)
36
+ mask (torch.Tensor): output_mask
37
+ shape: (batch_size, 1, mel_timesteps)
38
+ n_timesteps (int): number of diffusion steps
39
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
40
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
41
+ shape: (batch_size, spk_emb_dim)
42
+ cond: Not used but kept for future purposes
43
+
44
+ Returns:
45
+ sample: generated mel-spectrogram
46
+ shape: (batch_size, n_feats, mel_timesteps)
47
+ """
48
+ z = torch.randn_like(mu) * temperature
49
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
50
+ return self.solve_euler(
51
+ z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond
52
+ )
53
+
54
+ def solve_euler(self, x, t_span, mu, mask, spks, cond):
55
+ """
56
+ Fixed euler solver for ODEs.
57
+ Args:
58
+ x (torch.Tensor): random noise
59
+ t_span (torch.Tensor): n_timesteps interpolated
60
+ shape: (n_timesteps + 1,)
61
+ mu (torch.Tensor): output of encoder
62
+ shape: (batch_size, n_feats, mel_timesteps)
63
+ mask (torch.Tensor): output_mask
64
+ shape: (batch_size, 1, mel_timesteps)
65
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
66
+ shape: (batch_size, spk_emb_dim)
67
+ cond: Not used but kept for future purposes
68
+ """
69
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
70
+
71
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
72
+ # Or in future might add like a return_all_steps flag
73
+ sol = []
74
+
75
+ for step in range(1, len(t_span)):
76
+ dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
77
+
78
+ x = x + dt * dphi_dt
79
+ t = t + dt
80
+ sol.append(x)
81
+ if step < len(t_span) - 1:
82
+ dt = t_span[step + 1] - t
83
+
84
+ return sol[-1]
85
+
86
+ def compute_loss(self, x1, mask, mu, spks=None, cond=None):
87
+ """Computes diffusion loss
88
+
89
+ Args:
90
+ x1 (torch.Tensor): Target
91
+ shape: (batch_size, n_feats, mel_timesteps)
92
+ mask (torch.Tensor): target mask
93
+ shape: (batch_size, 1, mel_timesteps)
94
+ mu (torch.Tensor): output of encoder
95
+ shape: (batch_size, n_feats, mel_timesteps)
96
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
97
+ shape: (batch_size, spk_emb_dim)
98
+
99
+ Returns:
100
+ loss: conditional flow matching loss
101
+ y: conditional flow
102
+ shape: (batch_size, n_feats, mel_timesteps)
103
+ """
104
+ b, _, t = mu.shape
105
+
106
+ # random timestep
107
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
108
+ # sample noise p(x_0)
109
+ z = torch.randn_like(x1)
110
+
111
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
112
+ u = x1 - (1 - self.sigma_min) * z
113
+
114
+ loss = F.mse_loss(
115
+ self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum"
116
+ ) / (torch.sum(mask) * u.shape[1])
117
+ return loss, y
118
+
119
+
120
+ class CFM(BASECFM):
121
+ def __init__(
122
+ self,
123
+ in_channels,
124
+ out_channel,
125
+ cfm_params,
126
+ decoder_params,
127
+ n_spks=1,
128
+ spk_emb_dim=64,
129
+ ):
130
+ super().__init__(
131
+ n_feats=in_channels,
132
+ cfm_params=cfm_params,
133
+ n_spks=n_spks,
134
+ spk_emb_dim=spk_emb_dim,
135
+ )
136
+
137
+ in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0)
138
+ # Just change the architecture of the estimator here
139
+ self.estimator = Decoder(
140
+ in_channels=in_channels, out_channels=out_channel, **decoder_params
141
+ )
cosyvoice/matcha/transformer.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from diffusers.models.attention import (
6
+ GEGLU,
7
+ GELU,
8
+ AdaLayerNorm,
9
+ AdaLayerNormZero,
10
+ ApproximateGELU,
11
+ )
12
+ from diffusers.models.attention_processor import Attention
13
+ from diffusers.models.lora import LoRACompatibleLinear
14
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
15
+
16
+
17
+ class SnakeBeta(nn.Module):
18
+ """
19
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
20
+ Shape:
21
+ - Input: (B, C, T)
22
+ - Output: (B, C, T), same shape as the input
23
+ Parameters:
24
+ - alpha - trainable parameter that controls frequency
25
+ - beta - trainable parameter that controls magnitude
26
+ References:
27
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
28
+ https://arxiv.org/abs/2006.08195
29
+ Examples:
30
+ >>> a1 = snakebeta(256)
31
+ >>> x = torch.randn(256)
32
+ >>> x = a1(x)
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ in_features,
38
+ out_features,
39
+ alpha=1.0,
40
+ alpha_trainable=True,
41
+ alpha_logscale=True,
42
+ ):
43
+ """
44
+ Initialization.
45
+ INPUT:
46
+ - in_features: shape of the input
47
+ - alpha - trainable parameter that controls frequency
48
+ - beta - trainable parameter that controls magnitude
49
+ alpha is initialized to 1 by default, higher values = higher-frequency.
50
+ beta is initialized to 1 by default, higher values = higher-magnitude.
51
+ alpha will be trained along with the rest of your model.
52
+ """
53
+ super().__init__()
54
+ self.in_features = (
55
+ out_features if isinstance(out_features, list) else [out_features]
56
+ )
57
+ self.proj = LoRACompatibleLinear(in_features, out_features)
58
+
59
+ # initialize alpha
60
+ self.alpha_logscale = alpha_logscale
61
+ if self.alpha_logscale: # log scale alphas initialized to zeros
62
+ self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha)
63
+ self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha)
64
+ else: # linear scale alphas initialized to ones
65
+ self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha)
66
+ self.beta = nn.Parameter(torch.ones(self.in_features) * alpha)
67
+
68
+ self.alpha.requires_grad = alpha_trainable
69
+ self.beta.requires_grad = alpha_trainable
70
+
71
+ self.no_div_by_zero = 0.000000001
72
+
73
+ def forward(self, x):
74
+ """
75
+ Forward pass of the function.
76
+ Applies the function to the input elementwise.
77
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
78
+ """
79
+ x = self.proj(x)
80
+ if self.alpha_logscale:
81
+ alpha = torch.exp(self.alpha)
82
+ beta = torch.exp(self.beta)
83
+ else:
84
+ alpha = self.alpha
85
+ beta = self.beta
86
+
87
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(
88
+ torch.sin(x * alpha), 2
89
+ )
90
+
91
+ return x
92
+
93
+
94
+ class FeedForward(nn.Module):
95
+ r"""
96
+ A feed-forward layer.
97
+
98
+ Parameters:
99
+ dim (`int`): The number of channels in the input.
100
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
101
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
102
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
103
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
104
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
105
+ """
106
+
107
+ def __init__(
108
+ self,
109
+ dim: int,
110
+ dim_out: Optional[int] = None,
111
+ mult: int = 4,
112
+ dropout: float = 0.0,
113
+ activation_fn: str = "geglu",
114
+ final_dropout: bool = False,
115
+ ):
116
+ super().__init__()
117
+ inner_dim = int(dim * mult)
118
+ dim_out = dim_out if dim_out is not None else dim
119
+
120
+ if activation_fn == "gelu":
121
+ act_fn = GELU(dim, inner_dim)
122
+ if activation_fn == "gelu-approximate":
123
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
124
+ elif activation_fn == "geglu":
125
+ act_fn = GEGLU(dim, inner_dim)
126
+ elif activation_fn == "geglu-approximate":
127
+ act_fn = ApproximateGELU(dim, inner_dim)
128
+ elif activation_fn == "snakebeta":
129
+ act_fn = SnakeBeta(dim, inner_dim)
130
+
131
+ self.net = nn.ModuleList([])
132
+ # project in
133
+ self.net.append(act_fn)
134
+ # project dropout
135
+ self.net.append(nn.Dropout(dropout))
136
+ # project out
137
+ self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
138
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
139
+ if final_dropout:
140
+ self.net.append(nn.Dropout(dropout))
141
+
142
+ def forward(self, hidden_states):
143
+ for module in self.net:
144
+ hidden_states = module(hidden_states)
145
+ return hidden_states
146
+
147
+
148
+ @maybe_allow_in_graph
149
+ class BasicTransformerBlock(nn.Module):
150
+ r"""
151
+ A basic Transformer block.
152
+
153
+ Parameters:
154
+ dim (`int`): The number of channels in the input and output.
155
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
156
+ attention_head_dim (`int`): The number of channels in each head.
157
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
158
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
159
+ only_cross_attention (`bool`, *optional*):
160
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
161
+ double_self_attention (`bool`, *optional*):
162
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
163
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
164
+ num_embeds_ada_norm (:
165
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
166
+ attention_bias (:
167
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
168
+ """
169
+
170
+ def __init__(
171
+ self,
172
+ dim: int,
173
+ num_attention_heads: int,
174
+ attention_head_dim: int,
175
+ dropout=0.0,
176
+ cross_attention_dim: Optional[int] = None,
177
+ activation_fn: str = "geglu",
178
+ num_embeds_ada_norm: Optional[int] = None,
179
+ attention_bias: bool = False,
180
+ only_cross_attention: bool = False,
181
+ double_self_attention: bool = False,
182
+ upcast_attention: bool = False,
183
+ norm_elementwise_affine: bool = True,
184
+ norm_type: str = "layer_norm",
185
+ final_dropout: bool = False,
186
+ ):
187
+ super().__init__()
188
+ self.only_cross_attention = only_cross_attention
189
+
190
+ self.use_ada_layer_norm_zero = (
191
+ num_embeds_ada_norm is not None
192
+ ) and norm_type == "ada_norm_zero"
193
+ self.use_ada_layer_norm = (
194
+ num_embeds_ada_norm is not None
195
+ ) and norm_type == "ada_norm"
196
+
197
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
198
+ raise ValueError(
199
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
200
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
201
+ )
202
+
203
+ # Define 3 blocks. Each block has its own normalization layer.
204
+ # 1. Self-Attn
205
+ if self.use_ada_layer_norm:
206
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
207
+ elif self.use_ada_layer_norm_zero:
208
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
209
+ else:
210
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
211
+ self.attn1 = Attention(
212
+ query_dim=dim,
213
+ heads=num_attention_heads,
214
+ dim_head=attention_head_dim,
215
+ dropout=dropout,
216
+ bias=attention_bias,
217
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
218
+ upcast_attention=upcast_attention,
219
+ )
220
+
221
+ # 2. Cross-Attn
222
+ if cross_attention_dim is not None or double_self_attention:
223
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
224
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
225
+ # the second cross attention block.
226
+ self.norm2 = (
227
+ AdaLayerNorm(dim, num_embeds_ada_norm)
228
+ if self.use_ada_layer_norm
229
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
230
+ )
231
+ self.attn2 = Attention(
232
+ query_dim=dim,
233
+ cross_attention_dim=(
234
+ cross_attention_dim if not double_self_attention else None
235
+ ),
236
+ heads=num_attention_heads,
237
+ dim_head=attention_head_dim,
238
+ dropout=dropout,
239
+ bias=attention_bias,
240
+ upcast_attention=upcast_attention,
241
+ # scale_qk=False, # uncomment this to not to use flash attention
242
+ ) # is self-attn if encoder_hidden_states is none
243
+ else:
244
+ self.norm2 = None
245
+ self.attn2 = None
246
+
247
+ # 3. Feed-forward
248
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
249
+ self.ff = FeedForward(
250
+ dim,
251
+ dropout=dropout,
252
+ activation_fn=activation_fn,
253
+ final_dropout=final_dropout,
254
+ )
255
+
256
+ # let chunk size default to None
257
+ self._chunk_size = None
258
+ self._chunk_dim = 0
259
+
260
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
261
+ # Sets chunk feed-forward
262
+ self._chunk_size = chunk_size
263
+ self._chunk_dim = dim
264
+
265
+ def forward_native(
266
+ self,
267
+ hidden_states: torch.FloatTensor,
268
+ attention_mask: Optional[torch.FloatTensor] = None,
269
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
270
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
271
+ timestep: Optional[torch.LongTensor] = None,
272
+ cross_attention_kwargs: Dict[str, Any] = None,
273
+ class_labels: Optional[torch.LongTensor] = None,
274
+ ):
275
+ # Notice that normalization is always applied before the real computation in the following blocks.
276
+ # 1. Self-Attention
277
+ if self.use_ada_layer_norm:
278
+ norm_hidden_states = self.norm1(hidden_states, timestep)
279
+ elif self.use_ada_layer_norm_zero:
280
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
281
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
282
+ )
283
+ else:
284
+ norm_hidden_states = self.norm1(hidden_states)
285
+
286
+ cross_attention_kwargs = (
287
+ cross_attention_kwargs if cross_attention_kwargs is not None else {}
288
+ )
289
+
290
+ attn_output = self.attn1(
291
+ norm_hidden_states,
292
+ encoder_hidden_states=(
293
+ encoder_hidden_states if self.only_cross_attention else None
294
+ ),
295
+ attention_mask=(
296
+ encoder_attention_mask if self.only_cross_attention else attention_mask
297
+ ),
298
+ **cross_attention_kwargs,
299
+ )
300
+ if self.use_ada_layer_norm_zero:
301
+ attn_output = gate_msa.unsqueeze(1) * attn_output
302
+ hidden_states = attn_output + hidden_states
303
+
304
+ # 2. Cross-Attention
305
+ if self.attn2 is not None:
306
+ norm_hidden_states = (
307
+ self.norm2(hidden_states, timestep)
308
+ if self.use_ada_layer_norm
309
+ else self.norm2(hidden_states)
310
+ )
311
+
312
+ attn_output = self.attn2(
313
+ norm_hidden_states,
314
+ encoder_hidden_states=encoder_hidden_states,
315
+ attention_mask=encoder_attention_mask,
316
+ **cross_attention_kwargs,
317
+ )
318
+ hidden_states = attn_output + hidden_states
319
+
320
+ # 3. Feed-forward
321
+ norm_hidden_states = self.norm3(hidden_states)
322
+
323
+ if self.use_ada_layer_norm_zero:
324
+ norm_hidden_states = (
325
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
326
+ )
327
+
328
+ if self._chunk_size is not None:
329
+ # "feed_forward_chunk_size" can be used to save memory
330
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
331
+ raise ValueError(
332
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
333
+ )
334
+
335
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
336
+ ff_output = torch.cat(
337
+ [
338
+ self.ff(hid_slice)
339
+ for hid_slice in norm_hidden_states.chunk(
340
+ num_chunks, dim=self._chunk_dim
341
+ )
342
+ ],
343
+ dim=self._chunk_dim,
344
+ )
345
+ else:
346
+ ff_output = self.ff(norm_hidden_states)
347
+
348
+ if self.use_ada_layer_norm_zero:
349
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
350
+
351
+ hidden_states = ff_output + hidden_states
352
+
353
+ return hidden_states
354
+
355
+ def forward(
356
+ self,
357
+ hidden_states: torch.FloatTensor,
358
+ attention_mask: Optional[torch.FloatTensor] = None,
359
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
360
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
361
+ timestep: Optional[torch.LongTensor] = None,
362
+ cross_attention_kwargs: Dict[str, Any] = None,
363
+ class_labels: Optional[torch.LongTensor] = None,
364
+ ):
365
+ # Notice that normalization is always applied before the real computation in the following blocks.
366
+ # 1. Self-Attention
367
+ if self.use_ada_layer_norm:
368
+ norm_hidden_states = self.norm1(hidden_states, timestep)
369
+ elif self.use_ada_layer_norm_zero:
370
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
371
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
372
+ )
373
+ else:
374
+ norm_hidden_states = self.norm1(hidden_states)
375
+
376
+ cross_attention_kwargs = (
377
+ cross_attention_kwargs if cross_attention_kwargs is not None else {}
378
+ )
379
+
380
+ attn_output = self.attn1(
381
+ norm_hidden_states,
382
+ encoder_hidden_states=(
383
+ encoder_hidden_states if self.only_cross_attention else None
384
+ ),
385
+ attention_mask=(
386
+ encoder_attention_mask if self.only_cross_attention else attention_mask
387
+ ),
388
+ **cross_attention_kwargs,
389
+ )
390
+ if self.use_ada_layer_norm_zero:
391
+ attn_output = gate_msa.unsqueeze(1) * attn_output
392
+ hidden_states = attn_output + hidden_states
393
+
394
+ # 2. Cross-Attention
395
+ if self.attn2 is not None:
396
+ norm_hidden_states = (
397
+ self.norm2(hidden_states, timestep)
398
+ if self.use_ada_layer_norm
399
+ else self.norm2(hidden_states)
400
+ )
401
+
402
+ attn_output = self.attn2(
403
+ norm_hidden_states,
404
+ encoder_hidden_states=encoder_hidden_states,
405
+ attention_mask=encoder_attention_mask,
406
+ **cross_attention_kwargs,
407
+ )
408
+ hidden_states = attn_output + hidden_states
409
+
410
+ # 3. Feed-forward
411
+ norm_hidden_states = self.norm3(hidden_states)
412
+
413
+ if self.use_ada_layer_norm_zero:
414
+ norm_hidden_states = (
415
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
416
+ )
417
+
418
+ if self._chunk_size is not None:
419
+ # "feed_forward_chunk_size" can be used to save memory
420
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
421
+ raise ValueError(
422
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
423
+ )
424
+
425
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
426
+ ff_output = torch.cat(
427
+ [
428
+ self.ff(hid_slice)
429
+ for hid_slice in norm_hidden_states.chunk(
430
+ num_chunks, dim=self._chunk_dim
431
+ )
432
+ ],
433
+ dim=self._chunk_dim,
434
+ )
435
+ else:
436
+ ff_output = self.ff(norm_hidden_states)
437
+
438
+ if self.use_ada_layer_norm_zero:
439
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
440
+
441
+ hidden_states = ff_output + hidden_states
442
+
443
+ return hidden_states
cosyvoice/transformer/__init__.py ADDED
File without changes
cosyvoice/transformer/activation.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe)
2
+ # 2020 Northwestern Polytechnical University (Pengcheng Guo)
3
+ # 2020 Mobvoi Inc (Binbin Zhang)
4
+ # 2024 Alibaba Inc (Xiang Lyu)
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """Swish() activation function for Conformer."""
18
+
19
+ import torch
20
+ from torch import nn, sin, pow
21
+ from torch.nn import Parameter
22
+
23
+
24
+ class Swish(torch.nn.Module):
25
+ """Construct an Swish object."""
26
+
27
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
28
+ """Return Swish activation function."""
29
+ return x * torch.sigmoid(x)
30
+
31
+
32
+ # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
33
+ # LICENSE is in incl_licenses directory.
34
+ class Snake(nn.Module):
35
+ """
36
+ Implementation of a sine-based periodic activation function
37
+ Shape:
38
+ - Input: (B, C, T)
39
+ - Output: (B, C, T), same shape as the input
40
+ Parameters:
41
+ - alpha - trainable parameter
42
+ References:
43
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
44
+ https://arxiv.org/abs/2006.08195
45
+ Examples:
46
+ >>> a1 = snake(256)
47
+ >>> x = torch.randn(256)
48
+ >>> x = a1(x)
49
+ """
50
+
51
+ def __init__(
52
+ self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
53
+ ):
54
+ """
55
+ Initialization.
56
+ INPUT:
57
+ - in_features: shape of the input
58
+ - alpha: trainable parameter
59
+ alpha is initialized to 1 by default, higher values = higher-frequency.
60
+ alpha will be trained along with the rest of your model.
61
+ """
62
+ super(Snake, self).__init__()
63
+ self.in_features = in_features
64
+
65
+ # initialize alpha
66
+ self.alpha_logscale = alpha_logscale
67
+ if self.alpha_logscale: # log scale alphas initialized to zeros
68
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
69
+ else: # linear scale alphas initialized to ones
70
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
71
+
72
+ self.alpha.requires_grad = alpha_trainable
73
+
74
+ self.no_div_by_zero = 0.000000001
75
+
76
+ def forward(self, x):
77
+ """
78
+ Forward pass of the function.
79
+ Applies the function to the input elementwise.
80
+ Snake ∶= x + 1/a * sin^2 (xa)
81
+ """
82
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
83
+ if self.alpha_logscale:
84
+ alpha = torch.exp(alpha)
85
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
86
+
87
+ return x
cosyvoice/transformer/attention.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ # 2022 Xingchen Song ([email protected])
4
+ # 2024 Alibaba Inc (Xiang Lyu)
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """Multi-Head Attention layer definition."""
18
+
19
+ import math
20
+ from typing import Tuple
21
+
22
+ import torch
23
+ from torch import nn
24
+
25
+
26
+ class MultiHeadedAttention(nn.Module):
27
+ """Multi-Head Attention layer.
28
+
29
+ Args:
30
+ n_head (int): The number of heads.
31
+ n_feat (int): The number of features.
32
+ dropout_rate (float): Dropout rate.
33
+
34
+ """
35
+
36
+ def __init__(
37
+ self, n_head: int, n_feat: int, dropout_rate: float, key_bias: bool = True
38
+ ):
39
+ """Construct an MultiHeadedAttention object."""
40
+ super().__init__()
41
+ assert n_feat % n_head == 0
42
+ # We assume d_v always equals d_k
43
+ self.d_k = n_feat // n_head
44
+ self.h = n_head
45
+ self.linear_q = nn.Linear(n_feat, n_feat)
46
+ self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
47
+ self.linear_v = nn.Linear(n_feat, n_feat)
48
+ self.linear_out = nn.Linear(n_feat, n_feat)
49
+ self.dropout = nn.Dropout(p=dropout_rate)
50
+
51
+ def forward_qkv(
52
+ self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
53
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
54
+ """Transform query, key and value.
55
+
56
+ Args:
57
+ query (torch.Tensor): Query tensor (#batch, time1, size).
58
+ key (torch.Tensor): Key tensor (#batch, time2, size).
59
+ value (torch.Tensor): Value tensor (#batch, time2, size).
60
+
61
+ Returns:
62
+ torch.Tensor: Transformed query tensor, size
63
+ (#batch, n_head, time1, d_k).
64
+ torch.Tensor: Transformed key tensor, size
65
+ (#batch, n_head, time2, d_k).
66
+ torch.Tensor: Transformed value tensor, size
67
+ (#batch, n_head, time2, d_k).
68
+
69
+ """
70
+ n_batch = query.size(0)
71
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
72
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
73
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
74
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
75
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
76
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
77
+
78
+ return q, k, v
79
+
80
+ def forward_attention(
81
+ self,
82
+ value: torch.Tensor,
83
+ scores: torch.Tensor,
84
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
85
+ ) -> torch.Tensor:
86
+ """Compute attention context vector.
87
+
88
+ Args:
89
+ value (torch.Tensor): Transformed value, size
90
+ (#batch, n_head, time2, d_k).
91
+ scores (torch.Tensor): Attention score, size
92
+ (#batch, n_head, time1, time2).
93
+ mask (torch.Tensor): Mask, size (#batch, 1, time2) or
94
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
95
+
96
+ Returns:
97
+ torch.Tensor: Transformed value (#batch, time1, d_model)
98
+ weighted by the attention score (#batch, time1, time2).
99
+
100
+ """
101
+ n_batch = value.size(0)
102
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be True?
103
+ # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
104
+ # 1st chunk to ease the onnx export.]
105
+ # 2. pytorch training
106
+ if mask.size(2) > 0: # time2 > 0
107
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
108
+ # For last chunk, time2 might be larger than scores.size(-1)
109
+ mask = mask[:, :, :, : scores.size(-1)] # (batch, 1, *, time2)
110
+ scores = scores.masked_fill(mask, -float("inf"))
111
+ attn = torch.softmax(scores, dim=-1).masked_fill(
112
+ mask, 0.0
113
+ ) # (batch, head, time1, time2)
114
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be False?
115
+ # 1. onnx(16/-1, -1/-1, 16/0)
116
+ # 2. jit (16/-1, -1/-1, 16/0, 16/4)
117
+ else:
118
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
119
+
120
+ p_attn = self.dropout(attn)
121
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
122
+ x = (
123
+ x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
124
+ ) # (batch, time1, d_model)
125
+
126
+ return self.linear_out(x) # (batch, time1, d_model)
127
+
128
+ def forward(
129
+ self,
130
+ query: torch.Tensor,
131
+ key: torch.Tensor,
132
+ value: torch.Tensor,
133
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
134
+ pos_emb: torch.Tensor = torch.empty(0),
135
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
136
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
137
+ """Compute scaled dot product attention.
138
+
139
+ Args:
140
+ query (torch.Tensor): Query tensor (#batch, time1, size).
141
+ key (torch.Tensor): Key tensor (#batch, time2, size).
142
+ value (torch.Tensor): Value tensor (#batch, time2, size).
143
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
144
+ (#batch, time1, time2).
145
+ 1.When applying cross attention between decoder and encoder,
146
+ the batch padding mask for input is in (#batch, 1, T) shape.
147
+ 2.When applying self attention of encoder,
148
+ the mask is in (#batch, T, T) shape.
149
+ 3.When applying self attention of decoder,
150
+ the mask is in (#batch, L, L) shape.
151
+ 4.If the different position in decoder see different block
152
+ of the encoder, such as Mocha, the passed in mask could be
153
+ in (#batch, L, T) shape. But there is no such case in current
154
+ CosyVoice.
155
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
156
+ where `cache_t == chunk_size * num_decoding_left_chunks`
157
+ and `head * d_k == size`
158
+
159
+
160
+ Returns:
161
+ torch.Tensor: Output tensor (#batch, time1, d_model).
162
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
163
+ where `cache_t == chunk_size * num_decoding_left_chunks`
164
+ and `head * d_k == size`
165
+
166
+ """
167
+ q, k, v = self.forward_qkv(query, key, value)
168
+
169
+ # NOTE(xcsong):
170
+ # when export onnx model, for 1st chunk, we feed
171
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
172
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
173
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
174
+ # and we will always do splitting and
175
+ # concatnation(this will simplify onnx export). Note that
176
+ # it's OK to concat & split zero-shaped tensors(see code below).
177
+ # when export jit model, for 1st chunk, we always feed
178
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
179
+ # >>> a = torch.ones((1, 2, 0, 4))
180
+ # >>> b = torch.ones((1, 2, 3, 4))
181
+ # >>> c = torch.cat((a, b), dim=2)
182
+ # >>> torch.equal(b, c) # True
183
+ # >>> d = torch.split(a, 2, dim=-1)
184
+ # >>> torch.equal(d[0], d[1]) # True
185
+ if cache.size(0) > 0:
186
+ key_cache, value_cache = torch.split(cache, cache.size(-1) // 2, dim=-1)
187
+ k = torch.cat([key_cache, k], dim=2)
188
+ v = torch.cat([value_cache, v], dim=2)
189
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
190
+ # non-trivial to calculate `next_cache_start` here.
191
+ new_cache = torch.cat((k, v), dim=-1)
192
+
193
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
194
+ return self.forward_attention(v, scores, mask), new_cache
195
+
196
+
197
+ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
198
+ """Multi-Head Attention layer with relative position encoding.
199
+ Paper: https://arxiv.org/abs/1901.02860
200
+ Args:
201
+ n_head (int): The number of heads.
202
+ n_feat (int): The number of features.
203
+ dropout_rate (float): Dropout rate.
204
+ """
205
+
206
+ def __init__(
207
+ self, n_head: int, n_feat: int, dropout_rate: float, key_bias: bool = True
208
+ ):
209
+ """Construct an RelPositionMultiHeadedAttention object."""
210
+ super().__init__(n_head, n_feat, dropout_rate, key_bias)
211
+ # linear transformation for positional encoding
212
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
213
+ # these two learnable bias are used in matrix c and matrix d
214
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
215
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
216
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
217
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
218
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
219
+
220
+ def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
221
+ """Compute relative positional encoding.
222
+
223
+ Args:
224
+ x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
225
+ time1 means the length of query vector.
226
+
227
+ Returns:
228
+ torch.Tensor: Output tensor.
229
+
230
+ """
231
+ zero_pad = torch.zeros(
232
+ (x.size()[0], x.size()[1], x.size()[2], 1), device=x.device, dtype=x.dtype
233
+ )
234
+ x_padded = torch.cat([zero_pad, x], dim=-1)
235
+
236
+ x_padded = x_padded.view(x.size()[0], x.size()[1], x.size(3) + 1, x.size(2))
237
+ x = x_padded[:, :, 1:].view_as(x)[
238
+ :, :, :, : x.size(-1) // 2 + 1
239
+ ] # only keep the positions from 0 to time2
240
+ return x
241
+
242
+ def forward(
243
+ self,
244
+ query: torch.Tensor,
245
+ key: torch.Tensor,
246
+ value: torch.Tensor,
247
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
248
+ pos_emb: torch.Tensor = torch.empty(0),
249
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
250
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
251
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
252
+ Args:
253
+ query (torch.Tensor): Query tensor (#batch, time1, size).
254
+ key (torch.Tensor): Key tensor (#batch, time2, size).
255
+ value (torch.Tensor): Value tensor (#batch, time2, size).
256
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
257
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
258
+ pos_emb (torch.Tensor): Positional embedding tensor
259
+ (#batch, time2, size).
260
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
261
+ where `cache_t == chunk_size * num_decoding_left_chunks`
262
+ and `head * d_k == size`
263
+ Returns:
264
+ torch.Tensor: Output tensor (#batch, time1, d_model).
265
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
266
+ where `cache_t == chunk_size * num_decoding_left_chunks`
267
+ and `head * d_k == size`
268
+ """
269
+ q, k, v = self.forward_qkv(query, key, value)
270
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
271
+
272
+ # NOTE(xcsong):
273
+ # when export onnx model, for 1st chunk, we feed
274
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
275
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
276
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
277
+ # and we will always do splitting and
278
+ # concatnation(this will simplify onnx export). Note that
279
+ # it's OK to concat & split zero-shaped tensors(see code below).
280
+ # when export jit model, for 1st chunk, we always feed
281
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
282
+ # >>> a = torch.ones((1, 2, 0, 4))
283
+ # >>> b = torch.ones((1, 2, 3, 4))
284
+ # >>> c = torch.cat((a, b), dim=2)
285
+ # >>> torch.equal(b, c) # True
286
+ # >>> d = torch.split(a, 2, dim=-1)
287
+ # >>> torch.equal(d[0], d[1]) # True
288
+ if cache.size(0) > 0:
289
+ key_cache, value_cache = torch.split(cache, cache.size(-1) // 2, dim=-1)
290
+ k = torch.cat([key_cache, k], dim=2)
291
+ v = torch.cat([value_cache, v], dim=2)
292
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
293
+ # non-trivial to calculate `next_cache_start` here.
294
+ new_cache = torch.cat((k, v), dim=-1)
295
+
296
+ n_batch_pos = pos_emb.size(0)
297
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
298
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
299
+
300
+ # (batch, head, time1, d_k)
301
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
302
+ # (batch, head, time1, d_k)
303
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
304
+
305
+ # compute attention score
306
+ # first compute matrix a and matrix c
307
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
308
+ # (batch, head, time1, time2)
309
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
310
+
311
+ # compute matrix b and matrix d
312
+ # (batch, head, time1, time2)
313
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
314
+ # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
315
+ if matrix_ac.shape != matrix_bd.shape:
316
+ matrix_bd = self.rel_shift(matrix_bd)
317
+
318
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
319
+ self.d_k
320
+ ) # (batch, head, time1, time2)
321
+
322
+ return self.forward_attention(v, scores, mask), new_cache
cosyvoice/transformer/convolution.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
2
+ # 2024 Alibaba Inc (Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """ConvolutionModule definition."""
17
+
18
+ from typing import Tuple
19
+
20
+ import torch
21
+ from torch import nn
22
+
23
+
24
+ class ConvolutionModule(nn.Module):
25
+ """ConvolutionModule in Conformer model."""
26
+
27
+ def __init__(
28
+ self,
29
+ channels: int,
30
+ kernel_size: int = 15,
31
+ activation: nn.Module = nn.ReLU(),
32
+ norm: str = "batch_norm",
33
+ causal: bool = False,
34
+ bias: bool = True,
35
+ ):
36
+ """Construct an ConvolutionModule object.
37
+ Args:
38
+ channels (int): The number of channels of conv layers.
39
+ kernel_size (int): Kernel size of conv layers.
40
+ causal (int): Whether use causal convolution or not
41
+ """
42
+ super().__init__()
43
+
44
+ self.pointwise_conv1 = nn.Conv1d(
45
+ channels,
46
+ 2 * channels,
47
+ kernel_size=1,
48
+ stride=1,
49
+ padding=0,
50
+ bias=bias,
51
+ )
52
+ # self.lorder is used to distinguish if it's a causal convolution,
53
+ # if self.lorder > 0: it's a causal convolution, the input will be
54
+ # padded with self.lorder frames on the left in forward.
55
+ # else: it's a symmetrical convolution
56
+ if causal:
57
+ padding = 0
58
+ self.lorder = kernel_size - 1
59
+ else:
60
+ # kernel_size should be an odd number for none causal convolution
61
+ assert (kernel_size - 1) % 2 == 0
62
+ padding = (kernel_size - 1) // 2
63
+ self.lorder = 0
64
+ self.depthwise_conv = nn.Conv1d(
65
+ channels,
66
+ channels,
67
+ kernel_size,
68
+ stride=1,
69
+ padding=padding,
70
+ groups=channels,
71
+ bias=bias,
72
+ )
73
+
74
+ assert norm in ["batch_norm", "layer_norm"]
75
+ if norm == "batch_norm":
76
+ self.use_layer_norm = False
77
+ self.norm = nn.BatchNorm1d(channels)
78
+ else:
79
+ self.use_layer_norm = True
80
+ self.norm = nn.LayerNorm(channels)
81
+
82
+ self.pointwise_conv2 = nn.Conv1d(
83
+ channels,
84
+ channels,
85
+ kernel_size=1,
86
+ stride=1,
87
+ padding=0,
88
+ bias=bias,
89
+ )
90
+ self.activation = activation
91
+
92
+ def forward(
93
+ self,
94
+ x: torch.Tensor,
95
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
96
+ cache: torch.Tensor = torch.zeros((0, 0, 0)),
97
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
98
+ """Compute convolution module.
99
+ Args:
100
+ x (torch.Tensor): Input tensor (#batch, time, channels).
101
+ mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
102
+ (0, 0, 0) means fake mask.
103
+ cache (torch.Tensor): left context cache, it is only
104
+ used in causal convolution (#batch, channels, cache_t),
105
+ (0, 0, 0) meas fake cache.
106
+ Returns:
107
+ torch.Tensor: Output tensor (#batch, time, channels).
108
+ """
109
+ # exchange the temporal dimension and the feature dimension
110
+ x = x.transpose(1, 2) # (#batch, channels, time)
111
+
112
+ # mask batch padding
113
+ if mask_pad.size(2) > 0: # time > 0
114
+ x.masked_fill_(~mask_pad, 0.0)
115
+
116
+ if self.lorder > 0:
117
+ if cache.size(2) == 0: # cache_t == 0
118
+ x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
119
+ else:
120
+ assert cache.size(0) == x.size(0) # equal batch
121
+ assert cache.size(1) == x.size(1) # equal channel
122
+ x = torch.cat((cache, x), dim=2)
123
+ assert x.size(2) > self.lorder
124
+ new_cache = x[:, :, -self.lorder :]
125
+ else:
126
+ # It's better we just return None if no cache is required,
127
+ # However, for JIT export, here we just fake one tensor instead of
128
+ # None.
129
+ new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
130
+
131
+ # GLU mechanism
132
+ x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
133
+ x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
134
+
135
+ # 1D Depthwise Conv
136
+ x = self.depthwise_conv(x)
137
+ if self.use_layer_norm:
138
+ x = x.transpose(1, 2)
139
+ x = self.activation(self.norm(x))
140
+ if self.use_layer_norm:
141
+ x = x.transpose(1, 2)
142
+ x = self.pointwise_conv2(x)
143
+ # mask batch padding
144
+ if mask_pad.size(2) > 0: # time > 0
145
+ x.masked_fill_(~mask_pad, 0.0)
146
+
147
+ return x.transpose(1, 2), new_cache
cosyvoice/transformer/decoder.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
2
+ # 2024 Alibaba Inc (Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """Decoder definition."""
17
+ from typing import Tuple, List, Optional
18
+
19
+ import torch
20
+ import torch.utils.checkpoint as ckpt
21
+ import logging
22
+
23
+ from cosyvoice.transformer.decoder_layer import DecoderLayer
24
+ from cosyvoice.transformer.positionwise_feed_forward import (
25
+ PositionwiseFeedForward,
26
+ )
27
+ from cosyvoice.utils.class_utils import (
28
+ COSYVOICE_EMB_CLASSES,
29
+ COSYVOICE_ATTENTION_CLASSES,
30
+ COSYVOICE_ACTIVATION_CLASSES,
31
+ )
32
+ from cosyvoice.utils.mask import subsequent_mask, make_pad_mask
33
+
34
+
35
+ class TransformerDecoder(torch.nn.Module):
36
+ """Base class of Transfomer decoder module.
37
+ Args:
38
+ vocab_size: output dim
39
+ encoder_output_size: dimension of attention
40
+ attention_heads: the number of heads of multi head attention
41
+ linear_units: the hidden units number of position-wise feedforward
42
+ num_blocks: the number of decoder blocks
43
+ dropout_rate: dropout rate
44
+ self_attention_dropout_rate: dropout rate for attention
45
+ input_layer: input layer type
46
+ use_output_layer: whether to use output layer
47
+ pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
48
+ normalize_before:
49
+ True: use layer_norm before each sub-block of a layer.
50
+ False: use layer_norm after each sub-block of a layer.
51
+ src_attention: if false, encoder-decoder cross attention is not
52
+ applied, such as CIF model
53
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
54
+ gradient_checkpointing: rerunning a forward-pass segment for each
55
+ checkpointed segment during backward.
56
+ tie_word_embedding: Tie or clone module weights depending of whether we are
57
+ using TorchScript or not
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ vocab_size: int,
63
+ encoder_output_size: int,
64
+ attention_heads: int = 4,
65
+ linear_units: int = 2048,
66
+ num_blocks: int = 6,
67
+ dropout_rate: float = 0.1,
68
+ positional_dropout_rate: float = 0.1,
69
+ self_attention_dropout_rate: float = 0.0,
70
+ src_attention_dropout_rate: float = 0.0,
71
+ input_layer: str = "embed",
72
+ use_output_layer: bool = True,
73
+ normalize_before: bool = True,
74
+ src_attention: bool = True,
75
+ key_bias: bool = True,
76
+ activation_type: str = "relu",
77
+ gradient_checkpointing: bool = False,
78
+ tie_word_embedding: bool = False,
79
+ ):
80
+ super().__init__()
81
+ attention_dim = encoder_output_size
82
+ activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
83
+
84
+ self.embed = torch.nn.Sequential(
85
+ (
86
+ torch.nn.Identity()
87
+ if input_layer == "no_pos"
88
+ else torch.nn.Embedding(vocab_size, attention_dim)
89
+ ),
90
+ COSYVOICE_EMB_CLASSES[input_layer](attention_dim, positional_dropout_rate),
91
+ )
92
+
93
+ self.normalize_before = normalize_before
94
+ self.after_norm = torch.nn.LayerNorm(attention_dim, eps=1e-5)
95
+ self.use_output_layer = use_output_layer
96
+ if use_output_layer:
97
+ self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
98
+ else:
99
+ self.output_layer = torch.nn.Identity()
100
+ self.num_blocks = num_blocks
101
+ self.decoders = torch.nn.ModuleList(
102
+ [
103
+ DecoderLayer(
104
+ attention_dim,
105
+ COSYVOICE_ATTENTION_CLASSES["selfattn"](
106
+ attention_heads,
107
+ attention_dim,
108
+ self_attention_dropout_rate,
109
+ key_bias,
110
+ ),
111
+ (
112
+ COSYVOICE_ATTENTION_CLASSES["selfattn"](
113
+ attention_heads,
114
+ attention_dim,
115
+ src_attention_dropout_rate,
116
+ key_bias,
117
+ )
118
+ if src_attention
119
+ else None
120
+ ),
121
+ PositionwiseFeedForward(
122
+ attention_dim, linear_units, dropout_rate, activation
123
+ ),
124
+ dropout_rate,
125
+ normalize_before,
126
+ )
127
+ for _ in range(self.num_blocks)
128
+ ]
129
+ )
130
+
131
+ self.gradient_checkpointing = gradient_checkpointing
132
+ self.tie_word_embedding = tie_word_embedding
133
+
134
+ def forward(
135
+ self,
136
+ memory: torch.Tensor,
137
+ memory_mask: torch.Tensor,
138
+ ys_in_pad: torch.Tensor,
139
+ ys_in_lens: torch.Tensor,
140
+ r_ys_in_pad: torch.Tensor = torch.empty(0),
141
+ reverse_weight: float = 0.0,
142
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
143
+ """Forward decoder.
144
+ Args:
145
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
146
+ memory_mask: encoder memory mask, (batch, 1, maxlen_in)
147
+ ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
148
+ ys_in_lens: input lengths of this batch (batch)
149
+ r_ys_in_pad: not used in transformer decoder, in order to unify api
150
+ with bidirectional decoder
151
+ reverse_weight: not used in transformer decoder, in order to unify
152
+ api with bidirectional decode
153
+ Returns:
154
+ (tuple): tuple containing:
155
+ x: decoded token score before softmax (batch, maxlen_out,
156
+ vocab_size) if use_output_layer is True,
157
+ torch.tensor(0.0), in order to unify api with bidirectional decoder
158
+ olens: (batch, )
159
+ NOTE(xcsong):
160
+ We pass the `__call__` method of the modules instead of `forward` to the
161
+ checkpointing API because `__call__` attaches all the hooks of the module.
162
+ https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
163
+ """
164
+ tgt = ys_in_pad
165
+ maxlen = tgt.size(1)
166
+ # tgt_mask: (B, 1, L)
167
+ tgt_mask = ~make_pad_mask(ys_in_lens, maxlen).unsqueeze(1)
168
+ tgt_mask = tgt_mask.to(tgt.device)
169
+ # m: (1, L, L)
170
+ m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
171
+ # tgt_mask: (B, L, L)
172
+ tgt_mask = tgt_mask & m
173
+ x, _ = self.embed(tgt)
174
+ if self.gradient_checkpointing and self.training:
175
+ x = self.forward_layers_checkpointed(x, tgt_mask, memory, memory_mask)
176
+ else:
177
+ x = self.forward_layers(x, tgt_mask, memory, memory_mask)
178
+ if self.normalize_before:
179
+ x = self.after_norm(x)
180
+ if self.use_output_layer:
181
+ x = self.output_layer(x)
182
+ olens = tgt_mask.sum(1)
183
+ return x, torch.tensor(0.0), olens
184
+
185
+ def forward_layers(
186
+ self,
187
+ x: torch.Tensor,
188
+ tgt_mask: torch.Tensor,
189
+ memory: torch.Tensor,
190
+ memory_mask: torch.Tensor,
191
+ ) -> torch.Tensor:
192
+ for layer in self.decoders:
193
+ x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory, memory_mask)
194
+ return x
195
+
196
+ @torch.jit.unused
197
+ def forward_layers_checkpointed(
198
+ self,
199
+ x: torch.Tensor,
200
+ tgt_mask: torch.Tensor,
201
+ memory: torch.Tensor,
202
+ memory_mask: torch.Tensor,
203
+ ) -> torch.Tensor:
204
+ for layer in self.decoders:
205
+ x, tgt_mask, memory, memory_mask = ckpt.checkpoint(
206
+ layer.__call__, x, tgt_mask, memory, memory_mask
207
+ )
208
+ return x
209
+
210
+ def forward_one_step(
211
+ self,
212
+ memory: torch.Tensor,
213
+ memory_mask: torch.Tensor,
214
+ tgt: torch.Tensor,
215
+ tgt_mask: torch.Tensor,
216
+ cache: Optional[List[torch.Tensor]] = None,
217
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
218
+ """Forward one step.
219
+ This is only used for decoding.
220
+ Args:
221
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
222
+ memory_mask: encoded memory mask, (batch, 1, maxlen_in)
223
+ tgt: input token ids, int64 (batch, maxlen_out)
224
+ tgt_mask: input token mask, (batch, maxlen_out)
225
+ dtype=torch.uint8 in PyTorch 1.2-
226
+ dtype=torch.bool in PyTorch 1.2+ (include 1.2)
227
+ cache: cached output list of (batch, max_time_out-1, size)
228
+ Returns:
229
+ y, cache: NN output value and cache per `self.decoders`.
230
+ y.shape` is (batch, maxlen_out, token)
231
+ """
232
+ x, _ = self.embed(tgt)
233
+ new_cache = []
234
+ for i, decoder in enumerate(self.decoders):
235
+ if cache is None:
236
+ c = None
237
+ else:
238
+ c = cache[i]
239
+ x, tgt_mask, memory, memory_mask = decoder(
240
+ x, tgt_mask, memory, memory_mask, cache=c
241
+ )
242
+ new_cache.append(x)
243
+ if self.normalize_before:
244
+ y = self.after_norm(x[:, -1])
245
+ else:
246
+ y = x[:, -1]
247
+ if self.use_output_layer:
248
+ y = torch.log_softmax(self.output_layer(y), dim=-1)
249
+ return y, new_cache
250
+
251
+ def tie_or_clone_weights(self, jit_mode: bool = True):
252
+ """Tie or clone module weights (between word_emb and output_layer)
253
+ depending of whether we are using TorchScript or not"""
254
+ if not self.use_output_layer:
255
+ return
256
+ if jit_mode:
257
+ logging.info("clone emb.weight to output.weight")
258
+ self.output_layer.weight = torch.nn.Parameter(self.embed[0].weight.clone())
259
+ else:
260
+ logging.info("tie emb.weight with output.weight")
261
+ self.output_layer.weight = self.embed[0].weight
262
+
263
+ if getattr(self.output_layer, "bias", None) is not None:
264
+ self.output_layer.bias.data = torch.nn.functional.pad(
265
+ self.output_layer.bias.data,
266
+ (
267
+ 0,
268
+ self.output_layer.weight.shape[0] - self.output_layer.bias.shape[0],
269
+ ),
270
+ "constant",
271
+ 0,
272
+ )
273
+
274
+
275
+ class BiTransformerDecoder(torch.nn.Module):
276
+ """Base class of Transfomer decoder module.
277
+ Args:
278
+ vocab_size: output dim
279
+ encoder_output_size: dimension of attention
280
+ attention_heads: the number of heads of multi head attention
281
+ linear_units: the hidden units number of position-wise feedforward
282
+ num_blocks: the number of decoder blocks
283
+ r_num_blocks: the number of right to left decoder blocks
284
+ dropout_rate: dropout rate
285
+ self_attention_dropout_rate: dropout rate for attention
286
+ input_layer: input layer type
287
+ use_output_layer: whether to use output layer
288
+ pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
289
+ normalize_before:
290
+ True: use layer_norm before each sub-block of a layer.
291
+ False: use layer_norm after each sub-block of a layer.
292
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
293
+ """
294
+
295
+ def __init__(
296
+ self,
297
+ vocab_size: int,
298
+ encoder_output_size: int,
299
+ attention_heads: int = 4,
300
+ linear_units: int = 2048,
301
+ num_blocks: int = 6,
302
+ r_num_blocks: int = 0,
303
+ dropout_rate: float = 0.1,
304
+ positional_dropout_rate: float = 0.1,
305
+ self_attention_dropout_rate: float = 0.0,
306
+ src_attention_dropout_rate: float = 0.0,
307
+ input_layer: str = "embed",
308
+ use_output_layer: bool = True,
309
+ normalize_before: bool = True,
310
+ key_bias: bool = True,
311
+ gradient_checkpointing: bool = False,
312
+ tie_word_embedding: bool = False,
313
+ ):
314
+
315
+ super().__init__()
316
+ self.tie_word_embedding = tie_word_embedding
317
+ self.left_decoder = TransformerDecoder(
318
+ vocab_size,
319
+ encoder_output_size,
320
+ attention_heads,
321
+ linear_units,
322
+ num_blocks,
323
+ dropout_rate,
324
+ positional_dropout_rate,
325
+ self_attention_dropout_rate,
326
+ src_attention_dropout_rate,
327
+ input_layer,
328
+ use_output_layer,
329
+ normalize_before,
330
+ key_bias=key_bias,
331
+ gradient_checkpointing=gradient_checkpointing,
332
+ tie_word_embedding=tie_word_embedding,
333
+ )
334
+
335
+ self.right_decoder = TransformerDecoder(
336
+ vocab_size,
337
+ encoder_output_size,
338
+ attention_heads,
339
+ linear_units,
340
+ r_num_blocks,
341
+ dropout_rate,
342
+ positional_dropout_rate,
343
+ self_attention_dropout_rate,
344
+ src_attention_dropout_rate,
345
+ input_layer,
346
+ use_output_layer,
347
+ normalize_before,
348
+ key_bias=key_bias,
349
+ gradient_checkpointing=gradient_checkpointing,
350
+ tie_word_embedding=tie_word_embedding,
351
+ )
352
+
353
+ def forward(
354
+ self,
355
+ memory: torch.Tensor,
356
+ memory_mask: torch.Tensor,
357
+ ys_in_pad: torch.Tensor,
358
+ ys_in_lens: torch.Tensor,
359
+ r_ys_in_pad: torch.Tensor,
360
+ reverse_weight: float = 0.0,
361
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
362
+ """Forward decoder.
363
+ Args:
364
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
365
+ memory_mask: encoder memory mask, (batch, 1, maxlen_in)
366
+ ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
367
+ ys_in_lens: input lengths of this batch (batch)
368
+ r_ys_in_pad: padded input token ids, int64 (batch, maxlen_out),
369
+ used for right to left decoder
370
+ reverse_weight: used for right to left decoder
371
+ Returns:
372
+ (tuple): tuple containing:
373
+ x: decoded token score before softmax (batch, maxlen_out,
374
+ vocab_size) if use_output_layer is True,
375
+ r_x: x: decoded token score (right to left decoder)
376
+ before softmax (batch, maxlen_out, vocab_size)
377
+ if use_output_layer is True,
378
+ olens: (batch, )
379
+ """
380
+ l_x, _, olens = self.left_decoder(memory, memory_mask, ys_in_pad, ys_in_lens)
381
+ r_x = torch.tensor(0.0)
382
+ if reverse_weight > 0.0:
383
+ r_x, _, olens = self.right_decoder(
384
+ memory, memory_mask, r_ys_in_pad, ys_in_lens
385
+ )
386
+ return l_x, r_x, olens
387
+
388
+ def forward_one_step(
389
+ self,
390
+ memory: torch.Tensor,
391
+ memory_mask: torch.Tensor,
392
+ tgt: torch.Tensor,
393
+ tgt_mask: torch.Tensor,
394
+ cache: Optional[List[torch.Tensor]] = None,
395
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
396
+ """Forward one step.
397
+ This is only used for decoding.
398
+ Args:
399
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
400
+ memory_mask: encoded memory mask, (batch, 1, maxlen_in)
401
+ tgt: input token ids, int64 (batch, maxlen_out)
402
+ tgt_mask: input token mask, (batch, maxlen_out)
403
+ dtype=torch.uint8 in PyTorch 1.2-
404
+ dtype=torch.bool in PyTorch 1.2+ (include 1.2)
405
+ cache: cached output list of (batch, max_time_out-1, size)
406
+ Returns:
407
+ y, cache: NN output value and cache per `self.decoders`.
408
+ y.shape` is (batch, maxlen_out, token)
409
+ """
410
+ return self.left_decoder.forward_one_step(
411
+ memory, memory_mask, tgt, tgt_mask, cache
412
+ )
413
+
414
+ def tie_or_clone_weights(self, jit_mode: bool = True):
415
+ """Tie or clone module weights (between word_emb and output_layer)
416
+ depending of whether we are using TorchScript or not"""
417
+ self.left_decoder.tie_or_clone_weights(jit_mode)
418
+ self.right_decoder.tie_or_clone_weights(jit_mode)
cosyvoice/transformer/decoder_layer.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Decoder self-attention layer definition."""
16
+ from typing import Optional, Tuple
17
+
18
+ import torch
19
+ from torch import nn
20
+
21
+
22
+ class DecoderLayer(nn.Module):
23
+ """Single decoder layer module.
24
+
25
+ Args:
26
+ size (int): Input dimension.
27
+ self_attn (torch.nn.Module): Self-attention module instance.
28
+ `MultiHeadedAttention` instance can be used as the argument.
29
+ src_attn (torch.nn.Module): Inter-attention module instance.
30
+ `MultiHeadedAttention` instance can be used as the argument.
31
+ If `None` is passed, Inter-attention is not used, such as
32
+ CIF, GPT, and other decoder only model.
33
+ feed_forward (torch.nn.Module): Feed-forward module instance.
34
+ `PositionwiseFeedForward` instance can be used as the argument.
35
+ dropout_rate (float): Dropout rate.
36
+ normalize_before (bool):
37
+ True: use layer_norm before each sub-block.
38
+ False: to use layer_norm after each sub-block.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ size: int,
44
+ self_attn: nn.Module,
45
+ src_attn: Optional[nn.Module],
46
+ feed_forward: nn.Module,
47
+ dropout_rate: float,
48
+ normalize_before: bool = True,
49
+ ):
50
+ """Construct an DecoderLayer object."""
51
+ super().__init__()
52
+ self.size = size
53
+ self.self_attn = self_attn
54
+ self.src_attn = src_attn
55
+ self.feed_forward = feed_forward
56
+ self.norm1 = nn.LayerNorm(size, eps=1e-5)
57
+ self.norm2 = nn.LayerNorm(size, eps=1e-5)
58
+ self.norm3 = nn.LayerNorm(size, eps=1e-5)
59
+ self.dropout = nn.Dropout(dropout_rate)
60
+ self.normalize_before = normalize_before
61
+
62
+ def forward(
63
+ self,
64
+ tgt: torch.Tensor,
65
+ tgt_mask: torch.Tensor,
66
+ memory: torch.Tensor,
67
+ memory_mask: torch.Tensor,
68
+ cache: Optional[torch.Tensor] = None,
69
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
70
+ """Compute decoded features.
71
+
72
+ Args:
73
+ tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
74
+ tgt_mask (torch.Tensor): Mask for input tensor
75
+ (#batch, maxlen_out).
76
+ memory (torch.Tensor): Encoded memory
77
+ (#batch, maxlen_in, size).
78
+ memory_mask (torch.Tensor): Encoded memory mask
79
+ (#batch, maxlen_in).
80
+ cache (torch.Tensor): cached tensors.
81
+ (#batch, maxlen_out - 1, size).
82
+
83
+ Returns:
84
+ torch.Tensor: Output tensor (#batch, maxlen_out, size).
85
+ torch.Tensor: Mask for output tensor (#batch, maxlen_out).
86
+ torch.Tensor: Encoded memory (#batch, maxlen_in, size).
87
+ torch.Tensor: Encoded memory mask (#batch, maxlen_in).
88
+
89
+ """
90
+ residual = tgt
91
+ if self.normalize_before:
92
+ tgt = self.norm1(tgt)
93
+
94
+ if cache is None:
95
+ tgt_q = tgt
96
+ tgt_q_mask = tgt_mask
97
+ else:
98
+ # compute only the last frame query keeping dim: max_time_out -> 1
99
+ assert cache.shape == (
100
+ tgt.shape[0],
101
+ tgt.shape[1] - 1,
102
+ self.size,
103
+ ), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
104
+ tgt_q = tgt[:, -1:, :]
105
+ residual = residual[:, -1:, :]
106
+ tgt_q_mask = tgt_mask[:, -1:, :]
107
+
108
+ x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0])
109
+ if not self.normalize_before:
110
+ x = self.norm1(x)
111
+
112
+ if self.src_attn is not None:
113
+ residual = x
114
+ if self.normalize_before:
115
+ x = self.norm2(x)
116
+ x = residual + self.dropout(
117
+ self.src_attn(x, memory, memory, memory_mask)[0]
118
+ )
119
+ if not self.normalize_before:
120
+ x = self.norm2(x)
121
+
122
+ residual = x
123
+ if self.normalize_before:
124
+ x = self.norm3(x)
125
+ x = residual + self.dropout(self.feed_forward(x))
126
+ if not self.normalize_before:
127
+ x = self.norm3(x)
128
+
129
+ if cache is not None:
130
+ x = torch.cat([cache, x], dim=1)
131
+
132
+ return x, tgt_mask, memory, memory_mask
cosyvoice/transformer/embedding.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
2
+ # 2024 Alibaba Inc (Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """Positonal Encoding Module."""
17
+
18
+ import math
19
+ from typing import Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ import numpy as np
24
+
25
+
26
+ class PositionalEncoding(torch.nn.Module):
27
+ """Positional encoding.
28
+
29
+ :param int d_model: embedding dim
30
+ :param float dropout_rate: dropout rate
31
+ :param int max_len: maximum input length
32
+
33
+ PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
34
+ PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ d_model: int,
40
+ dropout_rate: float,
41
+ max_len: int = 5000,
42
+ reverse: bool = False,
43
+ ):
44
+ """Construct an PositionalEncoding object."""
45
+ super().__init__()
46
+ self.d_model = d_model
47
+ self.xscale = math.sqrt(self.d_model)
48
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
49
+ self.max_len = max_len
50
+
51
+ self.pe = torch.zeros(self.max_len, self.d_model)
52
+ position = torch.arange(0, self.max_len, dtype=torch.float32).unsqueeze(1)
53
+ div_term = torch.exp(
54
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
55
+ * -(math.log(10000.0) / self.d_model)
56
+ )
57
+ self.pe[:, 0::2] = torch.sin(position * div_term)
58
+ self.pe[:, 1::2] = torch.cos(position * div_term)
59
+ self.pe = self.pe.unsqueeze(0)
60
+
61
+ def forward(
62
+ self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0
63
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
64
+ """Add positional encoding.
65
+
66
+ Args:
67
+ x (torch.Tensor): Input. Its shape is (batch, time, ...)
68
+ offset (int, torch.tensor): position offset
69
+
70
+ Returns:
71
+ torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
72
+ torch.Tensor: for compatibility to RelPositionalEncoding
73
+ """
74
+
75
+ self.pe = self.pe.to(x.device)
76
+ pos_emb = self.position_encoding(offset, x.size(1), False)
77
+ x = x * self.xscale + pos_emb
78
+ return self.dropout(x), self.dropout(pos_emb)
79
+
80
+ def position_encoding(
81
+ self, offset: Union[int, torch.Tensor], size: int, apply_dropout: bool = True
82
+ ) -> torch.Tensor:
83
+ """For getting encoding in a streaming fashion
84
+
85
+ Attention!!!!!
86
+ we apply dropout only once at the whole utterance level in a none
87
+ streaming way, but will call this function several times with
88
+ increasing input size in a streaming scenario, so the dropout will
89
+ be applied several times.
90
+
91
+ Args:
92
+ offset (int or torch.tensor): start offset
93
+ size (int): required size of position encoding
94
+
95
+ Returns:
96
+ torch.Tensor: Corresponding encoding
97
+ """
98
+ # How to subscript a Union type:
99
+ # https://github.com/pytorch/pytorch/issues/69434
100
+ if isinstance(offset, int):
101
+ assert offset + size <= self.max_len
102
+ pos_emb = self.pe[:, offset : offset + size]
103
+ elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar
104
+ assert offset + size <= self.max_len
105
+ pos_emb = self.pe[:, offset : offset + size]
106
+ else: # for batched streaming decoding on GPU
107
+ assert torch.max(offset) + size <= self.max_len
108
+ index = offset.unsqueeze(1) + torch.arange(0, size).to(
109
+ offset.device
110
+ ) # B X T
111
+ flag = index > 0
112
+ # remove negative offset
113
+ index = index * flag
114
+ pos_emb = F.embedding(index, self.pe[0]) # B X T X d_model
115
+
116
+ if apply_dropout:
117
+ pos_emb = self.dropout(pos_emb)
118
+ return pos_emb
119
+
120
+
121
+ class RelPositionalEncoding(PositionalEncoding):
122
+ """Relative positional encoding module.
123
+ See : Appendix B in https://arxiv.org/abs/1901.02860
124
+ Args:
125
+ d_model (int): Embedding dimension.
126
+ dropout_rate (float): Dropout rate.
127
+ max_len (int): Maximum input length.
128
+ """
129
+
130
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
131
+ """Initialize class."""
132
+ super().__init__(d_model, dropout_rate, max_len, reverse=True)
133
+
134
+ def forward(
135
+ self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0
136
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
137
+ """Compute positional encoding.
138
+ Args:
139
+ x (torch.Tensor): Input tensor (batch, time, `*`).
140
+ Returns:
141
+ torch.Tensor: Encoded tensor (batch, time, `*`).
142
+ torch.Tensor: Positional embedding tensor (1, time, `*`).
143
+ """
144
+ self.pe = self.pe.to(x.device)
145
+ x = x * self.xscale
146
+ pos_emb = self.position_encoding(offset, x.size(1), False)
147
+ return self.dropout(x), self.dropout(pos_emb)
148
+
149
+
150
+ class WhisperPositionalEncoding(PositionalEncoding):
151
+ """Sinusoids position encoding used in openai-whisper.encoder"""
152
+
153
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 1500):
154
+ super().__init__(d_model, dropout_rate, max_len)
155
+ self.xscale = 1.0
156
+ log_timescale_increment = np.log(10000) / (d_model // 2 - 1)
157
+ inv_timescales = torch.exp(
158
+ -log_timescale_increment * torch.arange(d_model // 2)
159
+ )
160
+ scaled_time = (
161
+ torch.arange(max_len)[:, np.newaxis] * inv_timescales[np.newaxis, :]
162
+ )
163
+ pe = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
164
+ delattr(self, "pe")
165
+ self.register_buffer("pe", pe.unsqueeze(0))
166
+
167
+
168
+ class LearnablePositionalEncoding(PositionalEncoding):
169
+ """Learnable position encoding used in openai-whisper.decoder"""
170
+
171
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 448):
172
+ super().__init__(d_model, dropout_rate, max_len)
173
+ # NOTE(xcsong): overwrite self.pe & self.xscale
174
+ self.pe = torch.nn.Parameter(torch.empty(1, max_len, d_model))
175
+ self.xscale = 1.0
176
+
177
+
178
+ class NoPositionalEncoding(torch.nn.Module):
179
+ """No position encoding"""
180
+
181
+ def __init__(self, d_model: int, dropout_rate: float):
182
+ super().__init__()
183
+ self.d_model = d_model
184
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
185
+
186
+ def forward(
187
+ self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0
188
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
189
+ """Just return zero vector for interface compatibility"""
190
+ pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device)
191
+ return self.dropout(x), pos_emb
192
+
193
+ def position_encoding(
194
+ self, offset: Union[int, torch.Tensor], size: int
195
+ ) -> torch.Tensor:
196
+ return torch.zeros(1, size, self.d_model)
197
+
198
+
199
+ class EspnetRelPositionalEncoding(torch.nn.Module):
200
+ """Relative positional encoding module (new implementation).
201
+
202
+ Details can be found in https://github.com/espnet/espnet/pull/2816.
203
+
204
+ See : Appendix B in https://arxiv.org/abs/1901.02860
205
+
206
+ Args:
207
+ d_model (int): Embedding dimension.
208
+ dropout_rate (float): Dropout rate.
209
+ max_len (int): Maximum input length.
210
+
211
+ """
212
+
213
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
214
+ """Construct an PositionalEncoding object."""
215
+ super(EspnetRelPositionalEncoding, self).__init__()
216
+ self.d_model = d_model
217
+ self.xscale = math.sqrt(self.d_model)
218
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
219
+ self.pe = None
220
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
221
+
222
+ def extend_pe(self, x: torch.Tensor):
223
+ """Reset the positional encodings."""
224
+ if self.pe is not None:
225
+ # self.pe contains both positive and negative parts
226
+ # the length of self.pe is 2 * input_len - 1
227
+ if self.pe.size(1) >= x.size(1) * 2 - 1:
228
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
229
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
230
+ return
231
+ # Suppose `i` means to the position of query vecotr and `j` means the
232
+ # position of key vector. We use position relative positions when keys
233
+ # are to the left (i>j) and negative relative positions otherwise (i<j).
234
+ pe_positive = torch.zeros(x.size(1), self.d_model)
235
+ pe_negative = torch.zeros(x.size(1), self.d_model)
236
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
237
+ div_term = torch.exp(
238
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
239
+ * -(math.log(10000.0) / self.d_model)
240
+ )
241
+ pe_positive[:, 0::2] = torch.sin(position * div_term)
242
+ pe_positive[:, 1::2] = torch.cos(position * div_term)
243
+ pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
244
+ pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
245
+
246
+ # Reserve the order of positive indices and concat both positive and
247
+ # negative indices. This is used to support the shifting trick
248
+ # as in https://arxiv.org/abs/1901.02860
249
+ pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
250
+ pe_negative = pe_negative[1:].unsqueeze(0)
251
+ pe = torch.cat([pe_positive, pe_negative], dim=1)
252
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
253
+
254
+ def forward(
255
+ self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0
256
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
257
+ """Add positional encoding.
258
+
259
+ Args:
260
+ x (torch.Tensor): Input tensor (batch, time, `*`).
261
+
262
+ Returns:
263
+ torch.Tensor: Encoded tensor (batch, time, `*`).
264
+
265
+ """
266
+ self.extend_pe(x)
267
+ x = x * self.xscale
268
+ pos_emb = self.position_encoding(size=x.size(1), offset=offset)
269
+ return self.dropout(x), self.dropout(pos_emb)
270
+
271
+ def position_encoding(
272
+ self, offset: Union[int, torch.Tensor], size: int
273
+ ) -> torch.Tensor:
274
+ """For getting encoding in a streaming fashion
275
+
276
+ Attention!!!!!
277
+ we apply dropout only once at the whole utterance level in a none
278
+ streaming way, but will call this function several times with
279
+ increasing input size in a streaming scenario, so the dropout will
280
+ be applied several times.
281
+
282
+ Args:
283
+ offset (int or torch.tensor): start offset
284
+ size (int): required size of position encoding
285
+
286
+ Returns:
287
+ torch.Tensor: Corresponding encoding
288
+ """
289
+ pos_emb = self.pe[
290
+ :,
291
+ self.pe.size(1) // 2 - size + 1 : self.pe.size(1) // 2 + size,
292
+ ]
293
+ return pos_emb
cosyvoice/transformer/encoder.py ADDED
@@ -0,0 +1,633 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
2
+ # 2022 Xingchen Song ([email protected])
3
+ # 2024 Alibaba Inc (Xiang Lyu)
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # Modified from ESPnet(https://github.com/espnet/espnet)
17
+ """Encoder definition."""
18
+ from typing import Tuple
19
+ import time
20
+
21
+ import torch
22
+ import torch.utils.checkpoint as ckpt
23
+ import torch.nn.functional as F
24
+
25
+ from cosyvoice.transformer.convolution import ConvolutionModule
26
+ from cosyvoice.transformer.encoder_layer import (
27
+ TransformerEncoderLayer,
28
+ )
29
+ from cosyvoice.transformer.encoder_layer import (
30
+ ConformerEncoderLayer,
31
+ )
32
+ from cosyvoice.transformer.positionwise_feed_forward import (
33
+ PositionwiseFeedForward,
34
+ )
35
+ from cosyvoice.utils.class_utils import (
36
+ COSYVOICE_EMB_CLASSES,
37
+ COSYVOICE_SUBSAMPLE_CLASSES,
38
+ COSYVOICE_ATTENTION_CLASSES,
39
+ COSYVOICE_ACTIVATION_CLASSES,
40
+ )
41
+ from cosyvoice.utils.mask import make_pad_mask
42
+ from cosyvoice.utils.mask import add_optional_chunk_mask
43
+
44
+
45
+ class BaseEncoder(torch.nn.Module):
46
+
47
+ def __init__(
48
+ self,
49
+ input_size: int,
50
+ output_size: int = 256,
51
+ attention_heads: int = 4,
52
+ linear_units: int = 2048,
53
+ num_blocks: int = 6,
54
+ dropout_rate: float = 0.1,
55
+ positional_dropout_rate: float = 0.1,
56
+ attention_dropout_rate: float = 0.0,
57
+ input_layer: str = "conv2d",
58
+ pos_enc_layer_type: str = "abs_pos",
59
+ normalize_before: bool = True,
60
+ static_chunk_size: int = 0,
61
+ use_dynamic_chunk: bool = False,
62
+ global_cmvn: torch.nn.Module = None,
63
+ use_dynamic_left_chunk: bool = False,
64
+ gradient_checkpointing: bool = False,
65
+ ):
66
+ """
67
+ Args:
68
+ input_size (int): input dim
69
+ output_size (int): dimension of attention
70
+ attention_heads (int): the number of heads of multi head attention
71
+ linear_units (int): the hidden units number of position-wise feed
72
+ forward
73
+ num_blocks (int): the number of decoder blocks
74
+ dropout_rate (float): dropout rate
75
+ attention_dropout_rate (float): dropout rate in attention
76
+ positional_dropout_rate (float): dropout rate after adding
77
+ positional encoding
78
+ input_layer (str): input layer type.
79
+ optional [linear, conv2d, conv2d6, conv2d8]
80
+ pos_enc_layer_type (str): Encoder positional encoding layer type.
81
+ opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
82
+ normalize_before (bool):
83
+ True: use layer_norm before each sub-block of a layer.
84
+ False: use layer_norm after each sub-block of a layer.
85
+ static_chunk_size (int): chunk size for static chunk training and
86
+ decoding
87
+ use_dynamic_chunk (bool): whether use dynamic chunk size for
88
+ training or not, You can only use fixed chunk(chunk_size > 0)
89
+ or dyanmic chunk size(use_dynamic_chunk = True)
90
+ global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
91
+ use_dynamic_left_chunk (bool): whether use dynamic left chunk in
92
+ dynamic chunk training
93
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
94
+ gradient_checkpointing: rerunning a forward-pass segment for each
95
+ checkpointed segment during backward.
96
+ """
97
+ super().__init__()
98
+ self._output_size = output_size
99
+
100
+ self.global_cmvn = global_cmvn
101
+ self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
102
+ input_size,
103
+ output_size,
104
+ dropout_rate,
105
+ COSYVOICE_EMB_CLASSES[pos_enc_layer_type](
106
+ output_size, positional_dropout_rate
107
+ ),
108
+ )
109
+
110
+ self.normalize_before = normalize_before
111
+ self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
112
+ self.static_chunk_size = static_chunk_size
113
+ self.use_dynamic_chunk = use_dynamic_chunk
114
+ self.use_dynamic_left_chunk = use_dynamic_left_chunk
115
+ self.gradient_checkpointing = gradient_checkpointing
116
+
117
+ def output_size(self) -> int:
118
+ return self._output_size
119
+
120
+ def forward(
121
+ self,
122
+ xs: torch.Tensor,
123
+ xs_lens: torch.Tensor,
124
+ decoding_chunk_size: int = 0,
125
+ num_decoding_left_chunks: int = -1,
126
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
127
+ """Embed positions in tensor.
128
+
129
+ Args:
130
+ xs: padded input tensor (B, T, D)
131
+ xs_lens: input length (B)
132
+ decoding_chunk_size: decoding chunk size for dynamic chunk
133
+ 0: default for training, use random dynamic chunk.
134
+ <0: for decoding, use full chunk.
135
+ >0: for decoding, use fixed chunk size as set.
136
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
137
+ the chunk size is decoding_chunk_size.
138
+ >=0: use num_decoding_left_chunks
139
+ <0: use all left chunks
140
+ Returns:
141
+ encoder output tensor xs, and subsampled masks
142
+ xs: padded output tensor (B, T' ~= T/subsample_rate, D)
143
+ masks: torch.Tensor batch padding mask after subsample
144
+ (B, 1, T' ~= T/subsample_rate)
145
+ NOTE(xcsong):
146
+ We pass the `__call__` method of the modules instead of `forward` to the
147
+ checkpointing API because `__call__` attaches all the hooks of the module.
148
+ https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
149
+ """
150
+ T = xs.size(1)
151
+ masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
152
+ if self.global_cmvn is not None:
153
+ xs = self.global_cmvn(xs)
154
+ xs, pos_emb, masks = self.embed(xs, masks)
155
+ mask_pad = masks # (B, 1, T/subsample_rate)
156
+ chunk_masks = add_optional_chunk_mask(
157
+ xs,
158
+ masks,
159
+ self.use_dynamic_chunk,
160
+ self.use_dynamic_left_chunk,
161
+ decoding_chunk_size,
162
+ self.static_chunk_size,
163
+ num_decoding_left_chunks,
164
+ )
165
+ print(f"chunk_masks shape: {chunk_masks.shape}")
166
+ if self.gradient_checkpointing and self.training:
167
+ xs = self.forward_layers_checkpointed(xs, chunk_masks, pos_emb, mask_pad)
168
+ else:
169
+ xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
170
+ if self.normalize_before:
171
+ xs = self.after_norm(xs)
172
+ # Here we assume the mask is not changed in encoder layers, so just
173
+ # return the masks before encoder layers, and the masks will be used
174
+ # for cross attention with decoder later
175
+ return xs, masks
176
+
177
+ def forward_layers(
178
+ self,
179
+ xs: torch.Tensor,
180
+ chunk_masks: torch.Tensor,
181
+ pos_emb: torch.Tensor,
182
+ mask_pad: torch.Tensor,
183
+ ) -> torch.Tensor:
184
+ for layer in self.encoders:
185
+ xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
186
+ return xs
187
+
188
+ @torch.jit.unused
189
+ def forward_layers_checkpointed(
190
+ self,
191
+ xs: torch.Tensor,
192
+ chunk_masks: torch.Tensor,
193
+ pos_emb: torch.Tensor,
194
+ mask_pad: torch.Tensor,
195
+ ) -> torch.Tensor:
196
+ for layer in self.encoders:
197
+ xs, chunk_masks, _, _ = ckpt.checkpoint(
198
+ layer.__call__, xs, chunk_masks, pos_emb, mask_pad
199
+ )
200
+ return xs
201
+
202
+ @torch.jit.export
203
+ def forward_chunk(
204
+ self,
205
+ xs: torch.Tensor,
206
+ offset: int,
207
+ required_cache_size: int,
208
+ att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
209
+ cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
210
+ att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
211
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
212
+ """ Forward just one chunk
213
+
214
+ Args:
215
+ xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim),
216
+ where `time == (chunk_size - 1) * subsample_rate + \
217
+ subsample.right_context + 1`
218
+ offset (int): current offset in encoder output time stamp
219
+ required_cache_size (int): cache size required for next chunk
220
+ compuation
221
+ >=0: actual cache size
222
+ <0: means all history cache is required
223
+ att_cache (torch.Tensor): cache tensor for KEY & VALUE in
224
+ transformer/conformer attention, with shape
225
+ (elayers, head, cache_t1, d_k * 2), where
226
+ `head * d_k == hidden-dim` and
227
+ `cache_t1 == chunk_size * num_decoding_left_chunks`.
228
+ cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
229
+ (elayers, b=1, hidden-dim, cache_t2), where
230
+ `cache_t2 == cnn.lorder - 1`
231
+
232
+ Returns:
233
+ torch.Tensor: output of current input xs,
234
+ with shape (b=1, chunk_size, hidden-dim).
235
+ torch.Tensor: new attention cache required for next chunk, with
236
+ dynamic shape (elayers, head, ?, d_k * 2)
237
+ depending on required_cache_size.
238
+ torch.Tensor: new conformer cnn cache required for next chunk, with
239
+ same shape as the original cnn_cache.
240
+
241
+ """
242
+ assert xs.size(0) == 1
243
+ # tmp_masks is just for interface compatibility
244
+ tmp_masks = torch.ones(1, xs.size(1), device=xs.device, dtype=torch.bool)
245
+ tmp_masks = tmp_masks.unsqueeze(1)
246
+ if self.global_cmvn is not None:
247
+ xs = self.global_cmvn(xs)
248
+ # NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim)
249
+ xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
250
+ # NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim)
251
+ elayers, cache_t1 = att_cache.size(0), att_cache.size(2)
252
+ chunk_size = xs.size(1)
253
+ attention_key_size = cache_t1 + chunk_size
254
+ pos_emb = self.embed.position_encoding(
255
+ offset=offset - cache_t1, size=attention_key_size
256
+ )
257
+ if required_cache_size < 0:
258
+ next_cache_start = 0
259
+ elif required_cache_size == 0:
260
+ next_cache_start = attention_key_size
261
+ else:
262
+ next_cache_start = max(attention_key_size - required_cache_size, 0)
263
+ r_att_cache = []
264
+ r_cnn_cache = []
265
+ for i, layer in enumerate(self.encoders):
266
+ # NOTE(xcsong): Before layer.forward
267
+ # shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2),
268
+ # shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2)
269
+ xs, _, new_att_cache, new_cnn_cache = layer(
270
+ xs,
271
+ att_mask,
272
+ pos_emb,
273
+ att_cache=att_cache[i : i + 1] if elayers > 0 else att_cache,
274
+ cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache,
275
+ )
276
+ # NOTE(xcsong): After layer.forward
277
+ # shape(new_att_cache) is (1, head, attention_key_size, d_k * 2),
278
+ # shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2)
279
+ r_att_cache.append(new_att_cache[:, :, next_cache_start:, :])
280
+ r_cnn_cache.append(new_cnn_cache.unsqueeze(0))
281
+ if self.normalize_before:
282
+ xs = self.after_norm(xs)
283
+
284
+ # NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2),
285
+ # ? may be larger than cache_t1, it depends on required_cache_size
286
+ r_att_cache = torch.cat(r_att_cache, dim=0)
287
+ # NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2)
288
+ r_cnn_cache = torch.cat(r_cnn_cache, dim=0)
289
+
290
+ return (xs, r_att_cache, r_cnn_cache)
291
+
292
+ @torch.jit.unused
293
+ def forward_chunk_by_chunk(
294
+ self,
295
+ xs: torch.Tensor,
296
+ decoding_chunk_size: int,
297
+ num_decoding_left_chunks: int = -1,
298
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
299
+ """Forward input chunk by chunk with chunk_size like a streaming
300
+ fashion
301
+
302
+ Here we should pay special attention to computation cache in the
303
+ streaming style forward chunk by chunk. Three things should be taken
304
+ into account for computation in the current network:
305
+ 1. transformer/conformer encoder layers output cache
306
+ 2. convolution in conformer
307
+ 3. convolution in subsampling
308
+
309
+ However, we don't implement subsampling cache for:
310
+ 1. We can control subsampling module to output the right result by
311
+ overlapping input instead of cache left context, even though it
312
+ wastes some computation, but subsampling only takes a very
313
+ small fraction of computation in the whole model.
314
+ 2. Typically, there are several covolution layers with subsampling
315
+ in subsampling module, it is tricky and complicated to do cache
316
+ with different convolution layers with different subsampling
317
+ rate.
318
+ 3. Currently, nn.Sequential is used to stack all the convolution
319
+ layers in subsampling, we need to rewrite it to make it work
320
+ with cache, which is not preferred.
321
+ Args:
322
+ xs (torch.Tensor): (1, max_len, dim)
323
+ chunk_size (int): decoding chunk size
324
+ """
325
+ assert decoding_chunk_size > 0
326
+ # The model is trained by static or dynamic chunk
327
+ assert self.static_chunk_size > 0 or self.use_dynamic_chunk
328
+ subsampling = self.embed.subsampling_rate
329
+ context = self.embed.right_context + 1 # Add current frame
330
+ stride = subsampling * decoding_chunk_size
331
+ decoding_window = (decoding_chunk_size - 1) * subsampling + context
332
+ num_frames = xs.size(1)
333
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
334
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
335
+ outputs = []
336
+ offset = 0
337
+ required_cache_size = decoding_chunk_size * num_decoding_left_chunks
338
+
339
+ # Feed forward overlap input step by step
340
+ for cur in range(0, num_frames - context + 1, stride):
341
+ end = min(cur + decoding_window, num_frames)
342
+ chunk_xs = xs[:, cur:end, :]
343
+ (y, att_cache, cnn_cache) = self.forward_chunk(
344
+ chunk_xs, offset, required_cache_size, att_cache, cnn_cache
345
+ )
346
+ outputs.append(y)
347
+ offset += y.size(1)
348
+ ys = torch.cat(outputs, 1)
349
+ masks = torch.ones((1, 1, ys.size(1)), device=ys.device, dtype=torch.bool)
350
+ return ys, masks
351
+
352
+
353
+ class TransformerEncoder(BaseEncoder):
354
+ """Transformer encoder module."""
355
+
356
+ def __init__(
357
+ self,
358
+ input_size: int,
359
+ output_size: int = 256,
360
+ attention_heads: int = 4,
361
+ linear_units: int = 2048,
362
+ num_blocks: int = 6,
363
+ dropout_rate: float = 0.1,
364
+ positional_dropout_rate: float = 0.1,
365
+ attention_dropout_rate: float = 0.0,
366
+ input_layer: str = "conv2d",
367
+ pos_enc_layer_type: str = "abs_pos",
368
+ normalize_before: bool = True,
369
+ static_chunk_size: int = 0,
370
+ use_dynamic_chunk: bool = False,
371
+ global_cmvn: torch.nn.Module = None,
372
+ use_dynamic_left_chunk: bool = False,
373
+ key_bias: bool = True,
374
+ selfattention_layer_type: str = "selfattn",
375
+ activation_type: str = "relu",
376
+ gradient_checkpointing: bool = False,
377
+ ):
378
+ """Construct TransformerEncoder
379
+
380
+ See Encoder for the meaning of each parameter.
381
+ """
382
+ super().__init__(
383
+ input_size,
384
+ output_size,
385
+ attention_heads,
386
+ linear_units,
387
+ num_blocks,
388
+ dropout_rate,
389
+ positional_dropout_rate,
390
+ attention_dropout_rate,
391
+ input_layer,
392
+ pos_enc_layer_type,
393
+ normalize_before,
394
+ static_chunk_size,
395
+ use_dynamic_chunk,
396
+ global_cmvn,
397
+ use_dynamic_left_chunk,
398
+ gradient_checkpointing,
399
+ )
400
+ activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
401
+ self.encoders = torch.nn.ModuleList(
402
+ [
403
+ TransformerEncoderLayer(
404
+ output_size,
405
+ COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
406
+ attention_heads, output_size, attention_dropout_rate, key_bias
407
+ ),
408
+ PositionwiseFeedForward(
409
+ output_size, linear_units, dropout_rate, activation
410
+ ),
411
+ dropout_rate,
412
+ normalize_before,
413
+ )
414
+ for _ in range(num_blocks)
415
+ ]
416
+ )
417
+
418
+
419
+ class ConformerEncoder(BaseEncoder):
420
+ """Conformer encoder module."""
421
+
422
+ def __init__(
423
+ self,
424
+ input_size: int,
425
+ output_size: int = 256,
426
+ attention_heads: int = 4,
427
+ linear_units: int = 2048,
428
+ num_blocks: int = 6,
429
+ dropout_rate: float = 0.1,
430
+ positional_dropout_rate: float = 0.1,
431
+ attention_dropout_rate: float = 0.0,
432
+ input_layer: str = "conv2d",
433
+ pos_enc_layer_type: str = "rel_pos",
434
+ normalize_before: bool = True,
435
+ static_chunk_size: int = 0,
436
+ use_dynamic_chunk: bool = False,
437
+ global_cmvn: torch.nn.Module = None,
438
+ use_dynamic_left_chunk: bool = False,
439
+ positionwise_conv_kernel_size: int = 1,
440
+ macaron_style: bool = True,
441
+ selfattention_layer_type: str = "rel_selfattn",
442
+ activation_type: str = "swish",
443
+ use_cnn_module: bool = True,
444
+ cnn_module_kernel: int = 15,
445
+ causal: bool = False,
446
+ cnn_module_norm: str = "batch_norm",
447
+ key_bias: bool = True,
448
+ gradient_checkpointing: bool = False,
449
+ ):
450
+ """Construct ConformerEncoder
451
+
452
+ Args:
453
+ input_size to use_dynamic_chunk, see in BaseEncoder
454
+ positionwise_conv_kernel_size (int): Kernel size of positionwise
455
+ conv1d layer.
456
+ macaron_style (bool): Whether to use macaron style for
457
+ positionwise layer.
458
+ selfattention_layer_type (str): Encoder attention layer type,
459
+ the parameter has no effect now, it's just for configure
460
+ compatibility.
461
+ activation_type (str): Encoder activation function type.
462
+ use_cnn_module (bool): Whether to use convolution module.
463
+ cnn_module_kernel (int): Kernel size of convolution module.
464
+ causal (bool): whether to use causal convolution or not.
465
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
466
+ """
467
+ super().__init__(
468
+ input_size,
469
+ output_size,
470
+ attention_heads,
471
+ linear_units,
472
+ num_blocks,
473
+ dropout_rate,
474
+ positional_dropout_rate,
475
+ attention_dropout_rate,
476
+ input_layer,
477
+ pos_enc_layer_type,
478
+ normalize_before,
479
+ static_chunk_size,
480
+ use_dynamic_chunk,
481
+ global_cmvn,
482
+ use_dynamic_left_chunk,
483
+ gradient_checkpointing,
484
+ )
485
+ activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
486
+
487
+ # self-attention module definition
488
+ encoder_selfattn_layer_args = (
489
+ attention_heads,
490
+ output_size,
491
+ attention_dropout_rate,
492
+ key_bias,
493
+ )
494
+ # feed-forward module definition
495
+ positionwise_layer_args = (
496
+ output_size,
497
+ linear_units,
498
+ dropout_rate,
499
+ activation,
500
+ )
501
+ # convolution module definition
502
+ convolution_layer_args = (
503
+ output_size,
504
+ cnn_module_kernel,
505
+ activation,
506
+ cnn_module_norm,
507
+ causal,
508
+ )
509
+
510
+ self.encoders = torch.nn.ModuleList(
511
+ [
512
+ ConformerEncoderLayer(
513
+ output_size,
514
+ COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
515
+ *encoder_selfattn_layer_args
516
+ ),
517
+ PositionwiseFeedForward(*positionwise_layer_args),
518
+ (
519
+ PositionwiseFeedForward(*positionwise_layer_args)
520
+ if macaron_style
521
+ else None
522
+ ),
523
+ (
524
+ ConvolutionModule(*convolution_layer_args)
525
+ if use_cnn_module
526
+ else None
527
+ ),
528
+ dropout_rate,
529
+ normalize_before,
530
+ )
531
+ for _ in range(num_blocks)
532
+ ]
533
+ )
534
+ self.inference_buffers = {}
535
+ self.inference_graphs = {}
536
+
537
+ @torch.inference_mode()
538
+ def capture_inference(self, seq_len_to_capture=[128, 256, 512, 1024]):
539
+ device = next(self.parameters()).device
540
+ start_time = time.time()
541
+ print(
542
+ f"Start capture_inference for ConformerEncoder, seq_len_to_capture: {seq_len_to_capture}"
543
+ )
544
+
545
+ for seq_len in seq_len_to_capture:
546
+ xs = torch.randn(
547
+ 1, seq_len, self._output_size, device=device, dtype=torch.bfloat16
548
+ )
549
+ xs_lens = torch.tensor([seq_len], device=device, dtype=torch.int32)
550
+ decoding_chunk_size = 0
551
+ num_decoding_left_chunks = -1
552
+
553
+ T = xs.size(1)
554
+ masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
555
+ if self.global_cmvn is not None:
556
+ xs = self.global_cmvn(xs)
557
+ xs, pos_emb, masks = self.embed(xs, masks)
558
+ mask_pad = masks # (B, 1, T/subsample_rate)
559
+ chunk_masks = add_optional_chunk_mask(
560
+ xs,
561
+ masks,
562
+ self.use_dynamic_chunk,
563
+ self.use_dynamic_left_chunk,
564
+ decoding_chunk_size,
565
+ self.static_chunk_size,
566
+ num_decoding_left_chunks,
567
+ )
568
+
569
+ g = torch.cuda.CUDAGraph()
570
+ with torch.cuda.graph(g):
571
+ out = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
572
+
573
+ self.inference_graphs[seq_len] = g
574
+ self.inference_buffers[seq_len] = {
575
+ "xs": xs,
576
+ "chunk_masks": chunk_masks,
577
+ "pos_emb": pos_emb,
578
+ "mask_pad": mask_pad,
579
+ "out": out,
580
+ }
581
+ end_time = time.time()
582
+ print(
583
+ f"Finish capture_inference for ConformerEncoder, time elapsed: {end_time - start_time}"
584
+ )
585
+
586
+ @torch.inference_mode()
587
+ def inference(self, xs: torch.Tensor, xs_lens: torch.Tensor):
588
+ curr_seq_len = xs.shape[1]
589
+ target_len = None
590
+
591
+ for seq_len in sorted(self.inference_graphs.keys()):
592
+ if seq_len >= curr_seq_len:
593
+ target_len = seq_len
594
+ break
595
+
596
+ if target_len is not None:
597
+ xs = F.pad(xs, (0, 0, 0, target_len - curr_seq_len), "constant", 0)
598
+
599
+ decoding_chunk_size = 0
600
+ num_decoding_left_chunks = -1
601
+
602
+ T = xs.size(1)
603
+ masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
604
+ if self.global_cmvn is not None:
605
+ xs = self.global_cmvn(xs)
606
+ xs, pos_emb, masks = self.embed(xs, masks)
607
+ mask_pad = masks # (B, 1, T/subsample_rate)
608
+ chunk_masks = add_optional_chunk_mask(
609
+ xs,
610
+ masks,
611
+ self.use_dynamic_chunk,
612
+ self.use_dynamic_left_chunk,
613
+ decoding_chunk_size,
614
+ self.static_chunk_size,
615
+ num_decoding_left_chunks,
616
+ )
617
+
618
+ if target_len is not None:
619
+ buffer = self.inference_buffers[target_len]
620
+ buffer["xs"].copy_(xs)
621
+ buffer["chunk_masks"].copy_(chunk_masks)
622
+ buffer["pos_emb"].copy_(pos_emb)
623
+ buffer["mask_pad"].copy_(mask_pad)
624
+
625
+ self.inference_graphs[target_len].replay()
626
+
627
+ out = buffer["out"][:, :curr_seq_len, :]
628
+ else:
629
+ out = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
630
+
631
+ if self.normalize_before:
632
+ out = self.after_norm(out)
633
+ return out, masks
cosyvoice/transformer/encoder_layer.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
2
+ # 2022 Xingchen Song ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """Encoder self-attention layer definition."""
17
+
18
+ from typing import Optional, Tuple
19
+
20
+ import torch
21
+ from torch import nn
22
+
23
+
24
+ class TransformerEncoderLayer(nn.Module):
25
+ """Encoder layer module.
26
+
27
+ Args:
28
+ size (int): Input dimension.
29
+ self_attn (torch.nn.Module): Self-attention module instance.
30
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
31
+ instance can be used as the argument.
32
+ feed_forward (torch.nn.Module): Feed-forward module instance.
33
+ `PositionwiseFeedForward`, instance can be used as the argument.
34
+ dropout_rate (float): Dropout rate.
35
+ normalize_before (bool):
36
+ True: use layer_norm before each sub-block.
37
+ False: to use layer_norm after each sub-block.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ size: int,
43
+ self_attn: torch.nn.Module,
44
+ feed_forward: torch.nn.Module,
45
+ dropout_rate: float,
46
+ normalize_before: bool = True,
47
+ ):
48
+ """Construct an EncoderLayer object."""
49
+ super().__init__()
50
+ self.self_attn = self_attn
51
+ self.feed_forward = feed_forward
52
+ self.norm1 = nn.LayerNorm(size, eps=1e-5)
53
+ self.norm2 = nn.LayerNorm(size, eps=1e-5)
54
+ self.dropout = nn.Dropout(dropout_rate)
55
+ self.size = size
56
+ self.normalize_before = normalize_before
57
+
58
+ def forward(
59
+ self,
60
+ x: torch.Tensor,
61
+ mask: torch.Tensor,
62
+ pos_emb: torch.Tensor,
63
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
64
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
65
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
66
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
67
+ """Compute encoded features.
68
+
69
+ Args:
70
+ x (torch.Tensor): (#batch, time, size)
71
+ mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
72
+ (0, 0, 0) means fake mask.
73
+ pos_emb (torch.Tensor): just for interface compatibility
74
+ to ConformerEncoderLayer
75
+ mask_pad (torch.Tensor): does not used in transformer layer,
76
+ just for unified api with conformer.
77
+ att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
78
+ (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
79
+ cnn_cache (torch.Tensor): Convolution cache in conformer layer
80
+ (#batch=1, size, cache_t2), not used here, it's for interface
81
+ compatibility to ConformerEncoderLayer.
82
+ Returns:
83
+ torch.Tensor: Output tensor (#batch, time, size).
84
+ torch.Tensor: Mask tensor (#batch, time, time).
85
+ torch.Tensor: att_cache tensor,
86
+ (#batch=1, head, cache_t1 + time, d_k * 2).
87
+ torch.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2).
88
+
89
+ """
90
+ residual = x
91
+ if self.normalize_before:
92
+ x = self.norm1(x)
93
+ x_att, new_att_cache = self.self_attn(
94
+ x, x, x, mask, pos_emb=pos_emb, cache=att_cache
95
+ )
96
+ x = residual + self.dropout(x_att)
97
+ if not self.normalize_before:
98
+ x = self.norm1(x)
99
+
100
+ residual = x
101
+ if self.normalize_before:
102
+ x = self.norm2(x)
103
+ x = residual + self.dropout(self.feed_forward(x))
104
+ if not self.normalize_before:
105
+ x = self.norm2(x)
106
+
107
+ fake_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
108
+ return x, mask, new_att_cache, fake_cnn_cache
109
+
110
+
111
+ class ConformerEncoderLayer(nn.Module):
112
+ """Encoder layer module.
113
+ Args:
114
+ size (int): Input dimension.
115
+ self_attn (torch.nn.Module): Self-attention module instance.
116
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
117
+ instance can be used as the argument.
118
+ feed_forward (torch.nn.Module): Feed-forward module instance.
119
+ `PositionwiseFeedForward` instance can be used as the argument.
120
+ feed_forward_macaron (torch.nn.Module): Additional feed-forward module
121
+ instance.
122
+ `PositionwiseFeedForward` instance can be used as the argument.
123
+ conv_module (torch.nn.Module): Convolution module instance.
124
+ `ConvlutionModule` instance can be used as the argument.
125
+ dropout_rate (float): Dropout rate.
126
+ normalize_before (bool):
127
+ True: use layer_norm before each sub-block.
128
+ False: use layer_norm after each sub-block.
129
+ """
130
+
131
+ def __init__(
132
+ self,
133
+ size: int,
134
+ self_attn: torch.nn.Module,
135
+ feed_forward: Optional[nn.Module] = None,
136
+ feed_forward_macaron: Optional[nn.Module] = None,
137
+ conv_module: Optional[nn.Module] = None,
138
+ dropout_rate: float = 0.1,
139
+ normalize_before: bool = True,
140
+ ):
141
+ """Construct an EncoderLayer object."""
142
+ super().__init__()
143
+ self.self_attn = self_attn
144
+ self.feed_forward = feed_forward
145
+ self.feed_forward_macaron = feed_forward_macaron
146
+ self.conv_module = conv_module
147
+ self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module
148
+ self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module
149
+ if feed_forward_macaron is not None:
150
+ self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5)
151
+ self.ff_scale = 0.5
152
+ else:
153
+ self.ff_scale = 1.0
154
+ if self.conv_module is not None:
155
+ self.norm_conv = nn.LayerNorm(size, eps=1e-5) # for the CNN module
156
+ self.norm_final = nn.LayerNorm(
157
+ size, eps=1e-5
158
+ ) # for the final output of the block
159
+ self.dropout = nn.Dropout(dropout_rate)
160
+ self.size = size
161
+ self.normalize_before = normalize_before
162
+
163
+ def forward(
164
+ self,
165
+ x: torch.Tensor,
166
+ mask: torch.Tensor,
167
+ pos_emb: torch.Tensor,
168
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
169
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
170
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
171
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
172
+ """Compute encoded features.
173
+
174
+ Args:
175
+ x (torch.Tensor): (#batch, time, size)
176
+ mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
177
+ (0, 0, 0) means fake mask.
178
+ pos_emb (torch.Tensor): positional encoding, must not be None
179
+ for ConformerEncoderLayer.
180
+ mask_pad (torch.Tensor): batch padding mask used for conv module.
181
+ (#batch, 1,time), (0, 0, 0) means fake mask.
182
+ att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
183
+ (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
184
+ cnn_cache (torch.Tensor): Convolution cache in conformer layer
185
+ (#batch=1, size, cache_t2)
186
+ Returns:
187
+ torch.Tensor: Output tensor (#batch, time, size).
188
+ torch.Tensor: Mask tensor (#batch, time, time).
189
+ torch.Tensor: att_cache tensor,
190
+ (#batch=1, head, cache_t1 + time, d_k * 2).
191
+ torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
192
+ """
193
+
194
+ # whether to use macaron style
195
+ if self.feed_forward_macaron is not None:
196
+ residual = x
197
+ if self.normalize_before:
198
+ x = self.norm_ff_macaron(x)
199
+ x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
200
+ if not self.normalize_before:
201
+ x = self.norm_ff_macaron(x)
202
+
203
+ # multi-headed self-attention module
204
+ residual = x
205
+ if self.normalize_before:
206
+ x = self.norm_mha(x)
207
+ x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, att_cache)
208
+ x = residual + self.dropout(x_att)
209
+ if not self.normalize_before:
210
+ x = self.norm_mha(x)
211
+
212
+ # convolution module
213
+ # Fake new cnn cache here, and then change it in conv_module
214
+ new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
215
+ if self.conv_module is not None:
216
+ residual = x
217
+ if self.normalize_before:
218
+ x = self.norm_conv(x)
219
+ x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
220
+ x = residual + self.dropout(x)
221
+
222
+ if not self.normalize_before:
223
+ x = self.norm_conv(x)
224
+
225
+ # feed forward module
226
+ residual = x
227
+ if self.normalize_before:
228
+ x = self.norm_ff(x)
229
+
230
+ x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
231
+ if not self.normalize_before:
232
+ x = self.norm_ff(x)
233
+
234
+ if self.conv_module is not None:
235
+ x = self.norm_final(x)
236
+
237
+ return x, mask, new_att_cache, new_cnn_cache
cosyvoice/transformer/label_smoothing_loss.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Label smoothing module."""
16
+
17
+ import torch
18
+ from torch import nn
19
+
20
+
21
+ class LabelSmoothingLoss(nn.Module):
22
+ """Label-smoothing loss.
23
+
24
+ In a standard CE loss, the label's data distribution is:
25
+ [0,1,2] ->
26
+ [
27
+ [1.0, 0.0, 0.0],
28
+ [0.0, 1.0, 0.0],
29
+ [0.0, 0.0, 1.0],
30
+ ]
31
+
32
+ In the smoothing version CE Loss,some probabilities
33
+ are taken from the true label prob (1.0) and are divided
34
+ among other labels.
35
+
36
+ e.g.
37
+ smoothing=0.1
38
+ [0,1,2] ->
39
+ [
40
+ [0.9, 0.05, 0.05],
41
+ [0.05, 0.9, 0.05],
42
+ [0.05, 0.05, 0.9],
43
+ ]
44
+
45
+ Args:
46
+ size (int): the number of class
47
+ padding_idx (int): padding class id which will be ignored for loss
48
+ smoothing (float): smoothing rate (0.0 means the conventional CE)
49
+ normalize_length (bool):
50
+ normalize loss by sequence length if True
51
+ normalize loss by batch size if False
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ size: int,
57
+ padding_idx: int,
58
+ smoothing: float,
59
+ normalize_length: bool = False,
60
+ ):
61
+ """Construct an LabelSmoothingLoss object."""
62
+ super(LabelSmoothingLoss, self).__init__()
63
+ self.criterion = nn.KLDivLoss(reduction="none")
64
+ self.padding_idx = padding_idx
65
+ self.confidence = 1.0 - smoothing
66
+ self.smoothing = smoothing
67
+ self.size = size
68
+ self.normalize_length = normalize_length
69
+
70
+ def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
71
+ """Compute loss between x and target.
72
+
73
+ The model outputs and data labels tensors are flatten to
74
+ (batch*seqlen, class) shape and a mask is applied to the
75
+ padding part which should not be calculated for loss.
76
+
77
+ Args:
78
+ x (torch.Tensor): prediction (batch, seqlen, class)
79
+ target (torch.Tensor):
80
+ target signal masked with self.padding_id (batch, seqlen)
81
+ Returns:
82
+ loss (torch.Tensor) : The KL loss, scalar float value
83
+ """
84
+ assert x.size(2) == self.size
85
+ batch_size = x.size(0)
86
+ x = x.view(-1, self.size)
87
+ target = target.view(-1)
88
+ # use zeros_like instead of torch.no_grad() for true_dist,
89
+ # since no_grad() can not be exported by JIT
90
+ true_dist = torch.zeros_like(x)
91
+ true_dist.fill_(self.smoothing / (self.size - 1))
92
+ ignore = target == self.padding_idx # (B,)
93
+ total = len(target) - ignore.sum().item()
94
+ target = target.masked_fill(ignore, 0) # avoid -1 index
95
+ true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
96
+ kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
97
+ denom = total if self.normalize_length else batch_size
98
+ return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
cosyvoice/transformer/positionwise_feed_forward.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Positionwise feed forward layer definition."""
16
+
17
+ import torch
18
+
19
+
20
+ class PositionwiseFeedForward(torch.nn.Module):
21
+ """Positionwise feed forward layer.
22
+
23
+ FeedForward are appied on each position of the sequence.
24
+ The output dim is same with the input dim.
25
+
26
+ Args:
27
+ idim (int): Input dimenstion.
28
+ hidden_units (int): The number of hidden units.
29
+ dropout_rate (float): Dropout rate.
30
+ activation (torch.nn.Module): Activation function
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ idim: int,
36
+ hidden_units: int,
37
+ dropout_rate: float,
38
+ activation: torch.nn.Module = torch.nn.ReLU(),
39
+ ):
40
+ """Construct a PositionwiseFeedForward object."""
41
+ super(PositionwiseFeedForward, self).__init__()
42
+ self.w_1 = torch.nn.Linear(idim, hidden_units)
43
+ self.activation = activation
44
+ self.dropout = torch.nn.Dropout(dropout_rate)
45
+ self.w_2 = torch.nn.Linear(hidden_units, idim)
46
+
47
+ def forward(self, xs: torch.Tensor) -> torch.Tensor:
48
+ """Forward function.
49
+
50
+ Args:
51
+ xs: input tensor (B, L, D)
52
+ Returns:
53
+ output tensor, (B, L, D)
54
+ """
55
+ return self.w_2(self.dropout(self.activation(self.w_1(xs))))
56
+
57
+
58
+ class MoEFFNLayer(torch.nn.Module):
59
+ """
60
+ Mixture of expert with Positionwise feed forward layer
61
+ See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf
62
+ The output dim is same with the input dim.
63
+
64
+ Modified from https://github.com/Lightning-AI/lit-gpt/pull/823
65
+ https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219
66
+ Args:
67
+ n_expert: number of expert.
68
+ n_expert_per_token: The actual number of experts used for each frame
69
+ idim (int): Input dimenstion.
70
+ hidden_units (int): The number of hidden units.
71
+ dropout_rate (float): Dropout rate.
72
+ activation (torch.nn.Module): Activation function
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ n_expert: int,
78
+ n_expert_per_token: int,
79
+ idim: int,
80
+ hidden_units: int,
81
+ dropout_rate: float,
82
+ activation: torch.nn.Module = torch.nn.ReLU(),
83
+ ):
84
+ super(MoEFFNLayer, self).__init__()
85
+ self.gate = torch.nn.Linear(idim, n_expert, bias=False)
86
+ self.experts = torch.nn.ModuleList(
87
+ PositionwiseFeedForward(idim, hidden_units, dropout_rate, activation)
88
+ for _ in range(n_expert)
89
+ )
90
+ self.n_expert_per_token = n_expert_per_token
91
+
92
+ def forward(self, xs: torch.Tensor) -> torch.Tensor:
93
+ """Foward function.
94
+ Args:
95
+ xs: input tensor (B, L, D)
96
+ Returns:
97
+ output tensor, (B, L, D)
98
+
99
+ """
100
+ B, L, D = xs.size() # batch size, sequence length, embedding dimension (idim)
101
+ xs = xs.view(-1, D) # (B*L, D)
102
+ router = self.gate(xs) # (B*L, n_expert)
103
+ logits, indices = torch.topk(
104
+ router, self.n_expert_per_token
105
+ ) # probs:(B*L, n_expert), indices: (B*L, n_expert)
106
+ weights = torch.nn.functional.softmax(logits, dim=1, dtype=torch.float).to(
107
+ dtype=xs.dtype
108
+ ) # (B*L, n_expert_per_token)
109
+ output = torch.zeros_like(xs) # (B*L, D)
110
+ for i, expert in enumerate(self.experts):
111
+ mask = indices == i
112
+ batch_idx, ith_expert = torch.where(mask)
113
+ output[batch_idx] += weights[batch_idx, ith_expert, None] * expert(
114
+ xs[batch_idx]
115
+ )
116
+ return output.view(B, L, D)
cosyvoice/transformer/subsampling.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
2
+ # 2024 Alibaba Inc (Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """Subsampling layer definition."""
17
+
18
+ from typing import Tuple, Union
19
+
20
+ import torch
21
+
22
+
23
+ class BaseSubsampling(torch.nn.Module):
24
+
25
+ def __init__(self):
26
+ super().__init__()
27
+ self.right_context = 0
28
+ self.subsampling_rate = 1
29
+
30
+ def position_encoding(
31
+ self, offset: Union[int, torch.Tensor], size: int
32
+ ) -> torch.Tensor:
33
+ return self.pos_enc.position_encoding(offset, size)
34
+
35
+
36
+ class EmbedinigNoSubsampling(BaseSubsampling):
37
+ """Embedding input without subsampling"""
38
+
39
+ def __init__(
40
+ self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module
41
+ ):
42
+ super().__init__()
43
+ self.embed = torch.nn.Embedding(idim, odim)
44
+ self.pos_enc = pos_enc_class
45
+
46
+ def forward(
47
+ self,
48
+ x: torch.Tensor,
49
+ x_mask: torch.Tensor,
50
+ offset: Union[int, torch.Tensor] = 0,
51
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
52
+ """Input x.
53
+
54
+ Args:
55
+ x (torch.Tensor): Input tensor (#batch, time, idim).
56
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
57
+
58
+ Returns:
59
+ torch.Tensor: linear input tensor (#batch, time', odim),
60
+ where time' = time .
61
+ torch.Tensor: linear input mask (#batch, 1, time'),
62
+ where time' = time .
63
+
64
+ """
65
+ x = self.embed(x)
66
+ x, pos_emb = self.pos_enc(x, offset)
67
+ return x, pos_emb, x_mask
68
+
69
+
70
+ class LinearNoSubsampling(BaseSubsampling):
71
+ """Linear transform the input without subsampling
72
+
73
+ Args:
74
+ idim (int): Input dimension.
75
+ odim (int): Output dimension.
76
+ dropout_rate (float): Dropout rate.
77
+
78
+ """
79
+
80
+ def __init__(
81
+ self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module
82
+ ):
83
+ """Construct an linear object."""
84
+ super().__init__()
85
+ self.out = torch.nn.Sequential(
86
+ torch.nn.Linear(idim, odim),
87
+ torch.nn.LayerNorm(odim, eps=1e-5),
88
+ torch.nn.Dropout(dropout_rate),
89
+ )
90
+ self.pos_enc = pos_enc_class
91
+ self.right_context = 0
92
+ self.subsampling_rate = 1
93
+
94
+ def forward(
95
+ self,
96
+ x: torch.Tensor,
97
+ x_mask: torch.Tensor,
98
+ offset: Union[int, torch.Tensor] = 0,
99
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
100
+ """Input x.
101
+
102
+ Args:
103
+ x (torch.Tensor): Input tensor (#batch, time, idim).
104
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
105
+
106
+ Returns:
107
+ torch.Tensor: linear input tensor (#batch, time', odim),
108
+ where time' = time .
109
+ torch.Tensor: linear input mask (#batch, 1, time'),
110
+ where time' = time .
111
+
112
+ """
113
+ x = self.out(x)
114
+ x, pos_emb = self.pos_enc(x, offset)
115
+ return x, pos_emb, x_mask
116
+
117
+
118
+ class Conv1dSubsampling2(BaseSubsampling):
119
+ """Convolutional 1D subsampling (to 1/2 length).
120
+ It is designed for Whisper, ref:
121
+ https://github.com/openai/whisper/blob/main/whisper/model.py
122
+
123
+ Args:
124
+ idim (int): Input dimension.
125
+ odim (int): Output dimension.
126
+ dropout_rate (float): Dropout rate.
127
+
128
+ """
129
+
130
+ def __init__(
131
+ self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module
132
+ ):
133
+ """Construct an Conv1dSubsampling2 object."""
134
+ super().__init__()
135
+ self.conv = torch.nn.Sequential(
136
+ torch.nn.Conv1d(idim, odim, kernel_size=3, padding=1),
137
+ torch.nn.GELU(),
138
+ torch.nn.Conv1d(odim, odim, kernel_size=3, stride=2, padding=1),
139
+ torch.nn.GELU(),
140
+ )
141
+ self.pos_enc = pos_enc_class
142
+ # The right context for every conv layer is computed by:
143
+ # (kernel_size - 1) * frame_rate_of_this_layer
144
+ self.subsampling_rate = 2
145
+ # 4 = (3 - 1) * 1 + (3 - 1) * 1
146
+ self.right_context = 4
147
+
148
+ def forward(
149
+ self,
150
+ x: torch.Tensor,
151
+ x_mask: torch.Tensor,
152
+ offset: Union[int, torch.Tensor] = 0,
153
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
154
+ """Subsample x.
155
+
156
+ Args:
157
+ x (torch.Tensor): Input tensor (#batch, time, idim).
158
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
159
+
160
+ Returns:
161
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
162
+ where time' = time // 2.
163
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
164
+ where time' = time // 2.
165
+ torch.Tensor: positional encoding
166
+
167
+ """
168
+ time = x.size(1)
169
+ x = x.transpose(1, 2) # (b, f, t)
170
+ x = self.conv(x)
171
+ x = x.transpose(1, 2) # (b, t, f)
172
+ x, pos_emb = self.pos_enc(x, offset)
173
+ return x, pos_emb, x_mask[:, :, (time + 1) % 2 :: 2]
174
+
175
+
176
+ class Conv2dSubsampling4(BaseSubsampling):
177
+ """Convolutional 2D subsampling (to 1/4 length).
178
+
179
+ Args:
180
+ idim (int): Input dimension.
181
+ odim (int): Output dimension.
182
+ dropout_rate (float): Dropout rate.
183
+
184
+ """
185
+
186
+ def __init__(
187
+ self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module
188
+ ):
189
+ """Construct an Conv2dSubsampling4 object."""
190
+ super().__init__()
191
+ self.conv = torch.nn.Sequential(
192
+ torch.nn.Conv2d(1, odim, 3, 2),
193
+ torch.nn.ReLU(),
194
+ torch.nn.Conv2d(odim, odim, 3, 2),
195
+ torch.nn.ReLU(),
196
+ )
197
+ self.out = torch.nn.Sequential(
198
+ torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
199
+ )
200
+ self.pos_enc = pos_enc_class
201
+ # The right context for every conv layer is computed by:
202
+ # (kernel_size - 1) * frame_rate_of_this_layer
203
+ self.subsampling_rate = 4
204
+ # 6 = (3 - 1) * 1 + (3 - 1) * 2
205
+ self.right_context = 6
206
+
207
+ def forward(
208
+ self,
209
+ x: torch.Tensor,
210
+ x_mask: torch.Tensor,
211
+ offset: Union[int, torch.Tensor] = 0,
212
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
213
+ """Subsample x.
214
+
215
+ Args:
216
+ x (torch.Tensor): Input tensor (#batch, time, idim).
217
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
218
+
219
+ Returns:
220
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
221
+ where time' = time // 4.
222
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
223
+ where time' = time // 4.
224
+ torch.Tensor: positional encoding
225
+
226
+ """
227
+ x = x.unsqueeze(1) # (b, c=1, t, f)
228
+ x = self.conv(x)
229
+ b, c, t, f = x.size()
230
+ x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
231
+ x, pos_emb = self.pos_enc(x, offset)
232
+ return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2]
233
+
234
+
235
+ class Conv2dSubsampling6(BaseSubsampling):
236
+ """Convolutional 2D subsampling (to 1/6 length).
237
+ Args:
238
+ idim (int): Input dimension.
239
+ odim (int): Output dimension.
240
+ dropout_rate (float): Dropout rate.
241
+ pos_enc (torch.nn.Module): Custom position encoding layer.
242
+ """
243
+
244
+ def __init__(
245
+ self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module
246
+ ):
247
+ """Construct an Conv2dSubsampling6 object."""
248
+ super().__init__()
249
+ self.conv = torch.nn.Sequential(
250
+ torch.nn.Conv2d(1, odim, 3, 2),
251
+ torch.nn.ReLU(),
252
+ torch.nn.Conv2d(odim, odim, 5, 3),
253
+ torch.nn.ReLU(),
254
+ )
255
+ self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim)
256
+ self.pos_enc = pos_enc_class
257
+ # 10 = (3 - 1) * 1 + (5 - 1) * 2
258
+ self.subsampling_rate = 6
259
+ self.right_context = 10
260
+
261
+ def forward(
262
+ self,
263
+ x: torch.Tensor,
264
+ x_mask: torch.Tensor,
265
+ offset: Union[int, torch.Tensor] = 0,
266
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
267
+ """Subsample x.
268
+ Args:
269
+ x (torch.Tensor): Input tensor (#batch, time, idim).
270
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
271
+
272
+ Returns:
273
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
274
+ where time' = time // 6.
275
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
276
+ where time' = time // 6.
277
+ torch.Tensor: positional encoding
278
+ """
279
+ x = x.unsqueeze(1) # (b, c, t, f)
280
+ x = self.conv(x)
281
+ b, c, t, f = x.size()
282
+ x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
283
+ x, pos_emb = self.pos_enc(x, offset)
284
+ return x, pos_emb, x_mask[:, :, 2::2][:, :, 4::3]
285
+
286
+
287
+ class Conv2dSubsampling8(BaseSubsampling):
288
+ """Convolutional 2D subsampling (to 1/8 length).
289
+
290
+ Args:
291
+ idim (int): Input dimension.
292
+ odim (int): Output dimension.
293
+ dropout_rate (float): Dropout rate.
294
+
295
+ """
296
+
297
+ def __init__(
298
+ self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module
299
+ ):
300
+ """Construct an Conv2dSubsampling8 object."""
301
+ super().__init__()
302
+ self.conv = torch.nn.Sequential(
303
+ torch.nn.Conv2d(1, odim, 3, 2),
304
+ torch.nn.ReLU(),
305
+ torch.nn.Conv2d(odim, odim, 3, 2),
306
+ torch.nn.ReLU(),
307
+ torch.nn.Conv2d(odim, odim, 3, 2),
308
+ torch.nn.ReLU(),
309
+ )
310
+ self.linear = torch.nn.Linear(
311
+ odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim
312
+ )
313
+ self.pos_enc = pos_enc_class
314
+ self.subsampling_rate = 8
315
+ # 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4
316
+ self.right_context = 14
317
+
318
+ def forward(
319
+ self,
320
+ x: torch.Tensor,
321
+ x_mask: torch.Tensor,
322
+ offset: Union[int, torch.Tensor] = 0,
323
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
324
+ """Subsample x.
325
+
326
+ Args:
327
+ x (torch.Tensor): Input tensor (#batch, time, idim).
328
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
329
+
330
+ Returns:
331
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
332
+ where time' = time // 8.
333
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
334
+ where time' = time // 8.
335
+ torch.Tensor: positional encoding
336
+ """
337
+ x = x.unsqueeze(1) # (b, c, t, f)
338
+ x = self.conv(x)
339
+ b, c, t, f = x.size()
340
+ x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
341
+ x, pos_emb = self.pos_enc(x, offset)
342
+ return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2]
343
+
344
+
345
+ class LegacyLinearNoSubsampling(BaseSubsampling):
346
+ """Linear transform the input without subsampling
347
+
348
+ Args:
349
+ idim (int): Input dimension.
350
+ odim (int): Output dimension.
351
+ dropout_rate (float): Dropout rate.
352
+
353
+ """
354
+
355
+ def __init__(
356
+ self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module
357
+ ):
358
+ """Construct an linear object."""
359
+ super().__init__()
360
+ self.out = torch.nn.Sequential(
361
+ torch.nn.Linear(idim, odim),
362
+ torch.nn.LayerNorm(odim, eps=1e-5),
363
+ torch.nn.Dropout(dropout_rate),
364
+ torch.nn.ReLU(),
365
+ )
366
+ self.pos_enc = pos_enc_class
367
+ self.right_context = 0
368
+ self.subsampling_rate = 1
369
+
370
+ def forward(
371
+ self,
372
+ x: torch.Tensor,
373
+ x_mask: torch.Tensor,
374
+ offset: Union[int, torch.Tensor] = 0,
375
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
376
+ """Input x.
377
+
378
+ Args:
379
+ x (torch.Tensor): Input tensor (#batch, time, idim).
380
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
381
+
382
+ Returns:
383
+ torch.Tensor: linear input tensor (#batch, time', odim),
384
+ where time' = time .
385
+ torch.Tensor: linear input mask (#batch, 1, time'),
386
+ where time' = time .
387
+
388
+ """
389
+ x = self.out(x)
390
+ x, pos_emb = self.pos_enc(x, offset)
391
+ return x, pos_emb, x_mask
cosyvoice/utils/__init__.py ADDED
File without changes
cosyvoice/utils/audio.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.utils.data
4
+ from librosa.filters import mel as librosa_mel_fn
5
+ from scipy.io.wavfile import read
6
+
7
+ MAX_WAV_VALUE = 32768.0
8
+
9
+
10
+ def load_wav(full_path):
11
+ sampling_rate, data = read(full_path)
12
+ return data, sampling_rate
13
+
14
+
15
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
16
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
17
+
18
+
19
+ def dynamic_range_decompression(x, C=1):
20
+ return np.exp(x) / C
21
+
22
+
23
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
24
+ return torch.log(torch.clamp(x, min=clip_val) * C)
25
+
26
+
27
+ def dynamic_range_decompression_torch(x, C=1):
28
+ return torch.exp(x) / C
29
+
30
+
31
+ def spectral_normalize_torch(magnitudes):
32
+ output = dynamic_range_compression_torch(magnitudes)
33
+ return output
34
+
35
+
36
+ def spectral_de_normalize_torch(magnitudes):
37
+ output = dynamic_range_decompression_torch(magnitudes)
38
+ return output
39
+
40
+
41
+ mel_basis = {}
42
+ hann_window = {}
43
+
44
+
45
+ def mel_spectrogram(
46
+ y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
47
+ ):
48
+ # if torch.min(y) < -1.0:
49
+ # print("min value is ", torch.min(y))
50
+ # if torch.max(y) > 1.0:
51
+ # print("max value is ", torch.max(y))
52
+
53
+ global mel_basis, hann_window # pylint: disable=global-statement
54
+ if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
55
+ mel = librosa_mel_fn(
56
+ sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
57
+ )
58
+ mel_basis[str(fmax) + "_" + str(y.device)] = (
59
+ torch.from_numpy(mel).float().to(y.device)
60
+ )
61
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
62
+
63
+ y = torch.nn.functional.pad(
64
+ y.unsqueeze(1),
65
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
66
+ mode="reflect",
67
+ )
68
+ y = y.squeeze(1)
69
+
70
+ spec = torch.view_as_real(
71
+ torch.stft(
72
+ y,
73
+ n_fft,
74
+ hop_length=hop_size,
75
+ win_length=win_size,
76
+ window=hann_window[str(y.device)],
77
+ center=center,
78
+ pad_mode="reflect",
79
+ normalized=False,
80
+ onesided=True,
81
+ return_complex=True,
82
+ )
83
+ )
84
+
85
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
86
+
87
+ spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
88
+ spec = spectral_normalize_torch(spec)
89
+
90
+ return spec
cosyvoice/utils/class_utils.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright [2023-11-28] <[email protected], Xingchen Song>
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import torch
16
+
17
+ from cosyvoice.transformer.activation import Swish
18
+ from cosyvoice.transformer.subsampling import (
19
+ LinearNoSubsampling,
20
+ EmbedinigNoSubsampling,
21
+ Conv1dSubsampling2,
22
+ Conv2dSubsampling4,
23
+ Conv2dSubsampling6,
24
+ Conv2dSubsampling8,
25
+ )
26
+ from cosyvoice.transformer.embedding import (
27
+ PositionalEncoding,
28
+ RelPositionalEncoding,
29
+ WhisperPositionalEncoding,
30
+ LearnablePositionalEncoding,
31
+ NoPositionalEncoding,
32
+ )
33
+ from cosyvoice.transformer.attention import (
34
+ MultiHeadedAttention,
35
+ RelPositionMultiHeadedAttention,
36
+ )
37
+ from cosyvoice.transformer.embedding import (
38
+ EspnetRelPositionalEncoding,
39
+ )
40
+ from cosyvoice.transformer.subsampling import (
41
+ LegacyLinearNoSubsampling,
42
+ )
43
+
44
+
45
+ COSYVOICE_ACTIVATION_CLASSES = {
46
+ "hardtanh": torch.nn.Hardtanh,
47
+ "tanh": torch.nn.Tanh,
48
+ "relu": torch.nn.ReLU,
49
+ "selu": torch.nn.SELU,
50
+ "swish": getattr(torch.nn, "SiLU", Swish),
51
+ "gelu": torch.nn.GELU,
52
+ }
53
+
54
+ COSYVOICE_SUBSAMPLE_CLASSES = {
55
+ "linear": LinearNoSubsampling,
56
+ "linear_legacy": LegacyLinearNoSubsampling,
57
+ "embed": EmbedinigNoSubsampling,
58
+ "conv1d2": Conv1dSubsampling2,
59
+ "conv2d": Conv2dSubsampling4,
60
+ "conv2d6": Conv2dSubsampling6,
61
+ "conv2d8": Conv2dSubsampling8,
62
+ "paraformer_dummy": torch.nn.Identity,
63
+ }
64
+
65
+ COSYVOICE_EMB_CLASSES = {
66
+ "embed": PositionalEncoding,
67
+ "abs_pos": PositionalEncoding,
68
+ "rel_pos": RelPositionalEncoding,
69
+ "rel_pos_espnet": EspnetRelPositionalEncoding,
70
+ "no_pos": NoPositionalEncoding,
71
+ "abs_pos_whisper": WhisperPositionalEncoding,
72
+ "embed_learnable_pe": LearnablePositionalEncoding,
73
+ }
74
+
75
+ COSYVOICE_ATTENTION_CLASSES = {
76
+ "selfattn": MultiHeadedAttention,
77
+ "rel_selfattn": RelPositionMultiHeadedAttention,
78
+ }
cosyvoice/utils/common.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """Unility functions for Transformer."""
17
+
18
+ import random
19
+ from typing import List
20
+
21
+ import numpy as np
22
+ import torch
23
+
24
+ IGNORE_ID = -1
25
+
26
+
27
+ def pad_list(xs: List[torch.Tensor], pad_value: int):
28
+ """Perform padding for the list of tensors.
29
+
30
+ Args:
31
+ xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
32
+ pad_value (float): Value for padding.
33
+
34
+ Returns:
35
+ Tensor: Padded tensor (B, Tmax, `*`).
36
+
37
+ Examples:
38
+ >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
39
+ >>> x
40
+ [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
41
+ >>> pad_list(x, 0)
42
+ tensor([[1., 1., 1., 1.],
43
+ [1., 1., 0., 0.],
44
+ [1., 0., 0., 0.]])
45
+
46
+ """
47
+ max_len = max([len(item) for item in xs])
48
+ batchs = len(xs)
49
+ ndim = xs[0].ndim
50
+ if ndim == 1:
51
+ pad_res = torch.zeros(batchs, max_len, dtype=xs[0].dtype, device=xs[0].device)
52
+ elif ndim == 2:
53
+ pad_res = torch.zeros(
54
+ batchs, max_len, xs[0].shape[1], dtype=xs[0].dtype, device=xs[0].device
55
+ )
56
+ elif ndim == 3:
57
+ pad_res = torch.zeros(
58
+ batchs,
59
+ max_len,
60
+ xs[0].shape[1],
61
+ xs[0].shape[2],
62
+ dtype=xs[0].dtype,
63
+ device=xs[0].device,
64
+ )
65
+ else:
66
+ raise ValueError(f"Unsupported ndim: {ndim}")
67
+ pad_res.fill_(pad_value)
68
+ for i in range(batchs):
69
+ pad_res[i, : len(xs[i])] = xs[i]
70
+ return pad_res
71
+
72
+
73
+ def th_accuracy(
74
+ pad_outputs: torch.Tensor, pad_targets: torch.Tensor, ignore_label: int
75
+ ) -> torch.Tensor:
76
+ """Calculate accuracy.
77
+
78
+ Args:
79
+ pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
80
+ pad_targets (LongTensor): Target label tensors (B, Lmax).
81
+ ignore_label (int): Ignore label id.
82
+
83
+ Returns:
84
+ torch.Tensor: Accuracy value (0.0 - 1.0).
85
+
86
+ """
87
+ pad_pred = pad_outputs.view(
88
+ pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)
89
+ ).argmax(2)
90
+ mask = pad_targets != ignore_label
91
+ numerator = torch.sum(
92
+ pad_pred.masked_select(mask) == pad_targets.masked_select(mask)
93
+ )
94
+ denominator = torch.sum(mask)
95
+ return (numerator / denominator).detach()
96
+
97
+
98
+ def get_padding(kernel_size, dilation=1):
99
+ return int((kernel_size * dilation - dilation) / 2)
100
+
101
+
102
+ def init_weights(m, mean=0.0, std=0.01):
103
+ classname = m.__class__.__name__
104
+ if classname.find("Conv") != -1:
105
+ m.weight.data.normal_(mean, std)
106
+
107
+
108
+ # Repetition Aware Sampling in VALL-E 2
109
+ def ras_sampling(
110
+ weighted_scores,
111
+ decoded_tokens,
112
+ sampling,
113
+ top_p=0.8,
114
+ top_k=25,
115
+ win_size=10,
116
+ tau_r=0.1,
117
+ ):
118
+ top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
119
+ rep_num = (
120
+ (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids)
121
+ .sum()
122
+ .item()
123
+ )
124
+ if rep_num >= win_size * tau_r:
125
+ top_ids = random_sampling(weighted_scores, decoded_tokens, sampling)
126
+ return top_ids
127
+
128
+
129
+ def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
130
+ prob, indices = [], []
131
+ cum_prob = 0.0
132
+ sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(
133
+ descending=True, stable=True
134
+ )
135
+ for i in range(len(sorted_idx)):
136
+ # sampling both top-p and numbers.
137
+ if cum_prob < top_p and len(prob) < top_k:
138
+ cum_prob += sorted_value[i]
139
+ prob.append(sorted_value[i])
140
+ indices.append(sorted_idx[i])
141
+ else:
142
+ break
143
+ prob = torch.tensor(prob).to(weighted_scores)
144
+ indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device)
145
+ top_ids = indices[prob.multinomial(1, replacement=True)]
146
+ return top_ids
147
+
148
+
149
+ def random_sampling(weighted_scores, decoded_tokens, sampling):
150
+ top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
151
+ return top_ids
152
+
153
+
154
+ def fade_in_out(fade_in_mel, fade_out_mel, window):
155
+ device = fade_in_mel.device
156
+ fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
157
+ mel_overlap_len = int(window.shape[0] / 2)
158
+ fade_in_mel[..., :mel_overlap_len] = (
159
+ fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len]
160
+ + fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
161
+ )
162
+ return fade_in_mel.to(device)
163
+
164
+
165
+ def set_all_random_seed(seed):
166
+ random.seed(seed)
167
+ np.random.seed(seed)
168
+ torch.manual_seed(seed)
169
+ torch.cuda.manual_seed_all(seed)
cosyvoice/utils/executor.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+ from contextlib import nullcontext
18
+ import os
19
+
20
+ import torch
21
+ import torch.distributed as dist
22
+
23
+ from cosyvoice.utils.train_utils import (
24
+ update_parameter_and_lr,
25
+ log_per_step,
26
+ log_per_save,
27
+ batch_forward,
28
+ batch_backward,
29
+ save_model,
30
+ cosyvoice_join,
31
+ )
32
+
33
+
34
+ class Executor:
35
+
36
+ def __init__(self):
37
+ self.step = 0
38
+ self.epoch = 0
39
+ self.rank = int(os.environ.get("RANK", 0))
40
+ self.device = torch.device("cuda:{}".format(self.rank))
41
+
42
+ def train_one_epoc(
43
+ self,
44
+ model,
45
+ optimizer,
46
+ scheduler,
47
+ train_data_loader,
48
+ cv_data_loader,
49
+ writer,
50
+ info_dict,
51
+ group_join,
52
+ ):
53
+ """Train one epoch"""
54
+
55
+ lr = optimizer.param_groups[0]["lr"]
56
+ logging.info(
57
+ "Epoch {} TRAIN info lr {} rank {}".format(self.epoch, lr, self.rank)
58
+ )
59
+ logging.info(
60
+ "using accumulate grad, new batch size is {} times"
61
+ " larger than before".format(info_dict["accum_grad"])
62
+ )
63
+ # A context manager to be used in conjunction with an instance of
64
+ # torch.nn.parallel.DistributedDataParallel to be able to train
65
+ # with uneven inputs across participating processes.
66
+ model.train()
67
+ model_context = (
68
+ model.join if info_dict["train_engine"] == "torch_ddp" else nullcontext
69
+ )
70
+ with model_context():
71
+ for batch_idx, batch_dict in enumerate(train_data_loader):
72
+ info_dict["tag"] = "TRAIN"
73
+ info_dict["step"] = self.step
74
+ info_dict["epoch"] = self.epoch
75
+ info_dict["batch_idx"] = batch_idx
76
+ if cosyvoice_join(group_join, info_dict):
77
+ break
78
+
79
+ # Disable gradient synchronizations across DDP processes.
80
+ # Within this context, gradients will be accumulated on module
81
+ # variables, which will later be synchronized.
82
+ if (
83
+ info_dict["train_engine"] == "torch_ddp"
84
+ and (batch_idx + 1) % info_dict["accum_grad"] != 0
85
+ ):
86
+ context = model.no_sync
87
+ # Used for single gpu training and DDP gradient synchronization
88
+ # processes.
89
+ else:
90
+ context = nullcontext
91
+
92
+ with context():
93
+ info_dict = batch_forward(model, batch_dict, info_dict)
94
+ info_dict = batch_backward(model, info_dict)
95
+
96
+ info_dict = update_parameter_and_lr(
97
+ model, optimizer, scheduler, info_dict
98
+ )
99
+ log_per_step(writer, info_dict)
100
+ # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
101
+ if (
102
+ info_dict["save_per_step"] > 0
103
+ and (self.step + 1) % info_dict["save_per_step"] == 0
104
+ and (batch_idx + 1) % info_dict["accum_grad"] == 0
105
+ ):
106
+ dist.barrier()
107
+ self.cv(
108
+ model, cv_data_loader, writer, info_dict, on_batch_end=False
109
+ )
110
+ model.train()
111
+ if (batch_idx + 1) % info_dict["accum_grad"] == 0:
112
+ self.step += 1
113
+ dist.barrier()
114
+ self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
115
+
116
+ @torch.inference_mode()
117
+ def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=True):
118
+ """Cross validation on"""
119
+ logging.info(
120
+ "Epoch {} Step {} on_batch_end {} CV rank {}".format(
121
+ self.epoch, self.step + 1, on_batch_end, self.rank
122
+ )
123
+ )
124
+ model.eval()
125
+ total_num_utts, total_loss_dict = 0, {} # avoid division by 0
126
+ for batch_idx, batch_dict in enumerate(cv_data_loader):
127
+ info_dict["tag"] = "CV"
128
+ info_dict["step"] = self.step
129
+ info_dict["epoch"] = self.epoch
130
+ info_dict["batch_idx"] = batch_idx
131
+
132
+ num_utts = len(batch_dict["utts"])
133
+ total_num_utts += num_utts
134
+
135
+ info_dict = batch_forward(model, batch_dict, info_dict)
136
+
137
+ for k, v in info_dict["loss_dict"].items():
138
+ if k not in total_loss_dict:
139
+ total_loss_dict[k] = []
140
+ total_loss_dict[k].append(v.item() * num_utts)
141
+ log_per_step(None, info_dict)
142
+ for k, v in total_loss_dict.items():
143
+ total_loss_dict[k] = sum(v) / total_num_utts
144
+ info_dict["loss_dict"] = total_loss_dict
145
+ log_per_save(writer, info_dict)
146
+ model_name = (
147
+ "epoch_{}_whole".format(self.epoch)
148
+ if on_batch_end
149
+ else "epoch_{}_step_{}".format(self.epoch, self.step + 1)
150
+ )
151
+ save_model(model, model_name, info_dict)
cosyvoice/utils/file_utils.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import json
17
+ import torchaudio
18
+ import logging
19
+
20
+ logging.getLogger("matplotlib").setLevel(logging.WARNING)
21
+ logging.basicConfig(level=logging.DEBUG, format="%(asctime)s %(levelname)s %(message)s")
22
+
23
+
24
+ def read_lists(list_file):
25
+ lists = []
26
+ with open(list_file, "r", encoding="utf8") as fin:
27
+ for line in fin:
28
+ lists.append(line.strip())
29
+ return lists
30
+
31
+
32
+ def read_json_lists(list_file):
33
+ lists = read_lists(list_file)
34
+ results = {}
35
+ for fn in lists:
36
+ with open(fn, "r", encoding="utf8") as fin:
37
+ results.update(json.load(fin))
38
+ return results
39
+
40
+
41
+ def load_wav(wav, target_sr):
42
+ speech, sample_rate = torchaudio.load(wav)
43
+ speech = speech.mean(dim=0, keepdim=True)
44
+ if sample_rate != target_sr:
45
+ # assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
46
+ speech = torchaudio.transforms.Resample(
47
+ orig_freq=sample_rate, new_freq=target_sr
48
+ )(speech)
49
+ return speech
cosyvoice/utils/frontend_utils.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import re
16
+
17
+ chinese_char_pattern = re.compile(r"[\u4e00-\u9fff]+")
18
+
19
+
20
+ # whether contain chinese character
21
+ def contains_chinese(text):
22
+ return bool(chinese_char_pattern.search(text))
23
+
24
+
25
+ # replace special symbol
26
+ def replace_corner_mark(text):
27
+ text = text.replace("²", "平方")
28
+ text = text.replace("³", "立方")
29
+ return text
30
+
31
+
32
+ # remove meaningless symbol
33
+ def remove_bracket(text):
34
+ text = text.replace("(", "").replace(")", "")
35
+ text = text.replace("【", "").replace("】", "")
36
+ text = text.replace("`", "").replace("`", "")
37
+ text = text.replace("——", " ")
38
+ return text
39
+
40
+
41
+ # spell Arabic numerals
42
+ def spell_out_number(text: str, inflect_parser):
43
+ new_text = []
44
+ st = None
45
+ for i, c in enumerate(text):
46
+ if not c.isdigit():
47
+ if st is not None:
48
+ num_str = inflect_parser.number_to_words(text[st:i])
49
+ new_text.append(num_str)
50
+ st = None
51
+ new_text.append(c)
52
+ else:
53
+ if st is None:
54
+ st = i
55
+ if st is not None and st < len(text):
56
+ num_str = inflect_parser.number_to_words(text[st:])
57
+ new_text.append(num_str)
58
+ return "".join(new_text)
59
+
60
+
61
+ # split paragrah logic:
62
+ # 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len
63
+ # 2. cal sentence len according to lang
64
+ # 3. split sentence according to puncatation
65
+ def split_paragraph(
66
+ text: str,
67
+ tokenize,
68
+ lang="zh",
69
+ token_max_n=80,
70
+ token_min_n=60,
71
+ merge_len=20,
72
+ comma_split=False,
73
+ ):
74
+ def calc_utt_length(_text: str):
75
+ if lang == "zh":
76
+ return len(_text)
77
+ else:
78
+ return len(tokenize(_text))
79
+
80
+ def should_merge(_text: str):
81
+ if lang == "zh":
82
+ return len(_text) < merge_len
83
+ else:
84
+ return len(tokenize(_text)) < merge_len
85
+
86
+ if lang == "zh":
87
+ pounc = ["。", "?", "!", ";", ":", "、", ".", "?", "!", ";"]
88
+ else:
89
+ pounc = [".", "?", "!", ";", ":"]
90
+ if comma_split:
91
+ pounc.extend([",", ","])
92
+
93
+ if text[-1] not in pounc:
94
+ if lang == "zh":
95
+ text += "。"
96
+ else:
97
+ text += "."
98
+
99
+ st = 0
100
+ utts = []
101
+ for i, c in enumerate(text):
102
+ if c in pounc:
103
+ if len(text[st:i]) > 0:
104
+ utts.append(text[st:i] + c)
105
+ if i + 1 < len(text) and text[i + 1] in ['"', "”"]:
106
+ tmp = utts.pop(-1)
107
+ utts.append(tmp + text[i + 1])
108
+ st = i + 2
109
+ else:
110
+ st = i + 1
111
+
112
+ final_utts = []
113
+ cur_utt = ""
114
+ for utt in utts:
115
+ if (
116
+ calc_utt_length(cur_utt + utt) > token_max_n
117
+ and calc_utt_length(cur_utt) > token_min_n
118
+ ):
119
+ final_utts.append(cur_utt)
120
+ cur_utt = ""
121
+ cur_utt = cur_utt + utt
122
+ if len(cur_utt) > 0:
123
+ if should_merge(cur_utt) and len(final_utts) != 0:
124
+ final_utts[-1] = final_utts[-1] + cur_utt
125
+ else:
126
+ final_utts.append(cur_utt)
127
+
128
+ return final_utts
129
+
130
+
131
+ # remove blank between chinese character
132
+ def replace_blank(text: str):
133
+ out_str = []
134
+ for i, c in enumerate(text):
135
+ if c == " ":
136
+ if (text[i + 1].isascii() and text[i + 1] != " ") and (
137
+ text[i - 1].isascii() and text[i - 1] != " "
138
+ ):
139
+ out_str.append(c)
140
+ else:
141
+ out_str.append(c)
142
+ return "".join(out_str)
cosyvoice/utils/mask.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import torch
18
+
19
+ '''
20
+ def subsequent_mask(
21
+ size: int,
22
+ device: torch.device = torch.device("cpu"),
23
+ ) -> torch.Tensor:
24
+ """Create mask for subsequent steps (size, size).
25
+
26
+ This mask is used only in decoder which works in an auto-regressive mode.
27
+ This means the current step could only do attention with its left steps.
28
+
29
+ In encoder, fully attention is used when streaming is not necessary and
30
+ the sequence is not long. In this case, no attention mask is needed.
31
+
32
+ When streaming is need, chunk-based attention is used in encoder. See
33
+ subsequent_chunk_mask for the chunk-based attention mask.
34
+
35
+ Args:
36
+ size (int): size of mask
37
+ str device (str): "cpu" or "cuda" or torch.Tensor.device
38
+ dtype (torch.device): result dtype
39
+
40
+ Returns:
41
+ torch.Tensor: mask
42
+
43
+ Examples:
44
+ >>> subsequent_mask(3)
45
+ [[1, 0, 0],
46
+ [1, 1, 0],
47
+ [1, 1, 1]]
48
+ """
49
+ ret = torch.ones(size, size, device=device, dtype=torch.bool)
50
+ return torch.tril(ret)
51
+ '''
52
+
53
+
54
+ def subsequent_mask(
55
+ size: int,
56
+ device: torch.device = torch.device("cpu"),
57
+ ) -> torch.Tensor:
58
+ """Create mask for subsequent steps (size, size).
59
+
60
+ This mask is used only in decoder which works in an auto-regressive mode.
61
+ This means the current step could only do attention with its left steps.
62
+
63
+ In encoder, fully attention is used when streaming is not necessary and
64
+ the sequence is not long. In this case, no attention mask is needed.
65
+
66
+ When streaming is need, chunk-based attention is used in encoder. See
67
+ subsequent_chunk_mask for the chunk-based attention mask.
68
+
69
+ Args:
70
+ size (int): size of mask
71
+ str device (str): "cpu" or "cuda" or torch.Tensor.device
72
+ dtype (torch.device): result dtype
73
+
74
+ Returns:
75
+ torch.Tensor: mask
76
+
77
+ Examples:
78
+ >>> subsequent_mask(3)
79
+ [[1, 0, 0],
80
+ [1, 1, 0],
81
+ [1, 1, 1]]
82
+ """
83
+ arange = torch.arange(size, device=device)
84
+ mask = arange.expand(size, size)
85
+ arange = arange.unsqueeze(-1)
86
+ mask = mask <= arange
87
+ return mask
88
+
89
+
90
+ def subsequent_chunk_mask(
91
+ size: int,
92
+ chunk_size: int,
93
+ num_left_chunks: int = -1,
94
+ device: torch.device = torch.device("cpu"),
95
+ ) -> torch.Tensor:
96
+ """Create mask for subsequent steps (size, size) with chunk size,
97
+ this is for streaming encoder
98
+
99
+ Args:
100
+ size (int): size of mask
101
+ chunk_size (int): size of chunk
102
+ num_left_chunks (int): number of left chunks
103
+ <0: use full chunk
104
+ >=0: use num_left_chunks
105
+ device (torch.device): "cpu" or "cuda" or torch.Tensor.device
106
+
107
+ Returns:
108
+ torch.Tensor: mask
109
+
110
+ Examples:
111
+ >>> subsequent_chunk_mask(4, 2)
112
+ [[1, 1, 0, 0],
113
+ [1, 1, 0, 0],
114
+ [1, 1, 1, 1],
115
+ [1, 1, 1, 1]]
116
+ """
117
+ ret = torch.zeros(size, size, device=device, dtype=torch.bool)
118
+ for i in range(size):
119
+ if num_left_chunks < 0:
120
+ start = 0
121
+ else:
122
+ start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
123
+ ending = min((i // chunk_size + 1) * chunk_size, size)
124
+ ret[i, start:ending] = True
125
+ return ret
126
+
127
+
128
+ def add_optional_chunk_mask(
129
+ xs: torch.Tensor,
130
+ masks: torch.Tensor,
131
+ use_dynamic_chunk: bool,
132
+ use_dynamic_left_chunk: bool,
133
+ decoding_chunk_size: int,
134
+ static_chunk_size: int,
135
+ num_decoding_left_chunks: int,
136
+ enable_full_context: bool = True,
137
+ ):
138
+ """Apply optional mask for encoder.
139
+
140
+ Args:
141
+ xs (torch.Tensor): padded input, (B, L, D), L for max length
142
+ mask (torch.Tensor): mask for xs, (B, 1, L)
143
+ use_dynamic_chunk (bool): whether to use dynamic chunk or not
144
+ use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
145
+ training.
146
+ decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
147
+ 0: default for training, use random dynamic chunk.
148
+ <0: for decoding, use full chunk.
149
+ >0: for decoding, use fixed chunk size as set.
150
+ static_chunk_size (int): chunk size for static chunk training/decoding
151
+ if it's greater than 0, if use_dynamic_chunk is true,
152
+ this parameter will be ignored
153
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
154
+ the chunk size is decoding_chunk_size.
155
+ >=0: use num_decoding_left_chunks
156
+ <0: use all left chunks
157
+ enable_full_context (bool):
158
+ True: chunk size is either [1, 25] or full context(max_len)
159
+ False: chunk size ~ U[1, 25]
160
+
161
+ Returns:
162
+ torch.Tensor: chunk mask of the input xs.
163
+ """
164
+ # Whether to use chunk mask or not
165
+ if use_dynamic_chunk:
166
+ max_len = xs.size(1)
167
+ if decoding_chunk_size < 0:
168
+ chunk_size = max_len
169
+ num_left_chunks = -1
170
+ elif decoding_chunk_size > 0:
171
+ chunk_size = decoding_chunk_size
172
+ num_left_chunks = num_decoding_left_chunks
173
+ else:
174
+ # chunk size is either [1, 25] or full context(max_len).
175
+ # Since we use 4 times subsampling and allow up to 1s(100 frames)
176
+ # delay, the maximum frame is 100 / 4 = 25.
177
+ chunk_size = torch.randint(1, max_len, (1,)).item()
178
+ num_left_chunks = -1
179
+ if chunk_size > max_len // 2 and enable_full_context:
180
+ chunk_size = max_len
181
+ else:
182
+ chunk_size = chunk_size % 25 + 1
183
+ if use_dynamic_left_chunk:
184
+ max_left_chunks = (max_len - 1) // chunk_size
185
+ num_left_chunks = torch.randint(0, max_left_chunks, (1,)).item()
186
+ chunk_masks = subsequent_chunk_mask(
187
+ xs.size(1), chunk_size, num_left_chunks, xs.device
188
+ ) # (L, L)
189
+ chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
190
+ chunk_masks = masks & chunk_masks # (B, L, L)
191
+ elif static_chunk_size > 0:
192
+ num_left_chunks = num_decoding_left_chunks
193
+ chunk_masks = subsequent_chunk_mask(
194
+ xs.size(1), static_chunk_size, num_left_chunks, xs.device
195
+ ) # (L, L)
196
+ chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
197
+ chunk_masks = masks & chunk_masks # (B, L, L)
198
+ else:
199
+ chunk_masks = masks
200
+ return chunk_masks
201
+
202
+
203
+ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
204
+ """Make mask tensor containing indices of padded part.
205
+
206
+ See description of make_non_pad_mask.
207
+
208
+ Args:
209
+ lengths (torch.Tensor): Batch of lengths (B,).
210
+ Returns:
211
+ torch.Tensor: Mask tensor containing indices of padded part.
212
+
213
+ Examples:
214
+ >>> lengths = [5, 3, 2]
215
+ >>> make_pad_mask(lengths)
216
+ masks = [[0, 0, 0, 0 ,0],
217
+ [0, 0, 0, 1, 1],
218
+ [0, 0, 1, 1, 1]]
219
+ """
220
+ batch_size = lengths.size(0)
221
+ max_len = max_len if max_len > 0 else lengths.max().item()
222
+ seq_range = torch.arange(0, max_len, dtype=torch.int64, device=lengths.device)
223
+ seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
224
+ seq_length_expand = lengths.unsqueeze(-1)
225
+ mask = seq_range_expand >= seq_length_expand
226
+ return mask
cosyvoice/utils/scheduler.py ADDED
@@ -0,0 +1,761 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
2
+ # 2022 Ximalaya Inc (Yuguang Yang)
3
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # Modified from ESPnet(https://github.com/espnet/espnet)
17
+ # NeMo(https://github.com/NVIDIA/NeMo)
18
+
19
+ from typing import Union
20
+
21
+ import math
22
+ import warnings
23
+ import torch
24
+ from torch.optim.lr_scheduler import _LRScheduler
25
+
26
+
27
+ class WarmupLR(_LRScheduler):
28
+ """The WarmupLR scheduler
29
+
30
+ This scheduler is almost same as NoamLR Scheduler except for following
31
+ difference:
32
+
33
+ NoamLR:
34
+ lr = optimizer.lr * model_size ** -0.5
35
+ * min(step ** -0.5, step * warmup_step ** -1.5)
36
+ WarmupLR:
37
+ lr = optimizer.lr * warmup_step ** 0.5
38
+ * min(step ** -0.5, step * warmup_step ** -1.5)
39
+
40
+ Note that the maximum lr equals to optimizer.lr in this scheduler.
41
+
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ optimizer: torch.optim.Optimizer,
47
+ warmup_steps: Union[int, float] = 25000,
48
+ last_epoch: int = -1,
49
+ ):
50
+ self.warmup_steps = warmup_steps
51
+
52
+ # __init__() must be invoked before setting field
53
+ # because step() is also invoked in __init__()
54
+ super().__init__(optimizer, last_epoch)
55
+
56
+ def __repr__(self):
57
+ return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})"
58
+
59
+ def get_lr(self):
60
+ step_num = self.last_epoch + 1
61
+ if self.warmup_steps == 0:
62
+ return [lr * step_num**-0.5 for lr in self.base_lrs]
63
+ else:
64
+ return [
65
+ lr
66
+ * self.warmup_steps**0.5
67
+ * min(step_num**-0.5, step_num * self.warmup_steps**-1.5)
68
+ for lr in self.base_lrs
69
+ ]
70
+
71
+ def set_step(self, step: int):
72
+ self.last_epoch = step
73
+
74
+
75
+ class WarmupPolicy(_LRScheduler):
76
+ """Adds warmup kwargs and warmup logic to lr policy.
77
+ All arguments should be passed as kwargs for clarity,
78
+ Args:
79
+ warmup_steps: Number of training steps in warmup stage
80
+ warmup_ratio: Ratio of warmup steps to total steps
81
+ max_steps: Total number of steps while training or `None` for
82
+ infinite training
83
+ """
84
+
85
+ def __init__(
86
+ self,
87
+ optimizer,
88
+ *,
89
+ warmup_steps=None,
90
+ warmup_ratio=None,
91
+ max_steps=None,
92
+ min_lr=0.0,
93
+ last_epoch=-1,
94
+ ):
95
+ assert not (
96
+ warmup_steps is not None and warmup_ratio is not None
97
+ ), "Either use particular number of step or ratio"
98
+ assert (
99
+ warmup_ratio is None or max_steps is not None
100
+ ), "If there is a ratio, there should be a total steps"
101
+
102
+ # It is necessary to assign all attributes *before* __init__,
103
+ # as class is wrapped by an inner class.
104
+ self.max_steps = max_steps
105
+ if warmup_steps is not None:
106
+ self.warmup_steps = warmup_steps
107
+ elif warmup_ratio is not None:
108
+ self.warmup_steps = int(warmup_ratio * max_steps)
109
+ else:
110
+ self.warmup_steps = 0
111
+
112
+ self.min_lr = min_lr
113
+ super().__init__(optimizer, last_epoch)
114
+
115
+ def get_lr(self):
116
+ if not self._get_lr_called_within_step:
117
+ warnings.warn(
118
+ "To get the last learning rate computed "
119
+ "by the scheduler, please use `get_last_lr()`.",
120
+ UserWarning,
121
+ stacklevel=2,
122
+ )
123
+
124
+ step = self.last_epoch
125
+
126
+ if step <= self.warmup_steps and self.warmup_steps > 0:
127
+ return self._get_warmup_lr(step)
128
+
129
+ if step > self.max_steps:
130
+ return [self.min_lr for _ in self.base_lrs]
131
+
132
+ return self._get_lr(step)
133
+
134
+ def _get_warmup_lr(self, step):
135
+ lr_val = (step + 1) / (self.warmup_steps + 1)
136
+ return [initial_lr * lr_val for initial_lr in self.base_lrs]
137
+
138
+ def _get_lr(self, step):
139
+ """Simple const lr policy"""
140
+ return self.base_lrs
141
+
142
+
143
+ class SquareRootConstantPolicy(_LRScheduler):
144
+ """Adds warmup kwargs and warmup logic to lr policy.
145
+ All arguments should be passed as kwargs for clarity,
146
+ Args:
147
+ warmup_steps: Number of training steps in warmup stage
148
+ warmup_ratio: Ratio of warmup steps to total steps
149
+ max_steps: Total number of steps while training or `None` for
150
+ infinite training
151
+ """
152
+
153
+ def __init__(
154
+ self,
155
+ optimizer,
156
+ *,
157
+ constant_steps=None,
158
+ constant_ratio=None,
159
+ max_steps=None,
160
+ min_lr=0.0,
161
+ last_epoch=-1,
162
+ ):
163
+ assert not (
164
+ constant_steps is not None and constant_ratio is not None
165
+ ), "Either use particular number of step or ratio"
166
+ assert (
167
+ constant_ratio is None or max_steps is not None
168
+ ), "If there is a ratio, there should be a total steps"
169
+
170
+ # It is necessary to assign all attributes *before* __init__,
171
+ # as class is wrapped by an inner class.
172
+ self.max_steps = max_steps
173
+ if constant_steps is not None:
174
+ self.constant_steps = constant_steps
175
+ elif constant_ratio is not None:
176
+ self.constant_steps = int(constant_ratio * max_steps)
177
+ else:
178
+ self.constant_steps = 0
179
+
180
+ self.constant_lr = 1 / (constant_steps**0.5)
181
+ self.min_lr = min_lr
182
+ super().__init__(optimizer, last_epoch)
183
+
184
+ def get_lr(self):
185
+ if not self._get_lr_called_within_step:
186
+ warnings.warn(
187
+ "To get the last learning rate computed "
188
+ "by the scheduler, please use `get_last_lr()`.",
189
+ UserWarning,
190
+ stacklevel=2,
191
+ )
192
+
193
+ step = self.last_epoch
194
+
195
+ if step <= self.constant_steps:
196
+ return [self.constant_lr for _ in self.base_lrs]
197
+
198
+ if step > self.max_steps:
199
+ return [self.min_lr for _ in self.base_lrs]
200
+
201
+ return self._get_lr(step)
202
+
203
+ def _get_lr(self, step):
204
+ """Simple const lr policy"""
205
+ return self.base_lrs
206
+
207
+
208
+ class WarmupHoldPolicy(WarmupPolicy):
209
+ """Variant of WarmupPolicy which maintains high
210
+ learning rate for a defined number of steps.
211
+ All arguments should be passed as kwargs for clarity,
212
+ Args:
213
+ warmup_steps: Number of training steps in warmup stage
214
+ warmup_ratio: Ratio of warmup steps to total steps
215
+ hold_steps: Number of training steps to
216
+ hold the learning rate after warm up
217
+ hold_ratio: Ratio of hold steps to total steps
218
+ max_steps: Total number of steps while training or `None` for
219
+ infinite training
220
+ """
221
+
222
+ def __init__(
223
+ self,
224
+ optimizer,
225
+ *,
226
+ warmup_steps=None,
227
+ warmup_ratio=None,
228
+ hold_steps=None,
229
+ hold_ratio=None,
230
+ max_steps=None,
231
+ min_lr=0.0,
232
+ last_epoch=-1,
233
+ ):
234
+ assert not (
235
+ hold_steps is not None and hold_ratio is not None
236
+ ), "Either use particular number of step or ratio"
237
+ assert (
238
+ hold_ratio is None or max_steps is not None
239
+ ), "If there is a ratio, there should be a total steps"
240
+
241
+ self.min_lr = min_lr
242
+ self._last_warmup_lr = 0.0
243
+
244
+ # Necessary to duplicate as class attributes are hidden in inner class
245
+ self.max_steps = max_steps
246
+ if warmup_steps is not None:
247
+ self.warmup_steps = warmup_steps
248
+ elif warmup_ratio is not None:
249
+ self.warmup_steps = int(warmup_ratio * max_steps)
250
+ else:
251
+ self.warmup_steps = 0
252
+
253
+ if hold_steps is not None:
254
+ self.hold_steps = hold_steps + self.warmup_steps
255
+ elif hold_ratio is not None:
256
+ self.hold_steps = int(hold_ratio * max_steps) + self.warmup_steps
257
+ else:
258
+ self.hold_steps = 0
259
+
260
+ super().__init__(
261
+ optimizer,
262
+ warmup_steps=warmup_steps,
263
+ warmup_ratio=warmup_ratio,
264
+ max_steps=max_steps,
265
+ last_epoch=last_epoch,
266
+ min_lr=min_lr,
267
+ )
268
+
269
+ def get_lr(self):
270
+ if not self._get_lr_called_within_step:
271
+ warnings.warn(
272
+ "To get the last learning rate computed by the scheduler,"
273
+ " "
274
+ "please use `get_last_lr()`.",
275
+ UserWarning,
276
+ stacklevel=2,
277
+ )
278
+
279
+ step = self.last_epoch
280
+
281
+ # Warmup phase
282
+ if step <= self.warmup_steps and self.warmup_steps > 0:
283
+ return self._get_warmup_lr(step)
284
+
285
+ # Hold phase
286
+ if (step >= self.warmup_steps) and (step < self.hold_steps):
287
+ return self.base_lrs
288
+
289
+ if step > self.max_steps:
290
+ return [self.min_lr for _ in self.base_lrs]
291
+
292
+ return self._get_lr(step)
293
+
294
+
295
+ class WarmupAnnealHoldPolicy(_LRScheduler):
296
+ """Adds warmup kwargs and warmup logic to lr policy.
297
+ All arguments should be passed as kwargs for clarity,
298
+ Args:
299
+ warmup_steps: Number of training steps in warmup stage
300
+ warmup_ratio: Ratio of warmup steps to total steps
301
+ max_steps: Total number of steps while training or `None` for
302
+ infinite training
303
+ min_lr: Minimum lr to hold the learning rate after decay at.
304
+ constant_steps: Number of steps to keep lr constant at.
305
+ constant_ratio: Ratio of steps to keep lr constant.
306
+ """
307
+
308
+ def __init__(
309
+ self,
310
+ optimizer,
311
+ *,
312
+ warmup_steps=None,
313
+ warmup_ratio=None,
314
+ constant_steps=None,
315
+ constant_ratio=None,
316
+ max_steps=None,
317
+ min_lr=0.0,
318
+ last_epoch=-1,
319
+ ):
320
+ assert not (
321
+ warmup_steps is not None and warmup_ratio is not None
322
+ ), "Either use particular number of step or ratio"
323
+ assert not (
324
+ constant_steps is not None and constant_ratio is not None
325
+ ), "Either use constant_steps or constant_ratio"
326
+ assert (
327
+ warmup_ratio is None or max_steps is not None
328
+ ), "If there is a ratio, there should be a total steps"
329
+
330
+ # It is necessary to assign all attributes *before* __init__,
331
+ # as class is wrapped by an inner class.
332
+ self.max_steps = max_steps
333
+
334
+ if warmup_steps is not None:
335
+ self.warmup_steps = warmup_steps
336
+ elif warmup_ratio is not None:
337
+ self.warmup_steps = int(warmup_ratio * max_steps)
338
+ else:
339
+ self.warmup_steps = 0
340
+
341
+ if constant_steps is not None:
342
+ self.constant_steps = constant_steps
343
+ elif constant_ratio is not None:
344
+ self.constant_steps = int(constant_ratio * max_steps)
345
+ else:
346
+ self.constant_steps = 0
347
+
348
+ self.decay_steps = max_steps - (self.constant_steps + self.warmup_steps)
349
+
350
+ self.min_lr = min_lr
351
+ super().__init__(optimizer, last_epoch)
352
+
353
+ def get_lr(self):
354
+ if not self._get_lr_called_within_step:
355
+ warnings.warn(
356
+ "To get the last learning rate computed "
357
+ "by the scheduler, please use `get_last_lr()`.",
358
+ UserWarning,
359
+ stacklevel=2,
360
+ )
361
+
362
+ step = self.last_epoch
363
+
364
+ # Warmup steps
365
+ if self.warmup_steps > 0 and step <= self.warmup_steps:
366
+ return self._get_warmup_lr(step)
367
+
368
+ # Constant steps after warmup and decay
369
+ if (
370
+ self.constant_steps > 0
371
+ and (self.warmup_steps + self.decay_steps) < step <= self.max_steps
372
+ ):
373
+ return self._get_constant_lr(step)
374
+
375
+ # Min lr after max steps of updates
376
+ if step > self.max_steps:
377
+ return [self.min_lr for _ in self.base_lrs]
378
+
379
+ return self._get_lr(step)
380
+
381
+ def _get_warmup_lr(self, step):
382
+ lr_val = (step + 1) / (self.warmup_steps + 1)
383
+ return [initial_lr * lr_val for initial_lr in self.base_lrs]
384
+
385
+ def _get_constant_lr(self, step):
386
+ return [self.min_lr for _ in self.base_lrs]
387
+
388
+ def _get_lr(self, step):
389
+ """Simple const lr policy"""
390
+ return self.base_lrs
391
+
392
+
393
+ def _squareroot_annealing(initial_lr, step, max_steps, min_lr):
394
+ mult = ((max_steps - step) / max_steps) ** 0.5
395
+ out_lr = initial_lr * mult
396
+ out_lr = max(out_lr, min_lr)
397
+ return out_lr
398
+
399
+
400
+ def _square_annealing(initial_lr, step, max_steps, min_lr):
401
+ mult = ((max_steps - step) / max_steps) ** 2
402
+ out_lr = initial_lr * mult
403
+ out_lr = max(out_lr, min_lr)
404
+ return out_lr
405
+
406
+
407
+ def _cosine_annealing(initial_lr, step, max_steps, min_lr):
408
+ mult = 0.5 * (1 + math.cos(math.pi * step / max_steps))
409
+ out_lr = (initial_lr - min_lr) * mult + min_lr
410
+ return out_lr
411
+
412
+
413
+ def _linear_warmup_with_cosine_annealing(
414
+ max_lr, warmup_steps, step, decay_steps, min_lr
415
+ ):
416
+ assert max_lr > min_lr
417
+ # Use linear warmup for the initial part.
418
+ if warmup_steps > 0 and step <= warmup_steps:
419
+ return max_lr * float(step) / float(warmup_steps)
420
+
421
+ # For any steps larger than `decay_steps`, use `min_lr`.
422
+ if step > warmup_steps + decay_steps:
423
+ return min_lr
424
+
425
+ # If we are done with the warmup period, use the decay style.
426
+ num_steps_ = step - warmup_steps
427
+ decay_steps_ = decay_steps
428
+ decay_ratio = float(num_steps_) / float(decay_steps_)
429
+ assert decay_ratio >= 0.0
430
+ assert decay_ratio <= 1.0
431
+ delta_lr = max_lr - min_lr
432
+
433
+ coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)
434
+
435
+ return min_lr + coeff * delta_lr
436
+
437
+
438
+ def _poly_decay(initial_lr, step, decay_steps, power, min_lr, cycle):
439
+ if cycle:
440
+ multiplier = 1.0 if step == 0 else math.ceil(step / decay_steps)
441
+ decay_steps *= multiplier
442
+ else:
443
+ step = min(step, decay_steps)
444
+ p = step / decay_steps
445
+ lr = (initial_lr - min_lr) * math.pow(1.0 - p, power)
446
+ lr += min_lr
447
+ return lr
448
+
449
+
450
+ def _noam_hold_annealing(
451
+ initial_lr, step, warmup_steps, hold_steps, decay_rate, min_lr
452
+ ):
453
+ # hold_steps = total number of steps
454
+ # to hold the LR, not the warmup + hold steps.
455
+ T_warmup_decay = max(1, warmup_steps**decay_rate)
456
+ T_hold_decay = max(1, (step - hold_steps) ** decay_rate)
457
+ lr = (initial_lr * T_warmup_decay) / T_hold_decay
458
+ lr = max(lr, min_lr)
459
+ return lr
460
+
461
+
462
+ class SquareAnnealing(WarmupPolicy):
463
+
464
+ def __init__(self, optimizer, *, max_steps, min_lr=1e-5, last_epoch=-1, **kwargs):
465
+ super().__init__(
466
+ optimizer=optimizer,
467
+ max_steps=max_steps,
468
+ last_epoch=last_epoch,
469
+ min_lr=min_lr,
470
+ **kwargs,
471
+ )
472
+
473
+ def _get_lr(self, step):
474
+ new_lrs = [
475
+ _square_annealing(
476
+ initial_lr=initial_lr,
477
+ step=step - self.warmup_steps,
478
+ max_steps=self.max_steps - self.warmup_steps,
479
+ min_lr=self.min_lr,
480
+ )
481
+ for initial_lr in self.base_lrs
482
+ ]
483
+ return new_lrs
484
+
485
+
486
+ class SquareRootAnnealing(WarmupPolicy):
487
+
488
+ def __init__(self, optimizer, *, max_steps, min_lr=0, last_epoch=-1, **kwargs):
489
+ super().__init__(
490
+ optimizer=optimizer,
491
+ max_steps=max_steps,
492
+ last_epoch=last_epoch,
493
+ min_lr=min_lr,
494
+ **kwargs,
495
+ )
496
+
497
+ def _get_lr(self, step):
498
+ new_lrs = [
499
+ _squareroot_annealing(
500
+ initial_lr=initial_lr,
501
+ step=step,
502
+ max_steps=self.max_steps,
503
+ min_lr=self.min_lr,
504
+ )
505
+ for initial_lr in self.base_lrs
506
+ ]
507
+ return new_lrs
508
+
509
+
510
+ class CosineAnnealing(WarmupAnnealHoldPolicy):
511
+
512
+ def __init__(self, optimizer, *, max_steps, min_lr=0, last_epoch=-1, **kwargs):
513
+ super().__init__(
514
+ optimizer=optimizer,
515
+ max_steps=max_steps,
516
+ last_epoch=last_epoch,
517
+ min_lr=min_lr,
518
+ **kwargs,
519
+ )
520
+
521
+ def _get_lr(self, step):
522
+ for initial_lr in self.base_lrs:
523
+ if initial_lr < self.min_lr:
524
+ raise ValueError(
525
+ f"{self} received an initial learning rate "
526
+ f"that was lower than the minimum learning rate."
527
+ )
528
+
529
+ if self.constant_steps is None or self.constant_steps == 0:
530
+ new_lrs = [
531
+ _cosine_annealing(
532
+ initial_lr=initial_lr,
533
+ step=step - self.warmup_steps,
534
+ max_steps=self.max_steps - self.warmup_steps,
535
+ min_lr=self.min_lr,
536
+ )
537
+ for initial_lr in self.base_lrs
538
+ ]
539
+ else:
540
+ new_lrs = self._get_linear_warmup_with_cosine_annealing_lr(step)
541
+ return new_lrs
542
+
543
+ def _get_warmup_lr(self, step):
544
+ if self.constant_steps is None or self.constant_steps == 0:
545
+ return super()._get_warmup_lr(step)
546
+ else:
547
+ # Use linear warmup for the initial part.
548
+ return self._get_linear_warmup_with_cosine_annealing_lr(step)
549
+
550
+ def _get_constant_lr(self, step):
551
+ # Only called when `constant_steps` > 0.
552
+ return self._get_linear_warmup_with_cosine_annealing_lr(step)
553
+
554
+ def _get_linear_warmup_with_cosine_annealing_lr(self, step):
555
+ # Cosine Schedule for Megatron LM,
556
+ # slightly different warmup schedule + constant LR at the end.
557
+ new_lrs = [
558
+ _linear_warmup_with_cosine_annealing(
559
+ max_lr=self.base_lrs[0],
560
+ warmup_steps=self.warmup_steps,
561
+ step=step,
562
+ decay_steps=self.decay_steps,
563
+ min_lr=self.min_lr,
564
+ )
565
+ for _ in self.base_lrs
566
+ ]
567
+ return new_lrs
568
+
569
+
570
+ class NoamAnnealing(_LRScheduler):
571
+
572
+ def __init__(
573
+ self,
574
+ optimizer,
575
+ *,
576
+ d_model,
577
+ warmup_steps=None,
578
+ warmup_ratio=None,
579
+ max_steps=None,
580
+ min_lr=0.0,
581
+ last_epoch=-1,
582
+ ):
583
+ self._normalize = d_model ** (-0.5)
584
+ assert not (
585
+ warmup_steps is not None and warmup_ratio is not None
586
+ ), "Either use particular number of step or ratio"
587
+ assert (
588
+ warmup_ratio is None or max_steps is not None
589
+ ), "If there is a ratio, there should be a total steps"
590
+
591
+ # It is necessary to assign all attributes *before* __init__,
592
+ # as class is wrapped by an inner class.
593
+ self.max_steps = max_steps
594
+ if warmup_steps is not None:
595
+ self.warmup_steps = warmup_steps
596
+ elif warmup_ratio is not None:
597
+ self.warmup_steps = int(warmup_ratio * max_steps)
598
+ else:
599
+ self.warmup_steps = 0
600
+
601
+ self.min_lr = min_lr
602
+ super().__init__(optimizer, last_epoch)
603
+
604
+ def get_lr(self):
605
+ if not self._get_lr_called_within_step:
606
+ warnings.warn(
607
+ "To get the last learning rate computed "
608
+ "by the scheduler, please use `get_last_lr()`.",
609
+ UserWarning,
610
+ stacklevel=2,
611
+ )
612
+
613
+ step = max(1, self.last_epoch)
614
+
615
+ for initial_lr in self.base_lrs:
616
+ if initial_lr < self.min_lr:
617
+ raise ValueError(
618
+ f"{self} received an initial learning rate "
619
+ f"that was lower than the minimum learning rate."
620
+ )
621
+
622
+ new_lrs = [
623
+ self._noam_annealing(initial_lr=initial_lr, step=step)
624
+ for initial_lr in self.base_lrs
625
+ ]
626
+ return new_lrs
627
+
628
+ def _noam_annealing(self, initial_lr, step):
629
+ if self.warmup_steps > 0:
630
+ mult = self._normalize * min(
631
+ step ** (-0.5), step * (self.warmup_steps ** (-1.5))
632
+ )
633
+ else:
634
+ mult = self._normalize * step ** (-0.5)
635
+
636
+ out_lr = initial_lr * mult
637
+ if step > self.warmup_steps:
638
+ out_lr = max(out_lr, self.min_lr)
639
+ return out_lr
640
+
641
+
642
+ class NoamHoldAnnealing(WarmupHoldPolicy):
643
+
644
+ def __init__(
645
+ self,
646
+ optimizer,
647
+ *,
648
+ max_steps,
649
+ decay_rate=0.5,
650
+ min_lr=0.0,
651
+ last_epoch=-1,
652
+ **kwargs,
653
+ ):
654
+ """
655
+ From Nemo:
656
+ Implementation of the Noam Hold Annealing policy
657
+ from the SqueezeFormer paper.
658
+
659
+ Unlike NoamAnnealing, the peak learning rate
660
+ can be explicitly set for this scheduler.
661
+ The schedule first performs linear warmup,
662
+ then holds the peak LR, then decays with some schedule for
663
+ the remainder of the steps.
664
+ Therefore the min-lr is still dependent
665
+ on the hyper parameters selected.
666
+
667
+ It's schedule is determined by three factors-
668
+
669
+ Warmup Steps: Initial stage, where linear warmup
670
+ occurs uptil the peak LR is reached. Unlike NoamAnnealing,
671
+ the peak LR is explicitly stated here instead of a scaling factor.
672
+
673
+ Hold Steps: Intermediate stage, where the peak LR
674
+ is maintained for some number of steps. In this region,
675
+ the high peak LR allows the model to converge faster
676
+ if training is stable. However the high LR
677
+ may also cause instability during training.
678
+ Should usually be a significant fraction of training
679
+ steps (around 30-40% of the entire training steps).
680
+
681
+ Decay Steps: Final stage, where the LR rapidly decays
682
+ with some scaling rate (set by decay rate).
683
+ To attain Noam decay, use 0.5,
684
+ for Squeezeformer recommended decay, use 1.0.
685
+ The fast decay after prolonged high LR during
686
+ hold phase allows for rapid convergence.
687
+
688
+ References:
689
+ - [Squeezeformer:
690
+ An Efficient Transformer for Automatic Speech Recognition]
691
+ (https://arxiv.org/abs/2206.00888)
692
+
693
+ Args:
694
+ optimizer: Pytorch compatible Optimizer object.
695
+ warmup_steps: Number of training steps in warmup stage
696
+ warmup_ratio: Ratio of warmup steps to total steps
697
+ hold_steps: Number of training steps to
698
+ hold the learning rate after warm up
699
+ hold_ratio: Ratio of hold steps to total steps
700
+ max_steps: Total number of steps while training or `None` for
701
+ infinite training
702
+ decay_rate: Float value describing the polynomial decay
703
+ after the hold period. Default value
704
+ of 0.5 corresponds to Noam decay.
705
+ min_lr: Minimum learning rate.
706
+ """
707
+ self.decay_rate = decay_rate
708
+ super().__init__(
709
+ optimizer=optimizer,
710
+ max_steps=max_steps,
711
+ last_epoch=last_epoch,
712
+ min_lr=min_lr,
713
+ **kwargs,
714
+ )
715
+
716
+ def _get_lr(self, step):
717
+ if self.warmup_steps is None or self.warmup_steps == 0:
718
+ raise ValueError("Noam scheduler cannot be used without warmup steps")
719
+
720
+ if self.hold_steps > 0:
721
+ hold_steps = self.hold_steps - self.warmup_steps
722
+ else:
723
+ hold_steps = 0
724
+
725
+ new_lrs = [
726
+ _noam_hold_annealing(
727
+ initial_lr,
728
+ step=step,
729
+ warmup_steps=self.warmup_steps,
730
+ hold_steps=hold_steps,
731
+ decay_rate=self.decay_rate,
732
+ min_lr=self.min_lr,
733
+ )
734
+ for initial_lr in self.base_lrs
735
+ ]
736
+ return new_lrs
737
+
738
+ def set_step(self, step: int):
739
+ self.last_epoch = step
740
+
741
+
742
+ class ConstantLR(_LRScheduler):
743
+ """The ConstantLR scheduler
744
+
745
+ This scheduler keeps a constant lr
746
+
747
+ """
748
+
749
+ def __init__(
750
+ self,
751
+ optimizer: torch.optim.Optimizer,
752
+ ):
753
+ # __init__() must be invoked before setting field
754
+ # because step() is also invoked in __init__()
755
+ super().__init__(optimizer)
756
+
757
+ def get_lr(self):
758
+ return self.base_lrs
759
+
760
+ def set_step(self, step: int):
761
+ self.last_epoch = step
cosyvoice/utils/train_utils.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
2
+ # 2023 Horizon Inc. (authors: Xingchen Song)
3
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from contextlib import nullcontext
18
+ import logging
19
+ import os
20
+ import torch
21
+ import json
22
+ import re
23
+ import datetime
24
+ import yaml
25
+
26
+ import deepspeed
27
+ import torch.optim as optim
28
+ import torch.distributed as dist
29
+
30
+ from torch.utils.tensorboard import SummaryWriter
31
+ from torch.utils.data import DataLoader
32
+ from torch.nn.utils import clip_grad_norm_
33
+
34
+ from deepspeed.runtime.zero.stage_1_and_2 import (
35
+ estimate_zero2_model_states_mem_needs_all_live,
36
+ )
37
+
38
+ from cosyvoice.dataset.dataset import Dataset
39
+ from cosyvoice.utils.scheduler import (
40
+ WarmupLR,
41
+ NoamHoldAnnealing,
42
+ ConstantLR,
43
+ )
44
+
45
+
46
+ def init_distributed(args):
47
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
48
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
49
+ rank = int(os.environ.get("RANK", 0))
50
+ logging.info(
51
+ "training on multiple gpus, this gpu {}".format(local_rank)
52
+ + ", rank {}, world_size {}".format(rank, world_size)
53
+ )
54
+ if args.train_engine == "torch_ddp":
55
+ torch.cuda.set_device(local_rank)
56
+ dist.init_process_group(args.dist_backend)
57
+ else:
58
+ deepspeed.init_distributed(dist_backend=args.dist_backend)
59
+ return world_size, local_rank, rank
60
+
61
+
62
+ def init_dataset_and_dataloader(args, configs):
63
+ train_dataset = Dataset(
64
+ args.train_data,
65
+ data_pipeline=configs["data_pipeline"],
66
+ mode="train",
67
+ shuffle=True,
68
+ partition=True,
69
+ )
70
+ cv_dataset = Dataset(
71
+ args.cv_data,
72
+ data_pipeline=configs["data_pipeline"],
73
+ mode="train",
74
+ shuffle=False,
75
+ partition=False,
76
+ )
77
+
78
+ # do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts
79
+ train_data_loader = DataLoader(
80
+ train_dataset,
81
+ batch_size=None,
82
+ pin_memory=args.pin_memory,
83
+ num_workers=args.num_workers,
84
+ prefetch_factor=args.prefetch,
85
+ )
86
+ cv_data_loader = DataLoader(
87
+ cv_dataset,
88
+ batch_size=None,
89
+ pin_memory=args.pin_memory,
90
+ num_workers=args.num_workers,
91
+ prefetch_factor=args.prefetch,
92
+ )
93
+ return train_dataset, cv_dataset, train_data_loader, cv_data_loader
94
+
95
+
96
+ def check_modify_and_save_config(args, configs):
97
+ if args.train_engine == "torch_ddp":
98
+ configs["train_conf"]["dtype"] = "fp32"
99
+ else:
100
+ with open(args.deepspeed_config, "r") as fin:
101
+ ds_configs = json.load(fin)
102
+ if "fp16" in ds_configs and ds_configs["fp16"]["enabled"]:
103
+ configs["train_conf"]["dtype"] = "fp16"
104
+ elif "bf16" in ds_configs and ds_configs["bf16"]["enabled"]:
105
+ configs["train_conf"]["dtype"] = "bf16"
106
+ else:
107
+ configs["train_conf"]["dtype"] = "fp32"
108
+ assert ds_configs["train_micro_batch_size_per_gpu"] == 1
109
+ # if use deepspeed, override ddp config
110
+ configs["train_conf"]["save_per_step"] = int(
111
+ configs["train_conf"]["save_per_step"]
112
+ * configs["train_conf"]["accum_grad"]
113
+ / ds_configs["gradient_accumulation_steps"]
114
+ )
115
+ configs["train_conf"]["accum_grad"] = ds_configs["gradient_accumulation_steps"]
116
+ configs["train_conf"]["grad_clip"] = ds_configs["gradient_clipping"]
117
+ configs["train_conf"]["log_interval"] = ds_configs["steps_per_print"]
118
+ return configs
119
+
120
+
121
+ def wrap_cuda_model(args, model):
122
+ local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
123
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
124
+ if args.train_engine == "torch_ddp": # native pytorch ddp
125
+ assert torch.cuda.is_available()
126
+ model.cuda()
127
+ model = torch.nn.parallel.DistributedDataParallel(
128
+ model, find_unused_parameters=True
129
+ )
130
+ else:
131
+ if int(os.environ.get("RANK", 0)) == 0:
132
+ logging.info("Estimating model states memory needs (zero2)...")
133
+ estimate_zero2_model_states_mem_needs_all_live(
134
+ model,
135
+ num_gpus_per_node=local_world_size,
136
+ num_nodes=world_size // local_world_size,
137
+ )
138
+ return model
139
+
140
+
141
+ def init_optimizer_and_scheduler(args, configs, model):
142
+ if configs["train_conf"]["optim"] == "adam":
143
+ optimizer = optim.Adam(
144
+ model.parameters(), **configs["train_conf"]["optim_conf"]
145
+ )
146
+ elif configs["train_conf"]["optim"] == "adamw":
147
+ optimizer = optim.AdamW(
148
+ model.parameters(), **configs["train_conf"]["optim_conf"]
149
+ )
150
+ else:
151
+ raise ValueError("unknown optimizer: " + configs["train_conf"])
152
+
153
+ if configs["train_conf"]["scheduler"] == "warmuplr":
154
+ scheduler_type = WarmupLR
155
+ scheduler = WarmupLR(optimizer, **configs["train_conf"]["scheduler_conf"])
156
+ elif configs["train_conf"]["scheduler"] == "NoamHoldAnnealing":
157
+ scheduler_type = NoamHoldAnnealing
158
+ scheduler = NoamHoldAnnealing(
159
+ optimizer, **configs["train_conf"]["scheduler_conf"]
160
+ )
161
+ elif configs["train_conf"]["scheduler"] == "constantlr":
162
+ scheduler_type = ConstantLR
163
+ scheduler = ConstantLR(optimizer)
164
+ else:
165
+ raise ValueError("unknown scheduler: " + configs["train_conf"])
166
+
167
+ # use deepspeed optimizer for speedup
168
+ if args.train_engine == "deepspeed":
169
+
170
+ def scheduler(opt):
171
+ return scheduler_type(opt, **configs["train_conf"]["scheduler_conf"])
172
+
173
+ model, optimizer, _, scheduler = deepspeed.initialize(
174
+ args=args,
175
+ model=model,
176
+ optimizer=None,
177
+ lr_scheduler=scheduler,
178
+ model_parameters=model.parameters(),
179
+ )
180
+
181
+ return model, optimizer, scheduler
182
+
183
+
184
+ def init_summarywriter(args):
185
+ writer = None
186
+ if int(os.environ.get("RANK", 0)) == 0:
187
+ os.makedirs(args.model_dir, exist_ok=True)
188
+ writer = SummaryWriter(args.tensorboard_dir)
189
+ return writer
190
+
191
+
192
+ def save_model(model, model_name, info_dict):
193
+ rank = int(os.environ.get("RANK", 0))
194
+ model_dir = info_dict["model_dir"]
195
+ save_model_path = os.path.join(model_dir, "{}.pt".format(model_name))
196
+
197
+ if info_dict["train_engine"] == "torch_ddp":
198
+ if rank == 0:
199
+ torch.save(model.module.state_dict(), save_model_path)
200
+ else:
201
+ with torch.no_grad():
202
+ model.save_checkpoint(
203
+ save_dir=model_dir, tag=model_name, client_state=info_dict
204
+ )
205
+ if rank == 0:
206
+ info_path = re.sub(".pt$", ".yaml", save_model_path)
207
+ info_dict["save_time"] = datetime.datetime.now().strftime("%d/%m/%Y %H:%M:%S")
208
+ with open(info_path, "w") as fout:
209
+ data = yaml.dump(info_dict)
210
+ fout.write(data)
211
+ logging.info(
212
+ "[Rank {}] Checkpoint: save to checkpoint {}".format(rank, save_model_path)
213
+ )
214
+
215
+
216
+ def cosyvoice_join(group_join, info_dict):
217
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
218
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
219
+ rank = int(os.environ.get("RANK", 0))
220
+
221
+ if info_dict["batch_idx"] != 0:
222
+ # we try to join all rank in both ddp and deepspeed mode, in case different rank has different lr
223
+ try:
224
+ dist.monitored_barrier(
225
+ group=group_join, timeout=group_join.options._timeout
226
+ )
227
+ return False
228
+ except RuntimeError as e:
229
+ logging.info(
230
+ "Detected uneven workload distribution: {}\n".format(e)
231
+ + "Break current worker to manually join all workers, "
232
+ + "world_size {}, current rank {}, current local_rank {}\n".format(
233
+ world_size, rank, local_rank
234
+ )
235
+ )
236
+ return True
237
+ else:
238
+ return False
239
+
240
+
241
+ def batch_forward(model, batch, info_dict):
242
+ device = int(os.environ.get("LOCAL_RANK", 0))
243
+
244
+ dtype = info_dict["dtype"]
245
+ if dtype == "fp16":
246
+ dtype = torch.float16
247
+ elif dtype == "bf16":
248
+ dtype = torch.bfloat16
249
+ else: # fp32
250
+ dtype = torch.float32
251
+
252
+ if info_dict["train_engine"] == "torch_ddp":
253
+ autocast = nullcontext()
254
+ else:
255
+ autocast = torch.cuda.amp.autocast(
256
+ enabled=True, dtype=dtype, cache_enabled=False
257
+ )
258
+
259
+ with autocast:
260
+ info_dict["loss_dict"] = model(batch, device)
261
+ return info_dict
262
+
263
+
264
+ def batch_backward(model, info_dict):
265
+ if info_dict["train_engine"] == "deepspeed":
266
+ scaled_loss = model.backward(info_dict["loss_dict"]["loss"])
267
+ else:
268
+ scaled_loss = info_dict["loss_dict"]["loss"] / info_dict["accum_grad"]
269
+ scaled_loss.backward()
270
+
271
+ info_dict["loss_dict"]["loss"] = scaled_loss
272
+ return info_dict
273
+
274
+
275
+ def update_parameter_and_lr(model, optimizer, scheduler, info_dict):
276
+ grad_norm = 0.0
277
+ if info_dict["train_engine"] == "deepspeed":
278
+ info_dict["is_gradient_accumulation_boundary"] = (
279
+ model.is_gradient_accumulation_boundary()
280
+ )
281
+ model.step()
282
+ grad_norm = model.get_global_grad_norm()
283
+ elif (info_dict["batch_idx"] + 1) % info_dict["accum_grad"] == 0:
284
+ grad_norm = clip_grad_norm_(model.parameters(), info_dict["grad_clip"])
285
+ if torch.isfinite(grad_norm):
286
+ optimizer.step()
287
+ optimizer.zero_grad()
288
+ scheduler.step()
289
+ info_dict["lr"] = optimizer.param_groups[0]["lr"]
290
+ info_dict["grad_norm"] = grad_norm
291
+ return info_dict
292
+
293
+
294
+ def log_per_step(writer, info_dict):
295
+ tag = info_dict["tag"]
296
+ epoch = info_dict.get("epoch", 0)
297
+ step = info_dict["step"]
298
+ batch_idx = info_dict["batch_idx"]
299
+ loss_dict = info_dict["loss_dict"]
300
+ rank = int(os.environ.get("RANK", 0))
301
+
302
+ # only rank 0 write to tensorboard to avoid multi-process write
303
+ if writer is not None:
304
+ if (
305
+ info_dict["train_engine"] == "deepspeed"
306
+ and info_dict["is_gradient_accumulation_boundary"] is True
307
+ ) or (
308
+ info_dict["train_engine"] == "torch_ddp"
309
+ and (info_dict["batch_idx"] + 1) % info_dict["accum_grad"] == 0
310
+ ):
311
+ for k in ["epoch", "lr", "grad_norm"]:
312
+ writer.add_scalar("{}/{}".format(tag, k), info_dict[k], step + 1)
313
+ for k, v in loss_dict.items():
314
+ writer.add_scalar("{}/{}".format(tag, k), v, step + 1)
315
+
316
+ # TRAIN & CV, Shell log (stdout)
317
+ if (info_dict["batch_idx"] + 1) % info_dict["log_interval"] == 0:
318
+ log_str = "{} Batch {}/{} ".format(tag, epoch, batch_idx + 1)
319
+ for name, value in loss_dict.items():
320
+ log_str += "{} {:.6f} ".format(name, value)
321
+ if tag == "TRAIN":
322
+ log_str += "lr {:.8f} grad_norm {:.6f}".format(
323
+ info_dict["lr"], info_dict["grad_norm"]
324
+ )
325
+ log_str += " rank {}".format(rank)
326
+ logging.debug(log_str)
327
+
328
+
329
+ def log_per_save(writer, info_dict):
330
+ tag = info_dict["tag"]
331
+ epoch = info_dict["epoch"]
332
+ step = info_dict["step"]
333
+ loss_dict = info_dict["loss_dict"]
334
+ lr = info_dict["lr"]
335
+ rank = int(os.environ.get("RANK", 0))
336
+ logging.info(
337
+ "Epoch {} Step {} CV info lr {} {} rank {}".format(
338
+ epoch,
339
+ step + 1,
340
+ lr,
341
+ rank,
342
+ " ".join(["{}_{}".format(k, v) for k, v in loss_dict.items()]),
343
+ )
344
+ )
345
+
346
+ if writer is not None:
347
+ for k in ["epoch", "lr"]:
348
+ writer.add_scalar("{}/{}".format(tag, k), info_dict[k], step + 1)
349
+ for k, v in loss_dict.items():
350
+ writer.add_scalar("{}/{}".format(tag, k), v, step + 1)
funasr_detach/__init__.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Initialize funasr package."""
2
+
3
+ import os
4
+ import pkgutil
5
+ import importlib
6
+
7
+ dirname = os.path.dirname(__file__)
8
+ version_file = os.path.join(dirname, "version.txt")
9
+ with open(version_file, "r") as f:
10
+ __version__ = f.read().strip()
11
+
12
+
13
+ import importlib
14
+ import pkgutil
15
+
16
+
17
+ def import_submodules(package, recursive=True):
18
+ if isinstance(package, str):
19
+ package = importlib.import_module(package)
20
+ results = {}
21
+ for loader, name, is_pkg in pkgutil.walk_packages(
22
+ package.__path__, package.__name__ + "."
23
+ ):
24
+ try:
25
+ results[name] = importlib.import_module(name)
26
+ except Exception as e:
27
+ # 如果想要看到导入错误的具体信息,可以取消注释下面的行
28
+ # print(f"Failed to import {name}: {e}")
29
+ pass
30
+ if recursive and is_pkg:
31
+ results.update(import_submodules(name))
32
+ return results
33
+
34
+
35
+ import_submodules(__name__)
36
+
37
+ from funasr_detach.auto.auto_model import AutoModel
38
+ from funasr_detach.auto.auto_frontend import AutoFrontend
funasr_detach/auto/__init__.py ADDED
File without changes
funasr_detach/auto/auto_frontend.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import logging
3
+ from tqdm import tqdm
4
+
5
+ from funasr_detach.register import tables
6
+ from funasr_detach.download.download_from_hub import download_model
7
+ from funasr_detach.utils.load_utils import load_audio_text_image_video, extract_fbank
8
+ from funasr_detach.auto.auto_model import prepare_data_iterator
9
+ from funasr_detach.auto.auto_model import prepare_data_iterator
10
+
11
+
12
+ class AutoFrontend:
13
+ def __init__(self, **kwargs):
14
+ assert "model" in kwargs
15
+ if "model_conf" not in kwargs:
16
+ logging.info(
17
+ "download models from model hub: {}".format(
18
+ kwargs.get("model_hub", "ms")
19
+ )
20
+ )
21
+ kwargs = download_model(**kwargs)
22
+
23
+ # build frontend
24
+ frontend = kwargs.get("frontend", None)
25
+ if frontend is not None:
26
+ frontend_class = tables.frontend_classes.get(frontend)
27
+ frontend = frontend_class(**kwargs["frontend_conf"])
28
+
29
+ self.frontend = frontend
30
+ if "frontend" in kwargs:
31
+ del kwargs["frontend"]
32
+ self.kwargs = kwargs
33
+
34
+ def __call__(self, input, input_len=None, kwargs=None, **cfg):
35
+
36
+ kwargs = self.kwargs if kwargs is None else kwargs
37
+ kwargs.update(cfg)
38
+
39
+ key_list, data_list = prepare_data_iterator(input, input_len=input_len)
40
+ batch_size = kwargs.get("batch_size", 1)
41
+ device = kwargs.get("device", "cpu")
42
+ if device == "cpu":
43
+ batch_size = 1
44
+
45
+ meta_data = {}
46
+
47
+ result_list = []
48
+ num_samples = len(data_list)
49
+ pbar = tqdm(colour="blue", total=num_samples + 1, dynamic_ncols=True)
50
+
51
+ time0 = time.perf_counter()
52
+ for beg_idx in range(0, num_samples, batch_size):
53
+ end_idx = min(num_samples, beg_idx + batch_size)
54
+ data_batch = data_list[beg_idx:end_idx]
55
+ key_batch = key_list[beg_idx:end_idx]
56
+
57
+ # extract fbank feats
58
+ time1 = time.perf_counter()
59
+ audio_sample_list = load_audio_text_image_video(
60
+ data_batch, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000)
61
+ )
62
+ time2 = time.perf_counter()
63
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
64
+ speech, speech_lengths = extract_fbank(
65
+ audio_sample_list,
66
+ data_type=kwargs.get("data_type", "sound"),
67
+ frontend=self.frontend,
68
+ **kwargs,
69
+ )
70
+ time3 = time.perf_counter()
71
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
72
+ meta_data["batch_data_time"] = (
73
+ speech_lengths.sum().item()
74
+ * self.frontend.frame_shift
75
+ * self.frontend.lfr_n
76
+ / 1000
77
+ )
78
+
79
+ speech.to(device=device), speech_lengths.to(device=device)
80
+ batch = {"input": speech, "input_len": speech_lengths, "key": key_batch}
81
+ result_list.append(batch)
82
+
83
+ pbar.update(1)
84
+ description = f"{meta_data}, "
85
+ pbar.set_description(description)
86
+
87
+ time_end = time.perf_counter()
88
+ pbar.set_description(f"time escaped total: {time_end - time0:0.3f}")
89
+
90
+ return result_list
funasr_detach/auto/auto_model.py ADDED
@@ -0,0 +1,573 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import time
3
+ import copy
4
+ import torch
5
+ import random
6
+ import string
7
+ import logging
8
+ import os.path
9
+ import numpy as np
10
+ from tqdm import tqdm
11
+
12
+ from funasr_detach.register import tables
13
+ from funasr_detach.utils.load_utils import load_bytes
14
+ from funasr_detach.download.file import download_from_url
15
+ from funasr_detach.download.download_from_hub import download_model
16
+ from funasr_detach.utils.vad_utils import slice_padding_audio_samples
17
+ from funasr_detach.train_utils.set_all_random_seed import set_all_random_seed
18
+ from funasr_detach.train_utils.load_pretrained_model import load_pretrained_model
19
+ from funasr_detach.utils.load_utils import load_audio_text_image_video
20
+ from funasr_detach.utils.timestamp_tools import timestamp_sentence
21
+ from funasr_detach.models.campplus.utils import sv_chunk, postprocess, distribute_spk
22
+
23
+ try:
24
+ from funasr_detach.models.campplus.cluster_backend import ClusterBackend
25
+ except:
26
+ print("If you want to use the speaker diarization, please `pip install hdbscan`")
27
+
28
+
29
+ def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
30
+ """
31
+
32
+ :param input:
33
+ :param input_len:
34
+ :param data_type:
35
+ :param frontend:
36
+ :return:
37
+ """
38
+ data_list = []
39
+ key_list = []
40
+ filelist = [".scp", ".txt", ".json", ".jsonl"]
41
+
42
+ chars = string.ascii_letters + string.digits
43
+ if isinstance(data_in, str) and data_in.startswith("http"): # url
44
+ data_in = download_from_url(data_in)
45
+ if isinstance(data_in, str) and os.path.exists(
46
+ data_in
47
+ ): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
48
+ _, file_extension = os.path.splitext(data_in)
49
+ file_extension = file_extension.lower()
50
+ if file_extension in filelist: # filelist: wav.scp, file.jsonl;text.txt;
51
+ with open(data_in, encoding="utf-8") as fin:
52
+ for line in fin:
53
+ key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
54
+ if data_in.endswith(
55
+ ".jsonl"
56
+ ): # file.jsonl: json.dumps({"source": data})
57
+ lines = json.loads(line.strip())
58
+ data = lines["source"]
59
+ key = data["key"] if "key" in data else key
60
+ else: # filelist, wav.scp, text.txt: id \t data or data
61
+ lines = line.strip().split(maxsplit=1)
62
+ data = lines[1] if len(lines) > 1 else lines[0]
63
+ key = lines[0] if len(lines) > 1 else key
64
+
65
+ data_list.append(data)
66
+ key_list.append(key)
67
+ else:
68
+ key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
69
+ data_list = [data_in]
70
+ key_list = [key]
71
+ elif isinstance(data_in, (list, tuple)):
72
+ if data_type is not None and isinstance(
73
+ data_type, (list, tuple)
74
+ ): # mutiple inputs
75
+ data_list_tmp = []
76
+ for data_in_i, data_type_i in zip(data_in, data_type):
77
+ key_list, data_list_i = prepare_data_iterator(
78
+ data_in=data_in_i, data_type=data_type_i
79
+ )
80
+ data_list_tmp.append(data_list_i)
81
+ data_list = []
82
+ for item in zip(*data_list_tmp):
83
+ data_list.append(item)
84
+ else:
85
+ # [audio sample point, fbank, text]
86
+ data_list = data_in
87
+ key_list = [
88
+ "rand_key_" + "".join(random.choice(chars) for _ in range(13))
89
+ for _ in range(len(data_in))
90
+ ]
91
+ else: # raw text; audio sample point, fbank; bytes
92
+ if isinstance(data_in, bytes): # audio bytes
93
+ data_in = load_bytes(data_in)
94
+ if key is None:
95
+ key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
96
+ data_list = [data_in]
97
+ key_list = [key]
98
+
99
+ return key_list, data_list
100
+
101
+
102
+ class AutoModel:
103
+
104
+ def __init__(self, **kwargs):
105
+ if not kwargs.get("disable_log", False):
106
+ tables.print()
107
+
108
+ model, kwargs = self.build_model(**kwargs)
109
+
110
+ # if vad_model is not None, build vad model else None
111
+ vad_model = kwargs.get("vad_model", None)
112
+ vad_kwargs = kwargs.get("vad_model_revision", None)
113
+ if vad_model is not None:
114
+ logging.info("Building VAD model.")
115
+ vad_kwargs = {
116
+ "model": vad_model,
117
+ "model_revision": vad_kwargs,
118
+ "device": kwargs["device"],
119
+ }
120
+ vad_model, vad_kwargs = self.build_model(**vad_kwargs)
121
+
122
+ # if punc_model is not None, build punc model else None
123
+ punc_model = kwargs.get("punc_model", None)
124
+ punc_kwargs = kwargs.get("punc_model_revision", None)
125
+ if punc_model is not None:
126
+ logging.info("Building punc model.")
127
+ punc_kwargs = {
128
+ "model": punc_model,
129
+ "model_revision": punc_kwargs,
130
+ "device": kwargs["device"],
131
+ }
132
+ punc_model, punc_kwargs = self.build_model(**punc_kwargs)
133
+
134
+ # if spk_model is not None, build spk model else None
135
+ spk_model = kwargs.get("spk_model", None)
136
+ spk_kwargs = kwargs.get("spk_model_revision", None)
137
+ if spk_model is not None:
138
+ logging.info("Building SPK model.")
139
+ spk_kwargs = {
140
+ "model": spk_model,
141
+ "model_revision": spk_kwargs,
142
+ "device": kwargs["device"],
143
+ }
144
+ spk_model, spk_kwargs = self.build_model(**spk_kwargs)
145
+ self.cb_model = ClusterBackend().to(kwargs["device"])
146
+ spk_mode = kwargs.get("spk_mode", "punc_segment")
147
+ if spk_mode not in ["default", "vad_segment", "punc_segment"]:
148
+ logging.error(
149
+ "spk_mode should be one of default, vad_segment and punc_segment."
150
+ )
151
+ self.spk_mode = spk_mode
152
+
153
+ self.kwargs = kwargs
154
+ self.model = model
155
+ self.vad_model = vad_model
156
+ self.vad_kwargs = vad_kwargs
157
+ self.punc_model = punc_model
158
+ self.punc_kwargs = punc_kwargs
159
+ self.spk_model = spk_model
160
+ self.spk_kwargs = spk_kwargs
161
+ self.model_path = kwargs.get("model_path")
162
+
163
+ def build_model(self, **kwargs):
164
+ assert "model" in kwargs
165
+ if "model_conf" not in kwargs:
166
+ logging.info(
167
+ "download models from model hub: {}".format(
168
+ kwargs.get("model_hub", "ms")
169
+ )
170
+ )
171
+ kwargs = download_model(**kwargs)
172
+
173
+ set_all_random_seed(kwargs.get("seed", 0))
174
+
175
+ device = kwargs.get("device", "cuda")
176
+ if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0:
177
+ device = "cpu"
178
+ kwargs["batch_size"] = 1
179
+ kwargs["device"] = device
180
+
181
+ if kwargs.get("ncpu", None):
182
+ torch.set_num_threads(kwargs.get("ncpu"))
183
+
184
+ # build tokenizer
185
+ tokenizer = kwargs.get("tokenizer", None)
186
+ if tokenizer is not None:
187
+ tokenizer_class = tables.tokenizer_classes.get(tokenizer)
188
+ tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
189
+ kwargs["tokenizer"] = tokenizer
190
+ kwargs["token_list"] = tokenizer.token_list
191
+ vocab_size = len(tokenizer.token_list)
192
+ else:
193
+ vocab_size = -1
194
+
195
+ # build frontend
196
+ frontend = kwargs.get("frontend", None)
197
+ if frontend is not None:
198
+ frontend_class = tables.frontend_classes.get(frontend)
199
+ frontend = frontend_class(**kwargs["frontend_conf"])
200
+ kwargs["frontend"] = frontend
201
+ kwargs["input_size"] = frontend.output_size()
202
+
203
+ # build model
204
+ model_class = tables.model_classes.get(kwargs["model"])
205
+ model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
206
+
207
+ model.to(device)
208
+
209
+ # init_param
210
+ init_param = kwargs.get("init_param", None)
211
+ if init_param is not None:
212
+ logging.info(f"Loading pretrained params from {init_param}")
213
+ load_pretrained_model(
214
+ model=model,
215
+ path=init_param,
216
+ ignore_init_mismatch=kwargs.get("ignore_init_mismatch", False),
217
+ oss_bucket=kwargs.get("oss_bucket", None),
218
+ scope_map=kwargs.get("scope_map", None),
219
+ excludes=kwargs.get("excludes", None),
220
+ )
221
+
222
+ return model, kwargs
223
+
224
+ def __call__(self, *args, **cfg):
225
+ kwargs = self.kwargs
226
+ kwargs.update(cfg)
227
+ res = self.model(*args, kwargs)
228
+ return res
229
+
230
+ def generate(self, input, input_len=None, **cfg):
231
+ if self.vad_model is None:
232
+ return self.inference(input, input_len=input_len, **cfg)
233
+
234
+ else:
235
+ return self.inference_with_vad(input, input_len=input_len, **cfg)
236
+
237
+ def inference(
238
+ self, input, input_len=None, model=None, kwargs=None, key=None, **cfg
239
+ ):
240
+ kwargs = self.kwargs if kwargs is None else kwargs
241
+ kwargs.update(cfg)
242
+ model = self.model if model is None else model
243
+ model = model.cuda()
244
+ model.eval()
245
+
246
+ batch_size = kwargs.get("batch_size", 1)
247
+ # if kwargs.get("device", "cpu") == "cpu":
248
+ # batch_size = 1
249
+
250
+ key_list, data_list = prepare_data_iterator(
251
+ input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key
252
+ )
253
+
254
+ speed_stats = {}
255
+ asr_result_list = []
256
+ num_samples = len(data_list)
257
+ disable_pbar = kwargs.get("disable_pbar", False)
258
+ pbar = (
259
+ tqdm(colour="blue", total=num_samples, dynamic_ncols=True)
260
+ if not disable_pbar
261
+ else None
262
+ )
263
+ time_speech_total = 0.0
264
+ time_escape_total = 0.0
265
+ for beg_idx in range(0, num_samples, batch_size):
266
+ end_idx = min(num_samples, beg_idx + batch_size)
267
+ data_batch = data_list[beg_idx:end_idx]
268
+ key_batch = key_list[beg_idx:end_idx]
269
+ batch = {"data_in": data_batch, "key": key_batch}
270
+ if (end_idx - beg_idx) == 1 and kwargs.get(
271
+ "data_type", None
272
+ ) == "fbank": # fbank
273
+ batch["data_in"] = data_batch[0]
274
+ batch["data_lengths"] = input_len
275
+
276
+ time1 = time.perf_counter()
277
+ with torch.no_grad():
278
+ results, meta_data = model.inference(**batch, **kwargs)
279
+ time2 = time.perf_counter()
280
+
281
+ asr_result_list.extend(results)
282
+
283
+ # batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
284
+ batch_data_time = meta_data.get("batch_data_time", -1)
285
+ time_escape = time2 - time1
286
+ speed_stats["load_data"] = meta_data.get("load_data", 0.0)
287
+ speed_stats["extract_feat"] = meta_data.get("extract_feat", 0.0)
288
+ speed_stats["forward"] = f"{time_escape:0.3f}"
289
+ speed_stats["batch_size"] = f"{len(results)}"
290
+ speed_stats["time_cost"] = f"{(time_escape)}"
291
+ speed_stats["rtf"] = f"{(time_escape) / batch_data_time:0.3f}"
292
+ description = f"{speed_stats}, "
293
+ if pbar:
294
+ pbar.update(1)
295
+ pbar.set_description(description)
296
+ time_speech_total += batch_data_time
297
+ time_escape_total += time_escape
298
+
299
+ if pbar:
300
+ # pbar.update(1)
301
+ pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
302
+ torch.cuda.empty_cache()
303
+ return asr_result_list
304
+
305
+ def inference_with_vad(self, input, input_len=None, **cfg):
306
+
307
+ # step.1: compute the vad model
308
+ self.vad_kwargs.update(cfg)
309
+ beg_vad = time.time()
310
+ res = self.inference(
311
+ input,
312
+ input_len=input_len,
313
+ model=self.vad_model,
314
+ kwargs=self.vad_kwargs,
315
+ **cfg,
316
+ )
317
+ end_vad = time.time()
318
+ print(f"time cost vad: {end_vad - beg_vad:0.3f}")
319
+
320
+ # step.2 compute asr model
321
+ model = self.model
322
+ kwargs = self.kwargs
323
+ kwargs.update(cfg)
324
+ batch_size = int(kwargs.get("batch_size_s", 300)) * 1000
325
+ batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60)) * 1000
326
+ kwargs["batch_size"] = batch_size
327
+
328
+ key_list, data_list = prepare_data_iterator(
329
+ input, input_len=input_len, data_type=kwargs.get("data_type", None)
330
+ )
331
+ results_ret_list = []
332
+ time_speech_total_all_samples = 1e-6
333
+
334
+ beg_total = time.time()
335
+ pbar_total = tqdm(colour="red", total=len(res), dynamic_ncols=True)
336
+ for i in range(len(res)):
337
+ key = res[i]["key"]
338
+ vadsegments = res[i]["value"]
339
+ input_i = data_list[i]
340
+ speech = load_audio_text_image_video(
341
+ input_i, fs=kwargs["frontend"].fs, audio_fs=kwargs.get("fs", 16000)
342
+ )
343
+ speech_lengths = len(speech)
344
+ n = len(vadsegments)
345
+ data_with_index = [(vadsegments[i], i) for i in range(n)]
346
+ sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
347
+ results_sorted = []
348
+
349
+ if not len(sorted_data):
350
+ logging.info("decoding, utt: {}, empty speech".format(key))
351
+ continue
352
+
353
+ if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
354
+ batch_size = max(
355
+ batch_size, sorted_data[0][0][1] - sorted_data[0][0][0]
356
+ )
357
+
358
+ batch_size_ms_cum = 0
359
+ beg_idx = 0
360
+ beg_asr_total = time.time()
361
+ time_speech_total_per_sample = speech_lengths / 16000
362
+ time_speech_total_all_samples += time_speech_total_per_sample
363
+
364
+ all_segments = []
365
+ for j, _ in enumerate(range(0, n)):
366
+ # pbar_sample.update(1)
367
+ batch_size_ms_cum += sorted_data[j][0][1] - sorted_data[j][0][0]
368
+ if (
369
+ j < n - 1
370
+ and (
371
+ batch_size_ms_cum
372
+ + sorted_data[j + 1][0][1]
373
+ - sorted_data[j + 1][0][0]
374
+ )
375
+ < batch_size
376
+ and (sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0])
377
+ < batch_size_threshold_ms
378
+ ):
379
+ continue
380
+ batch_size_ms_cum = 0
381
+ end_idx = j + 1
382
+ speech_j, speech_lengths_j = slice_padding_audio_samples(
383
+ speech, speech_lengths, sorted_data[beg_idx:end_idx]
384
+ )
385
+ results = self.inference(
386
+ speech_j,
387
+ input_len=None,
388
+ model=model,
389
+ kwargs=kwargs,
390
+ disable_pbar=True,
391
+ **cfg,
392
+ )
393
+ if self.spk_model is not None:
394
+ # compose vad segments: [[start_time_sec, end_time_sec, speech], [...]]
395
+ for _b in range(len(speech_j)):
396
+ vad_segments = [
397
+ [
398
+ sorted_data[beg_idx:end_idx][_b][0][0] / 1000.0,
399
+ sorted_data[beg_idx:end_idx][_b][0][1] / 1000.0,
400
+ np.array(speech_j[_b]),
401
+ ]
402
+ ]
403
+ segments = sv_chunk(vad_segments)
404
+ all_segments.extend(segments)
405
+ speech_b = [i[2] for i in segments]
406
+ spk_res = self.inference(
407
+ speech_b,
408
+ input_len=None,
409
+ model=self.spk_model,
410
+ kwargs=kwargs,
411
+ disable_pbar=True,
412
+ **cfg,
413
+ )
414
+ results[_b]["spk_embedding"] = spk_res[0]["spk_embedding"]
415
+ beg_idx = end_idx
416
+ if len(results) < 1:
417
+ continue
418
+ results_sorted.extend(results)
419
+
420
+ restored_data = [0] * n
421
+ for j in range(n):
422
+ index = sorted_data[j][1]
423
+ restored_data[index] = results_sorted[j]
424
+ result = {}
425
+
426
+ # results combine for texts, timestamps, speaker embeddings and others
427
+ # TODO: rewrite for clean code
428
+ for j in range(n):
429
+ for k, v in restored_data[j].items():
430
+ if k.startswith("timestamp"):
431
+ if k not in result:
432
+ result[k] = []
433
+ for t in restored_data[j][k]:
434
+ t[0] += vadsegments[j][0]
435
+ t[1] += vadsegments[j][0]
436
+ result[k].extend(restored_data[j][k])
437
+ elif k == "spk_embedding":
438
+ if k not in result:
439
+ result[k] = restored_data[j][k]
440
+ else:
441
+ result[k] = torch.cat(
442
+ [result[k], restored_data[j][k]], dim=0
443
+ )
444
+ elif "text" in k:
445
+ if k not in result:
446
+ result[k] = restored_data[j][k]
447
+ else:
448
+ result[k] += " " + restored_data[j][k]
449
+ else:
450
+ if k not in result:
451
+ result[k] = restored_data[j][k]
452
+ else:
453
+ result[k] += restored_data[j][k]
454
+
455
+ return_raw_text = kwargs.get("return_raw_text", False)
456
+ # step.3 compute punc model
457
+ if self.punc_model is not None:
458
+ self.punc_kwargs.update(cfg)
459
+ punc_res = self.inference(
460
+ result["text"],
461
+ model=self.punc_model,
462
+ kwargs=self.punc_kwargs,
463
+ disable_pbar=True,
464
+ **cfg,
465
+ )
466
+ raw_text = copy.copy(result["text"])
467
+ if return_raw_text:
468
+ result["raw_text"] = raw_text
469
+ result["text"] = punc_res[0]["text"]
470
+ else:
471
+ raw_text = None
472
+
473
+ # speaker embedding cluster after resorted
474
+ if self.spk_model is not None and kwargs.get("return_spk_res", True):
475
+ if raw_text is None:
476
+ logging.error("Missing punc_model, which is required by spk_model.")
477
+ all_segments = sorted(all_segments, key=lambda x: x[0])
478
+ spk_embedding = result["spk_embedding"]
479
+ labels = self.cb_model(
480
+ spk_embedding.cpu(), oracle_num=kwargs.get("preset_spk_num", None)
481
+ )
482
+ # del result['spk_embedding']
483
+ sv_output = postprocess(all_segments, None, labels, spk_embedding.cpu())
484
+ if self.spk_mode == "vad_segment": # recover sentence_list
485
+ sentence_list = []
486
+ for res, vadsegment in zip(restored_data, vadsegments):
487
+ if "timestamp" not in res:
488
+ logging.error(
489
+ "Only 'iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch' \
490
+ and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\
491
+ can predict timestamp, and speaker diarization relies on timestamps."
492
+ )
493
+ sentence_list.append(
494
+ {
495
+ "start": vadsegment[0],
496
+ "end": vadsegment[1],
497
+ "sentence": res["text"],
498
+ "timestamp": res["timestamp"],
499
+ }
500
+ )
501
+ elif self.spk_mode == "punc_segment":
502
+ if "timestamp" not in result:
503
+ logging.error(
504
+ "Only 'iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch' \
505
+ and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\
506
+ can predict timestamp, and speaker diarization relies on timestamps."
507
+ )
508
+ sentence_list = timestamp_sentence(
509
+ punc_res[0]["punc_array"],
510
+ result["timestamp"],
511
+ raw_text,
512
+ return_raw_text=return_raw_text,
513
+ )
514
+ distribute_spk(sentence_list, sv_output)
515
+ result["sentence_info"] = sentence_list
516
+ elif kwargs.get("sentence_timestamp", False):
517
+ sentence_list = timestamp_sentence(
518
+ punc_res[0]["punc_array"],
519
+ result["timestamp"],
520
+ raw_text,
521
+ return_raw_text=return_raw_text,
522
+ )
523
+ result["sentence_info"] = sentence_list
524
+ if "spk_embedding" in result:
525
+ del result["spk_embedding"]
526
+
527
+ result["key"] = key
528
+ results_ret_list.append(result)
529
+ end_asr_total = time.time()
530
+ time_escape_total_per_sample = end_asr_total - beg_asr_total
531
+ pbar_total.update(1)
532
+ pbar_total.set_description(
533
+ f"rtf_avg: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
534
+ f"time_speech: {time_speech_total_per_sample: 0.3f}, "
535
+ f"time_escape: {time_escape_total_per_sample:0.3f}"
536
+ )
537
+
538
+ return results_ret_list
539
+
540
+ def infer_encoder(
541
+ self, input, input_len=None, model=None, kwargs=None, key=None, **cfg
542
+ ):
543
+ kwargs = self.kwargs if kwargs is None else kwargs
544
+ kwargs.update(cfg)
545
+ model = self.model if model is None else model
546
+ model = model.cuda()
547
+ model.eval()
548
+
549
+ batch_size = kwargs.get("batch_size", 1)
550
+
551
+ key_list, data_list = prepare_data_iterator(
552
+ input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key
553
+ )
554
+
555
+ asr_result_list = []
556
+ num_samples = len(data_list)
557
+ for beg_idx in range(0, num_samples, batch_size):
558
+ end_idx = min(num_samples, beg_idx + batch_size)
559
+ data_batch = data_list[beg_idx:end_idx]
560
+ key_batch = key_list[beg_idx:end_idx]
561
+ batch = {"data_in": data_batch, "key": key_batch}
562
+ if (end_idx - beg_idx) == 1 and kwargs.get(
563
+ "data_type", None
564
+ ) == "fbank": # fbank
565
+ batch["data_in"] = data_batch[0]
566
+ batch["data_lengths"] = input_len
567
+
568
+ with torch.no_grad():
569
+ results, meta_data, cache = model.infer_encoder(**batch, **kwargs)
570
+ asr_result_list.extend(results)
571
+
572
+ torch.cuda.empty_cache()
573
+ return asr_result_list, cache
funasr_detach/auto/auto_tokenizer.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ class AutoTokenizer:
2
+ """
3
+ Undo
4
+ """
5
+
6
+ def __init__(self):
7
+ pass