PierrunoYT commited on
Commit
0926815
Β·
1 Parent(s): d65e5f6

fix(model): improve 8-bit quantization configuration

Browse files

Update 8-bit loading to use proper BitsAndBytesConfig with compute dtype
specification and fix torch_dtype handling when consumed by quantization.
Add bitsandbytes dependency and remove unused dependencies.

Files changed (2) hide show
  1. llava/model/builder.py +7 -2
  2. requirements.txt +4 -5
llava/model/builder.py CHANGED
@@ -46,7 +46,10 @@ def load_pretrained_model(
46
  kwargs["device_map"] = {"": device}
47
 
48
  if load_8bit:
49
- kwargs["load_in_8bit"] = True
 
 
 
50
  elif load_4bit:
51
  kwargs["load_in_4bit"] = True
52
  kwargs["quantization_config"] = BitsAndBytesConfig(
@@ -158,4 +161,6 @@ def prepare_config_for_eval(config: PretrainedConfig, kwargs: dict):
158
  except AttributeError:
159
  raise ValueError(f"Invalid configuration! Cannot find vision_tower in config:\n{config}")
160
 
161
- config.model_dtype = kwargs.pop("torch_dtype").__str__()
 
 
 
46
  kwargs["device_map"] = {"": device}
47
 
48
  if load_8bit:
49
+ kwargs["quantization_config"] = BitsAndBytesConfig(
50
+ load_in_8bit=True,
51
+ bnb_8bit_compute_dtype=torch.float16,
52
+ )
53
  elif load_4bit:
54
  kwargs["load_in_4bit"] = True
55
  kwargs["quantization_config"] = BitsAndBytesConfig(
 
161
  except AttributeError:
162
  raise ValueError(f"Invalid configuration! Cannot find vision_tower in config:\n{config}")
163
 
164
+ # Handle case where torch_dtype might be consumed by quantization config
165
+ torch_dtype = kwargs.pop("torch_dtype", torch.float16)
166
+ config.model_dtype = torch_dtype.__str__()
requirements.txt CHANGED
@@ -4,14 +4,11 @@ hydra-core
4
  loguru
5
  Pillow
6
  pydub
7
- torch
8
- torchvision
9
 
10
 
11
  # Transformers and training utilities
12
  transformers==4.46.0
13
  pytorchvideo==0.1.5
14
- deepspeed==0.15.4
15
  accelerate==0.34.2
16
  numpy==1.26.4
17
  opencv-python-headless==4.8.0.76
@@ -27,8 +24,10 @@ jiwer
27
  einops
28
  wandb
29
  kaldiio
30
- peft==0.14.0
31
 
32
  # Compatibility fix
33
  protobuf==3.20.*
34
- triton==3.1.0
 
 
 
4
  loguru
5
  Pillow
6
  pydub
 
 
7
 
8
 
9
  # Transformers and training utilities
10
  transformers==4.46.0
11
  pytorchvideo==0.1.5
 
12
  accelerate==0.34.2
13
  numpy==1.26.4
14
  opencv-python-headless==4.8.0.76
 
24
  einops
25
  wandb
26
  kaldiio
27
+ peft
28
 
29
  # Compatibility fix
30
  protobuf==3.20.*
31
+ triton==3.1.0
32
+
33
+ bitsandbytes