from typing import Dict, List, Optional, Union, Tuple, BinaryIO |
import os |
import sys |
import json |
import shutil |
import tempfile |
import copy |
from tqdm.auto import tqdm |
from functools import partial |
from urllib.parse import urlparse |
from pathlib import Path |
import requests |
from hashlib import sha256 |
from filelock import FileLock |
import importlib_metadata |
import torch |
import torch.nn as nn |
from torch import Tensor |
import fnmatch |
__version__ = "4.0.0" |
_torch_version = importlib_metadata.version("torch") |
hf_cache_home = os.path.expanduser(os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))) |
default_cache_path = os.path.join(hf_cache_home, "transformers") |
"tuna": "https://mirrors.tuna.tsinghua.edu.cn/hugging-face-models", |
"bfsu": "https://mirrors.bfsu.edu.cn/hugging-face-models", |
} |
HUGGINGFACE_CO_PREFIX = "https://huggingface.co/{model_id}/resolve/{revision}/{filename}" |
WEIGHTS_NAME = "pytorch_model.bin" |
CONFIG_NAME = "config.json" |
def is_torch_available(): |
return True |
def is_tf_available(): |
return False |
def is_remote_url(url_or_filename): |
parsed = urlparse(url_or_filename) |
return parsed.scheme in ("http", "https") |
def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers: Optional[Dict[str, str]] = None): |
headers = copy.deepcopy(headers) |
if resume_size > 0: |
headers["Range"] = "bytes=%d-" % (resume_size,) |
r = requests.get(url, stream=True, proxies=proxies, headers=headers) |
r.raise_for_status() |
content_length = r.headers.get("Content-Length") |
total = resume_size + int(content_length) if content_length is not None else None |
progress = tqdm( |
unit="B", |
unit_scale=True, |
total=total, |
initial=resume_size, |
desc="Downloading", |
disable=False, |
) |
for chunk in r.iter_content(chunk_size=1024): |
if chunk: |
progress.update(len(chunk)) |
temp_file.write(chunk) |
progress.close() |
def url_to_filename(url: str, etag: Optional[str] = None) -> str: |
url_bytes = url.encode("utf-8") |
filename = sha256(url_bytes).hexdigest() |
if etag: |
etag_bytes = etag.encode("utf-8") |
filename += "." + sha256(etag_bytes).hexdigest() |
if url.endswith(".h5"): |
filename += ".h5" |
return filename |
def hf_bucket_url( |
model_id: str, filename: str, subfolder: Optional[str] = None, revision: Optional[str] = None, mirror=None |
) -> str: |
if subfolder is not None: |
filename = f"{subfolder}/{filename}" |
if mirror: |
endpoint = PRESET_MIRROR_DICT.get(mirror, mirror) |
legacy_format = "/" not in model_id |
if legacy_format: |
return f"{endpoint}/{model_id}-{filename}" |
else: |
return f"{endpoint}/{model_id}/{filename}" |
if revision is None: |
revision = "main" |
return HUGGINGFACE_CO_PREFIX.format(model_id=model_id, revision=revision, filename=filename) |
def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str: |
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0]) |
if is_torch_available(): |
ua += f"; torch/{_torch_version}" |
if is_tf_available(): |
ua += f"; tensorflow/{_tf_version}" |
if isinstance(user_agent, dict): |
ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items()) |
elif isinstance(user_agent, str): |
ua += "; " + user_agent |
return ua |
def get_from_cache( |
url: str, |
cache_dir=None, |
force_download=False, |
proxies=None, |
etag_timeout=10, |
resume_download=False, |
user_agent: Union[Dict, str, None] = None, |
use_auth_token: Union[bool, str, None] = None, |
local_files_only=False, |
) -> Optional[str]: |
if cache_dir is None: |
if isinstance(cache_dir, Path): |
cache_dir = str(cache_dir) |
os.makedirs(cache_dir, exist_ok=True) |
headers = {"user-agent": http_user_agent(user_agent)} |
if isinstance(use_auth_token, str): |
headers["authorization"] = "Bearer {}".format(use_auth_token) |
elif use_auth_token: |
token = HfFolder.get_token() |
if token is None: |
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.") |
headers["authorization"] = "Bearer {}".format(token) |
url_to_download = url |
etag = None |
if not local_files_only: |
try: |
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout) |
r.raise_for_status() |
etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag") |
if etag is None: |
raise OSError( |
"Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility." |
) |
if 300 <= r.status_code <= 399: |
url_to_download = r.headers["Location"] |
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout): |
pass |
filename = url_to_filename(url, etag) |
cache_path = os.path.join(cache_dir, filename) |
if etag is None: |
if os.path.exists(cache_path): |
return cache_path |
else: |
matching_files = [ |
file |
for file in fnmatch.filter(os.listdir(cache_dir), filename.split(".")[0] + ".*") |
if not file.endswith(".json") and not file.endswith(".lock") |
] |
if len(matching_files) > 0: |
return os.path.join(cache_dir, matching_files[-1]) |
else: |
if local_files_only: |
raise FileNotFoundError( |
"Cannot find the requested files in the cached path and outgoing traffic has been" |
" disabled. To enable model look-ups and downloads online, set 'local_files_only'" |
" to False." |
) |
else: |
raise ValueError( |
"Connection error, and we cannot find the requested files in the cached path." |
" Please try again or make sure your Internet connection is on." |
) |
if os.path.exists(cache_path) and not force_download: |
return cache_path |
lock_path = cache_path + ".lock" |
with FileLock(lock_path): |
if os.path.exists(cache_path) and not force_download: |
return cache_path |
if resume_download: |
incomplete_path = cache_path + ".incomplete" |
@contextmanager |
def _resumable_file_manager() -> "io.BufferedWriter": |
with open(incomplete_path, "ab") as f: |
yield f |
temp_file_manager = _resumable_file_manager |
if os.path.exists(incomplete_path): |
resume_size = os.stat(incomplete_path).st_size |
else: |
resume_size = 0 |
else: |
temp_file_manager = partial(tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False) |
resume_size = 0 |
with temp_file_manager() as temp_file: |
http_get(url_to_download, temp_file, proxies=proxies, resume_size=resume_size, headers=headers) |
os.replace(temp_file.name, cache_path) |
meta = {"url": url, "etag": etag} |
meta_path = cache_path + ".json" |
with open(meta_path, "w") as meta_file: |
json.dump(meta, meta_file) |
return cache_path |
def cached_path( |
url_or_filename, |
cache_dir=None, |
force_download=False, |
proxies=None, |
resume_download=False, |
user_agent: Union[Dict, str, None] = None, |
extract_compressed_file=False, |
force_extract=False, |
use_auth_token: Union[bool, str, None] = None, |
local_files_only=False, |
) -> Optional[str]: |
if cache_dir is None: |
if isinstance(url_or_filename, Path): |
url_or_filename = str(url_or_filename) |
if isinstance(cache_dir, Path): |
cache_dir = str(cache_dir) |
if is_remote_url(url_or_filename): |
output_path = get_from_cache( |
url_or_filename, |
cache_dir=cache_dir, |
force_download=force_download, |
proxies=proxies, |
resume_download=resume_download, |
user_agent=user_agent, |
use_auth_token=use_auth_token, |
local_files_only=local_files_only, |
) |
elif os.path.exists(url_or_filename): |
output_path = url_or_filename |
elif urlparse(url_or_filename).scheme == "": |
raise EnvironmentError("file {} not found".format(url_or_filename)) |
else: |
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) |
if extract_compressed_file: |
if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path): |
return output_path |
output_dir, output_file = os.path.split(output_path) |
output_extract_dir_name = output_file.replace(".", "-") + "-extracted" |
output_path_extracted = os.path.join(output_dir, output_extract_dir_name) |
if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract: |
return output_path_extracted |
lock_path = output_path + ".lock" |
with FileLock(lock_path): |
shutil.rmtree(output_path_extracted, ignore_errors=True) |
os.makedirs(output_path_extracted) |
if is_zipfile(output_path): |
with ZipFile(output_path, "r") as zip_file: |
zip_file.extractall(output_path_extracted) |
zip_file.close() |
elif tarfile.is_tarfile(output_path): |
tar_file = tarfile.open(output_path) |
tar_file.extractall(output_path_extracted) |
tar_file.close() |
else: |
raise EnvironmentError("Archive format of {} could not be identified".format(output_path)) |
return output_path_extracted |
return output_path |
def get_parameter_dtype(parameter: Union[nn.Module]): |
try: |
return next(parameter.parameters()).dtype |
except StopIteration: |
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: |
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] |
return tuples |
gen = parameter._named_members(get_members_fn=find_tensor_attributes) |
first_tuple = next(gen) |
return first_tuple[1].dtype |
def get_extended_attention_mask(attention_mask: Tensor, dtype) -> Tensor: |
assert attention_mask.dim() == 2 |
extended_attention_mask = attention_mask[:, None, None, :] |
extended_attention_mask = extended_attention_mask.to(dtype=dtype) |
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 |
return extended_attention_mask |