Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
@@ -4,59 +4,58 @@ import torch
|
|
4 |
import numpy as np
|
5 |
from fastapi import FastAPI
|
6 |
from pydantic import BaseModel
|
7 |
-
from transformers import AutoModelForTokenClassification, AutoTokenizer
|
8 |
|
9 |
-
# Hugging Face Cache Directory (For HF Spaces)
|
10 |
os.environ["HF_HOME"] = "/app/huggingface"
|
11 |
|
12 |
app = FastAPI()
|
13 |
|
14 |
-
#
|
15 |
model_name = "priyanandanwar/fine-tuned-gatortron"
|
16 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
17 |
-
|
18 |
|
19 |
-
#
|
20 |
dimension = 768
|
21 |
-
|
|
|
|
|
22 |
|
23 |
-
|
24 |
-
index = faiss.read_index(faiss_index_path)
|
25 |
-
print("✅ FAISS Index Loaded!")
|
26 |
-
else:
|
27 |
-
index = faiss.IndexFlatL2(dimension)
|
28 |
-
print("⚠️ FAISS Index Not Found! Using Empty Index.")
|
29 |
-
|
30 |
-
# --- Load Retrieval Model for Embeddings ---
|
31 |
-
retrieval_model_name = "priyanandanwar/fine-tuned-gatortron"
|
32 |
-
retrieval_tokenizer = AutoTokenizer.from_pretrained(retrieval_model_name)
|
33 |
-
retrieval_model = AutoModel.from_pretrained(retrieval_model_name)
|
34 |
-
|
35 |
-
# --- Request Model ---
|
36 |
class QueryRequest(BaseModel):
|
37 |
text: str
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
-
# --- Retrieval Endpoint ---
|
48 |
@app.post("/retrieve")
|
49 |
async def retrieve_trial(request: QueryRequest):
|
50 |
"""Retrieve Clinical Trial based on text"""
|
51 |
-
query_vector =
|
52 |
-
|
53 |
-
|
54 |
-
# Convert retrieved indices to NCT IDs directly
|
55 |
-
results = [{"NCT_ID": str(int(idx)), "similarity": float(round(100 / (1 + dist), 2))} for idx, dist in zip(indices[0], distances[0])]
|
56 |
|
57 |
-
return {
|
|
|
|
|
|
|
58 |
|
59 |
-
# --- Root Endpoint ---
|
60 |
@app.get("/")
|
61 |
async def root():
|
62 |
-
return {"message": "TrialGPT API is Running with
|
|
|
4 |
import numpy as np
|
5 |
from fastapi import FastAPI
|
6 |
from pydantic import BaseModel
|
7 |
+
from transformers import AutoModelForTokenClassification, AutoTokenizer
|
8 |
|
|
|
9 |
os.environ["HF_HOME"] = "/app/huggingface"
|
10 |
|
11 |
app = FastAPI()
|
12 |
|
13 |
+
# Load Model for NER
|
14 |
model_name = "priyanandanwar/fine-tuned-gatortron"
|
15 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
16 |
+
model = AutoModelForTokenClassification.from_pretrained(model_name)
|
17 |
|
18 |
+
# Dummy FAISS Retrieval System
|
19 |
dimension = 768
|
20 |
+
index = faiss.IndexFlatL2(dimension)
|
21 |
+
db_vectors = np.random.rand(10, dimension).astype('float32')
|
22 |
+
index.add(db_vectors)
|
23 |
|
24 |
+
# Request Model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
class QueryRequest(BaseModel):
|
26 |
text: str
|
27 |
+
temperature: float = 0.7
|
28 |
+
max_tokens: int = 256
|
29 |
+
top_p: float = 0.9
|
30 |
+
top_k: int = 50
|
31 |
+
|
32 |
+
@app.post("/ner")
|
33 |
+
async def predict_ner(request: QueryRequest):
|
34 |
+
"""Perform Named Entity Recognition (NER)"""
|
35 |
+
tokens = tokenizer(request.text, return_tensors="pt", truncation=True, padding=True, max_length=request.max_tokens)
|
36 |
+
outputs = model(**tokens)
|
37 |
+
predictions = torch.argmax(outputs.logits, dim=-1)[0].tolist()
|
38 |
+
tokenized_text = tokenizer.tokenize(request.text)
|
39 |
+
|
40 |
+
return {
|
41 |
+
"tokens": tokenized_text,
|
42 |
+
"labels": predictions,
|
43 |
+
"temperature": request.temperature,
|
44 |
+
"top_p": request.top_p,
|
45 |
+
"top_k": request.top_k
|
46 |
+
}
|
47 |
|
|
|
48 |
@app.post("/retrieve")
|
49 |
async def retrieve_trial(request: QueryRequest):
|
50 |
"""Retrieve Clinical Trial based on text"""
|
51 |
+
query_vector = np.random.rand(1, dimension).astype('float32') # Dummy Query Encoding
|
52 |
+
_, indices = index.search(query_vector, request.top_k) # Retrieve Top K Matches
|
|
|
|
|
|
|
53 |
|
54 |
+
return {
|
55 |
+
"matched_trial_ids": indices.tolist(),
|
56 |
+
"top_k": request.top_k
|
57 |
+
}
|
58 |
|
|
|
59 |
@app.get("/")
|
60 |
async def root():
|
61 |
+
return {"message": "TrialGPT API is Running with Parameterized Inputs!"}
|