update script
Browse files- convert_flax.py +144 -84
convert_flax.py
CHANGED
@@ -1,15 +1,18 @@
|
|
1 |
import argparse
|
|
|
2 |
from pathlib import Path
|
3 |
|
|
|
4 |
import jax.numpy as jnp
|
5 |
import numpy as np
|
|
|
6 |
from safetensors.flax import save_file
|
7 |
from tqdm import tqdm
|
8 |
|
9 |
-
|
10 |
|
11 |
|
12 |
-
def flatten(x:
|
13 |
if start < 0:
|
14 |
start += x.ndim
|
15 |
if end < 0:
|
@@ -18,27 +21,27 @@ def flatten(x: jnp.ndarray, start: int = 0, end: int = -1):
|
|
18 |
return x.reshape(new_shape)
|
19 |
|
20 |
|
21 |
-
def unflatten(x:
|
22 |
new_shape = x.shape[:dim] + tuple(sizes) + x.shape[dim + 1 :]
|
23 |
return x.reshape(new_shape)
|
24 |
|
25 |
|
26 |
# correct quantization parameters mean quantization error = 0 (or close to 0)
|
27 |
-
def check_groups(groups:
|
28 |
# groups: (a, b, c, 32, d, e, f)
|
29 |
# scales: (a, b, c, 1, d, e, f)
|
30 |
inv_scale = 1.0 / scales.clip(1e-12)
|
31 |
-
q_group =
|
32 |
-
max_diff =
|
33 |
return max_diff < 1e-6, max_diff
|
34 |
|
35 |
|
36 |
-
def find_scales(w:
|
37 |
w = unflatten(w, dim, (-1, 32))
|
38 |
group_range = w.max(dim + 1, keepdims=True) - w.min(dim + 1, keepdims=True)
|
39 |
|
40 |
scales = np.zeros_like(group_range)
|
41 |
-
for q in
|
42 |
try_scale = group_range / q
|
43 |
ok, _ = check_groups(w, try_scale, dim + 1)
|
44 |
scales[ok] = try_scale[ok]
|
@@ -49,10 +52,10 @@ def find_scales(w: jnp.ndarray, dim: int, pbar: bool = True):
|
|
49 |
return scales.squeeze(dim + 1)
|
50 |
|
51 |
|
52 |
-
def convert_siglip(params):
|
53 |
state_dict = dict()
|
54 |
|
55 |
-
def convert_layer(prefix: str, layer: dict[str,
|
56 |
bias = layer["bias"]
|
57 |
|
58 |
if "kernel" in layer:
|
@@ -85,96 +88,129 @@ def convert_siglip(params):
|
|
85 |
state_dict[f"{prefix}weight"] = w
|
86 |
state_dict[f"{prefix}bias"] = bias
|
87 |
|
88 |
-
convert_layer("embeddings.patch_embedding.", params["embedding"])
|
89 |
-
state_dict["embeddings.position_embedding.weight"] = params["pos_embedding"].squeeze(0)
|
|
|
90 |
|
91 |
-
|
92 |
-
convert_layer("post_layernorm.", trunk["encoder_norm"])
|
93 |
-
|
94 |
-
layer_idx = 0
|
95 |
-
while f"encoderblock_{layer_idx}" in trunk:
|
96 |
prefix = f"encoder.layers.{layer_idx}."
|
97 |
-
|
98 |
-
|
99 |
-
convert_layer(f"{prefix}layer_norm1.", encoder_layer["LayerNorm_0"])
|
100 |
-
convert_layer(f"{prefix}layer_norm2.", encoder_layer["LayerNorm_1"])
|
101 |
|
102 |
-
|
103 |
-
convert_layer(f"{prefix}
|
104 |
-
convert_layer(f"{prefix}self_attn.k_proj.", attn_layer["key"])
|
105 |
-
convert_layer(f"{prefix}self_attn.v_proj.", attn_layer["value"])
|
106 |
-
convert_layer(f"{prefix}self_attn.out_proj.", attn_layer["out"])
|
107 |
|
108 |
-
|
109 |
-
convert_layer(f"{prefix}
|
110 |
-
convert_layer(f"{prefix}
|
|
|
|
|
111 |
|
112 |
-
|
|
|
|
|
113 |
|
114 |
return state_dict
|
115 |
|
116 |
|
117 |
# convert to HF format first, then apply quantization
|
118 |
-
def convert_to_hf(
|
119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
# NOTE: all gemma3 models use tied embeddings, even for the 27B version.
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
prefix = f"model.layers.{layer_idx}."
|
128 |
-
layer_params = params[f"layer_{layer_idx}"]
|
129 |
-
state_dict[f"{prefix}input_layernorm.weight"] = layer_params["pre_attention_norm"]["scale"]
|
130 |
-
state_dict[f"{prefix}post_attention_layernorm.weight"] = layer_params["post_attention_norm"]["scale"]
|
131 |
-
state_dict[f"{prefix}pre_feedforward_layernorm.weight"] = layer_params["pre_ffw_norm"]["scale"]
|
132 |
-
state_dict[f"{prefix}post_feedforward_layernorm.weight"] = layer_params["post_ffw_norm"]["scale"]
|
133 |
-
|
134 |
-
prefix = f"model.layers.{layer_idx}.self_attn."
|
135 |
-
attn_params = layer_params["attn"]
|
136 |
-
state_dict[f"{prefix}q_norm.weight"] = attn_params["_query_norm"]["scale"]
|
137 |
-
state_dict[f"{prefix}k_norm.weight"] = attn_params["_key_norm"]["scale"]
|
138 |
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
|
|
143 |
|
144 |
-
|
145 |
-
|
|
|
|
|
146 |
|
147 |
-
|
148 |
-
|
149 |
-
state_dict[f"{prefix}gate_proj.weight"] = mlp_params["gating_einsum"][0] # NOTE: may need to transpose?
|
150 |
-
state_dict[f"{prefix}up_proj.weight"] = mlp_params["gating_einsum"][1]
|
151 |
-
state_dict[f"{prefix}down_proj.weight"] = mlp_params["linear"].T
|
152 |
|
153 |
-
|
|
|
154 |
|
155 |
-
|
156 |
-
if "vision_encoder" in params:
|
157 |
-
# HF append unused tokens for no reason???
|
158 |
-
state_dict["model.embed_tokens.weight"] = jnp.pad(state_dict["model.embed_tokens.weight"], ((0, 64), (0, 0)))
|
159 |
|
160 |
-
|
161 |
-
|
|
|
162 |
|
163 |
-
|
164 |
-
|
165 |
-
state_dict[f"{prefix}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
|
167 |
-
|
|
|
|
|
|
|
168 |
for k, v in siglip_state_dict.items():
|
169 |
state_dict[f"vision_tower.vision_model.{k}"] = v
|
170 |
-
|
171 |
-
return state_dict
|
172 |
|
173 |
|
174 |
-
def convert_awq(state_dict: dict[str,
|
175 |
awq_state_dict = dict()
|
176 |
|
177 |
-
for k, v in
|
178 |
if (
|
179 |
k.endswith("model.embed_tokens.weight") # AWQ doesn't support INT4 embeddings
|
180 |
or k.startswith(("vision_tower", "multi_modal_projector")) # vision tower is not quantized
|
@@ -186,10 +222,8 @@ def convert_awq(state_dict: dict[str, jnp.ndarray]):
|
|
186 |
assert v.ndim == 2
|
187 |
v = v.T # AWQ transpose the weight
|
188 |
|
189 |
-
# use numpy since jnp is very slow, likely due to bad memory management on CUDA
|
190 |
-
v = np.asarray(v)
|
191 |
K, N = v.shape
|
192 |
-
scales = find_scales(v, dim=0
|
193 |
inv_scale = 1 / scales.clip(1e-12)
|
194 |
qweight = np.round(v.reshape(K // 32, 32, N) * inv_scale[:, None])
|
195 |
|
@@ -216,7 +250,7 @@ def convert_awq(state_dict: dict[str, jnp.ndarray]):
|
|
216 |
prefix = k.removesuffix(".weight")
|
217 |
awq_state_dict[f"{prefix}.qweight"] = qweight_packed
|
218 |
awq_state_dict[f"{prefix}.qzeros"] = np.full((K // 32, N // 8), 0x8888_8888, dtype=np.uint32).view(np.int32)
|
219 |
-
awq_state_dict[f"{prefix}.scales"] =
|
220 |
|
221 |
return awq_state_dict
|
222 |
|
@@ -224,11 +258,37 @@ def convert_awq(state_dict: dict[str, jnp.ndarray]):
|
|
224 |
if __name__ == "__main__":
|
225 |
parser = argparse.ArgumentParser()
|
226 |
parser.add_argument("--ckpt_dir", required=True, type=Path)
|
227 |
-
parser.add_argument("--
|
228 |
args = parser.parse_args()
|
229 |
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import argparse
|
2 |
+
import json
|
3 |
from pathlib import Path
|
4 |
|
5 |
+
import jax
|
6 |
import jax.numpy as jnp
|
7 |
import numpy as np
|
8 |
+
import orbax.checkpoint as ocp
|
9 |
from safetensors.flax import save_file
|
10 |
from tqdm import tqdm
|
11 |
|
12 |
+
SIGLIP_PREFIX = "SigLiPFromPatches_0/siglip_encoder"
|
13 |
|
14 |
|
15 |
+
def flatten(x: np.ndarray, start: int = 0, end: int = -1):
|
16 |
if start < 0:
|
17 |
start += x.ndim
|
18 |
if end < 0:
|
|
|
21 |
return x.reshape(new_shape)
|
22 |
|
23 |
|
24 |
+
def unflatten(x: np.ndarray, dim: int, sizes: tuple[int, ...]):
|
25 |
new_shape = x.shape[:dim] + tuple(sizes) + x.shape[dim + 1 :]
|
26 |
return x.reshape(new_shape)
|
27 |
|
28 |
|
29 |
# correct quantization parameters mean quantization error = 0 (or close to 0)
|
30 |
+
def check_groups(groups: np.ndarray, scales: np.ndarray, dim: int):
|
31 |
# groups: (a, b, c, 32, d, e, f)
|
32 |
# scales: (a, b, c, 1, d, e, f)
|
33 |
inv_scale = 1.0 / scales.clip(1e-12)
|
34 |
+
q_group = np.round(groups * inv_scale)
|
35 |
+
max_diff = np.abs(q_group * scales - groups).max(dim, keepdims=True)
|
36 |
return max_diff < 1e-6, max_diff
|
37 |
|
38 |
|
39 |
+
def find_scales(w: np.ndarray, dim: int):
|
40 |
w = unflatten(w, dim, (-1, 32))
|
41 |
group_range = w.max(dim + 1, keepdims=True) - w.min(dim + 1, keepdims=True)
|
42 |
|
43 |
scales = np.zeros_like(group_range)
|
44 |
+
for q in range(15, 0, -1):
|
45 |
try_scale = group_range / q
|
46 |
ok, _ = check_groups(w, try_scale, dim + 1)
|
47 |
scales[ok] = try_scale[ok]
|
|
|
52 |
return scales.squeeze(dim + 1)
|
53 |
|
54 |
|
55 |
+
def convert_siglip(params, num_layers: int):
|
56 |
state_dict = dict()
|
57 |
|
58 |
+
def convert_layer(prefix: str, layer: dict[str, np.ndarray]):
|
59 |
bias = layer["bias"]
|
60 |
|
61 |
if "kernel" in layer:
|
|
|
88 |
state_dict[f"{prefix}weight"] = w
|
89 |
state_dict[f"{prefix}bias"] = bias
|
90 |
|
91 |
+
convert_layer("embeddings.patch_embedding.", params[f"{SIGLIP_PREFIX}/embedding"])
|
92 |
+
state_dict["embeddings.position_embedding.weight"] = params[SIGLIP_PREFIX]["pos_embedding"].squeeze(0)
|
93 |
+
convert_layer("post_layernorm.", params[f"{SIGLIP_PREFIX}/Transformer/encoder_norm"])
|
94 |
|
95 |
+
for layer_idx in range(num_layers):
|
|
|
|
|
|
|
|
|
96 |
prefix = f"encoder.layers.{layer_idx}."
|
97 |
+
layer_prefix = f"{SIGLIP_PREFIX}/Transformer/encoderblock_{layer_idx}/"
|
|
|
|
|
|
|
98 |
|
99 |
+
convert_layer(f"{prefix}layer_norm1.", params[f"{layer_prefix}LayerNorm_0"])
|
100 |
+
convert_layer(f"{prefix}layer_norm2.", params[f"{layer_prefix}LayerNorm_1"])
|
|
|
|
|
|
|
101 |
|
102 |
+
attn_prefix = f"{layer_prefix}MultiHeadDotProductAttention_0/"
|
103 |
+
convert_layer(f"{prefix}self_attn.q_proj.", params[f"{attn_prefix}query"])
|
104 |
+
convert_layer(f"{prefix}self_attn.k_proj.", params[f"{attn_prefix}key"])
|
105 |
+
convert_layer(f"{prefix}self_attn.v_proj.", params[f"{attn_prefix}value"])
|
106 |
+
convert_layer(f"{prefix}self_attn.out_proj.", params[f"{attn_prefix}out"])
|
107 |
|
108 |
+
mlp_prefix = f"{layer_prefix}MlpBlock_0/"
|
109 |
+
convert_layer(f"{prefix}mlp.fc1.", params[f"{mlp_prefix}Dense_0"])
|
110 |
+
convert_layer(f"{prefix}mlp.fc2.", params[f"{mlp_prefix}Dense_1"])
|
111 |
|
112 |
return state_dict
|
113 |
|
114 |
|
115 |
# convert to HF format first, then apply quantization
|
116 |
+
def convert_to_hf(path: Path):
|
117 |
+
path = path.absolute() # orbax only works with absolute path
|
118 |
+
ckpt = ocp.StandardCheckpointer()
|
119 |
+
metadata = dict(ckpt.metadata(path))
|
120 |
+
metadata = jax.tree.map(ocp.utils.to_shape_dtype_struct, metadata)
|
121 |
+
|
122 |
+
num_layers = num_siglip_layers = 0
|
123 |
+
while f"transformer/layer_{num_layers}/attn/_key_norm" in metadata:
|
124 |
+
num_layers += 1
|
125 |
+
while f"{SIGLIP_PREFIX}/Transformer/encoderblock_{num_siglip_layers}/LayerNorm_0" in metadata:
|
126 |
+
num_siglip_layers += 1
|
127 |
+
print(f"{num_layers=}")
|
128 |
+
print(f"{num_siglip_layers=}")
|
129 |
+
|
130 |
+
def load_params(*keys: tuple[str, ...], prefix: str | None = None):
|
131 |
+
# load params with specific keys and params starts with prefix
|
132 |
+
f1 = lambda k: tuple(subkey.key for subkey in k) in keys
|
133 |
+
f2 = lambda k: k[0].key.startswith(prefix)
|
134 |
+
|
135 |
+
# set to None to not load that weights
|
136 |
+
pytree = jax.tree.map_with_path(lambda k, v: v if f1(k) or f2(k) else None, metadata)
|
137 |
+
return ckpt.restore(path, pytree)
|
138 |
|
139 |
# NOTE: all gemma3 models use tied embeddings, even for the 27B version.
|
140 |
+
params = load_params(
|
141 |
+
("transformer/final_norm", "scale"),
|
142 |
+
prefix="transformer/embedder",
|
143 |
+
)
|
144 |
+
state_dict = dict()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
|
146 |
+
if num_siglip_layers > 0:
|
147 |
+
# HF append unused tokens for no reason???
|
148 |
+
embed = params["transformer/embedder"]["input_embedding"]
|
149 |
+
params["transformer/embedder"]["input_embedding"] = np.pad(embed, ((0, 64), (0, 0)))
|
150 |
+
gemma_prefix = "language_model."
|
151 |
|
152 |
+
prefix = "multi_modal_projector.mm_"
|
153 |
+
jax_prefix = "transformer/embedder/"
|
154 |
+
state_dict[f"{prefix}input_projection_weight"] = params[f"{jax_prefix}mm_input_projection"]["w"]
|
155 |
+
state_dict[f"{prefix}soft_emb_norm.weight"] = params[f"{jax_prefix}mm_soft_embedding_norm"]["scale"]
|
156 |
|
157 |
+
else:
|
158 |
+
gemma_prefix = ""
|
|
|
|
|
|
|
159 |
|
160 |
+
state_dict[f"{gemma_prefix}model.embed_tokens.weight"] = params["transformer/embedder"]["input_embedding"]
|
161 |
+
state_dict[f"{gemma_prefix}model.norm.weight"] = params["transformer/final_norm"]["scale"]
|
162 |
|
163 |
+
yield state_dict
|
|
|
|
|
|
|
164 |
|
165 |
+
for layer_idx in range(num_layers):
|
166 |
+
jax_prefix = f"transformer/layer_{layer_idx}/"
|
167 |
+
params = load_params(prefix=jax_prefix)
|
168 |
|
169 |
+
state_dict = dict()
|
170 |
+
prefix = f"{gemma_prefix}model.layers.{layer_idx}."
|
171 |
+
state_dict[f"{prefix}input_layernorm.weight"] = params[f"{jax_prefix}pre_attention_norm"]["scale"]
|
172 |
+
state_dict[f"{prefix}post_attention_layernorm.weight"] = params[f"{jax_prefix}post_attention_norm"]["scale"]
|
173 |
+
state_dict[f"{prefix}pre_feedforward_layernorm.weight"] = params[f"{jax_prefix}pre_ffw_norm"]["scale"]
|
174 |
+
state_dict[f"{prefix}post_feedforward_layernorm.weight"] = params[f"{jax_prefix}post_ffw_norm"]["scale"]
|
175 |
+
|
176 |
+
prefix = f"{gemma_prefix}model.layers.{layer_idx}.self_attn."
|
177 |
+
jax_prefix = f"transformer/layer_{layer_idx}/attn/"
|
178 |
+
state_dict[f"{prefix}q_norm.weight"] = params[f"{jax_prefix}_query_norm"]["scale"]
|
179 |
+
state_dict[f"{prefix}k_norm.weight"] = params[f"{jax_prefix}_key_norm"]["scale"]
|
180 |
+
|
181 |
+
# (num_heads, hidden_size, head_dim) -> (num_heads * head_dim, hidden_size)
|
182 |
+
state_dict[f"{prefix}q_proj.weight"] = flatten(params[f"{jax_prefix}q_einsum"]["w"].transpose(0, 2, 1), end=1)
|
183 |
+
state_dict[f"{prefix}k_proj.weight"] = flatten(
|
184 |
+
params[f"{jax_prefix}kv_einsum"]["w"][0].transpose(0, 2, 1), end=1
|
185 |
+
)
|
186 |
+
state_dict[f"{prefix}v_proj.weight"] = flatten(
|
187 |
+
params[f"{jax_prefix}kv_einsum"]["w"][1].transpose(0, 2, 1), end=1
|
188 |
+
)
|
189 |
+
|
190 |
+
# (num_heads, head_dim, hidden_size) -> (hidden_size, num_heads * head_dim)
|
191 |
+
state_dict[f"{prefix}o_proj.weight"] = flatten(params[f"{jax_prefix}attn_vec_einsum"]["w"], end=1).T
|
192 |
+
|
193 |
+
prefix = f"{gemma_prefix}model.layers.{layer_idx}.mlp."
|
194 |
+
jax_prefix = f"transformer/layer_{layer_idx}/mlp/"
|
195 |
+
state_dict[f"{prefix}gate_proj.weight"] = params[f"{jax_prefix}gating_einsum"]["w"][0]
|
196 |
+
state_dict[f"{prefix}up_proj.weight"] = params[f"{jax_prefix}gating_einsum"]["w"][1]
|
197 |
+
state_dict[f"{prefix}down_proj.weight"] = params[f"{jax_prefix}linear"]["w"].T
|
198 |
+
|
199 |
+
yield state_dict
|
200 |
|
201 |
+
# vision tower
|
202 |
+
if num_siglip_layers > 0:
|
203 |
+
params = load_params(prefix=SIGLIP_PREFIX)
|
204 |
+
siglip_state_dict = convert_siglip(params, num_siglip_layers)
|
205 |
for k, v in siglip_state_dict.items():
|
206 |
state_dict[f"vision_tower.vision_model.{k}"] = v
|
207 |
+
yield state_dict
|
|
|
208 |
|
209 |
|
210 |
+
def convert_awq(state_dict: dict[str, np.ndarray]):
|
211 |
awq_state_dict = dict()
|
212 |
|
213 |
+
for k, v in state_dict.items():
|
214 |
if (
|
215 |
k.endswith("model.embed_tokens.weight") # AWQ doesn't support INT4 embeddings
|
216 |
or k.startswith(("vision_tower", "multi_modal_projector")) # vision tower is not quantized
|
|
|
222 |
assert v.ndim == 2
|
223 |
v = v.T # AWQ transpose the weight
|
224 |
|
|
|
|
|
225 |
K, N = v.shape
|
226 |
+
scales = find_scales(v, dim=0) # (K/32, N)
|
227 |
inv_scale = 1 / scales.clip(1e-12)
|
228 |
qweight = np.round(v.reshape(K // 32, 32, N) * inv_scale[:, None])
|
229 |
|
|
|
250 |
prefix = k.removesuffix(".weight")
|
251 |
awq_state_dict[f"{prefix}.qweight"] = qweight_packed
|
252 |
awq_state_dict[f"{prefix}.qzeros"] = np.full((K // 32, N // 8), 0x8888_8888, dtype=np.uint32).view(np.int32)
|
253 |
+
awq_state_dict[f"{prefix}.scales"] = scales.astype(jnp.bfloat16)
|
254 |
|
255 |
return awq_state_dict
|
256 |
|
|
|
258 |
if __name__ == "__main__":
|
259 |
parser = argparse.ArgumentParser()
|
260 |
parser.add_argument("--ckpt_dir", required=True, type=Path)
|
261 |
+
parser.add_argument("--save_dir", required=True, type=Path)
|
262 |
args = parser.parse_args()
|
263 |
|
264 |
+
args.save_dir.mkdir(parents=True, exist_ok=True)
|
265 |
+
|
266 |
+
total_size = 0
|
267 |
+
weight_map = dict()
|
268 |
+
|
269 |
+
state_dict = dict()
|
270 |
+
size = 0
|
271 |
+
shard_idx = 0
|
272 |
+
filename = f"model-{shard_idx + 1:05d}.safetensors"
|
273 |
+
for sub_state_dict in tqdm(convert_to_hf(args.ckpt_dir)):
|
274 |
+
sub_state_dict = convert_awq(sub_state_dict)
|
275 |
+
|
276 |
+
for k, v in sub_state_dict.items():
|
277 |
+
state_dict[k] = v
|
278 |
+
size += v.nbytes
|
279 |
+
|
280 |
+
total_size += v.nbytes
|
281 |
+
weight_map[k] = filename
|
282 |
+
|
283 |
+
if size > 5e9:
|
284 |
+
save_file(state_dict, args.save_dir / filename)
|
285 |
+
state_dict = dict()
|
286 |
+
size = 0
|
287 |
+
shard_idx += 1
|
288 |
+
filename = f"model-{shard_idx + 1:05d}.safetensors"
|
289 |
+
|
290 |
+
save_file(state_dict, args.save_dir / filename)
|
291 |
+
json.dump(
|
292 |
+
dict(metadata=dict(total_size=total_size), weight_map=weight_map),
|
293 |
+
open(args.save_dir / "model.safetensors.index.json", "w"),
|
294 |
+
)
|