silk-road commited on
Commit
68ac936
·
1 Parent(s): 3fdabda

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +51 -29
models.py CHANGED
@@ -6,6 +6,8 @@ import torch.distributed as dist
6
  from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel, RobertaModel, RobertaLMHead
7
  from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel, BertLMPredictionHead
8
  from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutputWithPoolingAndCrossAttentions
 
 
9
 
10
  class MLPLayer(nn.Module):
11
  """
@@ -24,10 +26,12 @@ class MLPLayer(nn.Module):
24
  x = self.activation(x)
25
  return x
26
 
 
27
  class Similarity(nn.Module):
28
  """
29
  Dot product or cosine similarity
30
  """
 
31
  def __init__(self, temp):
32
  super().__init__()
33
  self.temp = temp
@@ -80,9 +84,11 @@ class Pooler(nn.Module):
80
 
81
  def mse_loss_mat(tensor_left, tensor_right):
82
  cos_sim_matrix = torch.matmul(tensor_left, tensor_right.t())
83
- cos_sim_matrix /= torch.matmul(torch.norm(tensor_left, dim=1, keepdim=True), torch.norm(tensor_right, dim=1, keepdim=True).t())
 
84
  return cos_sim_matrix
85
 
 
86
  def cl_init(cls, config):
87
  """
88
  Contrastive learning class init function.
@@ -116,10 +122,10 @@ def cl_forward(cls,
116
  num_sent = input_ids.size(1)
117
 
118
  mlm_outputs = None
119
- input_ids = input_ids.view((-1, input_ids.size(-1)))
120
- attention_mask = attention_mask.view((-1, attention_mask.size(-1)))
121
  if token_type_ids is not None:
122
- token_type_ids = token_type_ids.view((-1, token_type_ids.size(-1)))
123
 
124
  if inputs_embeds is not None:
125
  input_ids = None
@@ -133,7 +139,8 @@ def cl_forward(cls,
133
  head_mask=head_mask,
134
  inputs_embeds=inputs_embeds,
135
  output_attentions=output_attentions,
136
- output_hidden_states=True if cls.model_args.pooler_type in ['avg_top2', 'avg_first_last'] else False,
 
137
  return_dict=True,
138
  )
139
 
@@ -148,7 +155,8 @@ def cl_forward(cls,
148
  head_mask=head_mask,
149
  inputs_embeds=inputs_embeds,
150
  output_attentions=output_attentions,
151
- output_hidden_states=True if cls.model_args.pooler_type in ['avg_top2', 'avg_first_last'] else False,
 
152
  return_dict=True,
153
  )
154
 
@@ -156,7 +164,8 @@ def cl_forward(cls,
156
  print(outputs.last_hidden_state.shape)
157
  pooler_output = cls.pooler(attention_mask, outputs)
158
  print(pooler_output.shape)
159
- pooler_output = pooler_output.view((batch_size, num_sent, pooler_output.size(-1))) # (bs, num_sent, hidden)
 
160
  # If using "cls", we add an extra MLP layer
161
  # (same as BERT's original implementation) over the representation.
162
  if cls.pooler_type == "cls":
@@ -175,7 +184,8 @@ def cl_forward(cls,
175
  if dist.is_initialized() and cls.training:
176
  # Gather hard negative
177
  if num_sent >= 3:
178
- z3_list = [torch.zeros_like(z3) for _ in range(dist.get_world_size())]
 
179
  dist.all_gather(tensor_list=z3_list, tensor=z3.contiguous())
180
  z3_list[dist.get_rank()] = z3
181
  z3 = torch.cat(z3_list, 0)
@@ -194,9 +204,9 @@ def cl_forward(cls,
194
  # Get full batch embeddings: (bs x N, hidden)
195
  z1 = torch.cat(z1_list, 0)
196
  z2 = torch.cat(z2_list, 0)
197
-
198
  mse_loss = F.mse_loss(z1, tensor_left) + F.mse_loss(z2, tensor_right)
199
-
200
  """
201
  this is KL div loss
202
  """
@@ -210,17 +220,24 @@ def cl_forward(cls,
210
  cos_sim_matrix_data = mse_loss_mat(z1, z2)
211
  beta_scaled_cos_sim_matrix_data = beta * cos_sim_matrix_data
212
 
213
- beta_scaled_cos_sim_matrix_openai_vertical = beta_scaled_cos_sim_matrix_openai.softmax(dim=1)
214
- beta_scaled_cos_sim_matrix_openai_horizontal = beta_scaled_cos_sim_matrix_openai.softmax(dim=0)
 
 
215
 
216
- beta_scaled_cos_sim_matrix_data_vertical = beta_scaled_cos_sim_matrix_data.softmax(dim=1)
217
- beta_scaled_cos_sim_matrix_data_horizontal = beta_scaled_cos_sim_matrix_data.softmax(dim=0)
 
 
218
 
219
- KL_vertical_loss = KL_loss(beta_scaled_cos_sim_matrix_data_vertical.log(), beta_scaled_cos_sim_matrix_openai_vertical)
220
- KL_horizontal_loss = KL_loss(beta_scaled_cos_sim_matrix_data_horizontal.log(), beta_scaled_cos_sim_matrix_openai_horizontal)
 
 
 
221
 
222
  KL_loss = (KL_vertical_loss + KL_horizontal_loss) / 2
223
-
224
  ziang_loss = KL_loss + mse_loss
225
 
226
  cos_sim = cls.sim(z1.unsqueeze(1), z2.unsqueeze(0))
@@ -239,7 +256,7 @@ def cl_forward(cls,
239
  z3_weight = cls.model_args.hard_negative_weight
240
  weights = torch.tensor(
241
  [[0.0] * (cos_sim.size(-1) - z1_z3_cos.size(-1)) + [0.0] * i + [z3_weight] + [0.0] * (
242
- z1_z3_cos.size(-1) - i - 1) for i in range(z1_z3_cos.size(-1))]
243
  ).to(cls.device)
244
  cos_sim = cos_sim + weights
245
 
@@ -249,7 +266,8 @@ def cl_forward(cls,
249
  if mlm_outputs is not None and mlm_labels is not None:
250
  mlm_labels = mlm_labels.view(-1, mlm_labels.size(-1))
251
  prediction_scores = cls.lm_head(mlm_outputs.last_hidden_state)
252
- masked_lm_loss = loss_fct(prediction_scores.view(-1, cls.config.vocab_size), mlm_labels.view(-1))
 
253
  loss = loss + cls.model_args.mlm_weight * masked_lm_loss
254
 
255
  if not return_dict:
@@ -290,7 +308,8 @@ def sentemb_forward(
290
  head_mask=head_mask,
291
  inputs_embeds=inputs_embeds,
292
  output_attentions=output_attentions,
293
- output_hidden_states=True if cls.pooler_type in ['avg_top2', 'avg_first_last'] else False,
 
294
  return_dict=True,
295
  )
296
 
@@ -308,17 +327,23 @@ def sentemb_forward(
308
  )
309
 
310
 
 
 
 
 
 
 
 
 
311
  class BertForCL(BertPreTrainedModel):
312
  _keys_to_ignore_on_load_missing = [r"position_ids"]
313
 
314
  def __init__(self, config, *model_args, **model_kargs):
315
  super().__init__(config)
316
- self.model_args = model_kargs["model_args"]
317
  self.bert = BertModel(config, add_pooling_layer=False)
318
-
319
  if self.model_args.do_mlm:
320
  self.lm_head = BertLMPredictionHead(config)
321
-
322
  cl_init(self, config)
323
 
324
  def forward(self,
@@ -375,12 +400,10 @@ class RobertaForCL(RobertaPreTrainedModel):
375
 
376
  def __init__(self, config, *model_args, **model_kargs):
377
  super().__init__(config)
378
- self.model_args = model_kargs["model_args"]
379
  self.roberta = RobertaModel(config, add_pooling_layer=False)
380
-
381
  if self.model_args.do_mlm:
382
  self.lm_head = RobertaLMHead(config)
383
-
384
  cl_init(self, config)
385
 
386
  def forward(self,
@@ -427,7 +450,6 @@ class RobertaForCL(RobertaPreTrainedModel):
427
  return_dict=return_dict,
428
  mlm_input_ids=mlm_input_ids,
429
  mlm_labels=mlm_labels,
430
- left_emb=left_emb,
431
- right_emb=right_emb,
432
  )
433
-
 
6
  from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel, RobertaModel, RobertaLMHead
7
  from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel, BertLMPredictionHead
8
  from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutputWithPoolingAndCrossAttentions
9
+ from argparse import Namespace
10
+
11
 
12
  class MLPLayer(nn.Module):
13
  """
 
26
  x = self.activation(x)
27
  return x
28
 
29
+
30
  class Similarity(nn.Module):
31
  """
32
  Dot product or cosine similarity
33
  """
34
+
35
  def __init__(self, temp):
36
  super().__init__()
37
  self.temp = temp
 
84
 
85
  def mse_loss_mat(tensor_left, tensor_right):
86
  cos_sim_matrix = torch.matmul(tensor_left, tensor_right.t())
87
+ cos_sim_matrix /= torch.matmul(torch.norm(tensor_left, dim=1, keepdim=True),
88
+ torch.norm(tensor_right, dim=1, keepdim=True).t())
89
  return cos_sim_matrix
90
 
91
+
92
  def cl_init(cls, config):
93
  """
94
  Contrastive learning class init function.
 
122
  num_sent = input_ids.size(1)
123
 
124
  mlm_outputs = None
125
+ input_ids = input_ids.view((-1, input_ids.size(-1)))
126
+ attention_mask = attention_mask.view((-1, attention_mask.size(-1)))
127
  if token_type_ids is not None:
128
+ token_type_ids = token_type_ids.view((-1, token_type_ids.size(-1)))
129
 
130
  if inputs_embeds is not None:
131
  input_ids = None
 
139
  head_mask=head_mask,
140
  inputs_embeds=inputs_embeds,
141
  output_attentions=output_attentions,
142
+ output_hidden_states=True if cls.model_args.pooler_type in [
143
+ 'avg_top2', 'avg_first_last'] else False,
144
  return_dict=True,
145
  )
146
 
 
155
  head_mask=head_mask,
156
  inputs_embeds=inputs_embeds,
157
  output_attentions=output_attentions,
158
+ output_hidden_states=True if cls.model_args.pooler_type in [
159
+ 'avg_top2', 'avg_first_last'] else False,
160
  return_dict=True,
161
  )
162
 
 
164
  print(outputs.last_hidden_state.shape)
165
  pooler_output = cls.pooler(attention_mask, outputs)
166
  print(pooler_output.shape)
167
+ pooler_output = pooler_output.view(
168
+ (batch_size, num_sent, pooler_output.size(-1))) # (bs, num_sent, hidden)
169
  # If using "cls", we add an extra MLP layer
170
  # (same as BERT's original implementation) over the representation.
171
  if cls.pooler_type == "cls":
 
184
  if dist.is_initialized() and cls.training:
185
  # Gather hard negative
186
  if num_sent >= 3:
187
+ z3_list = [torch.zeros_like(z3)
188
+ for _ in range(dist.get_world_size())]
189
  dist.all_gather(tensor_list=z3_list, tensor=z3.contiguous())
190
  z3_list[dist.get_rank()] = z3
191
  z3 = torch.cat(z3_list, 0)
 
204
  # Get full batch embeddings: (bs x N, hidden)
205
  z1 = torch.cat(z1_list, 0)
206
  z2 = torch.cat(z2_list, 0)
207
+
208
  mse_loss = F.mse_loss(z1, tensor_left) + F.mse_loss(z2, tensor_right)
209
+
210
  """
211
  this is KL div loss
212
  """
 
220
  cos_sim_matrix_data = mse_loss_mat(z1, z2)
221
  beta_scaled_cos_sim_matrix_data = beta * cos_sim_matrix_data
222
 
223
+ beta_scaled_cos_sim_matrix_openai_vertical = beta_scaled_cos_sim_matrix_openai.softmax(
224
+ dim=1)
225
+ beta_scaled_cos_sim_matrix_openai_horizontal = beta_scaled_cos_sim_matrix_openai.softmax(
226
+ dim=0)
227
 
228
+ beta_scaled_cos_sim_matrix_data_vertical = beta_scaled_cos_sim_matrix_data.softmax(
229
+ dim=1)
230
+ beta_scaled_cos_sim_matrix_data_horizontal = beta_scaled_cos_sim_matrix_data.softmax(
231
+ dim=0)
232
 
233
+ # remove reduction="batchmean"
234
+ KL_vertical_loss = KL_loss(beta_scaled_cos_sim_matrix_data_vertical.log(
235
+ ), beta_scaled_cos_sim_matrix_openai_vertical)
236
+ KL_horizontal_loss = KL_loss(beta_scaled_cos_sim_matrix_data_horizontal.log(
237
+ ), beta_scaled_cos_sim_matrix_openai_horizontal)
238
 
239
  KL_loss = (KL_vertical_loss + KL_horizontal_loss) / 2
240
+
241
  ziang_loss = KL_loss + mse_loss
242
 
243
  cos_sim = cls.sim(z1.unsqueeze(1), z2.unsqueeze(0))
 
256
  z3_weight = cls.model_args.hard_negative_weight
257
  weights = torch.tensor(
258
  [[0.0] * (cos_sim.size(-1) - z1_z3_cos.size(-1)) + [0.0] * i + [z3_weight] + [0.0] * (
259
+ z1_z3_cos.size(-1) - i - 1) for i in range(z1_z3_cos.size(-1))]
260
  ).to(cls.device)
261
  cos_sim = cos_sim + weights
262
 
 
266
  if mlm_outputs is not None and mlm_labels is not None:
267
  mlm_labels = mlm_labels.view(-1, mlm_labels.size(-1))
268
  prediction_scores = cls.lm_head(mlm_outputs.last_hidden_state)
269
+ masked_lm_loss = loss_fct(
270
+ prediction_scores.view(-1, cls.config.vocab_size), mlm_labels.view(-1))
271
  loss = loss + cls.model_args.mlm_weight * masked_lm_loss
272
 
273
  if not return_dict:
 
308
  head_mask=head_mask,
309
  inputs_embeds=inputs_embeds,
310
  output_attentions=output_attentions,
311
+ output_hidden_states=True if cls.pooler_type in [
312
+ 'avg_top2', 'avg_first_last'] else False,
313
  return_dict=True,
314
  )
315
 
 
327
  )
328
 
329
 
330
+ default_model_args = Namespace(
331
+ do_mlm=None,
332
+ pooler_type="cls",
333
+ temp=0.05,
334
+ mlp_only_train=False
335
+ )
336
+
337
+
338
  class BertForCL(BertPreTrainedModel):
339
  _keys_to_ignore_on_load_missing = [r"position_ids"]
340
 
341
  def __init__(self, config, *model_args, **model_kargs):
342
  super().__init__(config)
343
+ self.model_args = model_kargs.get('model_args') or default_model_args
344
  self.bert = BertModel(config, add_pooling_layer=False)
 
345
  if self.model_args.do_mlm:
346
  self.lm_head = BertLMPredictionHead(config)
 
347
  cl_init(self, config)
348
 
349
  def forward(self,
 
400
 
401
  def __init__(self, config, *model_args, **model_kargs):
402
  super().__init__(config)
 
403
  self.roberta = RobertaModel(config, add_pooling_layer=False)
404
+ self.model_args = model_kargs.get('model_args') or default_model_args
405
  if self.model_args.do_mlm:
406
  self.lm_head = RobertaLMHead(config)
 
407
  cl_init(self, config)
408
 
409
  def forward(self,
 
450
  return_dict=return_dict,
451
  mlm_input_ids=mlm_input_ids,
452
  mlm_labels=mlm_labels,
453
+ left_emb=left_emb,
454
+ right_emb=right_emb,
455
  )