|
from typing import Dict, List, Optional, Union, Tuple, BinaryIO |
|
import os |
|
import sys |
|
import json |
|
import shutil |
|
import tempfile |
|
import copy |
|
from tqdm.auto import tqdm |
|
from functools import partial |
|
from urllib.parse import urlparse |
|
from pathlib import Path |
|
import requests |
|
from hashlib import sha256 |
|
from filelock import FileLock |
|
import importlib_metadata |
|
import torch |
|
import torch.nn as nn |
|
from torch import Tensor |
|
import fnmatch |
|
|
|
|
|
__version__ = "4.0.0" |
|
_torch_version = importlib_metadata.version("torch") |
|
|
|
hf_cache_home = os.path.expanduser(os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))) |
|
default_cache_path = os.path.join(hf_cache_home, "transformers") |
|
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path) |
|
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE) |
|
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE) |
|
|
|
PRESET_MIRROR_DICT = { |
|
"tuna": "https://mirrors.tuna.tsinghua.edu.cn/hugging-face-models", |
|
"bfsu": "https://mirrors.bfsu.edu.cn/hugging-face-models", |
|
} |
|
HUGGINGFACE_CO_PREFIX = "https://huggingface.co/{model_id}/resolve/{revision}/{filename}" |
|
WEIGHTS_NAME = "pytorch_model.bin" |
|
CONFIG_NAME = "config.json" |
|
|
|
|
|
def is_torch_available(): |
|
return True |
|
|
|
|
|
def is_tf_available(): |
|
return False |
|
|
|
|
|
def is_remote_url(url_or_filename): |
|
parsed = urlparse(url_or_filename) |
|
return parsed.scheme in ("http", "https") |
|
|
|
|
|
def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers: Optional[Dict[str, str]] = None): |
|
headers = copy.deepcopy(headers) |
|
if resume_size > 0: |
|
headers["Range"] = "bytes=%d-" % (resume_size,) |
|
r = requests.get(url, stream=True, proxies=proxies, headers=headers) |
|
r.raise_for_status() |
|
content_length = r.headers.get("Content-Length") |
|
total = resume_size + int(content_length) if content_length is not None else None |
|
progress = tqdm( |
|
unit="B", |
|
unit_scale=True, |
|
total=total, |
|
initial=resume_size, |
|
desc="Downloading", |
|
disable=False, |
|
) |
|
for chunk in r.iter_content(chunk_size=1024): |
|
if chunk: |
|
progress.update(len(chunk)) |
|
temp_file.write(chunk) |
|
progress.close() |
|
|
|
|
|
def url_to_filename(url: str, etag: Optional[str] = None) -> str: |
|
url_bytes = url.encode("utf-8") |
|
filename = sha256(url_bytes).hexdigest() |
|
|
|
if etag: |
|
etag_bytes = etag.encode("utf-8") |
|
filename += "." + sha256(etag_bytes).hexdigest() |
|
|
|
if url.endswith(".h5"): |
|
filename += ".h5" |
|
|
|
return filename |
|
|
|
|
|
def hf_bucket_url( |
|
model_id: str, filename: str, subfolder: Optional[str] = None, revision: Optional[str] = None, mirror=None |
|
) -> str: |
|
if subfolder is not None: |
|
filename = f"{subfolder}/{filename}" |
|
|
|
if mirror: |
|
endpoint = PRESET_MIRROR_DICT.get(mirror, mirror) |
|
legacy_format = "/" not in model_id |
|
if legacy_format: |
|
return f"{endpoint}/{model_id}-{filename}" |
|
else: |
|
return f"{endpoint}/{model_id}/{filename}" |
|
|
|
if revision is None: |
|
revision = "main" |
|
return HUGGINGFACE_CO_PREFIX.format(model_id=model_id, revision=revision, filename=filename) |
|
|
|
|
|
def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str: |
|
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0]) |
|
if is_torch_available(): |
|
ua += f"; torch/{_torch_version}" |
|
if is_tf_available(): |
|
ua += f"; tensorflow/{_tf_version}" |
|
if isinstance(user_agent, dict): |
|
ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items()) |
|
elif isinstance(user_agent, str): |
|
ua += "; " + user_agent |
|
return ua |
|
|
|
|
|
def get_from_cache( |
|
url: str, |
|
cache_dir=None, |
|
force_download=False, |
|
proxies=None, |
|
etag_timeout=10, |
|
resume_download=False, |
|
user_agent: Union[Dict, str, None] = None, |
|
use_auth_token: Union[bool, str, None] = None, |
|
local_files_only=False, |
|
) -> Optional[str]: |
|
if cache_dir is None: |
|
cache_dir = TRANSFORMERS_CACHE |
|
if isinstance(cache_dir, Path): |
|
cache_dir = str(cache_dir) |
|
|
|
os.makedirs(cache_dir, exist_ok=True) |
|
|
|
headers = {"user-agent": http_user_agent(user_agent)} |
|
if isinstance(use_auth_token, str): |
|
headers["authorization"] = "Bearer {}".format(use_auth_token) |
|
elif use_auth_token: |
|
token = HfFolder.get_token() |
|
if token is None: |
|
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.") |
|
headers["authorization"] = "Bearer {}".format(token) |
|
|
|
url_to_download = url |
|
etag = None |
|
if not local_files_only: |
|
try: |
|
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout) |
|
r.raise_for_status() |
|
etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag") |
|
|
|
|
|
|
|
if etag is None: |
|
raise OSError( |
|
"Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility." |
|
) |
|
|
|
|
|
|
|
|
|
if 300 <= r.status_code <= 399: |
|
url_to_download = r.headers["Location"] |
|
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout): |
|
|
|
pass |
|
|
|
filename = url_to_filename(url, etag) |
|
|
|
|
|
cache_path = os.path.join(cache_dir, filename) |
|
|
|
|
|
|
|
if etag is None: |
|
if os.path.exists(cache_path): |
|
return cache_path |
|
else: |
|
matching_files = [ |
|
file |
|
for file in fnmatch.filter(os.listdir(cache_dir), filename.split(".")[0] + ".*") |
|
if not file.endswith(".json") and not file.endswith(".lock") |
|
] |
|
if len(matching_files) > 0: |
|
return os.path.join(cache_dir, matching_files[-1]) |
|
else: |
|
|
|
|
|
|
|
if local_files_only: |
|
raise FileNotFoundError( |
|
"Cannot find the requested files in the cached path and outgoing traffic has been" |
|
" disabled. To enable model look-ups and downloads online, set 'local_files_only'" |
|
" to False." |
|
) |
|
else: |
|
raise ValueError( |
|
"Connection error, and we cannot find the requested files in the cached path." |
|
" Please try again or make sure your Internet connection is on." |
|
) |
|
|
|
|
|
if os.path.exists(cache_path) and not force_download: |
|
return cache_path |
|
|
|
|
|
lock_path = cache_path + ".lock" |
|
with FileLock(lock_path): |
|
|
|
|
|
if os.path.exists(cache_path) and not force_download: |
|
|
|
return cache_path |
|
|
|
if resume_download: |
|
incomplete_path = cache_path + ".incomplete" |
|
|
|
@contextmanager |
|
def _resumable_file_manager() -> "io.BufferedWriter": |
|
with open(incomplete_path, "ab") as f: |
|
yield f |
|
|
|
temp_file_manager = _resumable_file_manager |
|
if os.path.exists(incomplete_path): |
|
resume_size = os.stat(incomplete_path).st_size |
|
else: |
|
resume_size = 0 |
|
else: |
|
temp_file_manager = partial(tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False) |
|
resume_size = 0 |
|
|
|
|
|
|
|
with temp_file_manager() as temp_file: |
|
http_get(url_to_download, temp_file, proxies=proxies, resume_size=resume_size, headers=headers) |
|
|
|
os.replace(temp_file.name, cache_path) |
|
|
|
meta = {"url": url, "etag": etag} |
|
meta_path = cache_path + ".json" |
|
with open(meta_path, "w") as meta_file: |
|
json.dump(meta, meta_file) |
|
|
|
return cache_path |
|
|
|
|
|
def cached_path( |
|
url_or_filename, |
|
cache_dir=None, |
|
force_download=False, |
|
proxies=None, |
|
resume_download=False, |
|
user_agent: Union[Dict, str, None] = None, |
|
extract_compressed_file=False, |
|
force_extract=False, |
|
use_auth_token: Union[bool, str, None] = None, |
|
local_files_only=False, |
|
) -> Optional[str]: |
|
if cache_dir is None: |
|
cache_dir = TRANSFORMERS_CACHE |
|
if isinstance(url_or_filename, Path): |
|
url_or_filename = str(url_or_filename) |
|
if isinstance(cache_dir, Path): |
|
cache_dir = str(cache_dir) |
|
|
|
if is_remote_url(url_or_filename): |
|
|
|
output_path = get_from_cache( |
|
url_or_filename, |
|
cache_dir=cache_dir, |
|
force_download=force_download, |
|
proxies=proxies, |
|
resume_download=resume_download, |
|
user_agent=user_agent, |
|
use_auth_token=use_auth_token, |
|
local_files_only=local_files_only, |
|
) |
|
elif os.path.exists(url_or_filename): |
|
|
|
output_path = url_or_filename |
|
elif urlparse(url_or_filename).scheme == "": |
|
|
|
raise EnvironmentError("file {} not found".format(url_or_filename)) |
|
else: |
|
|
|
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) |
|
|
|
if extract_compressed_file: |
|
if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path): |
|
return output_path |
|
|
|
|
|
|
|
output_dir, output_file = os.path.split(output_path) |
|
output_extract_dir_name = output_file.replace(".", "-") + "-extracted" |
|
output_path_extracted = os.path.join(output_dir, output_extract_dir_name) |
|
|
|
if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract: |
|
return output_path_extracted |
|
|
|
|
|
lock_path = output_path + ".lock" |
|
with FileLock(lock_path): |
|
shutil.rmtree(output_path_extracted, ignore_errors=True) |
|
os.makedirs(output_path_extracted) |
|
if is_zipfile(output_path): |
|
with ZipFile(output_path, "r") as zip_file: |
|
zip_file.extractall(output_path_extracted) |
|
zip_file.close() |
|
elif tarfile.is_tarfile(output_path): |
|
tar_file = tarfile.open(output_path) |
|
tar_file.extractall(output_path_extracted) |
|
tar_file.close() |
|
else: |
|
raise EnvironmentError("Archive format of {} could not be identified".format(output_path)) |
|
|
|
return output_path_extracted |
|
|
|
return output_path |
|
|
|
|
|
def get_parameter_dtype(parameter: Union[nn.Module]): |
|
try: |
|
return next(parameter.parameters()).dtype |
|
except StopIteration: |
|
|
|
|
|
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: |
|
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] |
|
return tuples |
|
|
|
gen = parameter._named_members(get_members_fn=find_tensor_attributes) |
|
first_tuple = next(gen) |
|
return first_tuple[1].dtype |
|
|
|
|
|
def get_extended_attention_mask(attention_mask: Tensor, dtype) -> Tensor: |
|
|
|
assert attention_mask.dim() == 2 |
|
|
|
extended_attention_mask = attention_mask[:, None, None, :] |
|
extended_attention_mask = extended_attention_mask.to(dtype=dtype) |
|
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 |
|
return extended_attention_mask |
|
|