coang commited on
Commit
77010c2
·
verified ·
1 Parent(s): 6bd7052

Upload 14 files

Browse files
src/cross_rerank/__init__.py ADDED
File without changes
src/cross_rerank/collator.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import transformers
3
+ from dataclasses import dataclass
4
+ from typing import List, Dict, Any
5
+ from transformers import BatchEncoding, DataCollatorWithPadding
6
+
7
+
8
+ @dataclass
9
+ class CrossEncoderCollator(DataCollatorWithPadding):
10
+
11
+ def __call__(self, features: List[Dict[str, Any]]) -> BatchEncoding:
12
+ unpack_features = []
13
+ for ex in features:
14
+ keys = list(ex.keys())
15
+ # assert all(len(ex[k]) == 8 for k in keys)
16
+ for idx in range(len(ex[keys[0]])):
17
+ unpack_features.append({k: ex[k][idx] for k in keys})
18
+
19
+ old_level = transformers.logging.get_verbosity()
20
+ transformers.logging.set_verbosity_error()
21
+ collated_batch_dict = self.tokenizer.pad(
22
+ unpack_features,
23
+ padding=self.padding,
24
+ pad_to_multiple_of=self.pad_to_multiple_of,
25
+ return_tensors=self.return_tensors)
26
+ transformers.logging.set_verbosity(old_level)
27
+
28
+ collated_batch_dict['labels'] = torch.zeros(len(features), dtype=torch.long)
29
+
30
+ return collated_batch_dict
src/cross_rerank/config.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from dataclasses import dataclass, field
4
+ from typing import Optional
5
+ from transformers import TrainingArguments
6
+
7
+
8
+ from .logger_config import logger
9
+
10
+
11
+ @dataclass
12
+ class Arguments(TrainingArguments):
13
+ model_name_or_path: str = field(
14
+ default='bert-base-uncased',
15
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
16
+ )
17
+
18
+ corpus_file: str = field(
19
+ default=None, metadata={"help": "Path to corpus file"}
20
+ )
21
+
22
+ data_dir: str = field(
23
+ default=None, metadata={"help": "Path to train directory"}
24
+ )
25
+ task_type: str = field(
26
+ default='ir', metadata={"help": "task type: ir / qa"}
27
+ )
28
+ train_file: Optional[str] = field(
29
+ default=None, metadata={"help": "The input training data file (a jsonlines file)."}
30
+ )
31
+ validation_file: Optional[str] = field(
32
+ default=None,
33
+ metadata={
34
+ "help": "An optional input evaluation data file to evaluate the metrics on (a jsonlines file)."
35
+ },
36
+ )
37
+
38
+ train_n_passages: int = field(
39
+ default=8,
40
+ metadata={"help": "number of passages for each example (including both positive and negative passages)"}
41
+ )
42
+ share_encoder: bool = field(
43
+ default=True,
44
+ metadata={"help": "no weight sharing between qry passage encoders"}
45
+ )
46
+ use_first_positive: bool = field(
47
+ default=False,
48
+ metadata={"help": "Always use the first positive passage"}
49
+ )
50
+ use_first_negative: bool = field(
51
+ default=False,
52
+ metadata={"help": "Always use the first positive passage"}
53
+ )
54
+ use_scaled_loss: bool = field(
55
+ default=True,
56
+ metadata={"help": "Use scaled loss or not"}
57
+ )
58
+ loss_scale: float = field(
59
+ default=-1.,
60
+ metadata={"help": "loss scale, -1 will use world_size"}
61
+ )
62
+ add_pooler: bool = field(default=False)
63
+ out_dimension: int = field(
64
+ default=768,
65
+ metadata={"help": "output dimension for pooler"}
66
+ )
67
+ t: float = field(default=0.05, metadata={"help": "temperature of biencoder training"})
68
+ l2_normalize: bool = field(default=True, metadata={"help": "L2 normalize embeddings or not"})
69
+ t_warmup: bool = field(default=False, metadata={"help": "warmup temperature"})
70
+ full_contrastive_loss: bool = field(default=True, metadata={"help": "use full contrastive loss or not"})
71
+
72
+ # following arguments are used for encoding documents
73
+ do_encode: bool = field(default=False, metadata={"help": "run the encoding loop"})
74
+ encode_in_path: str = field(default=None, metadata={"help": "Path to data to encode"})
75
+ encode_save_dir: str = field(default=None, metadata={"help": "where to save the encode"})
76
+ encode_shard_size: int = field(default=int(2 * 10**6))
77
+ encode_batch_size: int = field(default=256)
78
+
79
+ # used for index search
80
+ do_search: bool = field(default=False, metadata={"help": "run the index search loop"})
81
+ search_split: str = field(default='dev', metadata={"help": "which split to search"})
82
+ search_batch_size: int = field(default=128, metadata={"help": "query batch size for index search"})
83
+ search_topk: int = field(default=200, metadata={"help": "return topk search results"})
84
+ search_out_dir: str = field(default='', metadata={"help": "output directory for writing search results"})
85
+
86
+ # used for reranking
87
+ do_rerank: bool = field(default=False, metadata={"help": "run the reranking loop"})
88
+ rerank_max_length: int = field(default=256, metadata={"help": "max length for rerank inputs"})
89
+ rerank_in_path: str = field(default='', metadata={"help": "Path to predictions for rerank"})
90
+ rerank_out_path: str = field(default='', metadata={"help": "Path to write rerank results"})
91
+ rerank_split: str = field(default='dev', metadata={"help": "which split to rerank"})
92
+ rerank_batch_size: int = field(default=128, metadata={"help": "rerank batch size"})
93
+ rerank_depth: int = field(default=1000, metadata={"help": "rerank depth, useful for debugging purpose"})
94
+ rerank_forward_factor: int = field(
95
+ default=1,
96
+ metadata={"help": "forward n passages, then select top n/factor passages for backward"}
97
+ )
98
+ rerank_use_rdrop: bool = field(default=False, metadata={"help": "use R-Drop regularization for re-ranker"})
99
+
100
+ # used for knowledge distillation
101
+ do_kd_gen_score: bool = field(default=False, metadata={"help": "run the score generation for distillation"})
102
+ kd_gen_score_split: str = field(default='dev', metadata={
103
+ "help": "Which split to use for generation of teacher score"
104
+ })
105
+ kd_gen_score_batch_size: int = field(default=128, metadata={"help": "batch size for teacher score generation"})
106
+ kd_gen_score_n_neg: int = field(default=30, metadata={"help": "number of negatives to compute teacher scores"})
107
+
108
+ do_kd_biencoder: bool = field(default=False, metadata={"help": "knowledge distillation to biencoder"})
109
+ kd_mask_hn: bool = field(default=True, metadata={"help": "mask out hard negatives for distillation"})
110
+ kd_cont_loss_weight: float = field(default=1.0, metadata={"help": "weight for contrastive loss"})
111
+
112
+ rlm_generator_model_name: Optional[str] = field(
113
+ default='google/electra-base-generator',
114
+ metadata={"help": "generator for replace LM pre-training"}
115
+ )
116
+ rlm_freeze_generator: Optional[bool] = field(
117
+ default=True,
118
+ metadata={'help': 'freeze generator params or not'}
119
+ )
120
+ rlm_generator_mlm_weight: Optional[float] = field(
121
+ default=0.2,
122
+ metadata={'help': 'weight for generator MLM loss'}
123
+ )
124
+ all_use_mask_token: Optional[bool] = field(
125
+ default=False,
126
+ metadata={'help': 'Do not use 80:10:10 mask, use [MASK] for all places'}
127
+ )
128
+ rlm_num_eval_samples: Optional[int] = field(
129
+ default=4096,
130
+ metadata={"help": "number of evaluation samples pre-training"}
131
+ )
132
+ rlm_max_length: Optional[int] = field(
133
+ default=144,
134
+ metadata={"help": "max length for MatchLM pre-training"}
135
+ )
136
+ rlm_decoder_layers: Optional[int] = field(
137
+ default=2,
138
+ metadata={"help": "number of transformer layers for MatchLM decoder part"}
139
+ )
140
+ rlm_encoder_mask_prob: Optional[float] = field(
141
+ default=0.3,
142
+ metadata={'help': 'mask rate for encoder'}
143
+ )
144
+ rlm_decoder_mask_prob: Optional[float] = field(
145
+ default=0.5,
146
+ metadata={'help': 'mask rate for decoder'}
147
+ )
148
+
149
+ q_max_len: int = field(
150
+ default=32,
151
+ metadata={
152
+ "help": "The maximum total input sequence length after tokenization for query."
153
+ },
154
+ )
155
+ p_max_len: int = field(
156
+ default=144,
157
+ metadata={
158
+ "help": "The maximum total input sequence length after tokenization for passage."
159
+ },
160
+ )
161
+
162
+ chunk_size: int = field(
163
+ default=8,
164
+ metadata={
165
+ "help": "The maximum total chunk"
166
+ },
167
+ )
168
+
169
+ max_train_samples: Optional[int] = field(
170
+ default=None,
171
+ metadata={
172
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
173
+ "value if set."
174
+ },
175
+ )
176
+ dry_run: Optional[bool] = field(
177
+ default=False,
178
+ metadata={'help': 'Set dry_run to True for debugging purpose'}
179
+ )
180
+
181
+ def __post_init__(self):
182
+ assert os.path.exists(self.data_dir)
183
+ assert torch.cuda.is_available(), 'Only support running on GPUs'
184
+ assert self.task_type in ['ir', 'qa']
185
+
186
+ if self.dry_run:
187
+ self.logging_steps = 1
188
+ self.max_train_samples = self.max_train_samples or 128
189
+ self.num_train_epochs = 1
190
+ self.per_device_train_batch_size = min(2, self.per_device_train_batch_size)
191
+ self.train_n_passages = min(4, self.train_n_passages)
192
+ self.rerank_forward_factor = 1
193
+ self.gradient_accumulation_steps = 1
194
+ self.rlm_num_eval_samples = min(256, self.rlm_num_eval_samples)
195
+ self.max_steps = 30
196
+ self.save_steps = self.eval_steps = 30
197
+ logger.warning('Dry run: set logging_steps=1')
198
+
199
+ if self.do_encode:
200
+ assert self.encode_save_dir
201
+ os.makedirs(self.encode_save_dir, exist_ok=True)
202
+ assert os.path.exists(self.encode_in_path)
203
+
204
+ if self.do_search:
205
+ assert os.path.exists(self.encode_save_dir)
206
+ assert self.search_out_dir
207
+ os.makedirs(self.search_out_dir, exist_ok=True)
208
+
209
+ if self.do_rerank:
210
+ assert os.path.exists(self.rerank_in_path)
211
+ logger.info('Rerank result will be written to {}'.format(self.rerank_out_path))
212
+ assert self.train_n_passages > 1, 'Having positive passages only does not make sense for training re-ranker'
213
+ assert self.train_n_passages % self.rerank_forward_factor == 0
214
+
215
+ if self.do_kd_gen_score:
216
+ assert os.path.exists('{}/{}.jsonl'.format(self.data_dir, self.kd_gen_score_split))
217
+
218
+ if self.do_kd_biencoder:
219
+ if self.use_scaled_loss:
220
+ assert not self.kd_mask_hn, 'Use scaled loss only works with not masking out hard negatives'
221
+
222
+ if torch.cuda.device_count() <= 1:
223
+ self.logging_steps = min(50, self.logging_steps)
224
+
225
+ super(Arguments, self).__post_init__()
226
+
227
+ if self.output_dir:
228
+ os.makedirs(self.output_dir, exist_ok=True)
229
+
230
+ self.label_names = ['labels']
src/cross_rerank/data_loader.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import random
3
+ import pandas as pd
4
+ import transformers
5
+ from typing import Tuple, Dict, List, Optional
6
+ from datasets import load_dataset, DatasetDict, Dataset
7
+ from transformers.file_utils import PaddingStrategy
8
+ from transformers import PreTrainedTokenizerFast, Trainer
9
+
10
+ from .config import Arguments
11
+ from .logger_config import logger
12
+ from .loader_utils import group_doc_ids
13
+
14
+
15
+ class CrossEncoderDataLoader:
16
+ def __init__(self, args: Arguments, tokenizer):
17
+ self.args = args
18
+ self.negative_size = args.train_n_passages - 1
19
+ assert self.negative_size > 0
20
+ #self.hard_neg_size = args.hard_neg_size if args.hard_neg_size < self.negative_size else self.negative_size
21
+ #self.rand_neg_size = self.negative_size - self.hard_neg_size if self.hard_neg_size < self.negative_size else 0
22
+ self.tokenizer = tokenizer
23
+ #corpus_path = os.path.join(args.data_dir, 'passages.jsonl.gz')
24
+ #self.corpus: Dataset = load_dataset('json', data_files=corpus_path)['train']
25
+ self.corpus = pd.read_csv(args.corpus_file)
26
+ #self.corpus = dcorpus['tokenized_text'].to_list()
27
+ #self.corpus_bad_ids = [idx for idx in range(len(self.corpus)) if (type(self.corpus['bm25_text'][idx]) is not str)]
28
+ self.train_dataset, self.eval_dataset = self._get_transformed_datasets()
29
+
30
+ # use its state to decide which positives/negatives to sample
31
+ self.trainer: Optional[Trainer] = None
32
+
33
+ def _transform_func(self, examples: Dict[str, List]) -> Dict[str, List]:
34
+ current_epoch = int(self.trainer.state.epoch or 0)
35
+
36
+ input_doc_ids = group_doc_ids(
37
+ examples=examples,
38
+ negative_size=self.negative_size,
39
+ offset=current_epoch + self.args.seed,
40
+ use_first_positive=self.args.use_first_positive,
41
+ use_first_negative=self.args.use_first_negative
42
+ )
43
+ assert len(input_doc_ids) == len(examples['query']) * self.args.train_n_passages
44
+
45
+ input_queries, input_docs = [], []
46
+ for idx, doc_id in enumerate(input_doc_ids):
47
+ #prefix = ''
48
+ #if self.corpus[doc_id].get('title', ''):
49
+ # prefix = self.corpus[doc_id]['title'] + ': '
50
+
51
+ input_docs.append(self.corpus['tokenized_text'][doc_id])
52
+ input_queries.append(examples['query'][idx // self.args.train_n_passages])
53
+
54
+ old_level = transformers.logging.get_verbosity()
55
+ transformers.logging.set_verbosity_error()
56
+ batch_dict = self.tokenizer(input_queries,
57
+ text_pair=input_docs,
58
+ max_length=self.args.rerank_max_length,
59
+ padding=PaddingStrategy.DO_NOT_PAD,
60
+ truncation=True)
61
+ transformers.logging.set_verbosity(old_level)
62
+
63
+ packed_batch_dict = {}
64
+ for k in batch_dict:
65
+ packed_batch_dict[k] = []
66
+ assert len(examples['query']) * self.args.train_n_passages == len(batch_dict[k])
67
+ for idx in range(len(examples['query'])):
68
+ start = idx * self.args.train_n_passages
69
+ packed_batch_dict[k].append(batch_dict[k][start:(start + self.args.train_n_passages)])
70
+
71
+ return packed_batch_dict
72
+
73
+ def _get_transformed_datasets(self) -> Tuple:
74
+ #data_files = {}
75
+ #if self.args.train_file is not None:
76
+ # data_files["train"] = self.args.train_file.split(',')
77
+ #if self.args.validation_file is not None:
78
+ # data_files["validation"] = self.args.validation_file
79
+ #raw_datasets: DatasetDict = load_dataset('json', data_files=data_files)
80
+
81
+ train_dataset, eval_dataset = None, None
82
+
83
+ if self.args.do_train:
84
+ #if "train" not in raw_datasets:
85
+ try:
86
+ train_dataset = load_dataset('json', data_files = os.path.join(self.args.data_dir, 'train.jsonl'))['train']
87
+ except:
88
+ raise ValueError("--do_train requires a train dataset")
89
+ #train_dataset = raw_datasets["train"]
90
+ if self.args.max_train_samples is not None:
91
+ train_dataset = train_dataset.select(range(self.args.max_train_samples))
92
+ # Log a few random samples from the training set:
93
+ for index in random.sample(range(len(train_dataset)), 1):
94
+ logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
95
+ train_dataset.set_transform(self._transform_func)
96
+
97
+ if self.args.do_eval:
98
+ #if "validation" not in raw_datasets:
99
+ try:
100
+ eval_dataset = load_dataset('json', data_files = os.path.join(self.args.data_dir, 'eval.jsonl'))['train']
101
+ except:
102
+ raise ValueError("--do_eval requires a validation dataset")
103
+ #eval_dataset = raw_datasets["validation"]
104
+ eval_dataset.set_transform(self._transform_func)
105
+
106
+ return train_dataset, eval_dataset
src/cross_rerank/data_utils.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import tqdm
4
+ import json
5
+
6
+ from typing import Dict, List, Any
7
+ from datasets import load_dataset, Dataset
8
+ from dataclasses import dataclass, field
9
+
10
+ from .logger_config import logger
11
+ from .config import Arguments
12
+ from .utils import save_json_to_file
13
+
14
+
15
+ @dataclass
16
+ class ScoredDoc:
17
+ qid: str
18
+ pid: str
19
+ rank: int
20
+ score: float = field(default=-1)
21
+
22
+
23
+ def load_qrels(path: str) -> Dict[str, Dict[str, int]]:
24
+ assert path.endswith('.txt')
25
+
26
+ # qid -> pid -> score
27
+ qrels = {}
28
+ for line in open(path, 'r', encoding='utf-8'):
29
+ qid, _, pid, score = line.strip().split('\t')
30
+ if qid not in qrels:
31
+ qrels[qid] = {}
32
+ qrels[qid][pid] = int(score)
33
+
34
+ logger.info('Load {} queries {} qrels from {}'.format(len(qrels), sum(len(v) for v in qrels.values()), path))
35
+ return qrels
36
+
37
+
38
+ def load_queries(path: str, task_type: str = 'ir') -> Dict[str, str]:
39
+ assert path.endswith('.tsv')
40
+
41
+ if task_type == 'qa':
42
+ qid_to_query = load_query_answers(path)
43
+ qid_to_query = {k: v['query'] for k, v in qid_to_query.items()}
44
+ elif task_type == 'ir':
45
+ qid_to_query = {}
46
+ for line in open(path, 'r', encoding='utf-8'):
47
+ qid, query = line.strip().split('\t')
48
+ qid_to_query[qid] = query
49
+ else:
50
+ raise ValueError('Unknown task type: {}'.format(task_type))
51
+
52
+ logger.info('Load {} queries from {}'.format(len(qid_to_query), path))
53
+ return qid_to_query
54
+
55
+
56
+ def normalize_qa_text(text: str) -> str:
57
+ # TriviaQA has some weird formats
58
+ # For example: """What breakfast food gets its name from the German word for """"stirrup""""?"""
59
+ while text.startswith('"') and text.endswith('"'):
60
+ text = text[1:-1].replace('""', '"')
61
+ return text
62
+
63
+
64
+ def get_question_key(question: str) -> str:
65
+ # For QA dataset, we'll use normalized question strings as dict key
66
+ return question
67
+
68
+
69
+ def load_query_answers(path: str) -> Dict[str, Dict[str, Any]]:
70
+ assert path.endswith('.tsv')
71
+
72
+ qid_to_query = {}
73
+ for line in open(path, 'r', encoding='utf-8'):
74
+ query, answers = line.strip().split('\t')
75
+ query = normalize_qa_text(query)
76
+ answers = normalize_qa_text(answers)
77
+ qid = get_question_key(query)
78
+ if qid in qid_to_query:
79
+ logger.warning('Duplicate question: {} vs {}'.format(query, qid_to_query[qid]['query']))
80
+ continue
81
+
82
+ qid_to_query[qid] = {}
83
+ qid_to_query[qid]['query'] = query
84
+ qid_to_query[qid]['answers'] = list(eval(answers))
85
+
86
+ logger.info('Load {} queries from {}'.format(len(qid_to_query), path))
87
+ return qid_to_query
88
+
89
+
90
+ def load_corpus(path: str) -> Dataset:
91
+ assert path.endswith('.jsonl') or path.endswith('.jsonl.gz')
92
+
93
+ # two fields: id, contents
94
+ corpus = load_dataset('json', data_files=path)['train']
95
+ logger.info('Load {} documents from {} with columns {}'.format(len(corpus), path, corpus.column_names))
96
+ logger.info('A random document: {}'.format(random.choice(corpus)))
97
+ return corpus
98
+
99
+
100
+ def load_msmarco_predictions(path: str) -> Dict[str, List[ScoredDoc]]:
101
+ assert path.endswith('.txt')
102
+
103
+ qid_to_scored_doc = {}
104
+ for line in tqdm.tqdm(open(path, 'r', encoding='utf-8'), desc='load prediction', mininterval=3):
105
+ fs = line.strip().split('\t')
106
+ qid, pid, rank = fs[:3]
107
+ rank = int(rank)
108
+ score = round(1 / rank, 4) if len(fs) == 3 else float(fs[3])
109
+
110
+ if qid not in qid_to_scored_doc:
111
+ qid_to_scored_doc[qid] = []
112
+ scored_doc = ScoredDoc(qid=qid, pid=pid, rank=rank, score=score)
113
+ qid_to_scored_doc[qid].append(scored_doc)
114
+
115
+ qid_to_scored_doc = {qid: sorted(scored_docs, key=lambda sd: sd.rank)
116
+ for qid, scored_docs in qid_to_scored_doc.items()}
117
+
118
+ logger.info('Load {} query predictions from {}'.format(len(qid_to_scored_doc), path))
119
+ return qid_to_scored_doc
120
+
121
+
122
+ def save_preds_to_msmarco_format(preds: Dict[str, List[ScoredDoc]], out_path: str):
123
+ with open(out_path, 'w', encoding='utf-8') as writer:
124
+ for qid in preds:
125
+ for idx, scored_doc in enumerate(preds[qid]):
126
+ writer.write('{}\t{}\t{}\t{}\n'.format(qid, scored_doc.pid, idx + 1, round(scored_doc.score, 3)))
127
+ logger.info('Successfully saved to {}'.format(out_path))
128
+
129
+
130
+ def save_to_readable_format(in_path: str, corpus: Dataset):
131
+ out_path = '{}/readable_{}'.format(os.path.dirname(in_path), os.path.basename(in_path))
132
+ dataset: Dataset = load_dataset('json', data_files=in_path)['train']
133
+
134
+ max_to_keep = 5
135
+
136
+ def _create_readable_field(samples: Dict[str, List]) -> List:
137
+ readable_ex = []
138
+ for idx in range(min(len(samples['doc_id']), max_to_keep)):
139
+ doc_id = samples['doc_id'][idx]
140
+ readable_ex.append({'doc_id': doc_id,
141
+ 'title': corpus[int(doc_id)].get('title', ''),
142
+ 'contents': corpus[int(doc_id)]['contents'],
143
+ 'score': samples['score'][idx]})
144
+ return readable_ex
145
+
146
+ def _mp_func(ex: Dict) -> Dict:
147
+ ex['positives'] = _create_readable_field(ex['positives'])
148
+ ex['negatives'] = _create_readable_field(ex['negatives'])
149
+ return ex
150
+ dataset = dataset.map(_mp_func, num_proc=8)
151
+
152
+ dataset.to_json(out_path, force_ascii=False, lines=False, indent=4)
153
+ logger.info('Done convert {} to readable format in {}'.format(in_path, out_path))
154
+
155
+
156
+ def get_rerank_shard_path(args: Arguments, worker_idx: int) -> str:
157
+ return '{}_shard_{}'.format(args.rerank_out_path, worker_idx)
158
+
159
+
160
+ def merge_rerank_predictions(args: Arguments, gpu_count: int):
161
+ from metrics import trec_eval, compute_mrr
162
+
163
+ qid_to_scored_doc: Dict[str, List[ScoredDoc]] = {}
164
+ for worker_idx in range(gpu_count):
165
+ path = get_rerank_shard_path(args, worker_idx)
166
+ for line in tqdm.tqdm(open(path, 'r', encoding='utf-8'), 'merge results', mininterval=3):
167
+ fs = line.strip().split('\t')
168
+ qid, pid, _, score = fs
169
+ score = float(score)
170
+
171
+ if qid not in qid_to_scored_doc:
172
+ qid_to_scored_doc[qid] = []
173
+ scored_doc = ScoredDoc(qid=qid, pid=pid, rank=-1, score=score)
174
+ qid_to_scored_doc[qid].append(scored_doc)
175
+
176
+ qid_to_scored_doc = {k: sorted(v, key=lambda sd: sd.score, reverse=True) for k, v in qid_to_scored_doc.items()}
177
+
178
+ ori_preds = load_msmarco_predictions(path=args.rerank_in_path)
179
+ for query_id in list(qid_to_scored_doc.keys()):
180
+ remain_scored_docs = ori_preds[query_id][args.rerank_depth:]
181
+ for idx, sd in enumerate(remain_scored_docs):
182
+ # make sure the order is not broken
183
+ sd.score = qid_to_scored_doc[query_id][-1].score - idx - 1
184
+ qid_to_scored_doc[query_id] += remain_scored_docs
185
+ assert len(set([sd.pid for sd in qid_to_scored_doc[query_id]])) == len(qid_to_scored_doc[query_id])
186
+
187
+ save_preds_to_msmarco_format(qid_to_scored_doc, out_path=args.rerank_out_path)
188
+
189
+ path_qrels = '{}/{}_qrels.txt'.format(args.data_dir, args.rerank_split)
190
+ if os.path.exists(path_qrels):
191
+ qrels = load_qrels(path=path_qrels)
192
+ all_metrics = trec_eval(qrels=qrels, predictions=qid_to_scored_doc)
193
+ all_metrics['mrr'] = compute_mrr(qrels=qrels, predictions=qid_to_scored_doc)
194
+
195
+ logger.info('{} trec metrics = {}'.format(args.rerank_split, json.dumps(all_metrics, ensure_ascii=False, indent=4)))
196
+ metrics_out_path = '{}/metrics_rerank_{}.json'.format(os.path.dirname(args.rerank_out_path), args.rerank_split)
197
+ save_json_to_file(all_metrics, metrics_out_path)
198
+ else:
199
+ logger.warning('No qrels found for {}'.format(args.rerank_split))
200
+
201
+ # cleanup some intermediate results
202
+ for worker_idx in range(gpu_count):
203
+ path = get_rerank_shard_path(args, worker_idx)
204
+ os.remove(path)
205
+
206
+
207
+ if __name__ == '__main__':
208
+ load_qrels('./data/msmarco/dev_qrels.txt')
209
+ load_queries('./data/msmarco/dev_queries.tsv')
210
+ corpus = load_corpus('./data/msmarco/passages.jsonl.gz')
211
+ preds = load_msmarco_predictions('./data/bm25.msmarco.txt')
src/cross_rerank/loader_utils.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict
2
+
3
+
4
+ def _slice_with_mod(elements: List, offset: int, cnt: int) -> List:
5
+ return [elements[(offset + idx) % len(elements)] for idx in range(cnt)]
6
+
7
+
8
+ def group_doc_ids(examples: Dict[str, List],
9
+ negative_size: int,
10
+ offset: int,
11
+ use_first_positive: bool = False,
12
+ use_first_negative: bool = True) -> List[int]:
13
+ pos_doc_ids: List[int] = []
14
+ positives: List[Dict[str, List]] = examples['positives']
15
+ for idx, ex_pos in enumerate(positives):
16
+ all_pos_doc_ids = ex_pos['doc_id']
17
+
18
+ if use_first_positive:
19
+ # keep positives that has higher score than all negatives
20
+ all_pos_doc_ids = [doc_id for p_idx, doc_id in enumerate(all_pos_doc_ids)
21
+ if ex_pos['score'][p_idx] == max(ex_pos['score'])]
22
+
23
+ cur_pos_doc_id = _slice_with_mod(all_pos_doc_ids, offset=offset, cnt=1)[0]
24
+ pos_doc_ids.append(int(cur_pos_doc_id))
25
+
26
+ neg_doc_ids: List[List[int]] = []
27
+ negatives: List[Dict[str, List]] = examples['negatives']
28
+ for ex_neg in negatives:
29
+ if use_first_negative:
30
+ cur_neg_doc_ids = ex_neg['doc_id'][:negative_size]
31
+ else:
32
+ cur_neg_doc_ids = _slice_with_mod(ex_neg['doc_id'],
33
+ offset=offset * negative_size,
34
+ cnt=negative_size)
35
+ cur_neg_doc_ids = [int(doc_id) for doc_id in cur_neg_doc_ids]
36
+ neg_doc_ids.append(cur_neg_doc_ids)
37
+
38
+ assert len(pos_doc_ids) == len(neg_doc_ids), '{} != {}'.format(len(pos_doc_ids), len(neg_doc_ids))
39
+ assert all(len(doc_ids) == negative_size for doc_ids in neg_doc_ids)
40
+
41
+ input_doc_ids: List[int] = []
42
+ for pos_doc_id, neg_ids in zip(pos_doc_ids, neg_doc_ids):
43
+ input_doc_ids.append(pos_doc_id)
44
+ input_doc_ids += neg_ids
45
+
46
+ return input_doc_ids
src/cross_rerank/logger_config.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+
4
+ from transformers.trainer_callback import TrainerCallback
5
+
6
+
7
+ def _setup_logger():
8
+ log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s")
9
+ logger = logging.getLogger()
10
+ logger.setLevel(logging.INFO)
11
+
12
+ console_handler = logging.StreamHandler()
13
+ console_handler.setFormatter(log_format)
14
+
15
+ data_dir = './data/'
16
+ os.makedirs(data_dir, exist_ok=True)
17
+ file_handler = logging.FileHandler('{}/log.txt'.format(data_dir))
18
+ file_handler.setFormatter(log_format)
19
+
20
+ logger.handlers = [console_handler, file_handler]
21
+
22
+ return logger
23
+
24
+
25
+ logger = _setup_logger()
26
+
27
+
28
+ class LoggerCallback(TrainerCallback):
29
+ def on_log(self, args, state, control, logs=None, **kwargs):
30
+ _ = logs.pop("total_flos", None)
31
+ if state.is_world_process_zero:
32
+ logger.info(logs)
src/cross_rerank/loss.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ class CrossEncoderNllLoss(object):
5
+ def __init__(self,
6
+ score_type="dot"):
7
+ self.score_type = score_type
8
+
9
+ def calc(
10
+ self,
11
+ logits,
12
+ labels):
13
+ """
14
+ Computes nll loss for the given lists of question and ctx vectors.
15
+ Return: a tuple of loss value and amount of correct predictions per batch
16
+ """
17
+
18
+ #if len(q_vectors.size()) > 1:
19
+ # q_num = q_vectors.size(0)
20
+ # scores = scores.view(q_num, -1)
21
+ # positive_idx_per_question = [i for i in range(q_num)]
22
+
23
+ softmax_scores = F.log_softmax(logits, dim=1)
24
+ #print("softmax", softmax_scores)
25
+ #print(softmax_scores.size())
26
+ #print(labels.size())
27
+ loss = F.nll_loss(
28
+ softmax_scores,
29
+ labels,
30
+ reduction="mean",
31
+ )
32
+ #print(loss)
33
+ #max_score, max_idxs = torch.max(softmax_scores, 1)
34
+ #correct_predictions_count = (max_idxs == torch.tensor(positive_idx_per_question).to(max_idxs.device)).sum()
35
+ return loss#, correct_predictions_count
36
+
src/cross_rerank/metrics.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pytrec_eval
3
+
4
+ from typing import List, Dict, Tuple
5
+
6
+ from .data_utils import ScoredDoc
7
+ from .logger_config import logger
8
+
9
+
10
+ def trec_eval(qrels: Dict[str, Dict[str, int]],
11
+ predictions: Dict[str, List[ScoredDoc]],
12
+ k_values: Tuple[int] = (10, 50, 100, 200, 1000)) -> Dict[str, float]:
13
+ ndcg, _map, recall = {}, {}, {}
14
+
15
+ for k in k_values:
16
+ ndcg[f"NDCG@{k}"] = 0.0
17
+ _map[f"MAP@{k}"] = 0.0
18
+ recall[f"Recall@{k}"] = 0.0
19
+
20
+ map_string = "map_cut." + ",".join([str(k) for k in k_values])
21
+ ndcg_string = "ndcg_cut." + ",".join([str(k) for k in k_values])
22
+ recall_string = "recall." + ",".join([str(k) for k in k_values])
23
+
24
+ results: Dict[str, Dict[str, float]] = {}
25
+ for query_id, scored_docs in predictions.items():
26
+ results.update({query_id: {sd.pid: sd.score for sd in scored_docs}})
27
+
28
+ evaluator = pytrec_eval.RelevanceEvaluator(qrels, {map_string, ndcg_string, recall_string})
29
+ scores = evaluator.evaluate(results)
30
+
31
+ for query_id in scores:
32
+ for k in k_values:
33
+ ndcg[f"NDCG@{k}"] += scores[query_id]["ndcg_cut_" + str(k)]
34
+ _map[f"MAP@{k}"] += scores[query_id]["map_cut_" + str(k)]
35
+ recall[f"Recall@{k}"] += scores[query_id]["recall_" + str(k)]
36
+
37
+ def _normalize(m: dict) -> dict:
38
+ return {k: round(v / len(scores), 5) for k, v in m.items()}
39
+
40
+ ndcg = _normalize(ndcg)
41
+ _map = _normalize(_map)
42
+ recall = _normalize(recall)
43
+
44
+ all_metrics = {}
45
+ for mt in [ndcg, _map, recall]:
46
+ all_metrics.update(mt)
47
+
48
+ return all_metrics
49
+
50
+
51
+ @torch.no_grad()
52
+ def accuracy(output: torch.tensor, target: torch.tensor, topk=(1,)) -> List[float]:
53
+ """Computes the accuracy over the k top predictions for the specified values of k"""
54
+ maxk = max(topk)
55
+ batch_size = target.size(0)
56
+
57
+ _, pred = output.topk(maxk, 1, True, True)
58
+ pred = pred.t()
59
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
60
+
61
+ res = []
62
+ for k in topk:
63
+ correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
64
+ res.append(correct_k.mul_(100.0 / batch_size).item())
65
+ return res
66
+
67
+
68
+ @torch.no_grad()
69
+ def batch_mrr(output: torch.tensor, target: torch.tensor) -> float:
70
+ assert len(output.shape) == 2
71
+ assert len(target.shape) == 1
72
+ sorted_score, sorted_indices = torch.sort(output, dim=-1, descending=True)
73
+ _, rank = torch.nonzero(sorted_indices.eq(target.unsqueeze(-1)).long(), as_tuple=True)
74
+ assert rank.shape[0] == output.shape[0]
75
+
76
+ rank = rank + 1
77
+ mrr = torch.sum(100 / rank.float()) / rank.shape[0]
78
+ return mrr.item()
79
+
80
+
81
+ def get_rel_threshold(qrels: Dict[str, Dict[str, int]]) -> int:
82
+ # For ms-marco passage ranking, score >= 1 is relevant
83
+ # for trec dl 2019 & 2020, score >= 2 is relevant
84
+ rel_labels = set()
85
+ for q_id in qrels:
86
+ for doc_id, label in qrels[q_id].items():
87
+ rel_labels.add(label)
88
+
89
+ logger.info('relevance labels: {}'.format(rel_labels))
90
+ return 2 if max(rel_labels) >= 3 else 1
91
+
92
+
93
+ def compute_mrr(qrels: Dict[str, Dict[str, int]],
94
+ predictions: Dict[str, List[ScoredDoc]],
95
+ k: int = 10) -> float:
96
+ threshold = get_rel_threshold(qrels)
97
+ mrr = 0
98
+ for qid in qrels:
99
+ scored_docs = predictions.get(qid, [])
100
+ for idx, scored_doc in enumerate(scored_docs[:k]):
101
+ if scored_doc.pid in qrels[qid] and qrels[qid][scored_doc.pid] >= threshold:
102
+ mrr += 1 / (idx + 1)
103
+ break
104
+
105
+ return round(mrr / len(qrels) * 100, 4)
src/cross_rerank/model.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from typing import Optional, Dict
5
+ from transformers import (
6
+ PreTrainedModel,
7
+ AutoModelForSequenceClassification
8
+ )
9
+ from transformers.modeling_outputs import SequenceClassifierOutput
10
+
11
+ from .config import Arguments
12
+ from .loss import CrossEncoderNllLoss
13
+
14
+ class Reranker(nn.Module):
15
+ def __init__(self, hf_model: PreTrainedModel, args: Arguments):
16
+ super().__init__()
17
+ self.hf_model = hf_model
18
+ self.args = args
19
+ self._keys_to_ignore_on_save = None
20
+
21
+ self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
22
+ #self.contrastive = CrossEncoderNllLoss()
23
+ #self.kl_loss_fn = torch.nn.KLDivLoss(reduction="batchmean", log_target=True)
24
+
25
+ def forward(self, input_ids, attention_mask, token_type_ids) -> SequenceClassifierOutput:
26
+ #n_psg_per_query = self.args.train_n_passages // self.args.rerank_forward_factor
27
+
28
+ outputs: SequenceClassifierOutput = self.hf_model(input_ids, attention_mask, token_type_ids, return_dict=True)
29
+ #outputs.logits = outputs.logits.view(-1, n_psg_per_query)
30
+ #loss = self.cross_entropy(outputs.logits, labels)
31
+
32
+ return outputs#, loss
33
+
34
+ @classmethod
35
+ def from_pretrained(cls, all_args: Arguments, *args, **kwargs):
36
+ hf_model = AutoModelForSequenceClassification.from_pretrained(*args, **kwargs)
37
+ return cls(hf_model, all_args)
38
+
39
+ def save_pretrained(self, output_dir: str):
40
+ self.hf_model.save_pretrained(output_dir)
41
+
42
+
43
+ class RerankerForInference(nn.Module):
44
+ def __init__(self, model_checkpoint, hf_model: Optional[PreTrainedModel] = None):
45
+ super().__init__()
46
+ self.hf_model = hf_model
47
+ if hf_model is None:
48
+ self.hf_model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint)
49
+ self.hf_model.eval()
50
+
51
+ @torch.no_grad()
52
+ def forward(self, batch) -> SequenceClassifierOutput:
53
+ return self.hf_model(**batch)
54
+
55
+ @classmethod
56
+ def from_pretrained(cls, pretrained_model_name_or_path: str):
57
+ hf_model = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path)
58
+ return cls(hf_model)
src/cross_rerank/trainer.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch import nn
4
+ from torch.utils.checkpoint import get_device_states, set_device_states
5
+ #from typing import Optional, Union, Dict, Any
6
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
7
+ from transformers.trainer import Trainer
8
+ from transformers.modeling_outputs import SequenceClassifierOutput
9
+ from collections.abc import Mapping
10
+ from .logger_config import logger
11
+ from .metrics import accuracy
12
+ from .utils import AverageMeter
13
+
14
+ def nested_detach(tensors):
15
+ "Detach `tensors` (even if it's a nested list/tuple/dict of tensors)."
16
+ if isinstance(tensors, (list, tuple)):
17
+ return type(tensors)(nested_detach(t) for t in tensors)
18
+ elif isinstance(tensors, Mapping):
19
+ return type(tensors)({k: nested_detach(t) for k, t in tensors.items()})
20
+ return tensors.detach()
21
+
22
+ class RandContext:
23
+ def __init__(self, *tensors):
24
+ self.fwd_cpu_state = torch.get_rng_state()
25
+ self.fwd_gpu_devices, self.fwd_gpu_states = get_device_states(*tensors)
26
+
27
+ def __enter__(self):
28
+ self._fork = torch.random.fork_rng(
29
+ devices=self.fwd_gpu_devices,
30
+ enabled=True
31
+ )
32
+ self._fork.__enter__()
33
+ torch.set_rng_state(self.fwd_cpu_state)
34
+ set_device_states(self.fwd_gpu_devices, self.fwd_gpu_states)
35
+
36
+ def __exit__(self, exc_type, exc_val, exc_tb):
37
+ self._fork.__exit__(exc_type, exc_val, exc_tb)
38
+ self._fork = None
39
+
40
+ class RerankerTrainer(Trainer):
41
+ def __init__(self, *pargs, **kwargs):
42
+ super(RerankerTrainer, self).__init__(*pargs, **kwargs)
43
+
44
+ self.acc_meter = AverageMeter('acc', round_digits=2)
45
+ self.last_epoch = 0
46
+
47
+ def _save(self, output_dir: Optional[str] = None, state_dict=None):
48
+ output_dir = output_dir if output_dir is not None else self.args.output_dir
49
+ os.makedirs(output_dir, exist_ok=True)
50
+ logger.info("Saving model checkpoint to {}".format(output_dir))
51
+
52
+ self.model.save_pretrained(output_dir)
53
+
54
+ if self.tokenizer is not None and self.is_world_process_zero():
55
+ self.tokenizer.save_pretrained(output_dir)
56
+
57
+ def compute_loss(self, model, inputs, return_outputs=False):
58
+ n_psg_per_query = self.args.train_n_passages // self.args.rerank_forward_factor
59
+ input_ids = inputs['input_ids']
60
+ attention_mask = inputs['attention_mask']
61
+ token_type_ids = inputs['token_type_ids']
62
+ labels = inputs['labels']
63
+ outputs = model(input_ids, attention_mask, token_type_ids)
64
+ outputs.logits = outputs.logits.view(-1, n_psg_per_query)
65
+ loss = self.model.cross_entropy(outputs.logits, labels)
66
+
67
+ if self.model.training:
68
+ step_acc = accuracy(output=outputs.logits.detach(), target=labels)[0]
69
+ self.acc_meter.update(step_acc)
70
+ if self.state.global_step > 0 and self.state.global_step % self.args.logging_steps == 0:
71
+ logger.info('step: {}, {}'.format(self.state.global_step, self.acc_meter))
72
+
73
+ self._reset_meters_if_needed()
74
+
75
+ return (loss, outputs) if return_outputs else loss
76
+
77
+ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
78
+ model.train()
79
+ inputs = self._prepare_inputs(inputs)
80
+
81
+ with self.compute_loss_context_manager():
82
+ loss = self.compute_loss_train(model, inputs)
83
+
84
+ return loss.detach() / self.args.gradient_accumulation_steps
85
+
86
+ def compute_loss_train(self, model, inputs, return_outputs=False):
87
+ #print(inputs)
88
+ #print(inputs['input_ids'].size())
89
+ n_psg_per_query = self.args.train_n_passages // self.args.rerank_forward_factor
90
+ input_ids = inputs['input_ids']
91
+ attention_mask = inputs['attention_mask']
92
+ token_type_ids = inputs['token_type_ids']
93
+ labels = inputs['labels']
94
+
95
+ all_reps, rnds = [], []
96
+
97
+ id_chunks = input_ids.split(self.args.chunk_size)
98
+ attn_mask_chunks = attention_mask.split(self.args.chunk_size)
99
+
100
+ type_ids_chunks = token_type_ids.split(self.args.chunk_size)
101
+
102
+ for id_chunk, attn_chunk, type_chunk in zip(id_chunks, attn_mask_chunks, type_ids_chunks):
103
+ rnds.append(RandContext(id_chunk, attn_chunk, type_chunk))
104
+ with torch.no_grad():
105
+ chunk_reps = self.model(id_chunk, attn_chunk, type_chunk).logits
106
+ all_reps.append(chunk_reps)
107
+ all_reps = torch.cat(all_reps)
108
+ all_reps = all_reps.view(-1, n_psg_per_query)
109
+
110
+ all_reps = all_reps.float().detach().requires_grad_()
111
+ loss = self.model.cross_entropy(all_reps, labels)
112
+
113
+ if self.args.n_gpu > 1:
114
+ loss = loss.mean() # mean() to average on multi-gpu parallel training
115
+
116
+ self.accelerator.backward(loss)
117
+ #if self.args.gradient_accumulation_steps > 1:
118
+ # loss = loss / self.args.gradient_accumulation_steps
119
+ #loss.backward()
120
+ #temp = all_reps.view(-1,1)
121
+ grads = all_reps.grad.split(int(self.args.chunk_size/n_psg_per_query))
122
+
123
+ for id_chunk, attn_chunk, type_chunk, grad, rnd in zip(id_chunks, attn_mask_chunks, type_ids_chunks, grads, rnds):
124
+ #print(id_chunk.size())
125
+ with rnd:
126
+ chunk_reps = self.model(id_chunk, attn_chunk, type_chunk).logits
127
+ #print(chunk_reps.size())
128
+ #print(grad.size())
129
+ surrogate = torch.dot(chunk_reps.flatten().float(), grad.flatten())
130
+
131
+ self.accelerator.backward(surrogate)
132
+
133
+ #outputs, loss = model(input_ids, attention_mask, token_type_ids, labels)
134
+
135
+ if self.model.training:
136
+ step_acc = accuracy(all_reps, target=labels)[0]
137
+ #print(step_acc)
138
+ self.acc_meter.update(step_acc)
139
+ if self.state.global_step > 0 and self.state.global_step % self.args.logging_steps == 0:
140
+ logger.info('step: {}, {}'.format(self.state.global_step, self.acc_meter))
141
+
142
+ self._reset_meters_if_needed()
143
+
144
+ return (loss, all_reps) if return_outputs else loss
145
+
146
+ '''def compute_loss_pred(self, model, inputs, return_outputs=False):
147
+ #print(inputs)
148
+ #print(inputs['input_ids'].size())
149
+ n_psg_per_query = self.args.train_n_passages // self.args.rerank_forward_factor
150
+ input_ids = inputs['input_ids']
151
+ attention_mask = inputs['attention_mask']
152
+ token_type_ids = inputs['token_type_ids']
153
+ labels = inputs['labels']
154
+
155
+ all_reps, rnds = [], []
156
+
157
+ id_chunks = input_ids.split(self.args.chunk_size)
158
+ attn_mask_chunks = attention_mask.split(self.args.chunk_size)
159
+
160
+ type_ids_chunks = token_type_ids.split(self.args.chunk_size)
161
+
162
+ for id_chunk, attn_chunk, type_chunk in zip(id_chunks, attn_mask_chunks, type_ids_chunks):
163
+ rnds.append(RandContext(id_chunk, attn_chunk, type_chunk))
164
+ with torch.no_grad():
165
+ chunk_reps = self.model(id_chunk, attn_chunk, type_chunk).logits
166
+ all_reps.append(chunk_reps)
167
+ all_reps = torch.cat(all_reps)
168
+ all_reps = all_reps.view(-1, n_psg_per_query)
169
+ loss = self.model.cross_entropy(all_reps, labels)
170
+
171
+ if self.model.training:
172
+ step_acc = accuracy(all_reps, target=labels)[0]
173
+ #print(step_acc)
174
+ self.acc_meter.update(step_acc)
175
+ if self.state.global_step > 0 and self.state.global_step % self.args.logging_steps == 0:
176
+ logger.info('step: {}, {}'.format(self.state.global_step, self.acc_meter))
177
+
178
+ self._reset_meters_if_needed()
179
+
180
+ return (loss, all_reps) if return_outputs else loss
181
+
182
+ def prediction_step(
183
+ self,
184
+ model: nn.Module,
185
+ inputs: Dict[str, Union[torch.Tensor, Any]],
186
+ prediction_loss_only: bool,
187
+ ignore_keys: Optional[List[str]] = None,
188
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
189
+ has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)
190
+ # For CLIP-like models capable of returning loss values.
191
+ # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
192
+ # is `True` in `model.forward`.
193
+ return_loss = inputs.get("return_loss", None)
194
+ if return_loss is None:
195
+ return_loss = self.can_return_loss
196
+ loss_without_labels = True if len(self.label_names) == 0 and return_loss else False
197
+
198
+ inputs = self._prepare_inputs(inputs)
199
+ if ignore_keys is None:
200
+ if hasattr(self.model, "config"):
201
+ ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
202
+ else:
203
+ ignore_keys = []
204
+
205
+ # labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
206
+ if has_labels or loss_without_labels:
207
+ labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
208
+ if len(labels) == 1:
209
+ labels = labels[0]
210
+ else:
211
+ labels = None
212
+
213
+ with torch.no_grad():
214
+ if has_labels or loss_without_labels:
215
+ with self.compute_loss_context_manager():
216
+ loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
217
+ loss = loss.mean().detach()
218
+
219
+ if isinstance(outputs, dict):
220
+ logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
221
+ else:
222
+ logits = outputs[1:]
223
+ else:
224
+ loss = None
225
+ with self.compute_loss_context_manager():
226
+ outputs = model(**inputs)
227
+ if isinstance(outputs, dict):
228
+ logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
229
+ else:
230
+ logits = outputs
231
+ # TODO: this needs to be fixed and made cleaner later.
232
+ if self.args.past_index >= 0:
233
+ self._past = outputs[self.args.past_index - 1]
234
+
235
+ if prediction_loss_only:
236
+ return (loss, None, None)
237
+
238
+ logits = nested_detach(logits)
239
+ if len(logits) == 1:
240
+ logits = logits[0]
241
+
242
+ return (loss, logits, labels)'''
243
+
244
+ def _reset_meters_if_needed(self):
245
+ if int(self.state.epoch) != self.last_epoch:
246
+ self.last_epoch = int(self.state.epoch)
247
+ self.acc_meter.reset()
src/cross_rerank/utils.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ import torch.distributed as dist
4
+
5
+ from typing import List, Union, Optional, Tuple, Mapping, Dict
6
+
7
+
8
+ def save_json_to_file(objects: Union[List, dict], path: str, line_by_line: bool = False):
9
+ if line_by_line:
10
+ assert isinstance(objects, list), 'Only list can be saved in line by line format'
11
+
12
+ with open(path, 'w', encoding='utf-8') as writer:
13
+ if not line_by_line:
14
+ json.dump(objects, writer, ensure_ascii=False, indent=4, separators=(',', ':'))
15
+ else:
16
+ for obj in objects:
17
+ writer.write(json.dumps(obj, ensure_ascii=False, separators=(',', ':')))
18
+ writer.write('\n')
19
+
20
+
21
+ def move_to_cuda(sample):
22
+ if len(sample) == 0:
23
+ return {}
24
+
25
+ def _move_to_cuda(maybe_tensor):
26
+ if torch.is_tensor(maybe_tensor):
27
+ return maybe_tensor.cuda(non_blocking=True)
28
+ elif isinstance(maybe_tensor, dict):
29
+ return {key: _move_to_cuda(value) for key, value in maybe_tensor.items()}
30
+ elif isinstance(maybe_tensor, list):
31
+ return [_move_to_cuda(x) for x in maybe_tensor]
32
+ elif isinstance(maybe_tensor, tuple):
33
+ return tuple([_move_to_cuda(x) for x in maybe_tensor])
34
+ elif isinstance(maybe_tensor, Mapping):
35
+ return type(maybe_tensor)({k: _move_to_cuda(v) for k, v in maybe_tensor.items()})
36
+ else:
37
+ return maybe_tensor
38
+
39
+ return _move_to_cuda(sample)
40
+
41
+
42
+ def dist_gather_tensor(t: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
43
+ if t is None:
44
+ return None
45
+
46
+ t = t.contiguous()
47
+ all_tensors = [torch.empty_like(t) for _ in range(dist.get_world_size())]
48
+ dist.all_gather(all_tensors, t)
49
+
50
+ all_tensors[dist.get_rank()] = t
51
+ all_tensors = torch.cat(all_tensors, dim=0)
52
+ return all_tensors
53
+
54
+
55
+ @torch.no_grad()
56
+ def select_grouped_indices(scores: torch.Tensor,
57
+ group_size: int,
58
+ start: int = 0) -> torch.Tensor:
59
+ assert len(scores.shape) == 2
60
+ batch_size = scores.shape[0]
61
+ assert batch_size * group_size <= scores.shape[1]
62
+
63
+ indices = torch.arange(0, group_size, dtype=torch.long)
64
+ indices = indices.repeat(batch_size, 1)
65
+ indices += torch.arange(0, batch_size, dtype=torch.long).unsqueeze(-1) * group_size
66
+ indices += start
67
+
68
+ return indices.to(scores.device)
69
+
70
+
71
+ def full_contrastive_scores_and_labels(
72
+ query: torch.Tensor,
73
+ key: torch.Tensor,
74
+ use_all_pairs: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
75
+ assert key.shape[0] % query.shape[0] == 0, '{} % {} > 0'.format(key.shape[0], query.shape[0])
76
+
77
+ train_n_passages = key.shape[0] // query.shape[0]
78
+ labels = torch.arange(0, query.shape[0], dtype=torch.long, device=query.device)
79
+ labels = labels * train_n_passages
80
+
81
+ # batch_size x (batch_size x n_psg)
82
+ qk = torch.mm(query, key.t())
83
+
84
+ if not use_all_pairs:
85
+ return qk, labels
86
+
87
+ # batch_size x dim
88
+ sliced_key = key.index_select(dim=0, index=labels)
89
+ assert query.shape[0] == sliced_key.shape[0]
90
+
91
+ # batch_size x batch_size
92
+ kq = torch.mm(sliced_key, query.t())
93
+ kq.fill_diagonal_(float('-inf'))
94
+
95
+ qq = torch.mm(query, query.t())
96
+ qq.fill_diagonal_(float('-inf'))
97
+
98
+ kk = torch.mm(sliced_key, sliced_key.t())
99
+ kk.fill_diagonal_(float('-inf'))
100
+
101
+ scores = torch.cat([qk, kq, qq, kk], dim=-1)
102
+
103
+ return scores, labels
104
+
105
+
106
+ def slice_batch_dict(batch_dict: Dict[str, torch.Tensor], prefix: str) -> dict:
107
+ return {k[len(prefix):]: v for k, v in batch_dict.items() if k.startswith(prefix)}
108
+
109
+
110
+ class AverageMeter(object):
111
+ """Computes and stores the average and current value"""
112
+
113
+ def __init__(self, name: str, round_digits: int = 3):
114
+ self.name = name
115
+ self.round_digits = round_digits
116
+ self.reset()
117
+
118
+ def reset(self):
119
+ self.avg = 0
120
+ self.sum = 0
121
+ self.count = 0
122
+
123
+ def update(self, val, n=1):
124
+ self.sum += val * n
125
+ self.count += n
126
+ self.avg = self.sum / self.count
127
+
128
+ def __str__(self):
129
+ return '{}: {}'.format(self.name, round(self.avg, self.round_digits))
130
+
131
+
132
+ if __name__ == '__main__':
133
+ query = torch.randn(4, 16)
134
+ key = torch.randn(4 * 3, 16)
135
+ scores, labels = full_contrastive_scores_and_labels(query, key)
136
+ print(scores.shape)
137
+ print(labels)
src/eval_cross.py ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import faiss
2
+ import torch
3
+ import logging
4
+ import json
5
+ import numpy as np
6
+ import pandas as pd
7
+ from tqdm import tqdm
8
+ from typing import Optional
9
+ from dataclasses import dataclass, field
10
+ from transformers import HfArgumentParser
11
+ from transformers import AutoTokenizer
12
+ from bi.model import SharedBiEncoder
13
+ from bi.preprocess import preprocess_question
14
+ from cross_rerank.model import RerankerForInference
15
+ #from src.process import process_query, process_text, concat_str
16
+ import itertools
17
+ from pyvi.ViTokenizer import tokenize
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ @dataclass
23
+ class Args:
24
+ encoder: str = field(
25
+ default="vinai/phobert-base-v2",
26
+ metadata={'help': 'The encoder name or path.'}
27
+ )
28
+ tokenizer: str = field(
29
+ default=None,
30
+ metadata={'help': 'The encoder name or path.'}
31
+ )
32
+ cross_checkpoint: str = field(
33
+ default="vinai/phobert-base-v2",
34
+ metadata={'help': 'The encoder name or path.'}
35
+ )
36
+ cross_tokenizer: str = field(
37
+ default=None,
38
+ metadata={'help': 'The encoder name or path.'}
39
+ )
40
+ sentence_pooling_method: str = field(
41
+ default="cls",
42
+ metadata={'help': 'Embedding method'}
43
+ )
44
+ fp16: bool = field(
45
+ default=False,
46
+ metadata={'help': 'Use fp16 in inference?'}
47
+ )
48
+ max_query_length: int = field(
49
+ default=32,
50
+ metadata={'help': 'Max query length.'}
51
+ )
52
+ max_passage_length: int = field(
53
+ default=256,
54
+ metadata={'help': 'Max passage length.'}
55
+ )
56
+ cross_max_length: int = field(
57
+ default=256,
58
+ metadata={'help': 'Max cross length.'}
59
+ )
60
+ cross_batch_size: int = field(
61
+ default=32,
62
+ metadata={'help': 'Inference batch size.'}
63
+ )
64
+ batch_size: int = field(
65
+ default=128,
66
+ metadata={'help': 'Inference batch size.'}
67
+ )
68
+ index_factory: str = field(
69
+ default="Flat",
70
+ metadata={'help': 'Faiss index factory.'}
71
+ )
72
+ k: int = field(
73
+ default=1000,
74
+ metadata={'help': 'How many neighbors to retrieve?'}
75
+ )
76
+ top_k: int = field(
77
+ default=1000,
78
+ metadata={'help': 'How many neighbors to rerank?'}
79
+ )
80
+ data_path: str = field(
81
+ default="/kaggle/input/zalo-data",
82
+ metadata={'help': 'Path to zalo data.'}
83
+ )
84
+ data_type: str = field(
85
+ default="test",
86
+ metadata={'help': 'Type data to test'}
87
+ )
88
+ corpus_file: str = field(
89
+ default="/kaggle/input/zalo-data",
90
+ metadata={'help': 'Path to zalo corpus.'}
91
+ )
92
+
93
+ data_file: str = field(
94
+ default=None,
95
+ metadata={'help': 'Path to evaluated data.'}
96
+ )
97
+
98
+ bi_data: bool = field(
99
+ default=False,
100
+ metadata={'help': 'Data for bi-encoder training'}
101
+ )
102
+
103
+ save_embedding: bool = field(
104
+ default=False,
105
+ metadata={'help': 'Save embeddings in memmap at save_dir?'}
106
+ )
107
+ load_embedding: str = field(
108
+ default='',
109
+ metadata={'help': 'Path to saved embeddings.'}
110
+ )
111
+
112
+ save_path: str = field(
113
+ default="embeddings.memmap",
114
+ metadata={'help': 'Path to save embeddings.'}
115
+ )
116
+
117
+ def index(model: SharedBiEncoder, tokenizer:AutoTokenizer, corpus, batch_size: int = 16, max_length: int=512, index_factory: str = "Flat", save_path: str = None, save_embedding: bool = False, load_embedding: bool = False):
118
+ """
119
+ 1. Encode the entire corpus into dense embeddings;
120
+ 2. Create faiss index;
121
+ 3. Optionally save embeddings.
122
+ """
123
+ if load_embedding != '':
124
+ test_tokens = tokenizer(['test'],
125
+ padding=True,
126
+ truncation=True,
127
+ max_length=128,
128
+ return_tensors="pt").to('cuda')
129
+ test = model.encoder.get_representation(test_tokens['input_ids'], test_tokens['attention_mask'])
130
+ test = test.cpu().numpy()
131
+ dtype = test.dtype
132
+ dim = test.shape[-1]
133
+
134
+ all_embeddings = np.memmap(
135
+ load_embedding,
136
+ mode="r",
137
+ dtype=dtype
138
+ ).reshape(-1, dim)
139
+
140
+ else:
141
+ #df_corpus = pd.DataFrame()
142
+ #df_corpus['text'] = corpus
143
+ #pandarallel.initialize(progress_bar=True, use_memory_fs=False, nb_workers=12)
144
+ #df_corpus['processed_text'] = df_corpus['text'].parallel_apply(process_text)
145
+ #processed_corpus = df_corpus['processed_text'].tolist()
146
+ #model.to('cuda')
147
+ all_embeddings = []
148
+ for start_index in tqdm(range(0, len(corpus), batch_size), desc="Inference Embeddings",
149
+ disable=len(corpus) < batch_size):
150
+ passages_batch = corpus[start_index:start_index + batch_size]
151
+ d_collated = tokenizer(
152
+ passages_batch,
153
+ padding=True,
154
+ truncation=True,
155
+ max_length=max_length,
156
+ return_tensors="pt",
157
+ ).to('cuda')
158
+
159
+ with torch.no_grad():
160
+ corpus_embeddings = model.encoder.get_representation(d_collated['input_ids'], d_collated['attention_mask'])
161
+
162
+ corpus_embeddings = corpus_embeddings.cpu().numpy()
163
+ all_embeddings.append(corpus_embeddings)
164
+
165
+ all_embeddings = np.concatenate(all_embeddings, axis=0)
166
+ dim = all_embeddings.shape[-1]
167
+
168
+ if save_embedding:
169
+ logger.info(f"saving embeddings at {save_path}...")
170
+ memmap = np.memmap(
171
+ save_path,
172
+ shape=all_embeddings.shape,
173
+ mode="w+",
174
+ dtype=all_embeddings.dtype
175
+ )
176
+
177
+ length = all_embeddings.shape[0]
178
+ # add in batch
179
+ save_batch_size = 10000
180
+ if length > save_batch_size:
181
+ for i in tqdm(range(0, length, save_batch_size), leave=False, desc="Saving Embeddings"):
182
+ j = min(i + save_batch_size, length)
183
+ memmap[i: j] = all_embeddings[i: j]
184
+ else:
185
+ memmap[:] = all_embeddings
186
+ # create faiss index
187
+ faiss_index = faiss.index_factory(dim, index_factory, faiss.METRIC_INNER_PRODUCT)
188
+
189
+ #if model.device == torch.device("cuda"):
190
+ if True:
191
+ co = faiss.GpuClonerOptions()
192
+ #co = faiss.GpuMultipleClonerOptions()
193
+ #co.useFloat16 = True
194
+ faiss_index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, faiss_index, co)
195
+ #faiss_index = faiss.index_cpu_to_all_gpus(faiss_index, co)
196
+
197
+ # NOTE: faiss only accepts float32
198
+ logger.info("Adding embeddings...")
199
+ all_embeddings = all_embeddings.astype(np.float32)
200
+ #print(all_embeddings[0])
201
+ faiss_index.train(all_embeddings)
202
+ faiss_index.add(all_embeddings)
203
+ return faiss_index
204
+
205
+
206
+ def search(model: SharedBiEncoder, tokenizer:AutoTokenizer, questions, faiss_index: faiss.Index, k:int = 100, batch_size: int = 256, max_length: int=128):
207
+ """
208
+ 1. Encode queries into dense embeddings;
209
+ 2. Search through faiss index
210
+ """
211
+ #model.to('cuda')
212
+ q_embeddings = []
213
+ #questions = queries['tokenized_question'].tolist()
214
+ #questions = [process_query(x) for x in questions]
215
+ for start_index in tqdm(range(0, len(questions), batch_size), desc="Inference Embeddings",
216
+ disable=len(questions) < batch_size):
217
+
218
+ q_collated = tokenizer(
219
+ questions[start_index: start_index + batch_size],
220
+ padding=True,
221
+ truncation=True,
222
+ max_length=128,
223
+ return_tensors="pt",
224
+ ).to('cuda')
225
+
226
+ with torch.no_grad():
227
+ query_embeddings = model.encoder.get_representation(q_collated['input_ids'], q_collated['attention_mask'])
228
+ query_embeddings = query_embeddings.cpu().numpy()
229
+ q_embeddings.append(query_embeddings)
230
+
231
+ q_embeddings = np.concatenate(q_embeddings, axis=0)
232
+ query_size = q_embeddings.shape[0]
233
+ all_scores = []
234
+ all_indices = []
235
+
236
+ for i in tqdm(range(0, query_size, batch_size), desc="Searching"):
237
+ j = min(i + batch_size, query_size)
238
+ q_embedding = q_embeddings[i: j]
239
+ score, indice = faiss_index.search(q_embedding.astype(np.float32), k=k)
240
+ all_scores.append(score)
241
+ all_indices.append(indice)
242
+
243
+ all_scores = np.concatenate(all_scores, axis=0)
244
+ all_indices = np.concatenate(all_indices, axis=0)
245
+ return all_scores, all_indices
246
+
247
+ def rerank(reranker: SharedBiEncoder, tokenizer:AutoTokenizer, questions, corpus, retrieved_ids, batch_size = 128, max_length = 256, top_k=30):
248
+ eos = tokenizer.eos_token
249
+ #questions = queries['tokenized_question'].tolist()
250
+ texts = []
251
+ for idx in range(len(questions)):
252
+ for j in range(top_k):
253
+ texts.append(questions[idx] + eos + eos + corpus[retrieved_ids[idx][j]])
254
+ reranked_ids = []
255
+ rerank_scores = []
256
+
257
+ for start_index in tqdm(range(0, len(questions), batch_size), desc="Rerank",
258
+ disable=len(questions) < batch_size):
259
+ batch_retrieved_ids = retrieved_ids[start_index: start_index+batch_size]
260
+ collated = tokenizer(
261
+ texts[start_index*top_k: (start_index + batch_size)*top_k],
262
+ padding=True,
263
+ truncation=True,
264
+ max_length=max_length,
265
+ return_tensors="pt",
266
+ ).to('cuda')
267
+ reranked_scores = reranker(collated).logits
268
+ reranked_scores = reranked_scores.view(-1,top_k).to('cpu').tolist()
269
+ for m in range(len(reranked_scores)):
270
+ tuple_lst = [(batch_retrieved_ids[m][n], reranked_scores[m][n]) for n in range(top_k)]
271
+ tuple_lst.sort(key=lambda tup: tup[1], reverse=True)
272
+ reranked_ids.append([tup[0] for tup in tuple_lst])
273
+ rerank_scores.append([tup[1] for tup in tuple_lst])
274
+
275
+ return reranked_ids, rerank_scores
276
+
277
+
278
+
279
+
280
+ def evaluate(preds, labels, cutoffs=[1,5,10,30,100]):
281
+ """
282
+ Evaluate MRR and Recall at cutoffs.
283
+ """
284
+ metrics = {}
285
+
286
+ # MRR
287
+ mrrs = np.zeros(len(cutoffs))
288
+ for pred, label in zip(preds, labels):
289
+ jump = False
290
+ for i, x in enumerate(pred, 1):
291
+ if x in label:
292
+ for k, cutoff in enumerate(cutoffs):
293
+ if i <= cutoff:
294
+ mrrs[k] += 1 / i
295
+ jump = True
296
+ if jump:
297
+ break
298
+ mrrs /= len(preds)
299
+ for i, cutoff in enumerate(cutoffs):
300
+ mrr = mrrs[i]
301
+ metrics[f"MRR@{cutoff}"] = mrr
302
+
303
+ # Recall
304
+ recalls = np.zeros(len(cutoffs))
305
+ for pred, label in zip(preds, labels):
306
+ for k, cutoff in enumerate(cutoffs):
307
+ recall = np.intersect1d(label, pred[:cutoff])
308
+ recalls[k] += len(recall) / len(label)
309
+ recalls /= len(preds)
310
+ for i, cutoff in enumerate(cutoffs):
311
+ recall = recalls[i]
312
+ metrics[f"Recall@{cutoff}"] = recall
313
+
314
+ return metrics
315
+
316
+ def calculate_score(ground_ids, retrieved_list):
317
+ all_count = 0
318
+ hit_count = 0
319
+ for i in range(len(ground_ids)):
320
+ all_check = True
321
+ hit_check = False
322
+ retrieved_ids = retrieved_list[i]
323
+ ans_ids = ground_ids[i]
324
+ for a_ids in ans_ids:
325
+ com = [a_id for a_id in a_ids if a_id in retrieved_ids]
326
+ if len(com) > 0:
327
+ hit_check = True
328
+ else:
329
+ all_check = False
330
+
331
+ if hit_check:
332
+ hit_count += 1
333
+ if all_check:
334
+ all_count += 1
335
+
336
+ all_acc = all_count/len(ground_ids)
337
+ hit_acc = hit_count/len(ground_ids)
338
+ return hit_acc, all_acc
339
+
340
+ def check(ground_ids, retrieved_list, cutoffs=[1,5,10,30,100]):
341
+ metrics = {}
342
+ for cutoff in cutoffs:
343
+ retrieved_k = [x[:cutoff] for x in retrieved_list]
344
+ hit_acc, all_acc = calculate_score(ground_ids, retrieved_k)
345
+ metrics[f"All@{cutoff}"] = all_acc
346
+ metrics[f"Hit@{cutoff}"] = hit_acc
347
+ return metrics
348
+
349
+ def save_bi_data(tokenized_queries, ground_ids, indices, scores, file, org_questions=None):
350
+ rst = []
351
+ #tokenized_queries = test_data['tokenized_question'].tolist()
352
+ for i in range(len(tokenized_queries)):
353
+ scores_i = scores[i]
354
+ indices_i = indices[i]
355
+ ans_ids = ground_ids[i]
356
+ all_ans_id = [element for x in ans_ids for element in x]
357
+ neg_doc_ids = []
358
+ neg_scores = []
359
+ for count in range(len(indices_i)):
360
+ if indices_i[count] not in all_ans_id and indices_i[count] != -1:
361
+ neg_doc_ids.append(indices_i[count])
362
+ neg_scores.append(scores_i[count])
363
+
364
+ for j in range(len(ans_ids)):
365
+ ans_id = ans_ids[j]
366
+ item = {}
367
+ if org_questions != None:
368
+ item['question'] = org_questions[i]
369
+ item['query'] = tokenized_queries[i]
370
+ item['positives'] = {}
371
+ item['negatives'] = {}
372
+ item['positives']['doc_id'] = []
373
+ item['positives']['score'] = []
374
+ item['negatives']['doc_id'] = neg_doc_ids
375
+ item['negatives']['score'] = neg_scores
376
+ for pos_id in ans_id:
377
+ item['positives']['doc_id'].append(pos_id)
378
+ try:
379
+ idx = indices_i.index(pos_id)
380
+ item['positives']['score'].append(scores_i[idx])
381
+ except:
382
+ item['positives']['score'].append(scores_i[-1])
383
+
384
+ rst.append(item)
385
+
386
+ with open(f'{file}.jsonl', 'w') as jsonl_file:
387
+ for item in rst:
388
+ json_line = json.dumps(item, ensure_ascii=False)
389
+ jsonl_file.write(json_line + '\n')
390
+
391
+ def main():
392
+ parser = HfArgumentParser([Args])
393
+ args: Args = parser.parse_args_into_dataclasses()[0]
394
+ print(args)
395
+ model = SharedBiEncoder(model_checkpoint=args.encoder,
396
+ representation=args.sentence_pooling_method,
397
+ fixed=True)
398
+ model.to('cuda')
399
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer if args.tokenizer else args.encoder)
400
+ reranker = RerankerForInference(model_checkpoint=args.cross_checkpoint)
401
+ reranker.to('cuda')
402
+ reranker_tokenizer = AutoTokenizer.from_pretrained(args.cross_tokenizer if args.cross_tokenizer else args.cross_checkpoint)
403
+ csv_file = True
404
+ if args.data_file:
405
+ if args.data_file.endswith("jsonl"):
406
+ test_data = []
407
+ with open(args.data_file, 'r') as jsonl_file:
408
+ for line in jsonl_file:
409
+ temp = json.loads(line)
410
+ test_data.append(temp)
411
+ csv_file=False
412
+ elif args.data_file.endswith("json"):
413
+ csv_file=False
414
+ with open(args.data_file, 'r') as json_file:
415
+ test_data = json.load(json_file)
416
+ elif args.data_file.endswith("csv"):
417
+ test_data = pd.read_csv(args.data_file)
418
+
419
+ elif args.data_type == 'eval':
420
+ test_data = pd.read_csv(args.data_path + "/tval.csv")
421
+ elif args.data_type == 'train':
422
+ test_data = pd.read_csv(args.data_path + "/ttrain.csv")
423
+ elif args.data_type == 'all':
424
+ data1 = pd.read_csv(args.data_path + "/ttrain.csv")
425
+ data2 = pd.read_csv(args.data_path + "/ttest.csv")
426
+ data3 = pd.read_csv(args.data_path + "/tval.csv")
427
+ test_data = pd.concat([data1, data3, data2], ignore_index=True)
428
+
429
+ else:
430
+ test_data = pd.read_csv(args.data_path + "/ttest.csv")
431
+ corpus_data = pd.read_csv(args.corpus_file)
432
+ #dcorpus = pd.DataFrame(corpus_data)
433
+ #pandarallel.initialize(progress_bar=True, use_memory_fs=False, nb_workers=12)
434
+ #dcorpus["full_text"] = dcorpus.parallel_apply(concat_str, axis=1)
435
+ corpus = corpus_data['tokenized_text'].tolist()
436
+
437
+ if csv_file:
438
+ ans_ids = []
439
+ ground_ids = []
440
+ org_questions = test_data['question'].tolist()
441
+ questions = test_data['tokenized_question'].tolist()
442
+ for i in range(len(test_data)):
443
+ ans_ids.append(json.loads(test_data['best_ans_id'][i]))
444
+ ground_ids.append(json.loads(test_data['ans_id'][i]))
445
+ ground_truths = []
446
+ for sample in ans_ids:
447
+ temp = [corpus_data['law_id'][y] + "_" + str(corpus_data['article_id'][y]) for y in sample]
448
+ ground_truths.append(temp)
449
+ else:
450
+ ground_truths = []
451
+ ground_ids = []
452
+ org_questions = [sample['question'] for sample in test_data]
453
+ questions = [tokenize(preprocess_question(sample['question'], remove_end_phrase=False)) for sample in test_data]
454
+ for sample in test_data:
455
+ try:
456
+ temp = [it['law_id'] + "_" + it['article_id'] for it in sample['relevance_articles']]
457
+ tempp = [it['ans_id'] for it in sample['relevance_articles']]
458
+ except:
459
+ temp = [it['law_id'] + "_" + it['article_id'] for it in sample['relevant_articles']]
460
+ tempp = [it['ans_id'] for it in sample['relevant_articles']]
461
+ ground_truths.append(temp)
462
+ ground_ids.append(tempp)
463
+
464
+ faiss_index = index(
465
+ model=model,
466
+ tokenizer=tokenizer,
467
+ corpus=corpus,
468
+ batch_size=args.batch_size,
469
+ max_length=args.max_passage_length,
470
+ index_factory=args.index_factory,
471
+ save_path=args.save_path,
472
+ save_embedding=args.save_embedding,
473
+ load_embedding=args.load_embedding
474
+ )
475
+
476
+ scores, indices = search(
477
+ model=model,
478
+ tokenizer=tokenizer,
479
+ questions=questions,
480
+ faiss_index=faiss_index,
481
+ k=args.k,
482
+ batch_size=args.batch_size,
483
+ max_length=args.max_query_length
484
+ )
485
+
486
+
487
+ retrieval_results, retrieval_ids = [], []
488
+ for indice in indices:
489
+ # filter invalid indices
490
+ indice = indice[indice != -1].tolist()
491
+ rst = []
492
+ for x in indice:
493
+ temp = corpus_data['law_id'][x] + "_" + str(corpus_data['article_id'][x])
494
+ if temp not in rst:
495
+ rst.append(temp)
496
+ retrieval_results.append(rst)
497
+ retrieval_ids.append(indice)
498
+
499
+ rerank_ids, rerank_scores = rerank(reranker, reranker_tokenizer, questions, corpus, retrieval_ids, args.cross_batch_size, args.cross_max_length, args.top_k)
500
+
501
+ if args.bi_data:
502
+ save_bi_data(questions, ground_ids, rerank_ids, rerank_scores, args.data_type, org_questions)
503
+
504
+ metrics = check(ground_ids, retrieval_ids)
505
+ print(metrics)
506
+ metrics = evaluate(retrieval_results, ground_truths)
507
+ print(metrics)
508
+ metrics = check(ground_ids, rerank_ids, cutoffs=[1,5,10,30])
509
+ print(metrics)
510
+
511
+ if __name__ == "__main__":
512
+ main()
src/train_cross_encoder.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import torch
4
+ from typing import Dict
5
+ from transformers.utils.logging import enable_explicit_format
6
+ from transformers.trainer_callback import PrinterCallback
7
+ from transformers import (
8
+ AutoTokenizer,
9
+ HfArgumentParser,
10
+ EvalPrediction,
11
+ set_seed,
12
+ PreTrainedTokenizerFast
13
+ )
14
+
15
+ from cross_rerank.logger_config import logger, LoggerCallback
16
+ from cross_rerank.config import Arguments
17
+ from cross_rerank.trainer import RerankerTrainer
18
+ from cross_rerank.data_loader import CrossEncoderDataLoader
19
+ from cross_rerank.collator import CrossEncoderCollator
20
+ from cross_rerank.metrics import accuracy
21
+ from cross_rerank.model import Reranker
22
+
23
+
24
+ def _common_setup(args: Arguments):
25
+ if args.process_index > 0:
26
+ logger.setLevel(logging.WARNING)
27
+ enable_explicit_format()
28
+ set_seed(args.seed)
29
+
30
+
31
+ def _compute_metrics(eval_pred: EvalPrediction) -> Dict:
32
+ preds = eval_pred.predictions
33
+ if isinstance(preds, tuple):
34
+ preds = preds[-1]
35
+ logits = torch.tensor(preds).float()
36
+ labels = torch.tensor(eval_pred.label_ids).long()
37
+ acc = accuracy(output=logits, target=labels)[0]
38
+
39
+ return {'acc': acc}
40
+
41
+
42
+ def main():
43
+ parser = HfArgumentParser((Arguments,))
44
+ args: Arguments = parser.parse_args_into_dataclasses()[0]
45
+ _common_setup(args)
46
+ logger.info('Args={}'.format(str(args)))
47
+
48
+ try:
49
+ tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(args.model_name_or_path)
50
+ except:
51
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
52
+
53
+ model: Reranker = Reranker.from_pretrained(
54
+ all_args=args,
55
+ pretrained_model_name_or_path=args.model_name_or_path,
56
+ num_labels=1)
57
+
58
+ logger.info(model)
59
+ logger.info('Vocab size: {}'.format(len(tokenizer)))
60
+
61
+ data_collator = CrossEncoderCollator(
62
+ tokenizer=tokenizer,
63
+ pad_to_multiple_of=256 if args.fp16 else 256)
64
+
65
+ rerank_data_loader = CrossEncoderDataLoader(args=args, tokenizer=tokenizer)
66
+ train_dataset = rerank_data_loader.train_dataset
67
+ eval_dataset = rerank_data_loader.eval_dataset
68
+
69
+ trainer = RerankerTrainer(
70
+ model=model,
71
+ args=args,
72
+ train_dataset=train_dataset if args.do_train else None,
73
+ eval_dataset=eval_dataset if args.do_eval else None,
74
+ data_collator=data_collator,
75
+ compute_metrics=_compute_metrics,
76
+ tokenizer=tokenizer,
77
+ )
78
+ trainer.remove_callback(PrinterCallback)
79
+ trainer.add_callback(LoggerCallback)
80
+ rerank_data_loader.trainer = trainer
81
+
82
+ if args.do_eval:
83
+ logger.info("*** Evaluate ***")
84
+ metrics = trainer.evaluate(metric_key_prefix="eval")
85
+ metrics["eval_samples"] = len(eval_dataset)
86
+
87
+ trainer.log_metrics("eval", metrics)
88
+ trainer.save_metrics("eval", metrics)
89
+
90
+ if args.do_train:
91
+ train_result = trainer.train(resume_from_checkpoint= args.resume_from_checkpoint)
92
+ trainer.save_model()
93
+
94
+ metrics = train_result.metrics
95
+ metrics["train_samples"] = len(train_dataset)
96
+
97
+ trainer.log_metrics("train", metrics)
98
+ trainer.save_metrics("train", metrics)
99
+
100
+ if args.do_eval:
101
+ logger.info("*** Evaluate ***")
102
+ metrics = trainer.evaluate(metric_key_prefix="eval")
103
+ metrics["eval_samples"] = len(eval_dataset)
104
+
105
+ trainer.log_metrics("eval", metrics)
106
+ trainer.save_metrics("eval", metrics)
107
+
108
+ return
109
+
110
+
111
+ if __name__ == "__main__":
112
+ main()