Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
@@ -1,3 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
@app.post("/retrieve")
|
2 |
async def retrieve_trial(request: QueryRequest):
|
3 |
query_vector = generate_embedding(request.text)
|
@@ -20,3 +79,8 @@ async def retrieve_trial(request: QueryRequest):
|
|
20 |
matched_trials.append(trial_data)
|
21 |
|
22 |
return {"matched_trials": matched_trials}
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import faiss
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
from fastapi import FastAPI
|
7 |
+
from pydantic import BaseModel
|
8 |
+
from transformers import AutoModel, AutoTokenizer
|
9 |
+
|
10 |
+
os.environ["HF_HOME"] = "/app/huggingface"
|
11 |
+
os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "60"
|
12 |
+
|
13 |
+
app = FastAPI()
|
14 |
+
|
15 |
+
# --- Load CSV ---
|
16 |
+
csv_path = "ctg-studies.csv"
|
17 |
+
if os.path.exists(csv_path):
|
18 |
+
df_trials = pd.read_csv(csv_path)
|
19 |
+
print("✅ CSV Loaded!")
|
20 |
+
else:
|
21 |
+
raise FileNotFoundError("❌ CSV File Not Found!")
|
22 |
+
|
23 |
+
# --- Load FAISS Index ---
|
24 |
+
dimension = 768
|
25 |
+
faiss_index_path = "clinical_trials.index"
|
26 |
+
|
27 |
+
if os.path.exists(faiss_index_path):
|
28 |
+
index = faiss.read_index(faiss_index_path)
|
29 |
+
print("✅ FAISS Index Loaded!")
|
30 |
+
else:
|
31 |
+
index = faiss.IndexFlatL2(dimension)
|
32 |
+
print("⚠️ FAISS Index Empty!")
|
33 |
+
|
34 |
+
# --- Load Model ---
|
35 |
+
retrieval_model_name = "priyanandanwar/fine-tuned-gatortron"
|
36 |
+
retrieval_tokenizer = AutoTokenizer.from_pretrained(retrieval_model_name)
|
37 |
+
retrieval_model = AutoModel.from_pretrained(retrieval_model_name)
|
38 |
+
|
39 |
+
# --- Request Model ---
|
40 |
+
class QueryRequest(BaseModel):
|
41 |
+
text: str
|
42 |
+
top_k: int = 5
|
43 |
+
|
44 |
+
# --- Generate Embedding ---
|
45 |
+
def generate_embedding(text):
|
46 |
+
inputs = retrieval_tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
|
47 |
+
with torch.no_grad():
|
48 |
+
outputs = retrieval_model(**inputs)
|
49 |
+
emb = outputs.last_hidden_state[:, 0, :].numpy()
|
50 |
+
|
51 |
+
# ✅ Normalize Embeddings
|
52 |
+
emb = emb / np.linalg.norm(emb)
|
53 |
+
return emb
|
54 |
+
|
55 |
+
# --- Compute Similarity ---
|
56 |
+
def compute_similarity(distance):
|
57 |
+
return round(np.exp(-distance) * 100, 2) # ✅ Softmax similarity fix
|
58 |
+
|
59 |
+
# --- Retrieve Trials ---
|
60 |
@app.post("/retrieve")
|
61 |
async def retrieve_trial(request: QueryRequest):
|
62 |
query_vector = generate_embedding(request.text)
|
|
|
79 |
matched_trials.append(trial_data)
|
80 |
|
81 |
return {"matched_trials": matched_trials}
|
82 |
+
|
83 |
+
|
84 |
+
@app.get("/")
|
85 |
+
async def root():
|
86 |
+
return {"message": "TrialGPT API Running!"}
|