Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -3,13 +3,11 @@ import torch
|
|
3 |
from transformers import BertTokenizerFast, BertForTokenClassification
|
4 |
import gradio as gr
|
5 |
|
6 |
-
# Initialize important things
|
7 |
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
|
8 |
model = BertForTokenClassification.from_pretrained('maximuspowers/bias-detection-ner')
|
9 |
model.eval()
|
10 |
model.to('cuda' if torch.cuda.is_available() else 'cpu')
|
11 |
|
12 |
-
# IDs to labels we want to display
|
13 |
id2label = {
|
14 |
0: 'O',
|
15 |
1: 'B-STEREO',
|
@@ -20,24 +18,50 @@ id2label = {
|
|
20 |
6: 'I-UNFAIR'
|
21 |
}
|
22 |
|
23 |
-
|
|
|
24 |
label_colors = {
|
25 |
-
"STEREO": "rgba(255, 0, 0, 0.
|
26 |
-
"GEN": "rgba(0, 0, 255, 0.
|
27 |
-
"UNFAIR": "rgba(0, 255, 0, 0.
|
28 |
}
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128)
|
42 |
input_ids = inputs['input_ids'].to(model.device)
|
43 |
attention_mask = inputs['attention_mask'].to(model.device)
|
@@ -46,47 +70,81 @@ def predict_ner_tags(sentence):
|
|
46 |
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
47 |
logits = outputs.logits
|
48 |
probabilities = torch.sigmoid(logits)
|
49 |
-
predicted_labels = (probabilities > 0.5).int()
|
50 |
|
51 |
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
|
52 |
-
|
53 |
-
prev_labels = []
|
54 |
-
|
55 |
for i, token in enumerate(tokens):
|
56 |
if token not in tokenizer.all_special_tokens:
|
57 |
-
# Extract the labels for this token
|
58 |
label_indices = (predicted_labels[0][i] == 1).nonzero(as_tuple=False).squeeze(-1)
|
59 |
-
labels = [id2label[idx.item()]
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
iface = gr.Interface(
|
87 |
-
fn=
|
88 |
-
inputs=gr.Textbox(label="Input Sentence"),
|
89 |
-
outputs=gr.HTML(label="
|
90 |
title="Social Bias Named Entity Recognition (with BERT) 🕵",
|
91 |
description=("Enter a sentence to predict biased parts of speech tags. This model uses multi-label BertForTokenClassification, to label the entities: (GEN)eralizations, (UNFAIR)ness, and (STEREO)types. Labels follow BIO format. Try it out :)."
|
92 |
"<br><br>Read more about how this model was trained in this <a href='https://huggingface.co/blog/maximuspowers/bias-entity-recognition' target='_blank'>blog post</a>."
|
|
|
3 |
from transformers import BertTokenizerFast, BertForTokenClassification
|
4 |
import gradio as gr
|
5 |
|
|
|
6 |
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
|
7 |
model = BertForTokenClassification.from_pretrained('maximuspowers/bias-detection-ner')
|
8 |
model.eval()
|
9 |
model.to('cuda' if torch.cuda.is_available() else 'cpu')
|
10 |
|
|
|
11 |
id2label = {
|
12 |
0: 'O',
|
13 |
1: 'B-STEREO',
|
|
|
18 |
6: 'I-UNFAIR'
|
19 |
}
|
20 |
|
21 |
+
label2id = {v: k for k, v in id2label.items()}
|
22 |
+
|
23 |
label_colors = {
|
24 |
+
"STEREO": "rgba(255, 0, 0, 0.2)",
|
25 |
+
"GEN": "rgba(0, 0, 255, 0.2)",
|
26 |
+
"UNFAIR": "rgba(0, 255, 0, 0.2)"
|
27 |
}
|
28 |
|
29 |
+
def post_process_entities(result):
|
30 |
+
prev_entity_type = None
|
31 |
+
|
32 |
+
for i, token_data in enumerate(result):
|
33 |
+
labels = token_data["labels"]
|
34 |
+
|
35 |
+
labels = list(set(labels))
|
36 |
+
for entity_type in ["GEN", "UNFAIR", "STEREO"]:
|
37 |
+
if f"B-{entity_type}" in labels and f"I-{entity_type}" in labels:
|
38 |
+
labels.remove(f"I-{entity_type}")
|
39 |
+
|
40 |
+
current_entity_type = None
|
41 |
+
current_label = None
|
42 |
+
for label in labels:
|
43 |
+
if label.startswith("B-") or label.startswith("I-"):
|
44 |
+
current_label = label
|
45 |
+
current_entity_type = label[2:]
|
46 |
+
|
47 |
+
if current_entity_type:
|
48 |
+
if current_label.startswith("B-") and prev_entity_type == current_entity_type:
|
49 |
+
labels.remove(current_label)
|
50 |
+
labels.append(f"I-{current_entity_type}")
|
51 |
+
|
52 |
+
if current_label.startswith("I-") and prev_entity_type != current_entity_type:
|
53 |
+
labels.remove(current_label)
|
54 |
+
labels.append(f"B-{current_entity_type}")
|
55 |
+
|
56 |
+
prev_entity_type = current_entity_type
|
57 |
+
else:
|
58 |
+
prev_entity_type = None
|
59 |
+
|
60 |
+
token_data["labels"] = labels
|
61 |
+
return result
|
62 |
+
|
63 |
+
|
64 |
+
def generate_json(sentence):
|
65 |
inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128)
|
66 |
input_ids = inputs['input_ids'].to(model.device)
|
67 |
attention_mask = inputs['attention_mask'].to(model.device)
|
|
|
70 |
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
71 |
logits = outputs.logits
|
72 |
probabilities = torch.sigmoid(logits)
|
73 |
+
predicted_labels = (probabilities > 0.5).int()
|
74 |
|
75 |
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
|
76 |
+
result = []
|
|
|
|
|
77 |
for i, token in enumerate(tokens):
|
78 |
if token not in tokenizer.all_special_tokens:
|
|
|
79 |
label_indices = (predicted_labels[0][i] == 1).nonzero(as_tuple=False).squeeze(-1)
|
80 |
+
labels = [id2label[idx.item()] for idx in label_indices] if label_indices.numel() > 0 else ['O']
|
81 |
+
result.append({"token": token.replace("##", ""), "labels": labels})
|
82 |
+
|
83 |
+
result = post_process_entities(result)
|
84 |
+
|
85 |
+
return json.dumps(result, indent=4)
|
86 |
+
|
87 |
+
def predict_ner_tags_with_json(sentence):
|
88 |
+
json_result = generate_json(sentence)
|
89 |
+
|
90 |
+
result = json.loads(json_result)
|
91 |
+
|
92 |
+
word_row = []
|
93 |
+
stereo_row = []
|
94 |
+
gen_row = []
|
95 |
+
unfair_row = []
|
96 |
+
|
97 |
+
for token_data in result:
|
98 |
+
token = token_data["token"]
|
99 |
+
labels = token_data["labels"]
|
100 |
+
|
101 |
+
word_row.append(f"<span style='font-weight:bold;'>{token}</span>")
|
102 |
+
|
103 |
+
stereo_labels = [label[2:] for label in labels if "STEREO" in label]
|
104 |
+
stereo_row.append(
|
105 |
+
f"<span style='background:{label_colors['STEREO']}; border-radius:6px; padding:2px 5px;'>{', '.join(stereo_labels)}</span>"
|
106 |
+
if stereo_labels else " "
|
107 |
+
)
|
108 |
+
|
109 |
+
gen_labels = [label[2:] for label in labels if "GEN" in label]
|
110 |
+
gen_row.append(
|
111 |
+
f"<span style='background:{label_colors['GEN']}; border-radius:6px; padding:2px 5px;'>{', '.join(gen_labels)}</span>"
|
112 |
+
if gen_labels else " "
|
113 |
+
)
|
114 |
+
|
115 |
+
unfair_labels = [label[2:] for label in labels if "UNFAIR" in label]
|
116 |
+
unfair_row.append(
|
117 |
+
f"<span style='background:{label_colors['UNFAIR']}; border-radius:6px; padding:2px 5px;'>{', '.join(unfair_labels)}</span>"
|
118 |
+
if unfair_labels else " "
|
119 |
+
)
|
120 |
+
|
121 |
+
matrix_html = f"""
|
122 |
+
<table style='border-collapse:collapse; width:100%; font-family:monospace; text-align:left;'>
|
123 |
+
<tr>
|
124 |
+
<td><strong>Text Sequence</strong></td>
|
125 |
+
{''.join(f"<td>{word}</td>" for word in word_row)}
|
126 |
+
</tr>
|
127 |
+
<tr>
|
128 |
+
<td><strong>Generalizations</strong></td>
|
129 |
+
{''.join(f"<td>{cell}</td>" for cell in gen_row)}
|
130 |
+
</tr>
|
131 |
+
<tr>
|
132 |
+
<td><strong>Unfairness</strong></td>
|
133 |
+
{''.join(f"<td>{cell}</td>" for cell in unfair_row)}
|
134 |
+
</tr>
|
135 |
+
<tr>
|
136 |
+
<td><strong>Stereotypes</strong></td>
|
137 |
+
{''.join(f"<td>{cell}</td>" for cell in stereo_row)}
|
138 |
+
</tr>
|
139 |
+
</table>
|
140 |
+
"""
|
141 |
+
|
142 |
+
return f"{matrix_html}<br><pre>{json_result}</pre>"
|
143 |
+
|
144 |
iface = gr.Interface(
|
145 |
+
fn=predict_ner_tags_with_json,
|
146 |
+
inputs=[gr.Textbox(label="Input Sentence")],
|
147 |
+
outputs=[gr.HTML(label="Entity Matrix and JSON Output")],
|
148 |
title="Social Bias Named Entity Recognition (with BERT) 🕵",
|
149 |
description=("Enter a sentence to predict biased parts of speech tags. This model uses multi-label BertForTokenClassification, to label the entities: (GEN)eralizations, (UNFAIR)ness, and (STEREO)types. Labels follow BIO format. Try it out :)."
|
150 |
"<br><br>Read more about how this model was trained in this <a href='https://huggingface.co/blog/maximuspowers/bias-entity-recognition' target='_blank'>blog post</a>."
|