""" retrieval.py ------------ This module handles retrieval of PubMed abstracts and indexing via Chromadb. It fetches abstracts using NCBI's E-utilities and indexes them in a vector store to enable similarity search for clinical queries. """ import os import tempfile import requests import torch from typing import List import chromadb from chromadb.config import Settings from transformers import AutoTokenizer, AutoModel # Optional: Set your PubMed API key from environment variables. PUBMED_API_KEY = os.environ.get("PUBMED_API_KEY", "") ############################################# # 1) FETCH PUBMED ABSTRACTS ############################################# def fetch_pubmed_abstracts(query: str, max_results: int = 5) -> List[str]: """ Retrieves PubMed abstracts for the given clinical query. Returns a list of abstract texts. """ search_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi" params = { "db": "pubmed", "term": query, "retmax": max_results, "api_key": PUBMED_API_KEY, "retmode": "json" } r = requests.get(search_url, params=params, timeout=10) r.raise_for_status() data = r.json() pmid_list = data["esearchresult"].get("idlist", []) abstracts = [] fetch_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi" for pmid in pmid_list: fetch_params = { "db": "pubmed", "id": pmid, "rettype": "abstract", "retmode": "text", "api_key": PUBMED_API_KEY } fetch_resp = requests.get(fetch_url, params=fetch_params, timeout=10) fetch_resp.raise_for_status() abstract_text = fetch_resp.text.strip() if abstract_text: abstracts.append(abstract_text) return abstracts ############################################# # 2) CHROMA + EMBEDDINGS SETUP ############################################# class EmbedFunction: """ Uses a Hugging Face embedding model to generate embeddings for clinical texts. """ def __init__(self, model_name: str): self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModel.from_pretrained(model_name) self.model.eval() def __call__(self, input: List[str]) -> List[List[float]]: if not input: return [] tokenized = self.tokenizer( input, return_tensors="pt", padding=True, truncation=True, max_length=512 ) with torch.no_grad(): outputs = self.model(**tokenized, output_hidden_states=True) # Mean-pooling over the last hidden state. last_hidden = outputs.hidden_states[-1] pooled = last_hidden.mean(dim=1) return pooled.cpu().tolist() EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" embed_function = EmbedFunction(EMBED_MODEL_NAME) # Create a temporary directory for the Chromadb persistent storage. temp_dir = tempfile.mkdtemp() print("Using temporary persist_directory:", temp_dir) client = chromadb.Client( settings=Settings( persist_directory=temp_dir, anonymized_telemetry=False ) ) # Create or retrieve the collection for clinical abstracts. collection = client.get_or_create_collection( name="ai_medical_knowledge", embedding_function=embed_function ) # Force initialization with a dummy document. try: collection.add(documents=["dummy"], ids=["dummy"]) _ = collection.query(query_texts=["dummy"], n_results=1) print("Dummy initialization successful.") except Exception as init_err: print("Dummy initialization failed:", init_err) def index_pubmed_docs(docs: List[str], prefix: str = "doc"): """ Indexes the retrieved PubMed abstracts into the Chromadb vector store. """ for i, doc in enumerate(docs): if doc.strip(): doc_id = f"{prefix}-{i}" try: collection.add(documents=[doc], ids=[doc_id]) print(f"Added document with id: {doc_id}") except Exception as e: print(f"Error adding document id {doc_id}: {e}") raise def query_similar_docs(query: str, top_k: int = 3) -> List[str]: """ Performs a similarity search on the indexed abstracts and returns the top relevant documents. """ results = collection.query(query_texts=[query], n_results=top_k) return results["documents"][0] if results and results["documents"] else [] ############################################# # 3) MAIN RETRIEVAL PIPELINE ############################################# def get_relevant_pubmed_docs(user_query: str) -> List[str]: """ Complete retrieval pipeline: 1. Fetch PubMed abstracts. 2. Index them into the vector store. 3. Retrieve and return the most similar documents. """ new_abstracts = fetch_pubmed_abstracts(user_query, max_results=5) if not new_abstracts: return [] index_pubmed_docs(new_abstracts, prefix=user_query) top_docs = query_similar_docs(user_query, top_k=3) return top_docs