Spaces:
Build error
Build error
mvy
commited on
Commit
·
8e19b14
1
Parent(s):
6f3f044
add validations checks
Browse files
app.py
CHANGED
|
@@ -24,7 +24,7 @@ examples = [
|
|
| 24 |
ner = NER('knowledgator/UTC-DeBERTa-small')
|
| 25 |
|
| 26 |
gradio_app = gr.Interface(
|
| 27 |
-
ner,
|
| 28 |
inputs = [
|
| 29 |
'text',
|
| 30 |
gr.Textbox(placeholder="Enter sentence here..."),
|
|
|
|
| 24 |
ner = NER('knowledgator/UTC-DeBERTa-small')
|
| 25 |
|
| 26 |
gradio_app = gr.Interface(
|
| 27 |
+
ner.process,
|
| 28 |
inputs = [
|
| 29 |
'text',
|
| 30 |
gr.Textbox(placeholder="Enter sentence here..."),
|
ner.py
CHANGED
|
@@ -4,6 +4,8 @@ import string
|
|
| 4 |
from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
|
| 5 |
import spacy
|
| 6 |
import torch
|
|
|
|
|
|
|
| 7 |
|
| 8 |
class NER:
|
| 9 |
prompt: str = """
|
|
@@ -13,8 +15,14 @@ Identify entities in the text having the following classes:
|
|
| 13 |
Text:
|
| 14 |
"""
|
| 15 |
|
| 16 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
self.sents_batch = sents_batch
|
|
|
|
| 18 |
|
| 19 |
self.nlp: spacy.Language = spacy.load(
|
| 20 |
'en_core_web_sm',
|
|
@@ -23,13 +31,13 @@ Text:
|
|
| 23 |
self.nlp.add_pipe('sentencizer')
|
| 24 |
|
| 25 |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
| 26 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 27 |
model = AutoModelForTokenClassification.from_pretrained(model_name)
|
| 28 |
|
| 29 |
self.pipeline = pipeline(
|
| 30 |
"ner",
|
| 31 |
model=model,
|
| 32 |
-
tokenizer=tokenizer,
|
| 33 |
aggregation_strategy='first',
|
| 34 |
batch_size=12,
|
| 35 |
device=device
|
|
@@ -115,14 +123,47 @@ Text:
|
|
| 115 |
return outputs
|
| 116 |
|
| 117 |
|
| 118 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
self, labels: str, text: str, threshold: float=0.
|
| 120 |
) -> dict[str, any]:
|
| 121 |
-
labels_list =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
chunks, chunks_starts = self.chunkanize(text)
|
| 124 |
inputs, prompts_lens = self.get_inputs(chunks, labels_list)
|
| 125 |
|
|
|
|
|
|
|
| 126 |
outputs = self.predict(
|
| 127 |
text, inputs, labels_list, chunks_starts, prompts_lens, threshold
|
| 128 |
)
|
|
|
|
| 4 |
from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
|
| 5 |
import spacy
|
| 6 |
import torch
|
| 7 |
+
import gradio as gr
|
| 8 |
+
|
| 9 |
|
| 10 |
class NER:
|
| 11 |
prompt: str = """
|
|
|
|
| 15 |
Text:
|
| 16 |
"""
|
| 17 |
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
model_name: str,
|
| 21 |
+
sents_batch: int=10,
|
| 22 |
+
tokens_limit: int=2048
|
| 23 |
+
):
|
| 24 |
self.sents_batch = sents_batch
|
| 25 |
+
self.tokens_limit = tokens_limit
|
| 26 |
|
| 27 |
self.nlp: spacy.Language = spacy.load(
|
| 28 |
'en_core_web_sm',
|
|
|
|
| 31 |
self.nlp.add_pipe('sentencizer')
|
| 32 |
|
| 33 |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
| 34 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 35 |
model = AutoModelForTokenClassification.from_pretrained(model_name)
|
| 36 |
|
| 37 |
self.pipeline = pipeline(
|
| 38 |
"ner",
|
| 39 |
model=model,
|
| 40 |
+
tokenizer=self.tokenizer,
|
| 41 |
aggregation_strategy='first',
|
| 42 |
batch_size=12,
|
| 43 |
device=device
|
|
|
|
| 123 |
return outputs
|
| 124 |
|
| 125 |
|
| 126 |
+
def check_text(self, text: str) -> None:
|
| 127 |
+
if not text:
|
| 128 |
+
raise gr.Error('No text provided. Please provide text.')
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def check_labels(self, labels: list[str]) -> None:
|
| 132 |
+
if not labels:
|
| 133 |
+
raise gr.Error(
|
| 134 |
+
'No labels provided. Please provide labels.'
|
| 135 |
+
' Multiple labels should be divided by commas.'
|
| 136 |
+
' See examples below.'
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def check_tokens_limit(self, inputs: list[str]) -> None:
|
| 141 |
+
tokens = 0
|
| 142 |
+
for input_ in inputs:
|
| 143 |
+
tokens += len(self.tokenizer.encode(input_))
|
| 144 |
+
if tokens > self.tokens_limit:
|
| 145 |
+
raise gr.Error(
|
| 146 |
+
'Too many tokens! Please reduce size of text or amount of labels.'
|
| 147 |
+
f' Max tokens count is: {self.tokens_limit}.'
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def process(
|
| 152 |
self, labels: str, text: str, threshold: float=0.
|
| 153 |
) -> dict[str, any]:
|
| 154 |
+
labels_list = list({
|
| 155 |
+
l for label in labels.split(',')
|
| 156 |
+
if (l:=label.strip())
|
| 157 |
+
})
|
| 158 |
+
|
| 159 |
+
self.check_labels(labels_list)
|
| 160 |
+
self.check_text(text)
|
| 161 |
|
| 162 |
chunks, chunks_starts = self.chunkanize(text)
|
| 163 |
inputs, prompts_lens = self.get_inputs(chunks, labels_list)
|
| 164 |
|
| 165 |
+
self.check_tokens_limit(inputs)
|
| 166 |
+
|
| 167 |
outputs = self.predict(
|
| 168 |
text, inputs, labels_list, chunks_starts, prompts_lens, threshold
|
| 169 |
)
|