gaunernst commited on
Commit
8271f98
·
1 Parent(s): 18df9dc

update script

Browse files
Files changed (1) hide show
  1. convert_flax.py +144 -84
convert_flax.py CHANGED
@@ -1,15 +1,18 @@
1
  import argparse
 
2
  from pathlib import Path
3
 
 
4
  import jax.numpy as jnp
5
  import numpy as np
 
6
  from safetensors.flax import save_file
7
  from tqdm import tqdm
8
 
9
- from gemma import gm
10
 
11
 
12
- def flatten(x: jnp.ndarray, start: int = 0, end: int = -1):
13
  if start < 0:
14
  start += x.ndim
15
  if end < 0:
@@ -18,27 +21,27 @@ def flatten(x: jnp.ndarray, start: int = 0, end: int = -1):
18
  return x.reshape(new_shape)
19
 
20
 
21
- def unflatten(x: jnp.ndarray, dim: int, sizes: tuple[int, ...]):
22
  new_shape = x.shape[:dim] + tuple(sizes) + x.shape[dim + 1 :]
23
  return x.reshape(new_shape)
24
 
25
 
26
  # correct quantization parameters mean quantization error = 0 (or close to 0)
27
- def check_groups(groups: jnp.ndarray, scales: jnp.ndarray, dim: int):
28
  # groups: (a, b, c, 32, d, e, f)
29
  # scales: (a, b, c, 1, d, e, f)
30
  inv_scale = 1.0 / scales.clip(1e-12)
31
- q_group = jnp.round(groups * inv_scale)
32
- max_diff = jnp.abs(q_group * scales - groups).max(dim, keepdims=True)
33
  return max_diff < 1e-6, max_diff
34
 
35
 
36
- def find_scales(w: jnp.ndarray, dim: int, pbar: bool = True):
37
  w = unflatten(w, dim, (-1, 32))
38
  group_range = w.max(dim + 1, keepdims=True) - w.min(dim + 1, keepdims=True)
39
 
40
  scales = np.zeros_like(group_range)
41
- for q in tqdm(range(15, 0, -1), disable=not pbar):
42
  try_scale = group_range / q
43
  ok, _ = check_groups(w, try_scale, dim + 1)
44
  scales[ok] = try_scale[ok]
@@ -49,10 +52,10 @@ def find_scales(w: jnp.ndarray, dim: int, pbar: bool = True):
49
  return scales.squeeze(dim + 1)
50
 
51
 
52
- def convert_siglip(params):
53
  state_dict = dict()
54
 
55
- def convert_layer(prefix: str, layer: dict[str, jnp.ndarray]):
56
  bias = layer["bias"]
57
 
58
  if "kernel" in layer:
@@ -85,96 +88,129 @@ def convert_siglip(params):
85
  state_dict[f"{prefix}weight"] = w
86
  state_dict[f"{prefix}bias"] = bias
87
 
88
- convert_layer("embeddings.patch_embedding.", params["embedding"])
89
- state_dict["embeddings.position_embedding.weight"] = params["pos_embedding"].squeeze(0)
 
90
 
91
- trunk = params["Transformer"]
92
- convert_layer("post_layernorm.", trunk["encoder_norm"])
93
-
94
- layer_idx = 0
95
- while f"encoderblock_{layer_idx}" in trunk:
96
  prefix = f"encoder.layers.{layer_idx}."
97
- encoder_layer = trunk[f"encoderblock_{layer_idx}"]
98
-
99
- convert_layer(f"{prefix}layer_norm1.", encoder_layer["LayerNorm_0"])
100
- convert_layer(f"{prefix}layer_norm2.", encoder_layer["LayerNorm_1"])
101
 
102
- attn_layer = encoder_layer["MultiHeadDotProductAttention_0"]
103
- convert_layer(f"{prefix}self_attn.q_proj.", attn_layer["query"])
104
- convert_layer(f"{prefix}self_attn.k_proj.", attn_layer["key"])
105
- convert_layer(f"{prefix}self_attn.v_proj.", attn_layer["value"])
106
- convert_layer(f"{prefix}self_attn.out_proj.", attn_layer["out"])
107
 
108
- mlp_layer = encoder_layer["MlpBlock_0"]
109
- convert_layer(f"{prefix}mlp.fc1.", mlp_layer["Dense_0"])
110
- convert_layer(f"{prefix}mlp.fc2.", mlp_layer["Dense_1"])
 
 
111
 
112
- layer_idx += 1
 
 
113
 
114
  return state_dict
115
 
116
 
117
  # convert to HF format first, then apply quantization
118
- def convert_to_hf(params):
119
- state_dict = dict()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  # NOTE: all gemma3 models use tied embeddings, even for the 27B version.
122
- state_dict["model.embed_tokens.weight"] = params["embedder"]["input_embedding"]
123
- state_dict["model.norm.weight"] = params["final_norm"]["scale"]
124
-
125
- layer_idx = 0
126
- while f"layer_{layer_idx}" in params:
127
- prefix = f"model.layers.{layer_idx}."
128
- layer_params = params[f"layer_{layer_idx}"]
129
- state_dict[f"{prefix}input_layernorm.weight"] = layer_params["pre_attention_norm"]["scale"]
130
- state_dict[f"{prefix}post_attention_layernorm.weight"] = layer_params["post_attention_norm"]["scale"]
131
- state_dict[f"{prefix}pre_feedforward_layernorm.weight"] = layer_params["pre_ffw_norm"]["scale"]
132
- state_dict[f"{prefix}post_feedforward_layernorm.weight"] = layer_params["post_ffw_norm"]["scale"]
133
-
134
- prefix = f"model.layers.{layer_idx}.self_attn."
135
- attn_params = layer_params["attn"]
136
- state_dict[f"{prefix}q_norm.weight"] = attn_params["_query_norm"]["scale"]
137
- state_dict[f"{prefix}k_norm.weight"] = attn_params["_key_norm"]["scale"]
138
 
139
- # (num_heads, hidden_size, head_dim) -> (num_heads * head_dim, hidden_size)
140
- state_dict[f"{prefix}q_proj.weight"] = flatten(attn_params["q_einsum"]["w"].transpose(0, 2, 1), end=1)
141
- state_dict[f"{prefix}k_proj.weight"] = flatten(attn_params["kv_einsum"]["w"][0].transpose(0, 2, 1), end=1)
142
- state_dict[f"{prefix}v_proj.weight"] = flatten(attn_params["kv_einsum"]["w"][1].transpose(0, 2, 1), end=1)
 
143
 
144
- # (num_heads, head_dim, hidden_size) -> (hidden_size, num_heads * head_dim)
145
- state_dict[f"{prefix}o_proj.weight"] = flatten(attn_params["attn_vec_einsum"]["w"], end=1).T
 
 
146
 
147
- prefix = f"model.layers.{layer_idx}.mlp."
148
- mlp_params = layer_params["mlp"]
149
- state_dict[f"{prefix}gate_proj.weight"] = mlp_params["gating_einsum"][0] # NOTE: may need to transpose?
150
- state_dict[f"{prefix}up_proj.weight"] = mlp_params["gating_einsum"][1]
151
- state_dict[f"{prefix}down_proj.weight"] = mlp_params["linear"].T
152
 
153
- layer_idx += 1
 
154
 
155
- # vision tower
156
- if "vision_encoder" in params:
157
- # HF append unused tokens for no reason???
158
- state_dict["model.embed_tokens.weight"] = jnp.pad(state_dict["model.embed_tokens.weight"], ((0, 64), (0, 0)))
159
 
160
- for k in list(state_dict.keys()):
161
- state_dict[f"language_model.{k}"] = state_dict.pop(k)
 
162
 
163
- prefix = "multi_modal_projector.mm_"
164
- state_dict[f"{prefix}input_projection_weight"] = params["embedder"]["mm_input_projection"]["w"]
165
- state_dict[f"{prefix}soft_emb_norm.weight"] = params["embedder"]["mm_soft_embedding_norm"]["scale"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
- siglip_state_dict = convert_siglip(params["vision_encoder"]["siglip_encoder"])
 
 
 
168
  for k, v in siglip_state_dict.items():
169
  state_dict[f"vision_tower.vision_model.{k}"] = v
170
-
171
- return state_dict
172
 
173
 
174
- def convert_awq(state_dict: dict[str, jnp.ndarray]):
175
  awq_state_dict = dict()
176
 
177
- for k, v in tqdm(state_dict.items(), total=len(state_dict)):
178
  if (
179
  k.endswith("model.embed_tokens.weight") # AWQ doesn't support INT4 embeddings
180
  or k.startswith(("vision_tower", "multi_modal_projector")) # vision tower is not quantized
@@ -186,10 +222,8 @@ def convert_awq(state_dict: dict[str, jnp.ndarray]):
186
  assert v.ndim == 2
187
  v = v.T # AWQ transpose the weight
188
 
189
- # use numpy since jnp is very slow, likely due to bad memory management on CUDA
190
- v = np.asarray(v)
191
  K, N = v.shape
192
- scales = find_scales(v, dim=0, pbar=False) # (K/32, N)
193
  inv_scale = 1 / scales.clip(1e-12)
194
  qweight = np.round(v.reshape(K // 32, 32, N) * inv_scale[:, None])
195
 
@@ -216,7 +250,7 @@ def convert_awq(state_dict: dict[str, jnp.ndarray]):
216
  prefix = k.removesuffix(".weight")
217
  awq_state_dict[f"{prefix}.qweight"] = qweight_packed
218
  awq_state_dict[f"{prefix}.qzeros"] = np.full((K // 32, N // 8), 0x8888_8888, dtype=np.uint32).view(np.int32)
219
- awq_state_dict[f"{prefix}.scales"] = jnp.asarray(scales).astype(jnp.bfloat16)
220
 
221
  return awq_state_dict
222
 
@@ -224,11 +258,37 @@ def convert_awq(state_dict: dict[str, jnp.ndarray]):
224
  if __name__ == "__main__":
225
  parser = argparse.ArgumentParser()
226
  parser.add_argument("--ckpt_dir", required=True, type=Path)
227
- parser.add_argument("--save_path", required=True, type=Path)
228
  args = parser.parse_args()
229
 
230
- params = gm.ckpts.load_params(args.ckpt_dir.absolute())
231
- state_dict = convert_to_hf(params)
232
- awq_state_dict = convert_awq(state_dict)
233
- args.save_path.parent.mkdir(parents=True, exist_ok=True)
234
- save_file(awq_state_dict, args.save_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import argparse
2
+ import json
3
  from pathlib import Path
4
 
5
+ import jax
6
  import jax.numpy as jnp
7
  import numpy as np
8
+ import orbax.checkpoint as ocp
9
  from safetensors.flax import save_file
10
  from tqdm import tqdm
11
 
12
+ SIGLIP_PREFIX = "SigLiPFromPatches_0/siglip_encoder"
13
 
14
 
15
+ def flatten(x: np.ndarray, start: int = 0, end: int = -1):
16
  if start < 0:
17
  start += x.ndim
18
  if end < 0:
 
21
  return x.reshape(new_shape)
22
 
23
 
24
+ def unflatten(x: np.ndarray, dim: int, sizes: tuple[int, ...]):
25
  new_shape = x.shape[:dim] + tuple(sizes) + x.shape[dim + 1 :]
26
  return x.reshape(new_shape)
27
 
28
 
29
  # correct quantization parameters mean quantization error = 0 (or close to 0)
30
+ def check_groups(groups: np.ndarray, scales: np.ndarray, dim: int):
31
  # groups: (a, b, c, 32, d, e, f)
32
  # scales: (a, b, c, 1, d, e, f)
33
  inv_scale = 1.0 / scales.clip(1e-12)
34
+ q_group = np.round(groups * inv_scale)
35
+ max_diff = np.abs(q_group * scales - groups).max(dim, keepdims=True)
36
  return max_diff < 1e-6, max_diff
37
 
38
 
39
+ def find_scales(w: np.ndarray, dim: int):
40
  w = unflatten(w, dim, (-1, 32))
41
  group_range = w.max(dim + 1, keepdims=True) - w.min(dim + 1, keepdims=True)
42
 
43
  scales = np.zeros_like(group_range)
44
+ for q in range(15, 0, -1):
45
  try_scale = group_range / q
46
  ok, _ = check_groups(w, try_scale, dim + 1)
47
  scales[ok] = try_scale[ok]
 
52
  return scales.squeeze(dim + 1)
53
 
54
 
55
+ def convert_siglip(params, num_layers: int):
56
  state_dict = dict()
57
 
58
+ def convert_layer(prefix: str, layer: dict[str, np.ndarray]):
59
  bias = layer["bias"]
60
 
61
  if "kernel" in layer:
 
88
  state_dict[f"{prefix}weight"] = w
89
  state_dict[f"{prefix}bias"] = bias
90
 
91
+ convert_layer("embeddings.patch_embedding.", params[f"{SIGLIP_PREFIX}/embedding"])
92
+ state_dict["embeddings.position_embedding.weight"] = params[SIGLIP_PREFIX]["pos_embedding"].squeeze(0)
93
+ convert_layer("post_layernorm.", params[f"{SIGLIP_PREFIX}/Transformer/encoder_norm"])
94
 
95
+ for layer_idx in range(num_layers):
 
 
 
 
96
  prefix = f"encoder.layers.{layer_idx}."
97
+ layer_prefix = f"{SIGLIP_PREFIX}/Transformer/encoderblock_{layer_idx}/"
 
 
 
98
 
99
+ convert_layer(f"{prefix}layer_norm1.", params[f"{layer_prefix}LayerNorm_0"])
100
+ convert_layer(f"{prefix}layer_norm2.", params[f"{layer_prefix}LayerNorm_1"])
 
 
 
101
 
102
+ attn_prefix = f"{layer_prefix}MultiHeadDotProductAttention_0/"
103
+ convert_layer(f"{prefix}self_attn.q_proj.", params[f"{attn_prefix}query"])
104
+ convert_layer(f"{prefix}self_attn.k_proj.", params[f"{attn_prefix}key"])
105
+ convert_layer(f"{prefix}self_attn.v_proj.", params[f"{attn_prefix}value"])
106
+ convert_layer(f"{prefix}self_attn.out_proj.", params[f"{attn_prefix}out"])
107
 
108
+ mlp_prefix = f"{layer_prefix}MlpBlock_0/"
109
+ convert_layer(f"{prefix}mlp.fc1.", params[f"{mlp_prefix}Dense_0"])
110
+ convert_layer(f"{prefix}mlp.fc2.", params[f"{mlp_prefix}Dense_1"])
111
 
112
  return state_dict
113
 
114
 
115
  # convert to HF format first, then apply quantization
116
+ def convert_to_hf(path: Path):
117
+ path = path.absolute() # orbax only works with absolute path
118
+ ckpt = ocp.StandardCheckpointer()
119
+ metadata = dict(ckpt.metadata(path))
120
+ metadata = jax.tree.map(ocp.utils.to_shape_dtype_struct, metadata)
121
+
122
+ num_layers = num_siglip_layers = 0
123
+ while f"transformer/layer_{num_layers}/attn/_key_norm" in metadata:
124
+ num_layers += 1
125
+ while f"{SIGLIP_PREFIX}/Transformer/encoderblock_{num_siglip_layers}/LayerNorm_0" in metadata:
126
+ num_siglip_layers += 1
127
+ print(f"{num_layers=}")
128
+ print(f"{num_siglip_layers=}")
129
+
130
+ def load_params(*keys: tuple[str, ...], prefix: str | None = None):
131
+ # load params with specific keys and params starts with prefix
132
+ f1 = lambda k: tuple(subkey.key for subkey in k) in keys
133
+ f2 = lambda k: k[0].key.startswith(prefix)
134
+
135
+ # set to None to not load that weights
136
+ pytree = jax.tree.map_with_path(lambda k, v: v if f1(k) or f2(k) else None, metadata)
137
+ return ckpt.restore(path, pytree)
138
 
139
  # NOTE: all gemma3 models use tied embeddings, even for the 27B version.
140
+ params = load_params(
141
+ ("transformer/final_norm", "scale"),
142
+ prefix="transformer/embedder",
143
+ )
144
+ state_dict = dict()
 
 
 
 
 
 
 
 
 
 
 
145
 
146
+ if num_siglip_layers > 0:
147
+ # HF append unused tokens for no reason???
148
+ embed = params["transformer/embedder"]["input_embedding"]
149
+ params["transformer/embedder"]["input_embedding"] = np.pad(embed, ((0, 64), (0, 0)))
150
+ gemma_prefix = "language_model."
151
 
152
+ prefix = "multi_modal_projector.mm_"
153
+ jax_prefix = "transformer/embedder/"
154
+ state_dict[f"{prefix}input_projection_weight"] = params[f"{jax_prefix}mm_input_projection"]["w"]
155
+ state_dict[f"{prefix}soft_emb_norm.weight"] = params[f"{jax_prefix}mm_soft_embedding_norm"]["scale"]
156
 
157
+ else:
158
+ gemma_prefix = ""
 
 
 
159
 
160
+ state_dict[f"{gemma_prefix}model.embed_tokens.weight"] = params["transformer/embedder"]["input_embedding"]
161
+ state_dict[f"{gemma_prefix}model.norm.weight"] = params["transformer/final_norm"]["scale"]
162
 
163
+ yield state_dict
 
 
 
164
 
165
+ for layer_idx in range(num_layers):
166
+ jax_prefix = f"transformer/layer_{layer_idx}/"
167
+ params = load_params(prefix=jax_prefix)
168
 
169
+ state_dict = dict()
170
+ prefix = f"{gemma_prefix}model.layers.{layer_idx}."
171
+ state_dict[f"{prefix}input_layernorm.weight"] = params[f"{jax_prefix}pre_attention_norm"]["scale"]
172
+ state_dict[f"{prefix}post_attention_layernorm.weight"] = params[f"{jax_prefix}post_attention_norm"]["scale"]
173
+ state_dict[f"{prefix}pre_feedforward_layernorm.weight"] = params[f"{jax_prefix}pre_ffw_norm"]["scale"]
174
+ state_dict[f"{prefix}post_feedforward_layernorm.weight"] = params[f"{jax_prefix}post_ffw_norm"]["scale"]
175
+
176
+ prefix = f"{gemma_prefix}model.layers.{layer_idx}.self_attn."
177
+ jax_prefix = f"transformer/layer_{layer_idx}/attn/"
178
+ state_dict[f"{prefix}q_norm.weight"] = params[f"{jax_prefix}_query_norm"]["scale"]
179
+ state_dict[f"{prefix}k_norm.weight"] = params[f"{jax_prefix}_key_norm"]["scale"]
180
+
181
+ # (num_heads, hidden_size, head_dim) -> (num_heads * head_dim, hidden_size)
182
+ state_dict[f"{prefix}q_proj.weight"] = flatten(params[f"{jax_prefix}q_einsum"]["w"].transpose(0, 2, 1), end=1)
183
+ state_dict[f"{prefix}k_proj.weight"] = flatten(
184
+ params[f"{jax_prefix}kv_einsum"]["w"][0].transpose(0, 2, 1), end=1
185
+ )
186
+ state_dict[f"{prefix}v_proj.weight"] = flatten(
187
+ params[f"{jax_prefix}kv_einsum"]["w"][1].transpose(0, 2, 1), end=1
188
+ )
189
+
190
+ # (num_heads, head_dim, hidden_size) -> (hidden_size, num_heads * head_dim)
191
+ state_dict[f"{prefix}o_proj.weight"] = flatten(params[f"{jax_prefix}attn_vec_einsum"]["w"], end=1).T
192
+
193
+ prefix = f"{gemma_prefix}model.layers.{layer_idx}.mlp."
194
+ jax_prefix = f"transformer/layer_{layer_idx}/mlp/"
195
+ state_dict[f"{prefix}gate_proj.weight"] = params[f"{jax_prefix}gating_einsum"]["w"][0]
196
+ state_dict[f"{prefix}up_proj.weight"] = params[f"{jax_prefix}gating_einsum"]["w"][1]
197
+ state_dict[f"{prefix}down_proj.weight"] = params[f"{jax_prefix}linear"]["w"].T
198
+
199
+ yield state_dict
200
 
201
+ # vision tower
202
+ if num_siglip_layers > 0:
203
+ params = load_params(prefix=SIGLIP_PREFIX)
204
+ siglip_state_dict = convert_siglip(params, num_siglip_layers)
205
  for k, v in siglip_state_dict.items():
206
  state_dict[f"vision_tower.vision_model.{k}"] = v
207
+ yield state_dict
 
208
 
209
 
210
+ def convert_awq(state_dict: dict[str, np.ndarray]):
211
  awq_state_dict = dict()
212
 
213
+ for k, v in state_dict.items():
214
  if (
215
  k.endswith("model.embed_tokens.weight") # AWQ doesn't support INT4 embeddings
216
  or k.startswith(("vision_tower", "multi_modal_projector")) # vision tower is not quantized
 
222
  assert v.ndim == 2
223
  v = v.T # AWQ transpose the weight
224
 
 
 
225
  K, N = v.shape
226
+ scales = find_scales(v, dim=0) # (K/32, N)
227
  inv_scale = 1 / scales.clip(1e-12)
228
  qweight = np.round(v.reshape(K // 32, 32, N) * inv_scale[:, None])
229
 
 
250
  prefix = k.removesuffix(".weight")
251
  awq_state_dict[f"{prefix}.qweight"] = qweight_packed
252
  awq_state_dict[f"{prefix}.qzeros"] = np.full((K // 32, N // 8), 0x8888_8888, dtype=np.uint32).view(np.int32)
253
+ awq_state_dict[f"{prefix}.scales"] = scales.astype(jnp.bfloat16)
254
 
255
  return awq_state_dict
256
 
 
258
  if __name__ == "__main__":
259
  parser = argparse.ArgumentParser()
260
  parser.add_argument("--ckpt_dir", required=True, type=Path)
261
+ parser.add_argument("--save_dir", required=True, type=Path)
262
  args = parser.parse_args()
263
 
264
+ args.save_dir.mkdir(parents=True, exist_ok=True)
265
+
266
+ total_size = 0
267
+ weight_map = dict()
268
+
269
+ state_dict = dict()
270
+ size = 0
271
+ shard_idx = 0
272
+ filename = f"model-{shard_idx + 1:05d}.safetensors"
273
+ for sub_state_dict in tqdm(convert_to_hf(args.ckpt_dir)):
274
+ sub_state_dict = convert_awq(sub_state_dict)
275
+
276
+ for k, v in sub_state_dict.items():
277
+ state_dict[k] = v
278
+ size += v.nbytes
279
+
280
+ total_size += v.nbytes
281
+ weight_map[k] = filename
282
+
283
+ if size > 5e9:
284
+ save_file(state_dict, args.save_dir / filename)
285
+ state_dict = dict()
286
+ size = 0
287
+ shard_idx += 1
288
+ filename = f"model-{shard_idx + 1:05d}.safetensors"
289
+
290
+ save_file(state_dict, args.save_dir / filename)
291
+ json.dump(
292
+ dict(metadata=dict(total_size=total_size), weight_map=weight_map),
293
+ open(args.save_dir / "model.safetensors.index.json", "w"),
294
+ )