import os import tempfile import requests import torch from typing import List import chromadb from chromadb.config import Settings from transformers import AutoTokenizer, AutoModel # Optional: Set your PubMed API key from environment variables PUBMED_API_KEY = os.environ.get("PUBMED_API_KEY", "") ############################################# # 1) FETCH PUBMED ABSTRACTS ############################################# def fetch_pubmed_abstracts(query: str, max_results: int = 5) -> List[str]: """ Fetches PubMed abstracts for the given query using NCBI's E-utilities. Returns a list of abstract texts. """ search_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi" params = { "db": "pubmed", "term": query, "retmax": max_results, "api_key": PUBMED_API_KEY, "retmode": "json" } r = requests.get(search_url, params=params) r.raise_for_status() data = r.json() pmid_list = data["esearchresult"].get("idlist", []) abstracts = [] fetch_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi" for pmid in pmid_list: fetch_params = { "db": "pubmed", "id": pmid, "rettype": "abstract", "retmode": "text", "api_key": PUBMED_API_KEY } fetch_resp = requests.get(fetch_url, params=fetch_params) fetch_resp.raise_for_status() abstract_text = fetch_resp.text.strip() if abstract_text: abstracts.append(abstract_text) return abstracts ############################################# # 2) CHROMA + EMBEDDINGS SETUP ############################################# class EmbedFunction: """ Wraps a Hugging Face embedding model to produce embeddings for a list of strings. """ def __init__(self, model_name: str): self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModel.from_pretrained(model_name) self.model.eval() def __call__(self, input: List[str]) -> List[List[float]]: if not input: return [] tokenized = self.tokenizer( input, return_tensors="pt", padding=True, truncation=True, max_length=512 ) with torch.no_grad(): outputs = self.model(**tokenized, output_hidden_states=True) last_hidden = outputs.hidden_states[-1] pooled = last_hidden.mean(dim=1) embeddings = pooled.cpu().tolist() return embeddings EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" embed_function = EmbedFunction(EMBED_MODEL_NAME) # Use a temporary directory for persistent storage. temp_dir = tempfile.mkdtemp() print("Using temporary persist_directory:", temp_dir) client = chromadb.Client( settings=Settings( persist_directory=temp_dir, anonymized_telemetry=False ) ) # Create or get the collection. Use a clear name. collection = client.get_or_create_collection( name="ai_medical_knowledge", embedding_function=embed_function ) # Force initialization: add a dummy document and perform a dummy query. try: collection.add(documents=["dummy"], ids=["dummy"]) _ = collection.query(query_texts=["dummy"], n_results=1) # Optionally, remove the dummy document if needed (Chromadb might not support deletion, so you can ignore it) print("Dummy initialization successful.") except Exception as init_err: print("Dummy initialization failed:", init_err) def index_pubmed_docs(docs: List[str], prefix: str = "doc"): """ Adds documents to the Chromadb collection with unique IDs. """ for i, doc in enumerate(docs): if doc.strip(): doc_id = f"{prefix}-{i}" try: collection.add(documents=[doc], ids=[doc_id]) print(f"Added document with id: {doc_id}") except Exception as e: print(f"Error adding document id {doc_id}: {e}") raise def query_similar_docs(query: str, top_k: int = 3) -> List[str]: """ Retrieves the top_k similar documents from Chromadb based on embedding similarity. """ results = collection.query(query_texts=[query], n_results=top_k) return results["documents"][0] if results and results["documents"] else [] ############################################# # 3) MAIN RETRIEVAL PIPELINE ############################################# def get_relevant_pubmed_docs(user_query: str) -> List[str]: """ End-to-end pipeline: 1. Fetch PubMed abstracts for the query. 2. Index them in Chromadb. 3. Retrieve the top relevant documents. """ new_abstracts = fetch_pubmed_abstracts(user_query, max_results=5) if not new_abstracts: return [] index_pubmed_docs(new_abstracts, prefix=user_query) top_docs = query_similar_docs(user_query, top_k=3) return top_docs