import os import glob import json import torch import shutil import numpy as np from pathlib import Path #from transformers import AutoModelForCausalLM, AutoTokenizer from safetensors.torch import safe_open, save_file from typing import Any, Dict, List, Optional, Union # interpolation from mergekit # thanks charles! def normalize(v: np.ndarray, eps: float): norm_v = np.linalg.norm(v) if norm_v > eps: v = v / norm_v return v def lerp( t: float, v0: Union[np.ndarray, torch.Tensor], v1: Union[np.ndarray, torch.Tensor] ) -> Union[np.ndarray, torch.Tensor]: return (1 - t) * v0 + t * v1 def slerp( t: Union[float, np.ndarray], v0: Union[np.ndarray, torch.Tensor], v1: Union[np.ndarray, torch.Tensor], DOT_THRESHOLD: float = 0.9995, eps: float = 1e-8, ): """ Spherical linear interpolation From: https://gist.github.com/dvschultz/3af50c40df002da3b751efab1daddf2c Args: t (float/np.ndarray): Float value between 0.0 and 1.0 v0 (np.ndarray): Starting vector v1 (np.ndarray): Final vector DOT_THRESHOLD (float): Threshold for considering the two vectors as colinear. Not recommended to alter this. Returns: v2 (np.ndarray): Interpolation vector between v0 and v1 """ is_torch = False if not isinstance(v0, np.ndarray): is_torch = True v0 = v0.detach().cpu().float().numpy() if not isinstance(v1, np.ndarray): is_torch = True v1 = v1.detach().cpu().float().numpy() # Copy the vectors to reuse them later v0_copy = np.copy(v0) v1_copy = np.copy(v1) # Normalize the vectors to get the directions and angles v0 = normalize(v0, eps) v1 = normalize(v1, eps) # Dot product with the normalized vectors (can't use np.dot in W) dot = np.sum(v0 * v1) # If absolute value of dot product is almost 1, vectors are ~colinear, so use lerp if np.abs(dot) > DOT_THRESHOLD: res = lerp(t, v0_copy, v1_copy) return maybe_torch(res, is_torch) # Calculate initial angle between v0 and v1 theta_0 = np.arccos(dot) sin_theta_0 = np.sin(theta_0) # Angle at timestep t theta_t = theta_0 * t sin_theta_t = np.sin(theta_t) # Finish the slerp algorithm s0 = np.sin(theta_0 - theta_t) / sin_theta_0 s1 = sin_theta_t / sin_theta_0 res = s0 * v0_copy + s1 * v1_copy return maybe_torch(res, is_torch) def maybe_torch(v: np.ndarray, is_torch: bool): if is_torch: return torch.from_numpy(v) return v # move layer indices backwards to make room for inserted layer def move_layer_back(model_dict, num_hidden_layers, layer_keys, layer_num, t): # just rename the keys print(f"move_layer_back {layer_keys[layer_num]}") d = [] for k in layer_keys[layer_num]: tensor = model_dict[k] # loop backwards through the layers, increasing the index # by one until the insertion layer has been reached # model.layers.0.mlp.down_proj -> model.layers.1.mlp.down_proj # .weight + .bias (for qwen) if k.startswith(f'model.layers.{layer_num}.'): tensor_suffix = k[len(f'model.layers.{layer_num}.'):] tensor_cur_prefix = f'model.layers.{layer_num}.' tensor_next_prefix = f'model.layers.{layer_num+1}.' tensor_prev_prefix = f'model.layers.{layer_num-1}.' model_dict[tensor_next_prefix + tensor_suffix] = tensor del model_dict[k] d.append(tensor_next_prefix + tensor_suffix) #print(layer_keys[layer_num]) layer_keys[layer_num+1] = d #print(layer_keys[layer_num+1]) #import pprint #pprint.pp(model_dict) # given a dict of tensors, a key, and layer_num, # return the tensor at previous layer's version of key def get_prev_tensor(model_dict, key, layer_num): if key.startswith(f'model.layers.{layer_num}.'): suffix = key[len(f'model.layers.{layer_num}.'):] cur_prefix = f'model.layers.{layer_num}.' prev_prefix = f'model.layers.{layer_num-1}.' return model_dict[prev_prefix + suffix] return None # given a dict of tensors, a key, and layer_num, # return the tensor at the next layer's version of key def get_next_tensor(model_dict, key, layer_num): if key.startswith(f'model.layers.{layer_num}.'): suffix = key[len(f'model.layers.{layer_num}.'):] cur_prefix = f'model.layers.{layer_num}.' next_prefix = f'model.layers.{layer_num+1}.' return model_dict[next_prefix + suffix] return None def insert_layer(model_dict, num_hidden_layers, layer_keys, layer_num, t=0.5, out_scale=0.4, scale=None): print(f"inserting layer between {layer_num-1} and {layer_num} [t={t}]") # need to move all layers after the insertion point for i in range(num_hidden_layers, layer_num, -1): #print(i) move_layer_back(model_dict, num_hidden_layers, layer_keys, i - 1, t) # now merge layer+1 with layer-1 and save to layer # (because everything got moved back) for k in layer_keys[layer_num]: #print(k) tensor = get_next_tensor(model_dict, k, layer_num) prev_tensor = get_prev_tensor(model_dict, k, layer_num) merge_tensor = lerp(t, prev_tensor, tensor) if scale is not None: merge_tensor = merge_tensor * scale print(f"merging {layer_num-1} w/ {layer_num+1}") #merge_tensor = slerp(t, prev_tensor, tensor) if k.endswith("mlp.down_proj.weight"): merge_tensor = merge_tensor*out_scale if k.endswith("mlp.o_proj.weight"): merge_tensor = merge_tensor*out_scale if k.endswith(".bias"): merge_tensor = merge_tensor*out_scale model_dict[k] = merge_tensor def get_dtype_size_in_bytes(tensor): dtype = tensor.dtype if dtype == torch.float32: size_in_bytes = tensor.numel() * 4 elif dtype == torch.float64: size_in_bytes = tensor.numel() * 8 elif dtype == torch.int32: size_in_bytes = tensor.numel() * 4 elif dtype == torch.int64: size_in_bytes = tensor.numel() * 8 elif dtype == torch.bool: size_in_bytes = tensor.numel() * 1 else: size_in_bytes = 0 return size_in_bytes model_name = 'BAAI/Emu3-Gen' dir_name = './' #dir_name = None conf = {} with open(Path(dir_name or model_name) / 'config.json') as f: conf = json.load(f) st_dict = {} tensor_dict = {} if (Path(dir_name) / 'model.safetensors.index.json').is_file(): with open(Path(dir_name or model_name) / 'model.safetensors.index.json') as f: st_index = json.load(f) tensors = st_index['weight_map'].keys() files = [] for name in tensors: if st_index['weight_map'][name] not in files: files.append(st_index['weight_map'][name]) #print(files) for st in files: tensor_dict = safe_open(st, framework='pt') for k in tensor_dict.keys(): st_dict[k] = tensor_dict.get_tensor(k) #print(st_dict) elif (Path(dir_name) / 'model.safetensors').is_file(): model_fn = 'model.safetensors' tensor_dict = safe_open(model_fn, framework='pt') for k in tensor_dict.keys(): st_dict[k] = tensor_dict.get_tensor(k) file_dict = {'model.safetensors': st_dict} else: print("please convert to safetensors") sys.exit(-1) print(conf) num_hidden_layers = conf['num_hidden_layers'] print(num_hidden_layers) model = {} #sys.exit(-1) #for k in tensor_dict.keys(): #model[k] = tensor_dict.get_tensor(k) #print(tensor_dict.keys()) #import pprint #pprint.pp(model) #layer = 0 layer_keys = {} for layer in range(num_hidden_layers): #layer_keys[layer] = [k for k in sorted(tensor_dict.keys()) if k.startswith(f'model.layers.{layer}.')] layer_keys[layer] = [k for k in sorted(st_dict.keys()) if k.startswith(f'model.layers.{layer}.')] for k in layer_keys.keys(): print(f"Layer {k}") print(layer_keys[k]) print("") insert_layer(st_dict, num_hidden_layers, layer_keys, 24, 0.5, 0.35, scale=None) num_hidden_layers += 1 insert_layer(st_dict, num_hidden_layers, layer_keys, 23, 0.5, 0.35, scale=None) num_hidden_layers += 1 insert_layer(st_dict, num_hidden_layers, layer_keys, 22, 0.5, 0.35, scale=None) num_hidden_layers += 1 insert_layer(st_dict, num_hidden_layers, layer_keys, 16, 0.5, 0.35, scale=None) num_hidden_layers += 1 insert_layer(st_dict, num_hidden_layers, layer_keys, 15, 0.5, 0.35, scale=None) num_hidden_layers += 1 insert_layer(st_dict, num_hidden_layers, layer_keys, 14, 0.5, 0.35, scale=None) num_hidden_layers += 1 insert_layer(st_dict, num_hidden_layers, layer_keys, 13, 0.5, 0.35, scale=None) num_hidden_layers += 1 insert_layer(st_dict, num_hidden_layers, layer_keys, 12, 0.5, 0.35, scale=None) num_hidden_layers += 1 insert_layer(st_dict, num_hidden_layers, layer_keys, 11, 0.5, 0.35, scale=None) num_hidden_layers += 1 insert_layer(st_dict, num_hidden_layers, layer_keys, 11, 0.5, 0.35, scale=None) num_hidden_layers += 1 insert_layer(st_dict, num_hidden_layers, layer_keys, 10, 0.5, 0.35, scale=None) num_hidden_layers += 1 insert_layer(st_dict, num_hidden_layers, layer_keys, 9, 0.5, 0.35, scale=None) num_hidden_layers += 1 insert_layer(st_dict, num_hidden_layers, layer_keys, 8, 0.5, 0.35, scale=None) num_hidden_layers += 1 insert_layer(st_dict, num_hidden_layers, layer_keys, 7, 0.5, 0.35, scale=None) num_hidden_layers += 1 os.makedirs("original", exist_ok=True) #shutil.copy("model.safetensors", "original") shutil.copy("config.json", "original") #save_file(st_dict, "model.safetensors", metadata={"format": "pt"}) max_shard_size = 5000000000 current_shard_size = 0 current_shard_index = 0 shard_dict = {} current_shard = {} shard_names = list(st_dict.keys()) byte_sum = 0 param_sum = 0 params = {k: st_dict[k].numel() for k in st_dict.keys()} tensor_size = {k: get_dtype_size_in_bytes(st_dict[k]) for k in st_dict.keys()} for p in params.keys(): param_sum += params[p] byte_sum += tensor_size[p] print(f"total params: {param_sum}") print(f"total size in bytes: {byte_sum}") if 'lm_head.weight' in shard_names: tensor_name = 'lm_head.weight' current_shard[tensor_name] = st_dict[tensor_name] current_shard_size += tensor_size[tensor_name] # for i in range(len(shard_names)): # if shard_names[i] == tensor_name: # del shard_names[i] # break layers = {} for i in range(num_hidden_layers): current_sizes = {} layers[i] = [k for k in shard_names if k.startswith(f"model.layers.{i}.")] for t in layers[i]: #current_shard[t] = st_dict[t] #size = get_dtype_size_in_bytes(st_dict[t]) #current_sizes[t] = size current_sizes[t] = tensor_size[t] for i in range(len(shard_names)): if shard_names[i] == tensor_name: del shard_names[i] break z = [k for k in shard_names if k.startswith(f"model.layers.")] z.append("lm_head.weight") remnants = list(set(shard_names) - set(z)) print(f"remnants size: {len(remnants)}") print(remnants) layer_size = 0 for l in layers[0]: layer_size += tensor_size[l] print(f"total size of tensors in a single layer: {layer_size}") for i in range(num_hidden_layers): print(f"current_shard_size: {current_shard_size}") print(f"layer_size: {layer_size}") print(f"max_shard_size: {max_shard_size}") if current_shard_size + layer_size >= max_shard_size: print(current_shard.keys()) # write shard print(f"writing xmodel-{current_shard_index}.safetensors") save_file(current_shard, f"xmodel-{current_shard_index}.safetensors", metadata={"format": "pt"}) shard_dict[current_shard_index] = current_shard.copy() current_shard_size = 0 current_shard_index += 1 current_shard = {} print(f"wrote xmodel-{current_shard_index}.safetensors") for t in layers[i]: print(f"shard: {t}") current_shard[t] = st_dict[t] current_shard_size += tensor_size[t] print("") print(shard_names) print("") print("") print(current_shard.keys()) # add remnants for x in remnants: remnant_size = get_dtype_size_in_bytes(st_dict[x]) if current_shard_size + remnant_size < max_shard_size: current_shard[x] = st_dict[x] for i in range(len(remnants)): if remnants[i] == tensor_name: del remnants[i] break # write shard print(f"writing xmodel-{current_shard_index}.safetensors") save_file(current_shard, f"xmodel-{current_shard_index}.safetensors", metadata={"format": "pt"}) shard_dict[current_shard_index] = current_shard.copy() current_shard_size = 0 current_shard_index += 1 current_shard = {} print(f"wrote xmodel-{current_shard_index}.safetensors") for x in remnants: current_shard[x] = st_dict[x] if len(remnants) > 0: # write shard print(f"writing xmodel-{current_shard_index}.safetensors") save_file(current_shard, f"xmodel-{current_shard_index}.safetensors", metadata={"format": "pt"}) shard_dict[current_shard_index] = current_shard.copy() current_shard_size = 0 current_shard_index += 1 #current_shard = {} print(f"wrote xmodel-{current_shard_index-1}.safetensors") # move safetensors to original print("Moving old safetensors to old/") unsorted_files = glob.glob("model-*-of-*.safetensors") files = sorted(unsorted_files) os.makedirs("old", exist_ok=True) shutil.copy("config.json", "old") for file in files: Path("old/" + file).unlink() shutil.move(file, "old") Path("old/model.safetensors.index.json").unlink() shutil.move("model.safetensors.index.json", "old") # move xmodel to safetensors for idx in range(current_shard_index): if Path(f"xmodel-{idx}.safetensors").is_file(): shutil.move(f"xmodel-{idx}.safetensors", f"model-{idx+1:05}-of-{current_shard_index:05}.safetensors") # write safetensor index wmap = {} index = {} for idx in range(current_shard_index): #print(idx) ts = shard_dict[idx].keys() for tname in ts: wmap[tname] = f"model-{idx+1:05}-of-{current_shard_index:05}.safetensors" index['metadata'] = {'total_size': param_sum} index['weight_map'] = wmap with open("model.safetensors.index.json", "w") as f: json.dump(index, f, indent=4) conf['num_hidden_layers'] = num_hidden_layers with open(Path(dir_name or model_name) / 'config.json', "w") as f: json.dump(conf, f, indent=4)