Spaces:
Sleeping
Sleeping
import os | |
from functools import partial | |
import torch | |
from safetensors.torch import load_file as safe_load_file | |
from transformers.utils import ( | |
SAFE_WEIGHTS_INDEX_NAME, | |
SAFE_WEIGHTS_NAME, | |
WEIGHTS_INDEX_NAME, | |
WEIGHTS_NAME, | |
) | |
from transformers.utils.hub import cached_file, get_checkpoint_shard_files | |
def state_dict_from_pretrained(model_name, device=None, dtype=None): | |
# If not fp32, then we don't want to load directly to the GPU | |
mapped_device = "cpu" if dtype not in [torch.float32, None] else device | |
is_sharded = False | |
load_safe = False | |
resolved_archive_file = None | |
weights_path = os.path.join(model_name, WEIGHTS_NAME) | |
weights_index_path = os.path.join(model_name, WEIGHTS_INDEX_NAME) | |
safe_weights_path = os.path.join(model_name, SAFE_WEIGHTS_NAME) | |
safe_weights_index_path = os.path.join(model_name, SAFE_WEIGHTS_INDEX_NAME) | |
if os.path.isfile(weights_path): | |
resolved_archive_file = cached_file( | |
model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False | |
) | |
elif os.path.isfile(weights_index_path): | |
resolved_archive_file = cached_file( | |
model_name, WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False | |
) | |
is_sharded = True | |
elif os.path.isfile(safe_weights_path): | |
resolved_archive_file = cached_file( | |
model_name, SAFE_WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False | |
) | |
load_safe = True | |
elif os.path.isfile(safe_weights_index_path): | |
resolved_archive_file = cached_file( | |
model_name, SAFE_WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False | |
) | |
is_sharded = True | |
load_safe = True | |
else: # Try loading from HF hub instead of from local files | |
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, | |
_raise_exceptions_for_missing_entries=False) | |
if resolved_archive_file is None: | |
resolved_archive_file = cached_file(model_name, WEIGHTS_INDEX_NAME, | |
_raise_exceptions_for_missing_entries=False) | |
if resolved_archive_file is not None: | |
is_sharded = True | |
if resolved_archive_file is None: | |
raise EnvironmentError(f"Model name {model_name} was not found.") | |
if load_safe: | |
loader = partial(safe_load_file, device=mapped_device) | |
else: | |
loader = partial(torch.load, map_location=mapped_device) | |
if is_sharded: | |
# resolved_archive_file becomes a list of files that point to the different | |
# checkpoint shards in this case. | |
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( | |
model_name, resolved_archive_file | |
) | |
state_dict = {} | |
for sharded_file in resolved_archive_file: | |
state_dict.update(loader(sharded_file)) | |
else: | |
state_dict = loader(resolved_archive_file) | |
# Convert dtype before moving to GPU to save memory | |
if dtype is not None: | |
state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()} | |
state_dict = {k: v.to(device=device) for k, v in state_dict.items()} | |
return state_dict | |