Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
@@ -7,18 +7,19 @@ from fastapi import FastAPI
|
|
7 |
from pydantic import BaseModel
|
8 |
from transformers import AutoModel, AutoTokenizer
|
9 |
|
|
|
10 |
os.environ["HF_HOME"] = "/app/huggingface"
|
11 |
os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "60"
|
12 |
|
13 |
app = FastAPI()
|
14 |
|
15 |
-
# --- Load CSV ---
|
16 |
csv_path = "ctg-studies.csv"
|
17 |
if os.path.exists(csv_path):
|
18 |
df_trials = pd.read_csv(csv_path)
|
19 |
-
print("β
CSV Loaded!")
|
20 |
else:
|
21 |
-
raise FileNotFoundError("β CSV File Not Found
|
22 |
|
23 |
# --- Load FAISS Index ---
|
24 |
dimension = 768
|
@@ -29,9 +30,9 @@ if os.path.exists(faiss_index_path):
|
|
29 |
print("β
FAISS Index Loaded!")
|
30 |
else:
|
31 |
index = faiss.IndexFlatL2(dimension)
|
32 |
-
print("
|
33 |
|
34 |
-
# --- Load Model ---
|
35 |
retrieval_model_name = "priyanandanwar/fine-tuned-gatortron"
|
36 |
retrieval_tokenizer = AutoTokenizer.from_pretrained(retrieval_model_name)
|
37 |
retrieval_model = AutoModel.from_pretrained(retrieval_model_name)
|
@@ -41,46 +42,55 @@ class QueryRequest(BaseModel):
|
|
41 |
text: str
|
42 |
top_k: int = 5
|
43 |
|
44 |
-
# --- Generate Embedding ---
|
45 |
def generate_embedding(text):
|
46 |
inputs = retrieval_tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
|
47 |
with torch.no_grad():
|
48 |
outputs = retrieval_model(**inputs)
|
49 |
-
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
54 |
|
55 |
-
# ---
|
56 |
-
def compute_similarity(distance):
|
57 |
-
return round(np.exp(-distance) * 100, 2) # β
Softmax similarity fix
|
58 |
-
|
59 |
-
# --- Retrieve Trials ---
|
60 |
@app.post("/retrieve")
|
61 |
async def retrieve_trial(request: QueryRequest):
|
|
|
62 |
query_vector = generate_embedding(request.text)
|
63 |
distances, indices = index.search(query_vector, request.top_k)
|
64 |
|
65 |
-
# β
Fix: Remove NaN/Inf from distances
|
66 |
-
distances = np.nan_to_num(distances, nan=1e9, posinf=1e9, neginf=1e9)
|
67 |
-
|
68 |
matched_trials = []
|
69 |
for idx, dist in zip(indices[0], distances[0]):
|
70 |
if idx < len(df_trials):
|
71 |
nct_id = df_trials.iloc[idx]["NCT Number"]
|
72 |
-
trial_data =
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
@app.get("/")
|
85 |
async def root():
|
86 |
-
return {"message": "TrialGPT API Running!"}
|
|
|
7 |
from pydantic import BaseModel
|
8 |
from transformers import AutoModel, AutoTokenizer
|
9 |
|
10 |
+
# Hugging Face Cache Directory
|
11 |
os.environ["HF_HOME"] = "/app/huggingface"
|
12 |
os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "60"
|
13 |
|
14 |
app = FastAPI()
|
15 |
|
16 |
+
# --- Load Clinical Trials CSV ---
|
17 |
csv_path = "ctg-studies.csv"
|
18 |
if os.path.exists(csv_path):
|
19 |
df_trials = pd.read_csv(csv_path)
|
20 |
+
print("β
CSV File Loaded Successfully!")
|
21 |
else:
|
22 |
+
raise FileNotFoundError(f"β CSV File Not Found: {csv_path}. Upload it first.")
|
23 |
|
24 |
# --- Load FAISS Index ---
|
25 |
dimension = 768
|
|
|
30 |
print("β
FAISS Index Loaded!")
|
31 |
else:
|
32 |
index = faiss.IndexFlatL2(dimension)
|
33 |
+
print("β FAISS Index Not Found! Using Empty Index.")
|
34 |
|
35 |
+
# --- Load Retrieval Model ---
|
36 |
retrieval_model_name = "priyanandanwar/fine-tuned-gatortron"
|
37 |
retrieval_tokenizer = AutoTokenizer.from_pretrained(retrieval_model_name)
|
38 |
retrieval_model = AutoModel.from_pretrained(retrieval_model_name)
|
|
|
42 |
text: str
|
43 |
top_k: int = 5
|
44 |
|
45 |
+
# --- Generate Embedding for Query ---
|
46 |
def generate_embedding(text):
|
47 |
inputs = retrieval_tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
|
48 |
with torch.no_grad():
|
49 |
outputs = retrieval_model(**inputs)
|
50 |
+
return outputs.last_hidden_state[:, 0, :].numpy() # CLS Token Embedding
|
51 |
|
52 |
+
# --- Retrieve Clinical Trial Info ---
|
53 |
+
def get_trial_info(nct_id):
|
54 |
+
trial_info = df_trials[df_trials["NCT Number"] == nct_id].to_dict(orient="records")
|
55 |
+
return trial_info[0] if trial_info else None
|
56 |
|
57 |
+
# --- Retrieval Endpoint ---
|
|
|
|
|
|
|
|
|
58 |
@app.post("/retrieve")
|
59 |
async def retrieve_trial(request: QueryRequest):
|
60 |
+
"""Retrieve Clinical Trial based on text"""
|
61 |
query_vector = generate_embedding(request.text)
|
62 |
distances, indices = index.search(query_vector, request.top_k)
|
63 |
|
|
|
|
|
|
|
64 |
matched_trials = []
|
65 |
for idx, dist in zip(indices[0], distances[0]):
|
66 |
if idx < len(df_trials):
|
67 |
nct_id = df_trials.iloc[idx]["NCT Number"]
|
68 |
+
trial_data = get_trial_info(nct_id)
|
69 |
+
|
70 |
+
if trial_data:
|
71 |
+
# *Ensure distance is valid*
|
72 |
+
if np.isnan(dist) or np.isinf(dist):
|
73 |
+
dist = 1e9 # Set large number instead of NaN/Inf
|
74 |
+
|
75 |
+
# β
*Proper Similarity Calculation*
|
76 |
+
similarity = round((1 / (1 + dist)) * 100, 2)
|
77 |
+
trial_data["similarity"] = f"{similarity}%" # Convert to percentage format
|
78 |
+
|
79 |
+
matched_trials.append(trial_data)
|
80 |
+
|
81 |
+
# *Convert NaN values to None before returning JSON*
|
82 |
+
def clean_json(data):
|
83 |
+
if isinstance(data, dict):
|
84 |
+
return {k: clean_json(v) for k, v in data.items()}
|
85 |
+
elif isinstance(data, list):
|
86 |
+
return [clean_json(v) for v in data]
|
87 |
+
elif isinstance(data, float) and (np.isnan(data) or np.isinf(data)):
|
88 |
+
return None
|
89 |
+
return data
|
90 |
+
|
91 |
+
return {"matched_trials": clean_json(matched_trials)}
|
92 |
+
|
93 |
+
# --- Root Endpoint ---
|
94 |
@app.get("/")
|
95 |
async def root():
|
96 |
+
return {"message": "TrialGPT API is Running with FAISS-based Retrieval!"}
|