File size: 5,251 Bytes
b8986a1 d183895 b8986a1 6799b1d b8986a1 6799b1d b8986a1 6799b1d b8986a1 6799b1d b8986a1 6799b1d b8986a1 6799b1d d183895 b8986a1 d183895 b8986a1 6799b1d b8986a1 6799b1d 118ab17 b8986a1 6799b1d b8986a1 028a072 b8986a1 6799b1d b8986a1 6799b1d b8986a1 6799b1d b8986a1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import os
import tempfile
import requests
import torch
from typing import List
import chromadb
from chromadb.config import Settings
from transformers import AutoTokenizer, AutoModel
# Optional: Set your PubMed API key from environment variables.
PUBMED_API_KEY = os.environ.get("PUBMED_API_KEY", "<YOUR_NCBI_API_KEY>")
#############################################
# 1) FETCH PUBMED ABSTRACTS
#############################################
def fetch_pubmed_abstracts(query: str, max_results: int = 5) -> List[str]:
"""
Retrieves PubMed abstracts for a given clinical query using NCBI's E-utilities.
Designed to quickly fetch up to 'max_results' abstracts.
"""
search_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
params = {
"db": "pubmed",
"term": query,
"retmax": max_results,
"api_key": PUBMED_API_KEY,
"retmode": "json"
}
r = requests.get(search_url, params=params, timeout=10)
r.raise_for_status()
data = r.json()
pmid_list = data["esearchresult"].get("idlist", [])
abstracts = []
fetch_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
for pmid in pmid_list:
fetch_params = {
"db": "pubmed",
"id": pmid,
"rettype": "abstract",
"retmode": "text",
"api_key": PUBMED_API_KEY
}
fetch_resp = requests.get(fetch_url, params=fetch_params, timeout=10)
fetch_resp.raise_for_status()
abstract_text = fetch_resp.text.strip()
if abstract_text:
abstracts.append(abstract_text)
return abstracts
#############################################
# 2) CHROMA + EMBEDDINGS SETUP
#############################################
class EmbedFunction:
"""
Uses a Hugging Face embedding model to generate embeddings for a list of strings.
This function is crucial for indexing abstracts for similarity search.
"""
def __init__(self, model_name: str):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
self.model.eval()
def __call__(self, input: List[str]) -> List[List[float]]:
if not input:
return []
tokenized = self.tokenizer(
input,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
with torch.no_grad():
outputs = self.model(**tokenized, output_hidden_states=True)
last_hidden = outputs.hidden_states[-1]
pooled = last_hidden.mean(dim=1)
embeddings = pooled.cpu().tolist()
return embeddings
EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
embed_function = EmbedFunction(EMBED_MODEL_NAME)
# Use a temporary directory for persistent storage to ensure a fresh initialization.
temp_dir = tempfile.mkdtemp()
print("Using temporary persist_directory:", temp_dir)
client = chromadb.Client(
settings=Settings(
persist_directory=temp_dir,
anonymized_telemetry=False
)
)
# Create or retrieve the collection for medical abstracts.
collection = client.get_or_create_collection(
name="ai_medical_knowledge",
embedding_function=embed_function
)
# Optional: Force initialization with a dummy document to ensure the schema is set up.
try:
collection.add(documents=["dummy"], ids=["dummy"])
_ = collection.query(query_texts=["dummy"], n_results=1)
print("Dummy initialization successful.")
except Exception as init_err:
print("Dummy initialization failed:", init_err)
def index_pubmed_docs(docs: List[str], prefix: str = "doc"):
"""
Indexes PubMed abstracts into the Chroma vector store.
Each document is assigned a unique ID based on the query prefix.
"""
for i, doc in enumerate(docs):
if doc.strip():
doc_id = f"{prefix}-{i}"
try:
collection.add(documents=[doc], ids=[doc_id])
print(f"Added document with id: {doc_id}")
except Exception as e:
print(f"Error adding document id {doc_id}: {e}")
raise
def query_similar_docs(query: str, top_k: int = 3) -> List[str]:
"""
Searches the indexed abstracts for those most similar to the given query.
Returns the top 'top_k' documents.
"""
results = collection.query(query_texts=[query], n_results=top_k)
return results["documents"][0] if results and results["documents"] else []
#############################################
# 3) MAIN RETRIEVAL PIPELINE
#############################################
def get_relevant_pubmed_docs(user_query: str) -> List[str]:
"""
Complete retrieval pipeline:
1. Fetch PubMed abstracts for the query.
2. Index the abstracts into the vector store.
3. Retrieve and return the most similar documents.
Designed for clinicians to quickly access relevant literature.
"""
new_abstracts = fetch_pubmed_abstracts(user_query, max_results=5)
if not new_abstracts:
return []
index_pubmed_docs(new_abstracts, prefix=user_query)
top_docs = query_similar_docs(user_query, top_k=3)
return top_docs
|