CHUYEN_MP3 / app.py
mrsu0994
upload f5-tts source
391bf91
raw
history blame
3.04 kB
from flask import Flask, request, send_file
import subprocess
import os
import sys
from huggingface_hub import hf_hub_download
app = Flask(__name__)
# =========================
# Hàm chạy F5-TTS
# =========================
def run_f5_tts(ref_audio_path, ref_text, gen_text, model="F5TTS_Base", speed=1.2, vocoder_name="vocos"):
current_dir = os.path.dirname(os.path.abspath(__file__))
infer_cli_path = os.path.join(current_dir, "src", "f5_tts", "infer", "infer_cli.py")
tests_dir = os.path.join(current_dir, "tests")
# Dùng huggingface_hub để tải file model và vocab từ repo 'nguyensu27/TTS'
vocab_file = hf_hub_download(repo_id="nguyensu27/TTS", filename="vocab.txt")
ckpt_file = hf_hub_download(repo_id="nguyensu27/TTS", filename="model_last.pt")
os.environ['PYTHONIOENCODING'] = 'utf-8'
command = [
sys.executable,
infer_cli_path,
"--model", model,
"--ref_audio", ref_audio_path,
"--ref_text", ref_text,
"--gen_text", gen_text,
"--speed", str(speed),
"--vocoder_name", vocoder_name,
"--vocab_file", vocab_file,
"--ckpt_file", ckpt_file
]
try:
result = subprocess.run(
command,
check=True,
capture_output=True,
text=True,
encoding='utf-8'
)
if os.path.exists(tests_dir):
wav_files = [f for f in os.listdir(tests_dir) if f.endswith('.wav')]
if wav_files:
latest_wav = max(
wav_files, key=lambda x: os.path.getmtime(os.path.join(tests_dir, x))
)
output_file = os.path.join(tests_dir, latest_wav)
return True, output_file
return False, "Không tìm thấy file âm thanh trong thư mục tests"
except subprocess.CalledProcessError as e:
return False, e.stderr
except Exception as e:
return False, str(e)
# =========================
# Routes
# =========================
@app.route('/')
def home():
return "F5-TTS API is running. Use POST /api/generate to generate audio."
@app.route('/api/generate', methods=['POST'])
def generate_speech():
if 'ref_audio' not in request.files:
return {"error": "Missing ref_audio"}, 400
ref_audio = request.files['ref_audio']
ref_text = request.form.get('ref_text', '')
gen_text = request.form.get('gen_text', '')
model = request.form.get('model', 'F5TTS_Base')
speed = float(request.form.get('speed', 1.2))
ref_audio_path = 'temp_ref.wav'
ref_audio.save(ref_audio_path)
success, result = run_f5_tts(ref_audio_path, ref_text, gen_text, model, speed)
os.remove(ref_audio_path)
if success:
return send_file(result, mimetype='audio/wav')
else:
return {"error": result}, 500
# =========================
# Main
# =========================
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
app.run(host="0.0.0.0", port=port, debug=False)