cockolo terada commited on
Commit
10529b8
·
verified ·
1 Parent(s): 9c68426

Update initialize.py

Browse files
Files changed (1) hide show
  1. initialize.py +98 -131
initialize.py CHANGED
@@ -1,147 +1,114 @@
1
- import argparse
2
  import json
3
- import shutil
4
  from pathlib import Path
5
-
6
- import yaml
7
  from huggingface_hub import hf_hub_download
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- from style_bert_vits2.logging import logger
 
 
 
 
 
 
 
 
 
 
10
 
 
 
 
11
 
12
- def download_bert_models():
 
13
  with open("bert/bert_models.json", encoding="utf-8") as fp:
14
  models = json.load(fp)
15
- for k, v in models.items():
16
- local_path = Path("bert").joinpath(k)
17
- for file in v["files"]:
18
- if not Path(local_path).joinpath(file).exists():
19
- logger.info(f"Downloading {k} {file}")
20
- hf_hub_download(v["repo_id"], file, local_dir=local_path)
21
-
22
-
23
- def download_slm_model():
24
- local_path = Path("slm/wavlm-base-plus/")
25
- file = "pytorch_model.bin"
26
- if not Path(local_path).joinpath(file).exists():
27
- logger.info(f"Downloading wavlm-base-plus {file}")
28
- hf_hub_download("microsoft/wavlm-base-plus", file, local_dir=local_path)
29
-
 
30
 
31
- def download_pretrained_models():
 
 
32
  files = ["G_0.safetensors", "D_0.safetensors", "DUR_0.safetensors"]
33
- local_path = Path("pretrained")
34
  for file in files:
35
- if not Path(local_path).joinpath(file).exists():
36
- logger.info(f"Downloading pretrained {file}")
37
- hf_hub_download(
38
- "litagin/Style-Bert-VITS2-1.0-base", file, local_dir=local_path
39
- )
40
-
41
 
42
- def download_jp_extra_pretrained_models():
 
 
43
  files = ["G_0.safetensors", "D_0.safetensors", "WD_0.safetensors"]
44
- local_path = Path("pretrained_jp_extra")
45
  for file in files:
46
- if not Path(local_path).joinpath(file).exists():
47
- logger.info(f"Downloading JP-Extra pretrained {file}")
48
- hf_hub_download(
49
- "litagin/Style-Bert-VITS2-2.0-base-JP-Extra", file, local_dir=local_path
50
- )
51
-
52
-
53
- #def download_default_models():
54
- # files = [
55
- # "jvnv-F1-jp/config.json",
56
- # "jvnv-F1-jp/jvnv-F1-jp_e160_s14000.safetensors",
57
- # "jvnv-F1-jp/style_vectors.npy",
58
- # "jvnv-F2-jp/config.json",
59
- # "jvnv-F2-jp/jvnv-F2_e166_s20000.safetensors",
60
- # "jvnv-F2-jp/style_vectors.npy",
61
- # "jvnv-M1-jp/config.json",
62
- # "jvnv-M1-jp/jvnv-M1-jp_e158_s14000.safetensors",
63
- # "jvnv-M1-jp/style_vectors.npy",
64
- # "jvnv-M2-jp/config.json",
65
- # "jvnv-M2-jp/jvnv-M2-jp_e159_s17000.safetensors",
66
- # "jvnv-M2-jp/style_vectors.npy",
67
- # ]
68
- # for file in files:
69
- # if not Path(f"model_assets/{file}").exists():
70
- # logger.info(f"Downloading {file}")
71
- # hf_hub_download(
72
- # "litagin/style_bert_vits2_jvnv",
73
- # file,
74
- # local_dir="model_assets",
75
- # )
76
- # additional_files = {
77
- # "litagin/sbv2_koharune_ami": [
78
- # "koharune-ami/config.json",
79
- # "koharune-ami/style_vectors.npy",
80
- # "koharune-ami/koharune-ami.safetensors",
81
- # ],
82
- # "litagin/sbv2_amitaro": [
83
- # "amitaro/config.json",
84
- # "amitaro/style_vectors.npy",
85
- # "amitaro/amitaro.safetensors",
86
- # ],
87
- # }
88
- # for repo_id, files in additional_files.items():
89
- # for file in files:
90
- # if not Path(f"model_assets/{file}").exists():
91
- # logger.info(f"Downloading {file}")
92
- # hf_hub_download(
93
- # repo_id,
94
- # file,
95
- # local_dir="model_assets",
96
- # )
97
-
98
-
99
- def main():
100
- parser = argparse.ArgumentParser()
101
- parser.add_argument("--skip_default_models", action="store_true")
102
- parser.add_argument("--only_infer", action="store_true")
103
- parser.add_argument(
104
- "--dataset_root",
105
- type=str,
106
- help="Dataset root path (default: Data)",
107
- default=None,
108
- )
109
- parser.add_argument(
110
- "--assets_root",
111
- type=str,
112
- help="Assets root path (default: model_assets)",
113
- default=None,
114
- )
115
- args = parser.parse_args()
116
-
117
- download_bert_models()
118
-
119
- # if not args.skip_default_models:
120
- # download_default_models()
121
- if not args.only_infer:
122
- download_slm_model()
123
- download_pretrained_models()
124
- download_jp_extra_pretrained_models()
125
-
126
- # If configs/paths.yml not exists, create it
127
- default_paths_yml = Path("configs/default_paths.yml")
128
- paths_yml = Path("configs/paths.yml")
129
- if not paths_yml.exists():
130
- shutil.copy(default_paths_yml, paths_yml)
131
-
132
- if args.dataset_root is None and args.assets_root is None:
133
- return
134
-
135
- # Change default paths if necessary
136
- with open(paths_yml, encoding="utf-8") as f:
137
- yml_data = yaml.safe_load(f)
138
- if args.assets_root is not None:
139
- yml_data["assets_root"] = args.assets_root
140
- if args.dataset_root is not None:
141
- yml_data["dataset_root"] = args.dataset_root
142
- with open(paths_yml, "w", encoding="utf-8") as f:
143
- yaml.dump(yml_data, f, allow_unicode=True)
144
-
145
 
146
  if __name__ == "__main__":
147
- main()
 
 
1
+ import os
2
  import json
 
3
  from pathlib import Path
 
 
4
  from huggingface_hub import hf_hub_download
5
+ import logging
6
+
7
+ # loggerをセットアップ (元のコードに合わせて)
8
+ # from style_bert_vits2.logging import logger でも可
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
+
12
+ def create_symlink_if_not_exists(repo_id, filename, link_dir):
13
+ """
14
+ Hugging Face Hubからファイルをキャッシュにダウンロードし、
15
+ 指定された場所にシンボリックリンクを作成するヘルパー関数。
16
+ """
17
+ # リンクを配置したいパスを定義 (例: pretrained/G_0.safetensors)
18
+ link_path = Path(link_dir) / filename
19
+
20
+ # リンクが既に存在する場合はスキップ
21
+ if link_path.exists():
22
+ return
23
 
24
+ # リンク先の親ディレクトリがなければ作成
25
+ link_path.parent.mkdir(parents=True, exist_ok=True)
26
+
27
+ logger.info(f"Downloading {repo_id}/{filename}")
28
+
29
+ # Hugging Face Hubからファイルをダウンロード(キャッシュ優先)し、実際のパスを取得
30
+ try:
31
+ actual_file_path = hf_hub_download(repo_id=repo_id, filename=filename)
32
+ except Exception as e:
33
+ logger.error(f"Failed to download {filename}: {e}")
34
+ return
35
 
36
+ # シンボリックリンクを作成
37
+ logger.info(f"Creating symlink: {link_path} -> {actual_file_path}")
38
+ os.symlink(actual_file_path, link_path)
39
 
40
+ def setup_bert_models():
41
+ logger.info("Setting up BERT models...")
42
  with open("bert/bert_models.json", encoding="utf-8") as fp:
43
  models = json.load(fp)
44
+
45
+ for model_name, model_info in models.items():
46
+ repo_id = model_info["repo_id"]
47
+ # BERTモデルは 'bert/モデル名/' というサブディレクトリに配置
48
+ link_dir = Path("bert") / model_name
49
+ for file in model_info["files"]:
50
+ create_symlink_if_not_exists(repo_id, file, link_dir)
51
+
52
+ def setup_slm_model():
53
+ logger.info("Setting up SLM model...")
54
+ # SLMモデルは 'slm/wavlm-base-plus/' というサブディレクトリに配置
55
+ create_symlink_if_not_exists(
56
+ repo_id="microsoft/wavlm-base-plus",
57
+ filename="pytorch_model.bin",
58
+ link_dir="slm/wavlm-base-plus/"
59
+ )
60
 
61
+ def setup_pretrained_models():
62
+ logger.info("Setting up Pretrained models...")
63
+ repo_id = "litagin/Style-Bert-VITS2-1.0-base"
64
  files = ["G_0.safetensors", "D_0.safetensors", "DUR_0.safetensors"]
 
65
  for file in files:
66
+ create_symlink_if_not_exists(repo_id, file, "pretrained")
 
 
 
 
 
67
 
68
+ def setup_jp_extra_pretrained_models():
69
+ logger.info("Setting up JP-Extra Pretrained models...")
70
+ repo_id = "litagin/Style-Bert-VITS2-2.0-base-JP-Extra"
71
  files = ["G_0.safetensors", "D_0.safetensors", "WD_0.safetensors"]
 
72
  for file in files:
73
+ create_symlink_if_not_exists(repo_id, file, "pretrained_jp_extra")
74
+
75
+ def setup_default_models():
76
+ logger.info("Setting up default speaker models...")
77
+ # 先ほどの回答で示した `model_assets` 以下のモデル
78
+ models_to_link = {
79
+ "litagin/style_bert_vits2_jvnv": [
80
+ "jvnv-F1-jp/config.json", "jvnv-F1-jp/jvnv-F1-jp_e160_s14000.safetensors", "jvnv-F1-jp/style_vectors.npy",
81
+ # ... 他のjvnvモデルも同様に追加
82
+ ],
83
+ "litagin/sbv2_koharune_ami": ["koharune-ami/config.json", "koharune-ami/style_vectors.npy", "koharune-ami/koharune-ami.safetensors"],
84
+ "litagin/sbv2_amitaro": ["amitaro/config.json", "amitaro/style_vectors.npy", "amitaro/amitaro.safetensors"],
85
+ }
86
+
87
+ for repo_id, files_in_repo in models_to_link.items():
88
+ for file_path in files_in_repo:
89
+ # ここではファイル名だけでなくディレクトリ構造も含むパスを渡す
90
+ create_symlink_if_not_exists(repo_id, file_path, "model_assets")
91
+
92
+ def main(skip_default_models=False, only_infer=False):
93
+ """
94
+ アプリケーションが必要とする全てのモデルのシンボリックリンクをセットアップする。
95
+ 元のスクリプトの引数も模倣。
96
+ """
97
+ logger.info("Starting model setup...")
98
+
99
+ # 常に必要なモデル
100
+ setup_bert_models()
101
+ if not skip_default_models:
102
+ setup_default_models()
103
+
104
+ # 推論時のみ不要なモデル
105
+ if not only_infer:
106
+ setup_slm_model()
107
+ setup_pretrained_models()
108
+ setup_jp_extra_pretrained_models()
109
+
110
+ logger.info("All models are set up and ready.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  if __name__ == "__main__":
113
+ # Gradioアプリから呼び出す場合は引数なしで実行
114
+ main()