""" |
Style-Bert-VITS2-Editor用のサーバー。 |
次のリポジトリ |
https://github.com/litagin02/Style-Bert-VITS2-Editor |
をビルドしてできあがったファイルをWebフォルダに入れて実行する。 |
TODO: リファクタリングやドキュメンテーションやAPI整理、辞書周りの改善などが必要。 |
""" |
import argparse |
import io |
import shutil |
import sys |
import webbrowser |
import zipfile |
from datetime import datetime |
from io import BytesIO |
from pathlib import Path |
from typing import Optional |
import numpy as np |
import requests |
import torch |
import uvicorn |
from fastapi import APIRouter, FastAPI, HTTPException, status |
from fastapi.middleware.cors import CORSMiddleware |
from fastapi.responses import JSONResponse, Response |
from fastapi.staticfiles import StaticFiles |
from pydantic import BaseModel |
from scipy.io import wavfile |
from config import get_path_config |
from style_bert_vits2.constants import ( |
Languages, |
) |
from style_bert_vits2.logging import logger |
from style_bert_vits2.nlp import bert_models |
from style_bert_vits2.nlp.japanese import pyopenjtalk_worker as pyopenjtalk |
from style_bert_vits2.nlp.japanese.g2p_utils import g2kata_tone, kata_tone2phone_tone |
from style_bert_vits2.nlp.japanese.normalizer import normalize_text |
from style_bert_vits2.nlp.japanese.user_dict import ( |
apply_word, |
delete_word, |
read_dict, |
rewrite_word, |
update_dict, |
) |
from style_bert_vits2.tts_model import TTSModelHolder, TTSModelInfo |
STATIC_DIR = Path("static") |
LAST_DOWNLOAD_FILE = STATIC_DIR / "last_download.txt" |
def download_static_files(user, repo, asset_name): |
"""Style-Bert-VITS2エディターの最新のビルドzipをダウンロードして展開する。""" |
logger.info("Checking for new release...") |
latest_release = get_latest_release(user, repo) |
if latest_release is None: |
logger.warning( |
"Failed to fetch the latest release. Proceeding without static files." |
) |
return |
if not new_release_available(latest_release): |
logger.info("No new release available. Proceeding with existing static files.") |
return |
logger.info("New release available. Downloading static files...") |
asset_url = get_asset_url(latest_release, asset_name) |
if asset_url: |
if STATIC_DIR.exists(): |
shutil.rmtree(STATIC_DIR) |
STATIC_DIR.mkdir(parents=True, exist_ok=True) |
download_and_extract(asset_url, STATIC_DIR) |
save_last_download(latest_release) |
else: |
logger.warning("Asset not found. Proceeding without static files.") |
def get_latest_release(user, repo): |
url = f"https://api.github.com/repos/{user}/{repo}/releases/latest" |
try: |
response = requests.get(url) |
response.raise_for_status() |
return response.json() |
except requests.RequestException: |
return None |
def get_asset_url(release, asset_name): |
for asset in release["assets"]: |
if asset["name"] == asset_name: |
return asset["browser_download_url"] |
return None |
def download_and_extract(url, extract_to: Path): |
response = requests.get(url) |
response.raise_for_status() |
with zipfile.ZipFile(io.BytesIO(response.content)) as zip_ref: |
zip_ref.extractall(extract_to) |
extracted_dirs = list(extract_to.iterdir()) |
if len(extracted_dirs) == 1 and extracted_dirs[0].is_dir(): |
for file in extracted_dirs[0].iterdir(): |
file.rename(extract_to / file.name) |
extracted_dirs[0].rmdir() |
if not (extract_to / "index.html").exists(): |
logger.warning("index.html not found in the extracted files.") |
def new_release_available(latest_release): |
if LAST_DOWNLOAD_FILE.exists(): |
with open(LAST_DOWNLOAD_FILE) as file: |
last_download_str = file.read().strip() |
last_download_str = last_download_str.replace("Z", "+00:00") |
last_download = datetime.fromisoformat(last_download_str) |
return ( |
datetime.fromisoformat( |
latest_release["published_at"].replace("Z", "+00:00") |
) |
> last_download |
) |
return True |
def save_last_download(latest_release): |
with open(LAST_DOWNLOAD_FILE, "w") as file: |
file.write(latest_release["published_at"]) |
pyopenjtalk.initialize_worker() |
update_dict() |
bert_models.load_model(Languages.JP) |
bert_models.load_tokenizer(Languages.JP) |
class AudioResponse(Response): |
media_type = "audio/wav" |
origins = [ |
"http://localhost:3000", |
"http://localhost:8000", |
"", |
"", |
] |
path_config = get_path_config() |
parser = argparse.ArgumentParser() |
parser.add_argument("--model_dir", type=str, default=path_config.assets_root) |
parser.add_argument("--device", type=str, default="cuda") |
parser.add_argument("--port", type=int, default=8000) |
parser.add_argument("--inbrowser", action="store_true") |
parser.add_argument("--line_length", type=int, default=None) |
parser.add_argument("--line_count", type=int, default=None) |
parser.add_argument("--skip_static_files", action="store_true") |
args = parser.parse_args() |
device = args.device |
if device == "cuda" and not torch.cuda.is_available(): |
device = "cpu" |
model_dir = Path(args.model_dir) |
port = int(args.port) |
skip_static_files = bool(args.skip_static_files) |
model_holder = TTSModelHolder(model_dir, device) |
if len(model_holder.model_names) == 0: |
logger.error(f"Models not found in {model_dir}.") |
sys.exit(1) |
app = FastAPI() |
app.add_middleware( |
CORSMiddleware, |
allow_origins=origins, |
allow_credentials=True, |
allow_methods=["*"], |
allow_headers=["*"], |
) |
router = APIRouter() |
@router.get("/version") |
def version() -> str: |
return VERSION |
class MoraTone(BaseModel): |
mora: str |
tone: int |
class TextRequest(BaseModel): |
text: str |
@router.post("/g2p") |
async def read_item(item: TextRequest): |
try: |
text = normalize_text(item.text) |
kata_tone_list = g2kata_tone(text) |
except Exception as e: |
raise HTTPException( |
status_code=400, |
detail=f"Failed to convert {item.text} to katakana and tone, {e}", |
) |
return [MoraTone(mora=kata, tone=tone) for kata, tone in kata_tone_list] |
@router.post("/normalize") |
async def normalize(item: TextRequest): |
return normalize_text(item.text) |
@router.get("/models_info", response_model=list[TTSModelInfo]) |
def models_info(): |
return model_holder.models_info |
class SynthesisRequest(BaseModel): |
model: str |
modelFile: str |
text: str |
moraToneList: list[MoraTone] |
style: str = DEFAULT_STYLE |
styleWeight: float = DEFAULT_STYLE_WEIGHT |
assistText: str = "" |
assistTextWeight: float = DEFAULT_ASSIST_TEXT_WEIGHT |
speed: float = 1.0 |
noise: float = DEFAULT_NOISE |
noisew: float = DEFAULT_NOISEW |
sdpRatio: float = DEFAULT_SDP_RATIO |
language: Languages = Languages.JP |
silenceAfter: float = 0.5 |
pitchScale: float = 1.0 |
intonationScale: float = 1.0 |
speaker: Optional[str] = None |
@router.post("/synthesis", response_class=AudioResponse) |
def synthesis(request: SynthesisRequest): |
if args.line_length is not None and len(request.text) > args.line_length: |
raise HTTPException( |
status_code=400, |
detail=f"1行の文字数は{args.line_length}文字以下にしてください。", |
) |
try: |
model = model_holder.get_model( |
model_name=request.model, model_path_str=request.modelFile |
) |
except Exception as e: |
logger.error(e) |
raise HTTPException( |
status_code=500, |
detail=f"Failed to load model {request.model} from {request.modelFile}, {e}", |
) |
text = request.text |
kata_tone_list = [ |
(mora_tone.mora, mora_tone.tone) for mora_tone in request.moraToneList |
] |
phone_tone = kata_tone2phone_tone(kata_tone_list) |
tone = [t for _, t in phone_tone] |
try: |
sid = 0 if request.speaker is None else model.spk2id[request.speaker] |
except KeyError: |
raise HTTPException( |
status_code=400, |
detail=f"Speaker {request.speaker} not found in {model.spk2id}", |
) |
sr, audio = model.infer( |
text=text, |
language=request.language, |
sdp_ratio=request.sdpRatio, |
noise=request.noise, |
noise_w=request.noisew, |
length=1 / request.speed, |
given_tone=tone, |
style=request.style, |
style_weight=request.styleWeight, |
assist_text=request.assistText, |
assist_text_weight=request.assistTextWeight, |
use_assist_text=bool(request.assistText), |
line_split=False, |
pitch_scale=request.pitchScale, |
intonation_scale=request.intonationScale, |
speaker_id=sid, |
) |
with BytesIO() as wavContent: |
wavfile.write(wavContent, sr, audio) |
return Response(content=wavContent.getvalue(), media_type="audio/wav") |
class MultiSynthesisRequest(BaseModel): |
lines: list[SynthesisRequest] |
@router.post("/multi_synthesis", response_class=AudioResponse) |
def multi_synthesis(request: MultiSynthesisRequest): |
lines = request.lines |
if args.line_count is not None and len(lines) > args.line_count: |
raise HTTPException( |
status_code=400, |
detail=f"行数は{args.line_count}行以下にしてください。", |
) |
audios = [] |
sr = None |
for i, req in enumerate(lines): |
if args.line_length is not None and len(req.text) > args.line_length: |
raise HTTPException( |
status_code=400, |
detail=f"1行の文字数は{args.line_length}文字以下にしてください。", |
) |
try: |
model = model_holder.get_model( |
model_name=req.model, model_path_str=req.modelFile |
) |
except Exception as e: |
logger.error(e) |
raise HTTPException( |
status_code=500, |
detail=f"Failed to load model {req.model} from {req.modelFile}, {e}", |
) |
text = req.text |
kata_tone_list = [ |
(mora_tone.mora, mora_tone.tone) for mora_tone in req.moraToneList |
] |
phone_tone = kata_tone2phone_tone(kata_tone_list) |
tone = [t for _, t in phone_tone] |
sr, audio = model.infer( |
text=text, |
language=req.language, |
sdp_ratio=req.sdpRatio, |
noise=req.noise, |
noise_w=req.noisew, |
length=1 / req.speed, |
given_tone=tone, |
style=req.style, |
style_weight=req.styleWeight, |
assist_text=req.assistText, |
assist_text_weight=req.assistTextWeight, |
use_assist_text=bool(req.assistText), |
line_split=False, |
pitch_scale=req.pitchScale, |
intonation_scale=req.intonationScale, |
) |
audios.append(audio) |
if i < len(lines) - 1: |
silence = int(sr * req.silenceAfter) |
audios.append(np.zeros(silence, dtype=np.int16)) |
audio = np.concatenate(audios) |
with BytesIO() as wavContent: |
wavfile.write(wavContent, sr, audio) |
return Response(content=wavContent.getvalue(), media_type="audio/wav") |
class UserDictWordRequest(BaseModel): |
surface: str |
pronunciation: str |
accent_type: int |
priority: int = 5 |
@router.get("/user_dict") |
def get_user_dict(): |
return read_dict() |
@router.post("/user_dict_word") |
def add_user_dict_word(request: UserDictWordRequest): |
uuid = apply_word( |
surface=request.surface, |
pronunciation=request.pronunciation, |
accent_type=request.accent_type, |
priority=request.priority, |
) |
update_dict() |
return JSONResponse( |
status_code=status.HTTP_201_CREATED, |
content={"uuid": uuid}, |
) |
@router.put("/user_dict_word/{uuid}") |
def update_user_dict_word(uuid: str, request: UserDictWordRequest): |
rewrite_word( |
word_uuid=uuid, |
surface=request.surface, |
pronunciation=request.pronunciation, |
accent_type=request.accent_type, |
priority=request.priority, |
) |
update_dict() |
return JSONResponse(status_code=status.HTTP_200_OK, content={"uuid": uuid}) |
@router.delete("/user_dict_word/{uuid}") |
def delete_user_dict_word(uuid: str): |
delete_word(uuid) |
update_dict() |
return JSONResponse(status_code=status.HTTP_200_OK, content={"uuid": uuid}) |
app.include_router(router, prefix="/api") |
if __name__ == "__main__": |
if not skip_static_files: |
download_static_files("litagin02", "Style-Bert-VITS2-Editor", "out.zip") |
app.mount("/", StaticFiles(directory=STATIC_DIR, html=True), name="static") |
if args.inbrowser: |
webbrowser.open(f"http://localhost:{port}") |
uvicorn.run(app, host="", port=port) |