|
import os |
|
import tempfile |
|
import requests |
|
import torch |
|
from typing import List |
|
import chromadb |
|
from chromadb.config import Settings |
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
|
|
PUBMED_API_KEY = os.environ.get("PUBMED_API_KEY", "<YOUR_NCBI_API_KEY>") |
|
|
|
|
|
|
|
|
|
def fetch_pubmed_abstracts(query: str, max_results: int = 5) -> List[str]: |
|
""" |
|
Retrieves PubMed abstracts for a given clinical query using NCBI's E-utilities. |
|
Designed to quickly fetch up to 'max_results' abstracts. |
|
""" |
|
search_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi" |
|
params = { |
|
"db": "pubmed", |
|
"term": query, |
|
"retmax": max_results, |
|
"api_key": PUBMED_API_KEY, |
|
"retmode": "json" |
|
} |
|
r = requests.get(search_url, params=params, timeout=10) |
|
r.raise_for_status() |
|
data = r.json() |
|
|
|
pmid_list = data["esearchresult"].get("idlist", []) |
|
abstracts = [] |
|
fetch_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi" |
|
for pmid in pmid_list: |
|
fetch_params = { |
|
"db": "pubmed", |
|
"id": pmid, |
|
"rettype": "abstract", |
|
"retmode": "text", |
|
"api_key": PUBMED_API_KEY |
|
} |
|
fetch_resp = requests.get(fetch_url, params=fetch_params, timeout=10) |
|
fetch_resp.raise_for_status() |
|
abstract_text = fetch_resp.text.strip() |
|
if abstract_text: |
|
abstracts.append(abstract_text) |
|
return abstracts |
|
|
|
|
|
|
|
|
|
class EmbedFunction: |
|
""" |
|
Uses a Hugging Face embedding model to generate embeddings for a list of strings. |
|
This function is crucial for indexing abstracts for similarity search. |
|
""" |
|
def __init__(self, model_name: str): |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
self.model = AutoModel.from_pretrained(model_name) |
|
self.model.eval() |
|
|
|
def __call__(self, input: List[str]) -> List[List[float]]: |
|
if not input: |
|
return [] |
|
tokenized = self.tokenizer( |
|
input, |
|
return_tensors="pt", |
|
padding=True, |
|
truncation=True, |
|
max_length=512 |
|
) |
|
with torch.no_grad(): |
|
outputs = self.model(**tokenized, output_hidden_states=True) |
|
last_hidden = outputs.hidden_states[-1] |
|
pooled = last_hidden.mean(dim=1) |
|
embeddings = pooled.cpu().tolist() |
|
return embeddings |
|
|
|
EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" |
|
embed_function = EmbedFunction(EMBED_MODEL_NAME) |
|
|
|
|
|
temp_dir = tempfile.mkdtemp() |
|
print("Using temporary persist_directory:", temp_dir) |
|
|
|
client = chromadb.Client( |
|
settings=Settings( |
|
persist_directory=temp_dir, |
|
anonymized_telemetry=False |
|
) |
|
) |
|
|
|
|
|
collection = client.get_or_create_collection( |
|
name="ai_medical_knowledge", |
|
embedding_function=embed_function |
|
) |
|
|
|
|
|
try: |
|
collection.add(documents=["dummy"], ids=["dummy"]) |
|
_ = collection.query(query_texts=["dummy"], n_results=1) |
|
print("Dummy initialization successful.") |
|
except Exception as init_err: |
|
print("Dummy initialization failed:", init_err) |
|
|
|
def index_pubmed_docs(docs: List[str], prefix: str = "doc"): |
|
""" |
|
Indexes PubMed abstracts into the Chroma vector store. |
|
Each document is assigned a unique ID based on the query prefix. |
|
""" |
|
for i, doc in enumerate(docs): |
|
if doc.strip(): |
|
doc_id = f"{prefix}-{i}" |
|
try: |
|
collection.add(documents=[doc], ids=[doc_id]) |
|
print(f"Added document with id: {doc_id}") |
|
except Exception as e: |
|
print(f"Error adding document id {doc_id}: {e}") |
|
raise |
|
|
|
def query_similar_docs(query: str, top_k: int = 3) -> List[str]: |
|
""" |
|
Searches the indexed abstracts for those most similar to the given query. |
|
Returns the top 'top_k' documents. |
|
""" |
|
results = collection.query(query_texts=[query], n_results=top_k) |
|
return results["documents"][0] if results and results["documents"] else [] |
|
|
|
|
|
|
|
|
|
def get_relevant_pubmed_docs(user_query: str) -> List[str]: |
|
""" |
|
Complete retrieval pipeline: |
|
1. Fetch PubMed abstracts for the query. |
|
2. Index the abstracts into the vector store. |
|
3. Retrieve and return the most similar documents. |
|
|
|
Designed for clinicians to quickly access relevant literature. |
|
""" |
|
new_abstracts = fetch_pubmed_abstracts(user_query, max_results=5) |
|
if not new_abstracts: |
|
return [] |
|
index_pubmed_docs(new_abstracts, prefix=user_query) |
|
top_docs = query_similar_docs(user_query, top_k=3) |
|
return top_docs |
|
|