guanwenhao commited on
Commit
795b0a7
·
verified ·
1 Parent(s): c719855

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +1 -302
README.md CHANGED
@@ -34,308 +34,7 @@ By combining autoregression and flow matching, MonoSpeech establishes a foundati
34
 
35
  ## 2. Quick Start
36
 
37
- Please refer to [**Github Repository**](https://github.com/gwh22/MonoSpeech)
38
 
39
 
40
 
41
- ## 3. Usage
42
- For Zero-shot TTS :
43
- ```py
44
- import argparse
45
- import json
46
- import multiprocessing as mp
47
- import os
48
- import socket
49
- from typing import List, Optional
50
- from tqdm import tqdm
51
- import random
52
-
53
- import transformers
54
- import torch
55
- import torchaudio
56
- import torch.distributed as dist
57
- import numpy as np
58
- from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
59
- from transformers import pipeline
60
-
61
- from monospeech.monospeech_model import MonoSpeech
62
- from monospeech.constants import *
63
-
64
- from monospeech.utils import MelSpec, make_pad_mask, MelSpec_bigvGAN, MelSpec_Taco
65
- from monospeech.tensor_util import spec_to_figure, spec_to_figure_single
66
-
67
- def setup_seed(seed):
68
- torch.manual_seed(seed)
69
- torch.cuda.manual_seed_all(seed)
70
- np.random.seed(seed)
71
- random.seed(seed)
72
- torch.backends.cudnn.deterministic = True
73
-
74
- @torch.no_grad()
75
- def main():
76
- parser = argparse.ArgumentParser()
77
- parser.add_argument("--ckpt_path", type=str, required=True)
78
- parser.add_argument("--llm_path", type=str, required=True)
79
- parser.add_argument("--cfg_scale", type=float, required=True)
80
-
81
- args = parser.parse_args()
82
-
83
- rank = int(os.environ["LOCAL_RANK"])
84
- world_size = int(os.environ["WORLD_SIZE"])
85
-
86
- dist.init_process_group("nccl", rank=rank, world_size=world_size)
87
- torch.cuda.set_device(rank)
88
- setup_seed(42) # random seed default=42
89
-
90
- # load tokenizer
91
- tokenizer = AutoTokenizer.from_pretrained(args.llm_path, add_bos_token=True, add_eos_token=True)
92
-
93
- # load model
94
- model_config = AutoConfig.from_pretrained(args.llm_path)
95
- model_config.learn_sigma = True
96
- model_config.tokenizer_max_length = 1024
97
- model_config.tokenizer_padding_side = 'right'
98
- model_config.use_flash_attn = False
99
- # model_config.attn_implementation="flash_attention_2" if model_config.use_flash_attn==True else "eager"
100
- model_config.use_pos_embed = True
101
- model_config.decoder_t_embed = "add_before_speech_tokens"
102
- model_config.use_adaln_final_layer = True
103
- model_config.use_bi_attn_img_tokens = True # or False for causal DiT
104
- model_config.add_pos_embed_each_layer = False
105
- model_config.use_hybrid_attn_mask = False
106
- model_config.audio_encoder_path = 'hf_ckpts/whisper-large-v3'
107
- model_config.speaker_encoder_path = 'hf_ckpts/wav2vec2-large-xlsr-53'
108
- model = MonoSpeech(
109
- model_config,
110
- llm_path = args.llm_path,
111
- tokenizer = tokenizer,
112
- cfg_scale = args.cfg_scale,
113
- )
114
- ckpt_type = args.ckpt_path.split(".")[-1]
115
- if ckpt_type == "safetensors":
116
- from safetensors.torch import load_file
117
- checkpoint = load_file(args.ckpt_path, device='cuda')
118
- else:
119
- checkpoint = torch.load(args.ckpt_path, map_location='cuda')
120
- model.load_state_dict(checkpoint)
121
- model.eval().cuda()
122
-
123
-
124
- # wav_path for speaker
125
- wav_path = "data/LJ001-0001.wav"
126
-
127
- audio, source_sample_rate = torchaudio.load(wav_path)
128
- if audio.shape[0] > 1: # mono
129
- audio = torch.mean(audio, dim=0, keepdim=True)
130
- if source_sample_rate != 22050: # whisper---16KHZ
131
- resampler = torchaudio.transforms.Resample(source_sample_rate, 22050)
132
- audio = resampler(audio)
133
- mel_spectrogram = MelSpec_bigvGAN(
134
- n_fft=1024,
135
- hop_length=256,
136
- win_length=1024,
137
- n_mel_channels=80,
138
- target_sample_rate=22050,
139
- )
140
- mel_spec = mel_spectrogram(audio)
141
- mel_spec = [mel_spec.squeeze(0).to('cuda')] # (D,T)
142
- speechs = [[]]
143
- flags = [[0]]
144
-
145
-
146
- # duration set by yourself
147
- duration = 6
148
- target_len = [int(duration*22050//256)] # mel_spec[0].shape[1].
149
- text = ["At once the goat gave a leap, escaped from the soldiers and with bowed head rushed upon the Boolooroo".lower()]
150
-
151
- temp = torch.randn(1).to('cuda')
152
- with torch.inference_mode():
153
- mel_out, mel_gt = model.sample(
154
- input_ids=temp,
155
- attention_mask=temp,
156
- labels=temp,
157
- mel_spec=mel_spec,
158
- speechs=speechs,
159
- flags=flags,
160
- target_len=target_len,
161
- text=text,
162
- wav_path=[wav_path],
163
- )
164
- text_name = '_'.join(text[0].strip().split())
165
- os.makedirs('infers', exist_ok=True)
166
- # bigvagn vocoder
167
- from BigVGAN import bigvgan
168
- vocoder = bigvgan.BigVGAN.from_pretrained('hf_ckpts/bigvgan_22k', use_cuda_kernel=False)
169
- vocoder.remove_weight_norm()
170
- vocoder = vocoder.eval().to('cuda')
171
-
172
- # generate waveform from mel
173
- with torch.inference_mode():
174
- wav_gen = vocoder(mel_out.transpose(0,1).unsqueeze(0)) # wav_gen is FloatTensor with shape [B(1), 1, T_time] and values in [-1, 1]
175
- wav_gen_float = wav_gen.squeeze(0).cpu()
176
- # wav_gen_int16 = (wav_gen_float * 32767.0).numpy().astype('int16') # wav_gen is now np.ndarray with shape [1, T_time] and int16 dtype
177
-
178
- torchaudio.save(f'infers/{text_name}.wav', wav_gen_float, 22050)
179
-
180
-
181
- if __name__ == "__main__":
182
- main()
183
- ```
184
-
185
- For ASR :
186
- ```py
187
- import argparse
188
- import json
189
- import multiprocessing as mp
190
- import os
191
- import socket
192
- from typing import List, Optional
193
-
194
- import transformers
195
- import random
196
- import numpy as np
197
- import torch
198
- import torchaudio
199
- import torch.distributed as dist
200
- from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
201
- from transformers import pipeline
202
-
203
- from monospeech.monospeech_model import MonoSpeech
204
- from monospeech.constants import *
205
-
206
-
207
- def setup_seed(seed):
208
- torch.manual_seed(seed)
209
- torch.cuda.manual_seed_all(seed)
210
- np.random.seed(seed)
211
- random.seed(seed)
212
- torch.backends.cudnn.deterministic = True
213
-
214
- def preprocess_inputs(tokenizer: transformers.PreTrainedTokenizer, inputs: List[str], speechs: List[torch.Tensor], max_length=512, device='cuda'):
215
- """
216
- Currently, only support batch size 1.
217
- """
218
- assert len(inputs) == 1
219
-
220
- input_ids, attention_mask = tokenizer(
221
- inputs,
222
- max_length=max_length,
223
- truncation=True,
224
- add_special_tokens=False,
225
- return_tensors="pt",
226
- ).values()
227
-
228
- if len(speechs) > 0:
229
- im_start_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_START_TOKEN)
230
- im_end_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_END_TOKEN)
231
- speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
232
- for cur_input_ids in input_ids:
233
- for idx in torch.where(cur_input_ids == im_start_token_id):
234
- if cur_input_ids[idx + 1] == tokenizer.pad_token_id:
235
- cur_input_ids[idx + 1] = speech_token_id
236
-
237
- attention_mask = input_ids.ne(tokenizer.pad_token_id)
238
-
239
- flags = [[1]]
240
- else:
241
- flags = []
242
-
243
- return {
244
- 'input_ids': input_ids.to(device),
245
- 'attention_mask': attention_mask.to(device),
246
- 'speechs': [speechs],
247
- 'flags': flags,
248
- 't': torch.tensor([0]).to(device),
249
- }
250
-
251
-
252
- @torch.no_grad()
253
- def main():
254
- parser = argparse.ArgumentParser()
255
- parser.add_argument("--ckpt_path", type=str, required=True)
256
- parser.add_argument("--temperature", type=float, default=0.2)
257
- parser.add_argument("--top_p", type=float, default=0.9)
258
- parser.add_argument("--top_k", type=int, default=50)
259
- parser.add_argument("--num_beams", type=int, default=1)
260
- parser.add_argument("--llm_path", type=str, required=True)
261
- args = parser.parse_args()
262
-
263
- rank = int(os.environ["LOCAL_RANK"])
264
- world_size = int(os.environ["WORLD_SIZE"])
265
-
266
- dist.init_process_group("nccl", rank=rank, world_size=world_size)
267
- torch.cuda.set_device(rank)
268
- setup_seed(42) # random seed default=42
269
-
270
- # load tokenizer
271
- tokenizer = AutoTokenizer.from_pretrained(args.llm_path, add_bos_token=True, add_eos_token=True)
272
-
273
- # # load model
274
- model_config = AutoConfig.from_pretrained(args.llm_path)
275
- model_config.learn_sigma = True
276
- model_config.tokenizer_max_length = 1024
277
- model_config.tokenizer_padding_side = 'right'
278
- model_config.use_flash_attn = False
279
- # model_config.attn_implementation="flash_attention_2" if model_config.use_flash_attn==True else "eager"
280
- model_config.use_pos_embed = True
281
- model_config.decoder_t_embed = "add_before_speech_tokens"
282
- model_config.use_adaln_final_layer = True
283
- model_config.use_bi_attn_img_tokens = True # or False for causal DiT
284
- model_config.add_pos_embed_each_layer = False
285
- model_config.use_hybrid_attn_mask = False
286
- model_config.audio_encoder_path = 'hf_ckpts/whisper-large-v3'
287
- model_config.speaker_encoder_path = 'hf_ckpts/wav2vec2-large-xlsr-53'
288
- model = MonoSpeech(
289
- model_config,
290
- llm_path = args.llm_path,
291
- tokenizer = tokenizer,
292
- cfg_scale = 1,
293
- )
294
- ckpt_type = args.ckpt_path.split(".")[-1]
295
- if ckpt_type == "safetensors":
296
- from safetensors.torch import load_file
297
- checkpoint = load_file(args.ckpt_path, device='cuda')
298
- else:
299
- checkpoint = torch.load(args.ckpt_path, map_location='cuda')
300
- model.load_state_dict(checkpoint)
301
- model.eval().cuda()
302
-
303
- feature_extracter = transformers.WhisperFeatureExtractor.from_pretrained('hf_ckpts/whisper-large-v3')
304
-
305
- # asr wav_path
306
- wav_path = "data/LJ001-0001.wav"
307
- audio, source_sample_rate = torchaudio.load(wav_path)
308
- if audio.shape[0] > 1: # mono
309
- audio = torch.mean(audio, dim=0, keepdim=True)
310
- if source_sample_rate != 16000: # whisper---16KHZ
311
- resampler = torchaudio.transforms.Resample(source_sample_rate, 16000)
312
- audio = resampler(audio)
313
-
314
- mel_spec = feature_extracter(audio.numpy(), sampling_rate=16000).input_features[0]
315
- mel_spec = torch.tensor(mel_spec, dtype=torch.float32)
316
- # speechs and prompt
317
- speechs = [mel_spec.to('cuda')]
318
- prompt = f"{DEFAULT_SPEECH_START_TOKEN}{DEFAULT_PAD_TOKEN}{DEFAULT_SPEECH_END_TOKEN}\n"
319
- inputs = [f"{tokenizer.bos_token}{prompt}"]
320
-
321
- inputs_dict = preprocess_inputs(tokenizer, inputs, speechs)
322
-
323
- with torch.inference_mode():
324
- output_ids = model.generate(
325
- input_ids=inputs_dict['input_ids'],
326
- attention_mask=inputs_dict['attention_mask'],
327
- speechs=inputs_dict['speechs'],
328
- flags=inputs_dict['flags'],
329
- t=inputs_dict['t'],
330
- temperature=args.temperature,
331
- top_p=args.top_p,
332
- top_k=args.top_k,
333
- num_beams=args.num_beams,
334
- )
335
- output_ids = output_ids.replace("\n"," ").replace("<|im_end|>","")
336
- print(output_ids)
337
-
338
-
339
- if __name__ == "__main__":
340
- main()
341
- ```
 
34
 
35
  ## 2. Quick Start
36
 
37
+ Please refer to [**Github Repository**](https://github.com/gwh22/Univoice)
38
 
39
 
40