Train parameters exclusively in specific ranges (#1390)
Browse files* Train parameters exclusively in specific ranges
* Fix the style and update docs
* Update yaml example
- examples/mistral/mixtral.yml +6 -6
- src/axolotl/train.py +2 -2
- src/axolotl/utils/freeze.py +199 -11
- tests/test_freeze.py +285 -0
examples/mistral/mixtral.yml
CHANGED
|
@@ -16,12 +16,12 @@ output_dir: ./qlora-out
|
|
| 16 |
|
| 17 |
## You can optionally freeze the entire model and unfreeze a subset of parameters
|
| 18 |
unfrozen_parameters:
|
| 19 |
-
# - lm_head
|
| 20 |
-
# - model.embed_tokens
|
| 21 |
-
# - model.layers.2[0-9]+.block_sparse_moe.gate
|
| 22 |
-
# - model.layers.2[0-9]+.block_sparse_moe.experts
|
| 23 |
-
# - model.layers.3[0-9]+.block_sparse_moe.gate
|
| 24 |
-
# - model.layers.3[0-9]+.block_sparse_moe.experts
|
| 25 |
|
| 26 |
model_config:
|
| 27 |
output_router_logits: true
|
|
|
|
| 16 |
|
| 17 |
## You can optionally freeze the entire model and unfreeze a subset of parameters
|
| 18 |
unfrozen_parameters:
|
| 19 |
+
# - ^lm_head.weight$
|
| 20 |
+
# - ^model.embed_tokens.weight$[:32000]
|
| 21 |
+
# - model.layers.2[0-9]+.block_sparse_moe.gate
|
| 22 |
+
# - model.layers.2[0-9]+.block_sparse_moe.experts
|
| 23 |
+
# - model.layers.3[0-9]+.block_sparse_moe.gate
|
| 24 |
+
# - model.layers.3[0-9]+.block_sparse_moe.experts
|
| 25 |
|
| 26 |
model_config:
|
| 27 |
output_router_logits: true
|
src/axolotl/train.py
CHANGED
|
@@ -19,7 +19,7 @@ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
|
| 19 |
from axolotl.common.cli import TrainerCliArgs
|
| 20 |
from axolotl.logging_config import configure_logging
|
| 21 |
from axolotl.utils.dict import DictDefault
|
| 22 |
-
from axolotl.utils.freeze import
|
| 23 |
from axolotl.utils.models import load_model, load_tokenizer
|
| 24 |
from axolotl.utils.trainer import setup_trainer
|
| 25 |
|
|
@@ -99,7 +99,7 @@ def train(
|
|
| 99 |
safe_serialization = cfg.save_safetensors is True
|
| 100 |
|
| 101 |
if cfg.unfrozen_parameters:
|
| 102 |
-
|
| 103 |
|
| 104 |
trainer = setup_trainer(
|
| 105 |
cfg,
|
|
|
|
| 19 |
from axolotl.common.cli import TrainerCliArgs
|
| 20 |
from axolotl.logging_config import configure_logging
|
| 21 |
from axolotl.utils.dict import DictDefault
|
| 22 |
+
from axolotl.utils.freeze import freeze_layers_except
|
| 23 |
from axolotl.utils.models import load_model, load_tokenizer
|
| 24 |
from axolotl.utils.trainer import setup_trainer
|
| 25 |
|
|
|
|
| 99 |
safe_serialization = cfg.save_safetensors is True
|
| 100 |
|
| 101 |
if cfg.unfrozen_parameters:
|
| 102 |
+
freeze_layers_except(model, cfg.unfrozen_parameters)
|
| 103 |
|
| 104 |
trainer = setup_trainer(
|
| 105 |
cfg,
|
src/axolotl/utils/freeze.py
CHANGED
|
@@ -3,13 +3,14 @@ module to freeze/unfreeze parameters by name
|
|
| 3 |
"""
|
| 4 |
import logging
|
| 5 |
import re
|
|
|
|
| 6 |
|
| 7 |
from axolotl.utils.distributed import is_main_process
|
| 8 |
|
| 9 |
LOG = logging.getLogger("axolotl.utils.freeze")
|
| 10 |
|
| 11 |
|
| 12 |
-
def
|
| 13 |
"""
|
| 14 |
Freezes all layers of the given model except for the layers that match given regex patterns.
|
| 15 |
Periods in the patterns are treated as literal periods, not as wildcard characters.
|
|
@@ -17,22 +18,209 @@ def freeze_parameters_except(model, regex_patterns):
|
|
| 17 |
Parameters:
|
| 18 |
- model (nn.Module): The PyTorch model to be modified.
|
| 19 |
- regex_patterns (list of str): List of regex patterns to match layer names to keep unfrozen.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
Returns:
|
| 22 |
None; the model is modified in place.
|
| 23 |
"""
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
re.compile(pattern.replace(".", "\\.")) for pattern in regex_patterns
|
| 27 |
-
]
|
| 28 |
|
| 29 |
-
|
| 30 |
-
for param in model.parameters():
|
| 31 |
-
param.requires_grad = False
|
| 32 |
|
| 33 |
# Unfreeze layers that match the regex patterns
|
| 34 |
for name, param in model.named_parameters():
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
| 38 |
param.requires_grad = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
"""
|
| 4 |
import logging
|
| 5 |
import re
|
| 6 |
+
from typing import Callable, List, Tuple
|
| 7 |
|
| 8 |
from axolotl.utils.distributed import is_main_process
|
| 9 |
|
| 10 |
LOG = logging.getLogger("axolotl.utils.freeze")
|
| 11 |
|
| 12 |
|
| 13 |
+
def freeze_layers_except(model, regex_patterns):
|
| 14 |
"""
|
| 15 |
Freezes all layers of the given model except for the layers that match given regex patterns.
|
| 16 |
Periods in the patterns are treated as literal periods, not as wildcard characters.
|
|
|
|
| 18 |
Parameters:
|
| 19 |
- model (nn.Module): The PyTorch model to be modified.
|
| 20 |
- regex_patterns (list of str): List of regex patterns to match layer names to keep unfrozen.
|
| 21 |
+
Note that you cannot use a dot as a wildcard character in the patterns since it is reserved for separating layer names.
|
| 22 |
+
Also, to match the entire layer name, the pattern should start with "^" and end with "$", otherwise it will match any part of the layer name.
|
| 23 |
+
The range pattern part is optional and it is not compiled as a regex pattern which means you must put "$" before the range pattern if you want to match the entire layer name.
|
| 24 |
+
E.g., ["^model.embed_tokens.weight$[:32000]", "layers.2[0-9]+.block_sparse_moe.gate.[a-z]+$"]
|
| 25 |
|
| 26 |
Returns:
|
| 27 |
None; the model is modified in place.
|
| 28 |
"""
|
| 29 |
+
if isinstance(regex_patterns, str):
|
| 30 |
+
regex_patterns = [regex_patterns]
|
|
|
|
|
|
|
| 31 |
|
| 32 |
+
patterns = [LayerNamePattern(pattern) for pattern in regex_patterns]
|
|
|
|
|
|
|
| 33 |
|
| 34 |
# Unfreeze layers that match the regex patterns
|
| 35 |
for name, param in model.named_parameters():
|
| 36 |
+
param.requires_grad = False
|
| 37 |
+
unfrozen_ranges = []
|
| 38 |
+
for pattern in patterns:
|
| 39 |
+
if not pattern.match(name):
|
| 40 |
+
continue
|
| 41 |
+
|
| 42 |
param.requires_grad = True
|
| 43 |
+
|
| 44 |
+
if pattern.range is not None:
|
| 45 |
+
unfrozen_ranges.append(pattern.range)
|
| 46 |
+
|
| 47 |
+
merged_unfrozen_ranges = _merge_ranges(unfrozen_ranges, len(param))
|
| 48 |
+
|
| 49 |
+
if param.requires_grad and is_main_process():
|
| 50 |
+
unfrozen_ranges = (
|
| 51 |
+
f" with ranges {merged_unfrozen_ranges}"
|
| 52 |
+
if merged_unfrozen_ranges
|
| 53 |
+
else ""
|
| 54 |
+
)
|
| 55 |
+
LOG.debug(f"Unfrozen {name}{unfrozen_ranges}")
|
| 56 |
+
|
| 57 |
+
if not merged_unfrozen_ranges:
|
| 58 |
+
continue
|
| 59 |
+
|
| 60 |
+
# The range list we need is actually the inverted of the merged ranges
|
| 61 |
+
ranges_to_freeze = _invert_ranges(merged_unfrozen_ranges, len(param))
|
| 62 |
+
|
| 63 |
+
param.register_hook(_create_freeze_parameters_hook(ranges_to_freeze))
|
| 64 |
+
|
| 65 |
+
if is_main_process() and all(
|
| 66 |
+
not param.requires_grad for param in model.parameters()
|
| 67 |
+
):
|
| 68 |
+
LOG.warning("All parameters are frozen. Model will not be trained.")
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _invert_ranges(
|
| 72 |
+
given_ranges: List[Tuple[int, int]], layer_size: int
|
| 73 |
+
) -> List[Tuple[int, int]]:
|
| 74 |
+
"""
|
| 75 |
+
Inverts a list of ranges to obtain the ranges not covered by the given ranges.
|
| 76 |
+
|
| 77 |
+
Parameters:
|
| 78 |
+
- given_ranges (List[Tuple[int, int]]): List of ranges to invert. Each range is represented as a tuple of start (inclusive) and end (exclusive) indices.
|
| 79 |
+
- layer_size (int): The length of the layer. E.g., len(model.layer.weight)
|
| 80 |
+
Returns:
|
| 81 |
+
- List[Tuple[int, int]]: List of inverted ranges, where each range is represented as a tuple of start (inclusive) and end (exclusive) indices.
|
| 82 |
+
"""
|
| 83 |
+
if not given_ranges:
|
| 84 |
+
return [(0, layer_size)]
|
| 85 |
+
|
| 86 |
+
inverted_ranges = []
|
| 87 |
+
current_start = 0
|
| 88 |
+
|
| 89 |
+
for start, end in sorted(given_ranges):
|
| 90 |
+
if start > current_start:
|
| 91 |
+
inverted_ranges.append((current_start, start))
|
| 92 |
+
current_start = max(current_start, end)
|
| 93 |
+
|
| 94 |
+
# Handle the case where the last given range does not reach the end of the total_size
|
| 95 |
+
if current_start < layer_size:
|
| 96 |
+
inverted_ranges.append((current_start, layer_size))
|
| 97 |
+
|
| 98 |
+
return inverted_ranges
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def _merge_ranges(
|
| 102 |
+
given_ranges: List[Tuple[int, int | None]], layer_size: int
|
| 103 |
+
) -> List[Tuple[int, int]]:
|
| 104 |
+
"""
|
| 105 |
+
Merges overlapping ranges and sorts the given ranges.
|
| 106 |
+
|
| 107 |
+
This function takes a list of ranges and merges any overlapping ranges. The ranges are represented
|
| 108 |
+
as tuples, where the first element is the start index (inclusive) and the second element is the end
|
| 109 |
+
index (exclusive). The end index can be None, indicating that the range extends to the end of the
|
| 110 |
+
sequence.
|
| 111 |
+
|
| 112 |
+
Parameters:
|
| 113 |
+
- given_ranges (List[Tuple[int, int | None]]): List of ranges to merge.
|
| 114 |
+
- layer_size (int): The length of the layer. E.g., len(model.layer.weight)
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
- List[Tuple[int, int]]: List of merged ranges, as start (inclusive) and end (exclusive) indices.
|
| 118 |
+
"""
|
| 119 |
+
# End of each range can be determined now since we have the total size
|
| 120 |
+
processed_ranges = [
|
| 121 |
+
(start, end if end is not None else layer_size) for start, end in given_ranges
|
| 122 |
+
]
|
| 123 |
+
|
| 124 |
+
# No need to merge if there's only one or no ranges
|
| 125 |
+
if len(processed_ranges) <= 1:
|
| 126 |
+
return processed_ranges
|
| 127 |
+
|
| 128 |
+
sorted_ranges = sorted(processed_ranges)
|
| 129 |
+
|
| 130 |
+
merged_ranges = [sorted_ranges[0]]
|
| 131 |
+
for start, end in sorted_ranges[1:]:
|
| 132 |
+
prev_start, prev_end = merged_ranges[-1]
|
| 133 |
+
if start <= prev_end:
|
| 134 |
+
merged_ranges[-1] = (prev_start, max(prev_end, end))
|
| 135 |
+
else:
|
| 136 |
+
merged_ranges.append((start, end))
|
| 137 |
+
|
| 138 |
+
return merged_ranges
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def _create_freeze_parameters_hook(ranges_to_freeze: List[Tuple[int, int]]) -> Callable:
|
| 142 |
+
"""
|
| 143 |
+
Create a hook to freeze parameters in specified ranges by setting their gradients to zero.
|
| 144 |
+
|
| 145 |
+
This function takes a list of tuples representing the ranges of indices to freeze. Each tuple should contain
|
| 146 |
+
two integers representing the start and end indices of the range.
|
| 147 |
+
|
| 148 |
+
Parameters:
|
| 149 |
+
- ranges_to_freeze (List[Tuple[int, int]]): Ranges of indices to freeze.
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
- Callable: A hook function to be used with `register_hook` on parameters.
|
| 153 |
+
|
| 154 |
+
Example usage:
|
| 155 |
+
```
|
| 156 |
+
ranges_to_freeze = [(0, 10), (20, 30)]
|
| 157 |
+
hook = _create_freeze_parameters_hook(ranges_to_freeze)
|
| 158 |
+
model.register_hook(hook)
|
| 159 |
+
```
|
| 160 |
+
"""
|
| 161 |
+
|
| 162 |
+
def freeze_parameters_hook(gradients):
|
| 163 |
+
for start, end in ranges_to_freeze:
|
| 164 |
+
gradients[start:end].zero_()
|
| 165 |
+
|
| 166 |
+
return freeze_parameters_hook
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class LayerNamePattern:
|
| 170 |
+
"""
|
| 171 |
+
Represents a regex pattern for layer names, potentially including a parameter index range.
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
def __init__(self, pattern: str):
|
| 175 |
+
"""
|
| 176 |
+
Initializes a new instance of the LayerNamePattern class.
|
| 177 |
+
|
| 178 |
+
Parameters:
|
| 179 |
+
- pattern (str): The regex pattern for layer names, potentially including a parameter index range.
|
| 180 |
+
"""
|
| 181 |
+
self.raw_pattern = pattern
|
| 182 |
+
name_pattern, self.range = self._parse_pattern(pattern)
|
| 183 |
+
self.name_regex = re.compile(name_pattern.replace(".", "\\."))
|
| 184 |
+
|
| 185 |
+
def match(self, name: str) -> bool:
|
| 186 |
+
"""
|
| 187 |
+
Checks if the given layer name matches the regex pattern.
|
| 188 |
+
|
| 189 |
+
Parameters:
|
| 190 |
+
- name (str): The layer name to check.
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
- bool: True if the layer name matches the pattern, False otherwise.
|
| 194 |
+
"""
|
| 195 |
+
return self.name_regex.match(name) is not None
|
| 196 |
+
|
| 197 |
+
def _parse_pattern(self, pattern: str) -> Tuple[str, Tuple[int, int | None] | None]:
|
| 198 |
+
"""
|
| 199 |
+
Extracts the range pattern from the given pattern.
|
| 200 |
+
|
| 201 |
+
Parameters:
|
| 202 |
+
- pattern (str): The pattern to extract the range from.
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
- Tuple[str, Tuple[int, int | None] | None]: A tuple containing the regex pattern to match the layer name without the range pattern and the range of layer indices to match, if specified.
|
| 206 |
+
"""
|
| 207 |
+
match = re.match(r"^(.+)\[([0-9]*)(?::([0-9]*))?\]$", pattern)
|
| 208 |
+
if not match:
|
| 209 |
+
return pattern, None
|
| 210 |
+
|
| 211 |
+
base_pattern, start_part, end_part = match.groups()
|
| 212 |
+
|
| 213 |
+
if end_part is None and start_part.isdecimal():
|
| 214 |
+
index = int(start_part)
|
| 215 |
+
return base_pattern, (index, index + 1)
|
| 216 |
+
|
| 217 |
+
# [:end] or [start:] or [start:end]
|
| 218 |
+
start = int(start_part) if start_part else 0
|
| 219 |
+
end = int(end_part) if end_part else None
|
| 220 |
+
|
| 221 |
+
if end is not None and start >= end:
|
| 222 |
+
raise ValueError(
|
| 223 |
+
f"Invalid range in layer name pattern: {pattern}."
|
| 224 |
+
"End of range must be greater than start."
|
| 225 |
+
)
|
| 226 |
+
return base_pattern, (start, end)
|
tests/test_freeze.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This module contains unit tests for the `freeze_layers_except` function.
|
| 3 |
+
|
| 4 |
+
The `freeze_layers_except` function is used to freeze layers in a model, except for the specified layers.
|
| 5 |
+
The unit tests in this module verify the behavior of the `freeze_layers_except` function in different scenarios.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import unittest
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch import nn
|
| 12 |
+
|
| 13 |
+
from axolotl.utils.freeze import freeze_layers_except
|
| 14 |
+
|
| 15 |
+
ZERO = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
|
| 16 |
+
ONE_TO_TEN = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class TestFreezeLayersExcept(unittest.TestCase):
|
| 20 |
+
"""
|
| 21 |
+
A test case class for the `freeze_layers_except` function.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def setUp(self):
|
| 25 |
+
self.model = _TestModel()
|
| 26 |
+
|
| 27 |
+
def test_freeze_layers_with_dots_in_name(self):
|
| 28 |
+
freeze_layers_except(self.model, ["features.layer"])
|
| 29 |
+
self.assertTrue(
|
| 30 |
+
self.model.features.layer.weight.requires_grad,
|
| 31 |
+
"model.features.layer should be trainable.",
|
| 32 |
+
)
|
| 33 |
+
self.assertFalse(
|
| 34 |
+
self.model.classifier.weight.requires_grad,
|
| 35 |
+
"model.classifier should be frozen.",
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
def test_freeze_layers_without_dots_in_name(self):
|
| 39 |
+
freeze_layers_except(self.model, ["classifier"])
|
| 40 |
+
self.assertFalse(
|
| 41 |
+
self.model.features.layer.weight.requires_grad,
|
| 42 |
+
"model.features.layer should be trainable.",
|
| 43 |
+
)
|
| 44 |
+
self.assertTrue(
|
| 45 |
+
self.model.classifier.weight.requires_grad,
|
| 46 |
+
"model.classifier should be frozen.",
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
def test_freeze_layers_regex_patterns(self):
|
| 50 |
+
# The second pattern cannot match because only characters 'a' to 'c' are allowed after the word 'class', whereas it should be matching the character 'i'.
|
| 51 |
+
freeze_layers_except(self.model, [r"^features.[a-z]+.weight$", r"class[a-c]+"])
|
| 52 |
+
self.assertTrue(
|
| 53 |
+
self.model.features.layer.weight.requires_grad,
|
| 54 |
+
"model.features.layer should be trainable.",
|
| 55 |
+
)
|
| 56 |
+
self.assertFalse(
|
| 57 |
+
self.model.classifier.weight.requires_grad,
|
| 58 |
+
"model.classifier should be frozen.",
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
def test_all_layers_frozen(self):
|
| 62 |
+
freeze_layers_except(self.model, [])
|
| 63 |
+
self.assertFalse(
|
| 64 |
+
self.model.features.layer.weight.requires_grad,
|
| 65 |
+
"model.features.layer should be frozen.",
|
| 66 |
+
)
|
| 67 |
+
self.assertFalse(
|
| 68 |
+
self.model.classifier.weight.requires_grad,
|
| 69 |
+
"model.classifier should be frozen.",
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
def test_all_layers_unfrozen(self):
|
| 73 |
+
freeze_layers_except(self.model, ["features.layer", "classifier"])
|
| 74 |
+
self.assertTrue(
|
| 75 |
+
self.model.features.layer.weight.requires_grad,
|
| 76 |
+
"model.features.layer should be trainable.",
|
| 77 |
+
)
|
| 78 |
+
self.assertTrue(
|
| 79 |
+
self.model.classifier.weight.requires_grad,
|
| 80 |
+
"model.classifier should be trainable.",
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
def test_freeze_layers_with_range_pattern_start_end(self):
|
| 84 |
+
freeze_layers_except(self.model, ["features.layer[1:5]"])
|
| 85 |
+
self.assertTrue(
|
| 86 |
+
self.model.features.layer.weight.requires_grad,
|
| 87 |
+
"model.features.layer should be trainable.",
|
| 88 |
+
)
|
| 89 |
+
self.assertFalse(
|
| 90 |
+
self.model.classifier.weight.requires_grad,
|
| 91 |
+
"model.classifier should be frozen.",
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
self._assert_gradient_output(
|
| 95 |
+
[
|
| 96 |
+
ZERO,
|
| 97 |
+
ONE_TO_TEN,
|
| 98 |
+
ONE_TO_TEN,
|
| 99 |
+
ONE_TO_TEN,
|
| 100 |
+
ONE_TO_TEN,
|
| 101 |
+
ZERO,
|
| 102 |
+
ZERO,
|
| 103 |
+
ZERO,
|
| 104 |
+
ZERO,
|
| 105 |
+
ZERO,
|
| 106 |
+
]
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
def test_freeze_layers_with_range_pattern_single_index(self):
|
| 110 |
+
freeze_layers_except(self.model, ["features.layer[5]"])
|
| 111 |
+
self.assertTrue(
|
| 112 |
+
self.model.features.layer.weight.requires_grad,
|
| 113 |
+
"model.features.layer should be trainable.",
|
| 114 |
+
)
|
| 115 |
+
self.assertFalse(
|
| 116 |
+
self.model.classifier.weight.requires_grad,
|
| 117 |
+
"model.classifier should be frozen.",
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
self._assert_gradient_output(
|
| 121 |
+
[ZERO, ZERO, ZERO, ZERO, ZERO, ONE_TO_TEN, ZERO, ZERO, ZERO, ZERO]
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
def test_freeze_layers_with_range_pattern_start_omitted(self):
|
| 125 |
+
freeze_layers_except(self.model, ["features.layer[:5]"])
|
| 126 |
+
self.assertTrue(
|
| 127 |
+
self.model.features.layer.weight.requires_grad,
|
| 128 |
+
"model.features.layer should be trainable.",
|
| 129 |
+
)
|
| 130 |
+
self.assertFalse(
|
| 131 |
+
self.model.classifier.weight.requires_grad,
|
| 132 |
+
"model.classifier should be frozen.",
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
self._assert_gradient_output(
|
| 136 |
+
[
|
| 137 |
+
ONE_TO_TEN,
|
| 138 |
+
ONE_TO_TEN,
|
| 139 |
+
ONE_TO_TEN,
|
| 140 |
+
ONE_TO_TEN,
|
| 141 |
+
ONE_TO_TEN,
|
| 142 |
+
ZERO,
|
| 143 |
+
ZERO,
|
| 144 |
+
ZERO,
|
| 145 |
+
ZERO,
|
| 146 |
+
ZERO,
|
| 147 |
+
]
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
def test_freeze_layers_with_range_pattern_end_omitted(self):
|
| 151 |
+
freeze_layers_except(self.model, ["features.layer[4:]"])
|
| 152 |
+
self.assertTrue(
|
| 153 |
+
self.model.features.layer.weight.requires_grad,
|
| 154 |
+
"model.features.layer should be trainable.",
|
| 155 |
+
)
|
| 156 |
+
self.assertFalse(
|
| 157 |
+
self.model.classifier.weight.requires_grad,
|
| 158 |
+
"model.classifier should be frozen.",
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
self._assert_gradient_output(
|
| 162 |
+
[
|
| 163 |
+
ZERO,
|
| 164 |
+
ZERO,
|
| 165 |
+
ZERO,
|
| 166 |
+
ZERO,
|
| 167 |
+
ONE_TO_TEN,
|
| 168 |
+
ONE_TO_TEN,
|
| 169 |
+
ONE_TO_TEN,
|
| 170 |
+
ONE_TO_TEN,
|
| 171 |
+
ONE_TO_TEN,
|
| 172 |
+
ONE_TO_TEN,
|
| 173 |
+
]
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
def test_freeze_layers_with_range_pattern_merge_included(self):
|
| 177 |
+
freeze_layers_except(self.model, ["features.layer[4:]", "features.layer[5:6]"])
|
| 178 |
+
self.assertTrue(
|
| 179 |
+
self.model.features.layer.weight.requires_grad,
|
| 180 |
+
"model.features.layer should be trainable.",
|
| 181 |
+
)
|
| 182 |
+
self.assertFalse(
|
| 183 |
+
self.model.classifier.weight.requires_grad,
|
| 184 |
+
"model.classifier should be frozen.",
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
self._assert_gradient_output(
|
| 188 |
+
[
|
| 189 |
+
ZERO,
|
| 190 |
+
ZERO,
|
| 191 |
+
ZERO,
|
| 192 |
+
ZERO,
|
| 193 |
+
ONE_TO_TEN,
|
| 194 |
+
ONE_TO_TEN,
|
| 195 |
+
ONE_TO_TEN,
|
| 196 |
+
ONE_TO_TEN,
|
| 197 |
+
ONE_TO_TEN,
|
| 198 |
+
ONE_TO_TEN,
|
| 199 |
+
]
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
def test_freeze_layers_with_range_pattern_merge_intersect(self):
|
| 203 |
+
freeze_layers_except(self.model, ["features.layer[4:7]", "features.layer[6:8]"])
|
| 204 |
+
self.assertTrue(
|
| 205 |
+
self.model.features.layer.weight.requires_grad,
|
| 206 |
+
"model.features.layer should be trainable.",
|
| 207 |
+
)
|
| 208 |
+
self.assertFalse(
|
| 209 |
+
self.model.classifier.weight.requires_grad,
|
| 210 |
+
"model.classifier should be frozen.",
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
self._assert_gradient_output(
|
| 214 |
+
[
|
| 215 |
+
ZERO,
|
| 216 |
+
ZERO,
|
| 217 |
+
ZERO,
|
| 218 |
+
ZERO,
|
| 219 |
+
ONE_TO_TEN,
|
| 220 |
+
ONE_TO_TEN,
|
| 221 |
+
ONE_TO_TEN,
|
| 222 |
+
ONE_TO_TEN,
|
| 223 |
+
ZERO,
|
| 224 |
+
ZERO,
|
| 225 |
+
]
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
def test_freeze_layers_with_range_pattern_merge_separate(self):
|
| 229 |
+
freeze_layers_except(
|
| 230 |
+
self.model,
|
| 231 |
+
["features.layer[1:2]", "features.layer[3:4]", "features.layer[5:6]"],
|
| 232 |
+
)
|
| 233 |
+
self.assertTrue(
|
| 234 |
+
self.model.features.layer.weight.requires_grad,
|
| 235 |
+
"model.features.layer should be trainable.",
|
| 236 |
+
)
|
| 237 |
+
self.assertFalse(
|
| 238 |
+
self.model.classifier.weight.requires_grad,
|
| 239 |
+
"model.classifier should be frozen.",
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
self._assert_gradient_output(
|
| 243 |
+
[
|
| 244 |
+
ZERO,
|
| 245 |
+
ONE_TO_TEN,
|
| 246 |
+
ZERO,
|
| 247 |
+
ONE_TO_TEN,
|
| 248 |
+
ZERO,
|
| 249 |
+
ONE_TO_TEN,
|
| 250 |
+
ZERO,
|
| 251 |
+
ZERO,
|
| 252 |
+
ZERO,
|
| 253 |
+
ZERO,
|
| 254 |
+
]
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
def _assert_gradient_output(self, expected):
|
| 258 |
+
input_tensor = torch.tensor([ONE_TO_TEN], dtype=torch.float32)
|
| 259 |
+
|
| 260 |
+
self.model.features.layer.weight.grad = None # Reset gradients
|
| 261 |
+
output = self.model.features.layer(input_tensor)
|
| 262 |
+
loss = output.sum()
|
| 263 |
+
loss.backward()
|
| 264 |
+
|
| 265 |
+
expected_grads = torch.tensor(expected)
|
| 266 |
+
torch.testing.assert_close(
|
| 267 |
+
self.model.features.layer.weight.grad, expected_grads
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class _SubLayerModule(nn.Module):
|
| 272 |
+
def __init__(self):
|
| 273 |
+
super().__init__()
|
| 274 |
+
self.layer = nn.Linear(10, 10)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class _TestModel(nn.Module):
|
| 278 |
+
def __init__(self):
|
| 279 |
+
super().__init__()
|
| 280 |
+
self.features = _SubLayerModule()
|
| 281 |
+
self.classifier = nn.Linear(10, 2)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
if __name__ == "__main__":
|
| 285 |
+
unittest.main()
|