shima-n commited on
Commit
c74b017
·
verified ·
1 Parent(s): bad6501

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +566 -0
  2. int_to_doc_id.pkl +3 -0
  3. requirements.txt +12 -0
app.py ADDED
@@ -0,0 +1,566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from qdrant_client import QdrantClient
2
+ from qdrant_client.models import VectorParams, Distance
3
+ from sentence_transformers import SentenceTransformer, CrossEncoder
4
+ from datasets import load_dataset
5
+ import numpy as np
6
+ import pandas as pd
7
+ import time
8
+ from tqdm import tqdm
9
+ import os, pickle
10
+ import gradio as gr
11
+ from gradio_client import Client
12
+ from math import log2
13
+
14
+ # =====================
15
+ # PARAMETERS
16
+ # =====================
17
+ retrieval_n = 50
18
+ num_queries = 10
19
+ docs_n = 100000
20
+ batch_size = 1000
21
+ embedding_models = ["all-MiniLM-L6-v2"]
22
+ rerank_models = [
23
+ "cross-encoder/ms-marco-MiniLM-L-6-v2",
24
+ "cross-encoder/ms-marco-TinyBERT-L-6",
25
+ #"cross-encoder/nli-deberta-v3-base-biomed", # biomedical NLI fine-tune
26
+ #"ncbi/MedCPT-Cross-Encoder-msmarco" # biomedical passage reranker
27
+ ]
28
+
29
+ collection_name = "trec_covid"
30
+ qdrant_url = os.getenv("QDRANT_URL", "http://localhost:6333")k_values = [1, 3, 5, 10, 20]
31
+
32
+ # =====================
33
+ # LOAD DATA
34
+ # =====================
35
+ print("Loading datasets...")
36
+ corpus = load_dataset("BeIR/trec-covid", "corpus")
37
+ queries = load_dataset("BeIR/trec-covid", "queries")
38
+ qrels = load_dataset("BeIR/trec-covid-qrels", split='test')
39
+
40
+ print(f"Preparing corpus dict from first {docs_n} docs...")
41
+ corpus_docs = corpus['corpus'][:docs_n]
42
+ corpus_dict= {}
43
+ for i in tqdm(range(len(corpus_docs['_id'])), desc="Corpus dict build"):
44
+ corpus_dict[corpus_docs['_id'][i]] = corpus_docs['text'][i]
45
+ doc_ids_set = set(corpus_dict.keys())
46
+
47
+ print("Building qrels dictionary...")
48
+ qrels_dict = {}
49
+ for row in tqdm(qrels, desc="Processing qrels"):
50
+ qid = int(row['query-id'])
51
+ if qid not in qrels_dict:
52
+ qrels_dict[qid] = {}
53
+ if row['corpus-id'] in doc_ids_set:
54
+ qrels_dict[qid][row['corpus-id']] = int(row['score'])
55
+
56
+ filtered_qids = [qid for qid in qrels_dict.keys() if len(qrels_dict[qid]) > 0][:num_queries]
57
+
58
+ print(f"Filtering and loading {len(filtered_qids)} queries...")
59
+ queries_list = []
60
+ for qid in tqdm(filtered_qids, desc="Loading queries"):
61
+ filtered_query = queries['queries'].filter(lambda x: x['_id'] == str(qid))
62
+ if len(filtered_query) > 0:
63
+ queries_list.append((qid, filtered_query[0]['text']))
64
+
65
+ avg_relevant_docs = np.mean([len([doc for doc, score in rel.items() if score >= 2]) for rel in qrels_dict.values()])
66
+ print(f"Average relevant docs per query: {avg_relevant_docs:.2f}")
67
+
68
+
69
+ # =====================
70
+ # METRICS FUNCTIONS
71
+ # =====================
72
+ def recall_at_k(relevant, retrieved, k):
73
+ relevant_set = set(relevant.keys())
74
+ retrieved_k = set(retrieved[:k])
75
+ return len(relevant_set.intersection(retrieved_k)) / len(relevant_set) if relevant_set else 0
76
+
77
+ def precision_at_k(relevant, retrieved, k, rel_threshold=1):
78
+ relevant_set = set(doc for doc, score in relevant.items() if score >= rel_threshold)
79
+ retrieved_k = retrieved[:k]
80
+ return sum(1 for doc in retrieved_k if doc in relevant_set) / k
81
+
82
+ def dcg_at_k(rels, k):
83
+ return sum((2**rel - 1) / np.log2(idx + 2) for idx, rel in enumerate(rels[:k]))
84
+
85
+ def ndcg_at_k(relevant_scores, retrieved_ids, k):
86
+ retrieved_rels = [relevant_scores.get(doc_id, 0) for doc_id in retrieved_ids[:k]]
87
+ ideal_rels = sorted(relevant_scores.values(), reverse=True)[:k]
88
+ ideal_dcg = dcg_at_k(ideal_rels, k)
89
+ actual_dcg = dcg_at_k(retrieved_rels, k)
90
+ return actual_dcg / ideal_dcg if ideal_dcg > 0 else 0
91
+
92
+ def average_precision(relevant, retrieved, rel_threshold=1):
93
+ relevant_set = set(doc for doc, score in relevant.items() if score >= rel_threshold)
94
+ hits = 0
95
+ sum_prec = 0.0
96
+ for i, doc_id in enumerate(retrieved):
97
+ if doc_id in relevant_set:
98
+ hits += 1
99
+ sum_prec += hits / (i + 1)
100
+ return sum_prec / len(relevant_set) if relevant_set else 0
101
+
102
+ def reciprocal_rank(relevant, retrieved, rel_threshold=1):
103
+ relevant_set = set(doc for doc, score in relevant.items() if score >= rel_threshold)
104
+ for i, doc_id in enumerate(retrieved):
105
+ if doc_id in relevant_set:
106
+ return 1 / (i + 1)
107
+ return 0
108
+
109
+ def success_at_k(relevant, retrieved, k, rel_threshold=1):
110
+ relevant_set = set(doc for doc, score in relevant.items() if score >= rel_threshold)
111
+ return int(any(doc in relevant_set for doc in retrieved[:k]))
112
+
113
+ # =====================
114
+ # METRICS EVALUATION FUNCTION
115
+ # =====================
116
+ def evaluate_metrics(results_data, qrels_dict, k_values):
117
+ rows = []
118
+ for model_name, data in results_data.items():
119
+ recalls = {k: [] for k in k_values}
120
+ precisions = {k: [] for k in k_values}
121
+ ndcgs = {k: [] for k in k_values}
122
+ success = {k: [] for k in k_values}
123
+ maps = []
124
+ mrrs = []
125
+ retrieval_times = data.get("retrieval_times", [])
126
+ rerank_times = data.get("rerank_times", [])
127
+
128
+ print(f"Evaluating metrics for {model_name} ...")
129
+ for i, (qid, retrieved, rerank_scores) in enumerate(tqdm(zip(data["qids"], data["retrieved"], data["rerank_scores"]), total=len(data["qids"]), desc=f"Metrics {model_name}")):
130
+ relevant = qrels_dict.get(qid, {})
131
+ if rerank_scores:
132
+ sorted_docs = [doc for doc, score in sorted(zip(retrieved, rerank_scores), key=lambda x: x[1], reverse=True)]
133
+ else:
134
+ sorted_docs = retrieved
135
+
136
+ for k in k_values:
137
+ recalls[k].append(recall_at_k(relevant, sorted_docs, k))
138
+ precisions[k].append(precision_at_k(relevant, sorted_docs, k))
139
+ ndcgs[k].append(ndcg_at_k(relevant, sorted_docs, k))
140
+ success[k].append(success_at_k(relevant, sorted_docs, k))
141
+
142
+ maps.append(average_precision(relevant, sorted_docs))
143
+ mrrs.append(reciprocal_rank(relevant, sorted_docs))
144
+
145
+ avg_retrieval_time = np.mean(retrieval_times) if retrieval_times else 0
146
+ avg_rerank_time = np.mean(rerank_times) if rerank_times else 0
147
+
148
+ row = {"Model": model_name}
149
+ for k in k_values:
150
+ row[f"Recall@{k}"] = round(np.mean(recalls[k]), 4)
151
+ row[f"Precision@{k}"] = round(np.mean(precisions[k]), 4)
152
+ row[f"NDCG@{k}"] = round(np.mean(ndcgs[k]), 4)
153
+ row[f"Success@{k}"] = round(np.mean(success[k]), 4)
154
+ row["MAP"] = round(np.mean(maps), 4)
155
+ row["MRR"] = round(np.mean(mrrs), 4)
156
+ row["AvgRetrievalTime(s)"] = round(avg_retrieval_time, 4)
157
+ row["AvgRerankTime(s)"] = round(avg_rerank_time, 4)
158
+ rows.append(row)
159
+ return pd.DataFrame(rows)
160
+
161
+ # =====================
162
+ # Encoding + Upload
163
+ # =====================
164
+
165
+ def encode_and_upload():
166
+ client = QdrantClient(url=qdrant_url, api_key=os.getenv("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.4a-XqSJIvuhW6_IO8kKpRir9k6NfWH7yY3NcZciHx-4"))
167
+
168
+ for embedding_model in embedding_models:
169
+ print(f"Encoding corpus with embedding model {embedding_model} ...")
170
+ embedder = SentenceTransformer(embedding_model)
171
+
172
+ corpus_ids = list(doc_ids_set)
173
+ corpus_texts = [corpus_dict[doc_id] for doc_id in tqdm(corpus_ids, desc="Encoding corpus texts")]
174
+
175
+ # Normalize embeddings for cosine similarity
176
+ vectors = embedder.encode(corpus_texts, normalize_embeddings=True).tolist()
177
+
178
+ global doc_id_to_int, int_to_doc_id
179
+ doc_id_to_int = {doc_id: i for i, doc_id in enumerate(corpus_ids)}
180
+ int_to_doc_id = {i: doc_id for doc_id, i in doc_id_to_int.items()}
181
+
182
+ # Create collection only if it doesn't exist
183
+ if not client.collection_exists(collection_name):
184
+ print(f"Creating collection '{collection_name}' ...")
185
+ client.create_collection(
186
+ collection_name=collection_name,
187
+ vectors_config=VectorParams(size=len(vectors[0]), distance=Distance.COSINE)
188
+ )
189
+ else:
190
+ print(f"Collection '{collection_name}' already exists. Skipping creation.")
191
+
192
+ # Check already uploaded points
193
+ existing_ids = set()
194
+ scroll_res, _ = client.scroll(collection_name=collection_name, with_payload=False, limit=100000)
195
+ existing_ids = {point.id for point in scroll_res}
196
+ print(f"Already stored {len(existing_ids)} points in '{collection_name}'.")
197
+
198
+ # Prepare points for only missing IDs
199
+ new_points = []
200
+ for doc_id, vec in zip(corpus_ids, vectors):
201
+ pid = doc_id_to_int[doc_id]
202
+ if pid not in existing_ids:
203
+ new_points.append({"id": pid, "vector": vec, "payload": {"text": corpus_dict[doc_id]}})
204
+
205
+ print(f"Uploading {len(new_points)} new points to collection '{collection_name}' ...")
206
+ for i in tqdm(range(0, len(new_points), batch_size), desc="Upserting points in batches"):
207
+ batch = new_points[i:i + batch_size]
208
+ client.upsert(collection_name=collection_name, points=batch)
209
+
210
+ # Preview first 5 stored docs
211
+ preview, _ = client.scroll(collection_name=collection_name, limit=5, with_payload=True)
212
+ print("\nPreview of stored points:")
213
+ for point in preview:
214
+ print(f"ID: {point.id} | Text: {point.payload['text'][:80]}...")
215
+
216
+ return embedder
217
+
218
+ # =====================
219
+ # Baseline Retrieval (No rerank)
220
+ # =====================
221
+ def run_retrieval(embedder):
222
+ client = QdrantClient(url=qdrant_url, api_key=os.getenv("QDRANT_API_KEY"))
223
+ retrieval_times = []
224
+ retrieved_docs_list = []
225
+ rerank_scores_list = []
226
+ qids = []
227
+
228
+ print("Running baseline retrieval ...")
229
+ for qid, qtext in tqdm(queries_list, desc="Baseline retrieval queries"):
230
+ q_vec = embedder.encode([qtext], normalize_embeddings=True)[0]
231
+
232
+ start_time = time.time()
233
+ search_result = client.query_points(
234
+ collection_name=collection_name,
235
+ query=q_vec,
236
+ limit=retrieval_n,
237
+ with_payload=True
238
+ )
239
+ retrieval_time = time.time() - start_time
240
+ retrieval_times.append(retrieval_time)
241
+
242
+ retrieved_ids_int = [hit.id for hit in search_result.points]
243
+ retrieved_ids = [int_to_doc_id[i] for i in retrieved_ids_int]
244
+
245
+ qids.append(qid)
246
+ retrieved_docs_list.append(retrieved_ids)
247
+ rerank_scores_list.append([])
248
+
249
+ results = {
250
+ "qids": qids,
251
+ "retrieved": retrieved_docs_list,
252
+ "rerank_scores": rerank_scores_list,
253
+ "retrieval_times": retrieval_times,
254
+ "rerank_times": []
255
+ }
256
+ return results
257
+
258
+ # =====================
259
+ # Retrieval + Rerank
260
+ # =====================
261
+ def run_rerank(embedder):
262
+ client = QdrantClient(url=qdrant_url, api_key=os.getenv("QDRANT_API_KEY"))
263
+ results_data = {}
264
+
265
+ for rerank_model in rerank_models:
266
+ print(f"Running retrieval + reranking with model {rerank_model} ...")
267
+ reranker = CrossEncoder(rerank_model, trust_remote_code=True)
268
+ retrieval_times = []
269
+ rerank_times = []
270
+ retrieved_docs_list = []
271
+ rerank_scores_list = []
272
+ qids = []
273
+
274
+ for qid, qtext in tqdm(queries_list, desc=f"Retrieval + rerank with {rerank_model}"):
275
+ q_vec = embedder.encode([qtext], normalize_embeddings=True)[0]
276
+
277
+ start_retrieval = time.time()
278
+ search_result = client.query_points(
279
+ collection_name=collection_name,
280
+ query=q_vec,
281
+ limit=retrieval_n,
282
+ with_payload=True
283
+ )
284
+ retrieval_time = time.time() - start_retrieval
285
+ retrieval_times.append(retrieval_time)
286
+
287
+ retrieved_ids_int = [hit.id for hit in search_result.points]
288
+ retrieved_ids = [int_to_doc_id[i] for i in retrieved_ids_int]
289
+ retrieved_texts = [hit.payload['text'] for hit in search_result.points]
290
+
291
+ start_rerank = time.time()
292
+ pairs = [(qtext, txt) for txt in retrieved_texts]
293
+ rerank_scores = reranker.predict(pairs)
294
+ rerank_time = time.time() - start_rerank
295
+ rerank_times.append(rerank_time)
296
+
297
+ qids.append(qid)
298
+ retrieved_docs_list.append(retrieved_ids)
299
+ rerank_scores_list.append(list(rerank_scores))
300
+
301
+ results_data[rerank_model] = {
302
+ "qids": qids,
303
+ "retrieved": retrieved_docs_list,
304
+ "rerank_scores": rerank_scores_list,
305
+ "retrieval_times": retrieval_times,
306
+ "rerank_times": rerank_times
307
+ }
308
+
309
+ return results_data
310
+
311
+
312
+ # =====================
313
+ # MAIN RUN
314
+ # =====================
315
+ if __name__ == "__main__":
316
+ #embedder = encode_and_upload()
317
+
318
+ #baseline_results = run_retrieval(embedder)
319
+ rerank_results = run_rerank(embedder)
320
+
321
+ #all_results = {"Qdrant Baseline": baseline_results}
322
+ all_results.update(rerank_results)
323
+
324
+ df_metrics = evaluate_metrics(all_results, qrels_dict, k_values)
325
+
326
+
327
+ # Prepare column groups
328
+ recall_cols = ["Model"] + [f"Recall@{k}" for k in k_values] + [f"Precision@{k}" for k in k_values]
329
+ ndcg_success_cols = ["Model"] + [f"NDCG@{k}" for k in k_values] + [f"Success@{k}" for k in k_values]
330
+ summary_cols = ["Model", "MAP", "MRR", "AvgRetrievalTime(s)", "AvgRerankTime(s)"]
331
+
332
+ print("\n--- Recall and Precision ---")
333
+ print(df_metrics[recall_cols].to_string(index=False))
334
+
335
+ print("\n--- NDCG and Success ---")
336
+ print(df_metrics[ndcg_success_cols].to_string(index=False))
337
+
338
+ print("\n--- Summary Metrics and Timing ---")
339
+ print(df_metrics[summary_cols].to_string(index=False))
340
+
341
+
342
+ avg_relevant_docs = np.mean([len([doc for doc, score in rel.items() if score >= 1]) for rel in qrels_dict.values()])
343
+ print(f"Average relevant docs per query: {avg_relevant_docs:.2f}")
344
+
345
+
346
+ # --------------------
347
+ # CONFIG
348
+ # --------------------
349
+ QDRANT_URL = os.getenv("https://5cd56757-1989-4ce6-b7b6-97f6e13f9e89.us-east4-0.gcp.cloud.qdrant.io:6333", "http://localhost:6333")COLLECTION_NAME = "trec_covid"
350
+ EMBEDDING_MODEL = "all-MiniLM-L6-v2"
351
+ MAPPING_FILE = "int_to_doc_id.pkl"
352
+ # --------------------
353
+ # DATA
354
+ # --------------------
355
+ corpus = load_dataset("BeIR/trec-covid", "corpus")
356
+ queries = load_dataset("BeIR/trec-covid", "queries")
357
+ qrels = load_dataset("BeIR/trec-covid-qrels", split="test")
358
+
359
+ qrels_dict = {}
360
+ for row in qrels:
361
+ qid = int(row["query-id"])
362
+ qrels_dict.setdefault(qid, {})[row["corpus-id"]] = int(row["score"])
363
+
364
+ qds = queries["queries"]
365
+ max_dd = min(200, len(qds))
366
+ _qids = qds["_id"][:max_dd]
367
+ _texts = qds["text"][:max_dd]
368
+ trec_queries = [(f"{_qids[i]}: {_texts[i][:80]}", int(_qids[i]), _texts[i]) for i in range(max_dd)]
369
+ label2qt = {lab: (qid, txt) for (lab, qid, txt) in trec_queries}
370
+
371
+ # --------------------
372
+ # ID MAP
373
+ # --------------------
374
+ if not os.path.exists(MAPPING_FILE):
375
+ raise FileNotFoundError(f"Missing {MAPPING_FILE}. Save it during indexing.")
376
+ with open(MAPPING_FILE, "rb") as f:
377
+ int_to_doc_id = pickle.load(f)
378
+ INDEXED_DOC_IDS = set(int_to_doc_id.values())
379
+
380
+ # --------------------
381
+ # Lazy singletons
382
+ # --------------------
383
+ _client = None
384
+ _embedder = None
385
+ _rerankers = {}
386
+ def get_client():
387
+ global _client
388
+ if _client is None:
389
+ _client = QdrantClient(url=QDRANT_URL, api_key=os.getenv("QDRANT_API_KEY"))
390
+ return _client
391
+
392
+ def get_embedder():
393
+ global _embedder
394
+ if _embedder is None:
395
+ _embedder = SentenceTransformer(EMBEDDING_MODEL)
396
+ return _embedder
397
+
398
+ def get_reranker(model_name):
399
+ if model_name not in _rerankers:
400
+ _rerankers[model_name] = CrossEncoder(model_name, trust_remote_code=True)
401
+ return _rerankers[model_name]
402
+
403
+ # --------------------
404
+ # Metrics
405
+ # --------------------
406
+ def recall_at_k(relevant_ids_set, retrieved_ids, k):
407
+ if not relevant_ids_set:
408
+ return None
409
+ return len(relevant_ids_set.intersection(retrieved_ids[:k])) / len(relevant_ids_set)
410
+
411
+ def precision_at_k(relevant_ids_set, retrieved_ids, k):
412
+ if k == 0:
413
+ return None
414
+ return len(relevant_ids_set.intersection(retrieved_ids[:k])) / k
415
+
416
+ def hit_at_k(relevant_ids_set, retrieved_ids, k):
417
+ return int(len(relevant_ids_set.intersection(retrieved_ids[:k])) > 0)
418
+
419
+ def ndcg_at_k(relevant_ids_scores, retrieved_ids, k):
420
+ dcg = 0.0
421
+ idcg = 0.0
422
+ for i, doc_id in enumerate(retrieved_ids[:k]):
423
+ rel = relevant_ids_scores.get(doc_id, 0)
424
+ if rel > 0:
425
+ dcg += (2**rel - 1) / log2(i+2)
426
+ sorted_rels = sorted(relevant_ids_scores.values(), reverse=True)[:k]
427
+ for i, rel in enumerate(sorted_rels):
428
+ if rel > 0:
429
+ idcg += (2**rel - 1) / log2(i+2)
430
+ return dcg / idcg if idcg > 0 else None
431
+
432
+ def evaluate_model(relevant_in_collection, relevant_scores_in_collection, doc_order, k):
433
+ return {
434
+ "Recall@k": round(recall_at_k(relevant_in_collection, doc_order, k), 4),
435
+ "Precision@k": round(precision_at_k(relevant_in_collection, doc_order, k), 4),
436
+ "Hit@k": hit_at_k(relevant_in_collection, doc_order, k),
437
+ "NDCG@k": None if ndcg_at_k(relevant_scores_in_collection, doc_order, k) is None else round(ndcg_at_k(relevant_scores_in_collection, doc_order, k), 4),
438
+ }
439
+
440
+ # --------------------
441
+ # Core
442
+ # --------------------
443
+ def run_demo(
444
+ query_text, retrieval_n, top_k, use_trec, trec_label, rel_threshold,
445
+ use_baseline, *selected_rerankers
446
+ ):
447
+ client = get_client()
448
+ embedder = get_embedder()
449
+
450
+ qid = None
451
+ if use_trec and trec_label:
452
+ qid, query_text = label2qt[trec_label]
453
+
454
+ if not query_text or not query_text.strip():
455
+ return pd.DataFrame(), {"Note": "Empty query."}
456
+
457
+ q_vec = embedder.encode([query_text], normalize_embeddings=True)[0]
458
+ res = client.query_points(
459
+ collection_name=COLLECTION_NAME,
460
+ query=q_vec,
461
+ limit=int(retrieval_n),
462
+ with_payload=True
463
+ )
464
+ points = getattr(res, "points", res)
465
+
466
+ cand_docs, cand_texts, cand_qdrant_scores = [], [], []
467
+ for p in points:
468
+ payload = getattr(p, "payload", {}) or {}
469
+ pid = int(getattr(p, "id"))
470
+ doc_id = payload.get("doc_id", int_to_doc_id.get(pid, str(pid)))
471
+ cand_docs.append(doc_id)
472
+ cand_texts.append(payload.get("text", ""))
473
+ cand_qdrant_scores.append(getattr(p, "score", None))
474
+
475
+ cols = {
476
+ "rank": list(range(1, int(top_k)+1)),
477
+ "doc_id": [],
478
+ "score_qdrant": [],
479
+ "text_snippet": [],
480
+ }
481
+ reranker_scores = {}
482
+
483
+ for model_name, is_selected in zip(rerank_models, selected_rerankers):
484
+ if is_selected:
485
+ rr = get_reranker(model_name)
486
+ reranker_scores[model_name] = rr.predict([(query_text, t) for t in cand_texts])
487
+
488
+ for i in range(min(int(top_k), len(cand_docs))):
489
+ cols["doc_id"].append(cand_docs[i])
490
+ cols["score_qdrant"].append(cand_qdrant_scores[i])
491
+ txt = cand_texts[i]
492
+ cols["text_snippet"].append(txt[:300] + ("…" if len(txt) > 300 else ""))
493
+ for model_name in reranker_scores:
494
+ col_key = f"score_{model_name.split('/')[-1]}"
495
+ if col_key not in cols:
496
+ cols[col_key] = []
497
+ cols[col_key].append(float(reranker_scores[model_name][i]))
498
+
499
+ df = pd.DataFrame(cols)
500
+
501
+ metrics = {}
502
+ if qid is not None:
503
+ rels = qrels_dict.get(qid, {})
504
+ relevant_all = {d for d, s in rels.items() if s >= rel_threshold}
505
+ relevant_in_collection = relevant_all & INDEXED_DOC_IDS
506
+ relevant_scores_in_collection = {d: s for d, s in rels.items() if d in INDEXED_DOC_IDS}
507
+ ceiling_recall = round(len(relevant_in_collection) / len(relevant_all), 4) if relevant_all else None
508
+
509
+ if use_baseline:
510
+ metrics["Qdrant"] = evaluate_model(relevant_in_collection, relevant_scores_in_collection, cand_docs, int(top_k))
511
+
512
+ for model_name, is_selected in zip(rerank_models, selected_rerankers):
513
+ if is_selected:
514
+ order = sorted(range(len(cand_docs)), key=lambda i: reranker_scores[model_name][i], reverse=True)
515
+ top_docs = [cand_docs[i] for i in order[:int(top_k)]]
516
+ metrics[model_name] = evaluate_model(relevant_in_collection, relevant_scores_in_collection, top_docs, int(top_k))
517
+
518
+ metrics["QueryID"] = int(qid)
519
+ metrics["Relevant>=threshold (all)"] = len(relevant_all)
520
+ metrics["Relevant in collection"] = len(relevant_in_collection)
521
+ metrics["Recall Ceiling (collection)"] = ceiling_recall
522
+
523
+ return df, metrics
524
+
525
+ # --------------------
526
+ # UI
527
+ # --------------------
528
+ with gr.Blocks(title="Qdrant Retrieval Demo") as demo:
529
+ gr.Markdown("### Qdrant Retrieval Demo (TREC-COVID) + Multiple Metrics")
530
+
531
+ with gr.Row():
532
+ query_text = gr.Textbox(label="Query (free text)", placeholder="e.g., ACE2 inhibitors and COVID-19", lines=2)
533
+ with gr.Row():
534
+ retrieval_n = gr.Slider(10, 2000, value=50, step=10, label="retrieval_n (candidates from Qdrant)")
535
+ top_k = gr.Slider(1, 500, value=10, step=1, label="top_k (metrics cutoff)")
536
+ with gr.Row():
537
+ use_trec = gr.Checkbox(label="Use a TREC-COVID query", value=True)
538
+ trec_choice = gr.Dropdown(choices=[lab for (lab, _, _) in trec_queries],
539
+ value=trec_queries[0][0] if trec_queries else None,
540
+ label="Pick TREC-COVID query")
541
+ rel_threshold = gr.Radio(choices=[1, 2], value=1, label="Relevance threshold")
542
+
543
+ gr.Markdown("**Models to evaluate:**")
544
+ with gr.Row():
545
+ use_baseline = gr.Checkbox(label="Qdrant baseline", value=True)
546
+ ce_checkboxes = [gr.Checkbox(label=model_name, value=False) for model_name in rerank_models]
547
+
548
+ run_btn = gr.Button("Search")
549
+ out_df = gr.Dataframe(label="Retrieved Docs + Scores", wrap=True)
550
+ out_metrics = gr.JSON(label="Metrics (per selected model + ceiling recall)")
551
+
552
+ run_btn.click(
553
+ fn=run_demo,
554
+ inputs=[query_text, retrieval_n, top_k, use_trec, trec_choice, rel_threshold,
555
+ use_baseline, *ce_checkboxes],
556
+ outputs=[out_df, out_metrics]
557
+ )
558
+ # demo.launch(...) # disabled for Spaces; see __main__ block below
559
+
560
+
561
+ if __name__ == "__main__":
562
+ try:
563
+ demo # Gradio Blocks defined in the notebook
564
+ except NameError:
565
+ raise RuntimeError("Could not find `demo`. Ensure your notebook defines `demo = gr.Blocks(...)`.")
566
+ demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))
int_to_doc_id.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2646d9e29946c295f2f697dfe63232cf5a8540cc24c2a98a4c1fcbf0d6b4a870
3
+ size 1469086
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ fastapi
3
+ uvicorn
4
+ qdrant-client
5
+ sentence-transformers
6
+ transformers
7
+ datasets
8
+ pandas
9
+ numpy
10
+ scikit-learn
11
+ torch
12
+ accelerate