Spaces:
Paused
Paused
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
- src/f5_tts/api.py +1 -1
- src/f5_tts/infer/infer_cli.py +13 -2
- src/f5_tts/infer/utils_infer.py +3 -3
src/f5_tts/api.py
CHANGED
|
@@ -119,7 +119,7 @@ class F5TTS:
|
|
| 119 |
seed_everything(seed)
|
| 120 |
self.seed = seed
|
| 121 |
|
| 122 |
-
ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text
|
| 123 |
|
| 124 |
wav, sr, spec = infer_process(
|
| 125 |
ref_file,
|
|
|
|
| 119 |
seed_everything(seed)
|
| 120 |
self.seed = seed
|
| 121 |
|
| 122 |
+
ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text)
|
| 123 |
|
| 124 |
wav, sr, spec = infer_process(
|
| 125 |
ref_file,
|
src/f5_tts/infer/infer_cli.py
CHANGED
|
@@ -162,6 +162,11 @@ parser.add_argument(
|
|
| 162 |
type=float,
|
| 163 |
help=f"Fix the total duration (ref and gen audios) in seconds, default {fix_duration}",
|
| 164 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
args = parser.parse_args()
|
| 166 |
|
| 167 |
|
|
@@ -202,6 +207,7 @@ cfg_strength = args.cfg_strength or config.get("cfg_strength", cfg_strength)
|
|
| 202 |
sway_sampling_coef = args.sway_sampling_coef or config.get("sway_sampling_coef", sway_sampling_coef)
|
| 203 |
speed = args.speed or config.get("speed", speed)
|
| 204 |
fix_duration = args.fix_duration or config.get("fix_duration", fix_duration)
|
|
|
|
| 205 |
|
| 206 |
|
| 207 |
# patches for pip pkg user
|
|
@@ -239,7 +245,9 @@ if vocoder_name == "vocos":
|
|
| 239 |
elif vocoder_name == "bigvgan":
|
| 240 |
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
|
| 241 |
|
| 242 |
-
vocoder = load_vocoder(
|
|
|
|
|
|
|
| 243 |
|
| 244 |
|
| 245 |
# load TTS model
|
|
@@ -270,7 +278,9 @@ if not ckpt_file:
|
|
| 270 |
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}"))
|
| 271 |
|
| 272 |
print(f"Using {model}...")
|
| 273 |
-
ema_model = load_model(
|
|
|
|
|
|
|
| 274 |
|
| 275 |
|
| 276 |
# inference process
|
|
@@ -326,6 +336,7 @@ def main():
|
|
| 326 |
sway_sampling_coef=sway_sampling_coef,
|
| 327 |
speed=speed,
|
| 328 |
fix_duration=fix_duration,
|
|
|
|
| 329 |
)
|
| 330 |
generated_audio_segments.append(audio_segment)
|
| 331 |
|
|
|
|
| 162 |
type=float,
|
| 163 |
help=f"Fix the total duration (ref and gen audios) in seconds, default {fix_duration}",
|
| 164 |
)
|
| 165 |
+
parser.add_argument(
|
| 166 |
+
"--device",
|
| 167 |
+
type=str,
|
| 168 |
+
help="Specify the device to run on",
|
| 169 |
+
)
|
| 170 |
args = parser.parse_args()
|
| 171 |
|
| 172 |
|
|
|
|
| 207 |
sway_sampling_coef = args.sway_sampling_coef or config.get("sway_sampling_coef", sway_sampling_coef)
|
| 208 |
speed = args.speed or config.get("speed", speed)
|
| 209 |
fix_duration = args.fix_duration or config.get("fix_duration", fix_duration)
|
| 210 |
+
device = args.device
|
| 211 |
|
| 212 |
|
| 213 |
# patches for pip pkg user
|
|
|
|
| 245 |
elif vocoder_name == "bigvgan":
|
| 246 |
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
|
| 247 |
|
| 248 |
+
vocoder = load_vocoder(
|
| 249 |
+
vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path, device=device
|
| 250 |
+
)
|
| 251 |
|
| 252 |
|
| 253 |
# load TTS model
|
|
|
|
| 278 |
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}"))
|
| 279 |
|
| 280 |
print(f"Using {model}...")
|
| 281 |
+
ema_model = load_model(
|
| 282 |
+
model_cls, model_arc, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file, device=device
|
| 283 |
+
)
|
| 284 |
|
| 285 |
|
| 286 |
# inference process
|
|
|
|
| 336 |
sway_sampling_coef=sway_sampling_coef,
|
| 337 |
speed=speed,
|
| 338 |
fix_duration=fix_duration,
|
| 339 |
+
device=device,
|
| 340 |
)
|
| 341 |
generated_audio_segments.append(audio_segment)
|
| 342 |
|
src/f5_tts/infer/utils_infer.py
CHANGED
|
@@ -149,7 +149,7 @@ def initialize_asr_pipeline(device: str = device, dtype=None):
|
|
| 149 |
dtype = (
|
| 150 |
torch.float16
|
| 151 |
if "cuda" in device
|
| 152 |
-
and torch.cuda.get_device_properties(device).major >=
|
| 153 |
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
|
| 154 |
else torch.float32
|
| 155 |
)
|
|
@@ -186,7 +186,7 @@ def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
|
|
| 186 |
dtype = (
|
| 187 |
torch.float16
|
| 188 |
if "cuda" in device
|
| 189 |
-
and torch.cuda.get_device_properties(device).major >=
|
| 190 |
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
|
| 191 |
else torch.float32
|
| 192 |
)
|
|
@@ -289,7 +289,7 @@ def remove_silence_edges(audio, silence_threshold=-42):
|
|
| 289 |
# preprocess reference audio and text
|
| 290 |
|
| 291 |
|
| 292 |
-
def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print
|
| 293 |
show_info("Converting audio...")
|
| 294 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
| 295 |
aseg = AudioSegment.from_file(ref_audio_orig)
|
|
|
|
| 149 |
dtype = (
|
| 150 |
torch.float16
|
| 151 |
if "cuda" in device
|
| 152 |
+
and torch.cuda.get_device_properties(device).major >= 7
|
| 153 |
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
|
| 154 |
else torch.float32
|
| 155 |
)
|
|
|
|
| 186 |
dtype = (
|
| 187 |
torch.float16
|
| 188 |
if "cuda" in device
|
| 189 |
+
and torch.cuda.get_device_properties(device).major >= 7
|
| 190 |
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
|
| 191 |
else torch.float32
|
| 192 |
)
|
|
|
|
| 289 |
# preprocess reference audio and text
|
| 290 |
|
| 291 |
|
| 292 |
+
def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print):
|
| 293 |
show_info("Converting audio...")
|
| 294 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
| 295 |
aseg = AudioSegment.from_file(ref_audio_orig)
|