Upload evaluation script

#3
Files changed (1) hide show
  1. st_eval.py +341 -0
st_eval.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, Literal
2
+ from tqdm import tqdm
3
+ import numpy as np
4
+ import os, csv
5
+ from sentence_transformers.cross_encoder.evaluation import CrossEncoderNanoBEIREvaluator, CrossEncoderRerankingEvaluator
6
+ from sentence_transformers.util import is_datasets_available
7
+
8
+ from gliclass import ZeroShotClassificationPipeline, ZeroShotClassificationWithLabelsChunkingPipeline
9
+
10
+ import logging
11
+ logger = logging.getLogger(__name__)
12
+
13
+ DatasetNameType = Literal[
14
+ "climatefever",
15
+ "dbpedia",
16
+ "fever",
17
+ "fiqa2018",
18
+ "hotpotqa",
19
+ "msmarco",
20
+ "nfcorpus",
21
+ "nq",
22
+ "quoraretrieval",
23
+ "scidocs",
24
+ "arguana",
25
+ "scifact",
26
+ "touche2020",
27
+ ]
28
+
29
+ dataset_name_to_id = {
30
+ "climatefever": "sentence-transformers/NanoClimateFEVER-bm25",
31
+ "dbpedia": "sentence-transformers/NanoDBPedia-bm25",
32
+ "fever": "sentence-transformers/NanoFEVER-bm25",
33
+ "fiqa2018": "sentence-transformers/NanoFiQA2018-bm25",
34
+ "hotpotqa": "sentence-transformers/NanoHotpotQA-bm25",
35
+ "msmarco": "sentence-transformers/NanoMSMARCO-bm25",
36
+ "nfcorpus": "sentence-transformers/NanoNFCorpus-bm25",
37
+ "nq": "sentence-transformers/NanoNQ-bm25",
38
+ "quoraretrieval": "sentence-transformers/NanoQuoraRetrieval-bm25",
39
+ "scidocs": "sentence-transformers/NanoSCIDOCS-bm25",
40
+ "arguana": "sentence-transformers/NanoArguAna-bm25",
41
+ "scifact": "sentence-transformers/NanoSciFact-bm25",
42
+ "touche2020": "sentence-transformers/NanoTouche2020-bm25",
43
+ }
44
+
45
+ dataset_name_to_human_readable = {
46
+ "climatefever": "ClimateFEVER",
47
+ "dbpedia": "DBPedia",
48
+ "fever": "FEVER",
49
+ "fiqa2018": "FiQA2018",
50
+ "hotpotqa": "HotpotQA",
51
+ "msmarco": "MSMARCO",
52
+ "nfcorpus": "NFCorpus",
53
+ "nq": "NQ",
54
+ "quoraretrieval": "QuoraRetrieval",
55
+ "scidocs": "SCIDOCS",
56
+ "arguana": "ArguAna",
57
+ "scifact": "SciFact",
58
+ "touche2020": "Touche2020",
59
+ }
60
+
61
+ class GLiClassRerankingEvaluator(CrossEncoderRerankingEvaluator):
62
+ def __call__(
63
+ self, model: Union[ZeroShotClassificationPipeline|ZeroShotClassificationWithLabelsChunkingPipeline], output_path: str = None, epoch: int = -1, steps: int = -1, labels_chunk_size: int = -1
64
+ ) -> dict[str, float]:
65
+
66
+ if epoch != -1:
67
+ if steps == -1:
68
+ out_txt = f" after epoch {epoch}"
69
+ else:
70
+ out_txt = f" in epoch {epoch} after {steps} steps"
71
+ else:
72
+ out_txt = ""
73
+
74
+ logger.info(f"GLiClassRerankingEvaluator: Evaluating the model on the {self.name} dataset{out_txt}:")
75
+
76
+ base_mrr_scores = []
77
+ base_ndcg_scores = []
78
+ base_ap_scores = []
79
+ all_mrr_scores = []
80
+ all_ndcg_scores = []
81
+ all_ap_scores = []
82
+ num_queries = 0
83
+ num_positives = []
84
+ num_negatives = []
85
+ for instance in tqdm(self.samples, desc="Evaluating samples", disable=not self.show_progress_bar, leave=False):
86
+ if "query" not in instance:
87
+ raise ValueError("GLiClassRerankingEvaluator requires a 'query' key in each sample.")
88
+ if "positive" not in instance:
89
+ raise ValueError("GLiClassRerankingEvaluator requires a 'positive' key in each sample.")
90
+ if ("negative" in instance and "documents" in instance) or (
91
+ "negative" not in instance and "documents" not in instance
92
+ ):
93
+ raise ValueError(
94
+ "GLiClassRerankingEvaluator requires exactly one of 'negative' and 'documents' in each sample."
95
+ )
96
+
97
+ query = instance["query"]
98
+ positive = instance["positive"]
99
+ if isinstance(positive, str):
100
+ positive = [positive]
101
+
102
+ negative = instance.get("negative", None)
103
+ documents = instance.get("documents", None)
104
+
105
+ if documents:
106
+ base_is_relevant = [int(sample in positive) for sample in documents]
107
+ if sum(base_is_relevant) == 0:
108
+ base_mrr, base_ndcg, base_ap = 0, 0, 0
109
+ else:
110
+ # If not all positives are in documents, we need to add them at the end
111
+ base_is_relevant += [1] * (len(positive) - sum(base_is_relevant))
112
+ base_pred_scores = np.array(range(len(base_is_relevant), 0, -1))
113
+ base_mrr, base_ndcg, base_ap = self.compute_metrics(base_is_relevant, base_pred_scores)
114
+ base_mrr_scores.append(base_mrr)
115
+ base_ndcg_scores.append(base_ndcg)
116
+ base_ap_scores.append(base_ap)
117
+
118
+ if self.always_rerank_positives:
119
+ docs = positive + [doc for doc in documents if doc not in positive]
120
+ is_relevant = [1] * len(positive) + [0] * (len(docs) - len(positive))
121
+ else:
122
+ docs = documents
123
+ is_relevant = [int(sample in positive) for sample in documents]
124
+ else:
125
+ docs = positive + negative
126
+ is_relevant = [1] * len(positive) + [0] * len(negative)
127
+
128
+ num_queries += 1
129
+
130
+ num_positives.append(len(positive))
131
+ num_negatives.append(len(is_relevant) - sum(is_relevant))
132
+
133
+ if sum(is_relevant) == 0:
134
+ all_mrr_scores.append(0)
135
+ all_ndcg_scores.append(0)
136
+ all_ap_scores.append(0)
137
+ continue
138
+
139
+ if labels_chunk_size>0 and isinstance(model, ZeroShotClassificationWithLabelsChunkingPipeline):
140
+ gliclass_outputs = model(query, docs, threshold=0.0, labels_chunk_size=labels_chunk_size)
141
+ else:
142
+ gliclass_outputs = model(query, docs, threshold=0.0)
143
+
144
+ pred_scores = np.array([item['score'] for item in gliclass_outputs[0]])
145
+ # Add the ignored positives at the end
146
+ if num_ignored_positives := len(is_relevant) - len(pred_scores):
147
+ pred_scores = np.concatenate([pred_scores, np.zeros(num_ignored_positives)])
148
+
149
+ mrr, ndcg, ap = self.compute_metrics(is_relevant, pred_scores)
150
+
151
+ all_mrr_scores.append(mrr)
152
+ all_ndcg_scores.append(ndcg)
153
+ all_ap_scores.append(ap)
154
+
155
+ mean_mrr = np.mean(all_mrr_scores)
156
+ mean_ndcg = np.mean(all_ndcg_scores)
157
+ mean_ap = np.mean(all_ap_scores)
158
+ metrics = {
159
+ "map": mean_ap,
160
+ f"mrr@{self.at_k}": mean_mrr,
161
+ f"ndcg@{self.at_k}": mean_ndcg,
162
+ }
163
+
164
+ logger.info(
165
+ f"Queries: {num_queries}\t"
166
+ f"Positives: Min {np.min(num_positives):.1f}, Mean {np.mean(num_positives):.1f}, Max {np.max(num_positives):.1f}\t"
167
+ f"Negatives: Min {np.min(num_negatives):.1f}, Mean {np.mean(num_negatives):.1f}, Max {np.max(num_negatives):.1f}"
168
+ )
169
+ if documents:
170
+ mean_base_mrr = np.mean(base_mrr_scores)
171
+ mean_base_ndcg = np.mean(base_ndcg_scores)
172
+ mean_base_ap = np.mean(base_ap_scores)
173
+ base_metrics = {
174
+ "base_map": mean_base_ap,
175
+ f"base_mrr@{self.at_k}": mean_base_mrr,
176
+ f"base_ndcg@{self.at_k}": mean_base_ndcg,
177
+ }
178
+ logger.info(f"{' ' * len(str(self.at_k))} Base -> Reranked")
179
+ logger.info(f"MAP:{' ' * len(str(self.at_k))} {mean_base_ap * 100:.2f} -> {mean_ap * 100:.2f}")
180
+ logger.info(f"MRR@{self.at_k}: {mean_base_mrr * 100:.2f} -> {mean_mrr * 100:.2f}")
181
+ logger.info(f"NDCG@{self.at_k}: {mean_base_ndcg * 100:.2f} -> {mean_ndcg * 100:.2f}")
182
+
183
+ model_card_metrics = {
184
+ "map": f"{mean_ap:.4f} ({mean_ap - mean_base_ap:+.4f})",
185
+ f"mrr@{self.at_k}": f"{mean_mrr:.4f} ({mean_mrr - mean_base_mrr:+.4f})",
186
+ f"ndcg@{self.at_k}": f"{mean_ndcg:.4f} ({mean_ndcg - mean_base_ndcg:+.4f})",
187
+ }
188
+ model_card_metrics = self.prefix_name_to_metrics(model_card_metrics, self.name)
189
+
190
+ metrics.update(base_metrics)
191
+ metrics = self.prefix_name_to_metrics(metrics, self.name)
192
+ else:
193
+ logger.info(f"MAP:{' ' * len(str(self.at_k))} {mean_ap * 100:.2f}")
194
+ logger.info(f"MRR@{self.at_k}: {mean_mrr * 100:.2f}")
195
+ logger.info(f"NDCG@{self.at_k}: {mean_ndcg * 100:.2f}")
196
+
197
+ metrics = self.prefix_name_to_metrics(metrics, self.name)
198
+ self.store_metrics_in_model_card_data(model, metrics, epoch, steps)
199
+
200
+ if output_path is not None and self.write_csv:
201
+ csv_path = os.path.join(output_path, self.csv_file)
202
+ output_file_exists = os.path.isfile(csv_path)
203
+ with open(csv_path, mode="a" if output_file_exists else "w", encoding="utf-8") as f:
204
+ writer = csv.writer(f)
205
+ if not output_file_exists:
206
+ writer.writerow(self.csv_headers)
207
+
208
+ writer.writerow([epoch, steps, mean_ap, mean_mrr, mean_ndcg])
209
+
210
+ return metrics
211
+
212
+ class GLiClassNanoBEIREvaluator(CrossEncoderNanoBEIREvaluator):
213
+ def _load_dataset(self, dataset_name, **ir_evaluator_kwargs) -> CrossEncoderRerankingEvaluator:
214
+ if not is_datasets_available():
215
+ raise ValueError(
216
+ "datasets is not available. Please install it to use the CrossEncoderNanoBEIREvaluator via `pip install datasets`."
217
+ )
218
+ from datasets import load_dataset
219
+
220
+ dataset_path = dataset_name_to_id[dataset_name.lower()]
221
+ corpus = load_dataset(dataset_path, "corpus", split="train")
222
+ corpus_mapping = dict(zip(corpus["_id"], corpus["text"]))
223
+ queries = load_dataset(dataset_path, "queries", split="train")
224
+ query_mapping = dict(zip(queries["_id"], queries["text"]))
225
+ relevance = load_dataset(dataset_path, "relevance", split="train")
226
+
227
+ def mapper(sample, corpus_mapping: dict[str, str], query_mapping: dict[str, str], rerank_k: int):
228
+ query = query_mapping[sample["query-id"]]
229
+ positives = [corpus_mapping[positive_id] for positive_id in sample["positive-corpus-ids"]]
230
+ documents = [corpus_mapping[document_id] for document_id in sample["bm25-ranked-ids"][:rerank_k]]
231
+ return {
232
+ "query": query,
233
+ "positive": positives,
234
+ "documents": documents,
235
+ }
236
+
237
+ relevance = relevance.map(
238
+ mapper,
239
+ fn_kwargs={"corpus_mapping": corpus_mapping, "query_mapping": query_mapping, "rerank_k": self.rerank_k},
240
+ )
241
+
242
+ human_readable_name = self._get_human_readable_name(dataset_name)
243
+ return GLiClassRerankingEvaluator(
244
+ samples=list(relevance),
245
+ name=human_readable_name,
246
+ **ir_evaluator_kwargs,
247
+ )
248
+
249
+ def __call__(
250
+ self, model: Union[ZeroShotClassificationPipeline|ZeroShotClassificationWithLabelsChunkingPipeline], output_path: str = None, epoch: int = -1, steps: int = -1, *args, **kwargs
251
+ ) -> dict[str, float]:
252
+ per_metric_results = {}
253
+ per_dataset_results = {}
254
+ if epoch != -1:
255
+ if steps == -1:
256
+ out_txt = f" after epoch {epoch}"
257
+ else:
258
+ out_txt = f" in epoch {epoch} after {steps} steps"
259
+ else:
260
+ out_txt = ""
261
+ logger.info(f"NanoBEIR Evaluation of the model on {self.dataset_names} dataset{out_txt}:")
262
+
263
+ for evaluator in tqdm(self.evaluators, desc="Evaluating datasets", disable=not self.show_progress_bar):
264
+ logger.info(f"Evaluating {evaluator.name}")
265
+ evaluation = evaluator(model, output_path, epoch, steps)
266
+ for k in evaluation:
267
+ dataset, _rerank_k, metric = k.split("_", maxsplit=2)
268
+ if metric not in per_metric_results:
269
+ per_metric_results[metric] = []
270
+ per_dataset_results[f"{dataset}_R{self.rerank_k}_{metric}"] = evaluation[k]
271
+ per_metric_results[metric].append(evaluation[k])
272
+ logger.info("")
273
+
274
+ agg_results = {}
275
+ for metric in per_metric_results:
276
+ agg_results[metric] = self.aggregate_fn(per_metric_results[metric])
277
+
278
+ if output_path is not None and self.write_csv:
279
+ csv_path = os.path.join(output_path, self.csv_file)
280
+ if not os.path.isfile(csv_path):
281
+ fOut = open(csv_path, mode="w", encoding="utf-8")
282
+ fOut.write(",".join(self.csv_headers))
283
+ fOut.write("\n")
284
+
285
+ else:
286
+ fOut = open(csv_path, mode="a", encoding="utf-8")
287
+
288
+ output_data = [
289
+ epoch,
290
+ steps,
291
+ agg_results["map"],
292
+ agg_results[f"mrr@{self.at_k}"],
293
+ agg_results[f"ndcg@{self.at_k}"],
294
+ ]
295
+
296
+ fOut.write(",".join(map(str, output_data)))
297
+ fOut.write("\n")
298
+ fOut.close()
299
+
300
+ logger.info("CrossEncoderNanoBEIREvaluator: Aggregated Results:")
301
+ logger.info(f"{' ' * len(str(self.at_k))} Base -> Reranked")
302
+ logger.info(
303
+ f"MAP:{' ' * len(str(self.at_k))} {agg_results['base_map'] * 100:.2f} -> {agg_results['map'] * 100:.2f}"
304
+ )
305
+ logger.info(
306
+ f"MRR@{self.at_k}: {agg_results[f'base_mrr@{self.at_k}'] * 100:.2f} -> {agg_results[f'mrr@{self.at_k}'] * 100:.2f}"
307
+ )
308
+ logger.info(
309
+ f"NDCG@{self.at_k}: {agg_results[f'base_ndcg@{self.at_k}'] * 100:.2f} -> {agg_results[f'ndcg@{self.at_k}'] * 100:.2f}"
310
+ )
311
+
312
+ model_card_metrics = {
313
+ "map": f"{agg_results['map']:.4f} ({agg_results['map'] - agg_results['base_map']:+.4f})",
314
+ f"mrr@{self.at_k}": f"{agg_results[f'mrr@{self.at_k}']:.4f} ({agg_results[f'mrr@{self.at_k}'] - agg_results[f'base_mrr@{self.at_k}']:+.4f})",
315
+ f"ndcg@{self.at_k}": f"{agg_results[f'ndcg@{self.at_k}']:.4f} ({agg_results[f'ndcg@{self.at_k}'] - agg_results[f'base_ndcg@{self.at_k}']:+.4f})",
316
+ }
317
+
318
+ agg_results = self.prefix_name_to_metrics(agg_results, self.name)
319
+ per_dataset_results.update(agg_results)
320
+
321
+ return per_dataset_results
322
+
323
+ if __name__ == '__main__':
324
+ from gliclass import GLiClassModel, ZeroShotClassificationPipeline, ZeroShotClassificationWithLabelsChunkingPipeline
325
+ from transformers import AutoTokenizer
326
+
327
+ chunk_pipeline = True
328
+
329
+ model_path = "knowledgator/gliclass-modern-base-v2.0"
330
+
331
+ model = GLiClassModel.from_pretrained(model_path)
332
+ tokenizer = AutoTokenizer.from_pretrained(model_path, add_prefix_space=True)
333
+ if not chunk_pipeline:
334
+ pipeline = ZeroShotClassificationPipeline(model, tokenizer, classification_type='multi-label', device='cuda:0', max_length=8192, progress_bar=False)
335
+ else:
336
+ pipeline = ZeroShotClassificationWithLabelsChunkingPipeline(model, tokenizer, classification_type='multi-label', device='cuda:0', max_length=8192, progress_bar=False)
337
+
338
+ dataset_names = ["msmarco", "nfcorpus", "nq"]
339
+ evaluator = GLiClassNanoBEIREvaluator(dataset_names)
340
+ results = evaluator(pipeline)
341
+ print(results)