fix: warn user to install mamba_ssm package (#1019)
Browse files- docker/Dockerfile +2 -2
- requirements.txt +3 -1
- setup.py +7 -7
- src/axolotl/models/mamba/__init__.py +12 -0
docker/Dockerfile
CHANGED
|
@@ -20,9 +20,9 @@ WORKDIR /workspace/axolotl
|
|
| 20 |
|
| 21 |
# If AXOLOTL_EXTRAS is set, append it in brackets
|
| 22 |
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
| 23 |
-
pip install -e .[deepspeed,flash-attn,$AXOLOTL_EXTRAS]; \
|
| 24 |
else \
|
| 25 |
-
pip install -e .[deepspeed,flash-attn]; \
|
| 26 |
fi
|
| 27 |
|
| 28 |
# So we can test the Docker image
|
|
|
|
| 20 |
|
| 21 |
# If AXOLOTL_EXTRAS is set, append it in brackets
|
| 22 |
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
| 23 |
+
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS]; \
|
| 24 |
else \
|
| 25 |
+
pip install -e .[deepspeed,flash-attn,mamba-ssm]; \
|
| 26 |
fi
|
| 27 |
|
| 28 |
# So we can test the Docker image
|
requirements.txt
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
| 2 |
-
packaging
|
| 3 |
peft==0.7.0
|
| 4 |
transformers @ git+https://github.com/huggingface/transformers.git@3cefac1d974db5e2825a0cb2b842883a628be7a0
|
| 5 |
tokenizers==0.15.0
|
|
@@ -34,6 +34,8 @@ fschat==0.2.34
|
|
| 34 |
gradio==3.50.2
|
| 35 |
tensorboard
|
| 36 |
|
|
|
|
|
|
|
| 37 |
# remote filesystems
|
| 38 |
s3fs
|
| 39 |
gcsfs
|
|
|
|
| 1 |
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
| 2 |
+
packaging==23.2
|
| 3 |
peft==0.7.0
|
| 4 |
transformers @ git+https://github.com/huggingface/transformers.git@3cefac1d974db5e2825a0cb2b842883a628be7a0
|
| 5 |
tokenizers==0.15.0
|
|
|
|
| 34 |
gradio==3.50.2
|
| 35 |
tensorboard
|
| 36 |
|
| 37 |
+
mamba-ssm==1.1.1
|
| 38 |
+
|
| 39 |
# remote filesystems
|
| 40 |
s3fs
|
| 41 |
gcsfs
|
setup.py
CHANGED
|
@@ -11,17 +11,17 @@ def parse_requirements():
|
|
| 11 |
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
| 12 |
lines = [r.strip() for r in requirements_file.readlines()]
|
| 13 |
for line in lines:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
if line.startswith("--extra-index-url"):
|
| 15 |
# Handle custom index URLs
|
| 16 |
_, url = line.split()
|
| 17 |
_dependency_links.append(url)
|
| 18 |
-
elif
|
| 19 |
-
"flash-attn" not in line
|
| 20 |
-
and "flash-attention" not in line
|
| 21 |
-
and "deepspeed" not in line
|
| 22 |
-
and line
|
| 23 |
-
and line[0] != "#"
|
| 24 |
-
):
|
| 25 |
# Handle standard packages
|
| 26 |
_install_requires.append(line)
|
| 27 |
|
|
|
|
| 11 |
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
| 12 |
lines = [r.strip() for r in requirements_file.readlines()]
|
| 13 |
for line in lines:
|
| 14 |
+
is_extras = (
|
| 15 |
+
"flash-attn" in line
|
| 16 |
+
or "flash-attention" in line
|
| 17 |
+
or "deepspeed" in line
|
| 18 |
+
or "mamba-ssm" in line
|
| 19 |
+
)
|
| 20 |
if line.startswith("--extra-index-url"):
|
| 21 |
# Handle custom index URLs
|
| 22 |
_, url = line.split()
|
| 23 |
_dependency_links.append(url)
|
| 24 |
+
elif not is_extras and line and line[0] != "#":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
# Handle standard packages
|
| 26 |
_install_requires.append(line)
|
| 27 |
|
src/axolotl/models/mamba/__init__.py
CHANGED
|
@@ -2,8 +2,20 @@
|
|
| 2 |
Modeling module for Mamba models
|
| 3 |
"""
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
def fix_mamba_attn_for_loss():
|
|
|
|
|
|
|
| 7 |
from mamba_ssm.models import mixer_seq_simple
|
| 8 |
|
| 9 |
from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed
|
|
|
|
| 2 |
Modeling module for Mamba models
|
| 3 |
"""
|
| 4 |
|
| 5 |
+
import importlib
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def check_mamba_ssm_installed():
|
| 9 |
+
mamba_ssm_spec = importlib.util.find_spec("mamba_ssm")
|
| 10 |
+
if mamba_ssm_spec is None:
|
| 11 |
+
raise ImportError(
|
| 12 |
+
"MambaLMHeadModel requires mamba_ssm. Please install it with `pip install -e .[mamba-ssm]`"
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
|
| 16 |
def fix_mamba_attn_for_loss():
|
| 17 |
+
check_mamba_ssm_installed()
|
| 18 |
+
|
| 19 |
from mamba_ssm.models import mixer_seq_simple
|
| 20 |
|
| 21 |
from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed
|