update script
Browse files- convert_flax.py +10 -23
convert_flax.py
CHANGED
|
@@ -127,20 +127,8 @@ def convert_to_hf(path: Path):
|
|
| 127 |
print(f"{num_layers=}")
|
| 128 |
print(f"{num_siglip_layers=}")
|
| 129 |
|
| 130 |
-
def load_params(*keys: tuple[str, ...], prefix: str | None = None):
|
| 131 |
-
# load params with specific keys and params starts with prefix
|
| 132 |
-
f1 = lambda k: tuple(subkey.key for subkey in k) in keys
|
| 133 |
-
f2 = lambda k: k[0].key.startswith(prefix)
|
| 134 |
-
|
| 135 |
-
# set to None to not load that weights
|
| 136 |
-
pytree = jax.tree.map_with_path(lambda k, v: v if f1(k) or f2(k) else None, metadata)
|
| 137 |
-
return ckpt.restore(path, pytree)
|
| 138 |
-
|
| 139 |
# NOTE: all gemma3 models use tied embeddings, even for the 27B version.
|
| 140 |
-
params =
|
| 141 |
-
("transformer/final_norm", "scale"),
|
| 142 |
-
prefix="transformer/embedder",
|
| 143 |
-
)
|
| 144 |
state_dict = dict()
|
| 145 |
|
| 146 |
if num_siglip_layers > 0:
|
|
@@ -164,7 +152,6 @@ def convert_to_hf(path: Path):
|
|
| 164 |
|
| 165 |
for layer_idx in range(num_layers):
|
| 166 |
jax_prefix = f"transformer/layer_{layer_idx}/"
|
| 167 |
-
params = load_params(prefix=jax_prefix)
|
| 168 |
|
| 169 |
state_dict = dict()
|
| 170 |
prefix = f"{gemma_prefix}model.layers.{layer_idx}."
|
|
@@ -200,7 +187,6 @@ def convert_to_hf(path: Path):
|
|
| 200 |
|
| 201 |
# vision tower
|
| 202 |
if num_siglip_layers > 0:
|
| 203 |
-
params = load_params(prefix=SIGLIP_PREFIX)
|
| 204 |
siglip_state_dict = convert_siglip(params, num_siglip_layers)
|
| 205 |
for k, v in siglip_state_dict.items():
|
| 206 |
state_dict[f"vision_tower.vision_model.{k}"] = v
|
|
@@ -272,21 +258,22 @@ if __name__ == "__main__":
|
|
| 272 |
filename = f"model-{shard_idx + 1:05d}.safetensors"
|
| 273 |
for sub_state_dict in tqdm(convert_to_hf(args.ckpt_dir)):
|
| 274 |
sub_state_dict = convert_awq(sub_state_dict)
|
|
|
|
| 275 |
|
| 276 |
-
|
| 277 |
-
state_dict[k] = v
|
| 278 |
-
size += v.nbytes
|
| 279 |
-
|
| 280 |
-
total_size += v.nbytes
|
| 281 |
-
weight_map[k] = filename
|
| 282 |
-
|
| 283 |
-
if size > 5e9:
|
| 284 |
save_file(state_dict, args.save_dir / filename)
|
| 285 |
state_dict = dict()
|
| 286 |
size = 0
|
| 287 |
shard_idx += 1
|
| 288 |
filename = f"model-{shard_idx + 1:05d}.safetensors"
|
| 289 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
save_file(state_dict, args.save_dir / filename)
|
| 291 |
json.dump(
|
| 292 |
dict(metadata=dict(total_size=total_size), weight_map=weight_map),
|
|
|
|
| 127 |
print(f"{num_layers=}")
|
| 128 |
print(f"{num_siglip_layers=}")
|
| 129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
# NOTE: all gemma3 models use tied embeddings, even for the 27B version.
|
| 131 |
+
params = ckpt.restore(path)
|
|
|
|
|
|
|
|
|
|
| 132 |
state_dict = dict()
|
| 133 |
|
| 134 |
if num_siglip_layers > 0:
|
|
|
|
| 152 |
|
| 153 |
for layer_idx in range(num_layers):
|
| 154 |
jax_prefix = f"transformer/layer_{layer_idx}/"
|
|
|
|
| 155 |
|
| 156 |
state_dict = dict()
|
| 157 |
prefix = f"{gemma_prefix}model.layers.{layer_idx}."
|
|
|
|
| 187 |
|
| 188 |
# vision tower
|
| 189 |
if num_siglip_layers > 0:
|
|
|
|
| 190 |
siglip_state_dict = convert_siglip(params, num_siglip_layers)
|
| 191 |
for k, v in siglip_state_dict.items():
|
| 192 |
state_dict[f"vision_tower.vision_model.{k}"] = v
|
|
|
|
| 258 |
filename = f"model-{shard_idx + 1:05d}.safetensors"
|
| 259 |
for sub_state_dict in tqdm(convert_to_hf(args.ckpt_dir)):
|
| 260 |
sub_state_dict = convert_awq(sub_state_dict)
|
| 261 |
+
new_size = sum(v.nbytes for v in sub_state_dict.values())
|
| 262 |
|
| 263 |
+
if size + new_size > 5e9:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
save_file(state_dict, args.save_dir / filename)
|
| 265 |
state_dict = dict()
|
| 266 |
size = 0
|
| 267 |
shard_idx += 1
|
| 268 |
filename = f"model-{shard_idx + 1:05d}.safetensors"
|
| 269 |
|
| 270 |
+
# assume that new_size < 5e9
|
| 271 |
+
size += new_size
|
| 272 |
+
total_size += new_size
|
| 273 |
+
for k, v in sub_state_dict.items():
|
| 274 |
+
state_dict[k] = v
|
| 275 |
+
weight_map[k] = filename
|
| 276 |
+
|
| 277 |
save_file(state_dict, args.save_dir / filename)
|
| 278 |
json.dump(
|
| 279 |
dict(metadata=dict(total_size=total_size), weight_map=weight_map),
|