Balaji S commited on
Commit
2c732ed
·
verified ·
1 Parent(s): d87c6b1

Removed unnecessary import

Browse files
Files changed (1) hide show
  1. model.py +324 -325
model.py CHANGED
@@ -1,325 +1,324 @@
1
- import math
2
- import torch
3
- import numpy as np
4
- import torch.nn as nn
5
- from tqdm import tqdm
6
- import scipy.sparse as sp
7
- import torch.nn.functional as F
8
- import torch.distributed as dist
9
-
10
- import transformers
11
- from transformers import RobertaTokenizer
12
- from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel, RobertaModel, RobertaLMHead
13
- from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel, BertLMPredictionHead
14
- from transformers.activations import gelu
15
- from transformers.file_utils import (
16
- add_code_sample_docstrings,
17
- add_start_docstrings,
18
- add_start_docstrings_to_model_forward,
19
- replace_return_docstrings,
20
- )
21
- from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutputWithPoolingAndCrossAttentions
22
-
23
- from loss_utils import *
24
-
25
- init = nn.init.xavier_uniform_
26
- uniformInit = nn.init.uniform
27
-
28
-
29
- """
30
- EasyRec
31
- """
32
- def dot_product_scores(q_vectors, ctx_vectors):
33
- r = torch.matmul(q_vectors, torch.transpose(ctx_vectors, 0, 1))
34
- return r
35
-
36
- class MLPLayer(nn.Module):
37
- """
38
- Head for getting sentence representations over RoBERTa/BERT's CLS representation.
39
- """
40
- def __init__(self, config):
41
- super().__init__()
42
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
43
- self.activation = nn.Tanh()
44
-
45
- def forward(self, features, **kwargs):
46
- x = self.dense(features)
47
- x = self.activation(x)
48
- return x
49
-
50
-
51
- class Pooler(nn.Module):
52
- """
53
- Parameter-free poolers to get the sentence embedding
54
- 'cls': [CLS] representation with BERT/RoBERTa's MLP pooler.
55
- 'cls_before_pooler': [CLS] representation without the original MLP pooler.
56
- 'avg': average of the last layers' hidden states at each token.
57
- 'avg_top2': average of the last two layers.
58
- 'avg_first_last': average of the first and the last layers.
59
- """
60
- def __init__(self, pooler_type):
61
- super().__init__()
62
- self.pooler_type = pooler_type
63
- assert self.pooler_type in ["cls", "cls_before_pooler", "avg", "avg_top2", "avg_first_last"], "unrecognized pooling type %s" % self.pooler_type
64
-
65
- def forward(self, attention_mask, outputs):
66
- last_hidden = outputs.last_hidden_state
67
- pooler_output = outputs.pooler_output
68
- hidden_states = outputs.hidden_states
69
-
70
- if self.pooler_type in ['cls_before_pooler', 'cls']:
71
- return last_hidden[:, 0]
72
- elif self.pooler_type == "avg":
73
- return ((last_hidden * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1))
74
- elif self.pooler_type == "avg_first_last":
75
- first_hidden = hidden_states[1]
76
- last_hidden = hidden_states[-1]
77
- pooled_result = ((first_hidden + last_hidden) / 2.0 * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1)
78
- return pooled_result
79
- elif self.pooler_type == "avg_top2":
80
- second_last_hidden = hidden_states[-2]
81
- last_hidden = hidden_states[-1]
82
- pooled_result = ((last_hidden + second_last_hidden) / 2.0 * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1)
83
- return pooled_result
84
- else:
85
- raise NotImplementedError
86
-
87
-
88
- class Similarity(nn.Module):
89
- """
90
- Dot product or cosine similarity
91
- """
92
- def __init__(self, temp):
93
- super().__init__()
94
- self.temp = temp
95
- self.cos = nn.CosineSimilarity(dim=-1)
96
-
97
- def forward(self, x, y):
98
- return self.cos(x, y) / self.temp
99
-
100
-
101
- class Easyrec(RobertaPreTrainedModel):
102
- _keys_to_ignore_on_load_missing = [r"position_ids"]
103
-
104
- def __init__(self, config, *model_args, **model_kargs):
105
- super().__init__(config)
106
- try:
107
- self.model_args = model_kargs["model_args"]
108
- self.roberta = RobertaModel(config, add_pooling_layer=False)
109
- if self.model_args.pooler_type == "cls":
110
- self.mlp = MLPLayer(config)
111
- if self.model_args.do_mlm:
112
- self.lm_head = RobertaLMHead(config)
113
- """
114
- Contrastive learning class init function.
115
- """
116
- self.pooler_type = self.model_args.pooler_type
117
- self.pooler = Pooler(self.pooler_type)
118
- self.sim = Similarity(temp=self.model_args.temp)
119
- self.init_weights()
120
- except:
121
- self.roberta = RobertaModel(config, add_pooling_layer=False)
122
- self.mlp = MLPLayer(config)
123
- self.lm_head = RobertaLMHead(config)
124
- self.pooler_type = 'cls'
125
- self.pooler = Pooler(self.pooler_type)
126
- self.init_weights()
127
-
128
- def forward(self,
129
- user_input_ids=None,
130
- user_attention_mask=None,
131
- pos_item_input_ids=None,
132
- pos_item_attention_mask=None,
133
- neg_item_input_ids=None,
134
- neg_item_attention_mask=None,
135
- token_type_ids=None,
136
- position_ids=None,
137
- head_mask=None,
138
- inputs_embeds=None,
139
- labels=None,
140
- output_attentions=None,
141
- output_hidden_states=None,
142
- return_dict=None,
143
- mlm_input_ids=None,
144
- mlm_attention_mask=None,
145
- mlm_labels=None,
146
- ):
147
- """
148
- Contrastive learning forward function.
149
- """
150
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
151
- batch_size = user_input_ids.size(0)
152
-
153
- # Get user embeddings
154
- user_outputs = self.roberta(
155
- input_ids=user_input_ids,
156
- attention_mask=user_attention_mask,
157
- token_type_ids=None,
158
- position_ids=None,
159
- head_mask=None,
160
- inputs_embeds=None,
161
- output_attentions=output_attentions,
162
- output_hidden_states=output_hidden_states,
163
- return_dict=return_dict,
164
- )
165
-
166
- # Get positive item embeddings
167
- pos_item_outputs = self.roberta(
168
- input_ids=pos_item_input_ids,
169
- attention_mask=pos_item_attention_mask,
170
- token_type_ids=None,
171
- position_ids=None,
172
- head_mask=None,
173
- inputs_embeds=None,
174
- output_attentions=output_attentions,
175
- output_hidden_states=output_hidden_states,
176
- return_dict=return_dict,
177
- )
178
-
179
- # Get negative item embeddings
180
- neg_item_outputs = self.roberta(
181
- input_ids=neg_item_input_ids,
182
- attention_mask=neg_item_attention_mask,
183
- token_type_ids=None,
184
- position_ids=None,
185
- head_mask=None,
186
- inputs_embeds=None,
187
- output_attentions=output_attentions,
188
- output_hidden_states=output_hidden_states,
189
- return_dict=return_dict,
190
- )
191
-
192
- # MLM auxiliary objective
193
- if mlm_input_ids is not None:
194
- mlm_outputs = self.roberta(
195
- input_ids=mlm_input_ids,
196
- attention_mask=mlm_attention_mask,
197
- token_type_ids=None,
198
- position_ids=None,
199
- head_mask=None,
200
- inputs_embeds=None,
201
- output_attentions=output_attentions,
202
- output_hidden_states=output_hidden_states,
203
- return_dict=return_dict,
204
- )
205
-
206
- # Pooling
207
- user_pooler_output = self.pooler(user_attention_mask, user_outputs)
208
- pos_item_pooler_output = self.pooler(pos_item_attention_mask, pos_item_outputs)
209
- neg_item_pooler_output = self.pooler(neg_item_attention_mask, neg_item_outputs)
210
-
211
- # If using "cls", we add an extra MLP layer
212
- # (same as BERT's original implementation) over the representation.
213
- if self.pooler_type == "cls":
214
- user_pooler_output = self.mlp(user_pooler_output)
215
- pos_item_pooler_output = self.mlp(pos_item_pooler_output)
216
- neg_item_pooler_output = self.mlp(neg_item_pooler_output)
217
-
218
- # Gather all item embeddings if using distributed training
219
- if dist.is_initialized() and self.training:
220
- # Dummy vectors for allgather
221
- user_list = [torch.zeros_like(user_pooler_output) for _ in range(dist.get_world_size())]
222
- pos_item_list = [torch.zeros_like(pos_item_pooler_output) for _ in range(dist.get_world_size())]
223
- neg_item_list = [torch.zeros_like(neg_item_pooler_output) for _ in range(dist.get_world_size())]
224
- # Allgather
225
- dist.all_gather(tensor_list=user_list, tensor=user_pooler_output.contiguous())
226
- dist.all_gather(tensor_list=pos_item_list, tensor=pos_item_pooler_output.contiguous())
227
- dist.all_gather(tensor_list=neg_item_list, tensor=neg_item_pooler_output.contiguous())
228
-
229
- # Since allgather results do not have gradients, we replace the
230
- # current process's corresponding embeddings with original tensors
231
- user_list[dist.get_rank()] = user_pooler_output
232
- pos_item_list[dist.get_rank()] = pos_item_pooler_output
233
- neg_item_list[dist.get_rank()] = neg_item_pooler_output
234
-
235
- # Get full batch embeddings
236
- user_pooler_output = torch.cat(user_list, dim=0)
237
- pos_item_pooler_output = torch.cat(pos_item_list, dim=0)
238
- neg_item_pooler_output = torch.cat(neg_item_list, dim=0)
239
-
240
- cos_sim = self.sim(user_pooler_output.unsqueeze(1), pos_item_pooler_output.unsqueeze(0))
241
- neg_sim = self.sim(user_pooler_output.unsqueeze(1), neg_item_pooler_output.unsqueeze(0))
242
- cos_sim = torch.cat([cos_sim, neg_sim], 1)
243
-
244
- labels = torch.arange(cos_sim.size(0)).long().to(self.device)
245
- loss_fct = nn.CrossEntropyLoss()
246
-
247
- loss = loss_fct(cos_sim, labels)
248
-
249
- # Calculate loss for MLM
250
- if mlm_outputs is not None and mlm_labels is not None and self.model_args.do_mlm:
251
- mlm_labels = mlm_labels.view(-1, mlm_labels.size(-1))
252
- prediction_scores = self.lm_head(mlm_outputs.last_hidden_state)
253
- masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), mlm_labels.view(-1))
254
- loss = loss + self.model_args.mlm_weight * masked_lm_loss
255
-
256
- if not return_dict:
257
- raise NotImplementedError
258
-
259
- return SequenceClassifierOutput(
260
- loss=loss,
261
- logits=cos_sim,
262
- )
263
-
264
- def encode(self,
265
- input_ids=None,
266
- attention_mask=None,
267
- token_type_ids=None,
268
- position_ids=None,
269
- head_mask=None,
270
- inputs_embeds=None,
271
- labels=None,
272
- output_attentions=None,
273
- output_hidden_states=None,
274
- return_dict=None,
275
- ):
276
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
277
- outputs = self.roberta(
278
- input_ids=input_ids,
279
- attention_mask=attention_mask,
280
- token_type_ids=None,
281
- position_ids=None,
282
- head_mask=None,
283
- inputs_embeds=None,
284
- output_attentions=output_attentions,
285
- output_hidden_states=output_hidden_states,
286
- return_dict=return_dict,
287
- )
288
- pooler_output = self.pooler(attention_mask, outputs)
289
- if self.pooler_type == "cls":
290
- pooler_output = self.mlp(pooler_output)
291
- if not return_dict:
292
- return (outputs[0], pooler_output) + outputs[2:]
293
-
294
- return BaseModelOutputWithPoolingAndCrossAttentions(
295
- pooler_output=pooler_output,
296
- last_hidden_state=outputs.last_hidden_state,
297
- hidden_states=outputs.hidden_states,
298
- )
299
-
300
- def inference(self,
301
- user_profile_list,
302
- item_profile_list,
303
- dataset_name,
304
- tokenizer,
305
- infer_batch_size=128
306
- ):
307
- n_user = len(user_profile_list)
308
- profiles = user_profile_list + item_profile_list
309
- n_batch = math.ceil(len(profiles) / infer_batch_size)
310
- text_embeds = []
311
- for i in tqdm(range(n_batch), desc=f'Encoding Text {dataset_name}'):
312
- batch_profiles = profiles[i * infer_batch_size: (i + 1) * infer_batch_size]
313
- inputs = tokenizer(batch_profiles, padding=True, truncation=True, max_length=512, return_tensors="pt")
314
- for k in inputs:
315
- inputs[k] = inputs[k].to(self.device)
316
- with torch.inference_mode():
317
- embeds = self.encode(
318
- input_ids=inputs.input_ids,
319
- attention_mask=inputs.attention_mask
320
- )
321
- text_embeds.append(embeds.pooler_output.detach().cpu())
322
- text_embeds = torch.concat(text_embeds, dim=0).cuda()
323
- user_embeds = F.normalize(text_embeds[: n_user], dim=-1)
324
- item_embeds = F.normalize(text_embeds[n_user: ], dim=-1)
325
- return user_embeds, item_embeds
 
1
+ import math
2
+ import torch
3
+ import numpy as np
4
+ import torch.nn as nn
5
+ from tqdm import tqdm
6
+ import scipy.sparse as sp
7
+ import torch.nn.functional as F
8
+ import torch.distributed as dist
9
+
10
+ import transformers
11
+ from transformers import RobertaTokenizer
12
+ from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel, RobertaModel, RobertaLMHead
13
+ from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel, BertLMPredictionHead
14
+ from transformers.activations import gelu
15
+ from transformers.file_utils import (
16
+ add_code_sample_docstrings,
17
+ add_start_docstrings,
18
+ add_start_docstrings_to_model_forward,
19
+ replace_return_docstrings,
20
+ )
21
+ from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutputWithPoolingAndCrossAttentions
22
+
23
+
24
+ init = nn.init.xavier_uniform_
25
+ uniformInit = nn.init.uniform
26
+
27
+
28
+ """
29
+ EasyRec
30
+ """
31
+ def dot_product_scores(q_vectors, ctx_vectors):
32
+ r = torch.matmul(q_vectors, torch.transpose(ctx_vectors, 0, 1))
33
+ return r
34
+
35
+ class MLPLayer(nn.Module):
36
+ """
37
+ Head for getting sentence representations over RoBERTa/BERT's CLS representation.
38
+ """
39
+ def __init__(self, config):
40
+ super().__init__()
41
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
42
+ self.activation = nn.Tanh()
43
+
44
+ def forward(self, features, **kwargs):
45
+ x = self.dense(features)
46
+ x = self.activation(x)
47
+ return x
48
+
49
+
50
+ class Pooler(nn.Module):
51
+ """
52
+ Parameter-free poolers to get the sentence embedding
53
+ 'cls': [CLS] representation with BERT/RoBERTa's MLP pooler.
54
+ 'cls_before_pooler': [CLS] representation without the original MLP pooler.
55
+ 'avg': average of the last layers' hidden states at each token.
56
+ 'avg_top2': average of the last two layers.
57
+ 'avg_first_last': average of the first and the last layers.
58
+ """
59
+ def __init__(self, pooler_type):
60
+ super().__init__()
61
+ self.pooler_type = pooler_type
62
+ assert self.pooler_type in ["cls", "cls_before_pooler", "avg", "avg_top2", "avg_first_last"], "unrecognized pooling type %s" % self.pooler_type
63
+
64
+ def forward(self, attention_mask, outputs):
65
+ last_hidden = outputs.last_hidden_state
66
+ pooler_output = outputs.pooler_output
67
+ hidden_states = outputs.hidden_states
68
+
69
+ if self.pooler_type in ['cls_before_pooler', 'cls']:
70
+ return last_hidden[:, 0]
71
+ elif self.pooler_type == "avg":
72
+ return ((last_hidden * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1))
73
+ elif self.pooler_type == "avg_first_last":
74
+ first_hidden = hidden_states[1]
75
+ last_hidden = hidden_states[-1]
76
+ pooled_result = ((first_hidden + last_hidden) / 2.0 * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1)
77
+ return pooled_result
78
+ elif self.pooler_type == "avg_top2":
79
+ second_last_hidden = hidden_states[-2]
80
+ last_hidden = hidden_states[-1]
81
+ pooled_result = ((last_hidden + second_last_hidden) / 2.0 * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1)
82
+ return pooled_result
83
+ else:
84
+ raise NotImplementedError
85
+
86
+
87
+ class Similarity(nn.Module):
88
+ """
89
+ Dot product or cosine similarity
90
+ """
91
+ def __init__(self, temp):
92
+ super().__init__()
93
+ self.temp = temp
94
+ self.cos = nn.CosineSimilarity(dim=-1)
95
+
96
+ def forward(self, x, y):
97
+ return self.cos(x, y) / self.temp
98
+
99
+
100
+ class Easyrec(RobertaPreTrainedModel):
101
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
102
+
103
+ def __init__(self, config, *model_args, **model_kargs):
104
+ super().__init__(config)
105
+ try:
106
+ self.model_args = model_kargs["model_args"]
107
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
108
+ if self.model_args.pooler_type == "cls":
109
+ self.mlp = MLPLayer(config)
110
+ if self.model_args.do_mlm:
111
+ self.lm_head = RobertaLMHead(config)
112
+ """
113
+ Contrastive learning class init function.
114
+ """
115
+ self.pooler_type = self.model_args.pooler_type
116
+ self.pooler = Pooler(self.pooler_type)
117
+ self.sim = Similarity(temp=self.model_args.temp)
118
+ self.init_weights()
119
+ except:
120
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
121
+ self.mlp = MLPLayer(config)
122
+ self.lm_head = RobertaLMHead(config)
123
+ self.pooler_type = 'cls'
124
+ self.pooler = Pooler(self.pooler_type)
125
+ self.init_weights()
126
+
127
+ def forward(self,
128
+ user_input_ids=None,
129
+ user_attention_mask=None,
130
+ pos_item_input_ids=None,
131
+ pos_item_attention_mask=None,
132
+ neg_item_input_ids=None,
133
+ neg_item_attention_mask=None,
134
+ token_type_ids=None,
135
+ position_ids=None,
136
+ head_mask=None,
137
+ inputs_embeds=None,
138
+ labels=None,
139
+ output_attentions=None,
140
+ output_hidden_states=None,
141
+ return_dict=None,
142
+ mlm_input_ids=None,
143
+ mlm_attention_mask=None,
144
+ mlm_labels=None,
145
+ ):
146
+ """
147
+ Contrastive learning forward function.
148
+ """
149
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
150
+ batch_size = user_input_ids.size(0)
151
+
152
+ # Get user embeddings
153
+ user_outputs = self.roberta(
154
+ input_ids=user_input_ids,
155
+ attention_mask=user_attention_mask,
156
+ token_type_ids=None,
157
+ position_ids=None,
158
+ head_mask=None,
159
+ inputs_embeds=None,
160
+ output_attentions=output_attentions,
161
+ output_hidden_states=output_hidden_states,
162
+ return_dict=return_dict,
163
+ )
164
+
165
+ # Get positive item embeddings
166
+ pos_item_outputs = self.roberta(
167
+ input_ids=pos_item_input_ids,
168
+ attention_mask=pos_item_attention_mask,
169
+ token_type_ids=None,
170
+ position_ids=None,
171
+ head_mask=None,
172
+ inputs_embeds=None,
173
+ output_attentions=output_attentions,
174
+ output_hidden_states=output_hidden_states,
175
+ return_dict=return_dict,
176
+ )
177
+
178
+ # Get negative item embeddings
179
+ neg_item_outputs = self.roberta(
180
+ input_ids=neg_item_input_ids,
181
+ attention_mask=neg_item_attention_mask,
182
+ token_type_ids=None,
183
+ position_ids=None,
184
+ head_mask=None,
185
+ inputs_embeds=None,
186
+ output_attentions=output_attentions,
187
+ output_hidden_states=output_hidden_states,
188
+ return_dict=return_dict,
189
+ )
190
+
191
+ # MLM auxiliary objective
192
+ if mlm_input_ids is not None:
193
+ mlm_outputs = self.roberta(
194
+ input_ids=mlm_input_ids,
195
+ attention_mask=mlm_attention_mask,
196
+ token_type_ids=None,
197
+ position_ids=None,
198
+ head_mask=None,
199
+ inputs_embeds=None,
200
+ output_attentions=output_attentions,
201
+ output_hidden_states=output_hidden_states,
202
+ return_dict=return_dict,
203
+ )
204
+
205
+ # Pooling
206
+ user_pooler_output = self.pooler(user_attention_mask, user_outputs)
207
+ pos_item_pooler_output = self.pooler(pos_item_attention_mask, pos_item_outputs)
208
+ neg_item_pooler_output = self.pooler(neg_item_attention_mask, neg_item_outputs)
209
+
210
+ # If using "cls", we add an extra MLP layer
211
+ # (same as BERT's original implementation) over the representation.
212
+ if self.pooler_type == "cls":
213
+ user_pooler_output = self.mlp(user_pooler_output)
214
+ pos_item_pooler_output = self.mlp(pos_item_pooler_output)
215
+ neg_item_pooler_output = self.mlp(neg_item_pooler_output)
216
+
217
+ # Gather all item embeddings if using distributed training
218
+ if dist.is_initialized() and self.training:
219
+ # Dummy vectors for allgather
220
+ user_list = [torch.zeros_like(user_pooler_output) for _ in range(dist.get_world_size())]
221
+ pos_item_list = [torch.zeros_like(pos_item_pooler_output) for _ in range(dist.get_world_size())]
222
+ neg_item_list = [torch.zeros_like(neg_item_pooler_output) for _ in range(dist.get_world_size())]
223
+ # Allgather
224
+ dist.all_gather(tensor_list=user_list, tensor=user_pooler_output.contiguous())
225
+ dist.all_gather(tensor_list=pos_item_list, tensor=pos_item_pooler_output.contiguous())
226
+ dist.all_gather(tensor_list=neg_item_list, tensor=neg_item_pooler_output.contiguous())
227
+
228
+ # Since allgather results do not have gradients, we replace the
229
+ # current process's corresponding embeddings with original tensors
230
+ user_list[dist.get_rank()] = user_pooler_output
231
+ pos_item_list[dist.get_rank()] = pos_item_pooler_output
232
+ neg_item_list[dist.get_rank()] = neg_item_pooler_output
233
+
234
+ # Get full batch embeddings
235
+ user_pooler_output = torch.cat(user_list, dim=0)
236
+ pos_item_pooler_output = torch.cat(pos_item_list, dim=0)
237
+ neg_item_pooler_output = torch.cat(neg_item_list, dim=0)
238
+
239
+ cos_sim = self.sim(user_pooler_output.unsqueeze(1), pos_item_pooler_output.unsqueeze(0))
240
+ neg_sim = self.sim(user_pooler_output.unsqueeze(1), neg_item_pooler_output.unsqueeze(0))
241
+ cos_sim = torch.cat([cos_sim, neg_sim], 1)
242
+
243
+ labels = torch.arange(cos_sim.size(0)).long().to(self.device)
244
+ loss_fct = nn.CrossEntropyLoss()
245
+
246
+ loss = loss_fct(cos_sim, labels)
247
+
248
+ # Calculate loss for MLM
249
+ if mlm_outputs is not None and mlm_labels is not None and self.model_args.do_mlm:
250
+ mlm_labels = mlm_labels.view(-1, mlm_labels.size(-1))
251
+ prediction_scores = self.lm_head(mlm_outputs.last_hidden_state)
252
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), mlm_labels.view(-1))
253
+ loss = loss + self.model_args.mlm_weight * masked_lm_loss
254
+
255
+ if not return_dict:
256
+ raise NotImplementedError
257
+
258
+ return SequenceClassifierOutput(
259
+ loss=loss,
260
+ logits=cos_sim,
261
+ )
262
+
263
+ def encode(self,
264
+ input_ids=None,
265
+ attention_mask=None,
266
+ token_type_ids=None,
267
+ position_ids=None,
268
+ head_mask=None,
269
+ inputs_embeds=None,
270
+ labels=None,
271
+ output_attentions=None,
272
+ output_hidden_states=None,
273
+ return_dict=None,
274
+ ):
275
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
276
+ outputs = self.roberta(
277
+ input_ids=input_ids,
278
+ attention_mask=attention_mask,
279
+ token_type_ids=None,
280
+ position_ids=None,
281
+ head_mask=None,
282
+ inputs_embeds=None,
283
+ output_attentions=output_attentions,
284
+ output_hidden_states=output_hidden_states,
285
+ return_dict=return_dict,
286
+ )
287
+ pooler_output = self.pooler(attention_mask, outputs)
288
+ if self.pooler_type == "cls":
289
+ pooler_output = self.mlp(pooler_output)
290
+ if not return_dict:
291
+ return (outputs[0], pooler_output) + outputs[2:]
292
+
293
+ return BaseModelOutputWithPoolingAndCrossAttentions(
294
+ pooler_output=pooler_output,
295
+ last_hidden_state=outputs.last_hidden_state,
296
+ hidden_states=outputs.hidden_states,
297
+ )
298
+
299
+ def inference(self,
300
+ user_profile_list,
301
+ item_profile_list,
302
+ dataset_name,
303
+ tokenizer,
304
+ infer_batch_size=128
305
+ ):
306
+ n_user = len(user_profile_list)
307
+ profiles = user_profile_list + item_profile_list
308
+ n_batch = math.ceil(len(profiles) / infer_batch_size)
309
+ text_embeds = []
310
+ for i in tqdm(range(n_batch), desc=f'Encoding Text {dataset_name}'):
311
+ batch_profiles = profiles[i * infer_batch_size: (i + 1) * infer_batch_size]
312
+ inputs = tokenizer(batch_profiles, padding=True, truncation=True, max_length=512, return_tensors="pt")
313
+ for k in inputs:
314
+ inputs[k] = inputs[k].to(self.device)
315
+ with torch.inference_mode():
316
+ embeds = self.encode(
317
+ input_ids=inputs.input_ids,
318
+ attention_mask=inputs.attention_mask
319
+ )
320
+ text_embeds.append(embeds.pooler_output.detach().cpu())
321
+ text_embeds = torch.concat(text_embeds, dim=0).cuda()
322
+ user_embeds = F.normalize(text_embeds[: n_user], dim=-1)
323
+ item_embeds = F.normalize(text_embeds[n_user: ], dim=-1)
324
+ return user_embeds, item_embeds