Spaces:
Running
Running
File size: 9,419 Bytes
10529b8 2d48e71 10529b8 9c5f6d1 10529b8 9c5f6d1 10529b8 9c5f6d1 10529b8 2d48e71 10529b8 9c5f6d1 10529b8 9c5f6d1 10529b8 9c5f6d1 10529b8 9c5f6d1 10529b8 2d48e71 10529b8 2d48e71 10529b8 2d48e71 10529b8 9c5f6d1 10529b8 9c5f6d1 10529b8 9c5f6d1 10529b8 2d48e71 10529b8 2d48e71 9c5f6d1 2d48e71 10529b8 2d48e71 9c5f6d1 10529b8 a2c557a 10529b8 a2c557a 9c5f6d1 4146122 9c5f6d1 10529b8 9c5f6d1 10529b8 2d48e71 10529b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import os
import json
from pathlib import Path
from huggingface_hub import hf_hub_download
import logging
# loggerをセットアップ (元のコードに合わせて)
# from style_bert_vits2.logging import logger でも可
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def create_symlink_if_not_exists(repo_id, repo_filepath, local_subpath, link_dir):
"""
Hugging Face Hubからファイルをダウンロードし、シンボリックリンクを作成する。
リポジトリ上のパスとローカルのパスを分離して扱えるように修正。
"""
# リンクを配置したいローカルのフルパスを定義
# 例: link_dir="model_assets", local_subpath="ANNYUI/config.json" -> "model_assets/ANNYUI/config.json"
link_path = Path(link_dir) / local_subpath
# リンクが既に存在する場合はスキップ
if link_path.exists():
return
# リンク先の親ディレクトリがなければ作成
link_path.parent.mkdir(parents=True, exist_ok=True)
logger.info(f"Downloading {repo_id}/{repo_filepath}")
# Hugging Face Hubからファイルをダウンロード(キャッシュ優先)
try:
# ダウンロードにはリポジトリ上のフルパス(repo_filepath)を使用
actual_file_path = hf_hub_download(repo_id=repo_id, filename=repo_filepath)
except Exception as e:
logger.error(f"Failed to download {repo_filepath}: {e}")
return
# シンボリックリンクを作成
logger.info(f"Creating symlink: {link_path} -> {actual_file_path}")
os.symlink(actual_file_path, link_path)
def setup_bert_models():
logger.info("Setting up BERT models...")
with open("bert/bert_models.json", encoding="utf-8") as fp:
models = json.load(fp)
for model_name, model_info in models.items():
repo_id = model_info["repo_id"]
# BERTモデルは 'bert/モデル名/' というサブディレクトリに配置
link_dir = Path("bert") / model_name
for file in model_info["files"]:
create_symlink_if_not_exists(repo_id, file, file, link_dir)
def setup_slm_model():
logger.info("Setting up SLM model...")
filename = "pytorch_model.bin"
link_dir = "slm/wavlm-base-plus/"
# SLMモデルは 'slm/wavlm-base-plus/' というサブディレクトリに配置
create_symlink_if_not_exists(
repo_id="microsoft/wavlm-base-plus",
repo_filepath=filename,
local_subpath=filename,
link_dir=link_dir
)
def setup_pretrained_models():
logger.info("Setting up Pretrained models...")
repo_id = "litagin/Style-Bert-VITS2-1.0-base"
files = ["G_0.safetensors", "D_0.safetensors", "DUR_0.safetensors"]
for file in files:
create_symlink_if_not_exists(repo_id, file, file, "pretrained")
def setup_jp_extra_pretrained_models():
logger.info("Setting up JP-Extra Pretrained models...")
repo_id = "litagin/Style-Bert-VITS2-2.0-base-JP-Extra"
files = ["G_0.safetensors", "D_0.safetensors", "WD_0.safetensors"]
for file in files:
create_symlink_if_not_exists(repo_id, file, file, "pretrained_jp_extra")
def setup_default_models():
logger.info("Setting up default speaker models...")
# `models_to_link` 辞書から "litagin/style_bert_vits2_jvnv" の項目を削除
models_to_link = {
"teradakokoro/voice_models": [
# CO Models
"CO/ANNYUI/config.json", "CO/ANNYUI/style_settings.json", "CO/ANNYUI/style_vectors.npy", "CO/ANNYUI/ANNYUI_e101_s21000.safetensors",
"CO/ASTK/config.json", "CO/ASTK/style_settings.json", "CO/ASTK/style_vectors.npy", "CO/ASTK/ASTK_e501_s28000.safetensors",
"CO/AZS/config.json", "CO/AZS/style_settings.json", "CO/AZS/style_vectors.npy", "CO/AZS/AZS_e442_s34000.safetensors",
"CO/BNKRG/config.json", "CO/BNKRG/style_settings.json", "CO/BNKRG/style_vectors.npy", "CO/BNKRG/BNKRG_e222_s31000.safetensors",
"CO/ESK/config.json", "CO/ESK/style_settings.json", "CO/ESK/style_vectors.npy", "CO/ESK/ESK_e42_s11000.safetensors",
"CO/HNS/config.json", "CO/HNS/style_settings.json", "CO/HNS/style_vectors.npy", "CO/HNS/HNS_e1000_s11000.safetensors",
"CO/HSI/config.json", "CO/HSI/style_settings.json", "CO/HSI/style_vectors.npy", "CO/HSI/HSI_e209_s26000.safetensors",
"CO/KNN/config.json", "CO/KNN/style_settings.json", "CO/KNN/style_vectors.npy", "CO/KNN/KNN_e68_s15000.safetensors",
"CO/MZ/config.json", "CO/MZ/style_settings.json", "CO/MZ/style_vectors.npy", "CO/MZ/MZ_e137_s14000.safetensors",
"CO/NEL/config.json", "CO/NEL/style_settings.json", "CO/NEL/style_vectors.npy", "CO/NEL/NEL_e1000_s16000.safetensors",
"CO/PSR/config.json", "CO/PSR/style_settings.json", "CO/PSR/style_vectors.npy", "CO/PSR/PSR_e142_s25000.safetensors",
"CO/RI/config.json", "CO/RI/style_settings.json", "CO/RI/style_vectors.npy", "CO/RI/RI_e94_s20000.safetensors",
"CO/SKRNBU/config.json", "CO/SKRNBU/style_settings.json", "CO/SKRNBU/style_vectors.npy", "CO/SKRNBU/SKRNBU_e135_s25000.safetensors",
"CO/SNNN/config.json", "CO/SNNN/style_settings.json", "CO/SNNN/style_vectors.npy", "CO/SNNN/SNNN_e201_s11000.safetensors",
"CO/SRMY/config.json", "CO/SRMY/style_settings.json", "CO/SRMY/style_vectors.npy", "CO/SRMY/SRMY_e1000_s20000.safetensors",
"CO/TIS/config.json", "CO/TIS/style_settings.json", "CO/TIS/style_vectors.npy", "CO/TIS/TIS_e596_s28000.safetensors",
"CO/UDK/config.json", "CO/UDK/style_settings.json", "CO/UDK/style_vectors.npy", "CO/UDK/UDK_e472_s25000.safetensors",
"CO/VVAN/config.json", "CO/VVAN/style_settings.json", "CO/VVAN/style_vectors.npy", "CO/VVAN/VVAN_e728_s32000.safetensors",
"CO/YZY/config.json", "CO/YZY/style_settings.json", "CO/YZY/style_vectors.npy", "CO/YZY/YZY_e1000_s15000.safetensors",
"CO/SZ/config.json", "CO/SZ/style_settings.json", "CO/SZ/style_vectors.npy", "CO/SZ/SZ_e622_s23000.safetensors",
"CO/RRM/config.json", "CO/RRM/style_settings.json", "CO/RRM/style_vectors.npy", "CO/RRM/RRM_e63_s20000.safetensors",
# FN Models
"FN/FN1/config.json", "FN/FN1/style_settings.json", "FN/FN1/style_vectors.npy", "FN/FN1/FN1.safetensors",
"FN/FN10/config.json", "FN/FN10/style_settings.json", "FN/FN10/style_vectors.npy", "FN/FN10/FN10.safetensors",
"FN/FN2/config.json", "FN/FN2/style_settings.json", "FN/FN2/style_vectors.npy", "FN/FN2/FN2.safetensors",
"FN/FN3/config.json", "FN/FN3/style_settings.json", "FN/FN3/style_vectors.npy", "FN/FN3/FN3.safetensors",
"FN/FN4/config.json", "FN/FN4/style_settings.json", "FN/FN4/style_vectors.npy", "FN/FN4/FN4.safetensors",
"FN/FN5/config.json", "FN/FN5/style_settings.json", "FN/FN5/style_vectors.npy", "FN/FN5/FN5.safetensors",
"FN/FN6/config.json", "FN/FN6/style_settings.json", "FN/FN6/style_vectors.npy", "FN/FN6/FN6.safetensors",
"FN/FN7/config.json", "FN/FN7/style_settings.json", "FN/FN7/style_vectors.npy", "FN/FN7/FN7.safetensors",
"FN/FN8/config.json", "FN/FN8/style_settings.json", "FN/FN8/style_vectors.npy", "FN/FN8/FN8.safetensors",
"FN/FN9/config.json", "FN/FN9/style_settings.json", "FN/FN9/style_vectors.npy", "FN/FN9/FN9.safetensors",
# other Models
"other/whisper/config.json", "other/whisper/style_vectors.npy", "other/whisper/whisper.safetensors",
]
}
for repo_id, files_in_repo in models_to_link.items():
for repo_filepath in files_in_repo:
local_subpath = repo_filepath # デフォルトではリポジトリパスをそのままローカルパスとする
# teradakokoro/voice_models の場合のみ、パスの先頭部分を加工
if repo_id == "teradakokoro/voice_models":
path_parts = repo_filepath.split('/')
# パスの先頭が 'CO', 'FN', 'other' のいずれかなら、それを取り除く
if len(path_parts) > 1 and path_parts[0] in ["CO", "FN", "other"]:
local_subpath = '/'.join(path_parts[1:])
# 修正したヘルパー関数を呼び出す
create_symlink_if_not_exists(
repo_id=repo_id,
repo_filepath=repo_filepath, # ダウンロード用には元のフルパス
local_subpath=local_subpath, # リンク作成用には加工したパス
link_dir="model_assets"
)
def main(skip_default_models=False, only_infer=False):
"""
アプリケーションが必要とする全てのモデルのシンボリックリンクをセットアップする。
元のスクリプトの引数も模倣。
"""
logger.info("Starting model setup...")
# 常に必要なモデル
setup_bert_models()
if not skip_default_models:
setup_default_models()
# 推論時のみ不要なモデル
if not only_infer:
setup_slm_model()
setup_pretrained_models()
setup_jp_extra_pretrained_models()
logger.info("All models are set up and ready.")
if __name__ == "__main__":
# Gradioアプリから呼び出す場合は引数なしで実行
main() |