priyanandanwar commited on
Commit
7397e3c
·
verified ·
1 Parent(s): e5e0954

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +32 -4
main.py CHANGED
@@ -3,20 +3,48 @@ import os
3
  os.environ["HF_HOME"] = "/app/huggingface"
4
 
5
  from fastapi import FastAPI
 
6
  from transformers import AutoModelForTokenClassification, AutoTokenizer
7
  import torch
 
 
8
 
9
  app = FastAPI()
10
 
11
- # Load Model
12
  model_name = "priyanandanwar/fine-tuned-gatortron"
13
  tokenizer = AutoTokenizer.from_pretrained(model_name)
14
  model = AutoModelForTokenClassification.from_pretrained(model_name)
15
 
16
- @app.post("/predict")
17
- async def predict(text: str):
 
 
 
 
 
 
 
 
 
 
 
 
18
  tokens = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
19
  outputs = model(**tokens)
20
  predictions = torch.argmax(outputs.logits, dim=-1)[0].tolist()
 
21
 
22
- return {"tokens": tokenizer.tokenize(text), "labels": predictions}
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  os.environ["HF_HOME"] = "/app/huggingface"
4
 
5
  from fastapi import FastAPI
6
+ from pydantic import BaseModel
7
  from transformers import AutoModelForTokenClassification, AutoTokenizer
8
  import torch
9
+ import faiss
10
+ import numpy as np
11
 
12
  app = FastAPI()
13
 
14
+ # Load Model for NER
15
  model_name = "priyanandanwar/fine-tuned-gatortron"
16
  tokenizer = AutoTokenizer.from_pretrained(model_name)
17
  model = AutoModelForTokenClassification.from_pretrained(model_name)
18
 
19
+ # Dummy Retrieval System (Replace with Real FAISS Index)
20
+ dimension = 768
21
+ index = faiss.IndexFlatL2(dimension)
22
+ db_vectors = np.random.rand(10, dimension).astype('float32')
23
+ index.add(db_vectors)
24
+
25
+ # Define Request Model
26
+ class QueryRequest(BaseModel):
27
+ text: str
28
+
29
+ @app.post("/ner")
30
+ async def predict_ner(request: QueryRequest):
31
+ """Perform Named Entity Recognition (NER)"""
32
+ text = request.text
33
  tokens = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
34
  outputs = model(**tokens)
35
  predictions = torch.argmax(outputs.logits, dim=-1)[0].tolist()
36
+ tokenized_text = tokenizer.tokenize(text)
37
 
38
+ return {"tokens": tokenized_text, "labels": predictions}
39
+
40
+ @app.post("/retrieve")
41
+ async def retrieve_trial(request: QueryRequest):
42
+ """Retrieve Clinical Trial based on text"""
43
+ query_vector = np.random.rand(1, dimension).astype('float32') # Dummy Query Encoding
44
+ _, indices = index.search(query_vector, 1) # Retrieve Top 1 Match
45
+ return {"matched_trial_id": int(indices[0][0])}
46
+
47
+ @app.get("/")
48
+ async def root():
49
+ return {"message": "TrialGPT API is Running!"}
50
+