Update app.py
Browse files
app.py
CHANGED
@@ -20,8 +20,6 @@ id2label = {
|
|
20 |
6: 'I-UNFAIR'
|
21 |
}
|
22 |
|
23 |
-
label2id = {v: k for k, v in id2label.items()}
|
24 |
-
|
25 |
# Entity colors for highlights
|
26 |
label_colors = {
|
27 |
"STEREO": "rgba(255, 0, 0, 0.2)", # Light Red
|
@@ -34,38 +32,23 @@ def post_process_entities(result):
|
|
34 |
prev_entity_type = None
|
35 |
for token_data in result:
|
36 |
labels = token_data["labels"]
|
37 |
-
labels = list(set(labels))
|
38 |
-
|
39 |
-
# Handle conflicting B- and I- tags for the same entity
|
40 |
-
for entity_type in ["GEN", "UNFAIR", "STEREO"]:
|
41 |
-
if f"B-{entity_type}" in labels and f"I-{entity_type}" in labels:
|
42 |
-
labels.remove(f"I-{entity_type}")
|
43 |
|
44 |
# Handle sequence rules
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
if label.startswith("B-")
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
if current_label.startswith("I-") and prev_entity_type != current_entity_type:
|
57 |
-
labels.remove(current_label)
|
58 |
-
labels.append(f"B-{current_entity_type}")
|
59 |
-
|
60 |
-
prev_entity_type = current_entity_type
|
61 |
-
else:
|
62 |
-
prev_entity_type = None
|
63 |
-
|
64 |
-
token_data["labels"] = labels
|
65 |
return result
|
66 |
|
67 |
-
# Generate JSON results
|
68 |
-
def
|
69 |
inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128)
|
70 |
input_ids = inputs['input_ids'].to(model.device)
|
71 |
attention_mask = inputs['attention_mask'].to(model.device)
|
@@ -74,26 +57,24 @@ def generate_json(sentence):
|
|
74 |
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
75 |
logits = outputs.logits
|
76 |
probabilities = torch.sigmoid(logits)
|
77 |
-
predicted_labels = (probabilities > 0.5).int()
|
78 |
|
79 |
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
|
80 |
result = []
|
81 |
for i, token in enumerate(tokens):
|
82 |
if token not in tokenizer.all_special_tokens:
|
83 |
-
label_indices = (
|
84 |
-
labels = [
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
result.append({"token": token.replace("##", ""), "labels": labels})
|
86 |
|
87 |
result = post_process_entities(result)
|
88 |
|
89 |
-
|
90 |
-
|
91 |
-
# Predict function
|
92 |
-
def predict_ner_tags_with_json(sentence):
|
93 |
-
json_result = generate_json(sentence)
|
94 |
-
|
95 |
-
result = json.loads(json_result)
|
96 |
-
|
97 |
word_row = []
|
98 |
stereo_row = []
|
99 |
gen_row = []
|
@@ -105,19 +86,28 @@ def predict_ner_tags_with_json(sentence):
|
|
105 |
|
106 |
word_row.append(f"<span style='font-weight:bold;'>{token}</span>")
|
107 |
|
108 |
-
|
|
|
|
|
|
|
109 |
stereo_row.append(
|
110 |
f"<span style='background:{label_colors['STEREO']}; border-radius:6px; padding:2px 5px;'>{', '.join(stereo_labels)}</span>"
|
111 |
if stereo_labels else " "
|
112 |
)
|
113 |
|
114 |
-
|
|
|
|
|
|
|
115 |
gen_row.append(
|
116 |
f"<span style='background:{label_colors['GEN']}; border-radius:6px; padding:2px 5px;'>{', '.join(gen_labels)}</span>"
|
117 |
if gen_labels else " "
|
118 |
)
|
119 |
|
120 |
-
|
|
|
|
|
|
|
121 |
unfair_row.append(
|
122 |
f"<span style='background:{label_colors['UNFAIR']}; border-radius:6px; padding:2px 5px;'>{', '.join(unfair_labels)}</span>"
|
123 |
if unfair_labels else " "
|
@@ -144,7 +134,7 @@ def predict_ner_tags_with_json(sentence):
|
|
144 |
</table>
|
145 |
"""
|
146 |
|
147 |
-
return
|
148 |
|
149 |
# Gradio Interface
|
150 |
iface = gr.Blocks()
|
@@ -168,7 +158,7 @@ with iface:
|
|
168 |
with gr.Row():
|
169 |
input_box = gr.Textbox(label="Input Sentence")
|
170 |
with gr.Row():
|
171 |
-
output_box = gr.HTML(label="Entity Matrix
|
172 |
|
173 |
input_box.change(predict_ner_tags_with_json, inputs=[input_box], outputs=[output_box])
|
174 |
|
|
|
20 |
6: 'I-UNFAIR'
|
21 |
}
|
22 |
|
|
|
|
|
23 |
# Entity colors for highlights
|
24 |
label_colors = {
|
25 |
"STEREO": "rgba(255, 0, 0, 0.2)", # Light Red
|
|
|
32 |
prev_entity_type = None
|
33 |
for token_data in result:
|
34 |
labels = token_data["labels"]
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
# Handle sequence rules
|
37 |
+
new_labels = []
|
38 |
+
for label_data in labels:
|
39 |
+
label = label_data['label']
|
40 |
+
if label.startswith("B-") and prev_entity_type == label[2:]:
|
41 |
+
new_labels.append({"label": f"I-{label[2:]}", "confidence": label_data["confidence"]})
|
42 |
+
elif label.startswith("I-") and prev_entity_type != label[2:]:
|
43 |
+
new_labels.append({"label": f"B-{label[2:]}", "confidence": label_data["confidence"]})
|
44 |
+
else:
|
45 |
+
new_labels.append(label_data)
|
46 |
+
prev_entity_type = label[2:]
|
47 |
+
token_data["labels"] = new_labels
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
return result
|
49 |
|
50 |
+
# Generate JSON results with probabilities
|
51 |
+
def predict_ner_tags_with_json(sentence):
|
52 |
inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128)
|
53 |
input_ids = inputs['input_ids'].to(model.device)
|
54 |
attention_mask = inputs['attention_mask'].to(model.device)
|
|
|
57 |
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
58 |
logits = outputs.logits
|
59 |
probabilities = torch.sigmoid(logits)
|
|
|
60 |
|
61 |
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
|
62 |
result = []
|
63 |
for i, token in enumerate(tokens):
|
64 |
if token not in tokenizer.all_special_tokens:
|
65 |
+
label_indices = (probabilities[0][i] > 0.52).nonzero(as_tuple=False).squeeze(-1)
|
66 |
+
labels = [
|
67 |
+
{
|
68 |
+
"label": id2label[idx.item()],
|
69 |
+
"confidence": round(probabilities[0][i][idx].item() * 100, 2)
|
70 |
+
}
|
71 |
+
for idx in label_indices
|
72 |
+
]
|
73 |
result.append({"token": token.replace("##", ""), "labels": labels})
|
74 |
|
75 |
result = post_process_entities(result)
|
76 |
|
77 |
+
# Create table rows
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
word_row = []
|
79 |
stereo_row = []
|
80 |
gen_row = []
|
|
|
86 |
|
87 |
word_row.append(f"<span style='font-weight:bold;'>{token}</span>")
|
88 |
|
89 |
+
# STEREO
|
90 |
+
stereo_labels = [
|
91 |
+
f"{label_data['label'][2:]} ({label_data['confidence']}%)" for label_data in labels if "STEREO" in label_data["label"]
|
92 |
+
]
|
93 |
stereo_row.append(
|
94 |
f"<span style='background:{label_colors['STEREO']}; border-radius:6px; padding:2px 5px;'>{', '.join(stereo_labels)}</span>"
|
95 |
if stereo_labels else " "
|
96 |
)
|
97 |
|
98 |
+
# GEN
|
99 |
+
gen_labels = [
|
100 |
+
f"{label_data['label'][2:]} ({label_data['confidence']}%)" for label_data in labels if "GEN" in label_data["label"]
|
101 |
+
]
|
102 |
gen_row.append(
|
103 |
f"<span style='background:{label_colors['GEN']}; border-radius:6px; padding:2px 5px;'>{', '.join(gen_labels)}</span>"
|
104 |
if gen_labels else " "
|
105 |
)
|
106 |
|
107 |
+
# UNFAIR
|
108 |
+
unfair_labels = [
|
109 |
+
f"{label_data['label'][2:]} ({label_data['confidence']}%)" for label_data in labels if "UNFAIR" in label_data["label"]
|
110 |
+
]
|
111 |
unfair_row.append(
|
112 |
f"<span style='background:{label_colors['UNFAIR']}; border-radius:6px; padding:2px 5px;'>{', '.join(unfair_labels)}</span>"
|
113 |
if unfair_labels else " "
|
|
|
134 |
</table>
|
135 |
"""
|
136 |
|
137 |
+
return matrix_html
|
138 |
|
139 |
# Gradio Interface
|
140 |
iface = gr.Blocks()
|
|
|
158 |
with gr.Row():
|
159 |
input_box = gr.Textbox(label="Input Sentence")
|
160 |
with gr.Row():
|
161 |
+
output_box = gr.HTML(label="Entity Matrix")
|
162 |
|
163 |
input_box.change(predict_ner_tags_with_json, inputs=[input_box], outputs=[output_box])
|
164 |
|