Adoetz commited on
Commit
bcd030b
·
verified ·
1 Parent(s): b30d9e8

Update TTS/utils/manage.py

Browse files
Files changed (1) hide show
  1. TTS/utils/manage.py +616 -621
TTS/utils/manage.py CHANGED
@@ -1,621 +1,616 @@
1
- import json
2
- import os
3
- import re
4
- import tarfile
5
- import zipfile
6
- from pathlib import Path
7
- from shutil import copyfile, rmtree
8
- from typing import Dict, List, Tuple
9
-
10
- import fsspec
11
- import requests
12
- from tqdm import tqdm
13
-
14
- from TTS.config import load_config, read_json_with_comments
15
- from TTS.utils.generic_utils import get_user_data_dir
16
-
17
- LICENSE_URLS = {
18
- "cc by-nc-nd 4.0": "https://creativecommons.org/licenses/by-nc-nd/4.0/",
19
- "mpl": "https://www.mozilla.org/en-US/MPL/2.0/",
20
- "mpl2": "https://www.mozilla.org/en-US/MPL/2.0/",
21
- "mpl 2.0": "https://www.mozilla.org/en-US/MPL/2.0/",
22
- "mit": "https://choosealicense.com/licenses/mit/",
23
- "apache 2.0": "https://choosealicense.com/licenses/apache-2.0/",
24
- "apache2": "https://choosealicense.com/licenses/apache-2.0/",
25
- "cc-by-sa 4.0": "https://creativecommons.org/licenses/by-sa/4.0/",
26
- "cpml": "https://coqui.ai/cpml.txt",
27
- }
28
-
29
-
30
- class ModelManager(object):
31
- tqdm_progress = None
32
- """Manage TTS models defined in .models.json.
33
- It provides an interface to list and download
34
- models defines in '.model.json'
35
-
36
- Models are downloaded under '.TTS' folder in the user's
37
- home path.
38
-
39
- Args:
40
- models_file (str): path to .model.json file. Defaults to None.
41
- output_prefix (str): prefix to `tts` to download models. Defaults to None
42
- progress_bar (bool): print a progress bar when donwloading a file. Defaults to False.
43
- verbose (bool): print info. Defaults to True.
44
- """
45
-
46
- def __init__(self, models_file=None, output_prefix=None, progress_bar=False, verbose=True):
47
- super().__init__()
48
- self.progress_bar = progress_bar
49
- self.verbose = verbose
50
- if output_prefix is None:
51
- self.output_prefix = get_user_data_dir("tts")
52
- else:
53
- self.output_prefix = os.path.join(output_prefix, "tts")
54
- self.models_dict = None
55
- if models_file is not None:
56
- self.read_models_file(models_file)
57
- else:
58
- # try the default location
59
- path = Path(__file__).parent / "../.models.json"
60
- self.read_models_file(path)
61
-
62
- def read_models_file(self, file_path):
63
- """Read .models.json as a dict
64
-
65
- Args:
66
- file_path (str): path to .models.json.
67
- """
68
- self.models_dict = read_json_with_comments(file_path)
69
-
70
- def _list_models(self, model_type, model_count=0):
71
- if self.verbose:
72
- print("\n Name format: type/language/dataset/model")
73
- model_list = []
74
- for lang in self.models_dict[model_type]:
75
- for dataset in self.models_dict[model_type][lang]:
76
- for model in self.models_dict[model_type][lang][dataset]:
77
- model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
78
- output_path = os.path.join(self.output_prefix, model_full_name)
79
- if self.verbose:
80
- if os.path.exists(output_path):
81
- print(f" {model_count}: {model_type}/{lang}/{dataset}/{model} [already downloaded]")
82
- else:
83
- print(f" {model_count}: {model_type}/{lang}/{dataset}/{model}")
84
- model_list.append(f"{model_type}/{lang}/{dataset}/{model}")
85
- model_count += 1
86
- return model_list
87
-
88
- def _list_for_model_type(self, model_type):
89
- models_name_list = []
90
- model_count = 1
91
- models_name_list.extend(self._list_models(model_type, model_count))
92
- return models_name_list
93
-
94
- def list_models(self):
95
- models_name_list = []
96
- model_count = 1
97
- for model_type in self.models_dict:
98
- model_list = self._list_models(model_type, model_count)
99
- models_name_list.extend(model_list)
100
- return models_name_list
101
-
102
- def model_info_by_idx(self, model_query):
103
- """Print the description of the model from .models.json file using model_idx
104
-
105
- Args:
106
- model_query (str): <model_tye>/<model_idx>
107
- """
108
- model_name_list = []
109
- model_type, model_query_idx = model_query.split("/")
110
- try:
111
- model_query_idx = int(model_query_idx)
112
- if model_query_idx <= 0:
113
- print("> model_query_idx should be a positive integer!")
114
- return
115
- except:
116
- print("> model_query_idx should be an integer!")
117
- return
118
- model_count = 0
119
- if model_type in self.models_dict:
120
- for lang in self.models_dict[model_type]:
121
- for dataset in self.models_dict[model_type][lang]:
122
- for model in self.models_dict[model_type][lang][dataset]:
123
- model_name_list.append(f"{model_type}/{lang}/{dataset}/{model}")
124
- model_count += 1
125
- else:
126
- print(f"> model_type {model_type} does not exist in the list.")
127
- return
128
- if model_query_idx > model_count:
129
- print(f"model query idx exceeds the number of available models [{model_count}] ")
130
- else:
131
- model_type, lang, dataset, model = model_name_list[model_query_idx - 1].split("/")
132
- print(f"> model type : {model_type}")
133
- print(f"> language supported : {lang}")
134
- print(f"> dataset used : {dataset}")
135
- print(f"> model name : {model}")
136
- if "description" in self.models_dict[model_type][lang][dataset][model]:
137
- print(f"> description : {self.models_dict[model_type][lang][dataset][model]['description']}")
138
- else:
139
- print("> description : coming soon")
140
- if "default_vocoder" in self.models_dict[model_type][lang][dataset][model]:
141
- print(f"> default_vocoder : {self.models_dict[model_type][lang][dataset][model]['default_vocoder']}")
142
-
143
- def model_info_by_full_name(self, model_query_name):
144
- """Print the description of the model from .models.json file using model_full_name
145
-
146
- Args:
147
- model_query_name (str): Format is <model_type>/<language>/<dataset>/<model_name>
148
- """
149
- model_type, lang, dataset, model = model_query_name.split("/")
150
- if model_type in self.models_dict:
151
- if lang in self.models_dict[model_type]:
152
- if dataset in self.models_dict[model_type][lang]:
153
- if model in self.models_dict[model_type][lang][dataset]:
154
- print(f"> model type : {model_type}")
155
- print(f"> language supported : {lang}")
156
- print(f"> dataset used : {dataset}")
157
- print(f"> model name : {model}")
158
- if "description" in self.models_dict[model_type][lang][dataset][model]:
159
- print(
160
- f"> description : {self.models_dict[model_type][lang][dataset][model]['description']}"
161
- )
162
- else:
163
- print("> description : coming soon")
164
- if "default_vocoder" in self.models_dict[model_type][lang][dataset][model]:
165
- print(
166
- f"> default_vocoder : {self.models_dict[model_type][lang][dataset][model]['default_vocoder']}"
167
- )
168
- else:
169
- print(f"> model {model} does not exist for {model_type}/{lang}/{dataset}.")
170
- else:
171
- print(f"> dataset {dataset} does not exist for {model_type}/{lang}.")
172
- else:
173
- print(f"> lang {lang} does not exist for {model_type}.")
174
- else:
175
- print(f"> model_type {model_type} does not exist in the list.")
176
-
177
- def list_tts_models(self):
178
- """Print all `TTS` models and return a list of model names
179
-
180
- Format is `language/dataset/model`
181
- """
182
- return self._list_for_model_type("tts_models")
183
-
184
- def list_vocoder_models(self):
185
- """Print all the `vocoder` models and return a list of model names
186
-
187
- Format is `language/dataset/model`
188
- """
189
- return self._list_for_model_type("vocoder_models")
190
-
191
- def list_vc_models(self):
192
- """Print all the voice conversion models and return a list of model names
193
-
194
- Format is `language/dataset/model`
195
- """
196
- return self._list_for_model_type("voice_conversion_models")
197
-
198
- def list_langs(self):
199
- """Print all the available languages"""
200
- print(" Name format: type/language")
201
- for model_type in self.models_dict:
202
- for lang in self.models_dict[model_type]:
203
- print(f" >: {model_type}/{lang} ")
204
-
205
- def list_datasets(self):
206
- """Print all the datasets"""
207
- print(" Name format: type/language/dataset")
208
- for model_type in self.models_dict:
209
- for lang in self.models_dict[model_type]:
210
- for dataset in self.models_dict[model_type][lang]:
211
- print(f" >: {model_type}/{lang}/{dataset}")
212
-
213
- @staticmethod
214
- def print_model_license(model_item: Dict):
215
- """Print the license of a model
216
-
217
- Args:
218
- model_item (dict): model item in the models.json
219
- """
220
- if "license" in model_item and model_item["license"].strip() != "":
221
- print(f" > Model's license - {model_item['license']}")
222
- if model_item["license"].lower() in LICENSE_URLS:
223
- print(f" > Check {LICENSE_URLS[model_item['license'].lower()]} for more info.")
224
- else:
225
- print(" > Check https://opensource.org/licenses for more info.")
226
- else:
227
- print(" > Model's license - No license information available")
228
-
229
- def _download_github_model(self, model_item: Dict, output_path: str):
230
- if isinstance(model_item["github_rls_url"], list):
231
- self._download_model_files(model_item["github_rls_url"], output_path, self.progress_bar)
232
- else:
233
- self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar)
234
-
235
- def _download_hf_model(self, model_item: Dict, output_path: str):
236
- if isinstance(model_item["hf_url"], list):
237
- self._download_model_files(model_item["hf_url"], output_path, self.progress_bar)
238
- else:
239
- self._download_zip_file(model_item["hf_url"], output_path, self.progress_bar)
240
-
241
- def download_fairseq_model(self, model_name, output_path):
242
- URI_PREFIX = "https://coqui.gateway.scarf.sh/fairseq/"
243
- _, lang, _, _ = model_name.split("/")
244
- model_download_uri = os.path.join(URI_PREFIX, f"{lang}.tar.gz")
245
- self._download_tar_file(model_download_uri, output_path, self.progress_bar)
246
-
247
- @staticmethod
248
- def set_model_url(model_item: Dict):
249
- model_item["model_url"] = None
250
- if "github_rls_url" in model_item:
251
- model_item["model_url"] = model_item["github_rls_url"]
252
- elif "hf_url" in model_item:
253
- model_item["model_url"] = model_item["hf_url"]
254
- elif "fairseq" in model_item["model_name"]:
255
- model_item["model_url"] = "https://coqui.gateway.scarf.sh/fairseq/"
256
- elif "xtts" in model_item["model_name"]:
257
- model_item["model_url"] = "https://coqui.gateway.scarf.sh/xtts/"
258
- return model_item
259
-
260
- def _set_model_item(self, model_name):
261
- # fetch model info from the dict
262
- if "fairseq" in model_name:
263
- model_type = "tts_models"
264
- lang = model_name.split("/")[1]
265
- model_item = {
266
- "model_type": "tts_models",
267
- "license": "CC BY-NC 4.0",
268
- "default_vocoder": None,
269
- "author": "fairseq",
270
- "description": "this model is released by Meta under Fairseq repo. Visit https://github.com/facebookresearch/fairseq/tree/main/examples/mms for more info.",
271
- }
272
- model_item["model_name"] = model_name
273
- elif "xtts" in model_name and len(model_name.split("/")) != 4:
274
- # loading xtts models with only model name (e.g. xtts_v2.0.2)
275
- # check model name has the version number with regex
276
- version_regex = r"v\d+\.\d+\.\d+"
277
- if re.search(version_regex, model_name):
278
- model_version = model_name.split("_")[-1]
279
- else:
280
- model_version = "main"
281
- model_type = "tts_models"
282
- lang = "multilingual"
283
- dataset = "multi-dataset"
284
- model = model_name
285
- model_item = {
286
- "default_vocoder": None,
287
- "license": "CPML",
288
- "contact": "[email protected]",
289
- "tos_required": True,
290
- "hf_url": [
291
- f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/model.pth",
292
- f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/config.json",
293
- f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/vocab.json",
294
- f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/hash.md5",
295
- f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/speakers_xtts.pth",
296
- ],
297
- }
298
- else:
299
- # get model from models.json
300
- model_type, lang, dataset, model = model_name.split("/")
301
- model_item = self.models_dict[model_type][lang][dataset][model]
302
- model_item["model_type"] = model_type
303
-
304
- model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
305
- md5hash = model_item["model_hash"] if "model_hash" in model_item else None
306
- model_item = self.set_model_url(model_item)
307
- return model_item, model_full_name, model, md5hash
308
-
309
- @staticmethod
310
- def ask_tos(model_full_path):
311
- """Ask the user to agree to the terms of service"""
312
- tos_path = os.path.join(model_full_path, "tos_agreed.txt")
313
- print(" > You must confirm the following:")
314
- print(' | > "I have purchased a commercial license from Coqui: [email protected]"')
315
- print(' | > "Otherwise, I agree to the terms of the non-commercial CPML: https://coqui.ai/cpml" - [y/n]')
316
- answer = input(" | | > ")
317
- if answer.lower() == "y":
318
- with open(tos_path, "w", encoding="utf-8") as f:
319
- f.write("I have read, understood and agreed to the Terms and Conditions.")
320
- return True
321
- return False
322
-
323
- @staticmethod
324
- def tos_agreed(model_item, model_full_path):
325
- """Check if the user has agreed to the terms of service"""
326
- if "tos_required" in model_item and model_item["tos_required"]:
327
- tos_path = os.path.join(model_full_path, "tos_agreed.txt")
328
- if os.path.exists(tos_path) or os.environ.get("COQUI_TOS_AGREED") == "1":
329
- return True
330
- return False
331
- return True
332
-
333
- def create_dir_and_download_model(self, model_name, model_item, output_path):
334
- os.makedirs(output_path, exist_ok=True)
335
- # handle TOS
336
- if not self.tos_agreed(model_item, output_path):
337
- if not self.ask_tos(output_path):
338
- os.rmdir(output_path)
339
- raise Exception(" [!] You must agree to the terms of service to use this model.")
340
- print(f" > Downloading model to {output_path}")
341
- try:
342
- if "fairseq" in model_name:
343
- self.download_fairseq_model(model_name, output_path)
344
- elif "github_rls_url" in model_item:
345
- self._download_github_model(model_item, output_path)
346
- elif "hf_url" in model_item:
347
- self._download_hf_model(model_item, output_path)
348
-
349
- except requests.RequestException as e:
350
- print(f" > Failed to download the model file to {output_path}")
351
- rmtree(output_path)
352
- raise e
353
- self.print_model_license(model_item=model_item)
354
-
355
- def check_if_configs_are_equal(self, model_name, model_item, output_path):
356
- with fsspec.open(self._find_files(output_path)[1], "r", encoding="utf-8") as f:
357
- config_local = json.load(f)
358
- remote_url = None
359
- for url in model_item["hf_url"]:
360
- if "config.json" in url:
361
- remote_url = url
362
- break
363
-
364
- with fsspec.open(remote_url, "r", encoding="utf-8") as f:
365
- config_remote = json.load(f)
366
-
367
- if not config_local == config_remote:
368
- print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...")
369
- self.create_dir_and_download_model(model_name, model_item, output_path)
370
-
371
- def download_model(self, model_name):
372
- """Download model files given the full model name.
373
- Model name is in the format
374
- 'type/language/dataset/model'
375
- e.g. 'tts_model/en/ljspeech/tacotron'
376
-
377
- Every model must have the following files:
378
- - *.pth : pytorch model checkpoint file.
379
- - config.json : model config file.
380
- - scale_stats.npy (if exist): scale values for preprocessing.
381
-
382
- Args:
383
- model_name (str): model name as explained above.
384
- """
385
- model_item, model_full_name, model, md5sum = self._set_model_item(model_name)
386
- # set the model specific output path
387
- output_path = os.path.join(self.output_prefix, model_full_name)
388
- if os.path.exists(output_path):
389
- if md5sum is not None:
390
- md5sum_file = os.path.join(output_path, "hash.md5")
391
- if os.path.isfile(md5sum_file):
392
- with open(md5sum_file, mode="r") as f:
393
- if not f.read() == md5sum:
394
- print(f" > {model_name} has been updated, clearing model cache...")
395
- self.create_dir_and_download_model(model_name, model_item, output_path)
396
- else:
397
- print(f" > {model_name} is already downloaded.")
398
- else:
399
- print(f" > {model_name} has been updated, clearing model cache...")
400
- self.create_dir_and_download_model(model_name, model_item, output_path)
401
- # if the configs are different, redownload it
402
- # ToDo: we need a better way to handle it
403
- if "xtts" in model_name:
404
- try:
405
- self.check_if_configs_are_equal(model_name, model_item, output_path)
406
- except:
407
- pass
408
- else:
409
- print(f" > {model_name} is already downloaded.")
410
- else:
411
- self.create_dir_and_download_model(model_name, model_item, output_path)
412
-
413
- # find downloaded files
414
- output_model_path = output_path
415
- output_config_path = None
416
- if (
417
- model not in ["tortoise-v2", "bark"] and "fairseq" not in model_name and "xtts" not in model_name
418
- ): # TODO:This is stupid but don't care for now.
419
- output_model_path, output_config_path = self._find_files(output_path)
420
- # update paths in the config.json
421
- self._update_paths(output_path, output_config_path)
422
- return output_model_path, output_config_path, model_item
423
-
424
- @staticmethod
425
- def _find_files(output_path: str) -> Tuple[str, str]:
426
- """Find the model and config files in the output path
427
-
428
- Args:
429
- output_path (str): path to the model files
430
-
431
- Returns:
432
- Tuple[str, str]: path to the model file and config file
433
- """
434
- model_file = None
435
- config_file = None
436
- for file_name in os.listdir(output_path):
437
- if file_name in ["model_file.pth", "model_file.pth.tar", "model.pth"]:
438
- model_file = os.path.join(output_path, file_name)
439
- elif file_name == "config.json":
440
- config_file = os.path.join(output_path, file_name)
441
- if model_file is None:
442
- raise ValueError(" [!] Model file not found in the output path")
443
- if config_file is None:
444
- raise ValueError(" [!] Config file not found in the output path")
445
- return model_file, config_file
446
-
447
- @staticmethod
448
- def _find_speaker_encoder(output_path: str) -> str:
449
- """Find the speaker encoder file in the output path
450
-
451
- Args:
452
- output_path (str): path to the model files
453
-
454
- Returns:
455
- str: path to the speaker encoder file
456
- """
457
- speaker_encoder_file = None
458
- for file_name in os.listdir(output_path):
459
- if file_name in ["model_se.pth", "model_se.pth.tar"]:
460
- speaker_encoder_file = os.path.join(output_path, file_name)
461
- return speaker_encoder_file
462
-
463
- def _update_paths(self, output_path: str, config_path: str) -> None:
464
- """Update paths for certain files in config.json after download.
465
-
466
- Args:
467
- output_path (str): local path the model is downloaded to.
468
- config_path (str): local config.json path.
469
- """
470
- output_stats_path = os.path.join(output_path, "scale_stats.npy")
471
- output_d_vector_file_path = os.path.join(output_path, "speakers.json")
472
- output_d_vector_file_pth_path = os.path.join(output_path, "speakers.pth")
473
- output_speaker_ids_file_path = os.path.join(output_path, "speaker_ids.json")
474
- output_speaker_ids_file_pth_path = os.path.join(output_path, "speaker_ids.pth")
475
- speaker_encoder_config_path = os.path.join(output_path, "config_se.json")
476
- speaker_encoder_model_path = self._find_speaker_encoder(output_path)
477
-
478
- # update the scale_path.npy file path in the model config.json
479
- self._update_path("audio.stats_path", output_stats_path, config_path)
480
-
481
- # update the speakers.json file path in the model config.json to the current path
482
- self._update_path("d_vector_file", output_d_vector_file_path, config_path)
483
- self._update_path("d_vector_file", output_d_vector_file_pth_path, config_path)
484
- self._update_path("model_args.d_vector_file", output_d_vector_file_path, config_path)
485
- self._update_path("model_args.d_vector_file", output_d_vector_file_pth_path, config_path)
486
-
487
- # update the speaker_ids.json file path in the model config.json to the current path
488
- self._update_path("speakers_file", output_speaker_ids_file_path, config_path)
489
- self._update_path("speakers_file", output_speaker_ids_file_pth_path, config_path)
490
- self._update_path("model_args.speakers_file", output_speaker_ids_file_path, config_path)
491
- self._update_path("model_args.speakers_file", output_speaker_ids_file_pth_path, config_path)
492
-
493
- # update the speaker_encoder file path in the model config.json to the current path
494
- self._update_path("speaker_encoder_model_path", speaker_encoder_model_path, config_path)
495
- self._update_path("model_args.speaker_encoder_model_path", speaker_encoder_model_path, config_path)
496
- self._update_path("speaker_encoder_config_path", speaker_encoder_config_path, config_path)
497
- self._update_path("model_args.speaker_encoder_config_path", speaker_encoder_config_path, config_path)
498
-
499
- @staticmethod
500
- def _update_path(field_name, new_path, config_path):
501
- """Update the path in the model config.json for the current environment after download"""
502
- if new_path and os.path.exists(new_path):
503
- config = load_config(config_path)
504
- field_names = field_name.split(".")
505
- if len(field_names) > 1:
506
- # field name points to a sub-level field
507
- sub_conf = config
508
- for fd in field_names[:-1]:
509
- if fd in sub_conf:
510
- sub_conf = sub_conf[fd]
511
- else:
512
- return
513
- if isinstance(sub_conf[field_names[-1]], list):
514
- sub_conf[field_names[-1]] = [new_path]
515
- else:
516
- sub_conf[field_names[-1]] = new_path
517
- else:
518
- # field name points to a top-level field
519
- if not field_name in config:
520
- return
521
- if isinstance(config[field_name], list):
522
- config[field_name] = [new_path]
523
- else:
524
- config[field_name] = new_path
525
- config.save_json(config_path)
526
-
527
- @staticmethod
528
- def _download_zip_file(file_url, output_folder, progress_bar):
529
- """Download the github releases"""
530
- # download the file
531
- r = requests.get(file_url, stream=True)
532
- # extract the file
533
- try:
534
- total_size_in_bytes = int(r.headers.get("content-length", 0))
535
- block_size = 1024 # 1 Kibibyte
536
- if progress_bar:
537
- ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
538
- temp_zip_name = os.path.join(output_folder, file_url.split("/")[-1])
539
- with open(temp_zip_name, "wb") as file:
540
- for data in r.iter_content(block_size):
541
- if progress_bar:
542
- ModelManager.tqdm_progress.update(len(data))
543
- file.write(data)
544
- with zipfile.ZipFile(temp_zip_name) as z:
545
- z.extractall(output_folder)
546
- os.remove(temp_zip_name) # delete zip after extract
547
- except zipfile.BadZipFile:
548
- print(f" > Error: Bad zip file - {file_url}")
549
- raise zipfile.BadZipFile # pylint: disable=raise-missing-from
550
- # move the files to the outer path
551
- for file_path in z.namelist():
552
- src_path = os.path.join(output_folder, file_path)
553
- if os.path.isfile(src_path):
554
- dst_path = os.path.join(output_folder, os.path.basename(file_path))
555
- if src_path != dst_path:
556
- copyfile(src_path, dst_path)
557
- # remove redundant (hidden or not) folders
558
- for file_path in z.namelist():
559
- if os.path.isdir(os.path.join(output_folder, file_path)):
560
- rmtree(os.path.join(output_folder, file_path))
561
-
562
- @staticmethod
563
- def _download_tar_file(file_url, output_folder, progress_bar):
564
- """Download the github releases"""
565
- # download the file
566
- r = requests.get(file_url, stream=True)
567
- # extract the file
568
- try:
569
- total_size_in_bytes = int(r.headers.get("content-length", 0))
570
- block_size = 1024 # 1 Kibibyte
571
- if progress_bar:
572
- ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
573
- temp_tar_name = os.path.join(output_folder, file_url.split("/")[-1])
574
- with open(temp_tar_name, "wb") as file:
575
- for data in r.iter_content(block_size):
576
- if progress_bar:
577
- ModelManager.tqdm_progress.update(len(data))
578
- file.write(data)
579
- with tarfile.open(temp_tar_name) as t:
580
- t.extractall(output_folder)
581
- tar_names = t.getnames()
582
- os.remove(temp_tar_name) # delete tar after extract
583
- except tarfile.ReadError:
584
- print(f" > Error: Bad tar file - {file_url}")
585
- raise tarfile.ReadError # pylint: disable=raise-missing-from
586
- # move the files to the outer path
587
- for file_path in os.listdir(os.path.join(output_folder, tar_names[0])):
588
- src_path = os.path.join(output_folder, tar_names[0], file_path)
589
- dst_path = os.path.join(output_folder, os.path.basename(file_path))
590
- if src_path != dst_path:
591
- copyfile(src_path, dst_path)
592
- # remove the extracted folder
593
- rmtree(os.path.join(output_folder, tar_names[0]))
594
-
595
- @staticmethod
596
- def _download_model_files(file_urls, output_folder, progress_bar):
597
- """Download the github releases"""
598
- for file_url in file_urls:
599
- # download the file
600
- r = requests.get(file_url, stream=True)
601
- # extract the file
602
- bease_filename = file_url.split("/")[-1]
603
- temp_zip_name = os.path.join(output_folder, bease_filename)
604
- total_size_in_bytes = int(r.headers.get("content-length", 0))
605
- block_size = 1024 # 1 Kibibyte
606
- with open(temp_zip_name, "wb") as file:
607
- if progress_bar:
608
- ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
609
- for data in r.iter_content(block_size):
610
- if progress_bar:
611
- ModelManager.tqdm_progress.update(len(data))
612
- file.write(data)
613
-
614
- @staticmethod
615
- def _check_dict_key(my_dict, key):
616
- if key in my_dict.keys() and my_dict[key] is not None:
617
- if not isinstance(key, str):
618
- return True
619
- if isinstance(key, str) and len(my_dict[key]) > 0:
620
- return True
621
- return False
 
1
+ import json
2
+ import os
3
+ import re
4
+ import tarfile
5
+ import zipfile
6
+ from pathlib import Path
7
+ from shutil import copyfile, rmtree
8
+ from typing import Dict, List, Tuple
9
+
10
+ import fsspec
11
+ import requests
12
+ from tqdm import tqdm
13
+
14
+ from TTS.config import load_config, read_json_with_comments
15
+ from TTS.utils.generic_utils import get_user_data_dir
16
+
17
+ LICENSE_URLS = {
18
+ "cc by-nc-nd 4.0": "https://creativecommons.org/licenses/by-nc-nd/4.0/",
19
+ "mpl": "https://www.mozilla.org/en-US/MPL/2.0/",
20
+ "mpl2": "https://www.mozilla.org/en-US/MPL/2.0/",
21
+ "mpl 2.0": "https://www.mozilla.org/en-US/MPL/2.0/",
22
+ "mit": "https://choosealicense.com/licenses/mit/",
23
+ "apache 2.0": "https://choosealicense.com/licenses/apache-2.0/",
24
+ "apache2": "https://choosealicense.com/licenses/apache-2.0/",
25
+ "cc-by-sa 4.0": "https://creativecommons.org/licenses/by-sa/4.0/",
26
+ "cpml": "https://coqui.ai/cpml.txt",
27
+ }
28
+
29
+
30
+ class ModelManager(object):
31
+ tqdm_progress = None
32
+ """Manage TTS models defined in .models.json.
33
+ It provides an interface to list and download
34
+ models defines in '.model.json'
35
+
36
+ Models are downloaded under '.TTS' folder in the user's
37
+ home path.
38
+
39
+ Args:
40
+ models_file (str): path to .model.json file. Defaults to None.
41
+ output_prefix (str): prefix to `tts` to download models. Defaults to None
42
+ progress_bar (bool): print a progress bar when donwloading a file. Defaults to False.
43
+ verbose (bool): print info. Defaults to True.
44
+ """
45
+
46
+ def __init__(self, models_file=None, output_prefix=None, progress_bar=False, verbose=True):
47
+ super().__init__()
48
+ self.progress_bar = progress_bar
49
+ self.verbose = verbose
50
+ if output_prefix is None:
51
+ self.output_prefix = get_user_data_dir("tts")
52
+ else:
53
+ self.output_prefix = os.path.join(output_prefix, "tts")
54
+ self.models_dict = None
55
+ if models_file is not None:
56
+ self.read_models_file(models_file)
57
+ else:
58
+ # try the default location
59
+ path = Path(__file__).parent / "../.models.json"
60
+ self.read_models_file(path)
61
+
62
+ def read_models_file(self, file_path):
63
+ """Read .models.json as a dict
64
+
65
+ Args:
66
+ file_path (str): path to .models.json.
67
+ """
68
+ self.models_dict = read_json_with_comments(file_path)
69
+
70
+ def _list_models(self, model_type, model_count=0):
71
+ if self.verbose:
72
+ print("\n Name format: type/language/dataset/model")
73
+ model_list = []
74
+ for lang in self.models_dict[model_type]:
75
+ for dataset in self.models_dict[model_type][lang]:
76
+ for model in self.models_dict[model_type][lang][dataset]:
77
+ model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
78
+ output_path = os.path.join(self.output_prefix, model_full_name)
79
+ if self.verbose:
80
+ if os.path.exists(output_path):
81
+ print(f" {model_count}: {model_type}/{lang}/{dataset}/{model} [already downloaded]")
82
+ else:
83
+ print(f" {model_count}: {model_type}/{lang}/{dataset}/{model}")
84
+ model_list.append(f"{model_type}/{lang}/{dataset}/{model}")
85
+ model_count += 1
86
+ return model_list
87
+
88
+ def _list_for_model_type(self, model_type):
89
+ models_name_list = []
90
+ model_count = 1
91
+ models_name_list.extend(self._list_models(model_type, model_count))
92
+ return models_name_list
93
+
94
+ def list_models(self):
95
+ models_name_list = []
96
+ model_count = 1
97
+ for model_type in self.models_dict:
98
+ model_list = self._list_models(model_type, model_count)
99
+ models_name_list.extend(model_list)
100
+ return models_name_list
101
+
102
+ def model_info_by_idx(self, model_query):
103
+ """Print the description of the model from .models.json file using model_idx
104
+
105
+ Args:
106
+ model_query (str): <model_tye>/<model_idx>
107
+ """
108
+ model_name_list = []
109
+ model_type, model_query_idx = model_query.split("/")
110
+ try:
111
+ model_query_idx = int(model_query_idx)
112
+ if model_query_idx <= 0:
113
+ print("> model_query_idx should be a positive integer!")
114
+ return
115
+ except:
116
+ print("> model_query_idx should be an integer!")
117
+ return
118
+ model_count = 0
119
+ if model_type in self.models_dict:
120
+ for lang in self.models_dict[model_type]:
121
+ for dataset in self.models_dict[model_type][lang]:
122
+ for model in self.models_dict[model_type][lang][dataset]:
123
+ model_name_list.append(f"{model_type}/{lang}/{dataset}/{model}")
124
+ model_count += 1
125
+ else:
126
+ print(f"> model_type {model_type} does not exist in the list.")
127
+ return
128
+ if model_query_idx > model_count:
129
+ print(f"model query idx exceeds the number of available models [{model_count}] ")
130
+ else:
131
+ model_type, lang, dataset, model = model_name_list[model_query_idx - 1].split("/")
132
+ print(f"> model type : {model_type}")
133
+ print(f"> language supported : {lang}")
134
+ print(f"> dataset used : {dataset}")
135
+ print(f"> model name : {model}")
136
+ if "description" in self.models_dict[model_type][lang][dataset][model]:
137
+ print(f"> description : {self.models_dict[model_type][lang][dataset][model]['description']}")
138
+ else:
139
+ print("> description : coming soon")
140
+ if "default_vocoder" in self.models_dict[model_type][lang][dataset][model]:
141
+ print(f"> default_vocoder : {self.models_dict[model_type][lang][dataset][model]['default_vocoder']}")
142
+
143
+ def model_info_by_full_name(self, model_query_name):
144
+ """Print the description of the model from .models.json file using model_full_name
145
+
146
+ Args:
147
+ model_query_name (str): Format is <model_type>/<language>/<dataset>/<model_name>
148
+ """
149
+ model_type, lang, dataset, model = model_query_name.split("/")
150
+ if model_type in self.models_dict:
151
+ if lang in self.models_dict[model_type]:
152
+ if dataset in self.models_dict[model_type][lang]:
153
+ if model in self.models_dict[model_type][lang][dataset]:
154
+ print(f"> model type : {model_type}")
155
+ print(f"> language supported : {lang}")
156
+ print(f"> dataset used : {dataset}")
157
+ print(f"> model name : {model}")
158
+ if "description" in self.models_dict[model_type][lang][dataset][model]:
159
+ print(
160
+ f"> description : {self.models_dict[model_type][lang][dataset][model]['description']}"
161
+ )
162
+ else:
163
+ print("> description : coming soon")
164
+ if "default_vocoder" in self.models_dict[model_type][lang][dataset][model]:
165
+ print(
166
+ f"> default_vocoder : {self.models_dict[model_type][lang][dataset][model]['default_vocoder']}"
167
+ )
168
+ else:
169
+ print(f"> model {model} does not exist for {model_type}/{lang}/{dataset}.")
170
+ else:
171
+ print(f"> dataset {dataset} does not exist for {model_type}/{lang}.")
172
+ else:
173
+ print(f"> lang {lang} does not exist for {model_type}.")
174
+ else:
175
+ print(f"> model_type {model_type} does not exist in the list.")
176
+
177
+ def list_tts_models(self):
178
+ """Print all `TTS` models and return a list of model names
179
+
180
+ Format is `language/dataset/model`
181
+ """
182
+ return self._list_for_model_type("tts_models")
183
+
184
+ def list_vocoder_models(self):
185
+ """Print all the `vocoder` models and return a list of model names
186
+
187
+ Format is `language/dataset/model`
188
+ """
189
+ return self._list_for_model_type("vocoder_models")
190
+
191
+ def list_vc_models(self):
192
+ """Print all the voice conversion models and return a list of model names
193
+
194
+ Format is `language/dataset/model`
195
+ """
196
+ return self._list_for_model_type("voice_conversion_models")
197
+
198
+ def list_langs(self):
199
+ """Print all the available languages"""
200
+ print(" Name format: type/language")
201
+ for model_type in self.models_dict:
202
+ for lang in self.models_dict[model_type]:
203
+ print(f" >: {model_type}/{lang} ")
204
+
205
+ def list_datasets(self):
206
+ """Print all the datasets"""
207
+ print(" Name format: type/language/dataset")
208
+ for model_type in self.models_dict:
209
+ for lang in self.models_dict[model_type]:
210
+ for dataset in self.models_dict[model_type][lang]:
211
+ print(f" >: {model_type}/{lang}/{dataset}")
212
+
213
+ @staticmethod
214
+ def print_model_license(model_item: Dict):
215
+ """Print the license of a model
216
+
217
+ Args:
218
+ model_item (dict): model item in the models.json
219
+ """
220
+ if "license" in model_item and model_item["license"].strip() != "":
221
+ print(f" > Model's license - {model_item['license']}")
222
+ if model_item["license"].lower() in LICENSE_URLS:
223
+ print(f" > Check {LICENSE_URLS[model_item["license"].lower()]} for more info.")
224
+ else:
225
+ print(" > Check https://opensource.org/licenses for more info.")
226
+ else:
227
+ print(" > Model's license - No license information available")
228
+
229
+ def _download_github_model(self, model_item: Dict, output_path: str):
230
+ if isinstance(model_item["github_rls_url"], list):
231
+ self._download_model_files(model_item["github_rls_url"], output_path, self.progress_bar)
232
+ else:
233
+ self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar)
234
+
235
+ def _download_hf_model(self, model_item: Dict, output_path: str):
236
+ if isinstance(model_item["hf_url"], list):
237
+ self._download_model_files(model_item["hf_url"], output_path, self.progress_bar)
238
+ else:
239
+ self._download_zip_file(model_item["hf_url"], output_path, self.progress_bar)
240
+
241
+ def download_fairseq_model(self, model_name, output_path):
242
+ URI_PREFIX = "https://coqui.gateway.scarf.sh/fairseq/"
243
+ _, lang, _, _ = model_name.split("/")
244
+ model_download_uri = os.path.join(URI_PREFIX, f"{lang}.tar.gz")
245
+ self._download_tar_file(model_download_uri, output_path, self.progress_bar)
246
+
247
+ @staticmethod
248
+ def set_model_url(model_item: Dict):
249
+ model_item["model_url"] = None
250
+ if "github_rls_url" in model_item:
251
+ model_item["model_url"] = model_item["github_rls_url"]
252
+ elif "hf_url" in model_item:
253
+ model_item["model_url"] = model_item["hf_url"]
254
+ elif "fairseq" in model_item["model_name"]:
255
+ model_item["model_url"] = "https://coqui.gateway.scarf.sh/fairseq/"
256
+ elif "xtts" in model_item["model_name"]:
257
+ model_item["model_url"] = "https://coqui.gateway.scarf.sh/xtts/"
258
+ return model_item
259
+
260
+ def _set_model_item(self, model_name):
261
+ # fetch model info from the dict
262
+ if "fairseq" in model_name:
263
+ model_type = "tts_models"
264
+ lang = model_name.split("/")[1]
265
+ model_item = {
266
+ "model_type": "tts_models",
267
+ "license": "CC BY-NC 4.0",
268
+ "default_vocoder": None,
269
+ "author": "fairseq",
270
+ "description": "this model is released by Meta under Fairseq repo. Visit https://github.com/facebookresearch/fairseq/tree/main/examples/mms for more info.",
271
+ }
272
+ model_item["model_name"] = model_name
273
+ elif "xtts" in model_name and len(model_name.split("/")) != 4:
274
+ # loading xtts models with only model name (e.g. xtts_v2.0.2)
275
+ # check model name has the version number with regex
276
+ version_regex = r"v\d+\.\d+\.\d+"
277
+ if re.search(version_regex, model_name):
278
+ model_version = model_name.split("_")[-1]
279
+ else:
280
+ model_version = "main"
281
+ model_type = "tts_models"
282
+ lang = "multilingual"
283
+ dataset = "multi-dataset"
284
+ model = model_name
285
+ model_item = {
286
+ "default_vocoder": None,
287
+ "license": "CPML",
288
+ "contact": "[email protected]",
289
+ "tos_required": True,
290
+ "hf_url": [
291
+ f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/model.pth",
292
+ f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/config.json",
293
+ f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/vocab.json",
294
+ f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/hash.md5",
295
+ f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/speakers_xtts.pth",
296
+ ],
297
+ }
298
+ else:
299
+ # get model from models.json
300
+ model_type, lang, dataset, model = model_name.split("/")
301
+ model_item = self.models_dict[model_type][lang][dataset][model]
302
+ model_item["model_type"] = model_type
303
+
304
+ model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
305
+ md5hash = model_item["model_hash"] if "model_hash" in model_item else None
306
+ model_item = self.set_model_url(model_item)
307
+ return model_item, model_full_name, model, md5hash
308
+
309
+ @staticmethod
310
+ def ask_tos(model_full_path):
311
+ """Automatically agree to the terms of service without user input."""
312
+ tos_path = os.path.join(model_full_path, "tos_agreed.txt")
313
+ # Automatically agree to the terms
314
+ with open(tos_path, "w", encoding="utf-8") as f:
315
+ f.write("I have read, understood and agreed to the Terms and Conditions.")
316
+ return True
317
+
318
+ @staticmethod
319
+ def tos_agreed(model_item, model_full_path):
320
+ """Check if the user has agreed to the terms of service"""
321
+ if "tos_required" in model_item and model_item["tos_required"]:
322
+ tos_path = os.path.join(model_full_path, "tos_agreed.txt")
323
+ if os.path.exists(tos_path) or os.environ.get("COQUI_TOS_AGREED") == "1":
324
+ return True
325
+ return False
326
+ return True
327
+
328
+ def create_dir_and_download_model(self, model_name, model_item, output_path):
329
+ os.makedirs(output_path, exist_ok=True)
330
+ # handle TOS
331
+ if not self.tos_agreed(model_item, output_path):
332
+ if not self.ask_tos(output_path):
333
+ os.rmdir(output_path)
334
+ raise Exception(" [!] You must agree to the terms of service to use this model.")
335
+ print(f" > Downloading model to {output_path}")
336
+ try:
337
+ if "fairseq" in model_name:
338
+ self.download_fairseq_model(model_name, output_path)
339
+ elif "github_rls_url" in model_item:
340
+ self._download_github_model(model_item, output_path)
341
+ elif "hf_url" in model_item:
342
+ self._download_hf_model(model_item, output_path)
343
+
344
+ except requests.RequestException as e:
345
+ print(f" > Failed to download the model file to {output_path}")
346
+ rmtree(output_path)
347
+ raise e
348
+ self.print_model_license(model_item=model_item)
349
+
350
+ def check_if_configs_are_equal(self, model_name, model_item, output_path):
351
+ with fsspec.open(self._find_files(output_path)[1], "r", encoding="utf-8") as f:
352
+ config_local = json.load(f)
353
+ remote_url = None
354
+ for url in model_item["hf_url"]:
355
+ if "config.json" in url:
356
+ remote_url = url
357
+ break
358
+
359
+ with fsspec.open(remote_url, "r", encoding="utf-8") as f:
360
+ config_remote = json.load(f)
361
+
362
+ if not config_local == config_remote:
363
+ print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...")
364
+ self.create_dir_and_download_model(model_name, model_item, output_path)
365
+
366
+ def download_model(self, model_name):
367
+ """Download model files given the full model name.
368
+ Model name is in the format
369
+ 'type/language/dataset/model'
370
+ e.g. 'tts_model/en/ljspeech/tacotron'
371
+
372
+ Every model must have the following files:
373
+ - *.pth : pytorch model checkpoint file.
374
+ - config.json : model config file.
375
+ - scale_stats.npy (if exist): scale values for preprocessing.
376
+
377
+ Args:
378
+ model_name (str): model name as explained above.
379
+ """
380
+ model_item, model_full_name, model, md5sum = self._set_model_item(model_name)
381
+ # set the model specific output path
382
+ output_path = os.path.join(self.output_prefix, model_full_name)
383
+ if os.path.exists(output_path):
384
+ if md5sum is not None:
385
+ md5sum_file = os.path.join(output_path, "hash.md5")
386
+ if os.path.isfile(md5sum_file):
387
+ with open(md5sum_file, mode="r") as f:
388
+ if not f.read() == md5sum:
389
+ print(f" > {model_name} has been updated, clearing model cache...")
390
+ self.create_dir_and_download_model(model_name, model_item, output_path)
391
+ else:
392
+ print(f" > {model_name} is already downloaded.")
393
+ else:
394
+ print(f" > {model_name} has been updated, clearing model cache...")
395
+ self.create_dir_and_download_model(model_name, model_item, output_path)
396
+ # if the configs are different, redownload it
397
+ # ToDo: we need a better way to handle it
398
+ if "xtts" in model_name:
399
+ try:
400
+ self.check_if_configs_are_equal(model_name, model_item, output_path)
401
+ except:
402
+ pass
403
+ else:
404
+ print(f" > {model_name} is already downloaded.")
405
+ else:
406
+ self.create_dir_and_download_model(model_name, model_item, output_path)
407
+
408
+ # find downloaded files
409
+ output_model_path = output_path
410
+ output_config_path = None
411
+ if (
412
+ model not in ["tortoise-v2", "bark"] and "fairseq" not in model_name and "xtts" not in model_name
413
+ ): # TODO:This is stupid but don't care for now.
414
+ output_model_path, output_config_path = self._find_files(output_path)
415
+ # update paths in the config.json
416
+ self._update_paths(output_path, output_config_path)
417
+ return output_model_path, output_config_path, model_item
418
+
419
+ @staticmethod
420
+ def _find_files(output_path: str) -> Tuple[str, str]:
421
+ """Find the model and config files in the output path
422
+
423
+ Args:
424
+ output_path (str): path to the model files
425
+
426
+ Returns:
427
+ Tuple[str, str]: path to the model file and config file
428
+ """
429
+ model_file = None
430
+ config_file = None
431
+ for file_name in os.listdir(output_path):
432
+ if file_name in ["model_file.pth", "model_file.pth.tar", "model.pth"]:
433
+ model_file = os.path.join(output_path, file_name)
434
+ elif file_name == "config.json":
435
+ config_file = os.path.join(output_path, file_name)
436
+ if model_file is None:
437
+ raise ValueError(" [!] Model file not found in the output path")
438
+ if config_file is None:
439
+ raise ValueError(" [!] Config file not found in the output path")
440
+ return model_file, config_file
441
+
442
+ @staticmethod
443
+ def _find_speaker_encoder(output_path: str) -> str:
444
+ """Find the speaker encoder file in the output path
445
+
446
+ Args:
447
+ output_path (str): path to the model files
448
+
449
+ Returns:
450
+ str: path to the speaker encoder file
451
+ """
452
+ speaker_encoder_file = None
453
+ for file_name in os.listdir(output_path):
454
+ if file_name in ["model_se.pth", "model_se.pth.tar"]:
455
+ speaker_encoder_file = os.path.join(output_path, file_name)
456
+ return speaker_encoder_file
457
+
458
+ def _update_paths(self, output_path: str, config_path: str) -> None:
459
+ """Update paths for certain files in config.json after download.
460
+
461
+ Args:
462
+ output_path (str): local path the model is downloaded to.
463
+ config_path (str): local config.json path.
464
+ """
465
+ output_stats_path = os.path.join(output_path, "scale_stats.npy")
466
+ output_d_vector_file_path = os.path.join(output_path, "speakers.json")
467
+ output_d_vector_file_pth_path = os.path.join(output_path, "speakers.pth")
468
+ output_speaker_ids_file_path = os.path.join(output_path, "speaker_ids.json")
469
+ output_speaker_ids_file_pth_path = os.path.join(output_path, "speaker_ids.pth")
470
+ speaker_encoder_config_path = os.path.join(output_path, "config_se.json")
471
+ speaker_encoder_model_path = self._find_speaker_encoder(output_path)
472
+
473
+ # update the scale_path.npy file path in the model config.json
474
+ self._update_path("audio.stats_path", output_stats_path, config_path)
475
+
476
+ # update the speakers.json file path in the model config.json to the current path
477
+ self._update_path("d_vector_file", output_d_vector_file_path, config_path)
478
+ self._update_path("d_vector_file", output_d_vector_file_pth_path, config_path)
479
+ self._update_path("model_args.d_vector_file", output_d_vector_file_path, config_path)
480
+ self._update_path("model_args.d_vector_file", output_d_vector_file_pth_path, config_path)
481
+
482
+ # update the speaker_ids.json file path in the model config.json to the current path
483
+ self._update_path("speakers_file", output_speaker_ids_file_path, config_path)
484
+ self._update_path("speakers_file", output_speaker_ids_file_pth_path, config_path)
485
+ self._update_path("model_args.speakers_file", output_speaker_ids_file_path, config_path)
486
+ self._update_path("model_args.speakers_file", output_speaker_ids_file_pth_path, config_path)
487
+
488
+ # update the speaker_encoder file path in the model config.json to the current path
489
+ self._update_path("speaker_encoder_model_path", speaker_encoder_model_path, config_path)
490
+ self._update_path("model_args.speaker_encoder_model_path", speaker_encoder_model_path, config_path)
491
+ self._update_path("speaker_encoder_config_path", speaker_encoder_config_path, config_path)
492
+ self._update_path("model_args.speaker_encoder_config_path", speaker_encoder_config_path, config_path)
493
+
494
+ @staticmethod
495
+ def _update_path(field_name, new_path, config_path):
496
+ """Update the path in the model config.json for the current environment after download"""
497
+ if new_path and os.path.exists(new_path):
498
+ config = load_config(config_path)
499
+ field_names = field_name.split(".")
500
+ if len(field_names) > 1:
501
+ # field name points to a sub-level field
502
+ sub_conf = config
503
+ for fd in field_names[:-1]:
504
+ if fd in sub_conf:
505
+ sub_conf = sub_conf[fd]
506
+ else:
507
+ return
508
+ if isinstance(sub_conf[field_names[-1]], list):
509
+ sub_conf[field_names[-1]] = [new_path]
510
+ else:
511
+ sub_conf[field_names[-1]] = new_path
512
+ else:
513
+ # field name points to a top-level field
514
+ if not field_name in config:
515
+ return
516
+ if isinstance(config[field_name], list):
517
+ config[field_name] = [new_path]
518
+ else:
519
+ config[field_name] = new_path
520
+ config.save_json(config_path)
521
+
522
+ @staticmethod
523
+ def _download_zip_file(file_url, output_folder, progress_bar):
524
+ """Download the github releases"""
525
+ # download the file
526
+ r = requests.get(file_url, stream=True)
527
+ # extract the file
528
+ try:
529
+ total_size_in_bytes = int(r.headers.get("content-length", 0))
530
+ block_size = 1024 # 1 Kibibyte
531
+ if progress_bar:
532
+ ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
533
+ temp_zip_name = os.path.join(output_folder, file_url.split("/")[-1])
534
+ with open(temp_zip_name, "wb") as file:
535
+ for data in r.iter_content(block_size):
536
+ if progress_bar:
537
+ ModelManager.tqdm_progress.update(len(data))
538
+ file.write(data)
539
+ with zipfile.ZipFile(temp_zip_name) as z:
540
+ z.extractall(output_folder)
541
+ os.remove(temp_zip_name) # delete zip after extract
542
+ except zipfile.BadZipFile:
543
+ print(f" > Error: Bad zip file - {file_url}")
544
+ raise zipfile.BadZipFile # pylint: disable=raise-missing-from
545
+ # move the files to the outer path
546
+ for file_path in z.namelist():
547
+ src_path = os.path.join(output_folder, file_path)
548
+ if os.path.isfile(src_path):
549
+ dst_path = os.path.join(output_folder, os.path.basename(file_path))
550
+ if src_path != dst_path:
551
+ copyfile(src_path, dst_path)
552
+ # remove redundant (hidden or not) folders
553
+ for file_path in z.namelist():
554
+ if os.path.isdir(os.path.join(output_folder, file_path)):
555
+ rmtree(os.path.join(output_folder, file_path))
556
+
557
+ @staticmethod
558
+ def _download_tar_file(file_url, output_folder, progress_bar):
559
+ """Download the github releases"""
560
+ # download the file
561
+ r = requests.get(file_url, stream=True)
562
+ # extract the file
563
+ try:
564
+ total_size_in_bytes = int(r.headers.get("content-length", 0))
565
+ block_size = 1024 # 1 Kibibyte
566
+ if progress_bar:
567
+ ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
568
+ temp_tar_name = os.path.join(output_folder, file_url.split("/")[-1])
569
+ with open(temp_tar_name, "wb") as file:
570
+ for data in r.iter_content(block_size):
571
+ if progress_bar:
572
+ ModelManager.tqdm_progress.update(len(data))
573
+ file.write(data)
574
+ with tarfile.open(temp_tar_name) as t:
575
+ t.extractall(output_folder)
576
+ tar_names = t.getnames()
577
+ os.remove(temp_tar_name) # delete tar after extract
578
+ except tarfile.ReadError:
579
+ print(f" > Error: Bad tar file - {file_url}")
580
+ raise tarfile.ReadError # pylint: disable=raise-missing-from
581
+ # move the files to the outer path
582
+ for file_path in os.listdir(os.path.join(output_folder, tar_names[0])):
583
+ src_path = os.path.join(output_folder, tar_names[0], file_path)
584
+ dst_path = os.path.join(output_folder, os.path.basename(file_path))
585
+ if src_path != dst_path:
586
+ copyfile(src_path, dst_path)
587
+ # remove the extracted folder
588
+ rmtree(os.path.join(output_folder, tar_names[0]))
589
+
590
+ @staticmethod
591
+ def _download_model_files(file_urls, output_folder, progress_bar):
592
+ """Download the github releases"""
593
+ for file_url in file_urls:
594
+ # download the file
595
+ r = requests.get(file_url, stream=True)
596
+ # extract the file
597
+ bease_filename = file_url.split("/")[-1]
598
+ temp_zip_name = os.path.join(output_folder, bease_filename)
599
+ total_size_in_bytes = int(r.headers.get("content-length", 0))
600
+ block_size = 1024 # 1 Kibibyte
601
+ with open(temp_zip_name, "wb") as file:
602
+ if progress_bar:
603
+ ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
604
+ for data in r.iter_content(block_size):
605
+ if progress_bar:
606
+ ModelManager.tqdm_progress.update(len(data))
607
+ file.write(data)
608
+
609
+ @staticmethod
610
+ def _check_dict_key(my_dict, key):
611
+ if key in my_dict.keys() and my_dict[key] is not None:
612
+ if not isinstance(key, str):
613
+ return True
614
+ if isinstance(key, str) and len(my_dict[key]) > 0:
615
+ return True
616
+ return False