Maxime
Maxime
commited on
add noisy embedding (#721)
Browse files* add noisy embedding
* fix format
* Update README.md
* Update README.md
* linter issues
* caseus fixes
---------
Co-authored-by: Maxime <[email protected]>
README.md
CHANGED
|
@@ -672,6 +672,11 @@ adam_epsilon:
|
|
| 672 |
# Gradient clipping max norm
|
| 673 |
max_grad_norm:
|
| 674 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 675 |
# Whether to bettertransformers
|
| 676 |
flash_optimum:
|
| 677 |
# Whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
|
|
|
| 672 |
# Gradient clipping max norm
|
| 673 |
max_grad_norm:
|
| 674 |
|
| 675 |
+
# Augmentation techniques
|
| 676 |
+
# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
|
| 677 |
+
# currently only supported on Llama and Mistral
|
| 678 |
+
noisy_embedding_alpha:
|
| 679 |
+
|
| 680 |
# Whether to bettertransformers
|
| 681 |
flash_optimum:
|
| 682 |
# Whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
src/axolotl/monkeypatch/llama_embeddings_hijack.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
patch to add noisy embeddings per https://arxiv.org/abs/2310.05914
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import transformers.models.llama.modeling_llama
|
| 7 |
+
from transformers.utils import logging
|
| 8 |
+
|
| 9 |
+
logger = logging.get_logger(__name__)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def replace_llama_embeddings_with_uniform_distribution(noise_alpha=5):
|
| 13 |
+
# pylint: disable=duplicate-code
|
| 14 |
+
def noised_embed(orig_embed, noise_alpha, model):
|
| 15 |
+
def new_func(input_ids):
|
| 16 |
+
# during training, we add noise to the embedding
|
| 17 |
+
# during generation, we don't add noise to the embedding
|
| 18 |
+
if model.training:
|
| 19 |
+
embed_init = orig_embed(input_ids)
|
| 20 |
+
dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
|
| 21 |
+
mag_norm = noise_alpha / torch.sqrt(dims)
|
| 22 |
+
return embed_init + torch.zeros_like(embed_init).uniform_(
|
| 23 |
+
-mag_norm, mag_norm
|
| 24 |
+
)
|
| 25 |
+
return orig_embed(input_ids)
|
| 26 |
+
|
| 27 |
+
return new_func
|
| 28 |
+
|
| 29 |
+
def post_init(orig_post_init):
|
| 30 |
+
def new_func(self):
|
| 31 |
+
orig_post_init(self)
|
| 32 |
+
self.embed_tokens.forward = noised_embed(
|
| 33 |
+
self.embed_tokens.forward, noise_alpha, self
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
return new_func
|
| 37 |
+
|
| 38 |
+
transformers.models.llama.modeling_llama.LlamaModel.post_init = post_init(
|
| 39 |
+
transformers.models.llama.modeling_llama.LlamaModel.post_init
|
| 40 |
+
)
|
src/axolotl/monkeypatch/mistral_embeddings_hijack.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
patch to add noisy embeddings per https://arxiv.org/abs/2310.05914
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import transformers.models.mistral.modeling_mistral
|
| 7 |
+
from transformers.utils import logging
|
| 8 |
+
|
| 9 |
+
logger = logging.get_logger(__name__)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def replace_mistral_embeddings_with_uniform_distribution(noise_alpha=5):
|
| 13 |
+
# pylint: disable=duplicate-code
|
| 14 |
+
def noised_embed(orig_embed, noise_alpha, model):
|
| 15 |
+
def new_func(input_ids):
|
| 16 |
+
# during training, we add noise to the embedding
|
| 17 |
+
# during generation, we don't add noise to the embedding
|
| 18 |
+
if model.training:
|
| 19 |
+
embed_init = orig_embed(input_ids)
|
| 20 |
+
dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
|
| 21 |
+
mag_norm = noise_alpha / torch.sqrt(dims)
|
| 22 |
+
return embed_init + torch.zeros_like(embed_init).uniform_(
|
| 23 |
+
-mag_norm, mag_norm
|
| 24 |
+
)
|
| 25 |
+
return orig_embed(input_ids)
|
| 26 |
+
|
| 27 |
+
return new_func
|
| 28 |
+
|
| 29 |
+
def post_init(orig_post_init):
|
| 30 |
+
def new_func(self):
|
| 31 |
+
orig_post_init(self)
|
| 32 |
+
self.embed_tokens.forward = noised_embed(
|
| 33 |
+
self.embed_tokens.forward, noise_alpha, self
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
return new_func
|
| 37 |
+
|
| 38 |
+
transformers.models.mistral.modeling_mistral.MistralModel.post_init = post_init(
|
| 39 |
+
transformers.models.mistral.modeling_mistral.MistralModel.post_init
|
| 40 |
+
)
|
src/axolotl/utils/models.py
CHANGED
|
@@ -180,6 +180,26 @@ def load_model(
|
|
| 180 |
LOG.info("patching with flash attention")
|
| 181 |
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
| 182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
if cfg.is_llama_derived_model and cfg.xpos_rope:
|
| 184 |
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
|
| 185 |
replace_llama_rope_with_xpos_rope,
|
|
|
|
| 180 |
LOG.info("patching with flash attention")
|
| 181 |
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
| 182 |
|
| 183 |
+
if cfg.is_llama_derived_model and cfg.noisy_embedding_alpha:
|
| 184 |
+
from axolotl.monkeypatch.llama_embeddings_hijack import (
|
| 185 |
+
replace_llama_embeddings_with_uniform_distribution,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
LOG.info("patching with noisy embeddings")
|
| 189 |
+
replace_llama_embeddings_with_uniform_distribution(
|
| 190 |
+
noise_alpha=cfg.noisy_embedding_alpha
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
if cfg.is_mistral_derived_model and cfg.noisy_embedding_alpha:
|
| 194 |
+
from axolotl.monkeypatch.mistral_embeddings_hijack import (
|
| 195 |
+
replace_mistral_embeddings_with_uniform_distribution,
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
LOG.info("patching with noisy embeddings")
|
| 199 |
+
replace_mistral_embeddings_with_uniform_distribution(
|
| 200 |
+
noise_alpha=cfg.noisy_embedding_alpha
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
if cfg.is_llama_derived_model and cfg.xpos_rope:
|
| 204 |
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
|
| 205 |
replace_llama_rope_with_xpos_rope,
|