Spaces:
Running
Running
File size: 3,305 Bytes
cf4089c 6927b07 cf4089c 6927b07 cf4089c 6927b07 cf4089c 6927b07 cf4089c 6927b07 cf4089c 6927b07 cf4089c 6927b07 cf4089c 6927b07 cf4089c 6927b07 cf4089c 6927b07 7397e3c 6927b07 e80c43e 7022d0b fbbf31c 6927b07 cf4089c 6927b07 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import os
import faiss
import torch
import numpy as np
import pandas as pd
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModel, AutoTokenizer
# Hugging Face Cache Directory
os.environ["HF_HOME"] = "/app/huggingface"
os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "60"
app = FastAPI()
# --- Load Clinical Trials CSV ---
csv_path = "ctg-studies.csv"
if os.path.exists(csv_path):
df_trials = pd.read_csv(csv_path)
print("β
CSV File Loaded Successfully!")
else:
raise FileNotFoundError(f"β CSV File Not Found: {csv_path}. Upload it first.")
# --- Load FAISS Index ---
dimension = 768
faiss_index_path = "clinical_trials.index"
if os.path.exists(faiss_index_path):
index = faiss.read_index(faiss_index_path)
print("β
FAISS Index Loaded!")
else:
index = faiss.IndexFlatL2(dimension)
print("β FAISS Index Not Found! Using Empty Index.")
# --- Load Retrieval Model ---
retrieval_model_name = "priyanandanwar/fine-tuned-gatortron"
retrieval_tokenizer = AutoTokenizer.from_pretrained(retrieval_model_name)
retrieval_model = AutoModel.from_pretrained(retrieval_model_name)
# --- Request Model ---
class QueryRequest(BaseModel):
text: str
top_k: int = 5
# --- Generate Embedding for Query ---
def generate_embedding(text):
inputs = retrieval_tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
with torch.no_grad():
outputs = retrieval_model(**inputs)
return outputs.last_hidden_state[:, 0, :].numpy() # CLS Token Embedding
# --- Retrieve Clinical Trial Info ---
def get_trial_info(nct_id):
trial_info = df_trials[df_trials["NCT Number"] == nct_id].to_dict(orient="records")
return trial_info[0] if trial_info else None
# --- Retrieval Endpoint ---
@app.post("/retrieve")
async def retrieve_trial(request: QueryRequest):
"""Retrieve Clinical Trial based on text"""
query_vector = generate_embedding(request.text)
distances, indices = index.search(query_vector, request.top_k)
matched_trials = []
for idx, dist in zip(indices[0], distances[0]):
if idx < len(df_trials):
nct_id = df_trials.iloc[idx]["NCT Number"]
trial_data = get_trial_info(nct_id)
if trial_data:
# *Ensure distance is valid*
if np.isnan(dist) or np.isinf(dist):
dist = 1e9 # Set large number instead of NaN/Inf
# β
*Proper Similarity Calculation*
similarity = round((1 / (1 + dist)) * 100, 2)
trial_data["similarity"] = f"{similarity}%" # Convert to percentage format
matched_trials.append(trial_data)
# *Convert NaN values to None before returning JSON*
def clean_json(data):
if isinstance(data, dict):
return {k: clean_json(v) for k, v in data.items()}
elif isinstance(data, list):
return [clean_json(v) for v in data]
elif isinstance(data, float) and (np.isnan(data) or np.isinf(data)):
return None
return data
return {"matched_trials": clean_json(matched_trials)}
# --- Root Endpoint ---
@app.get("/")
async def root():
return {"message": "TrialGPT API is Running with FAISS-based Retrieval!"} |