Lamp Socrates
commited on
Commit
·
4efeb3b
1
Parent(s):
b022555
latest
Browse files
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import uvicorn
|
2 |
import threading
|
|
|
3 |
from typing import Optional
|
4 |
from transformers import pipeline
|
5 |
from transformers import AutoTokenizer, AutoModelForTokenClassification
|
@@ -13,11 +14,14 @@ from fastapi import FastAPI
|
|
13 |
from pydantic import BaseModel
|
14 |
from typing import List, Dict
|
15 |
|
|
|
16 |
# Define the FastAPI app
|
17 |
app = FastAPI()
|
18 |
model_cache: Optional[object] = None
|
|
|
19 |
|
20 |
def load_model():
|
|
|
21 |
|
22 |
tokenizer = AutoTokenizer.from_pretrained("LampOfSocrates/bert-cased-plodcw-sourav")
|
23 |
model = AutoModelForTokenClassification.from_pretrained("LampOfSocrates/bert-cased-plodcw-sourav")
|
@@ -36,6 +40,12 @@ def load_plod_cw_dataset():
|
|
36 |
dataset = load_dataset("surrey-nlp/PLOD-CW")
|
37 |
return dataset
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
def get_cached_model():
|
40 |
global model_cache
|
41 |
if model_cache is None:
|
@@ -44,8 +54,7 @@ def get_cached_model():
|
|
44 |
|
45 |
# Cache the model when the server starts
|
46 |
model = get_cached_model()
|
47 |
-
|
48 |
-
|
49 |
|
50 |
class Entity(BaseModel):
|
51 |
entity: str
|
@@ -62,15 +71,20 @@ class NERRequest(BaseModel):
|
|
62 |
|
63 |
@app.get("/hello")
|
64 |
def read_root():
|
|
|
65 |
return {"message": "Hello, World!"}
|
66 |
|
67 |
|
68 |
@app.post("/ner", response_model=NERResponse)
|
69 |
def get_entities(request: NERRequest):
|
|
|
70 |
print(request)
|
|
|
71 |
model = get_cached_model()
|
|
|
72 |
# Use the NER model to detect entities
|
73 |
entities = model(request.text)
|
|
|
74 |
print(entities[0].keys())
|
75 |
# Convert entities to the response model
|
76 |
response_entities = [Entity(**entity) for entity in entities]
|
@@ -81,8 +95,9 @@ def get_color_for_label(label: str) -> str:
|
|
81 |
# Define a mapping of labels to colors
|
82 |
color_mapping = {
|
83 |
"I-LF": "red",
|
|
|
84 |
"B-AC": "blue",
|
85 |
-
"
|
86 |
# Add more labels and colors as needed
|
87 |
}
|
88 |
return color_mapping.get(label, "black") # Default to black if label not found
|
@@ -90,30 +105,73 @@ def get_color_for_label(label: str) -> str:
|
|
90 |
|
91 |
# Define the Gradio interface function
|
92 |
def ner_demo(text):
|
|
|
93 |
model = get_cached_model()
|
94 |
entities = model(text)
|
95 |
-
#return {"entities": entities}
|
96 |
|
97 |
-
|
98 |
-
|
|
|
|
|
|
|
99 |
for entity in entities:
|
100 |
-
#print(entity)
|
101 |
start, end, label = entity["start"], entity["end"], entity["entity"]
|
102 |
-
color = get_color_for_label(label)
|
103 |
entity_text = text[start:end]
|
104 |
-
colored_entity = f'<span style="color: {color}; font-weight: bold;">{entity_text}</span>'
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
def echo(text, request: gr.Request):
|
|
|
112 |
if request:
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
# Create the Gradio interface
|
119 |
demo = gr.Interface(
|
@@ -124,26 +182,27 @@ demo = gr.Interface(
|
|
124 |
title="Named Entity Recognition on PLOD-CW ",
|
125 |
description=f"{PROJECT_INTRO}\n\nEnter text to extract named entities using a NER model."
|
126 |
)
|
127 |
-
'''
|
128 |
-
with gr.Blocks() as demo:
|
129 |
-
gr.Markdown("# Page Title")
|
130 |
-
gr.Markdown("## Subtitle with h2 Font")
|
131 |
-
inputs=gr.Textbox(lines=10, placeholder="Enter text here...", label="Input Text")
|
132 |
-
|
133 |
-
with gr.Column():
|
134 |
-
echo_output = gr.Textbox(label="Echo Output")
|
135 |
-
html_output = ner_demo
|
136 |
|
137 |
-
|
138 |
-
|
139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
-
|
142 |
-
|
143 |
|
144 |
-
# Function to run FastAPI
|
145 |
-
def run_fastapi():
|
146 |
-
uvicorn.run(app, host="0.0.0.0", port=8000)
|
147 |
|
148 |
# Function to run Gradio
|
149 |
|
|
|
1 |
import uvicorn
|
2 |
import threading
|
3 |
+
from collections import Counter
|
4 |
from typing import Optional
|
5 |
from transformers import pipeline
|
6 |
from transformers import AutoTokenizer, AutoModelForTokenClassification
|
|
|
14 |
from pydantic import BaseModel
|
15 |
from typing import List, Dict
|
16 |
|
17 |
+
|
18 |
# Define the FastAPI app
|
19 |
app = FastAPI()
|
20 |
model_cache: Optional[object] = None
|
21 |
+
dataset_cache : Optional[object] = None
|
22 |
|
23 |
def load_model():
|
24 |
+
""" We load the model at startup"""
|
25 |
|
26 |
tokenizer = AutoTokenizer.from_pretrained("LampOfSocrates/bert-cased-plodcw-sourav")
|
27 |
model = AutoModelForTokenClassification.from_pretrained("LampOfSocrates/bert-cased-plodcw-sourav")
|
|
|
40 |
dataset = load_dataset("surrey-nlp/PLOD-CW")
|
41 |
return dataset
|
42 |
|
43 |
+
def get_cached_data():
|
44 |
+
global dataset_cache
|
45 |
+
if dataset_cache is None:
|
46 |
+
dataset_cache = load_plod_cw_dataset()
|
47 |
+
return dataset_cache
|
48 |
+
|
49 |
def get_cached_model():
|
50 |
global model_cache
|
51 |
if model_cache is None:
|
|
|
54 |
|
55 |
# Cache the model when the server starts
|
56 |
model = get_cached_model()
|
57 |
+
#plod_cw = get_cached_data()
|
|
|
58 |
|
59 |
class Entity(BaseModel):
|
60 |
entity: str
|
|
|
71 |
|
72 |
@app.get("/hello")
|
73 |
def read_root():
|
74 |
+
"""useful for testing connections"""
|
75 |
return {"message": "Hello, World!"}
|
76 |
|
77 |
|
78 |
@app.post("/ner", response_model=NERResponse)
|
79 |
def get_entities(request: NERRequest):
|
80 |
+
""" This is invoked while API Testing """
|
81 |
print(request)
|
82 |
+
|
83 |
model = get_cached_model()
|
84 |
+
|
85 |
# Use the NER model to detect entities
|
86 |
entities = model(request.text)
|
87 |
+
|
88 |
print(entities[0].keys())
|
89 |
# Convert entities to the response model
|
90 |
response_entities = [Entity(**entity) for entity in entities]
|
|
|
95 |
# Define a mapping of labels to colors
|
96 |
color_mapping = {
|
97 |
"I-LF": "red",
|
98 |
+
"B-LF": "pink",
|
99 |
"B-AC": "blue",
|
100 |
+
"B-O": "green",
|
101 |
# Add more labels and colors as needed
|
102 |
}
|
103 |
return color_mapping.get(label, "black") # Default to black if label not found
|
|
|
105 |
|
106 |
# Define the Gradio interface function
|
107 |
def ner_demo(text):
|
108 |
+
""" This is invoked while rendering the page"""
|
109 |
model = get_cached_model()
|
110 |
entities = model(text)
|
|
|
111 |
|
112 |
+
print("Entities detected {}".format(Counter( [ entity['entity'] for entity in entities])))
|
113 |
+
|
114 |
+
all_html = ""
|
115 |
+
last_index = 0
|
116 |
+
|
117 |
for entity in entities:
|
|
|
118 |
start, end, label = entity["start"], entity["end"], entity["entity"]
|
119 |
+
color = get_color_for_label(label)
|
120 |
entity_text = text[start:end]
|
121 |
+
#colored_entity = f'<span style="color: {color}; font-weight: bold;">{entity_text}</span>'
|
122 |
+
colored_entity = f'<sup style="color: {color}; font-weight: bold;">{entity_text}</sup>'
|
123 |
+
|
124 |
+
|
125 |
+
# Append text before the entity
|
126 |
+
all_html += text[last_index:start]
|
127 |
+
# Append the colored entity
|
128 |
+
all_html += colored_entity
|
129 |
+
# Update the last_index
|
130 |
+
last_index = end
|
131 |
+
|
132 |
+
# Append the remaining text after the last entity
|
133 |
+
all_html += text[last_index:]
|
134 |
+
return all_html
|
135 |
+
|
136 |
+
bo_color = get_color_for_label("B-O")
|
137 |
+
bac_color = get_color_for_label("B-AC")
|
138 |
+
ilf_color = get_color_for_label("I-LF")
|
139 |
+
blf_color = get_color_for_label("B-LF")
|
140 |
+
|
141 |
+
PROJECT_INTRO = f"""This is a HF Spaces hosted Gradio App built by NLP Group 27. \n\n
|
142 |
+
The model has been trained on surrey-nlp/PLOD-CW dataset.
|
143 |
+
The following Entities are recognized:
|
144 |
+
<sup style="color: {bo_color}; font-weight: bold;">B-O</sup>
|
145 |
+
<sup style="color: {bac_color}; font-weight: bold;">B-AC</sup>
|
146 |
+
<sup style="color: {ilf_color}; font-weight: bold;">I-LF</sup>
|
147 |
+
<sup style="color: {blf_color}; font-weight: bold;">B-LF</sup>
|
148 |
+
<sup style="color: black; font-weight: bold;">Rest</sup>
|
149 |
+
"""
|
150 |
def echo(text, request: gr.Request):
|
151 |
+
res = '<div>'
|
152 |
if request:
|
153 |
+
res += f"Request headers dictionary: {request.headers} <p>"
|
154 |
+
res += f"IP address: {request.client.host} <p>"
|
155 |
+
res += f"Query parameters: {dict(request.query_params)} <p>"
|
156 |
+
res += "</div>"
|
157 |
+
|
158 |
+
return res
|
159 |
+
|
160 |
+
def sample_data(text):
|
161 |
+
text = "The red dots represents LCI , the bright yellow rectangle represents RV , and the black triangle represents the /TLCnLCI"
|
162 |
+
|
163 |
+
#dat = get_cached_data()
|
164 |
+
|
165 |
+
#df = dat['test']['tokens'].sample(5)
|
166 |
+
|
167 |
+
data = {
|
168 |
+
"Text": [text],
|
169 |
+
"Length": [len(text)]
|
170 |
+
}
|
171 |
+
df = pd.DataFrame(data)
|
172 |
+
return df
|
173 |
+
|
174 |
+
|
175 |
|
176 |
# Create the Gradio interface
|
177 |
demo = gr.Interface(
|
|
|
182 |
title="Named Entity Recognition on PLOD-CW ",
|
183 |
description=f"{PROJECT_INTRO}\n\nEnter text to extract named entities using a NER model."
|
184 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
|
186 |
+
with gr.Blocks() as demo:
|
187 |
+
gr.Markdown("# Named Entity Recognition on PLOD-CW")
|
188 |
+
gr.Markdown(PROJECT_INTRO)
|
189 |
+
gr.Markdown("### Enter text to extract named entities using a NER model.")
|
190 |
+
text_input = gr.Textbox(lines=10, placeholder="Enter text here...", label="Input Text")
|
191 |
+
html_output = gr.HTML(label="HTML Output")
|
192 |
+
|
193 |
+
with gr.Row():
|
194 |
+
submit_button = gr.Button("Submit")
|
195 |
+
echo_button = gr.Button("Echo Client")
|
196 |
+
sample_button = gr.Button("Sample PLOD_CW")
|
197 |
+
|
198 |
+
sample_output = gr.Dataframe(label="Sample Table")
|
199 |
+
echo_output = gr.HTML(label="HTML Output")
|
200 |
+
|
201 |
+
submit_button.click(ner_demo, inputs=text_input, outputs=html_output)
|
202 |
|
203 |
+
echo_button.click(echo, inputs=text_input, outputs=echo_output)
|
204 |
+
sample_button.click(sample_data, inputs=text_input, outputs=sample_output)
|
205 |
|
|
|
|
|
|
|
206 |
|
207 |
# Function to run Gradio
|
208 |
|