|
import argparse |
|
import os |
|
from huggingface_hub import upload_folder, upload_file, hf_hub_download |
|
from rich.console import Console |
|
from rich.panel import Panel |
|
from rich import box, style |
|
from rich.table import Table |
|
|
|
CONSOLE = Console(width=120) |
|
|
|
|
|
def upload(): |
|
|
|
if args.folder_path: |
|
|
|
try: |
|
if token is not None: |
|
upload_folder(repo_id=args.repo_id, folder_path=args.folder_path, ignore_patterns=ignore_patterns, path_in_repo=args.path_in_repo, token=token) |
|
else: |
|
upload_folder(repo_id=args.repo_id, folder_path=args.folder_path, ignore_patterns=ignore_patterns, path_in_repo=args.path_in_repo) |
|
table = Table(title=None, show_header=False, box=box.MINIMAL, title_style=style.Style(bold=True)) |
|
table.add_row(f"Model id {args.repo_id}", str(args.folder_path)) |
|
CONSOLE.print(Panel(table, title="[bold][green]:tada: Upload completed DO NOT forget specify the model id in methods! :tada:[/bold]", expand=False)) |
|
|
|
except Exception as e: |
|
CONSOLE.print(f"[bold][yellow]:tada: Upload failed due to {e}.") |
|
raise e |
|
|
|
if args.file_path: |
|
|
|
try: |
|
if token is not None: |
|
upload_file( |
|
path_or_fileobj=args.file_path, |
|
path_in_repo=os.path.basename(args.file_path), |
|
repo_id=args.repo_id, |
|
repo_type='model', |
|
token=token |
|
) |
|
else: |
|
upload_file( |
|
path_or_fileobj=args.file_path, |
|
path_in_repo=os.path.basename(args.file_path), |
|
repo_id=args.repo_id, |
|
repo_type='model', |
|
) |
|
table = Table(title=None, show_header=False, box=box.MINIMAL, title_style=style.Style(bold=True)) |
|
table.add_row(f"Model id {args.repo_id}", str(args.file_path)) |
|
CONSOLE.print(Panel(table, title="[bold][green]:tada: Upload completed! :tada:[/bold]", expand=False)) |
|
|
|
except Exception as e: |
|
CONSOLE.print(f"[bold][yellow]:tada: Upload failed due to {e}.") |
|
raise e |
|
|
|
|
|
def download(): |
|
|
|
try: |
|
if token is not None: |
|
ckpt_path = hf_hub_download( |
|
repo_id=args.repo_id, |
|
filename=args.file_path, |
|
token=token |
|
) |
|
else: |
|
ckpt_path = hf_hub_download( |
|
repo_id=args.repo_id, |
|
filename=args.file_path, |
|
) |
|
table = Table(title=None, show_header=False, box=box.MINIMAL, title_style=style.Style(bold=True)) |
|
table.add_row(f"Model id {args.repo_id}", str(args.file_path)) |
|
CONSOLE.print(Panel(table, title=f"[bold][green]:tada: Download completed to {ckpt_path}! :tada:[/bold]", expand=False)) |
|
|
|
if args.save_path is not None: |
|
os.makedirs(args.save_path, exist_ok=True) |
|
import shutil |
|
shutil.copy(ckpt_path, os.path.join(args.save_path, args.file_path)) |
|
|
|
except Exception as e: |
|
CONSOLE.print(f"[bold][yellow]:tada: Download failed due to {e}.") |
|
raise e |
|
|
|
return ckpt_path |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--repo_id", type=str, default=None, required=True) |
|
parser.add_argument("--upload", action="store_true") |
|
parser.add_argument("--download", action="store_true") |
|
parser.add_argument("--folder_path", type=str, default=None, required=False) |
|
parser.add_argument("--file_path", type=str, default=None, required=False) |
|
parser.add_argument("--save_path", type=str, default=None, required=False) |
|
parser.add_argument("--token", type=str, default=None, required=False) |
|
parser.add_argument("--path_in_repo", type=str, default=None, required=False) |
|
args = parser.parse_args() |
|
|
|
token = args.token or os.getenv("hf_token", None) |
|
ignore_patterns = ["**/optimizer.bin", "**/random_states*", "**/scaler.pt", "**/scheduler.bin"] |
|
|
|
if not (args.folder_path or args.file_path): |
|
raise RuntimeError(f'Choose either folder path or file path please!') |
|
|
|
if len(args.repo_id.split('/')) != 2: |
|
raise RuntimeError(f'Invalid repo_id: {args.repo_id}, please use in [use-id]/[repo-name] format') |
|
CONSOLE.log(f"Use repo: [bold][yellow] {args.repo_id}") |
|
|
|
if args.upload: |
|
upload() |
|
|
|
if args.download: |
|
download() |
|
|