RS2002 commited on
Commit
e715770
·
verified ·
1 Parent(s): 0c4c7a9

Upload 2 files

Browse files
Files changed (2) hide show
  1. Octuple.pkl +3 -0
  2. model.py +287 -0
Octuple.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d8afdebe6a0040bb98b998050e43916d6739b137d4872a31faa78b534e82e008
3
+ size 43862
model.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import random
4
+ import torch
5
+ import torch.nn as nn
6
+ from transformers import BartModel
7
+ import torch.nn.functional as F
8
+ from huggingface_hub import PyTorchModelHubMixin
9
+ import pickle
10
+ from transformers import BartConfig
11
+
12
+
13
+ class Embeddings(nn.Module):
14
+ def __init__(self, n_token, d_model):
15
+ super().__init__()
16
+ self.lut = nn.Embedding(n_token, d_model)
17
+ self.d_model = d_model
18
+
19
+ def forward(self, x):
20
+ return self.lut(x) * math.sqrt(self.d_model)
21
+
22
+
23
+ class PianoBart(nn.Module):
24
+ def __init__(self, bartConfig, e2w, w2e):
25
+ super().__init__()
26
+
27
+ self.bart = BartModel(bartConfig)
28
+ self.hidden_size = bartConfig.d_model
29
+ self.bartConfig = bartConfig
30
+
31
+ # token types: 0 Measure(第几个Bar(小节)), 1 Position(Bar中的位置), 2 Program(乐器), 3 Pitch(音高), 4 Duration(持续时间), 5 Velocity(力度), 6 TimeSig(拍号), 7 Tempo(速度)
32
+ self.n_tokens = [] # 每个属性的种类数
33
+ self.classes = ['Bar', 'Position', 'Instrument', 'Pitch', 'Duration', 'Velocity', 'TimeSig', 'Tempo']
34
+ for key in self.classes:
35
+ self.n_tokens.append(len(e2w[key]))
36
+ self.emb_sizes = [256] * 8
37
+ self.e2w = e2w
38
+ self.w2e = w2e
39
+
40
+ # for deciding whether the current input_ids is a <PAD> token
41
+ self.bar_pad_word = self.e2w['Bar']['Bar <PAD>']
42
+ self.mask_word_np = np.array([self.e2w[etype]['%s <MASK>' % etype] for etype in self.classes], dtype=np.int64)
43
+ self.pad_word_np = np.array([self.e2w[etype]['%s <PAD>' % etype] for etype in self.classes], dtype=np.int64)
44
+ self.sos_word_np = np.array([self.e2w[etype]['%s <SOS>' % etype] for etype in self.classes], dtype=np.int64)
45
+ self.eos_word_np = np.array([self.e2w[etype]['%s <EOS>' % etype] for etype in self.classes], dtype=np.int64)
46
+
47
+
48
+ # word_emb: embeddings to change token ids into embeddings
49
+ self.word_emb = []
50
+ for i, key in enumerate(self.classes): # 将每个特征都Embedding到256维,Embedding参数是可学习的
51
+ self.word_emb.append(Embeddings(self.n_tokens[i], self.emb_sizes[i]))
52
+ self.word_emb = nn.ModuleList(self.word_emb)
53
+
54
+ # linear layer to merge embeddings from different token types
55
+ self.encoder_linear = nn.Linear(np.sum(self.emb_sizes), bartConfig.d_model)
56
+ self.decoder_linear = self.encoder_linear
57
+ self.decoder_emb=None
58
+ #self.decoder_linear= nn.Linear(np.sum(self.emb_sizes), bartConfig.d_model)
59
+
60
+ def forward(self, input_ids_encoder, input_ids_decoder=None, encoder_attention_mask=None, decoder_attention_mask=None, output_hidden_states=True, generate=False):
61
+ encoder_embs = []
62
+ decoder_embs = []
63
+ for i, key in enumerate(self.classes):
64
+ encoder_embs.append(self.word_emb[i](input_ids_encoder[..., i]))
65
+ if self.decoder_emb is None and input_ids_decoder is not None:
66
+ decoder_embs.append(self.word_emb[i](input_ids_decoder[..., i]))
67
+ if self.decoder_emb is not None and input_ids_decoder is not None:
68
+ decoder_embs.append(self.decoder_emb(input_ids_decoder))
69
+ encoder_embs = torch.cat([*encoder_embs], dim=-1)
70
+ emb_linear_encoder = self.encoder_linear(encoder_embs)
71
+ if input_ids_decoder is not None:
72
+ decoder_embs = torch.cat([*decoder_embs], dim=-1)
73
+ emb_linear_decoder = self.decoder_linear(decoder_embs)
74
+ # feed to bart
75
+ if input_ids_decoder is not None:
76
+ y = self.bart(inputs_embeds=emb_linear_encoder, decoder_inputs_embeds=emb_linear_decoder, attention_mask=encoder_attention_mask, decoder_attention_mask=decoder_attention_mask, output_hidden_states=output_hidden_states) #attention_mask用于屏蔽<PAD> (PAD作用是在结尾补齐长度)
77
+ else:
78
+ y=self.bart.encoder(inputs_embeds=emb_linear_encoder,attention_mask=encoder_attention_mask)
79
+ return y
80
+
81
+ def get_rand_tok(self):
82
+ rand=[0]*8
83
+ for i in range(8):
84
+ rand[i]=random.choice(range(self.n_tokens[i]))
85
+ return np.array(rand)
86
+
87
+ def change_decoder_embedding(self,new_embedding,new_linear=None):
88
+ self.decoder_emb=new_embedding
89
+ if new_linear is not None:
90
+ self.decoder_linear=new_linear
91
+
92
+
93
+ class PianoBartLM(nn.Module):
94
+ def __init__(self, pianobart: PianoBart):
95
+ super().__init__()
96
+ self.pianobart = pianobart
97
+ self.mask_lm = MLM(self.pianobart.e2w, self.pianobart.n_tokens, self.pianobart.hidden_size)
98
+
99
+ def forward(self,input_ids_encoder, input_ids_decoder=None, encoder_attention_mask=None, decoder_attention_mask=None,generate=False,device_num=-1):
100
+ if not generate:
101
+ x = self.pianobart(input_ids_encoder, input_ids_decoder, encoder_attention_mask, decoder_attention_mask)
102
+ return self.mask_lm(x)
103
+ else:
104
+ if input_ids_encoder.shape[0] !=1:
105
+ print("ERROR")
106
+ exit(-1)
107
+ if device_num==-1:
108
+ device=torch.device('cpu')
109
+ else:
110
+ device=torch.device('cuda:'+str(device_num))
111
+ pad=torch.from_numpy(self.pianobart.pad_word_np)
112
+ input_ids_decoder=pad.repeat(input_ids_encoder.shape[0],input_ids_encoder.shape[1],1).to(device)
113
+ result=pad.repeat(input_ids_encoder.shape[0],input_ids_encoder.shape[1],1).to(device)
114
+ decoder_attention_mask=torch.zeros_like(encoder_attention_mask).to(device)
115
+ input_ids_decoder[:,0,:] = torch.tensor(self.pianobart.sos_word_np)
116
+ decoder_attention_mask[:,0] = 1
117
+ for i in range(input_ids_encoder.shape[1]):
118
+ # pbar = tqdm.tqdm(range(input_ids_encoder.shape[1]), disable=False)
119
+ # for i in pbar:
120
+ x = self.mask_lm(self.pianobart(input_ids_encoder, input_ids_decoder, encoder_attention_mask, decoder_attention_mask))
121
+ # outputs = []
122
+ # for j, etype in enumerate(self.pianobart.e2w):
123
+ # output = np.argmax(x[j].cpu().detach().numpy(), axis=-1)
124
+ # outputs.append(output)
125
+ # outputs = np.stack(outputs, axis=-1)
126
+ # outputs = torch.from_numpy(outputs)
127
+ # outputs=self.sample(x)
128
+ # if i!=input_ids_encoder.shape[1]-1:
129
+ # input_ids_decoder[:,i+1,:]=outputs[:,i,:]
130
+ # decoder_attention_mask[:,i+1]+=1
131
+ # result[:,i,:]=outputs[:,i,:]
132
+ current_output=self.sample(x,i)
133
+ # print(current_output)
134
+ if i!=input_ids_encoder.shape[1]-1:
135
+ input_ids_decoder[:,i+1,:]=current_output
136
+ decoder_attention_mask[:,i+1]+=1
137
+ # 为提升速度,提前终止生成
138
+ if (current_output>=pad).any():
139
+ break
140
+ result[:,i,:]=current_output
141
+ return result
142
+
143
+ def sample(self,x,index): # Adaptive Sampling Policy in CP Transformer
144
+ # token types: 0 Measure(第几个Bar(小节)), 1 Position(Bar中的位置), 2 Program(乐器), 3 Pitch(音高), 4 Duration(持续时间), 5 Velocity(力度), 6 TimeSig(拍号), 7 Tempo(速度)
145
+ t=[1.2,1.2,5,1,2,5,5,1.2]
146
+ p=[1,1,1,0.9,0.9,1,1,0.9]
147
+ result=[]
148
+ for j, etype in enumerate(self.pianobart.e2w):
149
+ y=x[j]
150
+ y=y[:,index,:]
151
+ y=sampling(y,p[j],t[j])
152
+ result.append(y)
153
+ return torch.tensor(result)
154
+
155
+
156
+ # -- nucleus -- #
157
+ def nucleus(probs, p):
158
+ probs /= (sum(probs) + 1e-5)
159
+ sorted_probs = np.sort(probs)[::-1]
160
+ sorted_index = np.argsort(probs)[::-1]
161
+ cusum_sorted_probs = np.cumsum(sorted_probs)
162
+ after_threshold = cusum_sorted_probs > p
163
+ if sum(after_threshold) > 0:
164
+ last_index = np.where(after_threshold)[0][0] + 1
165
+ candi_index = sorted_index[:last_index]
166
+ else:
167
+ candi_index = sorted_index[0:1]
168
+ candi_probs = [probs[i] for i in candi_index]
169
+ candi_probs /= sum(candi_probs)
170
+ word = np.random.choice(candi_index, size=1, p=candi_probs)[0]
171
+ return word
172
+
173
+
174
+ def sampling(logit, p=None, t=1.0):
175
+ logit = logit.squeeze()
176
+ probs = torch.softmax(logit/t,dim=-1)
177
+ probs=probs.cpu().detach().numpy()
178
+ cur_word = nucleus(probs, p=p)
179
+ return cur_word
180
+
181
+
182
+ class MLM(nn.Module):
183
+ def __init__(self, e2w, n_tokens, hidden_size):
184
+ super().__init__()
185
+ self.proj = []
186
+ for i, etype in enumerate(e2w):
187
+ self.proj.append(nn.Linear(hidden_size, n_tokens[i]))
188
+ self.proj = nn.ModuleList(self.proj)
189
+ self.e2w = e2w
190
+
191
+ def forward(self, y):
192
+ y = y.last_hidden_state
193
+ ys = []
194
+ for i, etype in enumerate(self.e2w):
195
+ ys.append(self.proj[i](y))
196
+ return ys
197
+
198
+
199
+ class SelfAttention(nn.Module):
200
+ def __init__(self, input_dim, da, r):
201
+ '''
202
+ Args:
203
+ input_dim (int): batch, seq, input_dim
204
+ da (int): number of features in hidden layer from self-attn
205
+ r (int): number of aspects of self-attn
206
+ '''
207
+ super(SelfAttention, self).__init__()
208
+ self.ws1 = nn.Linear(input_dim, da, bias=False)
209
+ self.ws2 = nn.Linear(da, r, bias=False)
210
+
211
+ def forward(self, h):
212
+ attn_mat = F.softmax(self.ws2(torch.tanh(self.ws1(h))), dim=1)
213
+ attn_mat = attn_mat.permute(0,2,1)
214
+ return attn_mat
215
+
216
+
217
+ class SequenceClassification(nn.Module):
218
+ def __init__(self, pianobart, class_num, hs, da=128, r=4):
219
+ super().__init__()
220
+ self.pianobart = pianobart
221
+ self.attention = SelfAttention(hs, da, r)
222
+ self.classifier = nn.Sequential(
223
+ nn.Dropout(0.1),
224
+ nn.Linear(hs*r, 256),
225
+ nn.ReLU(),
226
+ nn.Linear(256, class_num)
227
+ )
228
+
229
+ def forward(self, input_ids_encoder, encoder_attention_mask=None):
230
+ # y_shift = torch.zeros_like(input_ids_encoder)
231
+ # y_shift[:, 1:, :] = input_ids_encoder[:, :-1, :]
232
+ # y_shift[:, 0, :] = torch.tensor(self.pianobart.sos_word_np)
233
+ # attn_shift = torch.zeros_like(encoder_attention_mask)
234
+ # attn_shift[:, 1:] = encoder_attention_mask[:, :-1]
235
+ # attn_shift[:, 0] = encoder_attention_mask[:, 0]
236
+ # x = self.pianobart(input_ids_encoder=input_ids_encoder,input_ids_decoder=y_shift,encoder_attention_mask=encoder_attention_mask,decoder_attention_mask=attn_shift)
237
+
238
+ x = self.pianobart(input_ids_encoder=input_ids_encoder,input_ids_decoder=input_ids_encoder,encoder_attention_mask=encoder_attention_mask,decoder_attention_mask=encoder_attention_mask)
239
+
240
+ x = x.last_hidden_state
241
+ attn_mat = self.attention(x)
242
+ m = torch.bmm(attn_mat, x)
243
+ flatten = m.view(m.size()[0], -1)
244
+ res = self.classifier(flatten)
245
+ return res
246
+
247
+
248
+ class TokenClassification(nn.Module):
249
+ def __init__(self, pianobart, class_num, hs):
250
+ super().__init__()
251
+ self.pianobart = pianobart
252
+ self.classifier = nn.Sequential(
253
+ nn.Dropout(0.1),
254
+ nn.Linear(hs, 256),
255
+ nn.ReLU(),
256
+ nn.Linear(256, class_num)
257
+ )
258
+
259
+ def forward(self, input_ids_encoder, input_ids_decoder, encoder_attention_mask=None, decoder_attention_mask=None):
260
+ x = self.pianobart(input_ids_encoder, input_ids_decoder, encoder_attention_mask, decoder_attention_mask)
261
+ x = x.last_hidden_state
262
+ res = self.classifier(x)
263
+ return res
264
+
265
+
266
+ class PianoBART(
267
+ nn.Module,
268
+ PyTorchModelHubMixin
269
+ ):
270
+ def __init__(self, max_position_embeddings=1024, hidden_size=1024, layers=8, heads=8, ffn_dims=2048):
271
+ super().__init__()
272
+ with open("./Octuple.pkl", 'rb') as f:
273
+ self.e2w, self.w2e = pickle.load(f)
274
+ self.config = BartConfig(max_position_embeddings=max_position_embeddings,
275
+ d_model=hidden_size,
276
+ encoder_layers=layers,
277
+ encoder_ffn_dim=ffn_dims,
278
+ encoder_attention_heads=heads,
279
+ decoder_layers=layers,
280
+ decoder_ffn_dim=ffn_dims,
281
+ decoder_attention_heads=heads
282
+ )
283
+ self.model = PianoBart(bartConfig=self.config, e2w=self.e2w, w2e=self.w2e)
284
+
285
+
286
+ def forward(self, input_ids_encoder, input_ids_decoder=None, encoder_attention_mask=None, decoder_attention_mask=None, output_hidden_states=True, generate=False):
287
+ return self.model(input_ids_encoder,input_ids_decoder,encoder_attention_mask,decoder_attention_mask,output_hidden_states,generate=False)