silk-road commited on
Commit
89f47a8
·
1 Parent(s): c53280c

Create models.py

Browse files
Files changed (1) hide show
  1. models.py +527 -0
models.py ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.distributed as dist
5
+
6
+ from simcse.modeling_glm import GLMModel, GLMPreTrainedModel
7
+ import simcse.mse_loss
8
+
9
+ import transformers
10
+ from transformers import RobertaTokenizer, AutoModel, PreTrainedModel
11
+ from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel, RobertaModel, RobertaLMHead
12
+ from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel, BertLMPredictionHead
13
+ from transformers.activations import gelu
14
+ from transformers.file_utils import (
15
+ add_code_sample_docstrings,
16
+ add_start_docstrings,
17
+ add_start_docstrings_to_model_forward,
18
+ replace_return_docstrings,
19
+ )
20
+ from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutputWithPoolingAndCrossAttentions
21
+
22
+ glm_model = None
23
+
24
+ def init_glm(path):
25
+ global glm_model
26
+ glm_model = GLMModel.from_pretrained(path, trust_remote_code=True).to("cuda:0")
27
+ for param in glm_model.parameters():
28
+ param.requires_grad = False
29
+
30
+
31
+
32
+ class MLPLayer(nn.Module):
33
+ """
34
+ Head for getting sentence representations over RoBERTa/BERT's CLS representation.
35
+ """
36
+
37
+ def __init__(self, config):
38
+ super().__init__()
39
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
40
+ # 1536
41
+ self.fc = nn.Linear(config.hidden_size, 1536)
42
+ self.activation = nn.Tanh()
43
+
44
+ def forward(self, features, **kwargs):
45
+ x = self.dense(features)
46
+ x = self.fc(x)
47
+ x = self.activation(x)
48
+
49
+ return x
50
+
51
+ class Similarity(nn.Module):
52
+ """
53
+ Dot product or cosine similarity
54
+ """
55
+
56
+ def __init__(self, temp):
57
+ super().__init__()
58
+ self.temp = temp
59
+ self.cos = nn.CosineSimilarity(dim=-1)
60
+
61
+ def forward(self, x, y):
62
+ return self.cos(x, y) / self.temp
63
+
64
+
65
+ class Pooler(nn.Module):
66
+ """
67
+ Parameter-free poolers to get the sentence embedding
68
+ 'cls': [CLS] representation with BERT/RoBERTa's MLP pooler.
69
+ 'cls_before_pooler': [CLS] representation without the original MLP pooler.
70
+ 'avg': average of the last layers' hidden states at each token.
71
+ 'avg_top2': average of the last two layers.
72
+ 'avg_first_last': average of the first and the last layers.
73
+ """
74
+
75
+ def __init__(self, pooler_type):
76
+ super().__init__()
77
+ self.pooler_type = pooler_type
78
+ assert self.pooler_type in ["cls", "cls_before_pooler", "avg", "avg_top2",
79
+ "avg_first_last"], "unrecognized pooling type %s" % self.pooler_type
80
+
81
+ def forward(self, attention_mask, outputs):
82
+ last_hidden = outputs.last_hidden_state
83
+ # pooler_output = outputs.pooler_output
84
+ hidden_states = outputs.hidden_states
85
+
86
+ if self.pooler_type in ['cls_before_pooler', 'cls']:
87
+ return last_hidden[:, 0]
88
+ elif self.pooler_type == "avg":
89
+ return ((last_hidden * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1))
90
+ elif self.pooler_type == "avg_first_last":
91
+ first_hidden = hidden_states[1]
92
+ last_hidden = hidden_states[-1]
93
+ pooled_result = ((first_hidden + last_hidden) / 2.0 * attention_mask.unsqueeze(-1)).sum(
94
+ 1) / attention_mask.sum(-1).unsqueeze(-1)
95
+ return pooled_result
96
+ elif self.pooler_type == "avg_top2":
97
+ second_last_hidden = hidden_states[-2]
98
+ last_hidden = hidden_states[-1]
99
+ pooled_result = ((last_hidden + second_last_hidden) / 2.0 * attention_mask.unsqueeze(-1)).sum(
100
+ 1) / attention_mask.sum(-1).unsqueeze(-1)
101
+ return pooled_result
102
+ else:
103
+ raise NotImplementedError
104
+
105
+
106
+ def cl_init(cls, config):
107
+ """
108
+ Contrastive learning class init function.
109
+ """
110
+ cls.pooler_type = cls.model_args.pooler_type
111
+ cls.pooler = Pooler(cls.model_args.pooler_type)
112
+ if cls.model_args.pooler_type == "cls":
113
+ cls.mlp = MLPLayer(config)
114
+ cls.sim = Similarity(temp=cls.model_args.temp)
115
+ cls.init_weights()
116
+
117
+
118
+ def cl_forward(cls,
119
+ encoder,
120
+ input_ids=None,
121
+ attention_mask=None,
122
+ token_type_ids=None,
123
+ position_ids=None,
124
+ head_mask=None,
125
+ inputs_embeds=None,
126
+ labels=None,
127
+ output_attentions=None,
128
+ output_hidden_states=None,
129
+ return_dict=None,
130
+ mlm_input_ids=None,
131
+ mlm_labels=None,
132
+ left_emb=None,
133
+ right_emb=None,
134
+ kl_loss=False
135
+ ):
136
+ return_dict = return_dict if return_dict is not None else cls.config.use_return_dict
137
+ ori_input_ids = input_ids
138
+ batch_size = input_ids.size(0)
139
+ # Number of sentences in one instance
140
+ # 2: pair instance; 3: pair instance with a hard negative
141
+ num_sent = input_ids.size(1)
142
+
143
+ mlm_outputs = None
144
+ # Flatten input for encoding
145
+ input_ids = input_ids.view((-1, input_ids.size(-1))) # (bs * num_sent, len)
146
+ attention_mask = attention_mask.view((-1, attention_mask.size(-1))) # (bs * num_sent len)
147
+ if token_type_ids is not None:
148
+ token_type_ids = token_type_ids.view((-1, token_type_ids.size(-1))) # (bs * num_sent, len)
149
+
150
+ if inputs_embeds is not None:
151
+ input_ids = None
152
+
153
+ # Get raw embeddings
154
+ outputs = encoder(
155
+ input_ids,
156
+ attention_mask=attention_mask,
157
+ token_type_ids=token_type_ids,
158
+ position_ids=position_ids,
159
+ head_mask=head_mask,
160
+ inputs_embeds=inputs_embeds,
161
+ output_attentions=output_attentions,
162
+ output_hidden_states=True if cls.model_args.pooler_type in ['avg_top2', 'avg_first_last'] else False,
163
+ return_dict=True,
164
+ )
165
+
166
+ # MLM auxiliary objective
167
+ if mlm_input_ids is not None:
168
+ mlm_input_ids = mlm_input_ids.view((-1, mlm_input_ids.size(-1)))
169
+ mlm_outputs = encoder(
170
+ mlm_input_ids,
171
+ attention_mask=attention_mask,
172
+ token_type_ids=token_type_ids,
173
+ position_ids=position_ids,
174
+ head_mask=head_mask,
175
+ inputs_embeds=inputs_embeds,
176
+ output_attentions=output_attentions,
177
+ output_hidden_states=True if cls.model_args.pooler_type in ['avg_top2', 'avg_first_last'] else False,
178
+ return_dict=True,
179
+ )
180
+
181
+ # Pooling
182
+ print(outputs.last_hidden_state.shape)
183
+ pooler_output = cls.pooler(attention_mask, outputs)
184
+ print(pooler_output.shape)
185
+ pooler_output = pooler_output.view((batch_size, num_sent, pooler_output.size(-1))) # (bs, num_sent, hidden)
186
+ # If using "cls", we add an extra MLP layer
187
+ # (same as BERT's original implementation) over the representation.
188
+ if cls.pooler_type == "cls":
189
+ pooler_output = cls.mlp(pooler_output)
190
+ # print("QAQ")
191
+
192
+ # Separate representation
193
+ z1, z2 = pooler_output[:, 0], pooler_output[:, 1]
194
+
195
+ tensor_left = left_emb
196
+ tensor_right = right_emb
197
+
198
+ # Hard negative
199
+ if num_sent == 3:
200
+ z3 = pooler_output[:, 2]
201
+
202
+ # Gather all embeddings if using distributed training
203
+ if dist.is_initialized() and cls.training:
204
+ # Gather hard negative
205
+ if num_sent >= 3:
206
+ z3_list = [torch.zeros_like(z3) for _ in range(dist.get_world_size())]
207
+ dist.all_gather(tensor_list=z3_list, tensor=z3.contiguous())
208
+ z3_list[dist.get_rank()] = z3
209
+ z3 = torch.cat(z3_list, 0)
210
+
211
+ # Dummy vectors for allgather
212
+ z1_list = [torch.zeros_like(z1) for _ in range(dist.get_world_size())]
213
+ z2_list = [torch.zeros_like(z2) for _ in range(dist.get_world_size())]
214
+ # Allgather
215
+ dist.all_gather(tensor_list=z1_list, tensor=z1.contiguous())
216
+ dist.all_gather(tensor_list=z2_list, tensor=z2.contiguous())
217
+
218
+ # Since allgather results do not have gradients, we replace the
219
+ # current process's corresponding embeddings with original tensors
220
+ z1_list[dist.get_rank()] = z1
221
+ z2_list[dist.get_rank()] = z2
222
+ # Get full batch embeddings: (bs x N, hidden)
223
+ z1 = torch.cat(z1_list, 0)
224
+ z2 = torch.cat(z2_list, 0)
225
+
226
+ mse_loss = F.mse_loss(z1, tensor_left) + F.mse_loss(z2, tensor_right)
227
+
228
+ # softmax_row, softmax_col = simcse.mse_loss.giveMeMatrix(tensor_left, tensor_right)
229
+ # softmax_row_model, softmax_col_model = simcse.mse_loss.giveMeMatrix(z1,z2)
230
+ # ziang_labels = torch.tensor([i for i in range(8)], device='cuda:0')
231
+
232
+ """
233
+ this is KL div loss
234
+ """
235
+
236
+ KL_loss = nn.KLDivLoss(reduction="batchmean")
237
+ beta = 5
238
+
239
+ # openai的embed,giveMeMatrix返回一个normalized过前后向量,相乘后的矩阵
240
+ cos_sim_matrix_openai = simcse.mse_loss.giveMeMatrix(tensor_left, tensor_right)
241
+ beta_scaled_cos_sim_matrix_openai = beta * cos_sim_matrix_openai
242
+
243
+ # 我们的embed,giveMeMatrix返回一个normalized过前后向量,相乘后的矩阵
244
+ cos_sim_matrix_data = simcse.mse_loss.giveMeMatrix(z1, z2)
245
+ beta_scaled_cos_sim_matrix_data = beta * cos_sim_matrix_data
246
+
247
+ beta_scaled_cos_sim_matrix_openai_vertical = beta_scaled_cos_sim_matrix_openai.softmax(dim=1)
248
+ beta_scaled_cos_sim_matrix_openai_horizontal = beta_scaled_cos_sim_matrix_openai.softmax(dim=0)
249
+
250
+ beta_scaled_cos_sim_matrix_data_vertical = beta_scaled_cos_sim_matrix_data.softmax(dim=1)
251
+ beta_scaled_cos_sim_matrix_data_horizontal = beta_scaled_cos_sim_matrix_data.softmax(dim=0)
252
+
253
+ # remove reduction="batchmean"
254
+ KL_vertical_loss = KL_loss(beta_scaled_cos_sim_matrix_data_vertical.log(), beta_scaled_cos_sim_matrix_openai_vertical)
255
+ KL_horizontal_loss = KL_loss(beta_scaled_cos_sim_matrix_data_horizontal.log(), beta_scaled_cos_sim_matrix_openai_horizontal)
256
+
257
+ KL_loss = (KL_vertical_loss + KL_horizontal_loss) / 2
258
+
259
+ # KL_row_loss = F.kl_div(softmax_row_model.log(), softmax_row, reduction='batchmean')
260
+ # KL_col_loss = F.kl_div(softmax_col_model.log(), softmax_col, reduction='batchmean')
261
+ # KL_loss = (KL_row_loss + KL_col_loss) / 2
262
+
263
+ ziang_loss = KL_loss + mse_loss
264
+
265
+ cos_sim = cls.sim(z1.unsqueeze(1), z2.unsqueeze(0))
266
+
267
+ # Hard negative
268
+ if num_sent >= 3:
269
+ z1_z3_cos = cls.sim(z1.unsqueeze(1), z3.unsqueeze(0))
270
+ cos_sim = torch.cat([cos_sim, z1_z3_cos], 1)
271
+
272
+ labels = torch.arange(cos_sim.size(0)).long().to(cls.device)
273
+ loss_fct = nn.CrossEntropyLoss()
274
+
275
+ # Calculate loss with hard negatives
276
+ if num_sent == 3:
277
+ # Note that weights are actually logits of weights
278
+ z3_weight = cls.model_args.hard_negative_weight
279
+ weights = torch.tensor(
280
+ [[0.0] * (cos_sim.size(-1) - z1_z3_cos.size(-1)) + [0.0] * i + [z3_weight] + [0.0] * (
281
+ z1_z3_cos.size(-1) - i - 1) for i in range(z1_z3_cos.size(-1))]
282
+ ).to(cls.device)
283
+ cos_sim = cos_sim + weights
284
+
285
+ loss = loss_fct(cos_sim, labels)
286
+
287
+ # Calculate loss for MLM
288
+ if mlm_outputs is not None and mlm_labels is not None:
289
+ mlm_labels = mlm_labels.view(-1, mlm_labels.size(-1))
290
+ prediction_scores = cls.lm_head(mlm_outputs.last_hidden_state)
291
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, cls.config.vocab_size), mlm_labels.view(-1))
292
+ loss = loss + cls.model_args.mlm_weight * masked_lm_loss
293
+
294
+ if not return_dict:
295
+ output = (cos_sim,) + outputs[2:]
296
+ return ((loss,) + output) if loss is not None else output
297
+
298
+ return SequenceClassifierOutput(
299
+ loss=ziang_loss,
300
+ logits=cos_sim,
301
+ hidden_states=outputs.hidden_states,
302
+ )
303
+
304
+
305
+ def sentemb_forward(
306
+ cls,
307
+ encoder,
308
+ input_ids=None,
309
+ attention_mask=None,
310
+ token_type_ids=None,
311
+ position_ids=None,
312
+ head_mask=None,
313
+ inputs_embeds=None,
314
+ labels=None,
315
+ output_attentions=None,
316
+ output_hidden_states=None,
317
+ return_dict=None,
318
+ ):
319
+ return_dict = return_dict if return_dict is not None else cls.config.use_return_dict
320
+
321
+ if inputs_embeds is not None:
322
+ input_ids = None
323
+
324
+ outputs = encoder(
325
+ input_ids,
326
+ attention_mask=attention_mask,
327
+ token_type_ids=token_type_ids,
328
+ position_ids=position_ids,
329
+ head_mask=head_mask,
330
+ inputs_embeds=inputs_embeds,
331
+ output_attentions=output_attentions,
332
+ output_hidden_states=True if cls.pooler_type in ['avg_top2', 'avg_first_last'] else False,
333
+ return_dict=True,
334
+ )
335
+
336
+ pooler_output = cls.pooler(attention_mask, outputs)
337
+ if cls.pooler_type == "cls" and not cls.model_args.mlp_only_train:
338
+ pooler_output = cls.mlp(pooler_output)
339
+
340
+ if not return_dict:
341
+ return (outputs[0], pooler_output) + outputs[2:]
342
+
343
+ return BaseModelOutputWithPoolingAndCrossAttentions(
344
+ pooler_output=pooler_output,
345
+ last_hidden_state=outputs.last_hidden_state,
346
+ hidden_states=outputs.hidden_states,
347
+ )
348
+
349
+
350
+ class BertForCL(BertPreTrainedModel):
351
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
352
+
353
+ def __init__(self, config, *model_args, **model_kargs):
354
+ super().__init__(config)
355
+ self.model_args = model_kargs["model_args"]
356
+ self.bert = BertModel(config, add_pooling_layer=False)
357
+
358
+ if self.model_args.do_mlm:
359
+ self.lm_head = BertLMPredictionHead(config)
360
+
361
+ if self.model_args.init_embeddings_model:
362
+ if "glm" in self.model_args.init_embeddings_model:
363
+ init_glm(self.model_args.init_embeddings_model)
364
+ self.fc = nn.Linear(glm_model.config.hidden_size, config.hidden_size)
365
+ else:
366
+ raise NotImplementedError
367
+
368
+ cl_init(self, config)
369
+
370
+ def forward(self,
371
+ input_ids=None,
372
+ attention_mask=None,
373
+ token_type_ids=None,
374
+ position_ids=None,
375
+ head_mask=None,
376
+ inputs_embeds=None,
377
+ labels=None,
378
+ output_attentions=None,
379
+ output_hidden_states=None,
380
+ return_dict=None,
381
+ sent_emb=False,
382
+ mlm_input_ids=None,
383
+ mlm_labels=None,
384
+ left_emb=None,
385
+ right_emb=None,
386
+ ):
387
+ if self.model_args.init_embeddings_model:
388
+ input_ids_for_glm = input_ids.view((-1, input_ids.size(-1))) # (bs * num_sent, len)
389
+ attention_mask_for_glm = attention_mask.view((-1, attention_mask.size(-1))) # (bs * num_sent len)
390
+ if token_type_ids is not None:
391
+ token_type_ids_for_glm = token_type_ids.view((-1, token_type_ids.size(-1))) # (bs * num_sent, len)
392
+
393
+ outputs_from_glm = glm_model(input_ids_for_glm,
394
+ attention_mask=attention_mask_for_glm,
395
+ token_type_ids=token_type_ids_for_glm,
396
+ position_ids=position_ids,
397
+ head_mask=head_mask,
398
+ inputs_embeds=inputs_embeds,
399
+ labels=labels,
400
+ output_attentions=output_attentions,
401
+ output_hidden_states=output_hidden_states,
402
+ return_dict=return_dict,
403
+ )
404
+
405
+ inputs_embeds = self.fc(outputs_from_glm.last_hidden_state)
406
+
407
+ if sent_emb:
408
+ return sentemb_forward(self, self.bert,
409
+ input_ids=input_ids,
410
+ attention_mask=attention_mask,
411
+ token_type_ids=token_type_ids,
412
+ position_ids=position_ids,
413
+ head_mask=head_mask,
414
+ inputs_embeds=inputs_embeds,
415
+ labels=labels,
416
+ output_attentions=output_attentions,
417
+ output_hidden_states=output_hidden_states,
418
+ return_dict=return_dict,
419
+ )
420
+ else:
421
+ return cl_forward(self, self.bert,
422
+ input_ids=input_ids,
423
+ attention_mask=attention_mask,
424
+ token_type_ids=token_type_ids,
425
+ position_ids=position_ids,
426
+ head_mask=head_mask,
427
+ inputs_embeds=inputs_embeds,
428
+ labels=labels,
429
+ output_attentions=output_attentions,
430
+ output_hidden_states=output_hidden_states,
431
+ return_dict=return_dict,
432
+ mlm_input_ids=mlm_input_ids,
433
+ mlm_labels=mlm_labels,
434
+ left_emb=left_emb,
435
+ right_emb=right_emb,
436
+ )
437
+
438
+
439
+ class RobertaForCL(RobertaPreTrainedModel):
440
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
441
+
442
+ def __init__(self, config, *model_args, **model_kargs):
443
+ super().__init__(config)
444
+ self.model_args = model_kargs["model_args"]
445
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
446
+
447
+ if self.model_args.do_mlm:
448
+ self.lm_head = RobertaLMHead(config)
449
+
450
+ if self.model_args.init_embeddings_model:
451
+ if "glm" in self.model_args.init_embeddings_model:
452
+ init_glm(self.model_args.init_embeddings_model)
453
+ self.fc = nn.Linear(glm_model.config.hidden_size, config.hidden_size)
454
+ else:
455
+ raise NotImplementedError
456
+
457
+ cl_init(self, config)
458
+
459
+ def forward(self,
460
+ input_ids=None,
461
+ attention_mask=None,
462
+ token_type_ids=None,
463
+ position_ids=None,
464
+ head_mask=None,
465
+ inputs_embeds=None,
466
+ labels=None,
467
+ output_attentions=None,
468
+ output_hidden_states=None,
469
+ return_dict=None,
470
+ sent_emb=False,
471
+ mlm_input_ids=None,
472
+ mlm_labels=None,
473
+ left_emb=None,
474
+ right_emb=None,
475
+ ):
476
+
477
+ if self.model_args.init_embeddings_model and not sent_emb:
478
+ input_ids_for_glm = input_ids.view((-1, input_ids.size(-1))) # (bs * num_sent, len)
479
+ attention_mask_for_glm = attention_mask.view((-1, attention_mask.size(-1))) # (bs * num_sent len)
480
+ if token_type_ids is not None:
481
+ token_type_ids_for_glm = token_type_ids.view((-1, token_type_ids.size(-1))) # (bs * num_sent, len)
482
+
483
+ outputs_from_glm = glm_model(input_ids_for_glm,
484
+ attention_mask=attention_mask_for_glm,
485
+ token_type_ids=token_type_ids_for_glm,
486
+ position_ids=position_ids,
487
+ head_mask=head_mask,
488
+ inputs_embeds=inputs_embeds,
489
+ labels=labels,
490
+ output_attentions=output_attentions,
491
+ output_hidden_states=output_hidden_states,
492
+ return_dict=return_dict,
493
+ )
494
+
495
+ inputs_embeds = self.fc(outputs_from_glm.last_hidden_state)
496
+
497
+ if sent_emb:
498
+ return sentemb_forward(self, self.roberta,
499
+ input_ids=input_ids,
500
+ attention_mask=attention_mask,
501
+ token_type_ids=token_type_ids,
502
+ position_ids=position_ids,
503
+ head_mask=head_mask,
504
+ inputs_embeds=inputs_embeds,
505
+ labels=labels,
506
+ output_attentions=output_attentions,
507
+ output_hidden_states=output_hidden_states,
508
+ return_dict=return_dict,
509
+ )
510
+ else:
511
+ return cl_forward(self, self.roberta,
512
+ input_ids=input_ids,
513
+ attention_mask=attention_mask,
514
+ token_type_ids=token_type_ids,
515
+ position_ids=position_ids,
516
+ head_mask=head_mask,
517
+ inputs_embeds=inputs_embeds,
518
+ labels=labels,
519
+ output_attentions=output_attentions,
520
+ output_hidden_states=output_hidden_states,
521
+ return_dict=return_dict,
522
+ mlm_input_ids=mlm_input_ids,
523
+ mlm_labels=mlm_labels,
524
+ left_emb=left_emb,
525
+ right_emb=right_emb,
526
+ )
527
+