from qdrant_client import QdrantClient from qdrant_client.models import VectorParams, Distance from sentence_transformers import SentenceTransformer, CrossEncoder from datasets import load_dataset import numpy as np import pandas as pd import time from tqdm import tqdm import os, pickle import gradio as gr from gradio_client import Client from math import log2 os.environ.setdefault("HF_HOME", "/app/.cache") os.environ.setdefault("HF_HUB_CACHE", "/app/.cache/hub") os.environ.setdefault("HF_DATASETS_CACHE", "/app/.cache/datasets") os.environ.setdefault("TRANSFORMERS_CACHE", "/app/.cache/transformers") # ===================== # PARAMETERS # ===================== retrieval_n = 50 num_queries = 10 docs_n = 100000 batch_size = 1000 embedding_models = ["all-MiniLM-L6-v2"] rerank_models = [ "cross-encoder/ms-marco-MiniLM-L-6-v2", "cross-encoder/ms-marco-TinyBERT-L-6", #"cross-encoder/nli-deberta-v3-base-biomed", # biomedical NLI fine-tune #"ncbi/MedCPT-Cross-Encoder-msmarco" # biomedical passage reranker ] collection_name = "trec_covid" qdrant_url = os.getenv("QDRANT_URL", "http://localhost:6333") k_values = [1, 3, 5, 10, 20] # ===================== # LOAD DATA # ===================== print("Loading datasets...") corpus = load_dataset("BeIR/trec-covid", "corpus") queries = load_dataset("BeIR/trec-covid", "queries") qrels = load_dataset("BeIR/trec-covid-qrels", split='test') print(f"Preparing corpus dict from first {docs_n} docs...") corpus_docs = corpus['corpus'][:docs_n] corpus_dict= {} for i in tqdm(range(len(corpus_docs['_id'])), desc="Corpus dict build"): corpus_dict[corpus_docs['_id'][i]] = corpus_docs['text'][i] doc_ids_set = set(corpus_dict.keys()) print("Building qrels dictionary...") qrels_dict = {} for row in tqdm(qrels, desc="Processing qrels"): qid = int(row['query-id']) if qid not in qrels_dict: qrels_dict[qid] = {} if row['corpus-id'] in doc_ids_set: qrels_dict[qid][row['corpus-id']] = int(row['score']) filtered_qids = [qid for qid in qrels_dict.keys() if len(qrels_dict[qid]) > 0][:num_queries] print(f"Filtering and loading {len(filtered_qids)} queries...") queries_list = [] for qid in tqdm(filtered_qids, desc="Loading queries"): filtered_query = queries['queries'].filter(lambda x: x['_id'] == str(qid)) if len(filtered_query) > 0: queries_list.append((qid, filtered_query[0]['text'])) avg_relevant_docs = np.mean([len([doc for doc, score in rel.items() if score >= 2]) for rel in qrels_dict.values()]) print(f"Average relevant docs per query: {avg_relevant_docs:.2f}") # ===================== # METRICS FUNCTIONS # ===================== def recall_at_k(relevant, retrieved, k): relevant_set = set(relevant.keys()) retrieved_k = set(retrieved[:k]) return len(relevant_set.intersection(retrieved_k)) / len(relevant_set) if relevant_set else 0 def precision_at_k(relevant, retrieved, k, rel_threshold=1): relevant_set = set(doc for doc, score in relevant.items() if score >= rel_threshold) retrieved_k = retrieved[:k] return sum(1 for doc in retrieved_k if doc in relevant_set) / k def dcg_at_k(rels, k): return sum((2**rel - 1) / np.log2(idx + 2) for idx, rel in enumerate(rels[:k])) def ndcg_at_k(relevant_scores, retrieved_ids, k): retrieved_rels = [relevant_scores.get(doc_id, 0) for doc_id in retrieved_ids[:k]] ideal_rels = sorted(relevant_scores.values(), reverse=True)[:k] ideal_dcg = dcg_at_k(ideal_rels, k) actual_dcg = dcg_at_k(retrieved_rels, k) return actual_dcg / ideal_dcg if ideal_dcg > 0 else 0 def average_precision(relevant, retrieved, rel_threshold=1): relevant_set = set(doc for doc, score in relevant.items() if score >= rel_threshold) hits = 0 sum_prec = 0.0 for i, doc_id in enumerate(retrieved): if doc_id in relevant_set: hits += 1 sum_prec += hits / (i + 1) return sum_prec / len(relevant_set) if relevant_set else 0 def reciprocal_rank(relevant, retrieved, rel_threshold=1): relevant_set = set(doc for doc, score in relevant.items() if score >= rel_threshold) for i, doc_id in enumerate(retrieved): if doc_id in relevant_set: return 1 / (i + 1) return 0 def success_at_k(relevant, retrieved, k, rel_threshold=1): relevant_set = set(doc for doc, score in relevant.items() if score >= rel_threshold) return int(any(doc in relevant_set for doc in retrieved[:k])) # ===================== # METRICS EVALUATION FUNCTION # ===================== def evaluate_metrics(results_data, qrels_dict, k_values): rows = [] for model_name, data in results_data.items(): recalls = {k: [] for k in k_values} precisions = {k: [] for k in k_values} ndcgs = {k: [] for k in k_values} success = {k: [] for k in k_values} maps = [] mrrs = [] retrieval_times = data.get("retrieval_times", []) rerank_times = data.get("rerank_times", []) print(f"Evaluating metrics for {model_name} ...") for i, (qid, retrieved, rerank_scores) in enumerate(tqdm(zip(data["qids"], data["retrieved"], data["rerank_scores"]), total=len(data["qids"]), desc=f"Metrics {model_name}")): relevant = qrels_dict.get(qid, {}) if rerank_scores: sorted_docs = [doc for doc, score in sorted(zip(retrieved, rerank_scores), key=lambda x: x[1], reverse=True)] else: sorted_docs = retrieved for k in k_values: recalls[k].append(recall_at_k(relevant, sorted_docs, k)) precisions[k].append(precision_at_k(relevant, sorted_docs, k)) ndcgs[k].append(ndcg_at_k(relevant, sorted_docs, k)) success[k].append(success_at_k(relevant, sorted_docs, k)) maps.append(average_precision(relevant, sorted_docs)) mrrs.append(reciprocal_rank(relevant, sorted_docs)) avg_retrieval_time = np.mean(retrieval_times) if retrieval_times else 0 avg_rerank_time = np.mean(rerank_times) if rerank_times else 0 row = {"Model": model_name} for k in k_values: row[f"Recall@{k}"] = round(np.mean(recalls[k]), 4) row[f"Precision@{k}"] = round(np.mean(precisions[k]), 4) row[f"NDCG@{k}"] = round(np.mean(ndcgs[k]), 4) row[f"Success@{k}"] = round(np.mean(success[k]), 4) row["MAP"] = round(np.mean(maps), 4) row["MRR"] = round(np.mean(mrrs), 4) row["AvgRetrievalTime(s)"] = round(avg_retrieval_time, 4) row["AvgRerankTime(s)"] = round(avg_rerank_time, 4) rows.append(row) return pd.DataFrame(rows) # ===================== # Encoding + Upload # ===================== def encode_and_upload(): client = QdrantClient(url=qdrant_url, api_key=os.getenv("QDRANT_API_KEY")) for embedding_model in embedding_models: print(f"Encoding corpus with embedding model {embedding_model} ...") embedder = SentenceTransformer(embedding_model) corpus_ids = list(doc_ids_set) corpus_texts = [corpus_dict[doc_id] for doc_id in tqdm(corpus_ids, desc="Encoding corpus texts")] # Normalize embeddings for cosine similarity vectors = embedder.encode(corpus_texts, normalize_embeddings=True).tolist() global doc_id_to_int, int_to_doc_id doc_id_to_int = {doc_id: i for i, doc_id in enumerate(corpus_ids)} int_to_doc_id = {i: doc_id for doc_id, i in doc_id_to_int.items()} # Create collection only if it doesn't exist if not client.collection_exists(collection_name): print(f"Creating collection '{collection_name}' ...") client.create_collection( collection_name=collection_name, vectors_config=VectorParams(size=len(vectors[0]), distance=Distance.COSINE) ) else: print(f"Collection '{collection_name}' already exists. Skipping creation.") # Check already uploaded points existing_ids = set() scroll_res, _ = client.scroll(collection_name=collection_name, with_payload=False, limit=100000) existing_ids = {point.id for point in scroll_res} print(f"Already stored {len(existing_ids)} points in '{collection_name}'.") # Prepare points for only missing IDs new_points = [] for doc_id, vec in zip(corpus_ids, vectors): pid = doc_id_to_int[doc_id] if pid not in existing_ids: new_points.append({"id": pid, "vector": vec, "payload": {"text": corpus_dict[doc_id]}}) print(f"Uploading {len(new_points)} new points to collection '{collection_name}' ...") for i in tqdm(range(0, len(new_points), batch_size), desc="Upserting points in batches"): batch = new_points[i:i + batch_size] client.upsert(collection_name=collection_name, points=batch) # Preview first 5 stored docs preview, _ = client.scroll(collection_name=collection_name, limit=5, with_payload=True) print("\nPreview of stored points:") for point in preview: print(f"ID: {point.id} | Text: {point.payload['text'][:80]}...") return embedder # ===================== # Baseline Retrieval (No rerank) # ===================== def run_retrieval(embedder): client = QdrantClient(url=qdrant_url, api_key=os.getenv("QDRANT_API_KEY")) retrieval_times = [] retrieved_docs_list = [] rerank_scores_list = [] qids = [] print("Running baseline retrieval ...") for qid, qtext in tqdm(queries_list, desc="Baseline retrieval queries"): q_vec = embedder.encode([qtext], normalize_embeddings=True)[0] start_time = time.time() search_result = client.query_points( collection_name=collection_name, query=q_vec, limit=retrieval_n, with_payload=True ) retrieval_time = time.time() - start_time retrieval_times.append(retrieval_time) retrieved_ids_int = [hit.id for hit in search_result.points] retrieved_ids = [int_to_doc_id[i] for i in retrieved_ids_int] qids.append(qid) retrieved_docs_list.append(retrieved_ids) rerank_scores_list.append([]) results = { "qids": qids, "retrieved": retrieved_docs_list, "rerank_scores": rerank_scores_list, "retrieval_times": retrieval_times, "rerank_times": [] } return results # ===================== # Retrieval + Rerank # ===================== def run_rerank(embedder): client = QdrantClient(url=qdrant_url, api_key=os.getenv("QDRANT_API_KEY")) results_data = {} for rerank_model in rerank_models: print(f"Running retrieval + reranking with model {rerank_model} ...") reranker = CrossEncoder(rerank_model, trust_remote_code=True) retrieval_times = [] rerank_times = [] retrieved_docs_list = [] rerank_scores_list = [] qids = [] for qid, qtext in tqdm(queries_list, desc=f"Retrieval + rerank with {rerank_model}"): q_vec = embedder.encode([qtext], normalize_embeddings=True)[0] start_retrieval = time.time() search_result = client.query_points( collection_name=collection_name, query=q_vec, limit=retrieval_n, with_payload=True ) retrieval_time = time.time() - start_retrieval retrieval_times.append(retrieval_time) retrieved_ids_int = [hit.id for hit in search_result.points] retrieved_ids = [int_to_doc_id[i] for i in retrieved_ids_int] retrieved_texts = [hit.payload['text'] for hit in search_result.points] start_rerank = time.time() pairs = [(qtext, txt) for txt in retrieved_texts] rerank_scores = reranker.predict(pairs) rerank_time = time.time() - start_rerank rerank_times.append(rerank_time) qids.append(qid) retrieved_docs_list.append(retrieved_ids) rerank_scores_list.append(list(rerank_scores)) results_data[rerank_model] = { "qids": qids, "retrieved": retrieved_docs_list, "rerank_scores": rerank_scores_list, "retrieval_times": retrieval_times, "rerank_times": rerank_times } return results_data # ===================== # MAIN RUN # ===================== if __name__ == "__main__": embedder = encode_and_upload() baseline_results = run_retrieval(embedder) rerank_results = run_rerank(embedder) all_results = {"Qdrant Baseline": baseline_results} all_results.update(rerank_results) df_metrics = evaluate_metrics(all_results, qrels_dict, k_values) # Prepare column groups recall_cols = ["Model"] + [f"Recall@{k}" for k in k_values] + [f"Precision@{k}" for k in k_values] ndcg_success_cols = ["Model"] + [f"NDCG@{k}" for k in k_values] + [f"Success@{k}" for k in k_values] summary_cols = ["Model", "MAP", "MRR", "AvgRetrievalTime(s)", "AvgRerankTime(s)"] print("\n--- Recall and Precision ---") print(df_metrics[recall_cols].to_string(index=False)) print("\n--- NDCG and Success ---") print(df_metrics[ndcg_success_cols].to_string(index=False)) print("\n--- Summary Metrics and Timing ---") print(df_metrics[summary_cols].to_string(index=False)) avg_relevant_docs = np.mean([len([doc for doc, score in rel.items() if score >= 1]) for rel in qrels_dict.values()]) print(f"Average relevant docs per query: {avg_relevant_docs:.2f}") # -------------------- # CONFIG # -------------------- QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6333") COLLECTION_NAME = "trec_covid" EMBEDDING_MODEL = "all-MiniLM-L6-v2" MAPPING_FILE = "int_to_doc_id.pkl" # -------------------- # DATA # -------------------- corpus = load_dataset("BeIR/trec-covid", "corpus") queries = load_dataset("BeIR/trec-covid", "queries") qrels = load_dataset("BeIR/trec-covid-qrels", split="test") qrels_dict = {} for row in qrels: qid = int(row["query-id"]) qrels_dict.setdefault(qid, {})[row["corpus-id"]] = int(row["score"]) qds = queries["queries"] max_dd = min(200, len(qds)) _qids = qds["_id"][:max_dd] _texts = qds["text"][:max_dd] trec_queries = [(f"{_qids[i]}: {_texts[i][:80]}", int(_qids[i]), _texts[i]) for i in range(max_dd)] label2qt = {lab: (qid, txt) for (lab, qid, txt) in trec_queries} # -------------------- # ID MAP # -------------------- if not os.path.exists(MAPPING_FILE): raise FileNotFoundError(f"Missing {MAPPING_FILE}. Save it during indexing.") with open(MAPPING_FILE, "rb") as f: int_to_doc_id = pickle.load(f) INDEXED_DOC_IDS = set(int_to_doc_id.values()) # -------------------- # Lazy singletons # -------------------- _client = None _embedder = None _rerankers = {} def get_client(): global _client if _client is None: _client = QdrantClient(url=QDRANT_URL, api_key=os.getenv("QDRANT_API_KEY")) return _client def get_embedder(): global _embedder if _embedder is None: _embedder = SentenceTransformer(EMBEDDING_MODEL) return _embedder def get_reranker(model_name): if model_name not in _rerankers: _rerankers[model_name] = CrossEncoder(model_name, trust_remote_code=True) return _rerankers[model_name] # -------------------- # Metrics # -------------------- def recall_at_k(relevant_ids_set, retrieved_ids, k): if not relevant_ids_set: return None return len(relevant_ids_set.intersection(retrieved_ids[:k])) / len(relevant_ids_set) def precision_at_k(relevant_ids_set, retrieved_ids, k): if k == 0: return None return len(relevant_ids_set.intersection(retrieved_ids[:k])) / k def hit_at_k(relevant_ids_set, retrieved_ids, k): return int(len(relevant_ids_set.intersection(retrieved_ids[:k])) > 0) def ndcg_at_k(relevant_ids_scores, retrieved_ids, k): dcg = 0.0 idcg = 0.0 for i, doc_id in enumerate(retrieved_ids[:k]): rel = relevant_ids_scores.get(doc_id, 0) if rel > 0: dcg += (2**rel - 1) / log2(i+2) sorted_rels = sorted(relevant_ids_scores.values(), reverse=True)[:k] for i, rel in enumerate(sorted_rels): if rel > 0: idcg += (2**rel - 1) / log2(i+2) return dcg / idcg if idcg > 0 else None def evaluate_model(relevant_in_collection, relevant_scores_in_collection, doc_order, k): return { "Recall@k": round(recall_at_k(relevant_in_collection, doc_order, k), 4), "Precision@k": round(precision_at_k(relevant_in_collection, doc_order, k), 4), "Hit@k": hit_at_k(relevant_in_collection, doc_order, k), "NDCG@k": None if ndcg_at_k(relevant_scores_in_collection, doc_order, k) is None else round(ndcg_at_k(relevant_scores_in_collection, doc_order, k), 4), } # -------------------- # Core # -------------------- def run_demo( query_text, retrieval_n, top_k, use_trec, trec_label, rel_threshold, use_baseline, *selected_rerankers ): client = get_client() embedder = get_embedder() qid = None if use_trec and trec_label: qid, query_text = label2qt[trec_label] if not query_text or not query_text.strip(): return pd.DataFrame(), {"Note": "Empty query."} q_vec = embedder.encode([query_text], normalize_embeddings=True)[0] res = client.query_points( collection_name=COLLECTION_NAME, query=q_vec, limit=int(retrieval_n), with_payload=True ) points = getattr(res, "points", res) cand_docs, cand_texts, cand_qdrant_scores = [], [], [] for p in points: payload = getattr(p, "payload", {}) or {} pid = int(getattr(p, "id")) doc_id = payload.get("doc_id", int_to_doc_id.get(pid, str(pid))) cand_docs.append(doc_id) cand_texts.append(payload.get("text", "")) cand_qdrant_scores.append(getattr(p, "score", None)) cols = { "rank": list(range(1, int(top_k)+1)), "doc_id": [], "score_qdrant": [], "text_snippet": [], } reranker_scores = {} for model_name, is_selected in zip(rerank_models, selected_rerankers): if is_selected: rr = get_reranker(model_name) reranker_scores[model_name] = rr.predict([(query_text, t) for t in cand_texts]) for i in range(min(int(top_k), len(cand_docs))): cols["doc_id"].append(cand_docs[i]) cols["score_qdrant"].append(cand_qdrant_scores[i]) txt = cand_texts[i] cols["text_snippet"].append(txt[:300] + ("…" if len(txt) > 300 else "")) for model_name in reranker_scores: col_key = f"score_{model_name.split('/')[-1]}" if col_key not in cols: cols[col_key] = [] cols[col_key].append(float(reranker_scores[model_name][i])) df = pd.DataFrame(cols) metrics = {} if qid is not None: rels = qrels_dict.get(qid, {}) relevant_all = {d for d, s in rels.items() if s >= rel_threshold} relevant_in_collection = relevant_all & INDEXED_DOC_IDS relevant_scores_in_collection = {d: s for d, s in rels.items() if d in INDEXED_DOC_IDS} ceiling_recall = round(len(relevant_in_collection) / len(relevant_all), 4) if relevant_all else None if use_baseline: metrics["Qdrant"] = evaluate_model(relevant_in_collection, relevant_scores_in_collection, cand_docs, int(top_k)) for model_name, is_selected in zip(rerank_models, selected_rerankers): if is_selected: order = sorted(range(len(cand_docs)), key=lambda i: reranker_scores[model_name][i], reverse=True) top_docs = [cand_docs[i] for i in order[:int(top_k)]] metrics[model_name] = evaluate_model(relevant_in_collection, relevant_scores_in_collection, top_docs, int(top_k)) metrics["QueryID"] = int(qid) metrics["Relevant>=threshold (all)"] = len(relevant_all) metrics["Relevant in collection"] = len(relevant_in_collection) metrics["Recall Ceiling (collection)"] = ceiling_recall return df, metrics # -------------------- # UI # -------------------- with gr.Blocks(title="Qdrant Retrieval Demo") as demo: gr.Markdown("### Qdrant Retrieval Demo (TREC-COVID) + Multiple Metrics") with gr.Row(): query_text = gr.Textbox(label="Query (free text)", placeholder="e.g., ACE2 inhibitors and COVID-19", lines=2) with gr.Row(): retrieval_n = gr.Slider(10, 2000, value=50, step=10, label="retrieval_n (candidates from Qdrant)") top_k = gr.Slider(1, 500, value=10, step=1, label="top_k (metrics cutoff)") with gr.Row(): use_trec = gr.Checkbox(label="Use a TREC-COVID query", value=True) trec_choice = gr.Dropdown(choices=[lab for (lab, _, _) in trec_queries], value=trec_queries[0][0] if trec_queries else None, label="Pick TREC-COVID query") rel_threshold = gr.Radio(choices=[1, 2], value=1, label="Relevance threshold") gr.Markdown("**Models to evaluate:**") with gr.Row(): use_baseline = gr.Checkbox(label="Qdrant baseline", value=True) ce_checkboxes = [gr.Checkbox(label=model_name, value=False) for model_name in rerank_models] run_btn = gr.Button("Search") out_df = gr.Dataframe(label="Retrieved Docs + Scores", wrap=True) out_metrics = gr.JSON(label="Metrics (per selected model + ceiling recall)") run_btn.click( fn=run_demo, inputs=[query_text, retrieval_n, top_k, use_trec, trec_choice, rel_threshold, use_baseline, *ce_checkboxes], outputs=[out_df, out_metrics] ) # demo.launch(...) # disabled for Spaces; see __main__ block below if __name__ == "__main__": try: demo # Gradio Blocks defined in the notebook except NameError: raise RuntimeError("Could not find `demo`. Ensure your notebook defines `demo = gr.Blocks(...)`.") demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))