reorg a bit
Browse files
src/axolotl/monkeypatch/llama_attn_hijack_flash.py
CHANGED
|
@@ -64,14 +64,13 @@ def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
|
|
| 64 |
try:
|
| 65 |
from flash_attn.ops.rms_norm import RMSNorm
|
| 66 |
|
| 67 |
-
LOG.info("patching with flash_attn.ops.rms_norm")
|
| 68 |
-
|
| 69 |
class LlamaRMSNorm(RMSNorm):
|
| 70 |
"""Patched LLamaRMSNorm"""
|
| 71 |
|
| 72 |
def __init__(self, hidden_size, eps=1e-6):
|
| 73 |
super().__init__(hidden_size, eps=eps)
|
| 74 |
|
|
|
|
| 75 |
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
| 76 |
except ImportError:
|
| 77 |
LOG.info(
|
|
|
|
| 64 |
try:
|
| 65 |
from flash_attn.ops.rms_norm import RMSNorm
|
| 66 |
|
|
|
|
|
|
|
| 67 |
class LlamaRMSNorm(RMSNorm):
|
| 68 |
"""Patched LLamaRMSNorm"""
|
| 69 |
|
| 70 |
def __init__(self, hidden_size, eps=1e-6):
|
| 71 |
super().__init__(hidden_size, eps=eps)
|
| 72 |
|
| 73 |
+
LOG.info("patching with flash_attn.ops.rms_norm")
|
| 74 |
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
| 75 |
except ImportError:
|
| 76 |
LOG.info(
|