from flask import Flask, request, send_file import subprocess import os import sys from huggingface_hub import hf_hub_download app = Flask(__name__) # ========================= # Hàm chạy F5-TTS # ========================= def run_f5_tts(ref_audio_path, ref_text, gen_text, model="F5TTS_Base", speed=1.2, vocoder_name="vocos"): current_dir = os.path.dirname(os.path.abspath(__file__)) infer_cli_path = os.path.join(current_dir, "src", "f5_tts", "infer", "infer_cli.py") tests_dir = os.path.join(current_dir, "tests") # Dùng huggingface_hub để tải file model và vocab từ repo 'nguyensu27/TTS' vocab_file = hf_hub_download(repo_id="nguyensu27/TTS", filename="vocab.txt") ckpt_file = hf_hub_download(repo_id="nguyensu27/TTS", filename="model_last.pt") os.environ['PYTHONIOENCODING'] = 'utf-8' command = [ sys.executable, infer_cli_path, "--model", model, "--ref_audio", ref_audio_path, "--ref_text", ref_text, "--gen_text", gen_text, "--speed", str(speed), "--vocoder_name", vocoder_name, "--vocab_file", vocab_file, "--ckpt_file", ckpt_file ] try: result = subprocess.run( command, check=True, capture_output=True, text=True, encoding='utf-8' ) if os.path.exists(tests_dir): wav_files = [f for f in os.listdir(tests_dir) if f.endswith('.wav')] if wav_files: latest_wav = max( wav_files, key=lambda x: os.path.getmtime(os.path.join(tests_dir, x)) ) output_file = os.path.join(tests_dir, latest_wav) return True, output_file return False, "Không tìm thấy file âm thanh trong thư mục tests" except subprocess.CalledProcessError as e: return False, e.stderr except Exception as e: return False, str(e) # ========================= # Routes # ========================= @app.route('/') def home(): return "F5-TTS API is running. Use POST /api/generate to generate audio." @app.route('/api/generate', methods=['POST']) def generate_speech(): if 'ref_audio' not in request.files: return {"error": "Missing ref_audio"}, 400 ref_audio = request.files['ref_audio'] ref_text = request.form.get('ref_text', '') gen_text = request.form.get('gen_text', '') model = request.form.get('model', 'F5TTS_Base') speed = float(request.form.get('speed', 1.2)) ref_audio_path = 'temp_ref.wav' ref_audio.save(ref_audio_path) success, result = run_f5_tts(ref_audio_path, ref_text, gen_text, model, speed) os.remove(ref_audio_path) if success: return send_file(result, mimetype='audio/wav') else: return {"error": result}, 500 # ========================= # Main # ========================= if __name__ == "__main__": port = int(os.environ.get("PORT", 7860)) app.run(host="0.0.0.0", port=port, debug=False)