import os from pathlib import Path from typing import Dict, List, Literal, Optional, Union import requests from tqdm.auto import tqdm as base_tqdm from tqdm.contrib.concurrent import thread_map from .constants import ( DEFAULT_ETAG_TIMEOUT, DEFAULT_REVISION, HF_HUB_CACHE, HF_HUB_ENABLE_HF_TRANSFER, REPO_TYPES, ) from .file_download import REGEX_COMMIT_HASH, hf_hub_download, repo_folder_name from .hf_api import DatasetInfo, HfApi, ModelInfo, SpaceInfo from .utils import ( GatedRepoError, LocalEntryNotFoundError, OfflineModeIsEnabled, RepositoryNotFoundError, RevisionNotFoundError, filter_repo_objects, logging, validate_hf_hub_args, ) from .utils import tqdm as hf_tqdm logger = logging.get_logger(__name__) @validate_hf_hub_args def snapshot_download( repo_id: str, *, repo_type: Optional[str] = None, revision: Optional[str] = None, cache_dir: Union[str, Path, None] = None, local_dir: Union[str, Path, None] = None, library_name: Optional[str] = None, library_version: Optional[str] = None, user_agent: Optional[Union[Dict, str]] = None, proxies: Optional[Dict] = None, etag_timeout: float = DEFAULT_ETAG_TIMEOUT, force_download: bool = False, token: Optional[Union[bool, str]] = None, local_files_only: bool = False, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None, max_workers: int = 8, tqdm_class: Optional[base_tqdm] = None, headers: Optional[Dict[str, str]] = None, endpoint: Optional[str] = None, # Deprecated args local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto", resume_download: Optional[bool] = None, ) -> str: """Download repo files. Download a whole snapshot of a repo's files at the specified revision. This is useful when you want all files from a repo, because you don't know which ones you will need a priori. All files are nested inside a folder in order to keep their actual filename relative to that folder. You can also filter which files to download using `allow_patterns` and `ignore_patterns`. If `local_dir` is provided, the file structure from the repo will be replicated in this location. When using this option, the `cache_dir` will not be used and a `.cache/huggingface/` folder will be created at the root of `local_dir` to store some metadata related to the downloaded files. While this mechanism is not as robust as the main cache-system, it's optimized for regularly pulling the latest version of a repository. An alternative would be to clone the repo but this requires git and git-lfs to be installed and properly configured. It is also not possible to filter which files to download when cloning a repository using git. Args: repo_id (`str`): A user or an organization name and a repo name separated by a `/`. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if downloading from a dataset or space, `None` or `"model"` if downloading from a model. Default is `None`. revision (`str`, *optional*): An optional Git revision id which can be a branch name, a tag, or a commit hash. cache_dir (`str`, `Path`, *optional*): Path to the folder where cached files are stored. local_dir (`str` or `Path`, *optional*): If provided, the downloaded files will be placed under this directory. library_name (`str`, *optional*): The name of the library to which the object corresponds. library_version (`str`, *optional*): The version of the library. user_agent (`str`, `dict`, *optional*): The user-agent info in the form of a dictionary or a string. proxies (`dict`, *optional*): Dictionary mapping protocol to the URL of the proxy passed to `requests.request`. etag_timeout (`float`, *optional*, defaults to `10`): When fetching ETag, how many seconds to wait for the server to send data before giving up which is passed to `requests.request`. force_download (`bool`, *optional*, defaults to `False`): Whether the file should be downloaded even if it already exists in the local cache. token (`str`, `bool`, *optional*): A token to be used for the download. - If `True`, the token is read from the HuggingFace config folder. - If a string, it's used as the authentication token. headers (`dict`, *optional*): Additional headers to include in the request. Those headers take precedence over the others. local_files_only (`bool`, *optional*, defaults to `False`): If `True`, avoid downloading the file and return the path to the local cached file if it exists. allow_patterns (`List[str]` or `str`, *optional*): If provided, only files matching at least one pattern are downloaded. ignore_patterns (`List[str]` or `str`, *optional*): If provided, files matching any of the patterns are not downloaded. max_workers (`int`, *optional*): Number of concurrent threads to download files (1 thread = 1 file download). Defaults to 8. tqdm_class (`tqdm`, *optional*): If provided, overwrites the default behavior for the progress bar. Passed argument must inherit from `tqdm.auto.tqdm` or at least mimic its behavior. Note that the `tqdm_class` is not passed to each individual download. Defaults to the custom HF progress bar that can be disabled by setting `HF_HUB_DISABLE_PROGRESS_BARS` environment variable. Returns: `str`: folder path of the repo snapshot. Raises: [`~utils.RepositoryNotFoundError`] If the repository to download from cannot be found. This may be because it doesn't exist, or because it is set to `private` and you do not have access. [`~utils.RevisionNotFoundError`] If the revision to download from cannot be found. [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) If `token=True` and the token cannot be found. [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if ETag cannot be determined. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if some parameter value is invalid. """ if cache_dir is None: cache_dir = HF_HUB_CACHE if revision is None: revision = DEFAULT_REVISION if isinstance(cache_dir, Path): cache_dir = str(cache_dir) if repo_type is None: repo_type = "model" if repo_type not in REPO_TYPES: raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(REPO_TYPES)}") storage_folder = os.path.join(cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type)) repo_info: Union[ModelInfo, DatasetInfo, SpaceInfo, None] = None api_call_error: Optional[Exception] = None if not local_files_only: # try/except logic to handle different errors => taken from `hf_hub_download` try: # if we have internet connection we want to list files to download api = HfApi( library_name=library_name, library_version=library_version, user_agent=user_agent, endpoint=endpoint, headers=headers, ) repo_info = api.repo_info(repo_id=repo_id, repo_type=repo_type, revision=revision, token=token) except (requests.exceptions.SSLError, requests.exceptions.ProxyError): # Actually raise for those subclasses of ConnectionError raise except ( requests.exceptions.ConnectionError, requests.exceptions.Timeout, OfflineModeIsEnabled, ) as error: # Internet connection is down # => will try to use local files only api_call_error = error pass except RevisionNotFoundError: # The repo was found but the revision doesn't exist on the Hub (never existed or got deleted) raise except requests.HTTPError as error: # Multiple reasons for an http error: # - Repository is private and invalid/missing token sent # - Repository is gated and invalid/missing token sent # - Hub is down (error 500 or 504) # => let's switch to 'local_files_only=True' to check if the files are already cached. # (if it's not the case, the error will be re-raised) api_call_error = error pass # At this stage, if `repo_info` is None it means either: # - internet connection is down # - internet connection is deactivated (local_files_only=True or HF_HUB_OFFLINE=True) # - repo is private/gated and invalid/missing token sent # - Hub is down # => let's look if we can find the appropriate folder in the cache: # - if the specified revision is a commit hash, look inside "snapshots". # - f the specified revision is a branch or tag, look inside "refs". if repo_info is None: # Try to get which commit hash corresponds to the specified revision commit_hash = None if REGEX_COMMIT_HASH.match(revision): commit_hash = revision else: ref_path = os.path.join(storage_folder, "refs", revision) if os.path.exists(ref_path): # retrieve commit_hash from refs file with open(ref_path) as f: commit_hash = f.read() # Try to locate snapshot folder for this commit hash if commit_hash is not None: snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash) if os.path.exists(snapshot_folder): # Snapshot folder exists => let's return it # (but we can't check if all the files are actually there) return snapshot_folder # If we couldn't find the appropriate folder on disk, raise an error. if local_files_only: raise LocalEntryNotFoundError( "Cannot find an appropriate cached snapshot folder for the specified revision on the local disk and " "outgoing traffic has been disabled. To enable repo look-ups and downloads online, pass " "'local_files_only=False' as input." ) elif isinstance(api_call_error, OfflineModeIsEnabled): raise LocalEntryNotFoundError( "Cannot find an appropriate cached snapshot folder for the specified revision on the local disk and " "outgoing traffic has been disabled. To enable repo look-ups and downloads online, set " "'HF_HUB_OFFLINE=0' as environment variable." ) from api_call_error elif isinstance(api_call_error, RepositoryNotFoundError) or isinstance(api_call_error, GatedRepoError): # Repo not found => let's raise the actual error raise api_call_error else: # Otherwise: most likely a connection issue or Hub downtime => let's warn the user raise LocalEntryNotFoundError( "An error happened while trying to locate the files on the Hub and we cannot find the appropriate" " snapshot folder for the specified revision on the local disk. Please check your internet connection" " and try again." ) from api_call_error # At this stage, internet connection is up and running # => let's download the files! assert repo_info.sha is not None, "Repo info returned from server must have a revision sha." assert repo_info.siblings is not None, "Repo info returned from server must have a siblings list." filtered_repo_files = list( filter_repo_objects( items=[f.rfilename for f in repo_info.siblings], allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, ) ) commit_hash = repo_info.sha snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash) # if passed revision is not identical to commit_hash # then revision has to be a branch name or tag name. # In that case store a ref. if revision != commit_hash: ref_path = os.path.join(storage_folder, "refs", revision) os.makedirs(os.path.dirname(ref_path), exist_ok=True) with open(ref_path, "w") as f: f.write(commit_hash) # we pass the commit_hash to hf_hub_download # so no network call happens if we already # have the file locally. def _inner_hf_hub_download(repo_file: str): return hf_hub_download( repo_id, filename=repo_file, repo_type=repo_type, revision=commit_hash, endpoint=endpoint, cache_dir=cache_dir, local_dir=local_dir, local_dir_use_symlinks=local_dir_use_symlinks, library_name=library_name, library_version=library_version, user_agent=user_agent, proxies=proxies, etag_timeout=etag_timeout, resume_download=resume_download, force_download=force_download, token=token, headers=headers, ) if HF_HUB_ENABLE_HF_TRANSFER: # when using hf_transfer we don't want extra parallelism # from the one hf_transfer provides for file in filtered_repo_files: _inner_hf_hub_download(file) else: thread_map( _inner_hf_hub_download, filtered_repo_files, desc=f"Fetching {len(filtered_repo_files)} files", max_workers=max_workers, # User can use its own tqdm class or the default one from `huggingface_hub.utils` tqdm_class=tqdm_class or hf_tqdm, ) if local_dir is not None: return str(os.path.realpath(local_dir)) return snapshot_folder