File size: 3,774 Bytes
f9bc688 7fd6f11 f9bc688 f625748 f9bc688 e947e04 f625748 e947e04 f9bc688 f625748 9daea47 f9bc688 e947e04 f625748 e947e04 f625748 e947e04 f9bc688 f625748 f9bc688 70ca6fa df09c16 f9bc688 df09c16 f9bc688 70ca6fa f9bc688 df09c16 f9bc688 f625748 70ca6fa f9bc688 f625748 f9bc688 70ca6fa f9bc688 f18abee 378ec43 f18abee 378ec43 f9bc688 f625748 f9bc688 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import random
from gliner import GLiNER
import gradio as gr
from datasets import load_dataset
# Load the BL dataset as a streaming iterator
dataset_iter = load_dataset(
"TheBritishLibrary/blbooks",
split="train",
streaming=True, # Enable streaming
trust_remote_code=True
).shuffle(seed=42) # Shuffle added
# Load the model
model = GLiNER.from_pretrained("urchade/gliner_multi-v2.1", trust_remote_code=True)
def ner(text: str, labels: str, threshold: float, nested_ner: bool):
# Convert user-provided labels (comma-separated string) into a list
labels_list = [label.strip() for label in labels.split(",")]
# Truncate the text to avoid length exceeding model limits (e.g., 384 tokens)
max_length = 384
truncated_text = text[:max_length]
# Predict entities using the GLiNER model
entities = model.predict_entities(truncated_text, labels_list, flat_ner=not nested_ner, threshold=threshold)
# Prepare entities for color-coded display using gr.HighlightedText
highlights = [{"start": ent["start"], "end": ent["end"], "entity": ent["label"]} for ent in entities]
# Return both the highlighted text and the raw entities in JSON format
return {
"text": truncated_text,
"entities": highlights
}, entities # Return both outputs: the first for HighlightedText, the second for JSON
with gr.Blocks(title="General NER with Color-Coded Output") as demo:
gr.Markdown(
"""
# GLiNER British Library Books Demo
This demo selects a random text snippet from the British Library's books dataset and identifies entities using GLiNER (urchade/gliner_multi-v2.1).
"""
)
# Display a random example
input_text = gr.Textbox(
value="Click on 'Get New Snippet' to load a piece of text from the British Library dataset",
label="Text input",
placeholder="Enter your text here",
lines=5
)
refresh_btn = gr.Button("Get New Snippet")
with gr.Row() as row:
labels = gr.Textbox(
value="Person, Location", # Default example labels
label="Labels",
placeholder="Enter your labels here (comma separated)",
scale=2,
)
threshold = gr.Slider(
0,
1,
value=0.5, # Adjusted to match the threshold used in the function
step=0.01,
label="Threshold",
info="Lower the threshold to increase how many entities get predicted.",
scale=1,
)
nested_ner = gr.Checkbox(
value=False,
label="Nested NER",
info="Enable Nested NER?",
)
submit_btn = gr.Button("Find Entities!")
# Define output components using HighlightedText for color-coded display
output_highlighted = gr.HighlightedText(label="Predicted Entities")
output_entities = gr.JSON(label="Entities")
def get_new_snippet():
# Preload several samples into a list
max_length = 384 # Maximum length for snippets
samples = [
sample['text'][:max_length] for sample, _ in zip(dataset_iter, range(100)) # Truncate to max_length
]
# Return a random snippet from the preloaded samples
if samples:
return random.choice(samples)
return "No more snippets available." # Return this if no valid snippets are found
# Connect refresh button
refresh_btn.click(fn=get_new_snippet, outputs=input_text)
# Connect submit button
submit_btn.click(
fn=ner,
inputs=[input_text, labels, threshold, nested_ner],
outputs=[output_highlighted, output_entities]
)
demo.queue()
demo.launch(debug=True) |