File size: 4,028 Bytes
3ab8901
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import contextlib
import contextvars
import threading
from typing import (
    Callable,
    ContextManager,
    NamedTuple,
    Optional,
    TypeVar,
    Union,
)

import torch

__all__ = ["no_init_or_tensor"]


Model = TypeVar("Model")


def no_init_or_tensor(
    loading_code: Optional[Callable[..., Model]] = None
) -> Union[Model, ContextManager]:
    """
    Suppress the initialization of weights while loading a model.

    Can either directly be passed a callable containing model-loading code,
    which will be evaluated with weight initialization suppressed,
    or used as a context manager around arbitrary model-loading code.

    Args:
        loading_code: Either a callable to evaluate
            with model weight initialization suppressed,
            or None (the default) to use as a context manager.

    Returns:
        The return value of `loading_code`, if `loading_code` is callable.

        Otherwise, if `loading_code` is None, returns a context manager
        to be used in a `with`-statement.

    Examples:
        As a context manager::

            from transformers import AutoConfig, AutoModelForCausalLM
            config = AutoConfig("EleutherAI/gpt-j-6B")
            with no_init_or_tensor():
                model = AutoModelForCausalLM.from_config(config)

        Or, directly passing a callable::

            from transformers import AutoConfig, AutoModelForCausalLM
            config = AutoConfig("EleutherAI/gpt-j-6B")
            model = no_init_or_tensor(lambda: AutoModelForCausalLM.from_config(config))
    """
    if loading_code is None:
        return _NoInitOrTensorImpl.context_manager()
    elif callable(loading_code):
        with _NoInitOrTensorImpl.context_manager():
            return loading_code()
    else:
        raise TypeError(
            "no_init_or_tensor() expected a callable to evaluate,"
            " or None if being used as a context manager;"
            f' got an object of type "{type(loading_code).__name__}" instead.'
        )


class _NoInitOrTensorImpl:
    # Implementation of the thread-safe, async-safe, re-entrant context manager
    # version of no_init_or_tensor().
    # This class essentially acts as a namespace.
    # It is not instantiable, because modifications to torch functions
    # inherently affect the global scope, and thus there is no worthwhile data
    # to store in the class instance scope.
    _MODULES = (torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm)
    _MODULE_ORIGINALS = tuple((m, m.reset_parameters) for m in _MODULES)
    _ORIGINAL_EMPTY = torch.empty

    is_active = contextvars.ContextVar("_NoInitOrTensorImpl.is_active", default=False)
    _count_active: int = 0
    _count_active_lock = threading.Lock()

    @classmethod
    @contextlib.contextmanager
    def context_manager(cls):
        if cls.is_active.get():
            yield
            return

        with cls._count_active_lock:
            cls._count_active += 1
            if cls._count_active == 1:
                for mod in cls._MODULES:
                    mod.reset_parameters = cls._disable(mod.reset_parameters)
                # When torch.empty is called, make it map to meta device by replacing
                # the device in kwargs.
                torch.empty = cls._ORIGINAL_EMPTY
        reset_token = cls.is_active.set(True)

        try:
            yield
        finally:
            cls.is_active.reset(reset_token)
            with cls._count_active_lock:
                cls._count_active -= 1
                if cls._count_active == 0:
                    torch.empty = cls._ORIGINAL_EMPTY
                    for mod, original in cls._MODULE_ORIGINALS:
                        mod.reset_parameters = original

    @staticmethod
    def _disable(func):
        def wrapper(*args, **kwargs):
            # Behaves as normal except in an active context
            if not _NoInitOrTensorImpl.is_active.get():
                return func(*args, **kwargs)

        return wrapper

    __init__ = None