update conversion script
Browse files- convert_flax.py +88 -3
convert_flax.py
CHANGED
|
@@ -49,10 +49,76 @@ def find_scales(w: jnp.ndarray, dim: int, pbar: bool = True):
|
|
| 49 |
return scales.squeeze(dim + 1)
|
| 50 |
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
# convert to HF format first, then apply quantization
|
| 53 |
def convert_to_hf(params):
|
| 54 |
state_dict = dict()
|
| 55 |
-
|
|
|
|
| 56 |
state_dict["model.embed_tokens.weight"] = params["embedder"]["input_embedding"]
|
| 57 |
state_dict["model.norm.weight"] = params["final_norm"]["scale"]
|
| 58 |
|
|
@@ -86,6 +152,22 @@ def convert_to_hf(params):
|
|
| 86 |
|
| 87 |
layer_idx += 1
|
| 88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
return state_dict
|
| 90 |
|
| 91 |
|
|
@@ -93,8 +175,11 @@ def convert_awq(state_dict: dict[str, jnp.ndarray]):
|
|
| 93 |
awq_state_dict = dict()
|
| 94 |
|
| 95 |
for k, v in tqdm(state_dict.items(), total=len(state_dict)):
|
| 96 |
-
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
| 98 |
awq_state_dict[k] = v.astype(jnp.bfloat16)
|
| 99 |
continue
|
| 100 |
|
|
|
|
| 49 |
return scales.squeeze(dim + 1)
|
| 50 |
|
| 51 |
|
| 52 |
+
def convert_siglip(params):
|
| 53 |
+
state_dict = dict()
|
| 54 |
+
|
| 55 |
+
def convert_layer(prefix: str, layer: dict[str, jnp.ndarray]):
|
| 56 |
+
bias = layer["bias"]
|
| 57 |
+
|
| 58 |
+
if "kernel" in layer:
|
| 59 |
+
w = layer["kernel"]
|
| 60 |
+
if w.ndim == 2: # linear layer
|
| 61 |
+
w = w.T
|
| 62 |
+
|
| 63 |
+
elif w.ndim == 3: # attn projection
|
| 64 |
+
# qkv projection - (dim, num_heads, head_dim)
|
| 65 |
+
if bias.ndim == 2:
|
| 66 |
+
w = flatten(w, 1, 2).T
|
| 67 |
+
bias = bias.reshape(-1)
|
| 68 |
+
|
| 69 |
+
# o projection - (num_heads, head_dim, dim)
|
| 70 |
+
elif bias.ndim == 1:
|
| 71 |
+
w = flatten(w, 0, 1).T
|
| 72 |
+
|
| 73 |
+
elif w.ndim == 4: # conv2d layer
|
| 74 |
+
w = w.transpose(3, 2, 0, 1)
|
| 75 |
+
|
| 76 |
+
else:
|
| 77 |
+
raise RuntimeError(f"Unsupported {w.shape=}")
|
| 78 |
+
|
| 79 |
+
elif "scale" in layer: # layer norm
|
| 80 |
+
w = layer["scale"]
|
| 81 |
+
|
| 82 |
+
else:
|
| 83 |
+
raise RuntimeError
|
| 84 |
+
|
| 85 |
+
state_dict[f"{prefix}weight"] = w
|
| 86 |
+
state_dict[f"{prefix}bias"] = bias
|
| 87 |
+
|
| 88 |
+
convert_layer("embeddings.patch_embedding.", params["embedding"])
|
| 89 |
+
state_dict["embeddings.position_embedding.weight"] = params["pos_embedding"].squeeze(0)
|
| 90 |
+
|
| 91 |
+
trunk = params["Transformer"]
|
| 92 |
+
convert_layer("post_layernorm.", trunk["encoder_norm"])
|
| 93 |
+
|
| 94 |
+
layer_idx = 0
|
| 95 |
+
while f"encoderblock_{layer_idx}" in trunk:
|
| 96 |
+
prefix = f"encoder.layers.{layer_idx}."
|
| 97 |
+
encoder_layer = trunk[f"encoderblock_{layer_idx}"]
|
| 98 |
+
|
| 99 |
+
convert_layer(f"{prefix}layer_norm1.", encoder_layer["LayerNorm_0"])
|
| 100 |
+
convert_layer(f"{prefix}layer_norm2.", encoder_layer["LayerNorm_1"])
|
| 101 |
+
|
| 102 |
+
attn_layer = encoder_layer["MultiHeadDotProductAttention_0"]
|
| 103 |
+
convert_layer(f"{prefix}self_attn.q_proj.", attn_layer["query"])
|
| 104 |
+
convert_layer(f"{prefix}self_attn.k_proj.", attn_layer["key"])
|
| 105 |
+
convert_layer(f"{prefix}self_attn.v_proj.", attn_layer["value"])
|
| 106 |
+
convert_layer(f"{prefix}self_attn.out_proj.", attn_layer["out"])
|
| 107 |
+
|
| 108 |
+
mlp_layer = encoder_layer["MlpBlock_0"]
|
| 109 |
+
convert_layer(f"{prefix}mlp.fc1.", mlp_layer["Dense_0"])
|
| 110 |
+
convert_layer(f"{prefix}mlp.fc2.", mlp_layer["Dense_1"])
|
| 111 |
+
|
| 112 |
+
layer_idx += 1
|
| 113 |
+
|
| 114 |
+
return state_dict
|
| 115 |
+
|
| 116 |
+
|
| 117 |
# convert to HF format first, then apply quantization
|
| 118 |
def convert_to_hf(params):
|
| 119 |
state_dict = dict()
|
| 120 |
+
|
| 121 |
+
# NOTE: all gemma3 models use tied embeddings, even for the 27B version.
|
| 122 |
state_dict["model.embed_tokens.weight"] = params["embedder"]["input_embedding"]
|
| 123 |
state_dict["model.norm.weight"] = params["final_norm"]["scale"]
|
| 124 |
|
|
|
|
| 152 |
|
| 153 |
layer_idx += 1
|
| 154 |
|
| 155 |
+
# vision tower
|
| 156 |
+
if "vision_encoder" in params:
|
| 157 |
+
# HF append unused tokens for no reason???
|
| 158 |
+
state_dict["model.embed_tokens.weight"] = jnp.pad(state_dict["model.embed_tokens.weight"], ((0, 64), (0, 0)))
|
| 159 |
+
|
| 160 |
+
for k in list(state_dict.keys()):
|
| 161 |
+
state_dict[f"language_model.{k}"] = state_dict.pop(k)
|
| 162 |
+
|
| 163 |
+
prefix = "multi_modal_projector.mm_"
|
| 164 |
+
state_dict[f"{prefix}input_projection_weight"] = params["embedder"]["mm_input_projection"]["w"]
|
| 165 |
+
state_dict[f"{prefix}soft_emb_norm.weight"] = params["embedder"]["mm_soft_embedding_norm"]["scale"]
|
| 166 |
+
|
| 167 |
+
siglip_state_dict = convert_siglip(params["vision_encoder"]["siglip_encoder"])
|
| 168 |
+
for k, v in siglip_state_dict.items():
|
| 169 |
+
state_dict[f"vision_tower.vision_model.{k}"] = v
|
| 170 |
+
|
| 171 |
return state_dict
|
| 172 |
|
| 173 |
|
|
|
|
| 175 |
awq_state_dict = dict()
|
| 176 |
|
| 177 |
for k, v in tqdm(state_dict.items(), total=len(state_dict)):
|
| 178 |
+
if (
|
| 179 |
+
k.endswith("model.embed_tokens.weight") # AWQ doesn't support INT4 embeddings
|
| 180 |
+
or k.startswith(("vision_tower", "multi_modal_projector")) # vision tower is not quantized
|
| 181 |
+
or v.ndim == 1
|
| 182 |
+
):
|
| 183 |
awq_state_dict[k] = v.astype(jnp.bfloat16)
|
| 184 |
continue
|
| 185 |
|