Spaces:
Running
Running
import os | |
import faiss | |
import torch | |
import numpy as np | |
import pandas as pd | |
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from transformers import AutoModel, AutoTokenizer | |
# Hugging Face Cache Directory | |
os.environ["HF_HOME"] = "/app/huggingface" | |
os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "60" | |
app = FastAPI() | |
# --- Load Clinical Trials CSV --- | |
csv_path = "ctg-studies.csv" | |
if os.path.exists(csv_path): | |
df_trials = pd.read_csv(csv_path) | |
print("β CSV File Loaded Successfully!") | |
else: | |
raise FileNotFoundError(f"β CSV File Not Found: {csv_path}. Upload it first.") | |
# --- Load FAISS Index --- | |
dimension = 768 | |
faiss_index_path = "clinical_trials.index" | |
if os.path.exists(faiss_index_path): | |
index = faiss.read_index(faiss_index_path) | |
print("β FAISS Index Loaded!") | |
else: | |
index = faiss.IndexFlatL2(dimension) | |
print("β FAISS Index Not Found! Using Empty Index.") | |
# --- Load Retrieval Model --- | |
retrieval_model_name = "priyanandanwar/fine-tuned-gatortron" | |
retrieval_tokenizer = AutoTokenizer.from_pretrained(retrieval_model_name) | |
retrieval_model = AutoModel.from_pretrained(retrieval_model_name) | |
# --- Request Model --- | |
class QueryRequest(BaseModel): | |
text: str | |
top_k: int = 5 | |
# --- Generate Embedding for Query --- | |
def generate_embedding(text): | |
inputs = retrieval_tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512) | |
with torch.no_grad(): | |
outputs = retrieval_model(**inputs) | |
return outputs.last_hidden_state[:, 0, :].numpy() # CLS Token Embedding | |
# --- Retrieve Clinical Trial Info --- | |
def get_trial_info(nct_id): | |
trial_info = df_trials[df_trials["NCT Number"] == nct_id].to_dict(orient="records") | |
return trial_info[0] if trial_info else None | |
# --- Retrieval Endpoint --- | |
async def retrieve_trial(request: QueryRequest): | |
"""Retrieve Clinical Trial based on text""" | |
query_vector = generate_embedding(request.text) | |
distances, indices = index.search(query_vector, request.top_k) | |
matched_trials = [] | |
for idx, dist in zip(indices[0], distances[0]): | |
if idx < len(df_trials): | |
nct_id = df_trials.iloc[idx]["NCT Number"] | |
trial_data = get_trial_info(nct_id) | |
if trial_data: | |
# *Ensure distance is valid* | |
if np.isnan(dist) or np.isinf(dist): | |
dist = 1e9 # Set large number instead of NaN/Inf | |
# β *Proper Similarity Calculation* | |
similarity = round((1 / (1 + dist)) * 100, 2) | |
trial_data["similarity"] = f"{similarity}%" # Convert to percentage format | |
matched_trials.append(trial_data) | |
# *Convert NaN values to None before returning JSON* | |
def clean_json(data): | |
if isinstance(data, dict): | |
return {k: clean_json(v) for k, v in data.items()} | |
elif isinstance(data, list): | |
return [clean_json(v) for v in data] | |
elif isinstance(data, float) and (np.isnan(data) or np.isinf(data)): | |
return None | |
return data | |
return {"matched_trials": clean_json(matched_trials)} | |
# --- Root Endpoint --- | |
async def root(): | |
return {"message": "TrialGPT API is Running with FAISS-based Retrieval!"} |