amateur_voice / server_fastapi.py
cockolo terada
Upload folder using huggingface_hub
2d48e71 verified
raw
history blame
13.2 kB
"""
API server for TTS
TODO: server_editor.pyใจ็ตฑๅˆใ™ใ‚‹?
"""
import argparse
import os
import sys
from io import BytesIO
from pathlib import Path
from typing import Any, Optional
from urllib.parse import unquote
import GPUtil
import psutil
import torch
import uvicorn
from fastapi import FastAPI, HTTPException, Query, Request, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, Response
from scipy.io import wavfile
from config import get_config
from style_bert_vits2.constants import (
DEFAULT_ASSIST_TEXT_WEIGHT,
DEFAULT_LENGTH,
DEFAULT_LINE_SPLIT,
DEFAULT_NOISE,
DEFAULT_NOISEW,
DEFAULT_SDP_RATIO,
DEFAULT_SPLIT_INTERVAL,
DEFAULT_STYLE,
DEFAULT_STYLE_WEIGHT,
Languages,
)
from style_bert_vits2.logging import logger
from style_bert_vits2.nlp import bert_models
from style_bert_vits2.nlp.japanese import pyopenjtalk_worker as pyopenjtalk
from style_bert_vits2.nlp.japanese.user_dict import update_dict
from style_bert_vits2.tts_model import TTSModel, TTSModelHolder
config = get_config()
ln = config.server_config.language
# pyopenjtalk_worker ใ‚’่ตทๅ‹•
## pyopenjtalk_worker ใฏ TCP ใ‚ฝใ‚ฑใƒƒใƒˆใ‚ตใƒผใƒใƒผใฎใŸใ‚ใ€ใ“ใ“ใง่ตทๅ‹•ใ™ใ‚‹
pyopenjtalk.initialize_worker()
# dict_data/ ไปฅไธ‹ใฎ่พžๆ›ธใƒ‡ใƒผใ‚ฟใ‚’ pyopenjtalk ใซ้ฉ็”จ
update_dict()
# ไบ‹ๅ‰ใซ BERT ใƒขใƒ‡ใƒซ/ใƒˆใƒผใ‚ฏใƒŠใ‚คใ‚ถใƒผใ‚’ใƒญใƒผใƒ‰ใ—ใฆใŠใ
## ใ“ใ“ใงใƒญใƒผใƒ‰ใ—ใชใใฆใ‚‚ๅฟ…่ฆใซใชใฃใŸ้š›ใซ่‡ชๅ‹•ใƒญใƒผใƒ‰ใ•ใ‚Œใ‚‹ใŒใ€ๆ™‚้–“ใŒใ‹ใ‹ใ‚‹ใŸใ‚ไบ‹ๅ‰ใซใƒญใƒผใƒ‰ใ—ใฆใŠใ„ใŸๆ–นใŒไฝ“้จ“ใŒ่‰ฏใ„
bert_models.load_model(Languages.JP)
bert_models.load_tokenizer(Languages.JP)
bert_models.load_model(Languages.EN)
bert_models.load_tokenizer(Languages.EN)
bert_models.load_model(Languages.ZH)
bert_models.load_tokenizer(Languages.ZH)
def raise_validation_error(msg: str, param: str):
logger.warning(f"Validation error: {msg}")
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=[dict(type="invalid_params", msg=msg, loc=["query", param])],
)
class AudioResponse(Response):
media_type = "audio/wav"
loaded_models: list[TTSModel] = []
def load_models(model_holder: TTSModelHolder):
global loaded_models
loaded_models = []
for model_name, model_paths in model_holder.model_files_dict.items():
model = TTSModel(
model_path=model_paths[0],
config_path=model_holder.root_dir / model_name / "config.json",
style_vec_path=model_holder.root_dir / model_name / "style_vectors.npy",
device=model_holder.device,
)
# ่ตทๅ‹•ๆ™‚ใซๅ…จใฆใฎใƒขใƒ‡ใƒซใ‚’่ชญใฟ่พผใ‚€ใฎใฏๆ™‚้–“ใŒใ‹ใ‹ใ‚Šใƒกใƒขใƒชใ‚’้ฃŸใ†ใฎใงใ‚„ใ‚ใ‚‹
# model.load()
loaded_models.append(model)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--cpu", action="store_true", help="Use CPU instead of GPU")
parser.add_argument(
"--dir", "-d", type=str, help="Model directory", default=config.assets_root
)
args = parser.parse_args()
if args.cpu:
device = "cpu"
else:
device = "cuda" if torch.cuda.is_available() else "cpu"
model_dir = Path(args.dir)
model_holder = TTSModelHolder(model_dir, device)
if len(model_holder.model_names) == 0:
logger.error(f"Models not found in {model_dir}.")
sys.exit(1)
logger.info("Loading models...")
load_models(model_holder)
limit = config.server_config.limit
if limit < 1:
limit = None
else:
logger.info(
f"The maximum length of the text is {limit}. If you want to change it, modify config.yml. Set limit to -1 to remove the limit."
)
app = FastAPI()
allow_origins = config.server_config.origins
if allow_origins:
logger.warning(
f"CORS allow_origins={config.server_config.origins}. If you don't want, modify config.yml"
)
app.add_middleware(
CORSMiddleware,
allow_origins=config.server_config.origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# app.logger = logger
# โ†‘ๅŠนใ„ใฆใ„ใชใ•ใใ†ใ€‚loggerใ‚’ใฉใ†ใ‚„ใฃใฆไธŠๆ›ธใใ™ใ‚‹ใ‹ใฏใ‚ˆใๅˆ†ใ‹ใ‚‰ใชใ‹ใฃใŸใ€‚
@app.api_route("/voice", methods=["GET", "POST"], response_class=AudioResponse)
async def voice(
request: Request,
text: str = Query(..., min_length=1, max_length=limit, description="ใ‚ปใƒชใƒ•"),
encoding: str = Query(None, description="textใ‚’URLใƒ‡ใ‚ณใƒผใƒ‰ใ™ใ‚‹(ex, `utf-8`)"),
model_name: str = Query(
None,
description="ใƒขใƒ‡ใƒซๅ(model_idใ‚ˆใ‚Šๅ„ชๅ…ˆ)ใ€‚model_assetsๅ†…ใฎใƒ‡ใ‚ฃใƒฌใ‚ฏใƒˆใƒชๅใ‚’ๆŒ‡ๅฎš",
),
model_id: int = Query(
0, description="ใƒขใƒ‡ใƒซIDใ€‚`GET /models/info`ใฎkeyใฎๅ€คใ‚’ๆŒ‡ๅฎšใใ ใ•ใ„"
),
speaker_name: str = Query(
None,
description="่ฉฑ่€…ๅ(speaker_idใ‚ˆใ‚Šๅ„ชๅ…ˆ)ใ€‚esd.listใฎ2ๅˆ—็›ฎใฎๆ–‡ๅญ—ๅˆ—ใ‚’ๆŒ‡ๅฎš",
),
speaker_id: int = Query(
0, description="่ฉฑ่€…IDใ€‚model_assets>[model]>config.jsonๅ†…ใฎspk2idใ‚’็ขบ่ช"
),
sdp_ratio: float = Query(
DEFAULT_SDP_RATIO,
description="SDP(Stochastic Duration Predictor)/DPๆททๅˆๆฏ”ใ€‚ๆฏ”็އใŒ้ซ˜ใใชใ‚‹ใปใฉใƒˆใƒผใƒณใฎใฐใ‚‰ใคใใŒๅคงใใใชใ‚‹",
),
noise: float = Query(
DEFAULT_NOISE,
description="ใ‚ตใƒณใƒ—ใƒซใƒŽใ‚คใ‚บใฎๅ‰ฒๅˆใ€‚ๅคงใใใ™ใ‚‹ใปใฉใƒฉใƒณใƒ€ใƒ ๆ€งใŒ้ซ˜ใพใ‚‹",
),
noisew: float = Query(
DEFAULT_NOISEW,
description="SDPใƒŽใ‚คใ‚บใ€‚ๅคงใใใ™ใ‚‹ใปใฉ็™บ้Ÿณใฎ้–“้š”ใซใฐใ‚‰ใคใใŒๅ‡บใ‚„ใ™ใใชใ‚‹",
),
length: float = Query(
DEFAULT_LENGTH,
description="่ฉฑ้€Ÿใ€‚ๅŸบๆบ–ใฏ1ใงๅคงใใใ™ใ‚‹ใปใฉ้Ÿณๅฃฐใฏ้•ทใใชใ‚Š่ชญใฟไธŠใ’ใŒ้…ใพใ‚‹",
),
language: Languages = Query(ln, description="textใฎ่จ€่ชž"),
auto_split: bool = Query(DEFAULT_LINE_SPLIT, description="ๆ”น่กŒใงๅˆ†ใ‘ใฆ็”Ÿๆˆ"),
split_interval: float = Query(
DEFAULT_SPLIT_INTERVAL, description="ๅˆ†ใ‘ใŸๅ ดๅˆใซๆŒŸใ‚€็„ก้Ÿณใฎ้•ทใ•๏ผˆ็ง’๏ผ‰"
),
assist_text: Optional[str] = Query(
None,
description="ใ“ใฎใƒ†ใ‚ญใ‚นใƒˆใฎ่ชญใฟไธŠใ’ใจไผผใŸๅฃฐ้Ÿณใƒปๆ„Ÿๆƒ…ใซใชใ‚Šใ‚„ใ™ใใชใ‚‹ใ€‚ใŸใ ใ—ๆŠ‘ๆšใ‚„ใƒ†ใƒณใƒ็ญ‰ใŒ็Š ็‰ฒใซใชใ‚‹ๅ‚พๅ‘ใŒใ‚ใ‚‹",
),
assist_text_weight: float = Query(
DEFAULT_ASSIST_TEXT_WEIGHT, description="assist_textใฎๅผทใ•"
),
style: Optional[str] = Query(DEFAULT_STYLE, description="ใ‚นใ‚ฟใ‚คใƒซ"),
style_weight: float = Query(DEFAULT_STYLE_WEIGHT, description="ใ‚นใ‚ฟใ‚คใƒซใฎๅผทใ•"),
reference_audio_path: Optional[str] = Query(
None, description="ใ‚นใ‚ฟใ‚คใƒซใ‚’้Ÿณๅฃฐใƒ•ใ‚กใ‚คใƒซใง่กŒใ†"
),
):
"""Infer text to speech(ใƒ†ใ‚ญใ‚นใƒˆใ‹ใ‚‰ๆ„Ÿๆƒ…ไป˜ใ้Ÿณๅฃฐใ‚’็”Ÿๆˆใ™ใ‚‹)"""
logger.info(
f"{request.client.host}:{request.client.port}/voice { unquote(str(request.query_params) )}"
)
if request.method == "GET":
logger.warning(
"The GET method is not recommended for this endpoint due to various restrictions. Please use the POST method."
)
if model_id >= len(
model_holder.model_names
): # /models/refresh ใŒใ‚ใ‚‹ใŸใ‚Query(le)ใง่กจ็พไธๅฏ
raise_validation_error(f"model_id={model_id} not found", "model_id")
if model_name:
# load_models() ใฎ ๅ‡ฆ็†ๅ†…ๅฎนใŒ i ใฎๆญฃๅฝ“ๆ€งใ‚’ๆ‹…ไฟใ—ใฆใ„ใ‚‹ใ“ใจใซๆณจๆ„
model_ids = [i for i, x in enumerate(model_holder.models_info) if x.name == model_name]
if not model_ids:
raise_validation_error(
f"model_name={model_name} not found", "model_name"
)
# ไปŠใฎๅฎŸ่ฃ…ใงใฏใƒ‡ใ‚ฃใƒฌใ‚ฏใƒˆใƒชๅใŒ้‡่ค‡ใ™ใ‚‹ใ“ใจใฏ็„กใ„ใฏใšใ ใŒ...
if len(model_ids) > 1:
raise_validation_error(
f"model_name={model_name} is ambiguous", "model_name"
)
model_id = model_ids[0]
model = loaded_models[model_id]
if speaker_name is None:
if speaker_id not in model.id2spk.keys():
raise_validation_error(
f"speaker_id={speaker_id} not found", "speaker_id"
)
else:
if speaker_name not in model.spk2id.keys():
raise_validation_error(
f"speaker_name={speaker_name} not found", "speaker_name"
)
speaker_id = model.spk2id[speaker_name]
if style not in model.style2id.keys():
raise_validation_error(f"style={style} not found", "style")
assert style is not None
if encoding is not None:
text = unquote(text, encoding=encoding)
sr, audio = model.infer(
text=text,
language=language,
speaker_id=speaker_id,
reference_audio_path=reference_audio_path,
sdp_ratio=sdp_ratio,
noise=noise,
noise_w=noisew,
length=length,
line_split=auto_split,
split_interval=split_interval,
assist_text=assist_text,
assist_text_weight=assist_text_weight,
use_assist_text=bool(assist_text),
style=style,
style_weight=style_weight,
)
logger.success("Audio data generated and sent successfully")
with BytesIO() as wavContent:
wavfile.write(wavContent, sr, audio)
return Response(content=wavContent.getvalue(), media_type="audio/wav")
@app.post("/g2p")
def g2p(text: str):
return g2kata_tone(normalize_text(text))
@app.get("/models/info")
def get_loaded_models_info():
"""ใƒญใƒผใƒ‰ใ•ใ‚ŒใŸใƒขใƒ‡ใƒซๆƒ…ๅ ฑใฎๅ–ๅพ—"""
result: dict[str, dict[str, Any]] = dict()
for model_id, model in enumerate(loaded_models):
result[str(model_id)] = {
"config_path": model.config_path,
"model_path": model.model_path,
"device": model.device,
"spk2id": model.spk2id,
"id2spk": model.id2spk,
"style2id": model.style2id,
}
return result
@app.post("/models/refresh")
def refresh():
"""ใƒขใƒ‡ใƒซใ‚’ใƒ‘ใ‚นใซ่ฟฝๅŠ /ๅ‰Š้™คใ—ใŸ้š›ใชใฉใซ่ชญใฟ่พผใพใ›ใ‚‹"""
model_holder.refresh()
load_models(model_holder)
return get_loaded_models_info()
@app.get("/status")
def get_status():
"""ๅฎŸ่กŒ็’ฐๅขƒใฎใ‚นใƒ†ใƒผใ‚ฟใ‚นใ‚’ๅ–ๅพ—"""
cpu_percent = psutil.cpu_percent(interval=1)
memory_info = psutil.virtual_memory()
memory_total = memory_info.total
memory_available = memory_info.available
memory_used = memory_info.used
memory_percent = memory_info.percent
gpuInfo = []
devices = ["cpu"]
for i in range(torch.cuda.device_count()):
devices.append(f"cuda:{i}")
gpus = GPUtil.getGPUs()
for gpu in gpus:
gpuInfo.append(
{
"gpu_id": gpu.id,
"gpu_load": gpu.load,
"gpu_memory": {
"total": gpu.memoryTotal,
"used": gpu.memoryUsed,
"free": gpu.memoryFree,
},
}
)
return {
"devices": devices,
"cpu_percent": cpu_percent,
"memory_total": memory_total,
"memory_available": memory_available,
"memory_used": memory_used,
"memory_percent": memory_percent,
"gpu": gpuInfo,
}
@app.get("/tools/get_audio", response_class=AudioResponse)
def get_audio(
request: Request, path: str = Query(..., description="local wav path")
):
"""wavใƒ‡ใƒผใ‚ฟใ‚’ๅ–ๅพ—ใ™ใ‚‹"""
logger.info(
f"{request.client.host}:{request.client.port}/tools/get_audio { unquote(str(request.query_params) )}"
)
if not os.path.isfile(path):
raise_validation_error(f"path={path} not found", "path")
if not path.lower().endswith(".wav"):
raise_validation_error(f"wav file not found in {path}", "path")
return FileResponse(path=path, media_type="audio/wav")
logger.info(f"server listen: http://127.0.0.1:{config.server_config.port}")
logger.info(f"API docs: http://127.0.0.1:{config.server_config.port}/docs")
logger.info(
f"Input text length limit: {limit}. You can change it in server.limit in config.yml"
)
uvicorn.run(
app, port=config.server_config.port, host="0.0.0.0", log_level="warning"
)