priyanandanwar commited on
Commit
e80c43e
·
verified ·
1 Parent(s): fecd08d

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +50 -38
main.py CHANGED
@@ -2,60 +2,72 @@ import os
2
  import faiss
3
  import torch
4
  import numpy as np
 
5
  from fastapi import FastAPI
6
  from pydantic import BaseModel
7
- from transformers import AutoModelForTokenClassification, AutoTokenizer
8
 
 
9
  os.environ["HF_HOME"] = "/app/huggingface"
10
 
11
  app = FastAPI()
12
 
13
- # Load Model for NER
14
- model_name = "priyanandanwar/fine-tuned-gatortron"
15
- tokenizer = AutoTokenizer.from_pretrained(model_name)
16
- model = AutoModelForTokenClassification.from_pretrained(model_name)
17
 
18
- # Dummy FAISS Retrieval System
19
  dimension = 768
20
- index = faiss.IndexFlatL2(dimension)
21
- db_vectors = np.random.rand(10, dimension).astype('float32')
22
- index.add(db_vectors)
23
 
24
- # Request Model
 
 
 
 
 
 
 
 
 
 
 
 
25
  class QueryRequest(BaseModel):
26
  text: str
27
- temperature: float = 0.7
28
- max_tokens: int = 256
29
- top_p: float = 0.9
30
- top_k: int = 50
31
-
32
- @app.post("/ner")
33
- async def predict_ner(request: QueryRequest):
34
- """Perform Named Entity Recognition (NER)"""
35
- tokens = tokenizer(request.text, return_tensors="pt", truncation=True, padding=True, max_length=request.max_tokens)
36
- outputs = model(**tokens)
37
- predictions = torch.argmax(outputs.logits, dim=-1)[0].tolist()
38
- tokenized_text = tokenizer.tokenize(request.text)
39
-
40
- return {
41
- "tokens": tokenized_text,
42
- "labels": predictions,
43
- "temperature": request.temperature,
44
- "top_p": request.top_p,
45
- "top_k": request.top_k
46
- }
47
 
 
 
 
 
 
 
48
  @app.post("/retrieve")
49
  async def retrieve_trial(request: QueryRequest):
50
  """Retrieve Clinical Trial based on text"""
51
- query_vector = np.random.rand(1, dimension).astype('float32') # Dummy Query Encoding
52
- _, indices = index.search(query_vector, request.top_k) # Retrieve Top K Matches
53
-
54
- return {
55
- "matched_trial_ids": indices.tolist(),
56
- "top_k": request.top_k
57
- }
 
 
 
 
 
 
58
 
 
59
  @app.get("/")
60
  async def root():
61
- return {"message": "TrialGPT API is Running with Parameterized Inputs!"}
 
2
  import faiss
3
  import torch
4
  import numpy as np
5
+ import pandas as pd
6
  from fastapi import FastAPI
7
  from pydantic import BaseModel
8
+ from transformers import AutoModel, AutoTokenizer
9
 
10
+ # Hugging Face Cache Directory
11
  os.environ["HF_HOME"] = "/app/huggingface"
12
 
13
  app = FastAPI()
14
 
15
+ # --- Load Clinical Trials CSV (for metadata lookup) ---
16
+ csv_path = "clinical_trials.csv" # Ensure this file is uploaded
17
+ df_trials = pd.read_csv(csv_path)
 
18
 
19
+ # --- Load FAISS Index ---
20
  dimension = 768
21
+ faiss_index_path = "clinical_trials.index"
 
 
22
 
23
+ if os.path.exists(faiss_index_path):
24
+ index = faiss.read_index(faiss_index_path)
25
+ print("✅ FAISS Index Loaded!")
26
+ else:
27
+ index = faiss.IndexFlatL2(dimension)
28
+ print("⚠️ FAISS Index Not Found! Using Empty Index.")
29
+
30
+ # --- Load Retrieval Model ---
31
+ retrieval_model_name = "priyanandanwar/fine-tuned-gatortron"
32
+ retrieval_tokenizer = AutoTokenizer.from_pretrained(retrieval_model_name)
33
+ retrieval_model = AutoModel.from_pretrained(retrieval_model_name)
34
+
35
+ # --- Request Model ---
36
  class QueryRequest(BaseModel):
37
  text: str
38
+ top_k: int = 5
39
+
40
+ # --- Generate Embedding for Query ---
41
+ def generate_embedding(text):
42
+ inputs = retrieval_tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
43
+ with torch.no_grad():
44
+ outputs = retrieval_model(**inputs)
45
+ return outputs.last_hidden_state[:, 0, :].numpy() # CLS Token Embedding
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ # --- Retrieve Clinical Trial Info ---
48
+ def get_trial_info(nct_id):
49
+ trial_info = df_trials[df_trials["NCT_ID"] == nct_id].to_dict(orient="records")
50
+ return trial_info[0] if trial_info else None
51
+
52
+ # --- Retrieval Endpoint ---
53
  @app.post("/retrieve")
54
  async def retrieve_trial(request: QueryRequest):
55
  """Retrieve Clinical Trial based on text"""
56
+ query_vector = generate_embedding(request.text)
57
+ distances, indices = index.search(query_vector, request.top_k)
58
+
59
+ matched_trials = []
60
+ for idx, dist in zip(indices[0], distances[0]):
61
+ nct_id = df_trials.iloc[idx]["NCT_ID"] # Get NCT_ID using FAISS index mapping
62
+ trial_data = get_trial_info(nct_id) # Fetch complete trial details
63
+
64
+ if trial_data:
65
+ trial_data["similarity"] = round(100 / (1 + dist), 2) # Convert similarity
66
+ matched_trials.append(trial_data)
67
+
68
+ return {"matched_trials": matched_trials}
69
 
70
+ # --- Root Endpoint ---
71
  @app.get("/")
72
  async def root():
73
+ return {"message": "TrialGPT API is Running with FAISS-based Retrieval!"}