import os import faiss import torch import numpy as np import pandas as pd from fastapi import FastAPI from pydantic import BaseModel from transformers import AutoModel, AutoTokenizer # Hugging Face Cache Directory os.environ["HF_HOME"] = "/app/huggingface" os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "60" app = FastAPI() # --- Load Clinical Trials CSV --- csv_path = "ctg-studies.csv" if os.path.exists(csv_path): df_trials = pd.read_csv(csv_path) print("✅ CSV File Loaded Successfully!") else: raise FileNotFoundError(f"❌ CSV File Not Found: {csv_path}. Upload it first.") # --- Load FAISS Index --- dimension = 768 faiss_index_path = "clinical_trials.index" if os.path.exists(faiss_index_path): index = faiss.read_index(faiss_index_path) print("✅ FAISS Index Loaded!") else: index = faiss.IndexFlatL2(dimension) print("⚠ FAISS Index Not Found! Using Empty Index.") # --- Load Retrieval Model --- retrieval_model_name = "priyanandanwar/fine-tuned-gatortron" retrieval_tokenizer = AutoTokenizer.from_pretrained(retrieval_model_name) retrieval_model = AutoModel.from_pretrained(retrieval_model_name) # --- Request Model --- class QueryRequest(BaseModel): text: str top_k: int = 5 # --- Generate Embedding for Query --- def generate_embedding(text): inputs = retrieval_tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512) with torch.no_grad(): outputs = retrieval_model(**inputs) return outputs.last_hidden_state[:, 0, :].numpy() # CLS Token Embedding # --- Retrieve Clinical Trial Info --- def get_trial_info(nct_id): trial_info = df_trials[df_trials["NCT Number"] == nct_id].to_dict(orient="records") return trial_info[0] if trial_info else None # --- Retrieval Endpoint --- @app.post("/retrieve") async def retrieve_trial(request: QueryRequest): """Retrieve Clinical Trial based on text""" query_vector = generate_embedding(request.text) distances, indices = index.search(query_vector, request.top_k) matched_trials = [] for idx, dist in zip(indices[0], distances[0]): if idx < len(df_trials): nct_id = df_trials.iloc[idx]["NCT Number"] trial_data = get_trial_info(nct_id) if trial_data: # *Ensure distance is valid* if np.isnan(dist) or np.isinf(dist): dist = 1e9 # Set large number instead of NaN/Inf # ✅ *Proper Similarity Calculation* similarity = round((1 / (1 + dist)) * 100, 2) trial_data["similarity"] = f"{similarity}%" # Convert to percentage format matched_trials.append(trial_data) # *Convert NaN values to None before returning JSON* def clean_json(data): if isinstance(data, dict): return {k: clean_json(v) for k, v in data.items()} elif isinstance(data, list): return [clean_json(v) for v in data] elif isinstance(data, float) and (np.isnan(data) or np.isinf(data)): return None return data return {"matched_trials": clean_json(matched_trials)} # --- Root Endpoint --- @app.get("/") async def root(): return {"message": "TrialGPT API is Running with FAISS-based Retrieval!"}