File size: 9,419 Bytes
10529b8
2d48e71
 
 
10529b8
 
 
 
 
 
 
9c5f6d1
10529b8
9c5f6d1
 
10529b8
9c5f6d1
 
 
10529b8
 
 
 
2d48e71
10529b8
 
 
9c5f6d1
10529b8
9c5f6d1
10529b8
9c5f6d1
 
10529b8
9c5f6d1
10529b8
2d48e71
10529b8
 
 
2d48e71
10529b8
 
2d48e71
 
10529b8
 
 
 
 
 
9c5f6d1
10529b8
 
 
9c5f6d1
 
10529b8
 
 
9c5f6d1
 
 
10529b8
2d48e71
10529b8
 
 
2d48e71
 
9c5f6d1
2d48e71
10529b8
 
 
2d48e71
 
9c5f6d1
10529b8
 
 
a2c557a
10529b8
a2c557a
 
9c5f6d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4146122
 
9c5f6d1
 
 
 
 
 
 
 
 
 
 
 
 
 
10529b8
 
 
9c5f6d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10529b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d48e71
 
10529b8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import os
import json
from pathlib import Path
from huggingface_hub import hf_hub_download
import logging

# loggerをセットアップ (元のコードに合わせて)
# from style_bert_vits2.logging import logger でも可
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def create_symlink_if_not_exists(repo_id, repo_filepath, local_subpath, link_dir):
    """
    Hugging Face Hubからファイルをダウンロードし、シンボリックリンクを作成する。
    リポジトリ上のパスとローカルのパスを分離して扱えるように修正。
    """
    # リンクを配置したいローカルのフルパスを定義
    # 例: link_dir="model_assets", local_subpath="ANNYUI/config.json" -> "model_assets/ANNYUI/config.json"
    link_path = Path(link_dir) / local_subpath

    # リンクが既に存在する場合はスキップ
    if link_path.exists():
        return

    # リンク先の親ディレクトリがなければ作成
    link_path.parent.mkdir(parents=True, exist_ok=True)
    
    logger.info(f"Downloading {repo_id}/{repo_filepath}")
    
    # Hugging Face Hubからファイルをダウンロード(キャッシュ優先)
    try:
        # ダウンロードにはリポジトリ上のフルパス(repo_filepath)を使用
        actual_file_path = hf_hub_download(repo_id=repo_id, filename=repo_filepath)
    except Exception as e:
        logger.error(f"Failed to download {repo_filepath}: {e}")
        return

    # シンボリックリンクを作成
    logger.info(f"Creating symlink: {link_path} -> {actual_file_path}")
    os.symlink(actual_file_path, link_path)

def setup_bert_models():
    logger.info("Setting up BERT models...")
    with open("bert/bert_models.json", encoding="utf-8") as fp:
        models = json.load(fp)
    
    for model_name, model_info in models.items():
        repo_id = model_info["repo_id"]
        # BERTモデルは 'bert/モデル名/' というサブディレクトリに配置
        link_dir = Path("bert") / model_name 
        for file in model_info["files"]:
            create_symlink_if_not_exists(repo_id, file, file, link_dir)

def setup_slm_model():
    logger.info("Setting up SLM model...")
    filename = "pytorch_model.bin"
    link_dir = "slm/wavlm-base-plus/"
    # SLMモデルは 'slm/wavlm-base-plus/' というサブディレクトリに配置
    create_symlink_if_not_exists(
        repo_id="microsoft/wavlm-base-plus",
        repo_filepath=filename,
        local_subpath=filename,
        link_dir=link_dir
    )

def setup_pretrained_models():
    logger.info("Setting up Pretrained models...")
    repo_id = "litagin/Style-Bert-VITS2-1.0-base"
    files = ["G_0.safetensors", "D_0.safetensors", "DUR_0.safetensors"]
    for file in files:
        create_symlink_if_not_exists(repo_id, file, file, "pretrained")

def setup_jp_extra_pretrained_models():
    logger.info("Setting up JP-Extra Pretrained models...")
    repo_id = "litagin/Style-Bert-VITS2-2.0-base-JP-Extra"
    files = ["G_0.safetensors", "D_0.safetensors", "WD_0.safetensors"]
    for file in files:
        create_symlink_if_not_exists(repo_id, file, file, "pretrained_jp_extra")

def setup_default_models():
    logger.info("Setting up default speaker models...")
    # `models_to_link` 辞書から "litagin/style_bert_vits2_jvnv" の項目を削除
    models_to_link = {
        "teradakokoro/voice_models": [
            # CO Models
            "CO/ANNYUI/config.json", "CO/ANNYUI/style_settings.json", "CO/ANNYUI/style_vectors.npy", "CO/ANNYUI/ANNYUI_e101_s21000.safetensors",
            "CO/ASTK/config.json", "CO/ASTK/style_settings.json", "CO/ASTK/style_vectors.npy", "CO/ASTK/ASTK_e501_s28000.safetensors",
            "CO/AZS/config.json", "CO/AZS/style_settings.json", "CO/AZS/style_vectors.npy", "CO/AZS/AZS_e442_s34000.safetensors",
            "CO/BNKRG/config.json", "CO/BNKRG/style_settings.json", "CO/BNKRG/style_vectors.npy", "CO/BNKRG/BNKRG_e222_s31000.safetensors",
            "CO/ESK/config.json", "CO/ESK/style_settings.json", "CO/ESK/style_vectors.npy", "CO/ESK/ESK_e42_s11000.safetensors",
            "CO/HNS/config.json", "CO/HNS/style_settings.json", "CO/HNS/style_vectors.npy", "CO/HNS/HNS_e1000_s11000.safetensors",
            "CO/HSI/config.json", "CO/HSI/style_settings.json", "CO/HSI/style_vectors.npy", "CO/HSI/HSI_e209_s26000.safetensors",
            "CO/KNN/config.json", "CO/KNN/style_settings.json", "CO/KNN/style_vectors.npy", "CO/KNN/KNN_e68_s15000.safetensors",
            "CO/MZ/config.json", "CO/MZ/style_settings.json", "CO/MZ/style_vectors.npy", "CO/MZ/MZ_e137_s14000.safetensors",
            "CO/NEL/config.json", "CO/NEL/style_settings.json", "CO/NEL/style_vectors.npy", "CO/NEL/NEL_e1000_s16000.safetensors",
            "CO/PSR/config.json", "CO/PSR/style_settings.json", "CO/PSR/style_vectors.npy", "CO/PSR/PSR_e142_s25000.safetensors",
            "CO/RI/config.json", "CO/RI/style_settings.json", "CO/RI/style_vectors.npy", "CO/RI/RI_e94_s20000.safetensors",
            "CO/SKRNBU/config.json", "CO/SKRNBU/style_settings.json", "CO/SKRNBU/style_vectors.npy", "CO/SKRNBU/SKRNBU_e135_s25000.safetensors",
            "CO/SNNN/config.json", "CO/SNNN/style_settings.json", "CO/SNNN/style_vectors.npy", "CO/SNNN/SNNN_e201_s11000.safetensors",
            "CO/SRMY/config.json", "CO/SRMY/style_settings.json", "CO/SRMY/style_vectors.npy", "CO/SRMY/SRMY_e1000_s20000.safetensors",
            "CO/TIS/config.json", "CO/TIS/style_settings.json", "CO/TIS/style_vectors.npy", "CO/TIS/TIS_e596_s28000.safetensors",
            "CO/UDK/config.json", "CO/UDK/style_settings.json", "CO/UDK/style_vectors.npy", "CO/UDK/UDK_e472_s25000.safetensors",
            "CO/VVAN/config.json", "CO/VVAN/style_settings.json", "CO/VVAN/style_vectors.npy", "CO/VVAN/VVAN_e728_s32000.safetensors",
            "CO/YZY/config.json", "CO/YZY/style_settings.json", "CO/YZY/style_vectors.npy", "CO/YZY/YZY_e1000_s15000.safetensors",
            "CO/SZ/config.json", "CO/SZ/style_settings.json", "CO/SZ/style_vectors.npy", "CO/SZ/SZ_e622_s23000.safetensors",
            "CO/RRM/config.json", "CO/RRM/style_settings.json", "CO/RRM/style_vectors.npy", "CO/RRM/RRM_e63_s20000.safetensors",
            # FN Models
            "FN/FN1/config.json", "FN/FN1/style_settings.json", "FN/FN1/style_vectors.npy", "FN/FN1/FN1.safetensors",
            "FN/FN10/config.json", "FN/FN10/style_settings.json", "FN/FN10/style_vectors.npy", "FN/FN10/FN10.safetensors",
            "FN/FN2/config.json", "FN/FN2/style_settings.json", "FN/FN2/style_vectors.npy", "FN/FN2/FN2.safetensors",
            "FN/FN3/config.json", "FN/FN3/style_settings.json", "FN/FN3/style_vectors.npy", "FN/FN3/FN3.safetensors",
            "FN/FN4/config.json", "FN/FN4/style_settings.json", "FN/FN4/style_vectors.npy", "FN/FN4/FN4.safetensors",
            "FN/FN5/config.json", "FN/FN5/style_settings.json", "FN/FN5/style_vectors.npy", "FN/FN5/FN5.safetensors",
            "FN/FN6/config.json", "FN/FN6/style_settings.json", "FN/FN6/style_vectors.npy", "FN/FN6/FN6.safetensors",
            "FN/FN7/config.json", "FN/FN7/style_settings.json", "FN/FN7/style_vectors.npy", "FN/FN7/FN7.safetensors",
            "FN/FN8/config.json", "FN/FN8/style_settings.json", "FN/FN8/style_vectors.npy", "FN/FN8/FN8.safetensors",
            "FN/FN9/config.json", "FN/FN9/style_settings.json", "FN/FN9/style_vectors.npy", "FN/FN9/FN9.safetensors",
            # other Models
            "other/whisper/config.json", "other/whisper/style_vectors.npy", "other/whisper/whisper.safetensors",
        ]
    }

    for repo_id, files_in_repo in models_to_link.items():
        for repo_filepath in files_in_repo:
            local_subpath = repo_filepath  # デフォルトではリポジトリパスをそのままローカルパスとする

            # teradakokoro/voice_models の場合のみ、パスの先頭部分を加工
            if repo_id == "teradakokoro/voice_models":
                path_parts = repo_filepath.split('/')
                # パスの先頭が 'CO', 'FN', 'other' のいずれかなら、それを取り除く
                if len(path_parts) > 1 and path_parts[0] in ["CO", "FN", "other"]:
                    local_subpath = '/'.join(path_parts[1:])
            
            # 修正したヘルパー関数を呼び出す
            create_symlink_if_not_exists(
                repo_id=repo_id,
                repo_filepath=repo_filepath, # ダウンロード用には元のフルパス
                local_subpath=local_subpath,   # リンク作成用には加工したパス
                link_dir="model_assets"
            )

def main(skip_default_models=False, only_infer=False):
    """
    アプリケーションが必要とする全てのモデルのシンボリックリンクをセットアップする。
    元のスクリプトの引数も模倣。
    """
    logger.info("Starting model setup...")
    
    # 常に必要なモデル
    setup_bert_models()
    if not skip_default_models:
        setup_default_models()

    # 推論時のみ不要なモデル
    if not only_infer:
        setup_slm_model()
        setup_pretrained_models()
        setup_jp_extra_pretrained_models()

    logger.info("All models are set up and ready.")

if __name__ == "__main__":
    # Gradioアプリから呼び出す場合は引数なしで実行
    main()