|
import random |
|
from gliner import GLiNER |
|
import gradio as gr |
|
from datasets import load_dataset |
|
|
|
|
|
dataset_iter = load_dataset( |
|
"TheBritishLibrary/blbooks", |
|
split="train", |
|
streaming=True, |
|
trust_remote_code=True |
|
).shuffle(seed=42) |
|
|
|
|
|
model = GLiNER.from_pretrained("urchade/gliner_multi-v2.1", trust_remote_code=True) |
|
|
|
def ner(text: str, labels: str, threshold: float, nested_ner: bool): |
|
|
|
labels_list = [label.strip() for label in labels.split(",")] |
|
|
|
|
|
max_length = 384 |
|
truncated_text = text[:max_length] |
|
|
|
|
|
entities = model.predict_entities(truncated_text, labels_list, flat_ner=not nested_ner, threshold=threshold) |
|
|
|
|
|
highlights = [{"start": ent["start"], "end": ent["end"], "entity": ent["label"]} for ent in entities] |
|
|
|
|
|
return { |
|
"text": truncated_text, |
|
"entities": highlights |
|
}, entities |
|
|
|
with gr.Blocks(title="General NER with Color-Coded Output") as demo: |
|
gr.Markdown( |
|
""" |
|
# GLiNER British Library Books Demo |
|
This demo selects a random text snippet from the British Library's books dataset and identifies entities using GLiNER (urchade/gliner_multi-v2.1). |
|
""" |
|
) |
|
|
|
|
|
input_text = gr.Textbox( |
|
value="Click on 'Get New Snippet' to load a piece of text from the British Library dataset", |
|
label="Text input", |
|
placeholder="Enter your text here", |
|
lines=5 |
|
) |
|
refresh_btn = gr.Button("Get New Snippet") |
|
|
|
with gr.Row() as row: |
|
labels = gr.Textbox( |
|
value="Person, Location", |
|
label="Labels", |
|
placeholder="Enter your labels here (comma separated)", |
|
scale=2, |
|
) |
|
threshold = gr.Slider( |
|
0, |
|
1, |
|
value=0.5, |
|
step=0.01, |
|
label="Threshold", |
|
info="Lower the threshold to increase how many entities get predicted.", |
|
scale=1, |
|
) |
|
nested_ner = gr.Checkbox( |
|
value=False, |
|
label="Nested NER", |
|
info="Enable Nested NER?", |
|
) |
|
submit_btn = gr.Button("Find Entities!") |
|
|
|
|
|
output_highlighted = gr.HighlightedText(label="Predicted Entities") |
|
output_entities = gr.JSON(label="Entities") |
|
|
|
|
|
def get_new_snippet(): |
|
|
|
max_length = 384 |
|
samples = [ |
|
sample['text'][:max_length] for sample, _ in zip(dataset_iter, range(100)) |
|
] |
|
|
|
|
|
if samples: |
|
return random.choice(samples) |
|
|
|
return "No more snippets available." |
|
|
|
|
|
refresh_btn.click(fn=get_new_snippet, outputs=input_text) |
|
|
|
|
|
submit_btn.click( |
|
fn=ner, |
|
inputs=[input_text, labels, threshold, nested_ner], |
|
outputs=[output_highlighted, output_entities] |
|
) |
|
|
|
demo.queue() |
|
demo.launch(debug=True) |