Update models.py
Browse files
models.py
CHANGED
@@ -6,6 +6,8 @@ import torch.distributed as dist
|
|
6 |
from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel, RobertaModel, RobertaLMHead
|
7 |
from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel, BertLMPredictionHead
|
8 |
from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutputWithPoolingAndCrossAttentions
|
|
|
|
|
9 |
|
10 |
class MLPLayer(nn.Module):
|
11 |
"""
|
@@ -24,10 +26,12 @@ class MLPLayer(nn.Module):
|
|
24 |
x = self.activation(x)
|
25 |
return x
|
26 |
|
|
|
27 |
class Similarity(nn.Module):
|
28 |
"""
|
29 |
Dot product or cosine similarity
|
30 |
"""
|
|
|
31 |
def __init__(self, temp):
|
32 |
super().__init__()
|
33 |
self.temp = temp
|
@@ -80,9 +84,11 @@ class Pooler(nn.Module):
|
|
80 |
|
81 |
def mse_loss_mat(tensor_left, tensor_right):
|
82 |
cos_sim_matrix = torch.matmul(tensor_left, tensor_right.t())
|
83 |
-
cos_sim_matrix /= torch.matmul(torch.norm(tensor_left, dim=1, keepdim=True),
|
|
|
84 |
return cos_sim_matrix
|
85 |
|
|
|
86 |
def cl_init(cls, config):
|
87 |
"""
|
88 |
Contrastive learning class init function.
|
@@ -116,10 +122,10 @@ def cl_forward(cls,
|
|
116 |
num_sent = input_ids.size(1)
|
117 |
|
118 |
mlm_outputs = None
|
119 |
-
input_ids = input_ids.view((-1, input_ids.size(-1)))
|
120 |
-
attention_mask = attention_mask.view((-1, attention_mask.size(-1)))
|
121 |
if token_type_ids is not None:
|
122 |
-
token_type_ids = token_type_ids.view((-1, token_type_ids.size(-1)))
|
123 |
|
124 |
if inputs_embeds is not None:
|
125 |
input_ids = None
|
@@ -133,7 +139,8 @@ def cl_forward(cls,
|
|
133 |
head_mask=head_mask,
|
134 |
inputs_embeds=inputs_embeds,
|
135 |
output_attentions=output_attentions,
|
136 |
-
output_hidden_states=True if cls.model_args.pooler_type in [
|
|
|
137 |
return_dict=True,
|
138 |
)
|
139 |
|
@@ -148,7 +155,8 @@ def cl_forward(cls,
|
|
148 |
head_mask=head_mask,
|
149 |
inputs_embeds=inputs_embeds,
|
150 |
output_attentions=output_attentions,
|
151 |
-
output_hidden_states=True if cls.model_args.pooler_type in [
|
|
|
152 |
return_dict=True,
|
153 |
)
|
154 |
|
@@ -156,7 +164,8 @@ def cl_forward(cls,
|
|
156 |
print(outputs.last_hidden_state.shape)
|
157 |
pooler_output = cls.pooler(attention_mask, outputs)
|
158 |
print(pooler_output.shape)
|
159 |
-
pooler_output = pooler_output.view(
|
|
|
160 |
# If using "cls", we add an extra MLP layer
|
161 |
# (same as BERT's original implementation) over the representation.
|
162 |
if cls.pooler_type == "cls":
|
@@ -175,7 +184,8 @@ def cl_forward(cls,
|
|
175 |
if dist.is_initialized() and cls.training:
|
176 |
# Gather hard negative
|
177 |
if num_sent >= 3:
|
178 |
-
z3_list = [torch.zeros_like(z3)
|
|
|
179 |
dist.all_gather(tensor_list=z3_list, tensor=z3.contiguous())
|
180 |
z3_list[dist.get_rank()] = z3
|
181 |
z3 = torch.cat(z3_list, 0)
|
@@ -194,9 +204,9 @@ def cl_forward(cls,
|
|
194 |
# Get full batch embeddings: (bs x N, hidden)
|
195 |
z1 = torch.cat(z1_list, 0)
|
196 |
z2 = torch.cat(z2_list, 0)
|
197 |
-
|
198 |
mse_loss = F.mse_loss(z1, tensor_left) + F.mse_loss(z2, tensor_right)
|
199 |
-
|
200 |
"""
|
201 |
this is KL div loss
|
202 |
"""
|
@@ -210,17 +220,24 @@ def cl_forward(cls,
|
|
210 |
cos_sim_matrix_data = mse_loss_mat(z1, z2)
|
211 |
beta_scaled_cos_sim_matrix_data = beta * cos_sim_matrix_data
|
212 |
|
213 |
-
beta_scaled_cos_sim_matrix_openai_vertical = beta_scaled_cos_sim_matrix_openai.softmax(
|
214 |
-
|
|
|
|
|
215 |
|
216 |
-
beta_scaled_cos_sim_matrix_data_vertical = beta_scaled_cos_sim_matrix_data.softmax(
|
217 |
-
|
|
|
|
|
218 |
|
219 |
-
|
220 |
-
|
|
|
|
|
|
|
221 |
|
222 |
KL_loss = (KL_vertical_loss + KL_horizontal_loss) / 2
|
223 |
-
|
224 |
ziang_loss = KL_loss + mse_loss
|
225 |
|
226 |
cos_sim = cls.sim(z1.unsqueeze(1), z2.unsqueeze(0))
|
@@ -239,7 +256,7 @@ def cl_forward(cls,
|
|
239 |
z3_weight = cls.model_args.hard_negative_weight
|
240 |
weights = torch.tensor(
|
241 |
[[0.0] * (cos_sim.size(-1) - z1_z3_cos.size(-1)) + [0.0] * i + [z3_weight] + [0.0] * (
|
242 |
-
|
243 |
).to(cls.device)
|
244 |
cos_sim = cos_sim + weights
|
245 |
|
@@ -249,7 +266,8 @@ def cl_forward(cls,
|
|
249 |
if mlm_outputs is not None and mlm_labels is not None:
|
250 |
mlm_labels = mlm_labels.view(-1, mlm_labels.size(-1))
|
251 |
prediction_scores = cls.lm_head(mlm_outputs.last_hidden_state)
|
252 |
-
masked_lm_loss = loss_fct(
|
|
|
253 |
loss = loss + cls.model_args.mlm_weight * masked_lm_loss
|
254 |
|
255 |
if not return_dict:
|
@@ -290,7 +308,8 @@ def sentemb_forward(
|
|
290 |
head_mask=head_mask,
|
291 |
inputs_embeds=inputs_embeds,
|
292 |
output_attentions=output_attentions,
|
293 |
-
output_hidden_states=True if cls.pooler_type in [
|
|
|
294 |
return_dict=True,
|
295 |
)
|
296 |
|
@@ -308,17 +327,23 @@ def sentemb_forward(
|
|
308 |
)
|
309 |
|
310 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
311 |
class BertForCL(BertPreTrainedModel):
|
312 |
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
313 |
|
314 |
def __init__(self, config, *model_args, **model_kargs):
|
315 |
super().__init__(config)
|
316 |
-
self.model_args = model_kargs
|
317 |
self.bert = BertModel(config, add_pooling_layer=False)
|
318 |
-
|
319 |
if self.model_args.do_mlm:
|
320 |
self.lm_head = BertLMPredictionHead(config)
|
321 |
-
|
322 |
cl_init(self, config)
|
323 |
|
324 |
def forward(self,
|
@@ -375,12 +400,10 @@ class RobertaForCL(RobertaPreTrainedModel):
|
|
375 |
|
376 |
def __init__(self, config, *model_args, **model_kargs):
|
377 |
super().__init__(config)
|
378 |
-
self.model_args = model_kargs["model_args"]
|
379 |
self.roberta = RobertaModel(config, add_pooling_layer=False)
|
380 |
-
|
381 |
if self.model_args.do_mlm:
|
382 |
self.lm_head = RobertaLMHead(config)
|
383 |
-
|
384 |
cl_init(self, config)
|
385 |
|
386 |
def forward(self,
|
@@ -427,7 +450,6 @@ class RobertaForCL(RobertaPreTrainedModel):
|
|
427 |
return_dict=return_dict,
|
428 |
mlm_input_ids=mlm_input_ids,
|
429 |
mlm_labels=mlm_labels,
|
430 |
-
|
431 |
-
|
432 |
)
|
433 |
-
|
|
|
6 |
from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel, RobertaModel, RobertaLMHead
|
7 |
from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel, BertLMPredictionHead
|
8 |
from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutputWithPoolingAndCrossAttentions
|
9 |
+
from argparse import Namespace
|
10 |
+
|
11 |
|
12 |
class MLPLayer(nn.Module):
|
13 |
"""
|
|
|
26 |
x = self.activation(x)
|
27 |
return x
|
28 |
|
29 |
+
|
30 |
class Similarity(nn.Module):
|
31 |
"""
|
32 |
Dot product or cosine similarity
|
33 |
"""
|
34 |
+
|
35 |
def __init__(self, temp):
|
36 |
super().__init__()
|
37 |
self.temp = temp
|
|
|
84 |
|
85 |
def mse_loss_mat(tensor_left, tensor_right):
|
86 |
cos_sim_matrix = torch.matmul(tensor_left, tensor_right.t())
|
87 |
+
cos_sim_matrix /= torch.matmul(torch.norm(tensor_left, dim=1, keepdim=True),
|
88 |
+
torch.norm(tensor_right, dim=1, keepdim=True).t())
|
89 |
return cos_sim_matrix
|
90 |
|
91 |
+
|
92 |
def cl_init(cls, config):
|
93 |
"""
|
94 |
Contrastive learning class init function.
|
|
|
122 |
num_sent = input_ids.size(1)
|
123 |
|
124 |
mlm_outputs = None
|
125 |
+
input_ids = input_ids.view((-1, input_ids.size(-1)))
|
126 |
+
attention_mask = attention_mask.view((-1, attention_mask.size(-1)))
|
127 |
if token_type_ids is not None:
|
128 |
+
token_type_ids = token_type_ids.view((-1, token_type_ids.size(-1)))
|
129 |
|
130 |
if inputs_embeds is not None:
|
131 |
input_ids = None
|
|
|
139 |
head_mask=head_mask,
|
140 |
inputs_embeds=inputs_embeds,
|
141 |
output_attentions=output_attentions,
|
142 |
+
output_hidden_states=True if cls.model_args.pooler_type in [
|
143 |
+
'avg_top2', 'avg_first_last'] else False,
|
144 |
return_dict=True,
|
145 |
)
|
146 |
|
|
|
155 |
head_mask=head_mask,
|
156 |
inputs_embeds=inputs_embeds,
|
157 |
output_attentions=output_attentions,
|
158 |
+
output_hidden_states=True if cls.model_args.pooler_type in [
|
159 |
+
'avg_top2', 'avg_first_last'] else False,
|
160 |
return_dict=True,
|
161 |
)
|
162 |
|
|
|
164 |
print(outputs.last_hidden_state.shape)
|
165 |
pooler_output = cls.pooler(attention_mask, outputs)
|
166 |
print(pooler_output.shape)
|
167 |
+
pooler_output = pooler_output.view(
|
168 |
+
(batch_size, num_sent, pooler_output.size(-1))) # (bs, num_sent, hidden)
|
169 |
# If using "cls", we add an extra MLP layer
|
170 |
# (same as BERT's original implementation) over the representation.
|
171 |
if cls.pooler_type == "cls":
|
|
|
184 |
if dist.is_initialized() and cls.training:
|
185 |
# Gather hard negative
|
186 |
if num_sent >= 3:
|
187 |
+
z3_list = [torch.zeros_like(z3)
|
188 |
+
for _ in range(dist.get_world_size())]
|
189 |
dist.all_gather(tensor_list=z3_list, tensor=z3.contiguous())
|
190 |
z3_list[dist.get_rank()] = z3
|
191 |
z3 = torch.cat(z3_list, 0)
|
|
|
204 |
# Get full batch embeddings: (bs x N, hidden)
|
205 |
z1 = torch.cat(z1_list, 0)
|
206 |
z2 = torch.cat(z2_list, 0)
|
207 |
+
|
208 |
mse_loss = F.mse_loss(z1, tensor_left) + F.mse_loss(z2, tensor_right)
|
209 |
+
|
210 |
"""
|
211 |
this is KL div loss
|
212 |
"""
|
|
|
220 |
cos_sim_matrix_data = mse_loss_mat(z1, z2)
|
221 |
beta_scaled_cos_sim_matrix_data = beta * cos_sim_matrix_data
|
222 |
|
223 |
+
beta_scaled_cos_sim_matrix_openai_vertical = beta_scaled_cos_sim_matrix_openai.softmax(
|
224 |
+
dim=1)
|
225 |
+
beta_scaled_cos_sim_matrix_openai_horizontal = beta_scaled_cos_sim_matrix_openai.softmax(
|
226 |
+
dim=0)
|
227 |
|
228 |
+
beta_scaled_cos_sim_matrix_data_vertical = beta_scaled_cos_sim_matrix_data.softmax(
|
229 |
+
dim=1)
|
230 |
+
beta_scaled_cos_sim_matrix_data_horizontal = beta_scaled_cos_sim_matrix_data.softmax(
|
231 |
+
dim=0)
|
232 |
|
233 |
+
# remove reduction="batchmean"
|
234 |
+
KL_vertical_loss = KL_loss(beta_scaled_cos_sim_matrix_data_vertical.log(
|
235 |
+
), beta_scaled_cos_sim_matrix_openai_vertical)
|
236 |
+
KL_horizontal_loss = KL_loss(beta_scaled_cos_sim_matrix_data_horizontal.log(
|
237 |
+
), beta_scaled_cos_sim_matrix_openai_horizontal)
|
238 |
|
239 |
KL_loss = (KL_vertical_loss + KL_horizontal_loss) / 2
|
240 |
+
|
241 |
ziang_loss = KL_loss + mse_loss
|
242 |
|
243 |
cos_sim = cls.sim(z1.unsqueeze(1), z2.unsqueeze(0))
|
|
|
256 |
z3_weight = cls.model_args.hard_negative_weight
|
257 |
weights = torch.tensor(
|
258 |
[[0.0] * (cos_sim.size(-1) - z1_z3_cos.size(-1)) + [0.0] * i + [z3_weight] + [0.0] * (
|
259 |
+
z1_z3_cos.size(-1) - i - 1) for i in range(z1_z3_cos.size(-1))]
|
260 |
).to(cls.device)
|
261 |
cos_sim = cos_sim + weights
|
262 |
|
|
|
266 |
if mlm_outputs is not None and mlm_labels is not None:
|
267 |
mlm_labels = mlm_labels.view(-1, mlm_labels.size(-1))
|
268 |
prediction_scores = cls.lm_head(mlm_outputs.last_hidden_state)
|
269 |
+
masked_lm_loss = loss_fct(
|
270 |
+
prediction_scores.view(-1, cls.config.vocab_size), mlm_labels.view(-1))
|
271 |
loss = loss + cls.model_args.mlm_weight * masked_lm_loss
|
272 |
|
273 |
if not return_dict:
|
|
|
308 |
head_mask=head_mask,
|
309 |
inputs_embeds=inputs_embeds,
|
310 |
output_attentions=output_attentions,
|
311 |
+
output_hidden_states=True if cls.pooler_type in [
|
312 |
+
'avg_top2', 'avg_first_last'] else False,
|
313 |
return_dict=True,
|
314 |
)
|
315 |
|
|
|
327 |
)
|
328 |
|
329 |
|
330 |
+
default_model_args = Namespace(
|
331 |
+
do_mlm=None,
|
332 |
+
pooler_type="cls",
|
333 |
+
temp=0.05,
|
334 |
+
mlp_only_train=False
|
335 |
+
)
|
336 |
+
|
337 |
+
|
338 |
class BertForCL(BertPreTrainedModel):
|
339 |
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
340 |
|
341 |
def __init__(self, config, *model_args, **model_kargs):
|
342 |
super().__init__(config)
|
343 |
+
self.model_args = model_kargs.get('model_args') or default_model_args
|
344 |
self.bert = BertModel(config, add_pooling_layer=False)
|
|
|
345 |
if self.model_args.do_mlm:
|
346 |
self.lm_head = BertLMPredictionHead(config)
|
|
|
347 |
cl_init(self, config)
|
348 |
|
349 |
def forward(self,
|
|
|
400 |
|
401 |
def __init__(self, config, *model_args, **model_kargs):
|
402 |
super().__init__(config)
|
|
|
403 |
self.roberta = RobertaModel(config, add_pooling_layer=False)
|
404 |
+
self.model_args = model_kargs.get('model_args') or default_model_args
|
405 |
if self.model_args.do_mlm:
|
406 |
self.lm_head = RobertaLMHead(config)
|
|
|
407 |
cl_init(self, config)
|
408 |
|
409 |
def forward(self,
|
|
|
450 |
return_dict=return_dict,
|
451 |
mlm_input_ids=mlm_input_ids,
|
452 |
mlm_labels=mlm_labels,
|
453 |
+
left_emb=left_emb,
|
454 |
+
right_emb=right_emb,
|
455 |
)
|
|