Spaces:
Running
on
Zero
Running
on
Zero
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
- README_REPO.md +3 -0
- inference-cli.py +51 -15
README_REPO.md
CHANGED
|
@@ -86,6 +86,9 @@ Currently support 30s for a single generation, which is the **TOTAL** length of
|
|
| 86 |
|
| 87 |
Either you can specify everything in `inference-cli.toml` or override with flags. Leave `--ref_text ""` will have ASR model transcribe the reference audio automatically (use extra GPU memory). If encounter network error, consider use local ckpt, just set `ckpt_path` in `inference-cli.py`
|
| 88 |
|
|
|
|
|
|
|
|
|
|
| 89 |
```bash
|
| 90 |
python inference-cli.py \
|
| 91 |
--model "F5-TTS" \
|
|
|
|
| 86 |
|
| 87 |
Either you can specify everything in `inference-cli.toml` or override with flags. Leave `--ref_text ""` will have ASR model transcribe the reference audio automatically (use extra GPU memory). If encounter network error, consider use local ckpt, just set `ckpt_path` in `inference-cli.py`
|
| 88 |
|
| 89 |
+
for change model use --ckpt_file to specify the model you want to load,
|
| 90 |
+
for change vocab.txt use --vocab_file to provide your vocab.txt file.
|
| 91 |
+
|
| 92 |
```bash
|
| 93 |
python inference-cli.py \
|
| 94 |
--model "F5-TTS" \
|
inference-cli.py
CHANGED
|
@@ -36,6 +36,16 @@ parser.add_argument(
|
|
| 36 |
"--model",
|
| 37 |
help="F5-TTS | E2-TTS",
|
| 38 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
parser.add_argument(
|
| 40 |
"-r",
|
| 41 |
"--ref_audio",
|
|
@@ -88,6 +98,8 @@ if gen_file:
|
|
| 88 |
gen_text = codecs.open(gen_file, "r", "utf-8").read()
|
| 89 |
output_dir = args.output_dir if args.output_dir else config["output_dir"]
|
| 90 |
model = args.model if args.model else config["model"]
|
|
|
|
|
|
|
| 91 |
remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
|
| 92 |
wave_path = Path(output_dir)/"out.wav"
|
| 93 |
spectrogram_path = Path(output_dir)/"out.png"
|
|
@@ -125,11 +137,19 @@ speed = 1.0
|
|
| 125 |
# fix_duration = 27 # None or float (duration in seconds)
|
| 126 |
fix_duration = None
|
| 127 |
|
| 128 |
-
def load_model(
|
| 129 |
-
|
| 130 |
-
if
|
| 131 |
-
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
model = CFM(
|
| 134 |
transformer=model_cls(
|
| 135 |
**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
|
|
@@ -149,14 +169,12 @@ def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
|
|
| 149 |
|
| 150 |
return model
|
| 151 |
|
| 152 |
-
|
| 153 |
# load models
|
| 154 |
F5TTS_model_cfg = dict(
|
| 155 |
dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
|
| 156 |
)
|
| 157 |
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
| 158 |
|
| 159 |
-
|
| 160 |
def chunk_text(text, max_chars=135):
|
| 161 |
"""
|
| 162 |
Splits the input text into chunks, each with a maximum number of characters.
|
|
@@ -184,12 +202,29 @@ def chunk_text(text, max_chars=135):
|
|
| 184 |
|
| 185 |
return chunks
|
| 186 |
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
-
def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence, cross_fade_duration=0.15):
|
| 189 |
if model == "F5-TTS":
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
elif model == "E2-TTS":
|
| 192 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
|
| 194 |
audio, sr = ref_audio
|
| 195 |
if audio.shape[0] > 1:
|
|
@@ -325,7 +360,7 @@ def process_voice(ref_audio_orig, ref_text):
|
|
| 325 |
print("Using custom reference text...")
|
| 326 |
return ref_audio, ref_text
|
| 327 |
|
| 328 |
-
def infer(ref_audio, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15):
|
| 329 |
print(gen_text)
|
| 330 |
# Add the functionality to ensure it ends with ". "
|
| 331 |
if not ref_text.endswith(". ") and not ref_text.endswith("。"):
|
|
@@ -343,10 +378,10 @@ def infer(ref_audio, ref_text, gen_text, model, remove_silence, cross_fade_durat
|
|
| 343 |
print(f'gen_text {i}', gen_text)
|
| 344 |
|
| 345 |
print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...")
|
| 346 |
-
return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence, cross_fade_duration)
|
| 347 |
|
| 348 |
|
| 349 |
-
def process(ref_audio, ref_text, text_gen, model, remove_silence):
|
| 350 |
main_voice = {"ref_audio":ref_audio, "ref_text":ref_text}
|
| 351 |
if "voices" not in config:
|
| 352 |
voices = {"main": main_voice}
|
|
@@ -371,7 +406,7 @@ def process(ref_audio, ref_text, text_gen, model, remove_silence):
|
|
| 371 |
ref_audio = voices[voice]['ref_audio']
|
| 372 |
ref_text = voices[voice]['ref_text']
|
| 373 |
print(f"Voice: {voice}")
|
| 374 |
-
audio, spectragram = infer(ref_audio, ref_text, gen_text, model, remove_silence)
|
| 375 |
generated_audio_segments.append(audio)
|
| 376 |
|
| 377 |
if generated_audio_segments:
|
|
@@ -389,4 +424,5 @@ def process(ref_audio, ref_text, text_gen, model, remove_silence):
|
|
| 389 |
aseg.export(f.name, format="wav")
|
| 390 |
print(f.name)
|
| 391 |
|
| 392 |
-
|
|
|
|
|
|
| 36 |
"--model",
|
| 37 |
help="F5-TTS | E2-TTS",
|
| 38 |
)
|
| 39 |
+
parser.add_argument(
|
| 40 |
+
"-p",
|
| 41 |
+
"--ckpt_file",
|
| 42 |
+
help="The Checkpoint .pt",
|
| 43 |
+
)
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
"-v",
|
| 46 |
+
"--vocab_file",
|
| 47 |
+
help="The vocab .txt",
|
| 48 |
+
)
|
| 49 |
parser.add_argument(
|
| 50 |
"-r",
|
| 51 |
"--ref_audio",
|
|
|
|
| 98 |
gen_text = codecs.open(gen_file, "r", "utf-8").read()
|
| 99 |
output_dir = args.output_dir if args.output_dir else config["output_dir"]
|
| 100 |
model = args.model if args.model else config["model"]
|
| 101 |
+
ckpt_file = args.ckpt_file if args.ckpt_file else ""
|
| 102 |
+
vocab_file = args.vocab_file if args.vocab_file else ""
|
| 103 |
remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
|
| 104 |
wave_path = Path(output_dir)/"out.wav"
|
| 105 |
spectrogram_path = Path(output_dir)/"out.png"
|
|
|
|
| 137 |
# fix_duration = 27 # None or float (duration in seconds)
|
| 138 |
fix_duration = None
|
| 139 |
|
| 140 |
+
def load_model(model_cls, model_cfg, ckpt_path,file_vocab):
|
| 141 |
+
|
| 142 |
+
if file_vocab=="":
|
| 143 |
+
file_vocab="Emilia_ZH_EN"
|
| 144 |
+
tokenizer="pinyin"
|
| 145 |
+
else:
|
| 146 |
+
tokenizer="custom"
|
| 147 |
+
|
| 148 |
+
print("\nvocab : ",vocab_file,tokenizer)
|
| 149 |
+
print("tokenizer : ",tokenizer)
|
| 150 |
+
print("model : ",ckpt_path,"\n")
|
| 151 |
+
|
| 152 |
+
vocab_char_map, vocab_size = get_tokenizer(file_vocab, tokenizer)
|
| 153 |
model = CFM(
|
| 154 |
transformer=model_cls(
|
| 155 |
**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
|
|
|
|
| 169 |
|
| 170 |
return model
|
| 171 |
|
|
|
|
| 172 |
# load models
|
| 173 |
F5TTS_model_cfg = dict(
|
| 174 |
dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
|
| 175 |
)
|
| 176 |
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
| 177 |
|
|
|
|
| 178 |
def chunk_text(text, max_chars=135):
|
| 179 |
"""
|
| 180 |
Splits the input text into chunks, each with a maximum number of characters.
|
|
|
|
| 202 |
|
| 203 |
return chunks
|
| 204 |
|
| 205 |
+
#ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
|
| 206 |
+
#if not Path(ckpt_path).exists():
|
| 207 |
+
#ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
| 208 |
|
| 209 |
+
def infer_batch(ref_audio, ref_text, gen_text_batches, model,ckpt_file,file_vocab, remove_silence, cross_fade_duration=0.15):
|
| 210 |
if model == "F5-TTS":
|
| 211 |
+
|
| 212 |
+
if ckpt_file == "":
|
| 213 |
+
repo_name= "F5-TTS"
|
| 214 |
+
exp_name = "F5TTS_Base"
|
| 215 |
+
ckpt_step= 1200000
|
| 216 |
+
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
| 217 |
+
|
| 218 |
+
ema_model = load_model(DiT, F5TTS_model_cfg, ckpt_file,file_vocab)
|
| 219 |
+
|
| 220 |
elif model == "E2-TTS":
|
| 221 |
+
if ckpt_file == "":
|
| 222 |
+
repo_name= "E2-TTS"
|
| 223 |
+
exp_name = "E2TTS_Base"
|
| 224 |
+
ckpt_step= 1200000
|
| 225 |
+
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
| 226 |
+
|
| 227 |
+
ema_model = load_model(UNetT, E2TTS_model_cfg, ckpt_file,file_vocab)
|
| 228 |
|
| 229 |
audio, sr = ref_audio
|
| 230 |
if audio.shape[0] > 1:
|
|
|
|
| 360 |
print("Using custom reference text...")
|
| 361 |
return ref_audio, ref_text
|
| 362 |
|
| 363 |
+
def infer(ref_audio, ref_text, gen_text, model,ckpt_file,file_vocab, remove_silence, cross_fade_duration=0.15):
|
| 364 |
print(gen_text)
|
| 365 |
# Add the functionality to ensure it ends with ". "
|
| 366 |
if not ref_text.endswith(". ") and not ref_text.endswith("。"):
|
|
|
|
| 378 |
print(f'gen_text {i}', gen_text)
|
| 379 |
|
| 380 |
print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...")
|
| 381 |
+
return infer_batch((audio, sr), ref_text, gen_text_batches, model,ckpt_file,file_vocab, remove_silence, cross_fade_duration)
|
| 382 |
|
| 383 |
|
| 384 |
+
def process(ref_audio, ref_text, text_gen, model,ckpt_file,file_vocab, remove_silence):
|
| 385 |
main_voice = {"ref_audio":ref_audio, "ref_text":ref_text}
|
| 386 |
if "voices" not in config:
|
| 387 |
voices = {"main": main_voice}
|
|
|
|
| 406 |
ref_audio = voices[voice]['ref_audio']
|
| 407 |
ref_text = voices[voice]['ref_text']
|
| 408 |
print(f"Voice: {voice}")
|
| 409 |
+
audio, spectragram = infer(ref_audio, ref_text, gen_text, model,ckpt_file,file_vocab, remove_silence)
|
| 410 |
generated_audio_segments.append(audio)
|
| 411 |
|
| 412 |
if generated_audio_segments:
|
|
|
|
| 424 |
aseg.export(f.name, format="wav")
|
| 425 |
print(f.name)
|
| 426 |
|
| 427 |
+
|
| 428 |
+
process(ref_audio, ref_text, gen_text, model,ckpt_file,vocab_file, remove_silence)
|