File size: 3,774 Bytes
f9bc688
 
 
 
 
 
 
 
 
 
 
 
 
 
7fd6f11
f9bc688
f625748
f9bc688
 
 
e947e04
 
 
 
f625748
e947e04
f9bc688
f625748
9daea47
f9bc688
e947e04
f625748
e947e04
f625748
e947e04
f9bc688
f625748
f9bc688
 
70ca6fa
df09c16
f9bc688
 
 
 
 
df09c16
f9bc688
 
 
 
70ca6fa
f9bc688
 
 
df09c16
f9bc688
 
 
 
 
 
 
 
 
 
 
 
 
f625748
 
 
 
 
70ca6fa
f9bc688
f625748
 
f9bc688
 
70ca6fa
f9bc688
f18abee
378ec43
 
 
 
f18abee
 
 
 
 
378ec43
f9bc688
 
 
 
 
 
f625748
 
f9bc688
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import random
from gliner import GLiNER
import gradio as gr
from datasets import load_dataset

# Load the BL dataset as a streaming iterator
dataset_iter = load_dataset(
    "TheBritishLibrary/blbooks",
    split="train",
    streaming=True,  # Enable streaming
    trust_remote_code=True
).shuffle(seed=42)  # Shuffle added

# Load the model
model = GLiNER.from_pretrained("urchade/gliner_multi-v2.1", trust_remote_code=True)

def ner(text: str, labels: str, threshold: float, nested_ner: bool):
    # Convert user-provided labels (comma-separated string) into a list
    labels_list = [label.strip() for label in labels.split(",")]
    
    # Truncate the text to avoid length exceeding model limits (e.g., 384 tokens)
    max_length = 384
    truncated_text = text[:max_length]
    
    # Predict entities using the GLiNER model
    entities = model.predict_entities(truncated_text, labels_list, flat_ner=not nested_ner, threshold=threshold)
    
    # Prepare entities for color-coded display using gr.HighlightedText
    highlights = [{"start": ent["start"], "end": ent["end"], "entity": ent["label"]} for ent in entities]
    
    # Return both the highlighted text and the raw entities in JSON format
    return {
        "text": truncated_text,
        "entities": highlights
    }, entities  # Return both outputs: the first for HighlightedText, the second for JSON

with gr.Blocks(title="General NER with Color-Coded Output") as demo:
    gr.Markdown(
        """
        # GLiNER British Library Books Demo
        This demo selects a random text snippet from the British Library's books dataset and identifies entities using GLiNER (urchade/gliner_multi-v2.1). 
        """
    )
    
    # Display a random example
    input_text = gr.Textbox(
        value="Click on 'Get New Snippet' to load a piece of text from the British Library dataset",
        label="Text input",
        placeholder="Enter your text here",
        lines=5
    )
    refresh_btn = gr.Button("Get New Snippet")

    with gr.Row() as row:
        labels = gr.Textbox(
            value="Person, Location",  # Default example labels
            label="Labels",
            placeholder="Enter your labels here (comma separated)",
            scale=2,
        )
        threshold = gr.Slider(
            0,
            1,
            value=0.5,  # Adjusted to match the threshold used in the function
            step=0.01,
            label="Threshold",
            info="Lower the threshold to increase how many entities get predicted.",
            scale=1,
        )
        nested_ner = gr.Checkbox(
            value=False,
            label="Nested NER",
            info="Enable Nested NER?",
        )
    submit_btn = gr.Button("Find Entities!")

    # Define output components using HighlightedText for color-coded display
    output_highlighted = gr.HighlightedText(label="Predicted Entities")
    output_entities = gr.JSON(label="Entities")
    

    def get_new_snippet():
        # Preload several samples into a list
        max_length = 384  # Maximum length for snippets
        samples = [
            sample['text'][:max_length] for sample, _ in zip(dataset_iter, range(100))  # Truncate to max_length
        ]
        
        # Return a random snippet from the preloaded samples
        if samples:
            return random.choice(samples)
        
        return "No more snippets available."  # Return this if no valid snippets are found 
    
    # Connect refresh button
    refresh_btn.click(fn=get_new_snippet, outputs=input_text)
    
    # Connect submit button
    submit_btn.click(
        fn=ner,
        inputs=[input_text, labels, threshold, nested_ner],
        outputs=[output_highlighted, output_entities]
    )

demo.queue()
demo.launch(debug=True)