File size: 37,308 Bytes
dd2bdcb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
import inspect
import json
import os
import warnings
from dataclasses import asdict, dataclass, is_dataclass
from pathlib import Path
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Dict,
    List,
    Optional,
    Tuple,
    Type,
    TypeVar,
    Union,
)

from .constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME, SAFETENSORS_SINGLE_FILE
from .file_download import hf_hub_download
from .hf_api import HfApi
from .repocard import ModelCard, ModelCardData
from .utils import (
    EntryNotFoundError,
    HfHubHTTPError,
    SoftTemporaryDirectory,
    is_jsonable,
    is_safetensors_available,
    is_simple_optional_type,
    is_torch_available,
    logging,
    unwrap_simple_optional_type,
    validate_hf_hub_args,
)


if TYPE_CHECKING:
    from _typeshed import DataclassInstance

if is_torch_available():
    import torch  # type: ignore

if is_safetensors_available():
    from safetensors.torch import load_model as load_model_as_safetensor
    from safetensors.torch import save_model as save_model_as_safetensor


logger = logging.get_logger(__name__)

# Generic variable that is either ModelHubMixin or a subclass thereof
T = TypeVar("T", bound="ModelHubMixin")
# Generic variable to represent an args type
ARGS_T = TypeVar("ARGS_T")
ENCODER_T = Callable[[ARGS_T], Any]
DECODER_T = Callable[[Any], ARGS_T]
CODER_T = Tuple[ENCODER_T, DECODER_T]


DEFAULT_MODEL_CARD = """
---
# For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
# Doc / guide: https://huggingface.co/docs/hub/model-cards
{{ card_data }}
---

This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
- Library: {{ repo_url | default("[More Information Needed]", true) }}
- Docs: {{ docs_url | default("[More Information Needed]", true) }}
"""


@dataclass
class MixinInfo:
    model_card_template: str
    model_card_data: ModelCardData
    repo_url: Optional[str] = None
    docs_url: Optional[str] = None


class ModelHubMixin:
    """
    A generic mixin to integrate ANY machine learning framework with the Hub.

    To integrate your framework, your model class must inherit from this class. Custom logic for saving/loading models
    have to be overwritten in  [`_from_pretrained`] and [`_save_pretrained`]. [`PyTorchModelHubMixin`] is a good example
    of mixin integration with the Hub. Check out our [integration guide](../guides/integrations) for more instructions.

    When inheriting from [`ModelHubMixin`], you can define class-level attributes. These attributes are not passed to
    `__init__` but to the class definition itself. This is useful to define metadata about the library integrating
    [`ModelHubMixin`].

    For more details on how to integrate the mixin with your library, checkout the [integration guide](../guides/integrations).

    Args:
        repo_url (`str`, *optional*):
            URL of the library repository. Used to generate model card.
        docs_url (`str`, *optional*):
            URL of the library documentation. Used to generate model card.
        model_card_template (`str`, *optional*):
            Template of the model card. Used to generate model card. Defaults to a generic template.
        language (`str` or `List[str]`, *optional*):
            Language supported by the library. Used to generate model card.
        library_name (`str`, *optional*):
            Name of the library integrating ModelHubMixin. Used to generate model card.
        license (`str`, *optional*):
            License of the library integrating ModelHubMixin. Used to generate model card.
            E.g: "apache-2.0"
        license_name (`str`, *optional*):
            Name of the library integrating ModelHubMixin. Used to generate model card.
            Only used if `license` is set to `other`.
            E.g: "coqui-public-model-license".
        license_link (`str`, *optional*):
            URL to the license of the library integrating ModelHubMixin. Used to generate model card.
            Only used if `license` is set to `other` and `license_name` is set.
            E.g: "https://coqui.ai/cpml".
        pipeline_tag (`str`, *optional*):
            Tag of the pipeline. Used to generate model card. E.g. "text-classification".
        tags (`List[str]`, *optional*):
            Tags to be added to the model card. Used to generate model card. E.g. ["x-custom-tag", "arxiv:2304.12244"]
        coders (`Dict[Type, Tuple[Callable, Callable]]`, *optional*):
            Dictionary of custom types and their encoders/decoders. Used to encode/decode arguments that are not
            jsonable by default. E.g dataclasses, argparse.Namespace, OmegaConf, etc.

    Example:

    ```python
    >>> from huggingface_hub import ModelHubMixin

    # Inherit from ModelHubMixin
    >>> class MyCustomModel(
    ...         ModelHubMixin,
    ...         library_name="my-library",
    ...         tags=["x-custom-tag", "arxiv:2304.12244"],
    ...         repo_url="https://github.com/huggingface/my-cool-library",
    ...         docs_url="https://huggingface.co/docs/my-cool-library",
    ...         # ^ optional metadata to generate model card
    ...     ):
    ...     def __init__(self, size: int = 512, device: str = "cpu"):
    ...         # define how to initialize your model
    ...         super().__init__()
    ...         ...
    ...
    ...     def _save_pretrained(self, save_directory: Path) -> None:
    ...         # define how to serialize your model
    ...         ...
    ...
    ...     @classmethod
    ...     def from_pretrained(
    ...         cls: Type[T],
    ...         pretrained_model_name_or_path: Union[str, Path],
    ...         *,
    ...         force_download: bool = False,
    ...         resume_download: Optional[bool] = None,
    ...         proxies: Optional[Dict] = None,
    ...         token: Optional[Union[str, bool]] = None,
    ...         cache_dir: Optional[Union[str, Path]] = None,
    ...         local_files_only: bool = False,
    ...         revision: Optional[str] = None,
    ...         **model_kwargs,
    ...     ) -> T:
    ...         # define how to deserialize your model
    ...         ...

    >>> model = MyCustomModel(size=256, device="gpu")

    # Save model weights to local directory
    >>> model.save_pretrained("my-awesome-model")

    # Push model weights to the Hub
    >>> model.push_to_hub("my-awesome-model")

    # Download and initialize weights from the Hub
    >>> reloaded_model = MyCustomModel.from_pretrained("username/my-awesome-model")
    >>> reloaded_model.size
    256

    # Model card has been correctly populated
    >>> from huggingface_hub import ModelCard
    >>> card = ModelCard.load("username/my-awesome-model")
    >>> card.data.tags
    ["x-custom-tag", "pytorch_model_hub_mixin", "model_hub_mixin"]
    >>> card.data.library_name
    "my-library"
    ```
    """

    _hub_mixin_config: Optional[Union[dict, "DataclassInstance"]] = None
    # ^ optional config attribute automatically set in `from_pretrained`
    _hub_mixin_info: MixinInfo
    # ^ information about the library integrating ModelHubMixin (used to generate model card)
    _hub_mixin_inject_config: bool  # whether `_from_pretrained` expects `config` or not
    _hub_mixin_init_parameters: Dict[str, inspect.Parameter]  # __init__ parameters
    _hub_mixin_jsonable_default_values: Dict[str, Any]  # default values for __init__ parameters
    _hub_mixin_jsonable_custom_types: Tuple[Type, ...]  # custom types that can be encoded/decoded
    _hub_mixin_coders: Dict[Type, CODER_T]  # encoders/decoders for custom types
    # ^ internal values to handle config

    def __init_subclass__(
        cls,
        *,
        # Generic info for model card
        repo_url: Optional[str] = None,
        docs_url: Optional[str] = None,
        # Model card template
        model_card_template: str = DEFAULT_MODEL_CARD,
        # Model card metadata
        language: Optional[List[str]] = None,
        library_name: Optional[str] = None,
        license: Optional[str] = None,
        license_name: Optional[str] = None,
        license_link: Optional[str] = None,
        pipeline_tag: Optional[str] = None,
        tags: Optional[List[str]] = None,
        # How to encode/decode arguments with custom type into a JSON config?
        coders: Optional[
            Dict[Type, CODER_T]
            # Key is a type.
            # Value is a tuple (encoder, decoder).
            # Example: {MyCustomType: (lambda x: x.value, lambda data: MyCustomType(data))}
        ] = None,
        # Deprecated arguments
        languages: Optional[List[str]] = None,
    ) -> None:
        """Inspect __init__ signature only once when subclassing + handle modelcard."""
        super().__init_subclass__()

        # Will be reused when creating modelcard
        tags = tags or []
        tags.append("model_hub_mixin")

        # Initialize MixinInfo if not existent
        info = MixinInfo(model_card_template=model_card_template, model_card_data=ModelCardData())

        # If parent class has a MixinInfo, inherit from it as a copy
        if hasattr(cls, "_hub_mixin_info"):
            # Inherit model card template from parent class if not explicitly set
            if model_card_template == DEFAULT_MODEL_CARD:
                info.model_card_template = cls._hub_mixin_info.model_card_template

            # Inherit from parent model card data
            info.model_card_data = ModelCardData(**cls._hub_mixin_info.model_card_data.to_dict())

            # Inherit other info
            info.docs_url = cls._hub_mixin_info.docs_url
            info.repo_url = cls._hub_mixin_info.repo_url
        cls._hub_mixin_info = info

        if languages is not None:
            warnings.warn(
                "The `languages` argument is deprecated. Use `language` instead. This will be removed in `huggingface_hub>=0.27.0`.",
                DeprecationWarning,
            )
            language = languages

        # Update MixinInfo with metadata
        if model_card_template is not None and model_card_template != DEFAULT_MODEL_CARD:
            info.model_card_template = model_card_template
        if repo_url is not None:
            info.repo_url = repo_url
        if docs_url is not None:
            info.docs_url = docs_url
        if language is not None:
            info.model_card_data.language = language
        if library_name is not None:
            info.model_card_data.library_name = library_name
        if license is not None:
            info.model_card_data.license = license
        if license_name is not None:
            info.model_card_data.license_name = license_name
        if license_link is not None:
            info.model_card_data.license_link = license_link
        if pipeline_tag is not None:
            info.model_card_data.pipeline_tag = pipeline_tag
        if tags is not None:
            if info.model_card_data.tags is not None:
                info.model_card_data.tags.extend(tags)
            else:
                info.model_card_data.tags = tags

        info.model_card_data.tags = sorted(set(info.model_card_data.tags))

        # Handle encoders/decoders for args
        cls._hub_mixin_coders = coders or {}
        cls._hub_mixin_jsonable_custom_types = tuple(cls._hub_mixin_coders.keys())

        # Inspect __init__ signature to handle config
        cls._hub_mixin_init_parameters = dict(inspect.signature(cls.__init__).parameters)
        cls._hub_mixin_jsonable_default_values = {
            param.name: cls._encode_arg(param.default)
            for param in cls._hub_mixin_init_parameters.values()
            if param.default is not inspect.Parameter.empty and cls._is_jsonable(param.default)
        }
        cls._hub_mixin_inject_config = "config" in inspect.signature(cls._from_pretrained).parameters

    def __new__(cls, *args, **kwargs) -> "ModelHubMixin":
        """Create a new instance of the class and handle config.

        3 cases:
        - If `self._hub_mixin_config` is already set, do nothing.
        - If `config` is passed as a dataclass, set it as `self._hub_mixin_config`.
        - Otherwise, build `self._hub_mixin_config` from default values and passed values.
        """
        instance = super().__new__(cls)

        # If `config` is already set, return early
        if instance._hub_mixin_config is not None:
            return instance

        # Infer passed values
        passed_values = {
            **{
                key: value
                for key, value in zip(
                    # [1:] to skip `self` parameter
                    list(cls._hub_mixin_init_parameters)[1:],
                    args,
                )
            },
            **kwargs,
        }

        # If config passed as dataclass => set it and return early
        if is_dataclass(passed_values.get("config")):
            instance._hub_mixin_config = passed_values["config"]
            return instance

        # Otherwise, build config from default + passed values
        init_config = {
            # default values
            **cls._hub_mixin_jsonable_default_values,
            # passed values
            **{
                key: cls._encode_arg(value)  # Encode custom types as jsonable value
                for key, value in passed_values.items()
                if instance._is_jsonable(value)  # Only if jsonable or we have a custom encoder
            },
        }
        passed_config = init_config.pop("config", {})

        # Populate `init_config` with provided config
        if isinstance(passed_config, dict):
            init_config.update(passed_config)

        # Set `config` attribute and return
        if init_config != {}:
            instance._hub_mixin_config = init_config
        return instance

    @classmethod
    def _is_jsonable(cls, value: Any) -> bool:
        """Check if a value is JSON serializable."""
        if isinstance(value, cls._hub_mixin_jsonable_custom_types):
            return True
        return is_jsonable(value)

    @classmethod
    def _encode_arg(cls, arg: Any) -> Any:
        """Encode an argument into a JSON serializable format."""
        for type_, (encoder, _) in cls._hub_mixin_coders.items():
            if isinstance(arg, type_):
                if arg is None:
                    return None
                return encoder(arg)
        return arg

    @classmethod
    def _decode_arg(cls, expected_type: Type[ARGS_T], value: Any) -> Optional[ARGS_T]:
        """Decode a JSON serializable value into an argument."""
        if is_simple_optional_type(expected_type):
            if value is None:
                return None
            expected_type = unwrap_simple_optional_type(expected_type)
        # Dataclass => handle it
        if is_dataclass(expected_type):
            return _load_dataclass(expected_type, value)  # type: ignore[return-value]
        # Otherwise => check custom decoders
        for type_, (_, decoder) in cls._hub_mixin_coders.items():
            if inspect.isclass(expected_type) and issubclass(expected_type, type_):
                return decoder(value)
        # Otherwise => don't decode
        return value

    def save_pretrained(
        self,
        save_directory: Union[str, Path],
        *,
        config: Optional[Union[dict, "DataclassInstance"]] = None,
        repo_id: Optional[str] = None,
        push_to_hub: bool = False,
        model_card_kwargs: Optional[Dict[str, Any]] = None,
        **push_to_hub_kwargs,
    ) -> Optional[str]:
        """
        Save weights in local directory.

        Args:
            save_directory (`str` or `Path`):
                Path to directory in which the model weights and configuration will be saved.
            config (`dict` or `DataclassInstance`, *optional*):
                Model configuration specified as a key/value dictionary or a dataclass instance.
            push_to_hub (`bool`, *optional*, defaults to `False`):
                Whether or not to push your model to the Huggingface Hub after saving it.
            repo_id (`str`, *optional*):
                ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
                not provided.
            model_card_kwargs (`Dict[str, Any]`, *optional*):
                Additional arguments passed to the model card template to customize the model card.
            push_to_hub_kwargs:
                Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method.
        Returns:
            `str` or `None`: url of the commit on the Hub if `push_to_hub=True`, `None` otherwise.
        """
        save_directory = Path(save_directory)
        save_directory.mkdir(parents=True, exist_ok=True)

        # Remove config.json if already exists. After `_save_pretrained` we don't want to overwrite config.json
        # as it might have been saved by the custom `_save_pretrained` already. However we do want to overwrite
        # an existing config.json if it was not saved by `_save_pretrained`.
        config_path = save_directory / CONFIG_NAME
        config_path.unlink(missing_ok=True)

        # save model weights/files (framework-specific)
        self._save_pretrained(save_directory)

        # save config (if provided and if not serialized yet in `_save_pretrained`)
        if config is None:
            config = self._hub_mixin_config
        if config is not None:
            if is_dataclass(config):
                config = asdict(config)  # type: ignore[arg-type]
            if not config_path.exists():
                config_str = json.dumps(config, sort_keys=True, indent=2)
                config_path.write_text(config_str)

        # save model card
        model_card_path = save_directory / "README.md"
        model_card_kwargs = model_card_kwargs if model_card_kwargs is not None else {}
        if not model_card_path.exists():  # do not overwrite if already exists
            self.generate_model_card(**model_card_kwargs).save(save_directory / "README.md")

        # push to the Hub if required
        if push_to_hub:
            kwargs = push_to_hub_kwargs.copy()  # soft-copy to avoid mutating input
            if config is not None:  # kwarg for `push_to_hub`
                kwargs["config"] = config
            if repo_id is None:
                repo_id = save_directory.name  # Defaults to `save_directory` name
            return self.push_to_hub(repo_id=repo_id, model_card_kwargs=model_card_kwargs, **kwargs)
        return None

    def _save_pretrained(self, save_directory: Path) -> None:
        """
        Overwrite this method in subclass to define how to save your model.
        Check out our [integration guide](../guides/integrations) for instructions.

        Args:
            save_directory (`str` or `Path`):
                Path to directory in which the model weights and configuration will be saved.
        """
        raise NotImplementedError

    @classmethod
    @validate_hf_hub_args
    def from_pretrained(
        cls: Type[T],
        pretrained_model_name_or_path: Union[str, Path],
        *,
        force_download: bool = False,
        resume_download: Optional[bool] = None,
        proxies: Optional[Dict] = None,
        token: Optional[Union[str, bool]] = None,
        cache_dir: Optional[Union[str, Path]] = None,
        local_files_only: bool = False,
        revision: Optional[str] = None,
        **model_kwargs,
    ) -> T:
        """
        Download a model from the Huggingface Hub and instantiate it.

        Args:
            pretrained_model_name_or_path (`str`, `Path`):
                - Either the `model_id` (string) of a model hosted on the Hub, e.g. `bigscience/bloom`.
                - Or a path to a `directory` containing model weights saved using
                    [`~transformers.PreTrainedModel.save_pretrained`], e.g., `../path/to/my_model_directory/`.
            revision (`str`, *optional*):
                Revision of the model on the Hub. Can be a branch name, a git tag or any commit id.
                Defaults to the latest commit on `main` branch.
            force_download (`bool`, *optional*, defaults to `False`):
                Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding
                the existing cache.
            proxies (`Dict[str, str]`, *optional*):
                A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
                'http://hostname': 'foo.bar:4012'}`. The proxies are used on every request.
            token (`str` or `bool`, *optional*):
                The token to use as HTTP bearer authorization for remote files. By default, it will use the token
                cached when running `huggingface-cli login`.
            cache_dir (`str`, `Path`, *optional*):
                Path to the folder where cached files are stored.
            local_files_only (`bool`, *optional*, defaults to `False`):
                If `True`, avoid downloading the file and return the path to the local cached file if it exists.
            model_kwargs (`Dict`, *optional*):
                Additional kwargs to pass to the model during initialization.
        """
        model_id = str(pretrained_model_name_or_path)
        config_file: Optional[str] = None
        if os.path.isdir(model_id):
            if CONFIG_NAME in os.listdir(model_id):
                config_file = os.path.join(model_id, CONFIG_NAME)
            else:
                logger.warning(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
        else:
            try:
                config_file = hf_hub_download(
                    repo_id=model_id,
                    filename=CONFIG_NAME,
                    revision=revision,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
                    resume_download=resume_download,
                    token=token,
                    local_files_only=local_files_only,
                )
            except HfHubHTTPError as e:
                logger.info(f"{CONFIG_NAME} not found on the HuggingFace Hub: {str(e)}")

        # Read config
        config = None
        if config_file is not None:
            with open(config_file, "r", encoding="utf-8") as f:
                config = json.load(f)

            # Decode custom types in config
            for key, value in config.items():
                if key in cls._hub_mixin_init_parameters:
                    expected_type = cls._hub_mixin_init_parameters[key].annotation
                    if expected_type is not inspect.Parameter.empty:
                        config[key] = cls._decode_arg(expected_type, value)

            # Populate model_kwargs from config
            for param in cls._hub_mixin_init_parameters.values():
                if param.name not in model_kwargs and param.name in config:
                    model_kwargs[param.name] = config[param.name]

            # Check if `config` argument was passed at init
            if "config" in cls._hub_mixin_init_parameters and "config" not in model_kwargs:
                # Decode `config` argument if it was passed
                config_annotation = cls._hub_mixin_init_parameters["config"].annotation
                config = cls._decode_arg(config_annotation, config)

                # Forward config to model initialization
                model_kwargs["config"] = config

            # Inject config if `**kwargs` are expected
            if is_dataclass(cls):
                for key in cls.__dataclass_fields__:
                    if key not in model_kwargs and key in config:
                        model_kwargs[key] = config[key]
            elif any(param.kind == inspect.Parameter.VAR_KEYWORD for param in cls._hub_mixin_init_parameters.values()):
                for key, value in config.items():
                    if key not in model_kwargs:
                        model_kwargs[key] = value

            # Finally, also inject if `_from_pretrained` expects it
            if cls._hub_mixin_inject_config and "config" not in model_kwargs:
                model_kwargs["config"] = config

        instance = cls._from_pretrained(
            model_id=str(model_id),
            revision=revision,
            cache_dir=cache_dir,
            force_download=force_download,
            proxies=proxies,
            resume_download=resume_download,
            local_files_only=local_files_only,
            token=token,
            **model_kwargs,
        )

        # Implicitly set the config as instance attribute if not already set by the class
        # This way `config` will be available when calling `save_pretrained` or `push_to_hub`.
        if config is not None and (getattr(instance, "_hub_mixin_config", None) in (None, {})):
            instance._hub_mixin_config = config

        return instance

    @classmethod
    def _from_pretrained(
        cls: Type[T],
        *,
        model_id: str,
        revision: Optional[str],
        cache_dir: Optional[Union[str, Path]],
        force_download: bool,
        proxies: Optional[Dict],
        resume_download: Optional[bool],
        local_files_only: bool,
        token: Optional[Union[str, bool]],
        **model_kwargs,
    ) -> T:
        """Overwrite this method in subclass to define how to load your model from pretrained.

        Use [`hf_hub_download`] or [`snapshot_download`] to download files from the Hub before loading them. Most
        args taken as input can be directly passed to those 2 methods. If needed, you can add more arguments to this
        method using "model_kwargs". For example [`PyTorchModelHubMixin._from_pretrained`] takes as input a `map_location`
        parameter to set on which device the model should be loaded.

        Check out our [integration guide](../guides/integrations) for more instructions.

        Args:
            model_id (`str`):
                ID of the model to load from the Huggingface Hub (e.g. `bigscience/bloom`).
            revision (`str`, *optional*):
                Revision of the model on the Hub. Can be a branch name, a git tag or any commit id. Defaults to the
                latest commit on `main` branch.
            force_download (`bool`, *optional*, defaults to `False`):
                Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding
                the existing cache.
            proxies (`Dict[str, str]`, *optional*):
                A dictionary of proxy servers to use by protocol or endpoint (e.g., `{'http': 'foo.bar:3128',
                'http://hostname': 'foo.bar:4012'}`).
            token (`str` or `bool`, *optional*):
                The token to use as HTTP bearer authorization for remote files. By default, it will use the token
                cached when running `huggingface-cli login`.
            cache_dir (`str`, `Path`, *optional*):
                Path to the folder where cached files are stored.
            local_files_only (`bool`, *optional*, defaults to `False`):
                If `True`, avoid downloading the file and return the path to the local cached file if it exists.
            model_kwargs:
                Additional keyword arguments passed along to the [`~ModelHubMixin._from_pretrained`] method.
        """
        raise NotImplementedError

    @validate_hf_hub_args
    def push_to_hub(
        self,
        repo_id: str,
        *,
        config: Optional[Union[dict, "DataclassInstance"]] = None,
        commit_message: str = "Push model using huggingface_hub.",
        private: bool = False,
        token: Optional[str] = None,
        branch: Optional[str] = None,
        create_pr: Optional[bool] = None,
        allow_patterns: Optional[Union[List[str], str]] = None,
        ignore_patterns: Optional[Union[List[str], str]] = None,
        delete_patterns: Optional[Union[List[str], str]] = None,
        model_card_kwargs: Optional[Dict[str, Any]] = None,
    ) -> str:
        """
        Upload model checkpoint to the Hub.

        Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use
        `delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more
        details.

        Args:
            repo_id (`str`):
                ID of the repository to push to (example: `"username/my-model"`).
            config (`dict` or `DataclassInstance`, *optional*):
                Model configuration specified as a key/value dictionary or a dataclass instance.
            commit_message (`str`, *optional*):
                Message to commit while pushing.
            private (`bool`, *optional*, defaults to `False`):
                Whether the repository created should be private.
            token (`str`, *optional*):
                The token to use as HTTP bearer authorization for remote files. By default, it will use the token
                cached when running `huggingface-cli login`.
            branch (`str`, *optional*):
                The git branch on which to push the model. This defaults to `"main"`.
            create_pr (`boolean`, *optional*):
                Whether or not to create a Pull Request from `branch` with that commit. Defaults to `False`.
            allow_patterns (`List[str]` or `str`, *optional*):
                If provided, only files matching at least one pattern are pushed.
            ignore_patterns (`List[str]` or `str`, *optional*):
                If provided, files matching any of the patterns are not pushed.
            delete_patterns (`List[str]` or `str`, *optional*):
                If provided, remote files matching any of the patterns will be deleted from the repo.
            model_card_kwargs (`Dict[str, Any]`, *optional*):
                Additional arguments passed to the model card template to customize the model card.

        Returns:
            The url of the commit of your model in the given repository.
        """
        api = HfApi(token=token)
        repo_id = api.create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id

        # Push the files to the repo in a single commit
        with SoftTemporaryDirectory() as tmp:
            saved_path = Path(tmp) / repo_id
            self.save_pretrained(saved_path, config=config, model_card_kwargs=model_card_kwargs)
            return api.upload_folder(
                repo_id=repo_id,
                repo_type="model",
                folder_path=saved_path,
                commit_message=commit_message,
                revision=branch,
                create_pr=create_pr,
                allow_patterns=allow_patterns,
                ignore_patterns=ignore_patterns,
                delete_patterns=delete_patterns,
            )

    def generate_model_card(self, *args, **kwargs) -> ModelCard:
        card = ModelCard.from_template(
            card_data=self._hub_mixin_info.model_card_data,
            template_str=self._hub_mixin_info.model_card_template,
            repo_url=self._hub_mixin_info.repo_url,
            docs_url=self._hub_mixin_info.docs_url,
            **kwargs,
        )
        return card


class PyTorchModelHubMixin(ModelHubMixin):
    """
    Implementation of [`ModelHubMixin`] to provide model Hub upload/download capabilities to PyTorch models. The model
    is set in evaluation mode by default using `model.eval()` (dropout modules are deactivated). To train the model,
    you should first set it back in training mode with `model.train()`.

    See [`ModelHubMixin`] for more details on how to use the mixin.

    Example:

    ```python
    >>> import torch
    >>> import torch.nn as nn
    >>> from huggingface_hub import PyTorchModelHubMixin

    >>> class MyModel(
    ...         nn.Module,
    ...         PyTorchModelHubMixin,
    ...         library_name="keras-nlp",
    ...         repo_url="https://github.com/keras-team/keras-nlp",
    ...         docs_url="https://keras.io/keras_nlp/",
    ...         # ^ optional metadata to generate model card
    ...     ):
    ...     def __init__(self, hidden_size: int = 512, vocab_size: int = 30000, output_size: int = 4):
    ...         super().__init__()
    ...         self.param = nn.Parameter(torch.rand(hidden_size, vocab_size))
    ...         self.linear = nn.Linear(output_size, vocab_size)

    ...     def forward(self, x):
    ...         return self.linear(x + self.param)
    >>> model = MyModel(hidden_size=256)

    # Save model weights to local directory
    >>> model.save_pretrained("my-awesome-model")

    # Push model weights to the Hub
    >>> model.push_to_hub("my-awesome-model")

    # Download and initialize weights from the Hub
    >>> model = MyModel.from_pretrained("username/my-awesome-model")
    >>> model.hidden_size
    256
    ```
    """

    def __init_subclass__(cls, *args, tags: Optional[List[str]] = None, **kwargs) -> None:
        tags = tags or []
        tags.append("pytorch_model_hub_mixin")
        kwargs["tags"] = tags
        return super().__init_subclass__(*args, **kwargs)

    def _save_pretrained(self, save_directory: Path) -> None:
        """Save weights from a Pytorch model to a local directory."""
        model_to_save = self.module if hasattr(self, "module") else self  # type: ignore
        save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))

    @classmethod
    def _from_pretrained(
        cls,
        *,
        model_id: str,
        revision: Optional[str],
        cache_dir: Optional[Union[str, Path]],
        force_download: bool,
        proxies: Optional[Dict],
        resume_download: Optional[bool],
        local_files_only: bool,
        token: Union[str, bool, None],
        map_location: str = "cpu",
        strict: bool = False,
        **model_kwargs,
    ):
        """Load Pytorch pretrained weights and return the loaded model."""
        model = cls(**model_kwargs)
        if os.path.isdir(model_id):
            print("Loading weights from local directory")
            model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
            return cls._load_as_safetensor(model, model_file, map_location, strict)
        else:
            try:
                model_file = hf_hub_download(
                    repo_id=model_id,
                    filename=SAFETENSORS_SINGLE_FILE,
                    revision=revision,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
                    resume_download=resume_download,
                    token=token,
                    local_files_only=local_files_only,
                )
                return cls._load_as_safetensor(model, model_file, map_location, strict)
            except EntryNotFoundError:
                model_file = hf_hub_download(
                    repo_id=model_id,
                    filename=PYTORCH_WEIGHTS_NAME,
                    revision=revision,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
                    resume_download=resume_download,
                    token=token,
                    local_files_only=local_files_only,
                )
                return cls._load_as_pickle(model, model_file, map_location, strict)

    @classmethod
    def _load_as_pickle(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
        state_dict = torch.load(model_file, map_location=torch.device(map_location))
        model.load_state_dict(state_dict, strict=strict)  # type: ignore
        model.eval()  # type: ignore
        return model

    @classmethod
    def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
        load_model_as_safetensor(model, model_file, strict=strict)  # type: ignore [arg-type]
        if map_location != "cpu":
            # TODO: remove this once https://github.com/huggingface/safetensors/pull/449 is merged.
            logger.warning(
                "Loading model weights on other devices than 'cpu' is not supported natively."
                " This means that the model is loaded on 'cpu' first and then copied to the device."
                " This leads to a slower loading time."
                " Support for loading directly on other devices is planned to be added in future releases."
                " See https://github.com/huggingface/huggingface_hub/pull/2086 for more details."
            )
            model.to(map_location)  # type: ignore [attr-defined]
        return model


def _load_dataclass(datacls: Type["DataclassInstance"], data: dict) -> "DataclassInstance":
    """Load a dataclass instance from a dictionary.

    Fields not expected by the dataclass are ignored.
    """
    return datacls(**{k: v for k, v in data.items() if k in datacls.__dataclass_fields__})