Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
@@ -9,12 +9,12 @@ from transformers import AutoModel, AutoTokenizer
|
|
9 |
|
10 |
# Hugging Face Cache Directory
|
11 |
os.environ["HF_HOME"] = "/app/huggingface"
|
12 |
-
os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "60" # Increase timeout
|
13 |
|
14 |
app = FastAPI()
|
15 |
|
16 |
-
# --- Load Clinical Trials CSV
|
17 |
-
csv_path = "ctg-studies.csv"
|
18 |
if os.path.exists(csv_path):
|
19 |
df_trials = pd.read_csv(csv_path)
|
20 |
print("✅ CSV File Loaded Successfully!")
|
@@ -59,19 +59,27 @@ def get_trial_info(nct_id):
|
|
59 |
async def retrieve_trial(request: QueryRequest):
|
60 |
"""Retrieve Clinical Trial based on text"""
|
61 |
query_vector = generate_embedding(request.text)
|
|
|
|
|
|
|
|
|
|
|
62 |
distances, indices = index.search(query_vector, request.top_k)
|
63 |
|
64 |
matched_trials = []
|
65 |
for idx, dist in zip(indices[0], distances[0]):
|
66 |
if idx < len(df_trials): # Ensure index is within bounds
|
67 |
-
nct_id = df_trials.iloc[idx]["NCT Number"]
|
68 |
-
trial_data = get_trial_info(nct_id)
|
69 |
-
|
70 |
if trial_data:
|
|
|
71 |
if np.isfinite(dist) and dist >= 0:
|
72 |
trial_data["similarity"] = float(round(100 / (1 + dist), 2))
|
73 |
else:
|
|
|
74 |
trial_data["similarity"] = 0.0
|
|
|
75 |
matched_trials.append(trial_data)
|
76 |
|
77 |
return {"matched_trials": matched_trials}
|
|
|
9 |
|
10 |
# Hugging Face Cache Directory
|
11 |
os.environ["HF_HOME"] = "/app/huggingface"
|
12 |
+
os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "60" # Increase timeout
|
13 |
|
14 |
app = FastAPI()
|
15 |
|
16 |
+
# --- Load Clinical Trials CSV ---
|
17 |
+
csv_path = "ctg-studies.csv"
|
18 |
if os.path.exists(csv_path):
|
19 |
df_trials = pd.read_csv(csv_path)
|
20 |
print("✅ CSV File Loaded Successfully!")
|
|
|
59 |
async def retrieve_trial(request: QueryRequest):
|
60 |
"""Retrieve Clinical Trial based on text"""
|
61 |
query_vector = generate_embedding(request.text)
|
62 |
+
|
63 |
+
# Check if FAISS index has vectors
|
64 |
+
if index.ntotal == 0:
|
65 |
+
return {"error": "FAISS index is empty. No trials available."}
|
66 |
+
|
67 |
distances, indices = index.search(query_vector, request.top_k)
|
68 |
|
69 |
matched_trials = []
|
70 |
for idx, dist in zip(indices[0], distances[0]):
|
71 |
if idx < len(df_trials): # Ensure index is within bounds
|
72 |
+
nct_id = df_trials.iloc[idx]["NCT Number"]
|
73 |
+
trial_data = get_trial_info(nct_id)
|
74 |
+
|
75 |
if trial_data:
|
76 |
+
# Handle NaN & Inf distances safely
|
77 |
if np.isfinite(dist) and dist >= 0:
|
78 |
trial_data["similarity"] = float(round(100 / (1 + dist), 2))
|
79 |
else:
|
80 |
+
print(f"⚠️ Invalid distance detected: {dist}")
|
81 |
trial_data["similarity"] = 0.0
|
82 |
+
|
83 |
matched_trials.append(trial_data)
|
84 |
|
85 |
return {"matched_trials": matched_trials}
|