Upload 14 files
Browse files- src/cross_rerank/__init__.py +0 -0
- src/cross_rerank/collator.py +30 -0
- src/cross_rerank/config.py +230 -0
- src/cross_rerank/data_loader.py +106 -0
- src/cross_rerank/data_utils.py +211 -0
- src/cross_rerank/loader_utils.py +46 -0
- src/cross_rerank/logger_config.py +32 -0
- src/cross_rerank/loss.py +36 -0
- src/cross_rerank/metrics.py +105 -0
- src/cross_rerank/model.py +58 -0
- src/cross_rerank/trainer.py +247 -0
- src/cross_rerank/utils.py +137 -0
- src/eval_cross.py +512 -0
- src/train_cross_encoder.py +112 -0
src/cross_rerank/__init__.py
ADDED
File without changes
|
src/cross_rerank/collator.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import transformers
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import List, Dict, Any
|
5 |
+
from transformers import BatchEncoding, DataCollatorWithPadding
|
6 |
+
|
7 |
+
|
8 |
+
@dataclass
|
9 |
+
class CrossEncoderCollator(DataCollatorWithPadding):
|
10 |
+
|
11 |
+
def __call__(self, features: List[Dict[str, Any]]) -> BatchEncoding:
|
12 |
+
unpack_features = []
|
13 |
+
for ex in features:
|
14 |
+
keys = list(ex.keys())
|
15 |
+
# assert all(len(ex[k]) == 8 for k in keys)
|
16 |
+
for idx in range(len(ex[keys[0]])):
|
17 |
+
unpack_features.append({k: ex[k][idx] for k in keys})
|
18 |
+
|
19 |
+
old_level = transformers.logging.get_verbosity()
|
20 |
+
transformers.logging.set_verbosity_error()
|
21 |
+
collated_batch_dict = self.tokenizer.pad(
|
22 |
+
unpack_features,
|
23 |
+
padding=self.padding,
|
24 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
25 |
+
return_tensors=self.return_tensors)
|
26 |
+
transformers.logging.set_verbosity(old_level)
|
27 |
+
|
28 |
+
collated_batch_dict['labels'] = torch.zeros(len(features), dtype=torch.long)
|
29 |
+
|
30 |
+
return collated_batch_dict
|
src/cross_rerank/config.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from dataclasses import dataclass, field
|
4 |
+
from typing import Optional
|
5 |
+
from transformers import TrainingArguments
|
6 |
+
|
7 |
+
|
8 |
+
from .logger_config import logger
|
9 |
+
|
10 |
+
|
11 |
+
@dataclass
|
12 |
+
class Arguments(TrainingArguments):
|
13 |
+
model_name_or_path: str = field(
|
14 |
+
default='bert-base-uncased',
|
15 |
+
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
16 |
+
)
|
17 |
+
|
18 |
+
corpus_file: str = field(
|
19 |
+
default=None, metadata={"help": "Path to corpus file"}
|
20 |
+
)
|
21 |
+
|
22 |
+
data_dir: str = field(
|
23 |
+
default=None, metadata={"help": "Path to train directory"}
|
24 |
+
)
|
25 |
+
task_type: str = field(
|
26 |
+
default='ir', metadata={"help": "task type: ir / qa"}
|
27 |
+
)
|
28 |
+
train_file: Optional[str] = field(
|
29 |
+
default=None, metadata={"help": "The input training data file (a jsonlines file)."}
|
30 |
+
)
|
31 |
+
validation_file: Optional[str] = field(
|
32 |
+
default=None,
|
33 |
+
metadata={
|
34 |
+
"help": "An optional input evaluation data file to evaluate the metrics on (a jsonlines file)."
|
35 |
+
},
|
36 |
+
)
|
37 |
+
|
38 |
+
train_n_passages: int = field(
|
39 |
+
default=8,
|
40 |
+
metadata={"help": "number of passages for each example (including both positive and negative passages)"}
|
41 |
+
)
|
42 |
+
share_encoder: bool = field(
|
43 |
+
default=True,
|
44 |
+
metadata={"help": "no weight sharing between qry passage encoders"}
|
45 |
+
)
|
46 |
+
use_first_positive: bool = field(
|
47 |
+
default=False,
|
48 |
+
metadata={"help": "Always use the first positive passage"}
|
49 |
+
)
|
50 |
+
use_first_negative: bool = field(
|
51 |
+
default=False,
|
52 |
+
metadata={"help": "Always use the first positive passage"}
|
53 |
+
)
|
54 |
+
use_scaled_loss: bool = field(
|
55 |
+
default=True,
|
56 |
+
metadata={"help": "Use scaled loss or not"}
|
57 |
+
)
|
58 |
+
loss_scale: float = field(
|
59 |
+
default=-1.,
|
60 |
+
metadata={"help": "loss scale, -1 will use world_size"}
|
61 |
+
)
|
62 |
+
add_pooler: bool = field(default=False)
|
63 |
+
out_dimension: int = field(
|
64 |
+
default=768,
|
65 |
+
metadata={"help": "output dimension for pooler"}
|
66 |
+
)
|
67 |
+
t: float = field(default=0.05, metadata={"help": "temperature of biencoder training"})
|
68 |
+
l2_normalize: bool = field(default=True, metadata={"help": "L2 normalize embeddings or not"})
|
69 |
+
t_warmup: bool = field(default=False, metadata={"help": "warmup temperature"})
|
70 |
+
full_contrastive_loss: bool = field(default=True, metadata={"help": "use full contrastive loss or not"})
|
71 |
+
|
72 |
+
# following arguments are used for encoding documents
|
73 |
+
do_encode: bool = field(default=False, metadata={"help": "run the encoding loop"})
|
74 |
+
encode_in_path: str = field(default=None, metadata={"help": "Path to data to encode"})
|
75 |
+
encode_save_dir: str = field(default=None, metadata={"help": "where to save the encode"})
|
76 |
+
encode_shard_size: int = field(default=int(2 * 10**6))
|
77 |
+
encode_batch_size: int = field(default=256)
|
78 |
+
|
79 |
+
# used for index search
|
80 |
+
do_search: bool = field(default=False, metadata={"help": "run the index search loop"})
|
81 |
+
search_split: str = field(default='dev', metadata={"help": "which split to search"})
|
82 |
+
search_batch_size: int = field(default=128, metadata={"help": "query batch size for index search"})
|
83 |
+
search_topk: int = field(default=200, metadata={"help": "return topk search results"})
|
84 |
+
search_out_dir: str = field(default='', metadata={"help": "output directory for writing search results"})
|
85 |
+
|
86 |
+
# used for reranking
|
87 |
+
do_rerank: bool = field(default=False, metadata={"help": "run the reranking loop"})
|
88 |
+
rerank_max_length: int = field(default=256, metadata={"help": "max length for rerank inputs"})
|
89 |
+
rerank_in_path: str = field(default='', metadata={"help": "Path to predictions for rerank"})
|
90 |
+
rerank_out_path: str = field(default='', metadata={"help": "Path to write rerank results"})
|
91 |
+
rerank_split: str = field(default='dev', metadata={"help": "which split to rerank"})
|
92 |
+
rerank_batch_size: int = field(default=128, metadata={"help": "rerank batch size"})
|
93 |
+
rerank_depth: int = field(default=1000, metadata={"help": "rerank depth, useful for debugging purpose"})
|
94 |
+
rerank_forward_factor: int = field(
|
95 |
+
default=1,
|
96 |
+
metadata={"help": "forward n passages, then select top n/factor passages for backward"}
|
97 |
+
)
|
98 |
+
rerank_use_rdrop: bool = field(default=False, metadata={"help": "use R-Drop regularization for re-ranker"})
|
99 |
+
|
100 |
+
# used for knowledge distillation
|
101 |
+
do_kd_gen_score: bool = field(default=False, metadata={"help": "run the score generation for distillation"})
|
102 |
+
kd_gen_score_split: str = field(default='dev', metadata={
|
103 |
+
"help": "Which split to use for generation of teacher score"
|
104 |
+
})
|
105 |
+
kd_gen_score_batch_size: int = field(default=128, metadata={"help": "batch size for teacher score generation"})
|
106 |
+
kd_gen_score_n_neg: int = field(default=30, metadata={"help": "number of negatives to compute teacher scores"})
|
107 |
+
|
108 |
+
do_kd_biencoder: bool = field(default=False, metadata={"help": "knowledge distillation to biencoder"})
|
109 |
+
kd_mask_hn: bool = field(default=True, metadata={"help": "mask out hard negatives for distillation"})
|
110 |
+
kd_cont_loss_weight: float = field(default=1.0, metadata={"help": "weight for contrastive loss"})
|
111 |
+
|
112 |
+
rlm_generator_model_name: Optional[str] = field(
|
113 |
+
default='google/electra-base-generator',
|
114 |
+
metadata={"help": "generator for replace LM pre-training"}
|
115 |
+
)
|
116 |
+
rlm_freeze_generator: Optional[bool] = field(
|
117 |
+
default=True,
|
118 |
+
metadata={'help': 'freeze generator params or not'}
|
119 |
+
)
|
120 |
+
rlm_generator_mlm_weight: Optional[float] = field(
|
121 |
+
default=0.2,
|
122 |
+
metadata={'help': 'weight for generator MLM loss'}
|
123 |
+
)
|
124 |
+
all_use_mask_token: Optional[bool] = field(
|
125 |
+
default=False,
|
126 |
+
metadata={'help': 'Do not use 80:10:10 mask, use [MASK] for all places'}
|
127 |
+
)
|
128 |
+
rlm_num_eval_samples: Optional[int] = field(
|
129 |
+
default=4096,
|
130 |
+
metadata={"help": "number of evaluation samples pre-training"}
|
131 |
+
)
|
132 |
+
rlm_max_length: Optional[int] = field(
|
133 |
+
default=144,
|
134 |
+
metadata={"help": "max length for MatchLM pre-training"}
|
135 |
+
)
|
136 |
+
rlm_decoder_layers: Optional[int] = field(
|
137 |
+
default=2,
|
138 |
+
metadata={"help": "number of transformer layers for MatchLM decoder part"}
|
139 |
+
)
|
140 |
+
rlm_encoder_mask_prob: Optional[float] = field(
|
141 |
+
default=0.3,
|
142 |
+
metadata={'help': 'mask rate for encoder'}
|
143 |
+
)
|
144 |
+
rlm_decoder_mask_prob: Optional[float] = field(
|
145 |
+
default=0.5,
|
146 |
+
metadata={'help': 'mask rate for decoder'}
|
147 |
+
)
|
148 |
+
|
149 |
+
q_max_len: int = field(
|
150 |
+
default=32,
|
151 |
+
metadata={
|
152 |
+
"help": "The maximum total input sequence length after tokenization for query."
|
153 |
+
},
|
154 |
+
)
|
155 |
+
p_max_len: int = field(
|
156 |
+
default=144,
|
157 |
+
metadata={
|
158 |
+
"help": "The maximum total input sequence length after tokenization for passage."
|
159 |
+
},
|
160 |
+
)
|
161 |
+
|
162 |
+
chunk_size: int = field(
|
163 |
+
default=8,
|
164 |
+
metadata={
|
165 |
+
"help": "The maximum total chunk"
|
166 |
+
},
|
167 |
+
)
|
168 |
+
|
169 |
+
max_train_samples: Optional[int] = field(
|
170 |
+
default=None,
|
171 |
+
metadata={
|
172 |
+
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
173 |
+
"value if set."
|
174 |
+
},
|
175 |
+
)
|
176 |
+
dry_run: Optional[bool] = field(
|
177 |
+
default=False,
|
178 |
+
metadata={'help': 'Set dry_run to True for debugging purpose'}
|
179 |
+
)
|
180 |
+
|
181 |
+
def __post_init__(self):
|
182 |
+
assert os.path.exists(self.data_dir)
|
183 |
+
assert torch.cuda.is_available(), 'Only support running on GPUs'
|
184 |
+
assert self.task_type in ['ir', 'qa']
|
185 |
+
|
186 |
+
if self.dry_run:
|
187 |
+
self.logging_steps = 1
|
188 |
+
self.max_train_samples = self.max_train_samples or 128
|
189 |
+
self.num_train_epochs = 1
|
190 |
+
self.per_device_train_batch_size = min(2, self.per_device_train_batch_size)
|
191 |
+
self.train_n_passages = min(4, self.train_n_passages)
|
192 |
+
self.rerank_forward_factor = 1
|
193 |
+
self.gradient_accumulation_steps = 1
|
194 |
+
self.rlm_num_eval_samples = min(256, self.rlm_num_eval_samples)
|
195 |
+
self.max_steps = 30
|
196 |
+
self.save_steps = self.eval_steps = 30
|
197 |
+
logger.warning('Dry run: set logging_steps=1')
|
198 |
+
|
199 |
+
if self.do_encode:
|
200 |
+
assert self.encode_save_dir
|
201 |
+
os.makedirs(self.encode_save_dir, exist_ok=True)
|
202 |
+
assert os.path.exists(self.encode_in_path)
|
203 |
+
|
204 |
+
if self.do_search:
|
205 |
+
assert os.path.exists(self.encode_save_dir)
|
206 |
+
assert self.search_out_dir
|
207 |
+
os.makedirs(self.search_out_dir, exist_ok=True)
|
208 |
+
|
209 |
+
if self.do_rerank:
|
210 |
+
assert os.path.exists(self.rerank_in_path)
|
211 |
+
logger.info('Rerank result will be written to {}'.format(self.rerank_out_path))
|
212 |
+
assert self.train_n_passages > 1, 'Having positive passages only does not make sense for training re-ranker'
|
213 |
+
assert self.train_n_passages % self.rerank_forward_factor == 0
|
214 |
+
|
215 |
+
if self.do_kd_gen_score:
|
216 |
+
assert os.path.exists('{}/{}.jsonl'.format(self.data_dir, self.kd_gen_score_split))
|
217 |
+
|
218 |
+
if self.do_kd_biencoder:
|
219 |
+
if self.use_scaled_loss:
|
220 |
+
assert not self.kd_mask_hn, 'Use scaled loss only works with not masking out hard negatives'
|
221 |
+
|
222 |
+
if torch.cuda.device_count() <= 1:
|
223 |
+
self.logging_steps = min(50, self.logging_steps)
|
224 |
+
|
225 |
+
super(Arguments, self).__post_init__()
|
226 |
+
|
227 |
+
if self.output_dir:
|
228 |
+
os.makedirs(self.output_dir, exist_ok=True)
|
229 |
+
|
230 |
+
self.label_names = ['labels']
|
src/cross_rerank/data_loader.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
import random
|
3 |
+
import pandas as pd
|
4 |
+
import transformers
|
5 |
+
from typing import Tuple, Dict, List, Optional
|
6 |
+
from datasets import load_dataset, DatasetDict, Dataset
|
7 |
+
from transformers.file_utils import PaddingStrategy
|
8 |
+
from transformers import PreTrainedTokenizerFast, Trainer
|
9 |
+
|
10 |
+
from .config import Arguments
|
11 |
+
from .logger_config import logger
|
12 |
+
from .loader_utils import group_doc_ids
|
13 |
+
|
14 |
+
|
15 |
+
class CrossEncoderDataLoader:
|
16 |
+
def __init__(self, args: Arguments, tokenizer):
|
17 |
+
self.args = args
|
18 |
+
self.negative_size = args.train_n_passages - 1
|
19 |
+
assert self.negative_size > 0
|
20 |
+
#self.hard_neg_size = args.hard_neg_size if args.hard_neg_size < self.negative_size else self.negative_size
|
21 |
+
#self.rand_neg_size = self.negative_size - self.hard_neg_size if self.hard_neg_size < self.negative_size else 0
|
22 |
+
self.tokenizer = tokenizer
|
23 |
+
#corpus_path = os.path.join(args.data_dir, 'passages.jsonl.gz')
|
24 |
+
#self.corpus: Dataset = load_dataset('json', data_files=corpus_path)['train']
|
25 |
+
self.corpus = pd.read_csv(args.corpus_file)
|
26 |
+
#self.corpus = dcorpus['tokenized_text'].to_list()
|
27 |
+
#self.corpus_bad_ids = [idx for idx in range(len(self.corpus)) if (type(self.corpus['bm25_text'][idx]) is not str)]
|
28 |
+
self.train_dataset, self.eval_dataset = self._get_transformed_datasets()
|
29 |
+
|
30 |
+
# use its state to decide which positives/negatives to sample
|
31 |
+
self.trainer: Optional[Trainer] = None
|
32 |
+
|
33 |
+
def _transform_func(self, examples: Dict[str, List]) -> Dict[str, List]:
|
34 |
+
current_epoch = int(self.trainer.state.epoch or 0)
|
35 |
+
|
36 |
+
input_doc_ids = group_doc_ids(
|
37 |
+
examples=examples,
|
38 |
+
negative_size=self.negative_size,
|
39 |
+
offset=current_epoch + self.args.seed,
|
40 |
+
use_first_positive=self.args.use_first_positive,
|
41 |
+
use_first_negative=self.args.use_first_negative
|
42 |
+
)
|
43 |
+
assert len(input_doc_ids) == len(examples['query']) * self.args.train_n_passages
|
44 |
+
|
45 |
+
input_queries, input_docs = [], []
|
46 |
+
for idx, doc_id in enumerate(input_doc_ids):
|
47 |
+
#prefix = ''
|
48 |
+
#if self.corpus[doc_id].get('title', ''):
|
49 |
+
# prefix = self.corpus[doc_id]['title'] + ': '
|
50 |
+
|
51 |
+
input_docs.append(self.corpus['tokenized_text'][doc_id])
|
52 |
+
input_queries.append(examples['query'][idx // self.args.train_n_passages])
|
53 |
+
|
54 |
+
old_level = transformers.logging.get_verbosity()
|
55 |
+
transformers.logging.set_verbosity_error()
|
56 |
+
batch_dict = self.tokenizer(input_queries,
|
57 |
+
text_pair=input_docs,
|
58 |
+
max_length=self.args.rerank_max_length,
|
59 |
+
padding=PaddingStrategy.DO_NOT_PAD,
|
60 |
+
truncation=True)
|
61 |
+
transformers.logging.set_verbosity(old_level)
|
62 |
+
|
63 |
+
packed_batch_dict = {}
|
64 |
+
for k in batch_dict:
|
65 |
+
packed_batch_dict[k] = []
|
66 |
+
assert len(examples['query']) * self.args.train_n_passages == len(batch_dict[k])
|
67 |
+
for idx in range(len(examples['query'])):
|
68 |
+
start = idx * self.args.train_n_passages
|
69 |
+
packed_batch_dict[k].append(batch_dict[k][start:(start + self.args.train_n_passages)])
|
70 |
+
|
71 |
+
return packed_batch_dict
|
72 |
+
|
73 |
+
def _get_transformed_datasets(self) -> Tuple:
|
74 |
+
#data_files = {}
|
75 |
+
#if self.args.train_file is not None:
|
76 |
+
# data_files["train"] = self.args.train_file.split(',')
|
77 |
+
#if self.args.validation_file is not None:
|
78 |
+
# data_files["validation"] = self.args.validation_file
|
79 |
+
#raw_datasets: DatasetDict = load_dataset('json', data_files=data_files)
|
80 |
+
|
81 |
+
train_dataset, eval_dataset = None, None
|
82 |
+
|
83 |
+
if self.args.do_train:
|
84 |
+
#if "train" not in raw_datasets:
|
85 |
+
try:
|
86 |
+
train_dataset = load_dataset('json', data_files = os.path.join(self.args.data_dir, 'train.jsonl'))['train']
|
87 |
+
except:
|
88 |
+
raise ValueError("--do_train requires a train dataset")
|
89 |
+
#train_dataset = raw_datasets["train"]
|
90 |
+
if self.args.max_train_samples is not None:
|
91 |
+
train_dataset = train_dataset.select(range(self.args.max_train_samples))
|
92 |
+
# Log a few random samples from the training set:
|
93 |
+
for index in random.sample(range(len(train_dataset)), 1):
|
94 |
+
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
|
95 |
+
train_dataset.set_transform(self._transform_func)
|
96 |
+
|
97 |
+
if self.args.do_eval:
|
98 |
+
#if "validation" not in raw_datasets:
|
99 |
+
try:
|
100 |
+
eval_dataset = load_dataset('json', data_files = os.path.join(self.args.data_dir, 'eval.jsonl'))['train']
|
101 |
+
except:
|
102 |
+
raise ValueError("--do_eval requires a validation dataset")
|
103 |
+
#eval_dataset = raw_datasets["validation"]
|
104 |
+
eval_dataset.set_transform(self._transform_func)
|
105 |
+
|
106 |
+
return train_dataset, eval_dataset
|
src/cross_rerank/data_utils.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import tqdm
|
4 |
+
import json
|
5 |
+
|
6 |
+
from typing import Dict, List, Any
|
7 |
+
from datasets import load_dataset, Dataset
|
8 |
+
from dataclasses import dataclass, field
|
9 |
+
|
10 |
+
from .logger_config import logger
|
11 |
+
from .config import Arguments
|
12 |
+
from .utils import save_json_to_file
|
13 |
+
|
14 |
+
|
15 |
+
@dataclass
|
16 |
+
class ScoredDoc:
|
17 |
+
qid: str
|
18 |
+
pid: str
|
19 |
+
rank: int
|
20 |
+
score: float = field(default=-1)
|
21 |
+
|
22 |
+
|
23 |
+
def load_qrels(path: str) -> Dict[str, Dict[str, int]]:
|
24 |
+
assert path.endswith('.txt')
|
25 |
+
|
26 |
+
# qid -> pid -> score
|
27 |
+
qrels = {}
|
28 |
+
for line in open(path, 'r', encoding='utf-8'):
|
29 |
+
qid, _, pid, score = line.strip().split('\t')
|
30 |
+
if qid not in qrels:
|
31 |
+
qrels[qid] = {}
|
32 |
+
qrels[qid][pid] = int(score)
|
33 |
+
|
34 |
+
logger.info('Load {} queries {} qrels from {}'.format(len(qrels), sum(len(v) for v in qrels.values()), path))
|
35 |
+
return qrels
|
36 |
+
|
37 |
+
|
38 |
+
def load_queries(path: str, task_type: str = 'ir') -> Dict[str, str]:
|
39 |
+
assert path.endswith('.tsv')
|
40 |
+
|
41 |
+
if task_type == 'qa':
|
42 |
+
qid_to_query = load_query_answers(path)
|
43 |
+
qid_to_query = {k: v['query'] for k, v in qid_to_query.items()}
|
44 |
+
elif task_type == 'ir':
|
45 |
+
qid_to_query = {}
|
46 |
+
for line in open(path, 'r', encoding='utf-8'):
|
47 |
+
qid, query = line.strip().split('\t')
|
48 |
+
qid_to_query[qid] = query
|
49 |
+
else:
|
50 |
+
raise ValueError('Unknown task type: {}'.format(task_type))
|
51 |
+
|
52 |
+
logger.info('Load {} queries from {}'.format(len(qid_to_query), path))
|
53 |
+
return qid_to_query
|
54 |
+
|
55 |
+
|
56 |
+
def normalize_qa_text(text: str) -> str:
|
57 |
+
# TriviaQA has some weird formats
|
58 |
+
# For example: """What breakfast food gets its name from the German word for """"stirrup""""?"""
|
59 |
+
while text.startswith('"') and text.endswith('"'):
|
60 |
+
text = text[1:-1].replace('""', '"')
|
61 |
+
return text
|
62 |
+
|
63 |
+
|
64 |
+
def get_question_key(question: str) -> str:
|
65 |
+
# For QA dataset, we'll use normalized question strings as dict key
|
66 |
+
return question
|
67 |
+
|
68 |
+
|
69 |
+
def load_query_answers(path: str) -> Dict[str, Dict[str, Any]]:
|
70 |
+
assert path.endswith('.tsv')
|
71 |
+
|
72 |
+
qid_to_query = {}
|
73 |
+
for line in open(path, 'r', encoding='utf-8'):
|
74 |
+
query, answers = line.strip().split('\t')
|
75 |
+
query = normalize_qa_text(query)
|
76 |
+
answers = normalize_qa_text(answers)
|
77 |
+
qid = get_question_key(query)
|
78 |
+
if qid in qid_to_query:
|
79 |
+
logger.warning('Duplicate question: {} vs {}'.format(query, qid_to_query[qid]['query']))
|
80 |
+
continue
|
81 |
+
|
82 |
+
qid_to_query[qid] = {}
|
83 |
+
qid_to_query[qid]['query'] = query
|
84 |
+
qid_to_query[qid]['answers'] = list(eval(answers))
|
85 |
+
|
86 |
+
logger.info('Load {} queries from {}'.format(len(qid_to_query), path))
|
87 |
+
return qid_to_query
|
88 |
+
|
89 |
+
|
90 |
+
def load_corpus(path: str) -> Dataset:
|
91 |
+
assert path.endswith('.jsonl') or path.endswith('.jsonl.gz')
|
92 |
+
|
93 |
+
# two fields: id, contents
|
94 |
+
corpus = load_dataset('json', data_files=path)['train']
|
95 |
+
logger.info('Load {} documents from {} with columns {}'.format(len(corpus), path, corpus.column_names))
|
96 |
+
logger.info('A random document: {}'.format(random.choice(corpus)))
|
97 |
+
return corpus
|
98 |
+
|
99 |
+
|
100 |
+
def load_msmarco_predictions(path: str) -> Dict[str, List[ScoredDoc]]:
|
101 |
+
assert path.endswith('.txt')
|
102 |
+
|
103 |
+
qid_to_scored_doc = {}
|
104 |
+
for line in tqdm.tqdm(open(path, 'r', encoding='utf-8'), desc='load prediction', mininterval=3):
|
105 |
+
fs = line.strip().split('\t')
|
106 |
+
qid, pid, rank = fs[:3]
|
107 |
+
rank = int(rank)
|
108 |
+
score = round(1 / rank, 4) if len(fs) == 3 else float(fs[3])
|
109 |
+
|
110 |
+
if qid not in qid_to_scored_doc:
|
111 |
+
qid_to_scored_doc[qid] = []
|
112 |
+
scored_doc = ScoredDoc(qid=qid, pid=pid, rank=rank, score=score)
|
113 |
+
qid_to_scored_doc[qid].append(scored_doc)
|
114 |
+
|
115 |
+
qid_to_scored_doc = {qid: sorted(scored_docs, key=lambda sd: sd.rank)
|
116 |
+
for qid, scored_docs in qid_to_scored_doc.items()}
|
117 |
+
|
118 |
+
logger.info('Load {} query predictions from {}'.format(len(qid_to_scored_doc), path))
|
119 |
+
return qid_to_scored_doc
|
120 |
+
|
121 |
+
|
122 |
+
def save_preds_to_msmarco_format(preds: Dict[str, List[ScoredDoc]], out_path: str):
|
123 |
+
with open(out_path, 'w', encoding='utf-8') as writer:
|
124 |
+
for qid in preds:
|
125 |
+
for idx, scored_doc in enumerate(preds[qid]):
|
126 |
+
writer.write('{}\t{}\t{}\t{}\n'.format(qid, scored_doc.pid, idx + 1, round(scored_doc.score, 3)))
|
127 |
+
logger.info('Successfully saved to {}'.format(out_path))
|
128 |
+
|
129 |
+
|
130 |
+
def save_to_readable_format(in_path: str, corpus: Dataset):
|
131 |
+
out_path = '{}/readable_{}'.format(os.path.dirname(in_path), os.path.basename(in_path))
|
132 |
+
dataset: Dataset = load_dataset('json', data_files=in_path)['train']
|
133 |
+
|
134 |
+
max_to_keep = 5
|
135 |
+
|
136 |
+
def _create_readable_field(samples: Dict[str, List]) -> List:
|
137 |
+
readable_ex = []
|
138 |
+
for idx in range(min(len(samples['doc_id']), max_to_keep)):
|
139 |
+
doc_id = samples['doc_id'][idx]
|
140 |
+
readable_ex.append({'doc_id': doc_id,
|
141 |
+
'title': corpus[int(doc_id)].get('title', ''),
|
142 |
+
'contents': corpus[int(doc_id)]['contents'],
|
143 |
+
'score': samples['score'][idx]})
|
144 |
+
return readable_ex
|
145 |
+
|
146 |
+
def _mp_func(ex: Dict) -> Dict:
|
147 |
+
ex['positives'] = _create_readable_field(ex['positives'])
|
148 |
+
ex['negatives'] = _create_readable_field(ex['negatives'])
|
149 |
+
return ex
|
150 |
+
dataset = dataset.map(_mp_func, num_proc=8)
|
151 |
+
|
152 |
+
dataset.to_json(out_path, force_ascii=False, lines=False, indent=4)
|
153 |
+
logger.info('Done convert {} to readable format in {}'.format(in_path, out_path))
|
154 |
+
|
155 |
+
|
156 |
+
def get_rerank_shard_path(args: Arguments, worker_idx: int) -> str:
|
157 |
+
return '{}_shard_{}'.format(args.rerank_out_path, worker_idx)
|
158 |
+
|
159 |
+
|
160 |
+
def merge_rerank_predictions(args: Arguments, gpu_count: int):
|
161 |
+
from metrics import trec_eval, compute_mrr
|
162 |
+
|
163 |
+
qid_to_scored_doc: Dict[str, List[ScoredDoc]] = {}
|
164 |
+
for worker_idx in range(gpu_count):
|
165 |
+
path = get_rerank_shard_path(args, worker_idx)
|
166 |
+
for line in tqdm.tqdm(open(path, 'r', encoding='utf-8'), 'merge results', mininterval=3):
|
167 |
+
fs = line.strip().split('\t')
|
168 |
+
qid, pid, _, score = fs
|
169 |
+
score = float(score)
|
170 |
+
|
171 |
+
if qid not in qid_to_scored_doc:
|
172 |
+
qid_to_scored_doc[qid] = []
|
173 |
+
scored_doc = ScoredDoc(qid=qid, pid=pid, rank=-1, score=score)
|
174 |
+
qid_to_scored_doc[qid].append(scored_doc)
|
175 |
+
|
176 |
+
qid_to_scored_doc = {k: sorted(v, key=lambda sd: sd.score, reverse=True) for k, v in qid_to_scored_doc.items()}
|
177 |
+
|
178 |
+
ori_preds = load_msmarco_predictions(path=args.rerank_in_path)
|
179 |
+
for query_id in list(qid_to_scored_doc.keys()):
|
180 |
+
remain_scored_docs = ori_preds[query_id][args.rerank_depth:]
|
181 |
+
for idx, sd in enumerate(remain_scored_docs):
|
182 |
+
# make sure the order is not broken
|
183 |
+
sd.score = qid_to_scored_doc[query_id][-1].score - idx - 1
|
184 |
+
qid_to_scored_doc[query_id] += remain_scored_docs
|
185 |
+
assert len(set([sd.pid for sd in qid_to_scored_doc[query_id]])) == len(qid_to_scored_doc[query_id])
|
186 |
+
|
187 |
+
save_preds_to_msmarco_format(qid_to_scored_doc, out_path=args.rerank_out_path)
|
188 |
+
|
189 |
+
path_qrels = '{}/{}_qrels.txt'.format(args.data_dir, args.rerank_split)
|
190 |
+
if os.path.exists(path_qrels):
|
191 |
+
qrels = load_qrels(path=path_qrels)
|
192 |
+
all_metrics = trec_eval(qrels=qrels, predictions=qid_to_scored_doc)
|
193 |
+
all_metrics['mrr'] = compute_mrr(qrels=qrels, predictions=qid_to_scored_doc)
|
194 |
+
|
195 |
+
logger.info('{} trec metrics = {}'.format(args.rerank_split, json.dumps(all_metrics, ensure_ascii=False, indent=4)))
|
196 |
+
metrics_out_path = '{}/metrics_rerank_{}.json'.format(os.path.dirname(args.rerank_out_path), args.rerank_split)
|
197 |
+
save_json_to_file(all_metrics, metrics_out_path)
|
198 |
+
else:
|
199 |
+
logger.warning('No qrels found for {}'.format(args.rerank_split))
|
200 |
+
|
201 |
+
# cleanup some intermediate results
|
202 |
+
for worker_idx in range(gpu_count):
|
203 |
+
path = get_rerank_shard_path(args, worker_idx)
|
204 |
+
os.remove(path)
|
205 |
+
|
206 |
+
|
207 |
+
if __name__ == '__main__':
|
208 |
+
load_qrels('./data/msmarco/dev_qrels.txt')
|
209 |
+
load_queries('./data/msmarco/dev_queries.tsv')
|
210 |
+
corpus = load_corpus('./data/msmarco/passages.jsonl.gz')
|
211 |
+
preds = load_msmarco_predictions('./data/bm25.msmarco.txt')
|
src/cross_rerank/loader_utils.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict
|
2 |
+
|
3 |
+
|
4 |
+
def _slice_with_mod(elements: List, offset: int, cnt: int) -> List:
|
5 |
+
return [elements[(offset + idx) % len(elements)] for idx in range(cnt)]
|
6 |
+
|
7 |
+
|
8 |
+
def group_doc_ids(examples: Dict[str, List],
|
9 |
+
negative_size: int,
|
10 |
+
offset: int,
|
11 |
+
use_first_positive: bool = False,
|
12 |
+
use_first_negative: bool = True) -> List[int]:
|
13 |
+
pos_doc_ids: List[int] = []
|
14 |
+
positives: List[Dict[str, List]] = examples['positives']
|
15 |
+
for idx, ex_pos in enumerate(positives):
|
16 |
+
all_pos_doc_ids = ex_pos['doc_id']
|
17 |
+
|
18 |
+
if use_first_positive:
|
19 |
+
# keep positives that has higher score than all negatives
|
20 |
+
all_pos_doc_ids = [doc_id for p_idx, doc_id in enumerate(all_pos_doc_ids)
|
21 |
+
if ex_pos['score'][p_idx] == max(ex_pos['score'])]
|
22 |
+
|
23 |
+
cur_pos_doc_id = _slice_with_mod(all_pos_doc_ids, offset=offset, cnt=1)[0]
|
24 |
+
pos_doc_ids.append(int(cur_pos_doc_id))
|
25 |
+
|
26 |
+
neg_doc_ids: List[List[int]] = []
|
27 |
+
negatives: List[Dict[str, List]] = examples['negatives']
|
28 |
+
for ex_neg in negatives:
|
29 |
+
if use_first_negative:
|
30 |
+
cur_neg_doc_ids = ex_neg['doc_id'][:negative_size]
|
31 |
+
else:
|
32 |
+
cur_neg_doc_ids = _slice_with_mod(ex_neg['doc_id'],
|
33 |
+
offset=offset * negative_size,
|
34 |
+
cnt=negative_size)
|
35 |
+
cur_neg_doc_ids = [int(doc_id) for doc_id in cur_neg_doc_ids]
|
36 |
+
neg_doc_ids.append(cur_neg_doc_ids)
|
37 |
+
|
38 |
+
assert len(pos_doc_ids) == len(neg_doc_ids), '{} != {}'.format(len(pos_doc_ids), len(neg_doc_ids))
|
39 |
+
assert all(len(doc_ids) == negative_size for doc_ids in neg_doc_ids)
|
40 |
+
|
41 |
+
input_doc_ids: List[int] = []
|
42 |
+
for pos_doc_id, neg_ids in zip(pos_doc_ids, neg_doc_ids):
|
43 |
+
input_doc_ids.append(pos_doc_id)
|
44 |
+
input_doc_ids += neg_ids
|
45 |
+
|
46 |
+
return input_doc_ids
|
src/cross_rerank/logger_config.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
|
4 |
+
from transformers.trainer_callback import TrainerCallback
|
5 |
+
|
6 |
+
|
7 |
+
def _setup_logger():
|
8 |
+
log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s")
|
9 |
+
logger = logging.getLogger()
|
10 |
+
logger.setLevel(logging.INFO)
|
11 |
+
|
12 |
+
console_handler = logging.StreamHandler()
|
13 |
+
console_handler.setFormatter(log_format)
|
14 |
+
|
15 |
+
data_dir = './data/'
|
16 |
+
os.makedirs(data_dir, exist_ok=True)
|
17 |
+
file_handler = logging.FileHandler('{}/log.txt'.format(data_dir))
|
18 |
+
file_handler.setFormatter(log_format)
|
19 |
+
|
20 |
+
logger.handlers = [console_handler, file_handler]
|
21 |
+
|
22 |
+
return logger
|
23 |
+
|
24 |
+
|
25 |
+
logger = _setup_logger()
|
26 |
+
|
27 |
+
|
28 |
+
class LoggerCallback(TrainerCallback):
|
29 |
+
def on_log(self, args, state, control, logs=None, **kwargs):
|
30 |
+
_ = logs.pop("total_flos", None)
|
31 |
+
if state.is_world_process_zero:
|
32 |
+
logger.info(logs)
|
src/cross_rerank/loss.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
class CrossEncoderNllLoss(object):
|
5 |
+
def __init__(self,
|
6 |
+
score_type="dot"):
|
7 |
+
self.score_type = score_type
|
8 |
+
|
9 |
+
def calc(
|
10 |
+
self,
|
11 |
+
logits,
|
12 |
+
labels):
|
13 |
+
"""
|
14 |
+
Computes nll loss for the given lists of question and ctx vectors.
|
15 |
+
Return: a tuple of loss value and amount of correct predictions per batch
|
16 |
+
"""
|
17 |
+
|
18 |
+
#if len(q_vectors.size()) > 1:
|
19 |
+
# q_num = q_vectors.size(0)
|
20 |
+
# scores = scores.view(q_num, -1)
|
21 |
+
# positive_idx_per_question = [i for i in range(q_num)]
|
22 |
+
|
23 |
+
softmax_scores = F.log_softmax(logits, dim=1)
|
24 |
+
#print("softmax", softmax_scores)
|
25 |
+
#print(softmax_scores.size())
|
26 |
+
#print(labels.size())
|
27 |
+
loss = F.nll_loss(
|
28 |
+
softmax_scores,
|
29 |
+
labels,
|
30 |
+
reduction="mean",
|
31 |
+
)
|
32 |
+
#print(loss)
|
33 |
+
#max_score, max_idxs = torch.max(softmax_scores, 1)
|
34 |
+
#correct_predictions_count = (max_idxs == torch.tensor(positive_idx_per_question).to(max_idxs.device)).sum()
|
35 |
+
return loss#, correct_predictions_count
|
36 |
+
|
src/cross_rerank/metrics.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import pytrec_eval
|
3 |
+
|
4 |
+
from typing import List, Dict, Tuple
|
5 |
+
|
6 |
+
from .data_utils import ScoredDoc
|
7 |
+
from .logger_config import logger
|
8 |
+
|
9 |
+
|
10 |
+
def trec_eval(qrels: Dict[str, Dict[str, int]],
|
11 |
+
predictions: Dict[str, List[ScoredDoc]],
|
12 |
+
k_values: Tuple[int] = (10, 50, 100, 200, 1000)) -> Dict[str, float]:
|
13 |
+
ndcg, _map, recall = {}, {}, {}
|
14 |
+
|
15 |
+
for k in k_values:
|
16 |
+
ndcg[f"NDCG@{k}"] = 0.0
|
17 |
+
_map[f"MAP@{k}"] = 0.0
|
18 |
+
recall[f"Recall@{k}"] = 0.0
|
19 |
+
|
20 |
+
map_string = "map_cut." + ",".join([str(k) for k in k_values])
|
21 |
+
ndcg_string = "ndcg_cut." + ",".join([str(k) for k in k_values])
|
22 |
+
recall_string = "recall." + ",".join([str(k) for k in k_values])
|
23 |
+
|
24 |
+
results: Dict[str, Dict[str, float]] = {}
|
25 |
+
for query_id, scored_docs in predictions.items():
|
26 |
+
results.update({query_id: {sd.pid: sd.score for sd in scored_docs}})
|
27 |
+
|
28 |
+
evaluator = pytrec_eval.RelevanceEvaluator(qrels, {map_string, ndcg_string, recall_string})
|
29 |
+
scores = evaluator.evaluate(results)
|
30 |
+
|
31 |
+
for query_id in scores:
|
32 |
+
for k in k_values:
|
33 |
+
ndcg[f"NDCG@{k}"] += scores[query_id]["ndcg_cut_" + str(k)]
|
34 |
+
_map[f"MAP@{k}"] += scores[query_id]["map_cut_" + str(k)]
|
35 |
+
recall[f"Recall@{k}"] += scores[query_id]["recall_" + str(k)]
|
36 |
+
|
37 |
+
def _normalize(m: dict) -> dict:
|
38 |
+
return {k: round(v / len(scores), 5) for k, v in m.items()}
|
39 |
+
|
40 |
+
ndcg = _normalize(ndcg)
|
41 |
+
_map = _normalize(_map)
|
42 |
+
recall = _normalize(recall)
|
43 |
+
|
44 |
+
all_metrics = {}
|
45 |
+
for mt in [ndcg, _map, recall]:
|
46 |
+
all_metrics.update(mt)
|
47 |
+
|
48 |
+
return all_metrics
|
49 |
+
|
50 |
+
|
51 |
+
@torch.no_grad()
|
52 |
+
def accuracy(output: torch.tensor, target: torch.tensor, topk=(1,)) -> List[float]:
|
53 |
+
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
54 |
+
maxk = max(topk)
|
55 |
+
batch_size = target.size(0)
|
56 |
+
|
57 |
+
_, pred = output.topk(maxk, 1, True, True)
|
58 |
+
pred = pred.t()
|
59 |
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
60 |
+
|
61 |
+
res = []
|
62 |
+
for k in topk:
|
63 |
+
correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
|
64 |
+
res.append(correct_k.mul_(100.0 / batch_size).item())
|
65 |
+
return res
|
66 |
+
|
67 |
+
|
68 |
+
@torch.no_grad()
|
69 |
+
def batch_mrr(output: torch.tensor, target: torch.tensor) -> float:
|
70 |
+
assert len(output.shape) == 2
|
71 |
+
assert len(target.shape) == 1
|
72 |
+
sorted_score, sorted_indices = torch.sort(output, dim=-1, descending=True)
|
73 |
+
_, rank = torch.nonzero(sorted_indices.eq(target.unsqueeze(-1)).long(), as_tuple=True)
|
74 |
+
assert rank.shape[0] == output.shape[0]
|
75 |
+
|
76 |
+
rank = rank + 1
|
77 |
+
mrr = torch.sum(100 / rank.float()) / rank.shape[0]
|
78 |
+
return mrr.item()
|
79 |
+
|
80 |
+
|
81 |
+
def get_rel_threshold(qrels: Dict[str, Dict[str, int]]) -> int:
|
82 |
+
# For ms-marco passage ranking, score >= 1 is relevant
|
83 |
+
# for trec dl 2019 & 2020, score >= 2 is relevant
|
84 |
+
rel_labels = set()
|
85 |
+
for q_id in qrels:
|
86 |
+
for doc_id, label in qrels[q_id].items():
|
87 |
+
rel_labels.add(label)
|
88 |
+
|
89 |
+
logger.info('relevance labels: {}'.format(rel_labels))
|
90 |
+
return 2 if max(rel_labels) >= 3 else 1
|
91 |
+
|
92 |
+
|
93 |
+
def compute_mrr(qrels: Dict[str, Dict[str, int]],
|
94 |
+
predictions: Dict[str, List[ScoredDoc]],
|
95 |
+
k: int = 10) -> float:
|
96 |
+
threshold = get_rel_threshold(qrels)
|
97 |
+
mrr = 0
|
98 |
+
for qid in qrels:
|
99 |
+
scored_docs = predictions.get(qid, [])
|
100 |
+
for idx, scored_doc in enumerate(scored_docs[:k]):
|
101 |
+
if scored_doc.pid in qrels[qid] and qrels[qid][scored_doc.pid] >= threshold:
|
102 |
+
mrr += 1 / (idx + 1)
|
103 |
+
break
|
104 |
+
|
105 |
+
return round(mrr / len(qrels) * 100, 4)
|
src/cross_rerank/model.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from typing import Optional, Dict
|
5 |
+
from transformers import (
|
6 |
+
PreTrainedModel,
|
7 |
+
AutoModelForSequenceClassification
|
8 |
+
)
|
9 |
+
from transformers.modeling_outputs import SequenceClassifierOutput
|
10 |
+
|
11 |
+
from .config import Arguments
|
12 |
+
from .loss import CrossEncoderNllLoss
|
13 |
+
|
14 |
+
class Reranker(nn.Module):
|
15 |
+
def __init__(self, hf_model: PreTrainedModel, args: Arguments):
|
16 |
+
super().__init__()
|
17 |
+
self.hf_model = hf_model
|
18 |
+
self.args = args
|
19 |
+
self._keys_to_ignore_on_save = None
|
20 |
+
|
21 |
+
self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
|
22 |
+
#self.contrastive = CrossEncoderNllLoss()
|
23 |
+
#self.kl_loss_fn = torch.nn.KLDivLoss(reduction="batchmean", log_target=True)
|
24 |
+
|
25 |
+
def forward(self, input_ids, attention_mask, token_type_ids) -> SequenceClassifierOutput:
|
26 |
+
#n_psg_per_query = self.args.train_n_passages // self.args.rerank_forward_factor
|
27 |
+
|
28 |
+
outputs: SequenceClassifierOutput = self.hf_model(input_ids, attention_mask, token_type_ids, return_dict=True)
|
29 |
+
#outputs.logits = outputs.logits.view(-1, n_psg_per_query)
|
30 |
+
#loss = self.cross_entropy(outputs.logits, labels)
|
31 |
+
|
32 |
+
return outputs#, loss
|
33 |
+
|
34 |
+
@classmethod
|
35 |
+
def from_pretrained(cls, all_args: Arguments, *args, **kwargs):
|
36 |
+
hf_model = AutoModelForSequenceClassification.from_pretrained(*args, **kwargs)
|
37 |
+
return cls(hf_model, all_args)
|
38 |
+
|
39 |
+
def save_pretrained(self, output_dir: str):
|
40 |
+
self.hf_model.save_pretrained(output_dir)
|
41 |
+
|
42 |
+
|
43 |
+
class RerankerForInference(nn.Module):
|
44 |
+
def __init__(self, model_checkpoint, hf_model: Optional[PreTrainedModel] = None):
|
45 |
+
super().__init__()
|
46 |
+
self.hf_model = hf_model
|
47 |
+
if hf_model is None:
|
48 |
+
self.hf_model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint)
|
49 |
+
self.hf_model.eval()
|
50 |
+
|
51 |
+
@torch.no_grad()
|
52 |
+
def forward(self, batch) -> SequenceClassifierOutput:
|
53 |
+
return self.hf_model(**batch)
|
54 |
+
|
55 |
+
@classmethod
|
56 |
+
def from_pretrained(cls, pretrained_model_name_or_path: str):
|
57 |
+
hf_model = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path)
|
58 |
+
return cls(hf_model)
|
src/cross_rerank/trainer.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from torch.utils.checkpoint import get_device_states, set_device_states
|
5 |
+
#from typing import Optional, Union, Dict, Any
|
6 |
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
7 |
+
from transformers.trainer import Trainer
|
8 |
+
from transformers.modeling_outputs import SequenceClassifierOutput
|
9 |
+
from collections.abc import Mapping
|
10 |
+
from .logger_config import logger
|
11 |
+
from .metrics import accuracy
|
12 |
+
from .utils import AverageMeter
|
13 |
+
|
14 |
+
def nested_detach(tensors):
|
15 |
+
"Detach `tensors` (even if it's a nested list/tuple/dict of tensors)."
|
16 |
+
if isinstance(tensors, (list, tuple)):
|
17 |
+
return type(tensors)(nested_detach(t) for t in tensors)
|
18 |
+
elif isinstance(tensors, Mapping):
|
19 |
+
return type(tensors)({k: nested_detach(t) for k, t in tensors.items()})
|
20 |
+
return tensors.detach()
|
21 |
+
|
22 |
+
class RandContext:
|
23 |
+
def __init__(self, *tensors):
|
24 |
+
self.fwd_cpu_state = torch.get_rng_state()
|
25 |
+
self.fwd_gpu_devices, self.fwd_gpu_states = get_device_states(*tensors)
|
26 |
+
|
27 |
+
def __enter__(self):
|
28 |
+
self._fork = torch.random.fork_rng(
|
29 |
+
devices=self.fwd_gpu_devices,
|
30 |
+
enabled=True
|
31 |
+
)
|
32 |
+
self._fork.__enter__()
|
33 |
+
torch.set_rng_state(self.fwd_cpu_state)
|
34 |
+
set_device_states(self.fwd_gpu_devices, self.fwd_gpu_states)
|
35 |
+
|
36 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
37 |
+
self._fork.__exit__(exc_type, exc_val, exc_tb)
|
38 |
+
self._fork = None
|
39 |
+
|
40 |
+
class RerankerTrainer(Trainer):
|
41 |
+
def __init__(self, *pargs, **kwargs):
|
42 |
+
super(RerankerTrainer, self).__init__(*pargs, **kwargs)
|
43 |
+
|
44 |
+
self.acc_meter = AverageMeter('acc', round_digits=2)
|
45 |
+
self.last_epoch = 0
|
46 |
+
|
47 |
+
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
48 |
+
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
49 |
+
os.makedirs(output_dir, exist_ok=True)
|
50 |
+
logger.info("Saving model checkpoint to {}".format(output_dir))
|
51 |
+
|
52 |
+
self.model.save_pretrained(output_dir)
|
53 |
+
|
54 |
+
if self.tokenizer is not None and self.is_world_process_zero():
|
55 |
+
self.tokenizer.save_pretrained(output_dir)
|
56 |
+
|
57 |
+
def compute_loss(self, model, inputs, return_outputs=False):
|
58 |
+
n_psg_per_query = self.args.train_n_passages // self.args.rerank_forward_factor
|
59 |
+
input_ids = inputs['input_ids']
|
60 |
+
attention_mask = inputs['attention_mask']
|
61 |
+
token_type_ids = inputs['token_type_ids']
|
62 |
+
labels = inputs['labels']
|
63 |
+
outputs = model(input_ids, attention_mask, token_type_ids)
|
64 |
+
outputs.logits = outputs.logits.view(-1, n_psg_per_query)
|
65 |
+
loss = self.model.cross_entropy(outputs.logits, labels)
|
66 |
+
|
67 |
+
if self.model.training:
|
68 |
+
step_acc = accuracy(output=outputs.logits.detach(), target=labels)[0]
|
69 |
+
self.acc_meter.update(step_acc)
|
70 |
+
if self.state.global_step > 0 and self.state.global_step % self.args.logging_steps == 0:
|
71 |
+
logger.info('step: {}, {}'.format(self.state.global_step, self.acc_meter))
|
72 |
+
|
73 |
+
self._reset_meters_if_needed()
|
74 |
+
|
75 |
+
return (loss, outputs) if return_outputs else loss
|
76 |
+
|
77 |
+
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
|
78 |
+
model.train()
|
79 |
+
inputs = self._prepare_inputs(inputs)
|
80 |
+
|
81 |
+
with self.compute_loss_context_manager():
|
82 |
+
loss = self.compute_loss_train(model, inputs)
|
83 |
+
|
84 |
+
return loss.detach() / self.args.gradient_accumulation_steps
|
85 |
+
|
86 |
+
def compute_loss_train(self, model, inputs, return_outputs=False):
|
87 |
+
#print(inputs)
|
88 |
+
#print(inputs['input_ids'].size())
|
89 |
+
n_psg_per_query = self.args.train_n_passages // self.args.rerank_forward_factor
|
90 |
+
input_ids = inputs['input_ids']
|
91 |
+
attention_mask = inputs['attention_mask']
|
92 |
+
token_type_ids = inputs['token_type_ids']
|
93 |
+
labels = inputs['labels']
|
94 |
+
|
95 |
+
all_reps, rnds = [], []
|
96 |
+
|
97 |
+
id_chunks = input_ids.split(self.args.chunk_size)
|
98 |
+
attn_mask_chunks = attention_mask.split(self.args.chunk_size)
|
99 |
+
|
100 |
+
type_ids_chunks = token_type_ids.split(self.args.chunk_size)
|
101 |
+
|
102 |
+
for id_chunk, attn_chunk, type_chunk in zip(id_chunks, attn_mask_chunks, type_ids_chunks):
|
103 |
+
rnds.append(RandContext(id_chunk, attn_chunk, type_chunk))
|
104 |
+
with torch.no_grad():
|
105 |
+
chunk_reps = self.model(id_chunk, attn_chunk, type_chunk).logits
|
106 |
+
all_reps.append(chunk_reps)
|
107 |
+
all_reps = torch.cat(all_reps)
|
108 |
+
all_reps = all_reps.view(-1, n_psg_per_query)
|
109 |
+
|
110 |
+
all_reps = all_reps.float().detach().requires_grad_()
|
111 |
+
loss = self.model.cross_entropy(all_reps, labels)
|
112 |
+
|
113 |
+
if self.args.n_gpu > 1:
|
114 |
+
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
115 |
+
|
116 |
+
self.accelerator.backward(loss)
|
117 |
+
#if self.args.gradient_accumulation_steps > 1:
|
118 |
+
# loss = loss / self.args.gradient_accumulation_steps
|
119 |
+
#loss.backward()
|
120 |
+
#temp = all_reps.view(-1,1)
|
121 |
+
grads = all_reps.grad.split(int(self.args.chunk_size/n_psg_per_query))
|
122 |
+
|
123 |
+
for id_chunk, attn_chunk, type_chunk, grad, rnd in zip(id_chunks, attn_mask_chunks, type_ids_chunks, grads, rnds):
|
124 |
+
#print(id_chunk.size())
|
125 |
+
with rnd:
|
126 |
+
chunk_reps = self.model(id_chunk, attn_chunk, type_chunk).logits
|
127 |
+
#print(chunk_reps.size())
|
128 |
+
#print(grad.size())
|
129 |
+
surrogate = torch.dot(chunk_reps.flatten().float(), grad.flatten())
|
130 |
+
|
131 |
+
self.accelerator.backward(surrogate)
|
132 |
+
|
133 |
+
#outputs, loss = model(input_ids, attention_mask, token_type_ids, labels)
|
134 |
+
|
135 |
+
if self.model.training:
|
136 |
+
step_acc = accuracy(all_reps, target=labels)[0]
|
137 |
+
#print(step_acc)
|
138 |
+
self.acc_meter.update(step_acc)
|
139 |
+
if self.state.global_step > 0 and self.state.global_step % self.args.logging_steps == 0:
|
140 |
+
logger.info('step: {}, {}'.format(self.state.global_step, self.acc_meter))
|
141 |
+
|
142 |
+
self._reset_meters_if_needed()
|
143 |
+
|
144 |
+
return (loss, all_reps) if return_outputs else loss
|
145 |
+
|
146 |
+
'''def compute_loss_pred(self, model, inputs, return_outputs=False):
|
147 |
+
#print(inputs)
|
148 |
+
#print(inputs['input_ids'].size())
|
149 |
+
n_psg_per_query = self.args.train_n_passages // self.args.rerank_forward_factor
|
150 |
+
input_ids = inputs['input_ids']
|
151 |
+
attention_mask = inputs['attention_mask']
|
152 |
+
token_type_ids = inputs['token_type_ids']
|
153 |
+
labels = inputs['labels']
|
154 |
+
|
155 |
+
all_reps, rnds = [], []
|
156 |
+
|
157 |
+
id_chunks = input_ids.split(self.args.chunk_size)
|
158 |
+
attn_mask_chunks = attention_mask.split(self.args.chunk_size)
|
159 |
+
|
160 |
+
type_ids_chunks = token_type_ids.split(self.args.chunk_size)
|
161 |
+
|
162 |
+
for id_chunk, attn_chunk, type_chunk in zip(id_chunks, attn_mask_chunks, type_ids_chunks):
|
163 |
+
rnds.append(RandContext(id_chunk, attn_chunk, type_chunk))
|
164 |
+
with torch.no_grad():
|
165 |
+
chunk_reps = self.model(id_chunk, attn_chunk, type_chunk).logits
|
166 |
+
all_reps.append(chunk_reps)
|
167 |
+
all_reps = torch.cat(all_reps)
|
168 |
+
all_reps = all_reps.view(-1, n_psg_per_query)
|
169 |
+
loss = self.model.cross_entropy(all_reps, labels)
|
170 |
+
|
171 |
+
if self.model.training:
|
172 |
+
step_acc = accuracy(all_reps, target=labels)[0]
|
173 |
+
#print(step_acc)
|
174 |
+
self.acc_meter.update(step_acc)
|
175 |
+
if self.state.global_step > 0 and self.state.global_step % self.args.logging_steps == 0:
|
176 |
+
logger.info('step: {}, {}'.format(self.state.global_step, self.acc_meter))
|
177 |
+
|
178 |
+
self._reset_meters_if_needed()
|
179 |
+
|
180 |
+
return (loss, all_reps) if return_outputs else loss
|
181 |
+
|
182 |
+
def prediction_step(
|
183 |
+
self,
|
184 |
+
model: nn.Module,
|
185 |
+
inputs: Dict[str, Union[torch.Tensor, Any]],
|
186 |
+
prediction_loss_only: bool,
|
187 |
+
ignore_keys: Optional[List[str]] = None,
|
188 |
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
189 |
+
has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)
|
190 |
+
# For CLIP-like models capable of returning loss values.
|
191 |
+
# If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
|
192 |
+
# is `True` in `model.forward`.
|
193 |
+
return_loss = inputs.get("return_loss", None)
|
194 |
+
if return_loss is None:
|
195 |
+
return_loss = self.can_return_loss
|
196 |
+
loss_without_labels = True if len(self.label_names) == 0 and return_loss else False
|
197 |
+
|
198 |
+
inputs = self._prepare_inputs(inputs)
|
199 |
+
if ignore_keys is None:
|
200 |
+
if hasattr(self.model, "config"):
|
201 |
+
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
|
202 |
+
else:
|
203 |
+
ignore_keys = []
|
204 |
+
|
205 |
+
# labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
|
206 |
+
if has_labels or loss_without_labels:
|
207 |
+
labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
|
208 |
+
if len(labels) == 1:
|
209 |
+
labels = labels[0]
|
210 |
+
else:
|
211 |
+
labels = None
|
212 |
+
|
213 |
+
with torch.no_grad():
|
214 |
+
if has_labels or loss_without_labels:
|
215 |
+
with self.compute_loss_context_manager():
|
216 |
+
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
|
217 |
+
loss = loss.mean().detach()
|
218 |
+
|
219 |
+
if isinstance(outputs, dict):
|
220 |
+
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
|
221 |
+
else:
|
222 |
+
logits = outputs[1:]
|
223 |
+
else:
|
224 |
+
loss = None
|
225 |
+
with self.compute_loss_context_manager():
|
226 |
+
outputs = model(**inputs)
|
227 |
+
if isinstance(outputs, dict):
|
228 |
+
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
|
229 |
+
else:
|
230 |
+
logits = outputs
|
231 |
+
# TODO: this needs to be fixed and made cleaner later.
|
232 |
+
if self.args.past_index >= 0:
|
233 |
+
self._past = outputs[self.args.past_index - 1]
|
234 |
+
|
235 |
+
if prediction_loss_only:
|
236 |
+
return (loss, None, None)
|
237 |
+
|
238 |
+
logits = nested_detach(logits)
|
239 |
+
if len(logits) == 1:
|
240 |
+
logits = logits[0]
|
241 |
+
|
242 |
+
return (loss, logits, labels)'''
|
243 |
+
|
244 |
+
def _reset_meters_if_needed(self):
|
245 |
+
if int(self.state.epoch) != self.last_epoch:
|
246 |
+
self.last_epoch = int(self.state.epoch)
|
247 |
+
self.acc_meter.reset()
|
src/cross_rerank/utils.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import torch
|
3 |
+
import torch.distributed as dist
|
4 |
+
|
5 |
+
from typing import List, Union, Optional, Tuple, Mapping, Dict
|
6 |
+
|
7 |
+
|
8 |
+
def save_json_to_file(objects: Union[List, dict], path: str, line_by_line: bool = False):
|
9 |
+
if line_by_line:
|
10 |
+
assert isinstance(objects, list), 'Only list can be saved in line by line format'
|
11 |
+
|
12 |
+
with open(path, 'w', encoding='utf-8') as writer:
|
13 |
+
if not line_by_line:
|
14 |
+
json.dump(objects, writer, ensure_ascii=False, indent=4, separators=(',', ':'))
|
15 |
+
else:
|
16 |
+
for obj in objects:
|
17 |
+
writer.write(json.dumps(obj, ensure_ascii=False, separators=(',', ':')))
|
18 |
+
writer.write('\n')
|
19 |
+
|
20 |
+
|
21 |
+
def move_to_cuda(sample):
|
22 |
+
if len(sample) == 0:
|
23 |
+
return {}
|
24 |
+
|
25 |
+
def _move_to_cuda(maybe_tensor):
|
26 |
+
if torch.is_tensor(maybe_tensor):
|
27 |
+
return maybe_tensor.cuda(non_blocking=True)
|
28 |
+
elif isinstance(maybe_tensor, dict):
|
29 |
+
return {key: _move_to_cuda(value) for key, value in maybe_tensor.items()}
|
30 |
+
elif isinstance(maybe_tensor, list):
|
31 |
+
return [_move_to_cuda(x) for x in maybe_tensor]
|
32 |
+
elif isinstance(maybe_tensor, tuple):
|
33 |
+
return tuple([_move_to_cuda(x) for x in maybe_tensor])
|
34 |
+
elif isinstance(maybe_tensor, Mapping):
|
35 |
+
return type(maybe_tensor)({k: _move_to_cuda(v) for k, v in maybe_tensor.items()})
|
36 |
+
else:
|
37 |
+
return maybe_tensor
|
38 |
+
|
39 |
+
return _move_to_cuda(sample)
|
40 |
+
|
41 |
+
|
42 |
+
def dist_gather_tensor(t: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
43 |
+
if t is None:
|
44 |
+
return None
|
45 |
+
|
46 |
+
t = t.contiguous()
|
47 |
+
all_tensors = [torch.empty_like(t) for _ in range(dist.get_world_size())]
|
48 |
+
dist.all_gather(all_tensors, t)
|
49 |
+
|
50 |
+
all_tensors[dist.get_rank()] = t
|
51 |
+
all_tensors = torch.cat(all_tensors, dim=0)
|
52 |
+
return all_tensors
|
53 |
+
|
54 |
+
|
55 |
+
@torch.no_grad()
|
56 |
+
def select_grouped_indices(scores: torch.Tensor,
|
57 |
+
group_size: int,
|
58 |
+
start: int = 0) -> torch.Tensor:
|
59 |
+
assert len(scores.shape) == 2
|
60 |
+
batch_size = scores.shape[0]
|
61 |
+
assert batch_size * group_size <= scores.shape[1]
|
62 |
+
|
63 |
+
indices = torch.arange(0, group_size, dtype=torch.long)
|
64 |
+
indices = indices.repeat(batch_size, 1)
|
65 |
+
indices += torch.arange(0, batch_size, dtype=torch.long).unsqueeze(-1) * group_size
|
66 |
+
indices += start
|
67 |
+
|
68 |
+
return indices.to(scores.device)
|
69 |
+
|
70 |
+
|
71 |
+
def full_contrastive_scores_and_labels(
|
72 |
+
query: torch.Tensor,
|
73 |
+
key: torch.Tensor,
|
74 |
+
use_all_pairs: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
|
75 |
+
assert key.shape[0] % query.shape[0] == 0, '{} % {} > 0'.format(key.shape[0], query.shape[0])
|
76 |
+
|
77 |
+
train_n_passages = key.shape[0] // query.shape[0]
|
78 |
+
labels = torch.arange(0, query.shape[0], dtype=torch.long, device=query.device)
|
79 |
+
labels = labels * train_n_passages
|
80 |
+
|
81 |
+
# batch_size x (batch_size x n_psg)
|
82 |
+
qk = torch.mm(query, key.t())
|
83 |
+
|
84 |
+
if not use_all_pairs:
|
85 |
+
return qk, labels
|
86 |
+
|
87 |
+
# batch_size x dim
|
88 |
+
sliced_key = key.index_select(dim=0, index=labels)
|
89 |
+
assert query.shape[0] == sliced_key.shape[0]
|
90 |
+
|
91 |
+
# batch_size x batch_size
|
92 |
+
kq = torch.mm(sliced_key, query.t())
|
93 |
+
kq.fill_diagonal_(float('-inf'))
|
94 |
+
|
95 |
+
qq = torch.mm(query, query.t())
|
96 |
+
qq.fill_diagonal_(float('-inf'))
|
97 |
+
|
98 |
+
kk = torch.mm(sliced_key, sliced_key.t())
|
99 |
+
kk.fill_diagonal_(float('-inf'))
|
100 |
+
|
101 |
+
scores = torch.cat([qk, kq, qq, kk], dim=-1)
|
102 |
+
|
103 |
+
return scores, labels
|
104 |
+
|
105 |
+
|
106 |
+
def slice_batch_dict(batch_dict: Dict[str, torch.Tensor], prefix: str) -> dict:
|
107 |
+
return {k[len(prefix):]: v for k, v in batch_dict.items() if k.startswith(prefix)}
|
108 |
+
|
109 |
+
|
110 |
+
class AverageMeter(object):
|
111 |
+
"""Computes and stores the average and current value"""
|
112 |
+
|
113 |
+
def __init__(self, name: str, round_digits: int = 3):
|
114 |
+
self.name = name
|
115 |
+
self.round_digits = round_digits
|
116 |
+
self.reset()
|
117 |
+
|
118 |
+
def reset(self):
|
119 |
+
self.avg = 0
|
120 |
+
self.sum = 0
|
121 |
+
self.count = 0
|
122 |
+
|
123 |
+
def update(self, val, n=1):
|
124 |
+
self.sum += val * n
|
125 |
+
self.count += n
|
126 |
+
self.avg = self.sum / self.count
|
127 |
+
|
128 |
+
def __str__(self):
|
129 |
+
return '{}: {}'.format(self.name, round(self.avg, self.round_digits))
|
130 |
+
|
131 |
+
|
132 |
+
if __name__ == '__main__':
|
133 |
+
query = torch.randn(4, 16)
|
134 |
+
key = torch.randn(4 * 3, 16)
|
135 |
+
scores, labels = full_contrastive_scores_and_labels(query, key)
|
136 |
+
print(scores.shape)
|
137 |
+
print(labels)
|
src/eval_cross.py
ADDED
@@ -0,0 +1,512 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import faiss
|
2 |
+
import torch
|
3 |
+
import logging
|
4 |
+
import json
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
from tqdm import tqdm
|
8 |
+
from typing import Optional
|
9 |
+
from dataclasses import dataclass, field
|
10 |
+
from transformers import HfArgumentParser
|
11 |
+
from transformers import AutoTokenizer
|
12 |
+
from bi.model import SharedBiEncoder
|
13 |
+
from bi.preprocess import preprocess_question
|
14 |
+
from cross_rerank.model import RerankerForInference
|
15 |
+
#from src.process import process_query, process_text, concat_str
|
16 |
+
import itertools
|
17 |
+
from pyvi.ViTokenizer import tokenize
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
@dataclass
|
23 |
+
class Args:
|
24 |
+
encoder: str = field(
|
25 |
+
default="vinai/phobert-base-v2",
|
26 |
+
metadata={'help': 'The encoder name or path.'}
|
27 |
+
)
|
28 |
+
tokenizer: str = field(
|
29 |
+
default=None,
|
30 |
+
metadata={'help': 'The encoder name or path.'}
|
31 |
+
)
|
32 |
+
cross_checkpoint: str = field(
|
33 |
+
default="vinai/phobert-base-v2",
|
34 |
+
metadata={'help': 'The encoder name or path.'}
|
35 |
+
)
|
36 |
+
cross_tokenizer: str = field(
|
37 |
+
default=None,
|
38 |
+
metadata={'help': 'The encoder name or path.'}
|
39 |
+
)
|
40 |
+
sentence_pooling_method: str = field(
|
41 |
+
default="cls",
|
42 |
+
metadata={'help': 'Embedding method'}
|
43 |
+
)
|
44 |
+
fp16: bool = field(
|
45 |
+
default=False,
|
46 |
+
metadata={'help': 'Use fp16 in inference?'}
|
47 |
+
)
|
48 |
+
max_query_length: int = field(
|
49 |
+
default=32,
|
50 |
+
metadata={'help': 'Max query length.'}
|
51 |
+
)
|
52 |
+
max_passage_length: int = field(
|
53 |
+
default=256,
|
54 |
+
metadata={'help': 'Max passage length.'}
|
55 |
+
)
|
56 |
+
cross_max_length: int = field(
|
57 |
+
default=256,
|
58 |
+
metadata={'help': 'Max cross length.'}
|
59 |
+
)
|
60 |
+
cross_batch_size: int = field(
|
61 |
+
default=32,
|
62 |
+
metadata={'help': 'Inference batch size.'}
|
63 |
+
)
|
64 |
+
batch_size: int = field(
|
65 |
+
default=128,
|
66 |
+
metadata={'help': 'Inference batch size.'}
|
67 |
+
)
|
68 |
+
index_factory: str = field(
|
69 |
+
default="Flat",
|
70 |
+
metadata={'help': 'Faiss index factory.'}
|
71 |
+
)
|
72 |
+
k: int = field(
|
73 |
+
default=1000,
|
74 |
+
metadata={'help': 'How many neighbors to retrieve?'}
|
75 |
+
)
|
76 |
+
top_k: int = field(
|
77 |
+
default=1000,
|
78 |
+
metadata={'help': 'How many neighbors to rerank?'}
|
79 |
+
)
|
80 |
+
data_path: str = field(
|
81 |
+
default="/kaggle/input/zalo-data",
|
82 |
+
metadata={'help': 'Path to zalo data.'}
|
83 |
+
)
|
84 |
+
data_type: str = field(
|
85 |
+
default="test",
|
86 |
+
metadata={'help': 'Type data to test'}
|
87 |
+
)
|
88 |
+
corpus_file: str = field(
|
89 |
+
default="/kaggle/input/zalo-data",
|
90 |
+
metadata={'help': 'Path to zalo corpus.'}
|
91 |
+
)
|
92 |
+
|
93 |
+
data_file: str = field(
|
94 |
+
default=None,
|
95 |
+
metadata={'help': 'Path to evaluated data.'}
|
96 |
+
)
|
97 |
+
|
98 |
+
bi_data: bool = field(
|
99 |
+
default=False,
|
100 |
+
metadata={'help': 'Data for bi-encoder training'}
|
101 |
+
)
|
102 |
+
|
103 |
+
save_embedding: bool = field(
|
104 |
+
default=False,
|
105 |
+
metadata={'help': 'Save embeddings in memmap at save_dir?'}
|
106 |
+
)
|
107 |
+
load_embedding: str = field(
|
108 |
+
default='',
|
109 |
+
metadata={'help': 'Path to saved embeddings.'}
|
110 |
+
)
|
111 |
+
|
112 |
+
save_path: str = field(
|
113 |
+
default="embeddings.memmap",
|
114 |
+
metadata={'help': 'Path to save embeddings.'}
|
115 |
+
)
|
116 |
+
|
117 |
+
def index(model: SharedBiEncoder, tokenizer:AutoTokenizer, corpus, batch_size: int = 16, max_length: int=512, index_factory: str = "Flat", save_path: str = None, save_embedding: bool = False, load_embedding: bool = False):
|
118 |
+
"""
|
119 |
+
1. Encode the entire corpus into dense embeddings;
|
120 |
+
2. Create faiss index;
|
121 |
+
3. Optionally save embeddings.
|
122 |
+
"""
|
123 |
+
if load_embedding != '':
|
124 |
+
test_tokens = tokenizer(['test'],
|
125 |
+
padding=True,
|
126 |
+
truncation=True,
|
127 |
+
max_length=128,
|
128 |
+
return_tensors="pt").to('cuda')
|
129 |
+
test = model.encoder.get_representation(test_tokens['input_ids'], test_tokens['attention_mask'])
|
130 |
+
test = test.cpu().numpy()
|
131 |
+
dtype = test.dtype
|
132 |
+
dim = test.shape[-1]
|
133 |
+
|
134 |
+
all_embeddings = np.memmap(
|
135 |
+
load_embedding,
|
136 |
+
mode="r",
|
137 |
+
dtype=dtype
|
138 |
+
).reshape(-1, dim)
|
139 |
+
|
140 |
+
else:
|
141 |
+
#df_corpus = pd.DataFrame()
|
142 |
+
#df_corpus['text'] = corpus
|
143 |
+
#pandarallel.initialize(progress_bar=True, use_memory_fs=False, nb_workers=12)
|
144 |
+
#df_corpus['processed_text'] = df_corpus['text'].parallel_apply(process_text)
|
145 |
+
#processed_corpus = df_corpus['processed_text'].tolist()
|
146 |
+
#model.to('cuda')
|
147 |
+
all_embeddings = []
|
148 |
+
for start_index in tqdm(range(0, len(corpus), batch_size), desc="Inference Embeddings",
|
149 |
+
disable=len(corpus) < batch_size):
|
150 |
+
passages_batch = corpus[start_index:start_index + batch_size]
|
151 |
+
d_collated = tokenizer(
|
152 |
+
passages_batch,
|
153 |
+
padding=True,
|
154 |
+
truncation=True,
|
155 |
+
max_length=max_length,
|
156 |
+
return_tensors="pt",
|
157 |
+
).to('cuda')
|
158 |
+
|
159 |
+
with torch.no_grad():
|
160 |
+
corpus_embeddings = model.encoder.get_representation(d_collated['input_ids'], d_collated['attention_mask'])
|
161 |
+
|
162 |
+
corpus_embeddings = corpus_embeddings.cpu().numpy()
|
163 |
+
all_embeddings.append(corpus_embeddings)
|
164 |
+
|
165 |
+
all_embeddings = np.concatenate(all_embeddings, axis=0)
|
166 |
+
dim = all_embeddings.shape[-1]
|
167 |
+
|
168 |
+
if save_embedding:
|
169 |
+
logger.info(f"saving embeddings at {save_path}...")
|
170 |
+
memmap = np.memmap(
|
171 |
+
save_path,
|
172 |
+
shape=all_embeddings.shape,
|
173 |
+
mode="w+",
|
174 |
+
dtype=all_embeddings.dtype
|
175 |
+
)
|
176 |
+
|
177 |
+
length = all_embeddings.shape[0]
|
178 |
+
# add in batch
|
179 |
+
save_batch_size = 10000
|
180 |
+
if length > save_batch_size:
|
181 |
+
for i in tqdm(range(0, length, save_batch_size), leave=False, desc="Saving Embeddings"):
|
182 |
+
j = min(i + save_batch_size, length)
|
183 |
+
memmap[i: j] = all_embeddings[i: j]
|
184 |
+
else:
|
185 |
+
memmap[:] = all_embeddings
|
186 |
+
# create faiss index
|
187 |
+
faiss_index = faiss.index_factory(dim, index_factory, faiss.METRIC_INNER_PRODUCT)
|
188 |
+
|
189 |
+
#if model.device == torch.device("cuda"):
|
190 |
+
if True:
|
191 |
+
co = faiss.GpuClonerOptions()
|
192 |
+
#co = faiss.GpuMultipleClonerOptions()
|
193 |
+
#co.useFloat16 = True
|
194 |
+
faiss_index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, faiss_index, co)
|
195 |
+
#faiss_index = faiss.index_cpu_to_all_gpus(faiss_index, co)
|
196 |
+
|
197 |
+
# NOTE: faiss only accepts float32
|
198 |
+
logger.info("Adding embeddings...")
|
199 |
+
all_embeddings = all_embeddings.astype(np.float32)
|
200 |
+
#print(all_embeddings[0])
|
201 |
+
faiss_index.train(all_embeddings)
|
202 |
+
faiss_index.add(all_embeddings)
|
203 |
+
return faiss_index
|
204 |
+
|
205 |
+
|
206 |
+
def search(model: SharedBiEncoder, tokenizer:AutoTokenizer, questions, faiss_index: faiss.Index, k:int = 100, batch_size: int = 256, max_length: int=128):
|
207 |
+
"""
|
208 |
+
1. Encode queries into dense embeddings;
|
209 |
+
2. Search through faiss index
|
210 |
+
"""
|
211 |
+
#model.to('cuda')
|
212 |
+
q_embeddings = []
|
213 |
+
#questions = queries['tokenized_question'].tolist()
|
214 |
+
#questions = [process_query(x) for x in questions]
|
215 |
+
for start_index in tqdm(range(0, len(questions), batch_size), desc="Inference Embeddings",
|
216 |
+
disable=len(questions) < batch_size):
|
217 |
+
|
218 |
+
q_collated = tokenizer(
|
219 |
+
questions[start_index: start_index + batch_size],
|
220 |
+
padding=True,
|
221 |
+
truncation=True,
|
222 |
+
max_length=128,
|
223 |
+
return_tensors="pt",
|
224 |
+
).to('cuda')
|
225 |
+
|
226 |
+
with torch.no_grad():
|
227 |
+
query_embeddings = model.encoder.get_representation(q_collated['input_ids'], q_collated['attention_mask'])
|
228 |
+
query_embeddings = query_embeddings.cpu().numpy()
|
229 |
+
q_embeddings.append(query_embeddings)
|
230 |
+
|
231 |
+
q_embeddings = np.concatenate(q_embeddings, axis=0)
|
232 |
+
query_size = q_embeddings.shape[0]
|
233 |
+
all_scores = []
|
234 |
+
all_indices = []
|
235 |
+
|
236 |
+
for i in tqdm(range(0, query_size, batch_size), desc="Searching"):
|
237 |
+
j = min(i + batch_size, query_size)
|
238 |
+
q_embedding = q_embeddings[i: j]
|
239 |
+
score, indice = faiss_index.search(q_embedding.astype(np.float32), k=k)
|
240 |
+
all_scores.append(score)
|
241 |
+
all_indices.append(indice)
|
242 |
+
|
243 |
+
all_scores = np.concatenate(all_scores, axis=0)
|
244 |
+
all_indices = np.concatenate(all_indices, axis=0)
|
245 |
+
return all_scores, all_indices
|
246 |
+
|
247 |
+
def rerank(reranker: SharedBiEncoder, tokenizer:AutoTokenizer, questions, corpus, retrieved_ids, batch_size = 128, max_length = 256, top_k=30):
|
248 |
+
eos = tokenizer.eos_token
|
249 |
+
#questions = queries['tokenized_question'].tolist()
|
250 |
+
texts = []
|
251 |
+
for idx in range(len(questions)):
|
252 |
+
for j in range(top_k):
|
253 |
+
texts.append(questions[idx] + eos + eos + corpus[retrieved_ids[idx][j]])
|
254 |
+
reranked_ids = []
|
255 |
+
rerank_scores = []
|
256 |
+
|
257 |
+
for start_index in tqdm(range(0, len(questions), batch_size), desc="Rerank",
|
258 |
+
disable=len(questions) < batch_size):
|
259 |
+
batch_retrieved_ids = retrieved_ids[start_index: start_index+batch_size]
|
260 |
+
collated = tokenizer(
|
261 |
+
texts[start_index*top_k: (start_index + batch_size)*top_k],
|
262 |
+
padding=True,
|
263 |
+
truncation=True,
|
264 |
+
max_length=max_length,
|
265 |
+
return_tensors="pt",
|
266 |
+
).to('cuda')
|
267 |
+
reranked_scores = reranker(collated).logits
|
268 |
+
reranked_scores = reranked_scores.view(-1,top_k).to('cpu').tolist()
|
269 |
+
for m in range(len(reranked_scores)):
|
270 |
+
tuple_lst = [(batch_retrieved_ids[m][n], reranked_scores[m][n]) for n in range(top_k)]
|
271 |
+
tuple_lst.sort(key=lambda tup: tup[1], reverse=True)
|
272 |
+
reranked_ids.append([tup[0] for tup in tuple_lst])
|
273 |
+
rerank_scores.append([tup[1] for tup in tuple_lst])
|
274 |
+
|
275 |
+
return reranked_ids, rerank_scores
|
276 |
+
|
277 |
+
|
278 |
+
|
279 |
+
|
280 |
+
def evaluate(preds, labels, cutoffs=[1,5,10,30,100]):
|
281 |
+
"""
|
282 |
+
Evaluate MRR and Recall at cutoffs.
|
283 |
+
"""
|
284 |
+
metrics = {}
|
285 |
+
|
286 |
+
# MRR
|
287 |
+
mrrs = np.zeros(len(cutoffs))
|
288 |
+
for pred, label in zip(preds, labels):
|
289 |
+
jump = False
|
290 |
+
for i, x in enumerate(pred, 1):
|
291 |
+
if x in label:
|
292 |
+
for k, cutoff in enumerate(cutoffs):
|
293 |
+
if i <= cutoff:
|
294 |
+
mrrs[k] += 1 / i
|
295 |
+
jump = True
|
296 |
+
if jump:
|
297 |
+
break
|
298 |
+
mrrs /= len(preds)
|
299 |
+
for i, cutoff in enumerate(cutoffs):
|
300 |
+
mrr = mrrs[i]
|
301 |
+
metrics[f"MRR@{cutoff}"] = mrr
|
302 |
+
|
303 |
+
# Recall
|
304 |
+
recalls = np.zeros(len(cutoffs))
|
305 |
+
for pred, label in zip(preds, labels):
|
306 |
+
for k, cutoff in enumerate(cutoffs):
|
307 |
+
recall = np.intersect1d(label, pred[:cutoff])
|
308 |
+
recalls[k] += len(recall) / len(label)
|
309 |
+
recalls /= len(preds)
|
310 |
+
for i, cutoff in enumerate(cutoffs):
|
311 |
+
recall = recalls[i]
|
312 |
+
metrics[f"Recall@{cutoff}"] = recall
|
313 |
+
|
314 |
+
return metrics
|
315 |
+
|
316 |
+
def calculate_score(ground_ids, retrieved_list):
|
317 |
+
all_count = 0
|
318 |
+
hit_count = 0
|
319 |
+
for i in range(len(ground_ids)):
|
320 |
+
all_check = True
|
321 |
+
hit_check = False
|
322 |
+
retrieved_ids = retrieved_list[i]
|
323 |
+
ans_ids = ground_ids[i]
|
324 |
+
for a_ids in ans_ids:
|
325 |
+
com = [a_id for a_id in a_ids if a_id in retrieved_ids]
|
326 |
+
if len(com) > 0:
|
327 |
+
hit_check = True
|
328 |
+
else:
|
329 |
+
all_check = False
|
330 |
+
|
331 |
+
if hit_check:
|
332 |
+
hit_count += 1
|
333 |
+
if all_check:
|
334 |
+
all_count += 1
|
335 |
+
|
336 |
+
all_acc = all_count/len(ground_ids)
|
337 |
+
hit_acc = hit_count/len(ground_ids)
|
338 |
+
return hit_acc, all_acc
|
339 |
+
|
340 |
+
def check(ground_ids, retrieved_list, cutoffs=[1,5,10,30,100]):
|
341 |
+
metrics = {}
|
342 |
+
for cutoff in cutoffs:
|
343 |
+
retrieved_k = [x[:cutoff] for x in retrieved_list]
|
344 |
+
hit_acc, all_acc = calculate_score(ground_ids, retrieved_k)
|
345 |
+
metrics[f"All@{cutoff}"] = all_acc
|
346 |
+
metrics[f"Hit@{cutoff}"] = hit_acc
|
347 |
+
return metrics
|
348 |
+
|
349 |
+
def save_bi_data(tokenized_queries, ground_ids, indices, scores, file, org_questions=None):
|
350 |
+
rst = []
|
351 |
+
#tokenized_queries = test_data['tokenized_question'].tolist()
|
352 |
+
for i in range(len(tokenized_queries)):
|
353 |
+
scores_i = scores[i]
|
354 |
+
indices_i = indices[i]
|
355 |
+
ans_ids = ground_ids[i]
|
356 |
+
all_ans_id = [element for x in ans_ids for element in x]
|
357 |
+
neg_doc_ids = []
|
358 |
+
neg_scores = []
|
359 |
+
for count in range(len(indices_i)):
|
360 |
+
if indices_i[count] not in all_ans_id and indices_i[count] != -1:
|
361 |
+
neg_doc_ids.append(indices_i[count])
|
362 |
+
neg_scores.append(scores_i[count])
|
363 |
+
|
364 |
+
for j in range(len(ans_ids)):
|
365 |
+
ans_id = ans_ids[j]
|
366 |
+
item = {}
|
367 |
+
if org_questions != None:
|
368 |
+
item['question'] = org_questions[i]
|
369 |
+
item['query'] = tokenized_queries[i]
|
370 |
+
item['positives'] = {}
|
371 |
+
item['negatives'] = {}
|
372 |
+
item['positives']['doc_id'] = []
|
373 |
+
item['positives']['score'] = []
|
374 |
+
item['negatives']['doc_id'] = neg_doc_ids
|
375 |
+
item['negatives']['score'] = neg_scores
|
376 |
+
for pos_id in ans_id:
|
377 |
+
item['positives']['doc_id'].append(pos_id)
|
378 |
+
try:
|
379 |
+
idx = indices_i.index(pos_id)
|
380 |
+
item['positives']['score'].append(scores_i[idx])
|
381 |
+
except:
|
382 |
+
item['positives']['score'].append(scores_i[-1])
|
383 |
+
|
384 |
+
rst.append(item)
|
385 |
+
|
386 |
+
with open(f'{file}.jsonl', 'w') as jsonl_file:
|
387 |
+
for item in rst:
|
388 |
+
json_line = json.dumps(item, ensure_ascii=False)
|
389 |
+
jsonl_file.write(json_line + '\n')
|
390 |
+
|
391 |
+
def main():
|
392 |
+
parser = HfArgumentParser([Args])
|
393 |
+
args: Args = parser.parse_args_into_dataclasses()[0]
|
394 |
+
print(args)
|
395 |
+
model = SharedBiEncoder(model_checkpoint=args.encoder,
|
396 |
+
representation=args.sentence_pooling_method,
|
397 |
+
fixed=True)
|
398 |
+
model.to('cuda')
|
399 |
+
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer if args.tokenizer else args.encoder)
|
400 |
+
reranker = RerankerForInference(model_checkpoint=args.cross_checkpoint)
|
401 |
+
reranker.to('cuda')
|
402 |
+
reranker_tokenizer = AutoTokenizer.from_pretrained(args.cross_tokenizer if args.cross_tokenizer else args.cross_checkpoint)
|
403 |
+
csv_file = True
|
404 |
+
if args.data_file:
|
405 |
+
if args.data_file.endswith("jsonl"):
|
406 |
+
test_data = []
|
407 |
+
with open(args.data_file, 'r') as jsonl_file:
|
408 |
+
for line in jsonl_file:
|
409 |
+
temp = json.loads(line)
|
410 |
+
test_data.append(temp)
|
411 |
+
csv_file=False
|
412 |
+
elif args.data_file.endswith("json"):
|
413 |
+
csv_file=False
|
414 |
+
with open(args.data_file, 'r') as json_file:
|
415 |
+
test_data = json.load(json_file)
|
416 |
+
elif args.data_file.endswith("csv"):
|
417 |
+
test_data = pd.read_csv(args.data_file)
|
418 |
+
|
419 |
+
elif args.data_type == 'eval':
|
420 |
+
test_data = pd.read_csv(args.data_path + "/tval.csv")
|
421 |
+
elif args.data_type == 'train':
|
422 |
+
test_data = pd.read_csv(args.data_path + "/ttrain.csv")
|
423 |
+
elif args.data_type == 'all':
|
424 |
+
data1 = pd.read_csv(args.data_path + "/ttrain.csv")
|
425 |
+
data2 = pd.read_csv(args.data_path + "/ttest.csv")
|
426 |
+
data3 = pd.read_csv(args.data_path + "/tval.csv")
|
427 |
+
test_data = pd.concat([data1, data3, data2], ignore_index=True)
|
428 |
+
|
429 |
+
else:
|
430 |
+
test_data = pd.read_csv(args.data_path + "/ttest.csv")
|
431 |
+
corpus_data = pd.read_csv(args.corpus_file)
|
432 |
+
#dcorpus = pd.DataFrame(corpus_data)
|
433 |
+
#pandarallel.initialize(progress_bar=True, use_memory_fs=False, nb_workers=12)
|
434 |
+
#dcorpus["full_text"] = dcorpus.parallel_apply(concat_str, axis=1)
|
435 |
+
corpus = corpus_data['tokenized_text'].tolist()
|
436 |
+
|
437 |
+
if csv_file:
|
438 |
+
ans_ids = []
|
439 |
+
ground_ids = []
|
440 |
+
org_questions = test_data['question'].tolist()
|
441 |
+
questions = test_data['tokenized_question'].tolist()
|
442 |
+
for i in range(len(test_data)):
|
443 |
+
ans_ids.append(json.loads(test_data['best_ans_id'][i]))
|
444 |
+
ground_ids.append(json.loads(test_data['ans_id'][i]))
|
445 |
+
ground_truths = []
|
446 |
+
for sample in ans_ids:
|
447 |
+
temp = [corpus_data['law_id'][y] + "_" + str(corpus_data['article_id'][y]) for y in sample]
|
448 |
+
ground_truths.append(temp)
|
449 |
+
else:
|
450 |
+
ground_truths = []
|
451 |
+
ground_ids = []
|
452 |
+
org_questions = [sample['question'] for sample in test_data]
|
453 |
+
questions = [tokenize(preprocess_question(sample['question'], remove_end_phrase=False)) for sample in test_data]
|
454 |
+
for sample in test_data:
|
455 |
+
try:
|
456 |
+
temp = [it['law_id'] + "_" + it['article_id'] for it in sample['relevance_articles']]
|
457 |
+
tempp = [it['ans_id'] for it in sample['relevance_articles']]
|
458 |
+
except:
|
459 |
+
temp = [it['law_id'] + "_" + it['article_id'] for it in sample['relevant_articles']]
|
460 |
+
tempp = [it['ans_id'] for it in sample['relevant_articles']]
|
461 |
+
ground_truths.append(temp)
|
462 |
+
ground_ids.append(tempp)
|
463 |
+
|
464 |
+
faiss_index = index(
|
465 |
+
model=model,
|
466 |
+
tokenizer=tokenizer,
|
467 |
+
corpus=corpus,
|
468 |
+
batch_size=args.batch_size,
|
469 |
+
max_length=args.max_passage_length,
|
470 |
+
index_factory=args.index_factory,
|
471 |
+
save_path=args.save_path,
|
472 |
+
save_embedding=args.save_embedding,
|
473 |
+
load_embedding=args.load_embedding
|
474 |
+
)
|
475 |
+
|
476 |
+
scores, indices = search(
|
477 |
+
model=model,
|
478 |
+
tokenizer=tokenizer,
|
479 |
+
questions=questions,
|
480 |
+
faiss_index=faiss_index,
|
481 |
+
k=args.k,
|
482 |
+
batch_size=args.batch_size,
|
483 |
+
max_length=args.max_query_length
|
484 |
+
)
|
485 |
+
|
486 |
+
|
487 |
+
retrieval_results, retrieval_ids = [], []
|
488 |
+
for indice in indices:
|
489 |
+
# filter invalid indices
|
490 |
+
indice = indice[indice != -1].tolist()
|
491 |
+
rst = []
|
492 |
+
for x in indice:
|
493 |
+
temp = corpus_data['law_id'][x] + "_" + str(corpus_data['article_id'][x])
|
494 |
+
if temp not in rst:
|
495 |
+
rst.append(temp)
|
496 |
+
retrieval_results.append(rst)
|
497 |
+
retrieval_ids.append(indice)
|
498 |
+
|
499 |
+
rerank_ids, rerank_scores = rerank(reranker, reranker_tokenizer, questions, corpus, retrieval_ids, args.cross_batch_size, args.cross_max_length, args.top_k)
|
500 |
+
|
501 |
+
if args.bi_data:
|
502 |
+
save_bi_data(questions, ground_ids, rerank_ids, rerank_scores, args.data_type, org_questions)
|
503 |
+
|
504 |
+
metrics = check(ground_ids, retrieval_ids)
|
505 |
+
print(metrics)
|
506 |
+
metrics = evaluate(retrieval_results, ground_truths)
|
507 |
+
print(metrics)
|
508 |
+
metrics = check(ground_ids, rerank_ids, cutoffs=[1,5,10,30])
|
509 |
+
print(metrics)
|
510 |
+
|
511 |
+
if __name__ == "__main__":
|
512 |
+
main()
|
src/train_cross_encoder.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from typing import Dict
|
5 |
+
from transformers.utils.logging import enable_explicit_format
|
6 |
+
from transformers.trainer_callback import PrinterCallback
|
7 |
+
from transformers import (
|
8 |
+
AutoTokenizer,
|
9 |
+
HfArgumentParser,
|
10 |
+
EvalPrediction,
|
11 |
+
set_seed,
|
12 |
+
PreTrainedTokenizerFast
|
13 |
+
)
|
14 |
+
|
15 |
+
from cross_rerank.logger_config import logger, LoggerCallback
|
16 |
+
from cross_rerank.config import Arguments
|
17 |
+
from cross_rerank.trainer import RerankerTrainer
|
18 |
+
from cross_rerank.data_loader import CrossEncoderDataLoader
|
19 |
+
from cross_rerank.collator import CrossEncoderCollator
|
20 |
+
from cross_rerank.metrics import accuracy
|
21 |
+
from cross_rerank.model import Reranker
|
22 |
+
|
23 |
+
|
24 |
+
def _common_setup(args: Arguments):
|
25 |
+
if args.process_index > 0:
|
26 |
+
logger.setLevel(logging.WARNING)
|
27 |
+
enable_explicit_format()
|
28 |
+
set_seed(args.seed)
|
29 |
+
|
30 |
+
|
31 |
+
def _compute_metrics(eval_pred: EvalPrediction) -> Dict:
|
32 |
+
preds = eval_pred.predictions
|
33 |
+
if isinstance(preds, tuple):
|
34 |
+
preds = preds[-1]
|
35 |
+
logits = torch.tensor(preds).float()
|
36 |
+
labels = torch.tensor(eval_pred.label_ids).long()
|
37 |
+
acc = accuracy(output=logits, target=labels)[0]
|
38 |
+
|
39 |
+
return {'acc': acc}
|
40 |
+
|
41 |
+
|
42 |
+
def main():
|
43 |
+
parser = HfArgumentParser((Arguments,))
|
44 |
+
args: Arguments = parser.parse_args_into_dataclasses()[0]
|
45 |
+
_common_setup(args)
|
46 |
+
logger.info('Args={}'.format(str(args)))
|
47 |
+
|
48 |
+
try:
|
49 |
+
tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
50 |
+
except:
|
51 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
52 |
+
|
53 |
+
model: Reranker = Reranker.from_pretrained(
|
54 |
+
all_args=args,
|
55 |
+
pretrained_model_name_or_path=args.model_name_or_path,
|
56 |
+
num_labels=1)
|
57 |
+
|
58 |
+
logger.info(model)
|
59 |
+
logger.info('Vocab size: {}'.format(len(tokenizer)))
|
60 |
+
|
61 |
+
data_collator = CrossEncoderCollator(
|
62 |
+
tokenizer=tokenizer,
|
63 |
+
pad_to_multiple_of=256 if args.fp16 else 256)
|
64 |
+
|
65 |
+
rerank_data_loader = CrossEncoderDataLoader(args=args, tokenizer=tokenizer)
|
66 |
+
train_dataset = rerank_data_loader.train_dataset
|
67 |
+
eval_dataset = rerank_data_loader.eval_dataset
|
68 |
+
|
69 |
+
trainer = RerankerTrainer(
|
70 |
+
model=model,
|
71 |
+
args=args,
|
72 |
+
train_dataset=train_dataset if args.do_train else None,
|
73 |
+
eval_dataset=eval_dataset if args.do_eval else None,
|
74 |
+
data_collator=data_collator,
|
75 |
+
compute_metrics=_compute_metrics,
|
76 |
+
tokenizer=tokenizer,
|
77 |
+
)
|
78 |
+
trainer.remove_callback(PrinterCallback)
|
79 |
+
trainer.add_callback(LoggerCallback)
|
80 |
+
rerank_data_loader.trainer = trainer
|
81 |
+
|
82 |
+
if args.do_eval:
|
83 |
+
logger.info("*** Evaluate ***")
|
84 |
+
metrics = trainer.evaluate(metric_key_prefix="eval")
|
85 |
+
metrics["eval_samples"] = len(eval_dataset)
|
86 |
+
|
87 |
+
trainer.log_metrics("eval", metrics)
|
88 |
+
trainer.save_metrics("eval", metrics)
|
89 |
+
|
90 |
+
if args.do_train:
|
91 |
+
train_result = trainer.train(resume_from_checkpoint= args.resume_from_checkpoint)
|
92 |
+
trainer.save_model()
|
93 |
+
|
94 |
+
metrics = train_result.metrics
|
95 |
+
metrics["train_samples"] = len(train_dataset)
|
96 |
+
|
97 |
+
trainer.log_metrics("train", metrics)
|
98 |
+
trainer.save_metrics("train", metrics)
|
99 |
+
|
100 |
+
if args.do_eval:
|
101 |
+
logger.info("*** Evaluate ***")
|
102 |
+
metrics = trainer.evaluate(metric_key_prefix="eval")
|
103 |
+
metrics["eval_samples"] = len(eval_dataset)
|
104 |
+
|
105 |
+
trainer.log_metrics("eval", metrics)
|
106 |
+
trainer.save_metrics("eval", metrics)
|
107 |
+
|
108 |
+
return
|
109 |
+
|
110 |
+
|
111 |
+
if __name__ == "__main__":
|
112 |
+
main()
|