|
import inspect |
|
import json |
|
import os |
|
import warnings |
|
from dataclasses import asdict, dataclass, is_dataclass |
|
from pathlib import Path |
|
from typing import ( |
|
TYPE_CHECKING, |
|
Any, |
|
Callable, |
|
Dict, |
|
List, |
|
Optional, |
|
Tuple, |
|
Type, |
|
TypeVar, |
|
Union, |
|
) |
|
|
|
from .constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME, SAFETENSORS_SINGLE_FILE |
|
from .file_download import hf_hub_download |
|
from .hf_api import HfApi |
|
from .repocard import ModelCard, ModelCardData |
|
from .utils import ( |
|
EntryNotFoundError, |
|
HfHubHTTPError, |
|
SoftTemporaryDirectory, |
|
is_jsonable, |
|
is_safetensors_available, |
|
is_simple_optional_type, |
|
is_torch_available, |
|
logging, |
|
unwrap_simple_optional_type, |
|
validate_hf_hub_args, |
|
) |
|
|
|
|
|
if TYPE_CHECKING: |
|
from _typeshed import DataclassInstance |
|
|
|
if is_torch_available(): |
|
import torch |
|
|
|
if is_safetensors_available(): |
|
from safetensors.torch import load_model as load_model_as_safetensor |
|
from safetensors.torch import save_model as save_model_as_safetensor |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
T = TypeVar("T", bound="ModelHubMixin") |
|
|
|
ARGS_T = TypeVar("ARGS_T") |
|
ENCODER_T = Callable[[ARGS_T], Any] |
|
DECODER_T = Callable[[Any], ARGS_T] |
|
CODER_T = Tuple[ENCODER_T, DECODER_T] |
|
|
|
|
|
DEFAULT_MODEL_CARD = """ |
|
--- |
|
# For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1 |
|
# Doc / guide: https://huggingface.co/docs/hub/model-cards |
|
{{ card_data }} |
|
--- |
|
|
|
This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration: |
|
- Library: {{ repo_url | default("[More Information Needed]", true) }} |
|
- Docs: {{ docs_url | default("[More Information Needed]", true) }} |
|
""" |
|
|
|
|
|
@dataclass |
|
class MixinInfo: |
|
model_card_template: str |
|
model_card_data: ModelCardData |
|
repo_url: Optional[str] = None |
|
docs_url: Optional[str] = None |
|
|
|
|
|
class ModelHubMixin: |
|
""" |
|
A generic mixin to integrate ANY machine learning framework with the Hub. |
|
|
|
To integrate your framework, your model class must inherit from this class. Custom logic for saving/loading models |
|
have to be overwritten in [`_from_pretrained`] and [`_save_pretrained`]. [`PyTorchModelHubMixin`] is a good example |
|
of mixin integration with the Hub. Check out our [integration guide](../guides/integrations) for more instructions. |
|
|
|
When inheriting from [`ModelHubMixin`], you can define class-level attributes. These attributes are not passed to |
|
`__init__` but to the class definition itself. This is useful to define metadata about the library integrating |
|
[`ModelHubMixin`]. |
|
|
|
For more details on how to integrate the mixin with your library, checkout the [integration guide](../guides/integrations). |
|
|
|
Args: |
|
repo_url (`str`, *optional*): |
|
URL of the library repository. Used to generate model card. |
|
docs_url (`str`, *optional*): |
|
URL of the library documentation. Used to generate model card. |
|
model_card_template (`str`, *optional*): |
|
Template of the model card. Used to generate model card. Defaults to a generic template. |
|
language (`str` or `List[str]`, *optional*): |
|
Language supported by the library. Used to generate model card. |
|
library_name (`str`, *optional*): |
|
Name of the library integrating ModelHubMixin. Used to generate model card. |
|
license (`str`, *optional*): |
|
License of the library integrating ModelHubMixin. Used to generate model card. |
|
E.g: "apache-2.0" |
|
license_name (`str`, *optional*): |
|
Name of the library integrating ModelHubMixin. Used to generate model card. |
|
Only used if `license` is set to `other`. |
|
E.g: "coqui-public-model-license". |
|
license_link (`str`, *optional*): |
|
URL to the license of the library integrating ModelHubMixin. Used to generate model card. |
|
Only used if `license` is set to `other` and `license_name` is set. |
|
E.g: "https://coqui.ai/cpml". |
|
pipeline_tag (`str`, *optional*): |
|
Tag of the pipeline. Used to generate model card. E.g. "text-classification". |
|
tags (`List[str]`, *optional*): |
|
Tags to be added to the model card. Used to generate model card. E.g. ["x-custom-tag", "arxiv:2304.12244"] |
|
coders (`Dict[Type, Tuple[Callable, Callable]]`, *optional*): |
|
Dictionary of custom types and their encoders/decoders. Used to encode/decode arguments that are not |
|
jsonable by default. E.g dataclasses, argparse.Namespace, OmegaConf, etc. |
|
|
|
Example: |
|
|
|
```python |
|
>>> from huggingface_hub import ModelHubMixin |
|
|
|
# Inherit from ModelHubMixin |
|
>>> class MyCustomModel( |
|
... ModelHubMixin, |
|
... library_name="my-library", |
|
... tags=["x-custom-tag", "arxiv:2304.12244"], |
|
... repo_url="https://github.com/huggingface/my-cool-library", |
|
... docs_url="https://huggingface.co/docs/my-cool-library", |
|
... # ^ optional metadata to generate model card |
|
... ): |
|
... def __init__(self, size: int = 512, device: str = "cpu"): |
|
... # define how to initialize your model |
|
... super().__init__() |
|
... ... |
|
... |
|
... def _save_pretrained(self, save_directory: Path) -> None: |
|
... # define how to serialize your model |
|
... ... |
|
... |
|
... @classmethod |
|
... def from_pretrained( |
|
... cls: Type[T], |
|
... pretrained_model_name_or_path: Union[str, Path], |
|
... *, |
|
... force_download: bool = False, |
|
... resume_download: Optional[bool] = None, |
|
... proxies: Optional[Dict] = None, |
|
... token: Optional[Union[str, bool]] = None, |
|
... cache_dir: Optional[Union[str, Path]] = None, |
|
... local_files_only: bool = False, |
|
... revision: Optional[str] = None, |
|
... **model_kwargs, |
|
... ) -> T: |
|
... # define how to deserialize your model |
|
... ... |
|
|
|
>>> model = MyCustomModel(size=256, device="gpu") |
|
|
|
# Save model weights to local directory |
|
>>> model.save_pretrained("my-awesome-model") |
|
|
|
# Push model weights to the Hub |
|
>>> model.push_to_hub("my-awesome-model") |
|
|
|
# Download and initialize weights from the Hub |
|
>>> reloaded_model = MyCustomModel.from_pretrained("username/my-awesome-model") |
|
>>> reloaded_model.size |
|
256 |
|
|
|
# Model card has been correctly populated |
|
>>> from huggingface_hub import ModelCard |
|
>>> card = ModelCard.load("username/my-awesome-model") |
|
>>> card.data.tags |
|
["x-custom-tag", "pytorch_model_hub_mixin", "model_hub_mixin"] |
|
>>> card.data.library_name |
|
"my-library" |
|
``` |
|
""" |
|
|
|
_hub_mixin_config: Optional[Union[dict, "DataclassInstance"]] = None |
|
|
|
_hub_mixin_info: MixinInfo |
|
|
|
_hub_mixin_inject_config: bool |
|
_hub_mixin_init_parameters: Dict[str, inspect.Parameter] |
|
_hub_mixin_jsonable_default_values: Dict[str, Any] |
|
_hub_mixin_jsonable_custom_types: Tuple[Type, ...] |
|
_hub_mixin_coders: Dict[Type, CODER_T] |
|
|
|
|
|
def __init_subclass__( |
|
cls, |
|
*, |
|
|
|
repo_url: Optional[str] = None, |
|
docs_url: Optional[str] = None, |
|
|
|
model_card_template: str = DEFAULT_MODEL_CARD, |
|
|
|
language: Optional[List[str]] = None, |
|
library_name: Optional[str] = None, |
|
license: Optional[str] = None, |
|
license_name: Optional[str] = None, |
|
license_link: Optional[str] = None, |
|
pipeline_tag: Optional[str] = None, |
|
tags: Optional[List[str]] = None, |
|
|
|
coders: Optional[ |
|
Dict[Type, CODER_T] |
|
|
|
|
|
|
|
] = None, |
|
|
|
languages: Optional[List[str]] = None, |
|
) -> None: |
|
"""Inspect __init__ signature only once when subclassing + handle modelcard.""" |
|
super().__init_subclass__() |
|
|
|
|
|
tags = tags or [] |
|
tags.append("model_hub_mixin") |
|
|
|
|
|
info = MixinInfo(model_card_template=model_card_template, model_card_data=ModelCardData()) |
|
|
|
|
|
if hasattr(cls, "_hub_mixin_info"): |
|
|
|
if model_card_template == DEFAULT_MODEL_CARD: |
|
info.model_card_template = cls._hub_mixin_info.model_card_template |
|
|
|
|
|
info.model_card_data = ModelCardData(**cls._hub_mixin_info.model_card_data.to_dict()) |
|
|
|
|
|
info.docs_url = cls._hub_mixin_info.docs_url |
|
info.repo_url = cls._hub_mixin_info.repo_url |
|
cls._hub_mixin_info = info |
|
|
|
if languages is not None: |
|
warnings.warn( |
|
"The `languages` argument is deprecated. Use `language` instead. This will be removed in `huggingface_hub>=0.27.0`.", |
|
DeprecationWarning, |
|
) |
|
language = languages |
|
|
|
|
|
if model_card_template is not None and model_card_template != DEFAULT_MODEL_CARD: |
|
info.model_card_template = model_card_template |
|
if repo_url is not None: |
|
info.repo_url = repo_url |
|
if docs_url is not None: |
|
info.docs_url = docs_url |
|
if language is not None: |
|
info.model_card_data.language = language |
|
if library_name is not None: |
|
info.model_card_data.library_name = library_name |
|
if license is not None: |
|
info.model_card_data.license = license |
|
if license_name is not None: |
|
info.model_card_data.license_name = license_name |
|
if license_link is not None: |
|
info.model_card_data.license_link = license_link |
|
if pipeline_tag is not None: |
|
info.model_card_data.pipeline_tag = pipeline_tag |
|
if tags is not None: |
|
if info.model_card_data.tags is not None: |
|
info.model_card_data.tags.extend(tags) |
|
else: |
|
info.model_card_data.tags = tags |
|
|
|
info.model_card_data.tags = sorted(set(info.model_card_data.tags)) |
|
|
|
|
|
cls._hub_mixin_coders = coders or {} |
|
cls._hub_mixin_jsonable_custom_types = tuple(cls._hub_mixin_coders.keys()) |
|
|
|
|
|
cls._hub_mixin_init_parameters = dict(inspect.signature(cls.__init__).parameters) |
|
cls._hub_mixin_jsonable_default_values = { |
|
param.name: cls._encode_arg(param.default) |
|
for param in cls._hub_mixin_init_parameters.values() |
|
if param.default is not inspect.Parameter.empty and cls._is_jsonable(param.default) |
|
} |
|
cls._hub_mixin_inject_config = "config" in inspect.signature(cls._from_pretrained).parameters |
|
|
|
def __new__(cls, *args, **kwargs) -> "ModelHubMixin": |
|
"""Create a new instance of the class and handle config. |
|
|
|
3 cases: |
|
- If `self._hub_mixin_config` is already set, do nothing. |
|
- If `config` is passed as a dataclass, set it as `self._hub_mixin_config`. |
|
- Otherwise, build `self._hub_mixin_config` from default values and passed values. |
|
""" |
|
instance = super().__new__(cls) |
|
|
|
|
|
if instance._hub_mixin_config is not None: |
|
return instance |
|
|
|
|
|
passed_values = { |
|
**{ |
|
key: value |
|
for key, value in zip( |
|
|
|
list(cls._hub_mixin_init_parameters)[1:], |
|
args, |
|
) |
|
}, |
|
**kwargs, |
|
} |
|
|
|
|
|
if is_dataclass(passed_values.get("config")): |
|
instance._hub_mixin_config = passed_values["config"] |
|
return instance |
|
|
|
|
|
init_config = { |
|
|
|
**cls._hub_mixin_jsonable_default_values, |
|
|
|
**{ |
|
key: cls._encode_arg(value) |
|
for key, value in passed_values.items() |
|
if instance._is_jsonable(value) |
|
}, |
|
} |
|
passed_config = init_config.pop("config", {}) |
|
|
|
|
|
if isinstance(passed_config, dict): |
|
init_config.update(passed_config) |
|
|
|
|
|
if init_config != {}: |
|
instance._hub_mixin_config = init_config |
|
return instance |
|
|
|
@classmethod |
|
def _is_jsonable(cls, value: Any) -> bool: |
|
"""Check if a value is JSON serializable.""" |
|
if isinstance(value, cls._hub_mixin_jsonable_custom_types): |
|
return True |
|
return is_jsonable(value) |
|
|
|
@classmethod |
|
def _encode_arg(cls, arg: Any) -> Any: |
|
"""Encode an argument into a JSON serializable format.""" |
|
for type_, (encoder, _) in cls._hub_mixin_coders.items(): |
|
if isinstance(arg, type_): |
|
if arg is None: |
|
return None |
|
return encoder(arg) |
|
return arg |
|
|
|
@classmethod |
|
def _decode_arg(cls, expected_type: Type[ARGS_T], value: Any) -> Optional[ARGS_T]: |
|
"""Decode a JSON serializable value into an argument.""" |
|
if is_simple_optional_type(expected_type): |
|
if value is None: |
|
return None |
|
expected_type = unwrap_simple_optional_type(expected_type) |
|
|
|
if is_dataclass(expected_type): |
|
return _load_dataclass(expected_type, value) |
|
|
|
for type_, (_, decoder) in cls._hub_mixin_coders.items(): |
|
if inspect.isclass(expected_type) and issubclass(expected_type, type_): |
|
return decoder(value) |
|
|
|
return value |
|
|
|
def save_pretrained( |
|
self, |
|
save_directory: Union[str, Path], |
|
*, |
|
config: Optional[Union[dict, "DataclassInstance"]] = None, |
|
repo_id: Optional[str] = None, |
|
push_to_hub: bool = False, |
|
model_card_kwargs: Optional[Dict[str, Any]] = None, |
|
**push_to_hub_kwargs, |
|
) -> Optional[str]: |
|
""" |
|
Save weights in local directory. |
|
|
|
Args: |
|
save_directory (`str` or `Path`): |
|
Path to directory in which the model weights and configuration will be saved. |
|
config (`dict` or `DataclassInstance`, *optional*): |
|
Model configuration specified as a key/value dictionary or a dataclass instance. |
|
push_to_hub (`bool`, *optional*, defaults to `False`): |
|
Whether or not to push your model to the Huggingface Hub after saving it. |
|
repo_id (`str`, *optional*): |
|
ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if |
|
not provided. |
|
model_card_kwargs (`Dict[str, Any]`, *optional*): |
|
Additional arguments passed to the model card template to customize the model card. |
|
push_to_hub_kwargs: |
|
Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method. |
|
Returns: |
|
`str` or `None`: url of the commit on the Hub if `push_to_hub=True`, `None` otherwise. |
|
""" |
|
save_directory = Path(save_directory) |
|
save_directory.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
|
config_path = save_directory / CONFIG_NAME |
|
config_path.unlink(missing_ok=True) |
|
|
|
|
|
self._save_pretrained(save_directory) |
|
|
|
|
|
if config is None: |
|
config = self._hub_mixin_config |
|
if config is not None: |
|
if is_dataclass(config): |
|
config = asdict(config) |
|
if not config_path.exists(): |
|
config_str = json.dumps(config, sort_keys=True, indent=2) |
|
config_path.write_text(config_str) |
|
|
|
|
|
model_card_path = save_directory / "README.md" |
|
model_card_kwargs = model_card_kwargs if model_card_kwargs is not None else {} |
|
if not model_card_path.exists(): |
|
self.generate_model_card(**model_card_kwargs).save(save_directory / "README.md") |
|
|
|
|
|
if push_to_hub: |
|
kwargs = push_to_hub_kwargs.copy() |
|
if config is not None: |
|
kwargs["config"] = config |
|
if repo_id is None: |
|
repo_id = save_directory.name |
|
return self.push_to_hub(repo_id=repo_id, model_card_kwargs=model_card_kwargs, **kwargs) |
|
return None |
|
|
|
def _save_pretrained(self, save_directory: Path) -> None: |
|
""" |
|
Overwrite this method in subclass to define how to save your model. |
|
Check out our [integration guide](../guides/integrations) for instructions. |
|
|
|
Args: |
|
save_directory (`str` or `Path`): |
|
Path to directory in which the model weights and configuration will be saved. |
|
""" |
|
raise NotImplementedError |
|
|
|
@classmethod |
|
@validate_hf_hub_args |
|
def from_pretrained( |
|
cls: Type[T], |
|
pretrained_model_name_or_path: Union[str, Path], |
|
*, |
|
force_download: bool = False, |
|
resume_download: Optional[bool] = None, |
|
proxies: Optional[Dict] = None, |
|
token: Optional[Union[str, bool]] = None, |
|
cache_dir: Optional[Union[str, Path]] = None, |
|
local_files_only: bool = False, |
|
revision: Optional[str] = None, |
|
**model_kwargs, |
|
) -> T: |
|
""" |
|
Download a model from the Huggingface Hub and instantiate it. |
|
|
|
Args: |
|
pretrained_model_name_or_path (`str`, `Path`): |
|
- Either the `model_id` (string) of a model hosted on the Hub, e.g. `bigscience/bloom`. |
|
- Or a path to a `directory` containing model weights saved using |
|
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `../path/to/my_model_directory/`. |
|
revision (`str`, *optional*): |
|
Revision of the model on the Hub. Can be a branch name, a git tag or any commit id. |
|
Defaults to the latest commit on `main` branch. |
|
force_download (`bool`, *optional*, defaults to `False`): |
|
Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding |
|
the existing cache. |
|
proxies (`Dict[str, str]`, *optional*): |
|
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', |
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on every request. |
|
token (`str` or `bool`, *optional*): |
|
The token to use as HTTP bearer authorization for remote files. By default, it will use the token |
|
cached when running `huggingface-cli login`. |
|
cache_dir (`str`, `Path`, *optional*): |
|
Path to the folder where cached files are stored. |
|
local_files_only (`bool`, *optional*, defaults to `False`): |
|
If `True`, avoid downloading the file and return the path to the local cached file if it exists. |
|
model_kwargs (`Dict`, *optional*): |
|
Additional kwargs to pass to the model during initialization. |
|
""" |
|
model_id = str(pretrained_model_name_or_path) |
|
config_file: Optional[str] = None |
|
if os.path.isdir(model_id): |
|
if CONFIG_NAME in os.listdir(model_id): |
|
config_file = os.path.join(model_id, CONFIG_NAME) |
|
else: |
|
logger.warning(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}") |
|
else: |
|
try: |
|
config_file = hf_hub_download( |
|
repo_id=model_id, |
|
filename=CONFIG_NAME, |
|
revision=revision, |
|
cache_dir=cache_dir, |
|
force_download=force_download, |
|
proxies=proxies, |
|
resume_download=resume_download, |
|
token=token, |
|
local_files_only=local_files_only, |
|
) |
|
except HfHubHTTPError as e: |
|
logger.info(f"{CONFIG_NAME} not found on the HuggingFace Hub: {str(e)}") |
|
|
|
|
|
config = None |
|
if config_file is not None: |
|
with open(config_file, "r", encoding="utf-8") as f: |
|
config = json.load(f) |
|
|
|
|
|
for key, value in config.items(): |
|
if key in cls._hub_mixin_init_parameters: |
|
expected_type = cls._hub_mixin_init_parameters[key].annotation |
|
if expected_type is not inspect.Parameter.empty: |
|
config[key] = cls._decode_arg(expected_type, value) |
|
|
|
|
|
for param in cls._hub_mixin_init_parameters.values(): |
|
if param.name not in model_kwargs and param.name in config: |
|
model_kwargs[param.name] = config[param.name] |
|
|
|
|
|
if "config" in cls._hub_mixin_init_parameters and "config" not in model_kwargs: |
|
|
|
config_annotation = cls._hub_mixin_init_parameters["config"].annotation |
|
config = cls._decode_arg(config_annotation, config) |
|
|
|
|
|
model_kwargs["config"] = config |
|
|
|
|
|
if is_dataclass(cls): |
|
for key in cls.__dataclass_fields__: |
|
if key not in model_kwargs and key in config: |
|
model_kwargs[key] = config[key] |
|
elif any(param.kind == inspect.Parameter.VAR_KEYWORD for param in cls._hub_mixin_init_parameters.values()): |
|
for key, value in config.items(): |
|
if key not in model_kwargs: |
|
model_kwargs[key] = value |
|
|
|
|
|
if cls._hub_mixin_inject_config and "config" not in model_kwargs: |
|
model_kwargs["config"] = config |
|
|
|
instance = cls._from_pretrained( |
|
model_id=str(model_id), |
|
revision=revision, |
|
cache_dir=cache_dir, |
|
force_download=force_download, |
|
proxies=proxies, |
|
resume_download=resume_download, |
|
local_files_only=local_files_only, |
|
token=token, |
|
**model_kwargs, |
|
) |
|
|
|
|
|
|
|
if config is not None and (getattr(instance, "_hub_mixin_config", None) in (None, {})): |
|
instance._hub_mixin_config = config |
|
|
|
return instance |
|
|
|
@classmethod |
|
def _from_pretrained( |
|
cls: Type[T], |
|
*, |
|
model_id: str, |
|
revision: Optional[str], |
|
cache_dir: Optional[Union[str, Path]], |
|
force_download: bool, |
|
proxies: Optional[Dict], |
|
resume_download: Optional[bool], |
|
local_files_only: bool, |
|
token: Optional[Union[str, bool]], |
|
**model_kwargs, |
|
) -> T: |
|
"""Overwrite this method in subclass to define how to load your model from pretrained. |
|
|
|
Use [`hf_hub_download`] or [`snapshot_download`] to download files from the Hub before loading them. Most |
|
args taken as input can be directly passed to those 2 methods. If needed, you can add more arguments to this |
|
method using "model_kwargs". For example [`PyTorchModelHubMixin._from_pretrained`] takes as input a `map_location` |
|
parameter to set on which device the model should be loaded. |
|
|
|
Check out our [integration guide](../guides/integrations) for more instructions. |
|
|
|
Args: |
|
model_id (`str`): |
|
ID of the model to load from the Huggingface Hub (e.g. `bigscience/bloom`). |
|
revision (`str`, *optional*): |
|
Revision of the model on the Hub. Can be a branch name, a git tag or any commit id. Defaults to the |
|
latest commit on `main` branch. |
|
force_download (`bool`, *optional*, defaults to `False`): |
|
Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding |
|
the existing cache. |
|
proxies (`Dict[str, str]`, *optional*): |
|
A dictionary of proxy servers to use by protocol or endpoint (e.g., `{'http': 'foo.bar:3128', |
|
'http://hostname': 'foo.bar:4012'}`). |
|
token (`str` or `bool`, *optional*): |
|
The token to use as HTTP bearer authorization for remote files. By default, it will use the token |
|
cached when running `huggingface-cli login`. |
|
cache_dir (`str`, `Path`, *optional*): |
|
Path to the folder where cached files are stored. |
|
local_files_only (`bool`, *optional*, defaults to `False`): |
|
If `True`, avoid downloading the file and return the path to the local cached file if it exists. |
|
model_kwargs: |
|
Additional keyword arguments passed along to the [`~ModelHubMixin._from_pretrained`] method. |
|
""" |
|
raise NotImplementedError |
|
|
|
@validate_hf_hub_args |
|
def push_to_hub( |
|
self, |
|
repo_id: str, |
|
*, |
|
config: Optional[Union[dict, "DataclassInstance"]] = None, |
|
commit_message: str = "Push model using huggingface_hub.", |
|
private: bool = False, |
|
token: Optional[str] = None, |
|
branch: Optional[str] = None, |
|
create_pr: Optional[bool] = None, |
|
allow_patterns: Optional[Union[List[str], str]] = None, |
|
ignore_patterns: Optional[Union[List[str], str]] = None, |
|
delete_patterns: Optional[Union[List[str], str]] = None, |
|
model_card_kwargs: Optional[Dict[str, Any]] = None, |
|
) -> str: |
|
""" |
|
Upload model checkpoint to the Hub. |
|
|
|
Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use |
|
`delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more |
|
details. |
|
|
|
Args: |
|
repo_id (`str`): |
|
ID of the repository to push to (example: `"username/my-model"`). |
|
config (`dict` or `DataclassInstance`, *optional*): |
|
Model configuration specified as a key/value dictionary or a dataclass instance. |
|
commit_message (`str`, *optional*): |
|
Message to commit while pushing. |
|
private (`bool`, *optional*, defaults to `False`): |
|
Whether the repository created should be private. |
|
token (`str`, *optional*): |
|
The token to use as HTTP bearer authorization for remote files. By default, it will use the token |
|
cached when running `huggingface-cli login`. |
|
branch (`str`, *optional*): |
|
The git branch on which to push the model. This defaults to `"main"`. |
|
create_pr (`boolean`, *optional*): |
|
Whether or not to create a Pull Request from `branch` with that commit. Defaults to `False`. |
|
allow_patterns (`List[str]` or `str`, *optional*): |
|
If provided, only files matching at least one pattern are pushed. |
|
ignore_patterns (`List[str]` or `str`, *optional*): |
|
If provided, files matching any of the patterns are not pushed. |
|
delete_patterns (`List[str]` or `str`, *optional*): |
|
If provided, remote files matching any of the patterns will be deleted from the repo. |
|
model_card_kwargs (`Dict[str, Any]`, *optional*): |
|
Additional arguments passed to the model card template to customize the model card. |
|
|
|
Returns: |
|
The url of the commit of your model in the given repository. |
|
""" |
|
api = HfApi(token=token) |
|
repo_id = api.create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id |
|
|
|
|
|
with SoftTemporaryDirectory() as tmp: |
|
saved_path = Path(tmp) / repo_id |
|
self.save_pretrained(saved_path, config=config, model_card_kwargs=model_card_kwargs) |
|
return api.upload_folder( |
|
repo_id=repo_id, |
|
repo_type="model", |
|
folder_path=saved_path, |
|
commit_message=commit_message, |
|
revision=branch, |
|
create_pr=create_pr, |
|
allow_patterns=allow_patterns, |
|
ignore_patterns=ignore_patterns, |
|
delete_patterns=delete_patterns, |
|
) |
|
|
|
def generate_model_card(self, *args, **kwargs) -> ModelCard: |
|
card = ModelCard.from_template( |
|
card_data=self._hub_mixin_info.model_card_data, |
|
template_str=self._hub_mixin_info.model_card_template, |
|
repo_url=self._hub_mixin_info.repo_url, |
|
docs_url=self._hub_mixin_info.docs_url, |
|
**kwargs, |
|
) |
|
return card |
|
|
|
|
|
class PyTorchModelHubMixin(ModelHubMixin): |
|
""" |
|
Implementation of [`ModelHubMixin`] to provide model Hub upload/download capabilities to PyTorch models. The model |
|
is set in evaluation mode by default using `model.eval()` (dropout modules are deactivated). To train the model, |
|
you should first set it back in training mode with `model.train()`. |
|
|
|
See [`ModelHubMixin`] for more details on how to use the mixin. |
|
|
|
Example: |
|
|
|
```python |
|
>>> import torch |
|
>>> import torch.nn as nn |
|
>>> from huggingface_hub import PyTorchModelHubMixin |
|
|
|
>>> class MyModel( |
|
... nn.Module, |
|
... PyTorchModelHubMixin, |
|
... library_name="keras-nlp", |
|
... repo_url="https://github.com/keras-team/keras-nlp", |
|
... docs_url="https://keras.io/keras_nlp/", |
|
... # ^ optional metadata to generate model card |
|
... ): |
|
... def __init__(self, hidden_size: int = 512, vocab_size: int = 30000, output_size: int = 4): |
|
... super().__init__() |
|
... self.param = nn.Parameter(torch.rand(hidden_size, vocab_size)) |
|
... self.linear = nn.Linear(output_size, vocab_size) |
|
|
|
... def forward(self, x): |
|
... return self.linear(x + self.param) |
|
>>> model = MyModel(hidden_size=256) |
|
|
|
# Save model weights to local directory |
|
>>> model.save_pretrained("my-awesome-model") |
|
|
|
# Push model weights to the Hub |
|
>>> model.push_to_hub("my-awesome-model") |
|
|
|
# Download and initialize weights from the Hub |
|
>>> model = MyModel.from_pretrained("username/my-awesome-model") |
|
>>> model.hidden_size |
|
256 |
|
``` |
|
""" |
|
|
|
def __init_subclass__(cls, *args, tags: Optional[List[str]] = None, **kwargs) -> None: |
|
tags = tags or [] |
|
tags.append("pytorch_model_hub_mixin") |
|
kwargs["tags"] = tags |
|
return super().__init_subclass__(*args, **kwargs) |
|
|
|
def _save_pretrained(self, save_directory: Path) -> None: |
|
"""Save weights from a Pytorch model to a local directory.""" |
|
model_to_save = self.module if hasattr(self, "module") else self |
|
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE)) |
|
|
|
@classmethod |
|
def _from_pretrained( |
|
cls, |
|
*, |
|
model_id: str, |
|
revision: Optional[str], |
|
cache_dir: Optional[Union[str, Path]], |
|
force_download: bool, |
|
proxies: Optional[Dict], |
|
resume_download: Optional[bool], |
|
local_files_only: bool, |
|
token: Union[str, bool, None], |
|
map_location: str = "cpu", |
|
strict: bool = False, |
|
**model_kwargs, |
|
): |
|
"""Load Pytorch pretrained weights and return the loaded model.""" |
|
model = cls(**model_kwargs) |
|
if os.path.isdir(model_id): |
|
print("Loading weights from local directory") |
|
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE) |
|
return cls._load_as_safetensor(model, model_file, map_location, strict) |
|
else: |
|
try: |
|
model_file = hf_hub_download( |
|
repo_id=model_id, |
|
filename=SAFETENSORS_SINGLE_FILE, |
|
revision=revision, |
|
cache_dir=cache_dir, |
|
force_download=force_download, |
|
proxies=proxies, |
|
resume_download=resume_download, |
|
token=token, |
|
local_files_only=local_files_only, |
|
) |
|
return cls._load_as_safetensor(model, model_file, map_location, strict) |
|
except EntryNotFoundError: |
|
model_file = hf_hub_download( |
|
repo_id=model_id, |
|
filename=PYTORCH_WEIGHTS_NAME, |
|
revision=revision, |
|
cache_dir=cache_dir, |
|
force_download=force_download, |
|
proxies=proxies, |
|
resume_download=resume_download, |
|
token=token, |
|
local_files_only=local_files_only, |
|
) |
|
return cls._load_as_pickle(model, model_file, map_location, strict) |
|
|
|
@classmethod |
|
def _load_as_pickle(cls, model: T, model_file: str, map_location: str, strict: bool) -> T: |
|
state_dict = torch.load(model_file, map_location=torch.device(map_location)) |
|
model.load_state_dict(state_dict, strict=strict) |
|
model.eval() |
|
return model |
|
|
|
@classmethod |
|
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T: |
|
load_model_as_safetensor(model, model_file, strict=strict) |
|
if map_location != "cpu": |
|
|
|
logger.warning( |
|
"Loading model weights on other devices than 'cpu' is not supported natively." |
|
" This means that the model is loaded on 'cpu' first and then copied to the device." |
|
" This leads to a slower loading time." |
|
" Support for loading directly on other devices is planned to be added in future releases." |
|
" See https://github.com/huggingface/huggingface_hub/pull/2086 for more details." |
|
) |
|
model.to(map_location) |
|
return model |
|
|
|
|
|
def _load_dataclass(datacls: Type["DataclassInstance"], data: dict) -> "DataclassInstance": |
|
"""Load a dataclass instance from a dictionary. |
|
|
|
Fields not expected by the dataclass are ignored. |
|
""" |
|
return datacls(**{k: v for k, v in data.items() if k in datacls.__dataclass_fields__}) |
|
|