File size: 4,168 Bytes
8da3546
 
 
 
 
 
 
 
 
f3ab87f
8da3546
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import streamlit as st
from transformers import pipeline
import time
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
# Load the NER pipeline
print('Preparing pipeline ...\n')
pipe = pipeline("ner",
                model="seddiktrk/xlm-roberta-base-finetuned-panx-all",
                device=device)
print('\nPipe Ready !!!')
# Example texts
examples = {
    "en": "My name is Clara and I live in Berkeley, California.",
    "fr": "Je m'appelle Marie et je travaille dans un café à Lyon.",
    "ar": "اسمي أحمد وأدرس في جامعة القاهرة.",
    "de": "Mein Name ist Hans und ich komme aus München.",
    "es": "Mi nombre es Lucía y vivo en una pequeña ciudad en México.",
    "it": "Mi chiamo Giulia e faccio il medico a Roma.",
    "pt": "Chamo-me Ana e moro em uma fazenda no Brasil.",
    "ru": "Меня зовут Ольга, и я живу в Санкт-Петербурге.",
    "jp": "私の名前は佐藤です。東京でITエンジニアとして働いています",
    "zh": "我叫李华,在北京的一家公司上班"

}

# Define colors for each entity type
ENTITY_COLORS = {
    "PER": ("#F7D4DA", "#E31A1C"),  # Light pink background, red text
    "ORG": ("#D4E2F4", "#2171B5"),  # Light blue background, blue text
    "LOC": ("#E8DAEF", "#6A51A3"),  # Light purple background, purple text
    #"MISC": ("#FFE5B4", "#FF8C00"),  # Light orange background, dark orange text
}
def get_colored_text(text, entities):
    offset = 0
    for entity in entities:
        start = entity['start'] + offset
        end = entity['end'] + offset
        label = entity['entity_group']
        background_color, text_color = ENTITY_COLORS.get(label, ("#FFD700", "#FF4500"))

        # HTML structure for styled entity display
        entity_text = f'''
        <span style="
            background-color:{background_color};
            padding: 3px 5px;
            border-radius: 5px;
            margin: 0 2px;
            display: inline-block;
            ">
            {text[start:end]}
            <span style="
                background-color:{text_color};
                color: white;
                padding: 1px 5px;
                border-radius: 5px;
                margin-left: 5px;
                font-size: 0.85em;
                vertical-align: middle;
                ">
                {label}
            </span>
        </span>
        '''
        
        # Replace the original text with the colored entity text
        text = text[:start] + entity_text + text[end:]
        
        # Update offset to adjust for the added characters in entity_text
        offset += len(entity_text) - (end - start)
    
    return text
# Streamlit interface

# Streamlit app
st.title('Multilingual NER')
st.markdown(
    """
    <p style='color: grey; font-size: 0.85em;'>
    This application performs Named Entity Recognition (NER) across 100+ languages.
    The model excels in cross-lingual transfer and capable of processing text that contains multiple languages simultaneously.
    </p>
    """,
    unsafe_allow_html=True
)
st.write("### 🔠 Token Classification")


# Create a two-column layout
col1, col2 = st.columns([4, 1])  # Adjust column widths as needed

# Dropdown in the right column
with col2:
    selected_example = st.selectbox(
        'Select an example:', 
        list(examples.keys()), 
    )

# Text area in the left column
with col1:
    user_input = st.text_area('Enter your text here:', value=examples[selected_example])


# Button to compute
if st.button("Compute"):
    with st.spinner():
        start_time = time.time()
        # Get NER results
        ner_results = pipe(user_input,aggregation_strategy="simple")
       
        # Display the results
        colored_text = get_colored_text(user_input, ner_results)

        # Display the results
        st.markdown(colored_text, unsafe_allow_html=True)
        end_time = time.time()
        st.write(f"Inference time: {end_time - start_time:.2f} seconds")
        with st.expander("Show raw output"):
            raw_results = pipe(user_input)
            st.json(raw_results)