mgbam commited on
Commit
6a2b285
·
verified ·
1 Parent(s): 1bc3e18

Update retrieval.py

Browse files
Files changed (1) hide show
  1. retrieval.py +20 -18
retrieval.py CHANGED
@@ -1,3 +1,11 @@
 
 
 
 
 
 
 
 
1
  import os
2
  import tempfile
3
  import requests
@@ -15,8 +23,8 @@ PUBMED_API_KEY = os.environ.get("PUBMED_API_KEY", "<YOUR_NCBI_API_KEY>")
15
  #############################################
16
  def fetch_pubmed_abstracts(query: str, max_results: int = 5) -> List[str]:
17
  """
18
- Retrieves PubMed abstracts for a given clinical query using NCBI's E-utilities.
19
- Designed to quickly fetch up to 'max_results' abstracts.
20
  """
21
  search_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
22
  params = {
@@ -29,7 +37,6 @@ def fetch_pubmed_abstracts(query: str, max_results: int = 5) -> List[str]:
29
  r = requests.get(search_url, params=params, timeout=10)
30
  r.raise_for_status()
31
  data = r.json()
32
-
33
  pmid_list = data["esearchresult"].get("idlist", [])
34
  abstracts = []
35
  fetch_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
@@ -53,8 +60,7 @@ def fetch_pubmed_abstracts(query: str, max_results: int = 5) -> List[str]:
53
  #############################################
54
  class EmbedFunction:
55
  """
56
- Uses a Hugging Face embedding model to generate embeddings for a list of strings.
57
- This function is crucial for indexing abstracts for similarity search.
58
  """
59
  def __init__(self, model_name: str):
60
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -73,15 +79,15 @@ class EmbedFunction:
73
  )
74
  with torch.no_grad():
75
  outputs = self.model(**tokenized, output_hidden_states=True)
 
76
  last_hidden = outputs.hidden_states[-1]
77
  pooled = last_hidden.mean(dim=1)
78
- embeddings = pooled.cpu().tolist()
79
- return embeddings
80
 
81
  EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
82
  embed_function = EmbedFunction(EMBED_MODEL_NAME)
83
 
84
- # Use a temporary directory for persistent storage to ensure a fresh initialization.
85
  temp_dir = tempfile.mkdtemp()
86
  print("Using temporary persist_directory:", temp_dir)
87
 
@@ -92,13 +98,13 @@ client = chromadb.Client(
92
  )
93
  )
94
 
95
- # Create or retrieve the collection for medical abstracts.
96
  collection = client.get_or_create_collection(
97
  name="ai_medical_knowledge",
98
  embedding_function=embed_function
99
  )
100
 
101
- # Optional: Force initialization with a dummy document to ensure the schema is set up.
102
  try:
103
  collection.add(documents=["dummy"], ids=["dummy"])
104
  _ = collection.query(query_texts=["dummy"], n_results=1)
@@ -108,8 +114,7 @@ except Exception as init_err:
108
 
109
  def index_pubmed_docs(docs: List[str], prefix: str = "doc"):
110
  """
111
- Indexes PubMed abstracts into the Chroma vector store.
112
- Each document is assigned a unique ID based on the query prefix.
113
  """
114
  for i, doc in enumerate(docs):
115
  if doc.strip():
@@ -123,8 +128,7 @@ def index_pubmed_docs(docs: List[str], prefix: str = "doc"):
123
 
124
  def query_similar_docs(query: str, top_k: int = 3) -> List[str]:
125
  """
126
- Searches the indexed abstracts for those most similar to the given query.
127
- Returns the top 'top_k' documents.
128
  """
129
  results = collection.query(query_texts=[query], n_results=top_k)
130
  return results["documents"][0] if results and results["documents"] else []
@@ -135,11 +139,9 @@ def query_similar_docs(query: str, top_k: int = 3) -> List[str]:
135
  def get_relevant_pubmed_docs(user_query: str) -> List[str]:
136
  """
137
  Complete retrieval pipeline:
138
- 1. Fetch PubMed abstracts for the query.
139
- 2. Index the abstracts into the vector store.
140
  3. Retrieve and return the most similar documents.
141
-
142
- Designed for clinicians to quickly access relevant literature.
143
  """
144
  new_abstracts = fetch_pubmed_abstracts(user_query, max_results=5)
145
  if not new_abstracts:
 
1
+ """
2
+ retrieval.py
3
+ ------------
4
+ This module handles retrieval of PubMed abstracts and indexing via Chromadb.
5
+ It fetches abstracts using NCBI's E-utilities and indexes them in a vector store
6
+ to enable similarity search for clinical queries.
7
+ """
8
+
9
  import os
10
  import tempfile
11
  import requests
 
23
  #############################################
24
  def fetch_pubmed_abstracts(query: str, max_results: int = 5) -> List[str]:
25
  """
26
+ Retrieves PubMed abstracts for the given clinical query.
27
+ Returns a list of abstract texts.
28
  """
29
  search_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
30
  params = {
 
37
  r = requests.get(search_url, params=params, timeout=10)
38
  r.raise_for_status()
39
  data = r.json()
 
40
  pmid_list = data["esearchresult"].get("idlist", [])
41
  abstracts = []
42
  fetch_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
 
60
  #############################################
61
  class EmbedFunction:
62
  """
63
+ Uses a Hugging Face embedding model to generate embeddings for clinical texts.
 
64
  """
65
  def __init__(self, model_name: str):
66
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
 
79
  )
80
  with torch.no_grad():
81
  outputs = self.model(**tokenized, output_hidden_states=True)
82
+ # Mean-pooling over the last hidden state.
83
  last_hidden = outputs.hidden_states[-1]
84
  pooled = last_hidden.mean(dim=1)
85
+ return pooled.cpu().tolist()
 
86
 
87
  EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
88
  embed_function = EmbedFunction(EMBED_MODEL_NAME)
89
 
90
+ # Create a temporary directory for the Chromadb persistent storage.
91
  temp_dir = tempfile.mkdtemp()
92
  print("Using temporary persist_directory:", temp_dir)
93
 
 
98
  )
99
  )
100
 
101
+ # Create or retrieve the collection for clinical abstracts.
102
  collection = client.get_or_create_collection(
103
  name="ai_medical_knowledge",
104
  embedding_function=embed_function
105
  )
106
 
107
+ # Force initialization with a dummy document.
108
  try:
109
  collection.add(documents=["dummy"], ids=["dummy"])
110
  _ = collection.query(query_texts=["dummy"], n_results=1)
 
114
 
115
  def index_pubmed_docs(docs: List[str], prefix: str = "doc"):
116
  """
117
+ Indexes the retrieved PubMed abstracts into the Chromadb vector store.
 
118
  """
119
  for i, doc in enumerate(docs):
120
  if doc.strip():
 
128
 
129
  def query_similar_docs(query: str, top_k: int = 3) -> List[str]:
130
  """
131
+ Performs a similarity search on the indexed abstracts and returns the top relevant documents.
 
132
  """
133
  results = collection.query(query_texts=[query], n_results=top_k)
134
  return results["documents"][0] if results and results["documents"] else []
 
139
  def get_relevant_pubmed_docs(user_query: str) -> List[str]:
140
  """
141
  Complete retrieval pipeline:
142
+ 1. Fetch PubMed abstracts.
143
+ 2. Index them into the vector store.
144
  3. Retrieve and return the most similar documents.
 
 
145
  """
146
  new_abstracts = fetch_pubmed_abstracts(user_query, max_results=5)
147
  if not new_abstracts: