demo-api / main.py
priyanandanwar's picture
Update main.py
6927b07 verified
import os
import faiss
import torch
import numpy as np
import pandas as pd
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModel, AutoTokenizer
# Hugging Face Cache Directory
os.environ["HF_HOME"] = "/app/huggingface"
os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "60"
app = FastAPI()
# --- Load Clinical Trials CSV ---
csv_path = "ctg-studies.csv"
if os.path.exists(csv_path):
df_trials = pd.read_csv(csv_path)
print("βœ… CSV File Loaded Successfully!")
else:
raise FileNotFoundError(f"❌ CSV File Not Found: {csv_path}. Upload it first.")
# --- Load FAISS Index ---
dimension = 768
faiss_index_path = "clinical_trials.index"
if os.path.exists(faiss_index_path):
index = faiss.read_index(faiss_index_path)
print("βœ… FAISS Index Loaded!")
else:
index = faiss.IndexFlatL2(dimension)
print("⚠ FAISS Index Not Found! Using Empty Index.")
# --- Load Retrieval Model ---
retrieval_model_name = "priyanandanwar/fine-tuned-gatortron"
retrieval_tokenizer = AutoTokenizer.from_pretrained(retrieval_model_name)
retrieval_model = AutoModel.from_pretrained(retrieval_model_name)
# --- Request Model ---
class QueryRequest(BaseModel):
text: str
top_k: int = 5
# --- Generate Embedding for Query ---
def generate_embedding(text):
inputs = retrieval_tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
with torch.no_grad():
outputs = retrieval_model(**inputs)
return outputs.last_hidden_state[:, 0, :].numpy() # CLS Token Embedding
# --- Retrieve Clinical Trial Info ---
def get_trial_info(nct_id):
trial_info = df_trials[df_trials["NCT Number"] == nct_id].to_dict(orient="records")
return trial_info[0] if trial_info else None
# --- Retrieval Endpoint ---
@app.post("/retrieve")
async def retrieve_trial(request: QueryRequest):
"""Retrieve Clinical Trial based on text"""
query_vector = generate_embedding(request.text)
distances, indices = index.search(query_vector, request.top_k)
matched_trials = []
for idx, dist in zip(indices[0], distances[0]):
if idx < len(df_trials):
nct_id = df_trials.iloc[idx]["NCT Number"]
trial_data = get_trial_info(nct_id)
if trial_data:
# *Ensure distance is valid*
if np.isnan(dist) or np.isinf(dist):
dist = 1e9 # Set large number instead of NaN/Inf
# βœ… *Proper Similarity Calculation*
similarity = round((1 / (1 + dist)) * 100, 2)
trial_data["similarity"] = f"{similarity}%" # Convert to percentage format
matched_trials.append(trial_data)
# *Convert NaN values to None before returning JSON*
def clean_json(data):
if isinstance(data, dict):
return {k: clean_json(v) for k, v in data.items()}
elif isinstance(data, list):
return [clean_json(v) for v in data]
elif isinstance(data, float) and (np.isnan(data) or np.isinf(data)):
return None
return data
return {"matched_trials": clean_json(matched_trials)}
# --- Root Endpoint ---
@app.get("/")
async def root():
return {"message": "TrialGPT API is Running with FAISS-based Retrieval!"}