Spaces:
Sleeping
Sleeping
File size: 2,844 Bytes
5ce1fe8 941850e 5ce1fe8 941850e 5ce1fe8 589e655 5ce1fe8 589e655 5ce1fe8 589e655 5ce1fe8 3147eb6 d8e6dc5 3147eb6 d8e6dc5 3147eb6 5ce1fe8 589e655 3147eb6 5ce1fe8 234de07 5ce1fe8 159e07d 5ce1fe8 589e655 e924ab6 5ce1fe8 159e07d 5ce1fe8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import os
from pathlib import Path
import sys
pwd = os.path.abspath(os.path.dirname(__file__))
sys.path.append(os.path.join(pwd, "../../"))
import huggingface_hub
from project_settings import project_path
def get_args():
parser = argparse.ArgumentParser()
# parser.add_argument(
# "--repo_id",
# default="csukuangfj/wenet-chinese-model",
# # default="csukuangfj/wenet-english-model",
# type=str
# )
# parser.add_argument("--model_filename", default="final.zip", type=str)
# parser.add_argument("--model_sub_folder", default=".", type=str)
# parser.add_argument("--tokens_filename", default="units.txt", type=str)
# parser.add_argument("--tokens_sub_folder", default=".", type=str)
parser.add_argument(
"--repo_id",
default="luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2",
type=str
)
parser.add_argument("--model_filename", default="cpu_jit_epoch_10_avg_2_torch_1.7.1.pt", type=str)
parser.add_argument("--model_sub_folder", default="exp", type=str)
parser.add_argument("--tokens_filename", default="tokens.txt", type=str)
parser.add_argument("--tokens_sub_folder", default="data/lang_char", type=str)
parser.add_argument(
"--pretrained_model_dir",
default=(project_path / "pretrained_models").as_posix(),
type=str
)
args = parser.parse_args()
return args
def main():
args = get_args()
pretrained_model_dir = Path(args.pretrained_model_dir)
pretrained_model_dir.mkdir(exist_ok=True)
repo_id: Path = Path(args.repo_id)
if len(repo_id.parts) == 1:
repo_name = repo_id.parts[-1]
repo_name = repo_name[:30]
folder = repo_name
elif len(repo_id.parts) == 2:
repo_supplier = repo_id.parts[-2]
repo_name = repo_id.parts[-1]
repo_name = repo_name[:30]
folder = "{}/{}".format(repo_supplier, repo_name)
else:
raise AssertionError("repo_id parts count invalid: {}".format(len(repo_id.parts)))
local_model_dir = pretrained_model_dir / "huggingface" / folder
local_model_dir.mkdir(parents=True, exist_ok=True)
print("download model")
model_filename = huggingface_hub.hf_hub_download(
repo_id=args.repo_id,
filename=args.model_filename,
subfolder=args.model_sub_folder,
local_dir=local_model_dir.as_posix(),
)
print(model_filename)
exit(0)
print("download tokens")
tokens_filename = huggingface_hub.hf_hub_download(
repo_id=args.repo_id,
filename=args.tokens_filename,
subfolder=args.tokens_sub_folder,
local_dir=local_model_dir.as_posix(),
)
print(tokens_filename)
return
if __name__ == "__main__":
main()
|