Roman Solomatin
commited on
fix shapes
Browse files- listconranker.py +9 -31
listconranker.py
CHANGED
|
@@ -39,12 +39,11 @@ from typing import Union, List, Optional
|
|
| 39 |
class ListConRankerConfig(PretrainedConfig):
|
| 40 |
"""Configuration class for ListConRanker model."""
|
| 41 |
|
| 42 |
-
model_type = "
|
| 43 |
|
| 44 |
def __init__(
|
| 45 |
self,
|
| 46 |
list_transformer_layers: int = 2,
|
| 47 |
-
num_attention_heads: int = 8,
|
| 48 |
hidden_size: int = 1792,
|
| 49 |
base_hidden_size: int = 1024,
|
| 50 |
num_labels: int = 1,
|
|
@@ -52,12 +51,12 @@ class ListConRankerConfig(PretrainedConfig):
|
|
| 52 |
):
|
| 53 |
super().__init__(**kwargs)
|
| 54 |
self.list_transformer_layers = list_transformer_layers
|
| 55 |
-
self.num_attention_heads = num_attention_heads
|
| 56 |
self.hidden_size = hidden_size
|
| 57 |
self.base_hidden_size = base_hidden_size
|
| 58 |
self.num_labels = num_labels
|
| 59 |
|
| 60 |
self.bert_config = BertConfig(**kwargs)
|
|
|
|
| 61 |
self.bert_config.output_hidden_states = True
|
| 62 |
|
| 63 |
class QueryEmbedding(nn.Module):
|
|
@@ -85,7 +84,8 @@ class ListTransformer(nn.Module):
|
|
| 85 |
self.linear_score2 = nn.Linear(config.hidden_size * 2, config.hidden_size)
|
| 86 |
self.linear_score1 = nn.Linear(config.hidden_size * 2, 1)
|
| 87 |
|
| 88 |
-
def forward(self, pair_features
|
|
|
|
| 89 |
pair_nums = [x + 1 for x in pair_nums]
|
| 90 |
batch_pair_features = pair_features.split(pair_nums)
|
| 91 |
|
|
@@ -154,7 +154,7 @@ class ListConRankerModel(PreTrainedModel):
|
|
| 154 |
super().__init__(config)
|
| 155 |
self.config = config
|
| 156 |
self.num_labels = config.num_labels
|
| 157 |
-
self.hf_model = BertModel(config)
|
| 158 |
|
| 159 |
self.sigmoid = nn.Sigmoid()
|
| 160 |
|
|
@@ -176,17 +176,8 @@ class ListConRankerModel(PreTrainedModel):
|
|
| 176 |
output_attentions: Optional[bool] = None,
|
| 177 |
output_hidden_states: Optional[bool] = None,
|
| 178 |
return_dict: Optional[bool] = None,
|
| 179 |
-
pair_num: Optional[torch.Tensor] = None,
|
| 180 |
**kwargs
|
| 181 |
) -> Union[SequenceClassifierOutput, tuple]:
|
| 182 |
-
# Handle pair_num parameter
|
| 183 |
-
if pair_num is not None:
|
| 184 |
-
pair_nums = pair_num.tolist()
|
| 185 |
-
else:
|
| 186 |
-
# Default behavior if pair_num is not provided
|
| 187 |
-
batch_size = input_ids.size(0) if input_ids is not None else inputs_embeds.size(0)
|
| 188 |
-
pair_nums = [1] * batch_size
|
| 189 |
-
|
| 190 |
# Get device
|
| 191 |
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 192 |
self.list_transformer.device = device
|
|
@@ -195,20 +186,7 @@ class ListConRankerModel(PreTrainedModel):
|
|
| 195 |
if self.training:
|
| 196 |
pass
|
| 197 |
else:
|
| 198 |
-
|
| 199 |
-
if sum(pair_nums) > split_batch:
|
| 200 |
-
last_hidden_state_list = []
|
| 201 |
-
input_ids_list = input_ids.split(split_batch)
|
| 202 |
-
attention_mask_list = attention_mask.split(split_batch)
|
| 203 |
-
for i in range(len(input_ids_list)):
|
| 204 |
-
last_hidden_state = self.hf_model(
|
| 205 |
-
input_ids=input_ids_list[i],
|
| 206 |
-
attention_mask=attention_mask_list[i],
|
| 207 |
-
return_dict=True).hidden_states[-1]
|
| 208 |
-
last_hidden_state_list.append(last_hidden_state)
|
| 209 |
-
last_hidden_state = torch.cat(last_hidden_state_list, dim=0)
|
| 210 |
-
else:
|
| 211 |
-
ranker_out = self.hf_model(
|
| 212 |
input_ids=input_ids,
|
| 213 |
attention_mask=attention_mask,
|
| 214 |
token_type_ids=token_type_ids,
|
|
@@ -217,12 +195,12 @@ class ListConRankerModel(PreTrainedModel):
|
|
| 217 |
inputs_embeds=inputs_embeds,
|
| 218 |
output_attentions=output_attentions,
|
| 219 |
return_dict=True)
|
| 220 |
-
|
| 221 |
|
| 222 |
pair_features = self.average_pooling(last_hidden_state, attention_mask)
|
| 223 |
pair_features = self.linear_in_embedding(pair_features)
|
| 224 |
|
| 225 |
-
logits, pair_features_after_list_transformer = self.list_transformer(pair_features
|
| 226 |
logits = self.sigmoid(logits)
|
| 227 |
|
| 228 |
return logits
|
|
@@ -249,4 +227,4 @@ class ListConRankerModel(PreTrainedModel):
|
|
| 249 |
except FileNotFoundError:
|
| 250 |
print(f"Warning: Could not load custom weights from {model_name_or_path}")
|
| 251 |
|
| 252 |
-
return model
|
|
|
|
| 39 |
class ListConRankerConfig(PretrainedConfig):
|
| 40 |
"""Configuration class for ListConRanker model."""
|
| 41 |
|
| 42 |
+
model_type = "ListConRanker"
|
| 43 |
|
| 44 |
def __init__(
|
| 45 |
self,
|
| 46 |
list_transformer_layers: int = 2,
|
|
|
|
| 47 |
hidden_size: int = 1792,
|
| 48 |
base_hidden_size: int = 1024,
|
| 49 |
num_labels: int = 1,
|
|
|
|
| 51 |
):
|
| 52 |
super().__init__(**kwargs)
|
| 53 |
self.list_transformer_layers = list_transformer_layers
|
|
|
|
| 54 |
self.hidden_size = hidden_size
|
| 55 |
self.base_hidden_size = base_hidden_size
|
| 56 |
self.num_labels = num_labels
|
| 57 |
|
| 58 |
self.bert_config = BertConfig(**kwargs)
|
| 59 |
+
self.bert_config.hidden_size = self.base_hidden_size
|
| 60 |
self.bert_config.output_hidden_states = True
|
| 61 |
|
| 62 |
class QueryEmbedding(nn.Module):
|
|
|
|
| 84 |
self.linear_score2 = nn.Linear(config.hidden_size * 2, config.hidden_size)
|
| 85 |
self.linear_score1 = nn.Linear(config.hidden_size * 2, 1)
|
| 86 |
|
| 87 |
+
def forward(self, pair_features: torch.Tensor):
|
| 88 |
+
pair_nums = pair_features.size(0)
|
| 89 |
pair_nums = [x + 1 for x in pair_nums]
|
| 90 |
batch_pair_features = pair_features.split(pair_nums)
|
| 91 |
|
|
|
|
| 154 |
super().__init__(config)
|
| 155 |
self.config = config
|
| 156 |
self.num_labels = config.num_labels
|
| 157 |
+
self.hf_model = BertModel(config.bert_config)
|
| 158 |
|
| 159 |
self.sigmoid = nn.Sigmoid()
|
| 160 |
|
|
|
|
| 176 |
output_attentions: Optional[bool] = None,
|
| 177 |
output_hidden_states: Optional[bool] = None,
|
| 178 |
return_dict: Optional[bool] = None,
|
|
|
|
| 179 |
**kwargs
|
| 180 |
) -> Union[SequenceClassifierOutput, tuple]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
# Get device
|
| 182 |
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 183 |
self.list_transformer.device = device
|
|
|
|
| 186 |
if self.training:
|
| 187 |
pass
|
| 188 |
else:
|
| 189 |
+
ranker_out = self.hf_model(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
input_ids=input_ids,
|
| 191 |
attention_mask=attention_mask,
|
| 192 |
token_type_ids=token_type_ids,
|
|
|
|
| 195 |
inputs_embeds=inputs_embeds,
|
| 196 |
output_attentions=output_attentions,
|
| 197 |
return_dict=True)
|
| 198 |
+
last_hidden_state = ranker_out.last_hidden_state
|
| 199 |
|
| 200 |
pair_features = self.average_pooling(last_hidden_state, attention_mask)
|
| 201 |
pair_features = self.linear_in_embedding(pair_features)
|
| 202 |
|
| 203 |
+
logits, pair_features_after_list_transformer = self.list_transformer(pair_features)
|
| 204 |
logits = self.sigmoid(logits)
|
| 205 |
|
| 206 |
return logits
|
|
|
|
| 227 |
except FileNotFoundError:
|
| 228 |
print(f"Warning: Could not load custom weights from {model_name_or_path}")
|
| 229 |
|
| 230 |
+
return model
|