File size: 2,543 Bytes
b3c4c5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
import sys
import safetensors

input_path = sys.argv[1]

# w = torch.load(input_path)

w = {}
with safetensors.safe_open(input_path, framework="pt") as f:
    for k in f.keys():
        w[k] = f.get_tensor(k)

w_new = {}
for k, v in w.items():
    k_orig = k
    k = k.replace('model.', '').replace('layers.', 'blocks.').replace('lm_head', 'head')
    k = k.replace('ffn_norm', 'ln2').replace('attn_norm', 'ln1').replace('pre_norm', 'ln0')
    k = k.replace('g_norm', 'ln_x')
    k = k.replace('norm', 'ln_out')
    k = k.replace('attn', 'att')
    k = k.replace('r_proj', 'receptance')
    k = k.replace('k_proj', 'key')
    k = k.replace('v_proj', 'value')
    k = k.replace('o_proj', 'output')
    if '_lora.lora.' in k and 'weight' in k:
        v = v.transpose(0, 1)
    k = k.replace('_lora.lora.2.bias', '0')
    k = k.replace('_lora.lora.2.weight', '2')
    k = k.replace('_lora.lora.0.weight', '1')

    if k == k_orig:
        print("untouched key: ", k)

    if 'att.x_x' in k:
        tensors = torch.split(v, 1, dim=0)
        names = ['r', 'w', 'k', 'v', 'a', 'g']
        for i in range(len(names)):
            w_new[k.replace('x_x', f'x_{names[i]}')] = tensors[i]
    else:
        w_new[k] = v

# print(w_new.keys())
# quit()

global_vocab_size = w_new['global_embedder.weight'].shape[0]
text_vocab_size = w_new['text_embedder.weight'].shape[0]
tts_tag_vocab_size = w_new['tts_tag_embedder.weight'].shape[0]
semantic_vocab_size = w_new['embeddings.weight'].shape[0]

# print(global_vocab_size, text_vocab_size, tts_tag_vocab_size, semantic_vocab_size)
# new embedding: | semantic 8193 | tts_tag 3 | global 4096 | text 65536 |

# del w_new['embeddings.weight']
w_new['emb.weight'] = torch.cat([w_new['embeddings.weight'], 
                                w_new['tts_tag_embedder.weight'], 
                                w_new['global_embedder.weight'], 
                                w_new['text_embedder.weight']], dim=0)

del w_new['text_embedder.weight']
del w_new['tts_tag_embedder.weight']
del w_new['global_embedder.weight']
del w_new['embeddings.weight']

print(w_new.keys())
# quit()

# print(w_new['emb.weight'].shape)
print(w_new['head.weight'].shape)
torch.save(w_new, sys.argv[1].replace('.safetensors', '_converted.pth'))

w_new['head.weight'] = torch.cat([w_new['head.weight'], torch.zeros((w_new['emb.weight'].shape[0]-w_new['head.weight'].shape[0], w_new['head.weight'].shape[1]))], dim=0)
print(w_new['head.weight'].shape)
torch.save(w_new, sys.argv[1].replace('.safetensors', '_padded.pth'))