respark / epoch2 /convert_rwkv.py
yueyulin's picture
Upload folder using huggingface_hub
b3c4c5d verified
import torch
import sys
import safetensors
input_path = sys.argv[1]
# w = torch.load(input_path)
w = {}
with safetensors.safe_open(input_path, framework="pt") as f:
for k in f.keys():
w[k] = f.get_tensor(k)
w_new = {}
for k, v in w.items():
k_orig = k
k = k.replace('model.', '').replace('layers.', 'blocks.').replace('lm_head', 'head')
k = k.replace('ffn_norm', 'ln2').replace('attn_norm', 'ln1').replace('pre_norm', 'ln0')
k = k.replace('g_norm', 'ln_x')
k = k.replace('norm', 'ln_out')
k = k.replace('attn', 'att')
k = k.replace('r_proj', 'receptance')
k = k.replace('k_proj', 'key')
k = k.replace('v_proj', 'value')
k = k.replace('o_proj', 'output')
if '_lora.lora.' in k and 'weight' in k:
v = v.transpose(0, 1)
k = k.replace('_lora.lora.2.bias', '0')
k = k.replace('_lora.lora.2.weight', '2')
k = k.replace('_lora.lora.0.weight', '1')
if k == k_orig:
print("untouched key: ", k)
if 'att.x_x' in k:
tensors = torch.split(v, 1, dim=0)
names = ['r', 'w', 'k', 'v', 'a', 'g']
for i in range(len(names)):
w_new[k.replace('x_x', f'x_{names[i]}')] = tensors[i]
else:
w_new[k] = v
# print(w_new.keys())
# quit()
global_vocab_size = w_new['global_embedder.weight'].shape[0]
text_vocab_size = w_new['text_embedder.weight'].shape[0]
tts_tag_vocab_size = w_new['tts_tag_embedder.weight'].shape[0]
semantic_vocab_size = w_new['embeddings.weight'].shape[0]
# print(global_vocab_size, text_vocab_size, tts_tag_vocab_size, semantic_vocab_size)
# new embedding: | semantic 8193 | tts_tag 3 | global 4096 | text 65536 |
# del w_new['embeddings.weight']
w_new['emb.weight'] = torch.cat([w_new['embeddings.weight'],
w_new['tts_tag_embedder.weight'],
w_new['global_embedder.weight'],
w_new['text_embedder.weight']], dim=0)
del w_new['text_embedder.weight']
del w_new['tts_tag_embedder.weight']
del w_new['global_embedder.weight']
del w_new['embeddings.weight']
print(w_new.keys())
# quit()
# print(w_new['emb.weight'].shape)
print(w_new['head.weight'].shape)
torch.save(w_new, sys.argv[1].replace('.safetensors', '_converted.pth'))
w_new['head.weight'] = torch.cat([w_new['head.weight'], torch.zeros((w_new['emb.weight'].shape[0]-w_new['head.weight'].shape[0], w_new['head.weight'].shape[1]))], dim=0)
print(w_new['head.weight'].shape)
torch.save(w_new, sys.argv[1].replace('.safetensors', '_padded.pth'))