mgbam commited on
Commit
6799b1d
·
verified ·
1 Parent(s): 5586859

Update retrieval.py

Browse files
Files changed (1) hide show
  1. retrieval.py +19 -15
retrieval.py CHANGED
@@ -7,7 +7,7 @@ import chromadb
7
  from chromadb.config import Settings
8
  from transformers import AutoTokenizer, AutoModel
9
 
10
- # Optional: Set your PubMed API key from environment variables
11
  PUBMED_API_KEY = os.environ.get("PUBMED_API_KEY", "<YOUR_NCBI_API_KEY>")
12
 
13
  #############################################
@@ -15,8 +15,8 @@ PUBMED_API_KEY = os.environ.get("PUBMED_API_KEY", "<YOUR_NCBI_API_KEY>")
15
  #############################################
16
  def fetch_pubmed_abstracts(query: str, max_results: int = 5) -> List[str]:
17
  """
18
- Fetches PubMed abstracts for the given query using NCBI's E-utilities.
19
- Returns a list of abstract texts.
20
  """
21
  search_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
22
  params = {
@@ -26,7 +26,7 @@ def fetch_pubmed_abstracts(query: str, max_results: int = 5) -> List[str]:
26
  "api_key": PUBMED_API_KEY,
27
  "retmode": "json"
28
  }
29
- r = requests.get(search_url, params=params)
30
  r.raise_for_status()
31
  data = r.json()
32
 
@@ -41,7 +41,7 @@ def fetch_pubmed_abstracts(query: str, max_results: int = 5) -> List[str]:
41
  "retmode": "text",
42
  "api_key": PUBMED_API_KEY
43
  }
44
- fetch_resp = requests.get(fetch_url, params=fetch_params)
45
  fetch_resp.raise_for_status()
46
  abstract_text = fetch_resp.text.strip()
47
  if abstract_text:
@@ -53,7 +53,8 @@ def fetch_pubmed_abstracts(query: str, max_results: int = 5) -> List[str]:
53
  #############################################
54
  class EmbedFunction:
55
  """
56
- Wraps a Hugging Face embedding model to produce embeddings for a list of strings.
 
57
  """
58
  def __init__(self, model_name: str):
59
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -80,7 +81,7 @@ class EmbedFunction:
80
  EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
81
  embed_function = EmbedFunction(EMBED_MODEL_NAME)
82
 
83
- # Use a temporary directory for persistent storage.
84
  temp_dir = tempfile.mkdtemp()
85
  print("Using temporary persist_directory:", temp_dir)
86
 
@@ -91,24 +92,24 @@ client = chromadb.Client(
91
  )
92
  )
93
 
94
- # Create or get the collection. Use a clear name.
95
  collection = client.get_or_create_collection(
96
  name="ai_medical_knowledge",
97
  embedding_function=embed_function
98
  )
99
 
100
- # Force initialization: add a dummy document and perform a dummy query.
101
  try:
102
  collection.add(documents=["dummy"], ids=["dummy"])
103
  _ = collection.query(query_texts=["dummy"], n_results=1)
104
- # Optionally, remove the dummy document if needed (Chromadb might not support deletion, so you can ignore it)
105
  print("Dummy initialization successful.")
106
  except Exception as init_err:
107
  print("Dummy initialization failed:", init_err)
108
 
109
  def index_pubmed_docs(docs: List[str], prefix: str = "doc"):
110
  """
111
- Adds documents to the Chromadb collection with unique IDs.
 
112
  """
113
  for i, doc in enumerate(docs):
114
  if doc.strip():
@@ -122,7 +123,8 @@ def index_pubmed_docs(docs: List[str], prefix: str = "doc"):
122
 
123
  def query_similar_docs(query: str, top_k: int = 3) -> List[str]:
124
  """
125
- Retrieves the top_k similar documents from Chromadb based on embedding similarity.
 
126
  """
127
  results = collection.query(query_texts=[query], n_results=top_k)
128
  return results["documents"][0] if results and results["documents"] else []
@@ -132,10 +134,12 @@ def query_similar_docs(query: str, top_k: int = 3) -> List[str]:
132
  #############################################
133
  def get_relevant_pubmed_docs(user_query: str) -> List[str]:
134
  """
135
- End-to-end pipeline:
136
  1. Fetch PubMed abstracts for the query.
137
- 2. Index them in Chromadb.
138
- 3. Retrieve the top relevant documents.
 
 
139
  """
140
  new_abstracts = fetch_pubmed_abstracts(user_query, max_results=5)
141
  if not new_abstracts:
 
7
  from chromadb.config import Settings
8
  from transformers import AutoTokenizer, AutoModel
9
 
10
+ # Optional: Set your PubMed API key from environment variables.
11
  PUBMED_API_KEY = os.environ.get("PUBMED_API_KEY", "<YOUR_NCBI_API_KEY>")
12
 
13
  #############################################
 
15
  #############################################
16
  def fetch_pubmed_abstracts(query: str, max_results: int = 5) -> List[str]:
17
  """
18
+ Retrieves PubMed abstracts for a given clinical query using NCBI's E-utilities.
19
+ Designed to quickly fetch up to 'max_results' abstracts.
20
  """
21
  search_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
22
  params = {
 
26
  "api_key": PUBMED_API_KEY,
27
  "retmode": "json"
28
  }
29
+ r = requests.get(search_url, params=params, timeout=10)
30
  r.raise_for_status()
31
  data = r.json()
32
 
 
41
  "retmode": "text",
42
  "api_key": PUBMED_API_KEY
43
  }
44
+ fetch_resp = requests.get(fetch_url, params=fetch_params, timeout=10)
45
  fetch_resp.raise_for_status()
46
  abstract_text = fetch_resp.text.strip()
47
  if abstract_text:
 
53
  #############################################
54
  class EmbedFunction:
55
  """
56
+ Uses a Hugging Face embedding model to generate embeddings for a list of strings.
57
+ This function is crucial for indexing abstracts for similarity search.
58
  """
59
  def __init__(self, model_name: str):
60
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
 
81
  EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
82
  embed_function = EmbedFunction(EMBED_MODEL_NAME)
83
 
84
+ # Use a temporary directory for persistent storage to ensure a fresh initialization.
85
  temp_dir = tempfile.mkdtemp()
86
  print("Using temporary persist_directory:", temp_dir)
87
 
 
92
  )
93
  )
94
 
95
+ # Create or retrieve the collection for medical abstracts.
96
  collection = client.get_or_create_collection(
97
  name="ai_medical_knowledge",
98
  embedding_function=embed_function
99
  )
100
 
101
+ # Optional: Force initialization with a dummy document to ensure the schema is set up.
102
  try:
103
  collection.add(documents=["dummy"], ids=["dummy"])
104
  _ = collection.query(query_texts=["dummy"], n_results=1)
 
105
  print("Dummy initialization successful.")
106
  except Exception as init_err:
107
  print("Dummy initialization failed:", init_err)
108
 
109
  def index_pubmed_docs(docs: List[str], prefix: str = "doc"):
110
  """
111
+ Indexes PubMed abstracts into the Chroma vector store.
112
+ Each document is assigned a unique ID based on the query prefix.
113
  """
114
  for i, doc in enumerate(docs):
115
  if doc.strip():
 
123
 
124
  def query_similar_docs(query: str, top_k: int = 3) -> List[str]:
125
  """
126
+ Searches the indexed abstracts for those most similar to the given query.
127
+ Returns the top 'top_k' documents.
128
  """
129
  results = collection.query(query_texts=[query], n_results=top_k)
130
  return results["documents"][0] if results and results["documents"] else []
 
134
  #############################################
135
  def get_relevant_pubmed_docs(user_query: str) -> List[str]:
136
  """
137
+ Complete retrieval pipeline:
138
  1. Fetch PubMed abstracts for the query.
139
+ 2. Index the abstracts into the vector store.
140
+ 3. Retrieve and return the most similar documents.
141
+
142
+ Designed for clinicians to quickly access relevant literature.
143
  """
144
  new_abstracts = fetch_pubmed_abstracts(user_query, max_results=5)
145
  if not new_abstracts: