guanwenhao commited on
Commit
d979ba0
·
verified ·
1 Parent(s): 4a94982

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +307 -0
README.md CHANGED
@@ -12,6 +12,8 @@ tags:
12
  - tts
13
  - asr
14
  - unified_model
 
 
15
  ---
16
 
17
  ## 1. Introduction
@@ -35,3 +37,308 @@ By combining autoregression and flow matching, MonoSpeech establishes a foundati
35
  Please refer to [**Github Repository**](https://github.com/gwh22/MonoSpeech)
36
 
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  - tts
13
  - asr
14
  - unified_model
15
+ pipeline_tag: any-to-any
16
+ library_name: transformers
17
  ---
18
 
19
  ## 1. Introduction
 
37
  Please refer to [**Github Repository**](https://github.com/gwh22/MonoSpeech)
38
 
39
 
40
+
41
+ ## 2. Usage
42
+ For Zero-shot TTS :
43
+ ```py
44
+ import argparse
45
+ import json
46
+ import multiprocessing as mp
47
+ import os
48
+ import socket
49
+ from typing import List, Optional
50
+ from tqdm import tqdm
51
+ import random
52
+
53
+ import transformers
54
+ import torch
55
+ import torchaudio
56
+ import torch.distributed as dist
57
+ import numpy as np
58
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
59
+ from transformers import pipeline
60
+
61
+ from monospeech.monospeech_model import MonoSpeech
62
+ from monospeech.constants import *
63
+
64
+ from monospeech.utils import read_config_from_file
65
+ from monospeech.utils import MelSpec, make_pad_mask, MelSpec_bigvGAN, MelSpec_Taco
66
+ from monospeech.tensor_util import spec_to_figure, spec_to_figure_single
67
+
68
+ def setup_seed(seed):
69
+ torch.manual_seed(seed)
70
+ torch.cuda.manual_seed_all(seed)
71
+ np.random.seed(seed)
72
+ random.seed(seed)
73
+ torch.backends.cudnn.deterministic = True
74
+
75
+ @torch.no_grad()
76
+ def main():
77
+ parser = argparse.ArgumentParser()
78
+ parser.add_argument("--ckpt_path", type=str, required=True)
79
+ parser.add_argument("--llm_path", type=str, required=True)
80
+ parser.add_argument("--cfg_scale", type=float, required=True)
81
+
82
+ args = parser.parse_args()
83
+
84
+ rank = int(os.environ["LOCAL_RANK"])
85
+ world_size = int(os.environ["WORLD_SIZE"])
86
+
87
+ dist.init_process_group("nccl", rank=rank, world_size=world_size)
88
+ torch.cuda.set_device(rank)
89
+ setup_seed(42) # random seed default=42
90
+
91
+ # load tokenizer
92
+ tokenizer = AutoTokenizer.from_pretrained(args.llm_path, add_bos_token=True, add_eos_token=True)
93
+
94
+ # load model
95
+ model_config = AutoConfig.from_pretrained(args.llm_path)
96
+ model_config.learn_sigma = True
97
+ model_config.tokenizer_max_length = 1024
98
+ model_config.tokenizer_padding_side = 'right'
99
+ model_config.use_flash_attn = False
100
+ # model_config.attn_implementation="flash_attention_2" if model_config.use_flash_attn==True else "eager"
101
+ model_config.use_pos_embed = True
102
+ model_config.decoder_t_embed = "add_before_speech_tokens"
103
+ model_config.use_adaln_final_layer = True
104
+ model_config.use_bi_attn_img_tokens = True # or False for causal DiT
105
+ model_config.add_pos_embed_each_layer = False
106
+ model_config.use_hybrid_attn_mask = False
107
+ model_config.audio_encoder_path = 'hf_ckpts/whisper-large-v3'
108
+ model_config.speaker_encoder_path = 'hf_ckpts/wav2vec2-large-xlsr-53'
109
+ model = MonoSpeech(
110
+ model_config,
111
+ llm_path = args.llm_path,
112
+ tokenizer = tokenizer,
113
+ cfg_scale = args.cfg_scale,
114
+ )
115
+ ckpt_type = args.ckpt_path.split(".")[-1]
116
+ if ckpt_type == "safetensors":
117
+ from safetensors.torch import load_file
118
+ checkpoint = load_file(args.ckpt_path, device='cuda')
119
+ else:
120
+ checkpoint = torch.load(args.ckpt_path, map_location='cuda')
121
+ model.load_state_dict(checkpoint)
122
+ model.eval().cuda()
123
+
124
+
125
+ # wav_path for speaker
126
+ wav_path = "data/LJ001-0001.wav"
127
+
128
+ audio, source_sample_rate = torchaudio.load(wav_path)
129
+ if audio.shape[0] > 1: # mono
130
+ audio = torch.mean(audio, dim=0, keepdim=True)
131
+ if source_sample_rate != 22050: # whisper---16KHZ
132
+ resampler = torchaudio.transforms.Resample(source_sample_rate, 22050)
133
+ audio = resampler(audio)
134
+ mel_spectrogram = MelSpec_bigvGAN(
135
+ n_fft=1024,
136
+ hop_length=256,
137
+ win_length=1024,
138
+ n_mel_channels=80,
139
+ target_sample_rate=22050,
140
+ )
141
+ mel_spec = mel_spectrogram(audio)
142
+ mel_spec = [mel_spec.squeeze(0).to('cuda')] # (D,T)
143
+ speechs = [[]]
144
+ flags = [[0]]
145
+
146
+
147
+ # duration set by yourself
148
+ duration = 6
149
+ target_len = [int(duration*22050//256)] # mel_spec[0].shape[1].
150
+ text = ["At once the goat gave a leap, escaped from the soldiers and with bowed head rushed upon the Boolooroo".lower()]
151
+
152
+ temp = torch.randn(1).to('cuda')
153
+ with torch.inference_mode():
154
+ mel_out, mel_gt = model.sample(
155
+ input_ids=temp,
156
+ attention_mask=temp,
157
+ labels=temp,
158
+ mel_spec=mel_spec,
159
+ speechs=speechs,
160
+ flags=flags,
161
+ target_len=target_len,
162
+ text=text,
163
+ wav_path=[wav_path],
164
+ )
165
+ text_name = '_'.join(text[0].strip().split())
166
+ os.makedirs('infers', exist_ok=True)
167
+ # bigvagn vocoder
168
+ from BigVGAN import bigvgan
169
+ vocoder = bigvgan.BigVGAN.from_pretrained('hf_ckpts/bigvgan_22k', use_cuda_kernel=False)
170
+ vocoder.remove_weight_norm()
171
+ vocoder = vocoder.eval().to('cuda')
172
+
173
+ # generate waveform from mel
174
+ with torch.inference_mode():
175
+ wav_gen = vocoder(mel_out.transpose(0,1).unsqueeze(0)) # wav_gen is FloatTensor with shape [B(1), 1, T_time] and values in [-1, 1]
176
+ wav_gen_float = wav_gen.squeeze(0).cpu()
177
+ # wav_gen_int16 = (wav_gen_float * 32767.0).numpy().astype('int16') # wav_gen is now np.ndarray with shape [1, T_time] and int16 dtype
178
+
179
+ torchaudio.save(f'infers/{text_name}.wav', wav_gen_float, 22050)
180
+
181
+
182
+ if __name__ == "__main__":
183
+ main()
184
+ ```
185
+
186
+ For ASR :
187
+ ```py
188
+ import argparse
189
+ import json
190
+ import multiprocessing as mp
191
+ import os
192
+ import socket
193
+ from typing import List, Optional
194
+
195
+ import transformers
196
+ import random
197
+ import numpy as np
198
+ import torch
199
+ import torchaudio
200
+ import torch.distributed as dist
201
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
202
+ from transformers import pipeline
203
+
204
+ from monospeech.monospeech_model import MonoSpeech
205
+ from monospeech.constants import *
206
+
207
+ from monospeech.utils import read_config_from_file
208
+
209
+ def setup_seed(seed):
210
+ torch.manual_seed(seed)
211
+ torch.cuda.manual_seed_all(seed)
212
+ np.random.seed(seed)
213
+ random.seed(seed)
214
+ torch.backends.cudnn.deterministic = True
215
+
216
+ def preprocess_inputs(tokenizer: transformers.PreTrainedTokenizer, inputs: List[str], speechs: List[torch.Tensor], max_length=512, device='cuda'):
217
+ """
218
+ Currently, only support batch size 1.
219
+ """
220
+ assert len(inputs) == 1
221
+
222
+ input_ids, attention_mask = tokenizer(
223
+ inputs,
224
+ max_length=max_length,
225
+ truncation=True,
226
+ add_special_tokens=False,
227
+ return_tensors="pt",
228
+ ).values()
229
+
230
+ if len(speechs) > 0:
231
+ # FIXME: replace pad token after <|im_start|> with <speech>, this is due to tokenizer cannot correctly tokenize <speech> after <|im_start|>
232
+ im_start_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_START_TOKEN)
233
+ im_end_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_END_TOKEN)
234
+ speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
235
+ for cur_input_ids in input_ids:
236
+ for idx in torch.where(cur_input_ids == im_start_token_id):
237
+ if cur_input_ids[idx + 1] == tokenizer.pad_token_id:
238
+ cur_input_ids[idx + 1] = speech_token_id
239
+
240
+ attention_mask = input_ids.ne(tokenizer.pad_token_id)
241
+
242
+ flags = [[1]]
243
+ else:
244
+ flags = []
245
+
246
+ return {
247
+ 'input_ids': input_ids.to(device),
248
+ 'attention_mask': attention_mask.to(device),
249
+ 'speechs': [speechs],
250
+ 'flags': flags,
251
+ 't': torch.tensor([0]).to(device),
252
+ }
253
+
254
+
255
+ @torch.no_grad()
256
+ def main():
257
+ parser = argparse.ArgumentParser()
258
+ parser.add_argument("--ckpt_path", type=str, required=True)
259
+ parser.add_argument("--temperature", type=float, default=0.2)
260
+ parser.add_argument("--top_p", type=float, default=0.9)
261
+ parser.add_argument("--top_k", type=int, default=50)
262
+ parser.add_argument("--num_beams", type=int, default=1)
263
+ parser.add_argument("--llm_path", type=str, required=True)
264
+ args = parser.parse_args()
265
+
266
+ rank = int(os.environ["LOCAL_RANK"])
267
+ world_size = int(os.environ["WORLD_SIZE"])
268
+
269
+ dist.init_process_group("nccl", rank=rank, world_size=world_size)
270
+ torch.cuda.set_device(rank)
271
+ setup_seed(42) # random seed default=42
272
+
273
+ # load tokenizer
274
+ tokenizer = AutoTokenizer.from_pretrained(args.llm_path, add_bos_token=True, add_eos_token=True)
275
+
276
+ # # load model
277
+ model_config = AutoConfig.from_pretrained(args.llm_path)
278
+ model_config.learn_sigma = True
279
+ model_config.tokenizer_max_length = 1024
280
+ model_config.tokenizer_padding_side = 'right'
281
+ model_config.use_flash_attn = False
282
+ # model_config.attn_implementation="flash_attention_2" if model_config.use_flash_attn==True else "eager"
283
+ model_config.use_pos_embed = True
284
+ model_config.decoder_t_embed = "add_before_speech_tokens"
285
+ model_config.use_adaln_final_layer = True
286
+ model_config.use_bi_attn_img_tokens = True # or False for causal DiT
287
+ model_config.add_pos_embed_each_layer = False
288
+ model_config.use_hybrid_attn_mask = False
289
+ model_config.audio_encoder_path = 'hf_ckpts/whisper-large-v3'
290
+ model_config.speaker_encoder_path = 'hf_ckpts/wav2vec2-large-xlsr-53'
291
+ model = MonoSpeech(
292
+ model_config,
293
+ llm_path = args.llm_path,
294
+ tokenizer = tokenizer,
295
+ cfg_scale = 1,
296
+ )
297
+ ckpt_type = args.ckpt_path.split(".")[-1]
298
+ if ckpt_type == "safetensors":
299
+ from safetensors.torch import load_file
300
+ checkpoint = load_file(args.ckpt_path, device='cuda')
301
+ else:
302
+ checkpoint = torch.load(args.ckpt_path, map_location='cuda')
303
+ model.load_state_dict(checkpoint)
304
+ model.eval().cuda()
305
+
306
+ feature_extracter = transformers.WhisperFeatureExtractor.from_pretrained('hf_ckpts/whisper-large-v3')
307
+
308
+ # asr wav_path
309
+ wav_path = "data/LJ001-0001.wav"
310
+ audio, source_sample_rate = torchaudio.load(wav_path)
311
+ if audio.shape[0] > 1: # mono
312
+ audio = torch.mean(audio, dim=0, keepdim=True)
313
+ if source_sample_rate != 16000: # whisper---16KHZ
314
+ resampler = torchaudio.transforms.Resample(source_sample_rate, 16000)
315
+ audio = resampler(audio)
316
+
317
+ mel_spec = feature_extracter(audio.numpy(), sampling_rate=16000).input_features[0]
318
+ mel_spec = torch.tensor(mel_spec, dtype=torch.float32)
319
+ # speechs and prompt
320
+ speechs = [mel_spec.to('cuda')]
321
+ prompt = f"{DEFAULT_SPEECH_START_TOKEN}{DEFAULT_PAD_TOKEN}{DEFAULT_SPEECH_END_TOKEN}\n"
322
+ inputs = [f"{tokenizer.bos_token}{prompt}"]
323
+
324
+ inputs_dict = preprocess_inputs(tokenizer, inputs, speechs)
325
+
326
+ with torch.inference_mode():
327
+ output_ids = model.generate(
328
+ input_ids=inputs_dict['input_ids'],
329
+ attention_mask=inputs_dict['attention_mask'],
330
+ speechs=inputs_dict['speechs'],
331
+ flags=inputs_dict['flags'],
332
+ t=inputs_dict['t'],
333
+ temperature=args.temperature,
334
+ top_p=args.top_p,
335
+ top_k=args.top_k,
336
+ num_beams=args.num_beams,
337
+ )
338
+ output_ids = output_ids.replace("\n"," ").replace("<|im_end|>","")
339
+ print(output_ids)
340
+
341
+
342
+ if __name__ == "__main__":
343
+ main()
344
+ ```