Spaces:
Runtime error
Runtime error
Commit
Β·
0926815
1
Parent(s):
d65e5f6
fix(model): improve 8-bit quantization configuration
Browse filesUpdate 8-bit loading to use proper BitsAndBytesConfig with compute dtype
specification and fix torch_dtype handling when consumed by quantization.
Add bitsandbytes dependency and remove unused dependencies.
- llava/model/builder.py +7 -2
- requirements.txt +4 -5
llava/model/builder.py
CHANGED
|
@@ -46,7 +46,10 @@ def load_pretrained_model(
|
|
| 46 |
kwargs["device_map"] = {"": device}
|
| 47 |
|
| 48 |
if load_8bit:
|
| 49 |
-
kwargs["
|
|
|
|
|
|
|
|
|
|
| 50 |
elif load_4bit:
|
| 51 |
kwargs["load_in_4bit"] = True
|
| 52 |
kwargs["quantization_config"] = BitsAndBytesConfig(
|
|
@@ -158,4 +161,6 @@ def prepare_config_for_eval(config: PretrainedConfig, kwargs: dict):
|
|
| 158 |
except AttributeError:
|
| 159 |
raise ValueError(f"Invalid configuration! Cannot find vision_tower in config:\n{config}")
|
| 160 |
|
| 161 |
-
|
|
|
|
|
|
|
|
|
| 46 |
kwargs["device_map"] = {"": device}
|
| 47 |
|
| 48 |
if load_8bit:
|
| 49 |
+
kwargs["quantization_config"] = BitsAndBytesConfig(
|
| 50 |
+
load_in_8bit=True,
|
| 51 |
+
bnb_8bit_compute_dtype=torch.float16,
|
| 52 |
+
)
|
| 53 |
elif load_4bit:
|
| 54 |
kwargs["load_in_4bit"] = True
|
| 55 |
kwargs["quantization_config"] = BitsAndBytesConfig(
|
|
|
|
| 161 |
except AttributeError:
|
| 162 |
raise ValueError(f"Invalid configuration! Cannot find vision_tower in config:\n{config}")
|
| 163 |
|
| 164 |
+
# Handle case where torch_dtype might be consumed by quantization config
|
| 165 |
+
torch_dtype = kwargs.pop("torch_dtype", torch.float16)
|
| 166 |
+
config.model_dtype = torch_dtype.__str__()
|
requirements.txt
CHANGED
|
@@ -4,14 +4,11 @@ hydra-core
|
|
| 4 |
loguru
|
| 5 |
Pillow
|
| 6 |
pydub
|
| 7 |
-
torch
|
| 8 |
-
torchvision
|
| 9 |
|
| 10 |
|
| 11 |
# Transformers and training utilities
|
| 12 |
transformers==4.46.0
|
| 13 |
pytorchvideo==0.1.5
|
| 14 |
-
deepspeed==0.15.4
|
| 15 |
accelerate==0.34.2
|
| 16 |
numpy==1.26.4
|
| 17 |
opencv-python-headless==4.8.0.76
|
|
@@ -27,8 +24,10 @@ jiwer
|
|
| 27 |
einops
|
| 28 |
wandb
|
| 29 |
kaldiio
|
| 30 |
-
peft
|
| 31 |
|
| 32 |
# Compatibility fix
|
| 33 |
protobuf==3.20.*
|
| 34 |
-
triton==3.1.0
|
|
|
|
|
|
|
|
|
| 4 |
loguru
|
| 5 |
Pillow
|
| 6 |
pydub
|
|
|
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
# Transformers and training utilities
|
| 10 |
transformers==4.46.0
|
| 11 |
pytorchvideo==0.1.5
|
|
|
|
| 12 |
accelerate==0.34.2
|
| 13 |
numpy==1.26.4
|
| 14 |
opencv-python-headless==4.8.0.76
|
|
|
|
| 24 |
einops
|
| 25 |
wandb
|
| 26 |
kaldiio
|
| 27 |
+
peft
|
| 28 |
|
| 29 |
# Compatibility fix
|
| 30 |
protobuf==3.20.*
|
| 31 |
+
triton==3.1.0
|
| 32 |
+
|
| 33 |
+
bitsandbytes
|