Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python3 | |
| # app.py — ESpeech-TTS с поддержкой ZeroGPU (Hugging Face Spaces) | |
| # ----------------- ZeroGPU / spaces импорт + fallback ----------------- | |
| # В среде ZeroGPU доступен пакет `spaces`, который предоставляет декоратор GPU. | |
| # Для локальной отладки мы делаем fallback — noop-декоратор. | |
| import spaces # provided by Spaces/ZeroGPU environment | |
| GPU_DECORATOR = spaces.GPU | |
| print("spaces module available — ZeroGPU features enabled") | |
| import os | |
| import gc | |
| import json | |
| import tempfile | |
| import traceback | |
| from pathlib import Path | |
| import gradio as gr | |
| import numpy as np | |
| import soundfile as sf | |
| import torch | |
| import torchaudio | |
| from huggingface_hub import hf_hub_download | |
| # Ваши зависимости / локальные импорты | |
| from ruaccent import RUAccent | |
| import onnx_asr | |
| from f5_tts.infer.utils_infer import ( | |
| infer_process, | |
| load_model, | |
| load_vocoder, | |
| preprocess_ref_audio_text, | |
| remove_silence_for_generated_wav, | |
| save_spectrogram, | |
| tempfile_kwargs, | |
| ) | |
| from f5_tts.model import DiT | |
| # Явно включаем ленивый режим кеширования примеров, чтобы примеры не запускались на старте | |
| # (ZeroGPU по умолчанию использует lazy — делаем это явным). | |
| os.environ.setdefault("GRADIO_CACHE_MODE", "lazy") | |
| os.environ.setdefault("GRADIO_CACHE_EXAMPLES", "lazy") | |
| # ----------------- HF hub / модели ----------------- | |
| # Настройте репозитории и имена файлов в Hub под себя | |
| MODEL_REPOS = { | |
| "ESpeech-TTS-1 [RL] V2": { | |
| "repo_id": "ESpeech/ESpeech-TTS-1_RL-V2", | |
| "filename": "espeech_tts_rlv2.pt", | |
| }, | |
| "ESpeech-TTS-1 [RL] V1": { | |
| "repo_id": "ESpeech/ESpeech-TTS-1_RL-V1", | |
| "filename": "espeech_tts_rlv1.pt", | |
| }, | |
| "ESpeech-TTS-1 [SFT] 95K": { | |
| "repo_id": "ESpeech/ESpeech-TTS-1_SFT-95K", | |
| "filename": "espeech_tts_95k.pt", | |
| }, | |
| "ESpeech-TTS-1 [SFT] 265K": { | |
| "repo_id": "ESpeech/ESpeech-TTS-1_SFT-256K", | |
| "filename": "espeech_tts_256k.pt", | |
| }, | |
| "ESpeech-TTS-1 PODCASTER [SFT]": { | |
| "repo_id": "ESpeech/ESpeech-TTS-1_podcaster", | |
| "filename": "espeech_tts_podcaster.pt", | |
| }, | |
| } | |
| # где лежит общий vocab в Hub | |
| VOCAB_REPO = "ESpeech/ESpeech-TTS-1_podcaster" | |
| VOCAB_FILENAME = "vocab.txt" | |
| # токен, если репозитории приватные (в Spaces обычно берут из Secrets) | |
| HF_TOKEN = os.environ.get("HUGGINGFACE_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") or None | |
| MODEL_CFG = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) | |
| # кэш локальных путей после hf_hub_download | |
| _cached_local_paths = {} | |
| loaded_models = {} # хранит объекты моделей в памяти (по имени выбора) | |
| # Пример текста для демонстрации | |
| EXAMPLE_TEXT = "Экспериментальный центр напоминает вам о том, что кубы не умеют разговаривать. В случае, если грузовой куб все же заговорит, центр настоятельно рекомендует вам игнорировать его советы." | |
| EXAMPLE_REF_AUDIO = "ref/example.mp3" | |
| # ----------------- Вспомогательные функции HF ----------------- | |
| def hf_download_file(repo_id: str, filename: str, token: str = None): | |
| try: | |
| print(f"hf_hub_download: {repo_id}/{filename}") | |
| p = hf_hub_download(repo_id=repo_id, filename=filename, token=token, repo_type="model") | |
| print(" ->", p) | |
| return p | |
| except Exception as e: | |
| print("Download error:", e) | |
| raise | |
| def get_vocab_path(): | |
| key = f"{VOCAB_REPO}::{VOCAB_FILENAME}" | |
| if key in _cached_local_paths and Path(_cached_local_paths[key]).exists(): | |
| return _cached_local_paths[key] | |
| p = hf_download_file(VOCAB_REPO, VOCAB_FILENAME, token=HF_TOKEN) | |
| _cached_local_paths[key] = p | |
| return p | |
| def get_model_local_path(choice: str): | |
| if choice not in MODEL_REPOS: | |
| raise KeyError("Unknown model choice: " + repr(choice)) | |
| repo = MODEL_REPOS[choice] | |
| key = f"{repo['repo_id']}::{repo['filename']}" | |
| if key in _cached_local_paths and Path(_cached_local_paths[key]).exists(): | |
| return _cached_local_paths[key] | |
| p = hf_download_file(repo["repo_id"], repo["filename"], token=HF_TOKEN) | |
| _cached_local_paths[key] = p | |
| return p | |
| def load_model_if_needed(choice: str): | |
| """ | |
| Лениво: если модель уже загружена в loaded_models — вернуть. | |
| Иначе скачать файл (если нужно) и вызвать вашу load_model (возвращает PyTorch модель в CPU). | |
| Не переводим на GPU здесь — это делается внутри GPU-декорированной функции. | |
| """ | |
| if choice in loaded_models: | |
| return loaded_models[choice] | |
| model_file = get_model_local_path(choice) | |
| vocab_file = get_vocab_path() | |
| print(f"Loading model into CPU memory: {choice} from {model_file}") | |
| model = load_model(DiT, MODEL_CFG, model_file, vocab_file=vocab_file) | |
| loaded_models[choice] = model | |
| return model | |
| # ----------------- общие ресурсы (vocoder, RUAccent, ASR) ----------------- | |
| print("Loading RUAccent...") | |
| accentizer = RUAccent() | |
| accentizer.load(omograph_model_size='turbo3.1', use_dictionary=True, tiny_mode=False) | |
| print("RUAccent loaded.") | |
| print("Loading ASR (onnx) ...") | |
| asr_model = onnx_asr.load_model("nemo-fastconformer-ru-rnnt") | |
| print("ASR ready.") | |
| print("Loading vocoder (CPU) ...") | |
| vocoder = load_vocoder() | |
| print("Vocoder loaded.") | |
| # ----------------- Функция для обработки текста с учетом "+" ----------------- | |
| def process_text_with_accent(text, accentizer): | |
| """ | |
| Обрабатывает текст через RUAccent, если в нем нет символа '+'. | |
| Если есть '+' - пользователь сам проставил ударения, не трогаем. | |
| """ | |
| if not text or not text.strip(): | |
| return text | |
| if '+' in text: | |
| # Пользователь сам проставил ударения | |
| return text | |
| else: | |
| # Прогоняем через RUAccent | |
| return accentizer.process_all(text) | |
| # ----------------- Функция для обработки текста без синтеза ----------------- | |
| def process_texts_only(ref_text, gen_text): | |
| """ | |
| Обрабатывает только тексты через RUAccent, не делая синтез. | |
| Возвращает обработанные тексты для обновления полей ввода. | |
| """ | |
| processed_ref_text = process_text_with_accent(ref_text, accentizer) | |
| processed_gen_text = process_text_with_accent(gen_text, accentizer) | |
| return processed_ref_text, processed_gen_text | |
| # ----------------- Основная функция синтеза (GPU-aware) ----------------- | |
| # Декорируем synthesize, чтобы при вызове Space выделял GPU (если доступно). | |
| # duration — сколько секунд просим GPU (адаптируйте под ваш инференс). | |
| def synthesize( | |
| model_choice, | |
| ref_audio, | |
| ref_text, | |
| gen_text, | |
| remove_silence, | |
| seed, | |
| cross_fade_duration=0.15, | |
| nfe_step=32, | |
| speed=1.0, | |
| ): | |
| """ | |
| Эта функция будет выполняться с выделенным GPU в ZeroGPU Spaces. | |
| Подход: | |
| - лениво загружаем модель (в CPU) если надо | |
| - переносим модель и (если требуется) vocoder на cuda | |
| - делаем infer | |
| - возвращаем модели на CPU и очищаем cuda cache | |
| """ | |
| if not ref_audio: | |
| gr.Warning("Please provide reference audio.") | |
| return None, None, ref_text, gen_text | |
| if seed is None or seed < 0 or seed > 2**31 - 1: | |
| seed = np.random.randint(0, 2**31 - 1) | |
| torch.manual_seed(int(seed)) | |
| if not gen_text or not gen_text.strip(): | |
| gr.Warning("Please enter text to generate.") | |
| return None, None, ref_text, gen_text | |
| # ASR если нужно | |
| if not ref_text or not ref_text.strip(): | |
| gr.Info("Reference text is empty. Running ASR to transcribe reference audio...") | |
| try: | |
| waveform, sample_rate = torchaudio.load(ref_audio) | |
| waveform = waveform.numpy() | |
| if waveform.dtype == np.int16: | |
| waveform = waveform / 2**15 | |
| elif waveform.dtype == np.int32: | |
| waveform = waveform / 2**31 | |
| if waveform.ndim == 2: | |
| waveform = waveform.mean(axis=0) | |
| transcribed_text = asr_model.recognize(waveform, sample_rate=sample_rate) | |
| ref_text = transcribed_text | |
| gr.Info(f"ASR transcription: {ref_text}") | |
| except Exception as e: | |
| gr.Warning(f"ASR failed: {e}") | |
| return None, None, ref_text, gen_text | |
| # Акцентирование с учетом наличия символа "+" | |
| processed_ref_text = process_text_with_accent(ref_text, accentizer) | |
| processed_gen_text = process_text_with_accent(gen_text, accentizer) | |
| # Ленивая загрузка модели (в CPU) | |
| try: | |
| model = load_model_if_needed(model_choice) | |
| except Exception as e: | |
| gr.Warning(f"Failed to download/load model {model_choice}: {e}") | |
| return None, None, processed_ref_text, processed_gen_text | |
| # Определяем устройство (в ZeroGPU внутри декоратора должен быть доступен CUDA) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| moved_to_cuda = [] | |
| try: | |
| # Переносим модель на GPU (если есть) | |
| if device.type == "cuda": | |
| try: | |
| model.to(device) | |
| moved_to_cuda.append(("model", model)) | |
| # если vocoder использует torch — переносим его тоже | |
| try: | |
| vocoder.to(device) | |
| moved_to_cuda.append(("vocoder", vocoder)) | |
| except Exception: | |
| # если vocoder не torch-объект — ок | |
| pass | |
| except Exception as e: | |
| print("Warning: failed to move model/vocoder to cuda:", e) | |
| # Препроцессинг рефа (оно ожидает путь/файл) | |
| try: | |
| ref_audio_proc, processed_ref_text_final = preprocess_ref_audio_text( | |
| ref_audio, | |
| processed_ref_text, | |
| show_info=gr.Info | |
| ) | |
| except Exception as e: | |
| gr.Warning(f"Preprocess failed: {e}") | |
| traceback.print_exc() | |
| return None, None, processed_ref_text, processed_gen_text | |
| # Инференс (предполагается, что infer_process корректно работает и на GPU) | |
| try: | |
| final_wave, final_sample_rate, combined_spectrogram = infer_process( | |
| ref_audio_proc, | |
| processed_ref_text_final, | |
| processed_gen_text, | |
| model, | |
| vocoder, | |
| cross_fade_duration=cross_fade_duration, | |
| nfe_step=nfe_step, | |
| speed=speed, | |
| show_info=gr.Info, | |
| progress=gr.Progress(), | |
| ) | |
| except Exception as e: | |
| gr.Warning(f"Infer failed: {e}") | |
| traceback.print_exc() | |
| return None, None, processed_ref_text, processed_gen_text | |
| # Удаление тишин (на CPU) | |
| if remove_silence: | |
| try: | |
| with tempfile.NamedTemporaryFile(suffix=".wav", **tempfile_kwargs) as f: | |
| temp_path = f.name | |
| sf.write(temp_path, final_wave, final_sample_rate) | |
| remove_silence_for_generated_wav(temp_path) | |
| final_wave_tensor, _ = torchaudio.load(temp_path) | |
| final_wave = final_wave_tensor.squeeze().cpu().numpy() | |
| except Exception as e: | |
| print("Remove silence failed:", e) | |
| # Сохраняем спектрограмму | |
| try: | |
| with tempfile.NamedTemporaryFile(suffix=".png", **tempfile_kwargs) as tmp_spectrogram: | |
| spectrogram_path = tmp_spectrogram.name | |
| save_spectrogram(combined_spectrogram, spectrogram_path) | |
| except Exception as e: | |
| print("Save spectrogram failed:", e) | |
| spectrogram_path = None | |
| return (final_sample_rate, final_wave), spectrogram_path, processed_ref_text_final, processed_gen_text | |
| finally: | |
| # Переносим всё обратно на CPU и очищаем GPU память | |
| if device.type == "cuda": | |
| try: | |
| for name, obj in moved_to_cuda: | |
| try: | |
| obj.to("cpu") | |
| except Exception: | |
| pass | |
| torch.cuda.empty_cache() | |
| # немножко сборки мусора | |
| gc.collect() | |
| except Exception as e: | |
| print("Warning during cuda cleanup:", e) | |
| # ----------------- Gradio UI (как у вас) ----------------- | |
| with gr.Blocks(title="ESpeech-TTS") as app: | |
| gr.Markdown("# ESpeech-TTS") | |
| gr.Markdown("Подробнее см. на https://huggingface.co/ESpeech") | |
| gr.Markdown("💡 **Совет:** Добавьте символ '+' в тексте, чтобы указать пользовательское ударение (например, 'прив+ет'). Текст с '+' не будет обрабатываться RUAccent.") | |
| # Описание моделей на русском языке | |
| gr.Markdown(""" | |
| ## 📋 Описание моделей: | |
| - **ESpeech-TTS-1 [RL] V1** - Первая версия модели с RL | |
| - **ESpeech-TTS-1 [RL] V2** - Вторая версия модели с RL | |
| - **ESpeech-TTS-1 PODCASTER [SFT]** - Модель обученная только на подкастах, лучше генерирует спонтанную речь | |
| - **ESpeech-TTS-1 [SFT] 95K** - чекпоинт с 95000 шагов (на нем основана RL V1) | |
| - **ESpeech-TTS-1 [SFT] 265K** - чекпоинт с 265000 шагов (на нем основана RL V2) | |
| """) | |
| model_choice = gr.Dropdown( | |
| choices=list(MODEL_REPOS.keys()), | |
| label="Select Model", | |
| value=list(MODEL_REPOS.keys())[0], | |
| interactive=True | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| ref_audio_input = gr.Audio(label="Reference Audio", type="filepath") | |
| ref_text_input = gr.Textbox( | |
| label="Reference Text", | |
| lines=2, | |
| placeholder="leave empty → ASR will transcribe" | |
| ) | |
| with gr.Column(): | |
| gen_text_input = gr.Textbox( | |
| label="Text to Generate", | |
| lines=5, | |
| max_lines=20, | |
| placeholder="Enter text to synthesize..." | |
| ) | |
| # Кнопка для обработки текста без синтеза | |
| process_text_btn = gr.Button("✏️ Process Text (Add Accents)", variant="secondary") | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Accordion("Advanced Settings", open=False): | |
| seed_input = gr.Number(label="Seed (-1 for random)", value=-1, precision=0) | |
| remove_silence = gr.Checkbox(label="Remove Silences", value=False) | |
| speed_slider = gr.Slider(label="Speed", minimum=0.3, maximum=2.0, value=1.0, step=0.1) | |
| nfe_slider = gr.Slider(label="NFE Steps", minimum=4, maximum=64, value=48, step=2) | |
| cross_fade_slider = gr.Slider(label="Cross-Fade Duration (s)", minimum=0.0, maximum=1.0, value=0.15, step=0.01) | |
| generate_btn = gr.Button("🎤 Generate Speech", variant="primary", size="lg") | |
| with gr.Row(): | |
| audio_output = gr.Audio(label="Generated Audio", type="numpy") | |
| spectrogram_output = gr.Image(label="Spectrogram", type="filepath") | |
| # Примеры | |
| gr.Markdown("## 🎯 Example") | |
| gr.Examples( | |
| examples=[ | |
| [ | |
| EXAMPLE_REF_AUDIO, # ref_audio | |
| "", # ref_text (empty for ASR) | |
| EXAMPLE_TEXT, # gen_text | |
| False, # remove_silence | |
| 42, # seed | |
| 0.15, # cross_fade | |
| 48, # nfe_step | |
| 1.0, # speed | |
| ] | |
| ], | |
| inputs=[ | |
| ref_audio_input, | |
| ref_text_input, | |
| gen_text_input, | |
| remove_silence, | |
| seed_input, | |
| cross_fade_slider, | |
| nfe_slider, | |
| speed_slider, | |
| ], | |
| outputs=[audio_output, spectrogram_output, ref_text_input, gen_text_input], | |
| fn=lambda *args: synthesize(model_choice.value, *args), | |
| cache_examples=True, | |
| run_on_click=True, | |
| ) | |
| # Обработка текста без синтеза | |
| process_text_btn.click( | |
| process_texts_only, | |
| inputs=[ref_text_input, gen_text_input], | |
| outputs=[ref_text_input, gen_text_input] | |
| ) | |
| # Основная генерация | |
| generate_btn.click( | |
| synthesize, | |
| inputs=[ | |
| model_choice, | |
| ref_audio_input, | |
| ref_text_input, | |
| gen_text_input, | |
| remove_silence, | |
| seed_input, | |
| cross_fade_slider, | |
| nfe_slider, | |
| speed_slider, | |
| ], | |
| outputs=[audio_output, spectrogram_output, ref_text_input, gen_text_input] | |
| ) | |
| if __name__ == "__main__": | |
| #app.launch(server_name="0.0.0.0", server_port=7860) | |
| app.launch() |