reorg a bit
Browse files
src/axolotl/monkeypatch/llama_attn_hijack_flash.py
CHANGED
@@ -64,14 +64,13 @@ def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
|
|
64 |
try:
|
65 |
from flash_attn.ops.rms_norm import RMSNorm
|
66 |
|
67 |
-
LOG.info("patching with flash_attn.ops.rms_norm")
|
68 |
-
|
69 |
class LlamaRMSNorm(RMSNorm):
|
70 |
"""Patched LLamaRMSNorm"""
|
71 |
|
72 |
def __init__(self, hidden_size, eps=1e-6):
|
73 |
super().__init__(hidden_size, eps=eps)
|
74 |
|
|
|
75 |
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
76 |
except ImportError:
|
77 |
LOG.info(
|
|
|
64 |
try:
|
65 |
from flash_attn.ops.rms_norm import RMSNorm
|
66 |
|
|
|
|
|
67 |
class LlamaRMSNorm(RMSNorm):
|
68 |
"""Patched LLamaRMSNorm"""
|
69 |
|
70 |
def __init__(self, hidden_size, eps=1e-6):
|
71 |
super().__init__(hidden_size, eps=eps)
|
72 |
|
73 |
+
LOG.info("patching with flash_attn.ops.rms_norm")
|
74 |
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
75 |
except ImportError:
|
76 |
LOG.info(
|