priyanandanwar commited on
Commit
6927b07
Β·
verified Β·
1 Parent(s): cf4089c

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +41 -31
main.py CHANGED
@@ -7,18 +7,19 @@ from fastapi import FastAPI
7
  from pydantic import BaseModel
8
  from transformers import AutoModel, AutoTokenizer
9
 
 
10
  os.environ["HF_HOME"] = "/app/huggingface"
11
  os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "60"
12
 
13
  app = FastAPI()
14
 
15
- # --- Load CSV ---
16
  csv_path = "ctg-studies.csv"
17
  if os.path.exists(csv_path):
18
  df_trials = pd.read_csv(csv_path)
19
- print("βœ… CSV Loaded!")
20
  else:
21
- raise FileNotFoundError("❌ CSV File Not Found!")
22
 
23
  # --- Load FAISS Index ---
24
  dimension = 768
@@ -29,9 +30,9 @@ if os.path.exists(faiss_index_path):
29
  print("βœ… FAISS Index Loaded!")
30
  else:
31
  index = faiss.IndexFlatL2(dimension)
32
- print("⚠️ FAISS Index Empty!")
33
 
34
- # --- Load Model ---
35
  retrieval_model_name = "priyanandanwar/fine-tuned-gatortron"
36
  retrieval_tokenizer = AutoTokenizer.from_pretrained(retrieval_model_name)
37
  retrieval_model = AutoModel.from_pretrained(retrieval_model_name)
@@ -41,46 +42,55 @@ class QueryRequest(BaseModel):
41
  text: str
42
  top_k: int = 5
43
 
44
- # --- Generate Embedding ---
45
  def generate_embedding(text):
46
  inputs = retrieval_tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
47
  with torch.no_grad():
48
  outputs = retrieval_model(**inputs)
49
- emb = outputs.last_hidden_state[:, 0, :].numpy()
50
 
51
- # βœ… Normalize Embeddings
52
- emb = emb / np.linalg.norm(emb)
53
- return emb
 
54
 
55
- # --- Compute Similarity ---
56
- def compute_similarity(distance):
57
- return round(np.exp(-distance) * 100, 2) # βœ… Softmax similarity fix
58
-
59
- # --- Retrieve Trials ---
60
  @app.post("/retrieve")
61
  async def retrieve_trial(request: QueryRequest):
 
62
  query_vector = generate_embedding(request.text)
63
  distances, indices = index.search(query_vector, request.top_k)
64
 
65
- # βœ… Fix: Remove NaN/Inf from distances
66
- distances = np.nan_to_num(distances, nan=1e9, posinf=1e9, neginf=1e9)
67
-
68
  matched_trials = []
69
  for idx, dist in zip(indices[0], distances[0]):
70
  if idx < len(df_trials):
71
  nct_id = df_trials.iloc[idx]["NCT Number"]
72
- trial_data = df_trials[df_trials["NCT Number"] == nct_id].to_dict(orient="records")[0]
73
-
74
- # βœ… Ensure similarity is calculated safely
75
- similarity = round(np.exp(-dist) * 100, 2)
76
- similarity = max(0.0, similarity) # No negative similarity
77
-
78
- trial_data["similarity"] = f"{similarity}%"
79
- matched_trials.append(trial_data)
80
-
81
- return {"matched_trials": matched_trials}
82
-
83
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  @app.get("/")
85
  async def root():
86
- return {"message": "TrialGPT API Running!"}
 
7
  from pydantic import BaseModel
8
  from transformers import AutoModel, AutoTokenizer
9
 
10
+ # Hugging Face Cache Directory
11
  os.environ["HF_HOME"] = "/app/huggingface"
12
  os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "60"
13
 
14
  app = FastAPI()
15
 
16
+ # --- Load Clinical Trials CSV ---
17
  csv_path = "ctg-studies.csv"
18
  if os.path.exists(csv_path):
19
  df_trials = pd.read_csv(csv_path)
20
+ print("βœ… CSV File Loaded Successfully!")
21
  else:
22
+ raise FileNotFoundError(f"❌ CSV File Not Found: {csv_path}. Upload it first.")
23
 
24
  # --- Load FAISS Index ---
25
  dimension = 768
 
30
  print("βœ… FAISS Index Loaded!")
31
  else:
32
  index = faiss.IndexFlatL2(dimension)
33
+ print("⚠ FAISS Index Not Found! Using Empty Index.")
34
 
35
+ # --- Load Retrieval Model ---
36
  retrieval_model_name = "priyanandanwar/fine-tuned-gatortron"
37
  retrieval_tokenizer = AutoTokenizer.from_pretrained(retrieval_model_name)
38
  retrieval_model = AutoModel.from_pretrained(retrieval_model_name)
 
42
  text: str
43
  top_k: int = 5
44
 
45
+ # --- Generate Embedding for Query ---
46
  def generate_embedding(text):
47
  inputs = retrieval_tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
48
  with torch.no_grad():
49
  outputs = retrieval_model(**inputs)
50
+ return outputs.last_hidden_state[:, 0, :].numpy() # CLS Token Embedding
51
 
52
+ # --- Retrieve Clinical Trial Info ---
53
+ def get_trial_info(nct_id):
54
+ trial_info = df_trials[df_trials["NCT Number"] == nct_id].to_dict(orient="records")
55
+ return trial_info[0] if trial_info else None
56
 
57
+ # --- Retrieval Endpoint ---
 
 
 
 
58
  @app.post("/retrieve")
59
  async def retrieve_trial(request: QueryRequest):
60
+ """Retrieve Clinical Trial based on text"""
61
  query_vector = generate_embedding(request.text)
62
  distances, indices = index.search(query_vector, request.top_k)
63
 
 
 
 
64
  matched_trials = []
65
  for idx, dist in zip(indices[0], distances[0]):
66
  if idx < len(df_trials):
67
  nct_id = df_trials.iloc[idx]["NCT Number"]
68
+ trial_data = get_trial_info(nct_id)
69
+
70
+ if trial_data:
71
+ # *Ensure distance is valid*
72
+ if np.isnan(dist) or np.isinf(dist):
73
+ dist = 1e9 # Set large number instead of NaN/Inf
74
+
75
+ # βœ… *Proper Similarity Calculation*
76
+ similarity = round((1 / (1 + dist)) * 100, 2)
77
+ trial_data["similarity"] = f"{similarity}%" # Convert to percentage format
78
+
79
+ matched_trials.append(trial_data)
80
+
81
+ # *Convert NaN values to None before returning JSON*
82
+ def clean_json(data):
83
+ if isinstance(data, dict):
84
+ return {k: clean_json(v) for k, v in data.items()}
85
+ elif isinstance(data, list):
86
+ return [clean_json(v) for v in data]
87
+ elif isinstance(data, float) and (np.isnan(data) or np.isinf(data)):
88
+ return None
89
+ return data
90
+
91
+ return {"matched_trials": clean_json(matched_trials)}
92
+
93
+ # --- Root Endpoint ---
94
  @app.get("/")
95
  async def root():
96
+ return {"message": "TrialGPT API is Running with FAISS-based Retrieval!"}