File size: 4,460 Bytes
c1a7f73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d37e5d1
c1a7f73
d37e5d1
c1a7f73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d37e5d1
c1a7f73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import argparse
import os
from huggingface_hub import upload_folder, upload_file, hf_hub_download
from rich.console import Console
from rich.panel import Panel
from rich import box, style
from rich.table import Table

CONSOLE = Console(width=120)


def upload():

    if args.folder_path:

        try:
            if token is not None:
                upload_folder(repo_id=args.repo_id, folder_path=args.folder_path, ignore_patterns=ignore_patterns, path_in_repo=args.path_in_repo, token=token)
            else:
                upload_folder(repo_id=args.repo_id, folder_path=args.folder_path, ignore_patterns=ignore_patterns, path_in_repo=args.path_in_repo)
            table = Table(title=None, show_header=False, box=box.MINIMAL, title_style=style.Style(bold=True))
            table.add_row(f"Model id {args.repo_id}", str(args.folder_path))
            CONSOLE.print(Panel(table, title="[bold][green]:tada: Upload completed DO NOT forget specify the model id in methods! :tada:[/bold]", expand=False))

        except Exception as e:
            CONSOLE.print(f"[bold][yellow]:tada: Upload failed due to {e}.")
            raise e

    if args.file_path:

        try:
            if token is not None:
                upload_file(
                    path_or_fileobj=args.file_path,
                    path_in_repo=os.path.basename(args.file_path),
                    repo_id=args.repo_id,
                    repo_type='model',
                    token=token
                )
            else:
                upload_file(
                    path_or_fileobj=args.file_path,
                    path_in_repo=os.path.basename(args.file_path),
                    repo_id=args.repo_id,
                    repo_type='model',
                )
            table = Table(title=None, show_header=False, box=box.MINIMAL, title_style=style.Style(bold=True))
            table.add_row(f"Model id {args.repo_id}", str(args.file_path))
            CONSOLE.print(Panel(table, title="[bold][green]:tada: Upload completed! :tada:[/bold]", expand=False))

        except Exception as e:
            CONSOLE.print(f"[bold][yellow]:tada: Upload failed due to {e}.")
            raise e


def download():

    try:
        if token is not None:
            ckpt_path = hf_hub_download(
                repo_id=args.repo_id,
                filename=args.file_path,
                token=token
            )
        else:
            ckpt_path = hf_hub_download(
                repo_id=args.repo_id,
                filename=args.file_path,
            )
        table = Table(title=None, show_header=False, box=box.MINIMAL, title_style=style.Style(bold=True))
        table.add_row(f"Model id {args.repo_id}", str(args.file_path))
        CONSOLE.print(Panel(table, title=f"[bold][green]:tada: Download completed to {ckpt_path}! :tada:[/bold]", expand=False))

        if args.save_path is not None:
            os.makedirs(args.save_path, exist_ok=True)
            import shutil
            shutil.copy(ckpt_path, os.path.join(args.save_path, args.file_path))

    except Exception as e:
        CONSOLE.print(f"[bold][yellow]:tada: Download failed due to {e}.")
        raise e

    return ckpt_path


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--repo_id", type=str, default=None, required=True)
    parser.add_argument("--upload", action="store_true")
    parser.add_argument("--download", action="store_true")
    parser.add_argument("--folder_path", type=str, default=None, required=False)
    parser.add_argument("--file_path", type=str, default=None, required=False)
    parser.add_argument("--save_path", type=str, default=None, required=False)
    parser.add_argument("--token", type=str, default=None, required=False)
    parser.add_argument("--path_in_repo", type=str, default=None, required=False)
    args = parser.parse_args()

    token = args.token or os.getenv("hf_token", None)
    ignore_patterns = ["**/optimizer.bin", "**/random_states*", "**/scaler.pt", "**/scheduler.bin"]

    if not (args.folder_path or args.file_path):
        raise RuntimeError(f'Choose either folder path or file path please!')

    if len(args.repo_id.split('/')) != 2:
        raise RuntimeError(f'Invalid repo_id: {args.repo_id}, please use in [use-id]/[repo-name] format')
    CONSOLE.log(f"Use repo: [bold][yellow] {args.repo_id}")

    if args.upload:
        upload()

    if args.download:
        download()