Update retrieval.py
Browse files- retrieval.py +9 -5
retrieval.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import os
|
|
|
2 |
import requests
|
3 |
import torch
|
4 |
from typing import List
|
@@ -79,10 +80,13 @@ class EmbedFunction:
|
|
79 |
EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
80 |
embed_function = EmbedFunction(EMBED_MODEL_NAME)
|
81 |
|
82 |
-
# Use
|
|
|
|
|
|
|
83 |
client = chromadb.Client(
|
84 |
settings=Settings(
|
85 |
-
persist_directory=
|
86 |
anonymized_telemetry=False
|
87 |
)
|
88 |
)
|
@@ -95,7 +99,7 @@ collection = client.get_or_create_collection(
|
|
95 |
|
96 |
def index_pubmed_docs(docs: List[str], prefix: str = "doc"):
|
97 |
"""
|
98 |
-
Adds documents to the
|
99 |
"""
|
100 |
for i, doc in enumerate(docs):
|
101 |
if doc.strip():
|
@@ -109,7 +113,7 @@ def index_pubmed_docs(docs: List[str], prefix: str = "doc"):
|
|
109 |
|
110 |
def query_similar_docs(query: str, top_k: int = 3) -> List[str]:
|
111 |
"""
|
112 |
-
Retrieves the top_k similar documents from
|
113 |
"""
|
114 |
results = collection.query(query_texts=[query], n_results=top_k)
|
115 |
return results["documents"][0] if results and results["documents"] else []
|
@@ -121,7 +125,7 @@ def get_relevant_pubmed_docs(user_query: str) -> List[str]:
|
|
121 |
"""
|
122 |
End-to-end pipeline:
|
123 |
1. Fetch PubMed abstracts for the query.
|
124 |
-
2. Index them in
|
125 |
3. Retrieve the top relevant documents.
|
126 |
"""
|
127 |
new_abstracts = fetch_pubmed_abstracts(user_query, max_results=5)
|
|
|
1 |
import os
|
2 |
+
import tempfile
|
3 |
import requests
|
4 |
import torch
|
5 |
from typing import List
|
|
|
80 |
EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
81 |
embed_function = EmbedFunction(EMBED_MODEL_NAME)
|
82 |
|
83 |
+
# Use a temporary directory for persistent storage
|
84 |
+
temp_dir = tempfile.mkdtemp()
|
85 |
+
print("Using temporary persist_directory:", temp_dir)
|
86 |
+
|
87 |
client = chromadb.Client(
|
88 |
settings=Settings(
|
89 |
+
persist_directory=temp_dir,
|
90 |
anonymized_telemetry=False
|
91 |
)
|
92 |
)
|
|
|
99 |
|
100 |
def index_pubmed_docs(docs: List[str], prefix: str = "doc"):
|
101 |
"""
|
102 |
+
Adds documents to the Chromadb collection with unique IDs.
|
103 |
"""
|
104 |
for i, doc in enumerate(docs):
|
105 |
if doc.strip():
|
|
|
113 |
|
114 |
def query_similar_docs(query: str, top_k: int = 3) -> List[str]:
|
115 |
"""
|
116 |
+
Retrieves the top_k similar documents from Chromadb based on embedding similarity.
|
117 |
"""
|
118 |
results = collection.query(query_texts=[query], n_results=top_k)
|
119 |
return results["documents"][0] if results and results["documents"] else []
|
|
|
125 |
"""
|
126 |
End-to-end pipeline:
|
127 |
1. Fetch PubMed abstracts for the query.
|
128 |
+
2. Index them in Chromadb.
|
129 |
3. Retrieve the top relevant documents.
|
130 |
"""
|
131 |
new_abstracts = fetch_pubmed_abstracts(user_query, max_results=5)
|