Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
@@ -3,20 +3,48 @@ import os
|
|
3 |
os.environ["HF_HOME"] = "/app/huggingface"
|
4 |
|
5 |
from fastapi import FastAPI
|
|
|
6 |
from transformers import AutoModelForTokenClassification, AutoTokenizer
|
7 |
import torch
|
|
|
|
|
8 |
|
9 |
app = FastAPI()
|
10 |
|
11 |
-
# Load Model
|
12 |
model_name = "priyanandanwar/fine-tuned-gatortron"
|
13 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
14 |
model = AutoModelForTokenClassification.from_pretrained(model_name)
|
15 |
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
tokens = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
19 |
outputs = model(**tokens)
|
20 |
predictions = torch.argmax(outputs.logits, dim=-1)[0].tolist()
|
|
|
21 |
|
22 |
-
return {"tokens":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
os.environ["HF_HOME"] = "/app/huggingface"
|
4 |
|
5 |
from fastapi import FastAPI
|
6 |
+
from pydantic import BaseModel
|
7 |
from transformers import AutoModelForTokenClassification, AutoTokenizer
|
8 |
import torch
|
9 |
+
import faiss
|
10 |
+
import numpy as np
|
11 |
|
12 |
app = FastAPI()
|
13 |
|
14 |
+
# Load Model for NER
|
15 |
model_name = "priyanandanwar/fine-tuned-gatortron"
|
16 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
17 |
model = AutoModelForTokenClassification.from_pretrained(model_name)
|
18 |
|
19 |
+
# Dummy Retrieval System (Replace with Real FAISS Index)
|
20 |
+
dimension = 768
|
21 |
+
index = faiss.IndexFlatL2(dimension)
|
22 |
+
db_vectors = np.random.rand(10, dimension).astype('float32')
|
23 |
+
index.add(db_vectors)
|
24 |
+
|
25 |
+
# Define Request Model
|
26 |
+
class QueryRequest(BaseModel):
|
27 |
+
text: str
|
28 |
+
|
29 |
+
@app.post("/ner")
|
30 |
+
async def predict_ner(request: QueryRequest):
|
31 |
+
"""Perform Named Entity Recognition (NER)"""
|
32 |
+
text = request.text
|
33 |
tokens = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
34 |
outputs = model(**tokens)
|
35 |
predictions = torch.argmax(outputs.logits, dim=-1)[0].tolist()
|
36 |
+
tokenized_text = tokenizer.tokenize(text)
|
37 |
|
38 |
+
return {"tokens": tokenized_text, "labels": predictions}
|
39 |
+
|
40 |
+
@app.post("/retrieve")
|
41 |
+
async def retrieve_trial(request: QueryRequest):
|
42 |
+
"""Retrieve Clinical Trial based on text"""
|
43 |
+
query_vector = np.random.rand(1, dimension).astype('float32') # Dummy Query Encoding
|
44 |
+
_, indices = index.search(query_vector, 1) # Retrieve Top 1 Match
|
45 |
+
return {"matched_trial_id": int(indices[0][0])}
|
46 |
+
|
47 |
+
@app.get("/")
|
48 |
+
async def root():
|
49 |
+
return {"message": "TrialGPT API is Running!"}
|
50 |
+
|