priyanandanwar commited on
Commit
cf4089c
·
verified ·
1 Parent(s): 9ae18e2

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +64 -0
main.py CHANGED
@@ -1,3 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  @app.post("/retrieve")
2
  async def retrieve_trial(request: QueryRequest):
3
  query_vector = generate_embedding(request.text)
@@ -20,3 +79,8 @@ async def retrieve_trial(request: QueryRequest):
20
  matched_trials.append(trial_data)
21
 
22
  return {"matched_trials": matched_trials}
 
 
 
 
 
 
1
+ import os
2
+ import faiss
3
+ import torch
4
+ import numpy as np
5
+ import pandas as pd
6
+ from fastapi import FastAPI
7
+ from pydantic import BaseModel
8
+ from transformers import AutoModel, AutoTokenizer
9
+
10
+ os.environ["HF_HOME"] = "/app/huggingface"
11
+ os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "60"
12
+
13
+ app = FastAPI()
14
+
15
+ # --- Load CSV ---
16
+ csv_path = "ctg-studies.csv"
17
+ if os.path.exists(csv_path):
18
+ df_trials = pd.read_csv(csv_path)
19
+ print("✅ CSV Loaded!")
20
+ else:
21
+ raise FileNotFoundError("❌ CSV File Not Found!")
22
+
23
+ # --- Load FAISS Index ---
24
+ dimension = 768
25
+ faiss_index_path = "clinical_trials.index"
26
+
27
+ if os.path.exists(faiss_index_path):
28
+ index = faiss.read_index(faiss_index_path)
29
+ print("✅ FAISS Index Loaded!")
30
+ else:
31
+ index = faiss.IndexFlatL2(dimension)
32
+ print("⚠️ FAISS Index Empty!")
33
+
34
+ # --- Load Model ---
35
+ retrieval_model_name = "priyanandanwar/fine-tuned-gatortron"
36
+ retrieval_tokenizer = AutoTokenizer.from_pretrained(retrieval_model_name)
37
+ retrieval_model = AutoModel.from_pretrained(retrieval_model_name)
38
+
39
+ # --- Request Model ---
40
+ class QueryRequest(BaseModel):
41
+ text: str
42
+ top_k: int = 5
43
+
44
+ # --- Generate Embedding ---
45
+ def generate_embedding(text):
46
+ inputs = retrieval_tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
47
+ with torch.no_grad():
48
+ outputs = retrieval_model(**inputs)
49
+ emb = outputs.last_hidden_state[:, 0, :].numpy()
50
+
51
+ # ✅ Normalize Embeddings
52
+ emb = emb / np.linalg.norm(emb)
53
+ return emb
54
+
55
+ # --- Compute Similarity ---
56
+ def compute_similarity(distance):
57
+ return round(np.exp(-distance) * 100, 2) # ✅ Softmax similarity fix
58
+
59
+ # --- Retrieve Trials ---
60
  @app.post("/retrieve")
61
  async def retrieve_trial(request: QueryRequest):
62
  query_vector = generate_embedding(request.text)
 
79
  matched_trials.append(trial_data)
80
 
81
  return {"matched_trials": matched_trials}
82
+
83
+
84
+ @app.get("/")
85
+ async def root():
86
+ return {"message": "TrialGPT API Running!"}