priyanandanwar commited on
Commit
fecd08d
·
verified ·
1 Parent(s): 67e6fc9

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +35 -36
main.py CHANGED
@@ -4,59 +4,58 @@ import torch
4
  import numpy as np
5
  from fastapi import FastAPI
6
  from pydantic import BaseModel
7
- from transformers import AutoModelForTokenClassification, AutoTokenizer, AutoModel
8
 
9
- # Hugging Face Cache Directory (For HF Spaces)
10
  os.environ["HF_HOME"] = "/app/huggingface"
11
 
12
  app = FastAPI()
13
 
14
- # --- Load NER Model ---
15
  model_name = "priyanandanwar/fine-tuned-gatortron"
16
  tokenizer = AutoTokenizer.from_pretrained(model_name)
17
- ner_model = AutoModelForTokenClassification.from_pretrained(model_name)
18
 
19
- # --- Load FAISS Index ---
20
  dimension = 768
21
- faiss_index_path = "clinical_trials.index"
 
 
22
 
23
- if os.path.exists(faiss_index_path):
24
- index = faiss.read_index(faiss_index_path)
25
- print("✅ FAISS Index Loaded!")
26
- else:
27
- index = faiss.IndexFlatL2(dimension)
28
- print("⚠️ FAISS Index Not Found! Using Empty Index.")
29
-
30
- # --- Load Retrieval Model for Embeddings ---
31
- retrieval_model_name = "priyanandanwar/fine-tuned-gatortron"
32
- retrieval_tokenizer = AutoTokenizer.from_pretrained(retrieval_model_name)
33
- retrieval_model = AutoModel.from_pretrained(retrieval_model_name)
34
-
35
- # --- Request Model ---
36
  class QueryRequest(BaseModel):
37
  text: str
38
- top_k: int = 5
39
-
40
- # --- Generate Embedding for Query ---
41
- def generate_embedding(text):
42
- inputs = retrieval_tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
43
- with torch.no_grad():
44
- outputs = retrieval_model(**inputs)
45
- return outputs.last_hidden_state[:, 0, :].numpy() # CLS Token Embedding
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- # --- Retrieval Endpoint ---
48
  @app.post("/retrieve")
49
  async def retrieve_trial(request: QueryRequest):
50
  """Retrieve Clinical Trial based on text"""
51
- query_vector = generate_embedding(request.text)
52
- distances, indices = index.search(query_vector, request.top_k)
53
-
54
- # Convert retrieved indices to NCT IDs directly
55
- results = [{"NCT_ID": str(int(idx)), "similarity": float(round(100 / (1 + dist), 2))} for idx, dist in zip(indices[0], distances[0])]
56
 
57
- return {"matched_trials": results}
 
 
 
58
 
59
- # --- Root Endpoint ---
60
  @app.get("/")
61
  async def root():
62
- return {"message": "TrialGPT API is Running with FAISS-based Retrieval!"}
 
4
  import numpy as np
5
  from fastapi import FastAPI
6
  from pydantic import BaseModel
7
+ from transformers import AutoModelForTokenClassification, AutoTokenizer
8
 
 
9
  os.environ["HF_HOME"] = "/app/huggingface"
10
 
11
  app = FastAPI()
12
 
13
+ # Load Model for NER
14
  model_name = "priyanandanwar/fine-tuned-gatortron"
15
  tokenizer = AutoTokenizer.from_pretrained(model_name)
16
+ model = AutoModelForTokenClassification.from_pretrained(model_name)
17
 
18
+ # Dummy FAISS Retrieval System
19
  dimension = 768
20
+ index = faiss.IndexFlatL2(dimension)
21
+ db_vectors = np.random.rand(10, dimension).astype('float32')
22
+ index.add(db_vectors)
23
 
24
+ # Request Model
 
 
 
 
 
 
 
 
 
 
 
 
25
  class QueryRequest(BaseModel):
26
  text: str
27
+ temperature: float = 0.7
28
+ max_tokens: int = 256
29
+ top_p: float = 0.9
30
+ top_k: int = 50
31
+
32
+ @app.post("/ner")
33
+ async def predict_ner(request: QueryRequest):
34
+ """Perform Named Entity Recognition (NER)"""
35
+ tokens = tokenizer(request.text, return_tensors="pt", truncation=True, padding=True, max_length=request.max_tokens)
36
+ outputs = model(**tokens)
37
+ predictions = torch.argmax(outputs.logits, dim=-1)[0].tolist()
38
+ tokenized_text = tokenizer.tokenize(request.text)
39
+
40
+ return {
41
+ "tokens": tokenized_text,
42
+ "labels": predictions,
43
+ "temperature": request.temperature,
44
+ "top_p": request.top_p,
45
+ "top_k": request.top_k
46
+ }
47
 
 
48
  @app.post("/retrieve")
49
  async def retrieve_trial(request: QueryRequest):
50
  """Retrieve Clinical Trial based on text"""
51
+ query_vector = np.random.rand(1, dimension).astype('float32') # Dummy Query Encoding
52
+ _, indices = index.search(query_vector, request.top_k) # Retrieve Top K Matches
 
 
 
53
 
54
+ return {
55
+ "matched_trial_ids": indices.tolist(),
56
+ "top_k": request.top_k
57
+ }
58
 
 
59
  @app.get("/")
60
  async def root():
61
+ return {"message": "TrialGPT API is Running with Parameterized Inputs!"}