gaunernst commited on
Commit
18df9dc
·
1 Parent(s): 36e3234

update conversion script

Browse files
Files changed (1) hide show
  1. convert_flax.py +88 -3
convert_flax.py CHANGED
@@ -49,10 +49,76 @@ def find_scales(w: jnp.ndarray, dim: int, pbar: bool = True):
49
  return scales.squeeze(dim + 1)
50
 
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  # convert to HF format first, then apply quantization
53
  def convert_to_hf(params):
54
  state_dict = dict()
55
- # TODO: output projection
 
56
  state_dict["model.embed_tokens.weight"] = params["embedder"]["input_embedding"]
57
  state_dict["model.norm.weight"] = params["final_norm"]["scale"]
58
 
@@ -86,6 +152,22 @@ def convert_to_hf(params):
86
 
87
  layer_idx += 1
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  return state_dict
90
 
91
 
@@ -93,8 +175,11 @@ def convert_awq(state_dict: dict[str, jnp.ndarray]):
93
  awq_state_dict = dict()
94
 
95
  for k, v in tqdm(state_dict.items(), total=len(state_dict)):
96
- # AWQ doesn't support INT4 embeddings
97
- if k == "model.embed_tokens.weight" or v.ndim == 1:
 
 
 
98
  awq_state_dict[k] = v.astype(jnp.bfloat16)
99
  continue
100
 
 
49
  return scales.squeeze(dim + 1)
50
 
51
 
52
+ def convert_siglip(params):
53
+ state_dict = dict()
54
+
55
+ def convert_layer(prefix: str, layer: dict[str, jnp.ndarray]):
56
+ bias = layer["bias"]
57
+
58
+ if "kernel" in layer:
59
+ w = layer["kernel"]
60
+ if w.ndim == 2: # linear layer
61
+ w = w.T
62
+
63
+ elif w.ndim == 3: # attn projection
64
+ # qkv projection - (dim, num_heads, head_dim)
65
+ if bias.ndim == 2:
66
+ w = flatten(w, 1, 2).T
67
+ bias = bias.reshape(-1)
68
+
69
+ # o projection - (num_heads, head_dim, dim)
70
+ elif bias.ndim == 1:
71
+ w = flatten(w, 0, 1).T
72
+
73
+ elif w.ndim == 4: # conv2d layer
74
+ w = w.transpose(3, 2, 0, 1)
75
+
76
+ else:
77
+ raise RuntimeError(f"Unsupported {w.shape=}")
78
+
79
+ elif "scale" in layer: # layer norm
80
+ w = layer["scale"]
81
+
82
+ else:
83
+ raise RuntimeError
84
+
85
+ state_dict[f"{prefix}weight"] = w
86
+ state_dict[f"{prefix}bias"] = bias
87
+
88
+ convert_layer("embeddings.patch_embedding.", params["embedding"])
89
+ state_dict["embeddings.position_embedding.weight"] = params["pos_embedding"].squeeze(0)
90
+
91
+ trunk = params["Transformer"]
92
+ convert_layer("post_layernorm.", trunk["encoder_norm"])
93
+
94
+ layer_idx = 0
95
+ while f"encoderblock_{layer_idx}" in trunk:
96
+ prefix = f"encoder.layers.{layer_idx}."
97
+ encoder_layer = trunk[f"encoderblock_{layer_idx}"]
98
+
99
+ convert_layer(f"{prefix}layer_norm1.", encoder_layer["LayerNorm_0"])
100
+ convert_layer(f"{prefix}layer_norm2.", encoder_layer["LayerNorm_1"])
101
+
102
+ attn_layer = encoder_layer["MultiHeadDotProductAttention_0"]
103
+ convert_layer(f"{prefix}self_attn.q_proj.", attn_layer["query"])
104
+ convert_layer(f"{prefix}self_attn.k_proj.", attn_layer["key"])
105
+ convert_layer(f"{prefix}self_attn.v_proj.", attn_layer["value"])
106
+ convert_layer(f"{prefix}self_attn.out_proj.", attn_layer["out"])
107
+
108
+ mlp_layer = encoder_layer["MlpBlock_0"]
109
+ convert_layer(f"{prefix}mlp.fc1.", mlp_layer["Dense_0"])
110
+ convert_layer(f"{prefix}mlp.fc2.", mlp_layer["Dense_1"])
111
+
112
+ layer_idx += 1
113
+
114
+ return state_dict
115
+
116
+
117
  # convert to HF format first, then apply quantization
118
  def convert_to_hf(params):
119
  state_dict = dict()
120
+
121
+ # NOTE: all gemma3 models use tied embeddings, even for the 27B version.
122
  state_dict["model.embed_tokens.weight"] = params["embedder"]["input_embedding"]
123
  state_dict["model.norm.weight"] = params["final_norm"]["scale"]
124
 
 
152
 
153
  layer_idx += 1
154
 
155
+ # vision tower
156
+ if "vision_encoder" in params:
157
+ # HF append unused tokens for no reason???
158
+ state_dict["model.embed_tokens.weight"] = jnp.pad(state_dict["model.embed_tokens.weight"], ((0, 64), (0, 0)))
159
+
160
+ for k in list(state_dict.keys()):
161
+ state_dict[f"language_model.{k}"] = state_dict.pop(k)
162
+
163
+ prefix = "multi_modal_projector.mm_"
164
+ state_dict[f"{prefix}input_projection_weight"] = params["embedder"]["mm_input_projection"]["w"]
165
+ state_dict[f"{prefix}soft_emb_norm.weight"] = params["embedder"]["mm_soft_embedding_norm"]["scale"]
166
+
167
+ siglip_state_dict = convert_siglip(params["vision_encoder"]["siglip_encoder"])
168
+ for k, v in siglip_state_dict.items():
169
+ state_dict[f"vision_tower.vision_model.{k}"] = v
170
+
171
  return state_dict
172
 
173
 
 
175
  awq_state_dict = dict()
176
 
177
  for k, v in tqdm(state_dict.items(), total=len(state_dict)):
178
+ if (
179
+ k.endswith("model.embed_tokens.weight") # AWQ doesn't support INT4 embeddings
180
+ or k.startswith(("vision_tower", "multi_modal_projector")) # vision tower is not quantized
181
+ or v.ndim == 1
182
+ ):
183
  awq_state_dict[k] = v.astype(jnp.bfloat16)
184
  continue
185