max-long's picture
Update app.py
378ec43 verified
import random
from gliner import GLiNER
import gradio as gr
from datasets import load_dataset
# Load the BL dataset as a streaming iterator
dataset_iter = load_dataset(
"TheBritishLibrary/blbooks",
split="train",
streaming=True, # Enable streaming
trust_remote_code=True
).shuffle(seed=42) # Shuffle added
# Load the model
model = GLiNER.from_pretrained("urchade/gliner_multi-v2.1", trust_remote_code=True)
def ner(text: str, labels: str, threshold: float, nested_ner: bool):
# Convert user-provided labels (comma-separated string) into a list
labels_list = [label.strip() for label in labels.split(",")]
# Truncate the text to avoid length exceeding model limits (e.g., 384 tokens)
max_length = 384
truncated_text = text[:max_length]
# Predict entities using the GLiNER model
entities = model.predict_entities(truncated_text, labels_list, flat_ner=not nested_ner, threshold=threshold)
# Prepare entities for color-coded display using gr.HighlightedText
highlights = [{"start": ent["start"], "end": ent["end"], "entity": ent["label"]} for ent in entities]
# Return both the highlighted text and the raw entities in JSON format
return {
"text": truncated_text,
"entities": highlights
}, entities # Return both outputs: the first for HighlightedText, the second for JSON
with gr.Blocks(title="General NER with Color-Coded Output") as demo:
gr.Markdown(
"""
# GLiNER British Library Books Demo
This demo selects a random text snippet from the British Library's books dataset and identifies entities using GLiNER (urchade/gliner_multi-v2.1).
"""
)
# Display a random example
input_text = gr.Textbox(
value="Click on 'Get New Snippet' to load a piece of text from the British Library dataset",
label="Text input",
placeholder="Enter your text here",
lines=5
)
refresh_btn = gr.Button("Get New Snippet")
with gr.Row() as row:
labels = gr.Textbox(
value="Person, Location", # Default example labels
label="Labels",
placeholder="Enter your labels here (comma separated)",
scale=2,
)
threshold = gr.Slider(
0,
1,
value=0.5, # Adjusted to match the threshold used in the function
step=0.01,
label="Threshold",
info="Lower the threshold to increase how many entities get predicted.",
scale=1,
)
nested_ner = gr.Checkbox(
value=False,
label="Nested NER",
info="Enable Nested NER?",
)
submit_btn = gr.Button("Find Entities!")
# Define output components using HighlightedText for color-coded display
output_highlighted = gr.HighlightedText(label="Predicted Entities")
output_entities = gr.JSON(label="Entities")
def get_new_snippet():
# Preload several samples into a list
max_length = 384 # Maximum length for snippets
samples = [
sample['text'][:max_length] for sample, _ in zip(dataset_iter, range(100)) # Truncate to max_length
]
# Return a random snippet from the preloaded samples
if samples:
return random.choice(samples)
return "No more snippets available." # Return this if no valid snippets are found
# Connect refresh button
refresh_btn.click(fn=get_new_snippet, outputs=input_text)
# Connect submit button
submit_btn.click(
fn=ner,
inputs=[input_text, labels, threshold, nested_ner],
outputs=[output_highlighted, output_entities]
)
demo.queue()
demo.launch(debug=True)