File size: 4,971 Bytes
b8986a1
d183895
b8986a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118ab17
d183895
 
 
b8986a1
 
d183895
b8986a1
 
 
 
028a072
b8986a1
 
 
 
 
118ab17
 
 
 
 
 
 
 
 
b8986a1
 
d183895
b8986a1
 
 
 
028a072
 
 
 
 
 
b8986a1
 
 
d183895
b8986a1
 
 
 
 
 
 
 
 
 
 
d183895
b8986a1
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os
import tempfile
import requests
import torch
from typing import List
import chromadb
from chromadb.config import Settings
from transformers import AutoTokenizer, AutoModel

# Optional: Set your PubMed API key from environment variables
PUBMED_API_KEY = os.environ.get("PUBMED_API_KEY", "<YOUR_NCBI_API_KEY>")

#############################################
# 1) FETCH PUBMED ABSTRACTS
#############################################
def fetch_pubmed_abstracts(query: str, max_results: int = 5) -> List[str]:
    """
    Fetches PubMed abstracts for the given query using NCBI's E-utilities.
    Returns a list of abstract texts.
    """
    search_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
    params = {
        "db": "pubmed",
        "term": query,
        "retmax": max_results,
        "api_key": PUBMED_API_KEY,
        "retmode": "json"
    }
    r = requests.get(search_url, params=params)
    r.raise_for_status()
    data = r.json()

    pmid_list = data["esearchresult"].get("idlist", [])
    abstracts = []
    fetch_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
    for pmid in pmid_list:
        fetch_params = {
            "db": "pubmed",
            "id": pmid,
            "rettype": "abstract",
            "retmode": "text",
            "api_key": PUBMED_API_KEY
        }
        fetch_resp = requests.get(fetch_url, params=fetch_params)
        fetch_resp.raise_for_status()
        abstract_text = fetch_resp.text.strip()
        if abstract_text:
            abstracts.append(abstract_text)
    return abstracts

#############################################
# 2) CHROMA + EMBEDDINGS SETUP
#############################################
class EmbedFunction:
    """
    Wraps a Hugging Face embedding model to produce embeddings for a list of strings.
    """
    def __init__(self, model_name: str):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)
        self.model.eval()
    
    def __call__(self, input: List[str]) -> List[List[float]]:
        if not input:
            return []
        tokenized = self.tokenizer(
            input,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=512
        )
        with torch.no_grad():
            outputs = self.model(**tokenized, output_hidden_states=True)
        last_hidden = outputs.hidden_states[-1]
        pooled = last_hidden.mean(dim=1)
        embeddings = pooled.cpu().tolist()
        return embeddings

EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
embed_function = EmbedFunction(EMBED_MODEL_NAME)

# Use a temporary directory for persistent storage.
temp_dir = tempfile.mkdtemp()
print("Using temporary persist_directory:", temp_dir)

client = chromadb.Client(
    settings=Settings(
        persist_directory=temp_dir,
        anonymized_telemetry=False
    )
)

# Create or get the collection. Use a clear name.
collection = client.get_or_create_collection(
    name="ai_medical_knowledge",
    embedding_function=embed_function
)

# Force initialization: add a dummy document and perform a dummy query.
try:
    collection.add(documents=["dummy"], ids=["dummy"])
    _ = collection.query(query_texts=["dummy"], n_results=1)
    # Optionally, remove the dummy document if needed (Chromadb might not support deletion, so you can ignore it)
    print("Dummy initialization successful.")
except Exception as init_err:
    print("Dummy initialization failed:", init_err)

def index_pubmed_docs(docs: List[str], prefix: str = "doc"):
    """
    Adds documents to the Chromadb collection with unique IDs.
    """
    for i, doc in enumerate(docs):
        if doc.strip():
            doc_id = f"{prefix}-{i}"
            try:
                collection.add(documents=[doc], ids=[doc_id])
                print(f"Added document with id: {doc_id}")
            except Exception as e:
                print(f"Error adding document id {doc_id}: {e}")
                raise

def query_similar_docs(query: str, top_k: int = 3) -> List[str]:
    """
    Retrieves the top_k similar documents from Chromadb based on embedding similarity.
    """
    results = collection.query(query_texts=[query], n_results=top_k)
    return results["documents"][0] if results and results["documents"] else []

#############################################
# 3) MAIN RETRIEVAL PIPELINE
#############################################
def get_relevant_pubmed_docs(user_query: str) -> List[str]:
    """
    End-to-end pipeline:
      1. Fetch PubMed abstracts for the query.
      2. Index them in Chromadb.
      3. Retrieve the top relevant documents.
    """
    new_abstracts = fetch_pubmed_abstracts(user_query, max_results=5)
    if not new_abstracts:
        return []
    index_pubmed_docs(new_abstracts, prefix=user_query)
    top_docs = query_similar_docs(user_query, top_k=3)
    return top_docs