Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
@@ -2,60 +2,72 @@ import os
|
|
2 |
import faiss
|
3 |
import torch
|
4 |
import numpy as np
|
|
|
5 |
from fastapi import FastAPI
|
6 |
from pydantic import BaseModel
|
7 |
-
from transformers import
|
8 |
|
|
|
9 |
os.environ["HF_HOME"] = "/app/huggingface"
|
10 |
|
11 |
app = FastAPI()
|
12 |
|
13 |
-
# Load
|
14 |
-
|
15 |
-
|
16 |
-
model = AutoModelForTokenClassification.from_pretrained(model_name)
|
17 |
|
18 |
-
#
|
19 |
dimension = 768
|
20 |
-
|
21 |
-
db_vectors = np.random.rand(10, dimension).astype('float32')
|
22 |
-
index.add(db_vectors)
|
23 |
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
class QueryRequest(BaseModel):
|
26 |
text: str
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
tokens = tokenizer(request.text, return_tensors="pt", truncation=True, padding=True, max_length=request.max_tokens)
|
36 |
-
outputs = model(**tokens)
|
37 |
-
predictions = torch.argmax(outputs.logits, dim=-1)[0].tolist()
|
38 |
-
tokenized_text = tokenizer.tokenize(request.text)
|
39 |
-
|
40 |
-
return {
|
41 |
-
"tokens": tokenized_text,
|
42 |
-
"labels": predictions,
|
43 |
-
"temperature": request.temperature,
|
44 |
-
"top_p": request.top_p,
|
45 |
-
"top_k": request.top_k
|
46 |
-
}
|
47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
@app.post("/retrieve")
|
49 |
async def retrieve_trial(request: QueryRequest):
|
50 |
"""Retrieve Clinical Trial based on text"""
|
51 |
-
query_vector =
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
"
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
|
|
59 |
@app.get("/")
|
60 |
async def root():
|
61 |
-
return {"message": "TrialGPT API is Running with
|
|
|
2 |
import faiss
|
3 |
import torch
|
4 |
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
from fastapi import FastAPI
|
7 |
from pydantic import BaseModel
|
8 |
+
from transformers import AutoModel, AutoTokenizer
|
9 |
|
10 |
+
# Hugging Face Cache Directory
|
11 |
os.environ["HF_HOME"] = "/app/huggingface"
|
12 |
|
13 |
app = FastAPI()
|
14 |
|
15 |
+
# --- Load Clinical Trials CSV (for metadata lookup) ---
|
16 |
+
csv_path = "clinical_trials.csv" # Ensure this file is uploaded
|
17 |
+
df_trials = pd.read_csv(csv_path)
|
|
|
18 |
|
19 |
+
# --- Load FAISS Index ---
|
20 |
dimension = 768
|
21 |
+
faiss_index_path = "clinical_trials.index"
|
|
|
|
|
22 |
|
23 |
+
if os.path.exists(faiss_index_path):
|
24 |
+
index = faiss.read_index(faiss_index_path)
|
25 |
+
print("✅ FAISS Index Loaded!")
|
26 |
+
else:
|
27 |
+
index = faiss.IndexFlatL2(dimension)
|
28 |
+
print("⚠️ FAISS Index Not Found! Using Empty Index.")
|
29 |
+
|
30 |
+
# --- Load Retrieval Model ---
|
31 |
+
retrieval_model_name = "priyanandanwar/fine-tuned-gatortron"
|
32 |
+
retrieval_tokenizer = AutoTokenizer.from_pretrained(retrieval_model_name)
|
33 |
+
retrieval_model = AutoModel.from_pretrained(retrieval_model_name)
|
34 |
+
|
35 |
+
# --- Request Model ---
|
36 |
class QueryRequest(BaseModel):
|
37 |
text: str
|
38 |
+
top_k: int = 5
|
39 |
+
|
40 |
+
# --- Generate Embedding for Query ---
|
41 |
+
def generate_embedding(text):
|
42 |
+
inputs = retrieval_tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
|
43 |
+
with torch.no_grad():
|
44 |
+
outputs = retrieval_model(**inputs)
|
45 |
+
return outputs.last_hidden_state[:, 0, :].numpy() # CLS Token Embedding
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
+
# --- Retrieve Clinical Trial Info ---
|
48 |
+
def get_trial_info(nct_id):
|
49 |
+
trial_info = df_trials[df_trials["NCT_ID"] == nct_id].to_dict(orient="records")
|
50 |
+
return trial_info[0] if trial_info else None
|
51 |
+
|
52 |
+
# --- Retrieval Endpoint ---
|
53 |
@app.post("/retrieve")
|
54 |
async def retrieve_trial(request: QueryRequest):
|
55 |
"""Retrieve Clinical Trial based on text"""
|
56 |
+
query_vector = generate_embedding(request.text)
|
57 |
+
distances, indices = index.search(query_vector, request.top_k)
|
58 |
+
|
59 |
+
matched_trials = []
|
60 |
+
for idx, dist in zip(indices[0], distances[0]):
|
61 |
+
nct_id = df_trials.iloc[idx]["NCT_ID"] # Get NCT_ID using FAISS index mapping
|
62 |
+
trial_data = get_trial_info(nct_id) # Fetch complete trial details
|
63 |
+
|
64 |
+
if trial_data:
|
65 |
+
trial_data["similarity"] = round(100 / (1 + dist), 2) # Convert similarity
|
66 |
+
matched_trials.append(trial_data)
|
67 |
+
|
68 |
+
return {"matched_trials": matched_trials}
|
69 |
|
70 |
+
# --- Root Endpoint ---
|
71 |
@app.get("/")
|
72 |
async def root():
|
73 |
+
return {"message": "TrialGPT API is Running with FAISS-based Retrieval!"}
|