priyanandanwar commited on
Commit
fbbf31c
·
verified ·
1 Parent(s): a8d46a3

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +14 -6
main.py CHANGED
@@ -9,12 +9,12 @@ from transformers import AutoModel, AutoTokenizer
9
 
10
  # Hugging Face Cache Directory
11
  os.environ["HF_HOME"] = "/app/huggingface"
12
- os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "60" # Increase timeout to 60 seconds
13
 
14
  app = FastAPI()
15
 
16
- # --- Load Clinical Trials CSV (for metadata lookup) ---
17
- csv_path = "ctg-studies.csv" # Ensure this file is uploaded
18
  if os.path.exists(csv_path):
19
  df_trials = pd.read_csv(csv_path)
20
  print("✅ CSV File Loaded Successfully!")
@@ -59,19 +59,27 @@ def get_trial_info(nct_id):
59
  async def retrieve_trial(request: QueryRequest):
60
  """Retrieve Clinical Trial based on text"""
61
  query_vector = generate_embedding(request.text)
 
 
 
 
 
62
  distances, indices = index.search(query_vector, request.top_k)
63
 
64
  matched_trials = []
65
  for idx, dist in zip(indices[0], distances[0]):
66
  if idx < len(df_trials): # Ensure index is within bounds
67
- nct_id = df_trials.iloc[idx]["NCT Number"] # Get NCT Number using FAISS index mapping
68
- trial_data = get_trial_info(nct_id) # Fetch complete trial details
69
-
70
  if trial_data:
 
71
  if np.isfinite(dist) and dist >= 0:
72
  trial_data["similarity"] = float(round(100 / (1 + dist), 2))
73
  else:
 
74
  trial_data["similarity"] = 0.0
 
75
  matched_trials.append(trial_data)
76
 
77
  return {"matched_trials": matched_trials}
 
9
 
10
  # Hugging Face Cache Directory
11
  os.environ["HF_HOME"] = "/app/huggingface"
12
+ os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "60" # Increase timeout
13
 
14
  app = FastAPI()
15
 
16
+ # --- Load Clinical Trials CSV ---
17
+ csv_path = "ctg-studies.csv"
18
  if os.path.exists(csv_path):
19
  df_trials = pd.read_csv(csv_path)
20
  print("✅ CSV File Loaded Successfully!")
 
59
  async def retrieve_trial(request: QueryRequest):
60
  """Retrieve Clinical Trial based on text"""
61
  query_vector = generate_embedding(request.text)
62
+
63
+ # Check if FAISS index has vectors
64
+ if index.ntotal == 0:
65
+ return {"error": "FAISS index is empty. No trials available."}
66
+
67
  distances, indices = index.search(query_vector, request.top_k)
68
 
69
  matched_trials = []
70
  for idx, dist in zip(indices[0], distances[0]):
71
  if idx < len(df_trials): # Ensure index is within bounds
72
+ nct_id = df_trials.iloc[idx]["NCT Number"]
73
+ trial_data = get_trial_info(nct_id)
74
+
75
  if trial_data:
76
+ # Handle NaN & Inf distances safely
77
  if np.isfinite(dist) and dist >= 0:
78
  trial_data["similarity"] = float(round(100 / (1 + dist), 2))
79
  else:
80
+ print(f"⚠️ Invalid distance detected: {dist}")
81
  trial_data["similarity"] = 0.0
82
+
83
  matched_trials.append(trial_data)
84
 
85
  return {"matched_trials": matched_trials}