Medic / retrieval.py
mgbam's picture
Update retrieval.py
9920af3 verified
raw
history blame
4.44 kB
import os
import requests
import torch
from typing import List
import chromadb
from chromadb.config import Settings
from transformers import AutoTokenizer, AutoModel
# Optional: Set your PubMed API key from environment variables
PUBMED_API_KEY = os.environ.get("PUBMED_API_KEY", "<YOUR_NCBI_API_KEY>")
#############################################
# 1) FETCH PUBMED ABSTRACTS
#############################################
def fetch_pubmed_abstracts(query: str, max_results: int = 5) -> List[str]:
"""
Fetches PubMed abstracts for the given query using NCBI's E-utilities.
Returns a list of abstract texts.
"""
search_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
params = {
"db": "pubmed",
"term": query,
"retmax": max_results,
"api_key": PUBMED_API_KEY,
"retmode": "json"
}
r = requests.get(search_url, params=params)
r.raise_for_status()
data = r.json()
pmid_list = data["esearchresult"].get("idlist", [])
abstracts = []
fetch_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
for pmid in pmid_list:
fetch_params = {
"db": "pubmed",
"id": pmid,
"rettype": "abstract",
"retmode": "text",
"api_key": PUBMED_API_KEY
}
fetch_resp = requests.get(fetch_url, params=fetch_params)
fetch_resp.raise_for_status()
abstract_text = fetch_resp.text.strip()
if abstract_text:
abstracts.append(abstract_text)
return abstracts
#############################################
# 2) CHROMA + EMBEDDINGS SETUP
#############################################
class EmbedFunction:
"""
Wraps a Hugging Face embedding model to produce embeddings for a list of strings.
"""
def __init__(self, model_name: str):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
self.model.eval()
def __call__(self, input: List[str]) -> List[List[float]]:
if not input:
return []
tokenized = self.tokenizer(
input,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
with torch.no_grad():
outputs = self.model(**tokenized, output_hidden_states=True)
last_hidden = outputs.hidden_states[-1]
pooled = last_hidden.mean(dim=1)
embeddings = pooled.cpu().tolist()
return embeddings
EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
embed_function = EmbedFunction(EMBED_MODEL_NAME)
# Use ephemeral (in-memory) storage to avoid persistent storage issues.
client = chromadb.Client(
settings=Settings(
in_memory=True,
anonymized_telemetry=False
)
)
# Create or get the collection. Use a clear name.
collection = client.get_or_create_collection(
name="ai_medical_knowledge",
embedding_function=embed_function
)
def index_pubmed_docs(docs: List[str], prefix: str = "doc"):
"""
Adds documents to the Chroma collection with unique IDs.
"""
for i, doc in enumerate(docs):
if doc.strip():
doc_id = f"{prefix}-{i}"
try:
collection.add(documents=[doc], ids=[doc_id])
print(f"Added document with id: {doc_id}")
except Exception as e:
print(f"Error adding document id {doc_id}: {e}")
raise
def query_similar_docs(query: str, top_k: int = 3) -> List[str]:
"""
Retrieves the top_k similar documents from Chroma based on embedding similarity.
"""
results = collection.query(query_texts=[query], n_results=top_k)
return results["documents"][0] if results and results["documents"] else []
#############################################
# 3) MAIN RETRIEVAL PIPELINE
#############################################
def get_relevant_pubmed_docs(user_query: str) -> List[str]:
"""
End-to-end pipeline:
1. Fetch PubMed abstracts for the query.
2. Index them in Chroma.
3. Retrieve the top relevant documents.
"""
new_abstracts = fetch_pubmed_abstracts(user_query, max_results=5)
if not new_abstracts:
return []
index_pubmed_docs(new_abstracts, prefix=user_query)
top_docs = query_similar_docs(user_query, top_k=3)
return top_docs