File size: 3,305 Bytes
cf4089c
 
 
 
 
 
 
 
 
6927b07
cf4089c
 
 
 
 
6927b07
cf4089c
 
 
6927b07
cf4089c
6927b07
cf4089c
 
 
 
 
 
 
 
 
 
6927b07
cf4089c
6927b07
cf4089c
 
 
 
 
 
 
 
 
6927b07
cf4089c
 
 
 
6927b07
cf4089c
6927b07
 
 
 
cf4089c
6927b07
7397e3c
 
6927b07
e80c43e
 
 
 
 
7022d0b
fbbf31c
6927b07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf4089c
 
6927b07
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import faiss
import torch
import numpy as np
import pandas as pd
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModel, AutoTokenizer

# Hugging Face Cache Directory
os.environ["HF_HOME"] = "/app/huggingface"
os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "60"

app = FastAPI()

# --- Load Clinical Trials CSV ---
csv_path = "ctg-studies.csv"
if os.path.exists(csv_path):
    df_trials = pd.read_csv(csv_path)
    print("βœ… CSV File Loaded Successfully!")
else:
    raise FileNotFoundError(f"❌ CSV File Not Found: {csv_path}. Upload it first.")

# --- Load FAISS Index ---
dimension = 768
faiss_index_path = "clinical_trials.index"

if os.path.exists(faiss_index_path):
    index = faiss.read_index(faiss_index_path)
    print("βœ… FAISS Index Loaded!")
else:
    index = faiss.IndexFlatL2(dimension)
    print("⚠ FAISS Index Not Found! Using Empty Index.")

# --- Load Retrieval Model ---
retrieval_model_name = "priyanandanwar/fine-tuned-gatortron"
retrieval_tokenizer = AutoTokenizer.from_pretrained(retrieval_model_name)
retrieval_model = AutoModel.from_pretrained(retrieval_model_name)

# --- Request Model ---
class QueryRequest(BaseModel):
    text: str
    top_k: int = 5

# --- Generate Embedding for Query ---
def generate_embedding(text):
    inputs = retrieval_tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
    with torch.no_grad():
        outputs = retrieval_model(**inputs)
    return outputs.last_hidden_state[:, 0, :].numpy()  # CLS Token Embedding

# --- Retrieve Clinical Trial Info ---
def get_trial_info(nct_id):
    trial_info = df_trials[df_trials["NCT Number"] == nct_id].to_dict(orient="records")
    return trial_info[0] if trial_info else None

# --- Retrieval Endpoint ---
@app.post("/retrieve")
async def retrieve_trial(request: QueryRequest):
    """Retrieve Clinical Trial based on text"""
    query_vector = generate_embedding(request.text)
    distances, indices = index.search(query_vector, request.top_k)

    matched_trials = []
    for idx, dist in zip(indices[0], distances[0]):
        if idx < len(df_trials):
            nct_id = df_trials.iloc[idx]["NCT Number"]
            trial_data = get_trial_info(nct_id)

            if trial_data:
                # *Ensure distance is valid*
                if np.isnan(dist) or np.isinf(dist):  
                    dist = 1e9  # Set large number instead of NaN/Inf

                # βœ… *Proper Similarity Calculation*
                similarity = round((1 / (1 + dist)) * 100, 2)  
                trial_data["similarity"] = f"{similarity}%"  # Convert to percentage format
                
                matched_trials.append(trial_data)

    # *Convert NaN values to None before returning JSON*
    def clean_json(data):
        if isinstance(data, dict):
            return {k: clean_json(v) for k, v in data.items()}
        elif isinstance(data, list):
            return [clean_json(v) for v in data]
        elif isinstance(data, float) and (np.isnan(data) or np.isinf(data)):
            return None
        return data

    return {"matched_trials": clean_json(matched_trials)}

# --- Root Endpoint ---
@app.get("/")
async def root():
    return {"message": "TrialGPT API is Running with FAISS-based Retrieval!"}