maximuspowers commited on
Commit
2c53668
·
verified ·
1 Parent(s): 5b65826

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -52
app.py CHANGED
@@ -3,13 +3,11 @@ import torch
3
  from transformers import BertTokenizerFast, BertForTokenClassification
4
  import gradio as gr
5
 
6
- # Initialize important things
7
  tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
8
  model = BertForTokenClassification.from_pretrained('maximuspowers/bias-detection-ner')
9
  model.eval()
10
  model.to('cuda' if torch.cuda.is_available() else 'cpu')
11
 
12
- # IDs to labels we want to display
13
  id2label = {
14
  0: 'O',
15
  1: 'B-STEREO',
@@ -20,24 +18,50 @@ id2label = {
20
  6: 'I-UNFAIR'
21
  }
22
 
23
- # Color map for entities
 
24
  label_colors = {
25
- "STEREO": "rgba(255, 0, 0, 0.3)", # Red
26
- "GEN": "rgba(0, 0, 255, 0.3)", # Blue
27
- "UNFAIR": "rgba(0, 255, 0, 0.3)" # Green
28
  }
29
 
30
- # Helper to wrap a token in a span with color
31
- def wrap_token_with_color(token, labels):
32
- # Build nested highlights
33
- style = "position: relative;"
34
- for label in labels:
35
- if label != "O" and label in label_colors:
36
- style += f"background: {label_colors[label]};"
37
- return f"<span style='{style}'>{token}</span>"
38
-
39
- # Predict function
40
- def predict_ner_tags(sentence):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128)
42
  input_ids = inputs['input_ids'].to(model.device)
43
  attention_mask = inputs['attention_mask'].to(model.device)
@@ -46,47 +70,81 @@ def predict_ner_tags(sentence):
46
  outputs = model(input_ids=input_ids, attention_mask=attention_mask)
47
  logits = outputs.logits
48
  probabilities = torch.sigmoid(logits)
49
- predicted_labels = (probabilities > 0.5).int() # Threshold
50
 
51
  tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
52
- highlighted_sentence = ""
53
- prev_labels = []
54
-
55
  for i, token in enumerate(tokens):
56
  if token not in tokenizer.all_special_tokens:
57
- # Extract the labels for this token
58
  label_indices = (predicted_labels[0][i] == 1).nonzero(as_tuple=False).squeeze(-1)
59
- labels = [id2label[idx.item()][2:] for idx in label_indices if idx.item() in id2label] # Safe lookup
60
-
61
- if not labels: # Handle empty labels gracefully
62
- labels = ["O"]
63
-
64
- # Check if labels are the same as the previous token (for seamless highlighting)
65
- if labels != prev_labels:
66
- if prev_labels: # Close the previous span if needed
67
- highlighted_sentence += "</span>"
68
-
69
- # Start a new span
70
- if labels != ["O"]:
71
- highlight_colors = [label_colors[label] for label in labels if label in label_colors]
72
- if highlight_colors: # Only create gradient if valid colors exist
73
- highlighted_sentence += f"<span style='background: linear-gradient({', '.join(highlight_colors)});'>"
74
-
75
- # Add the token to the span
76
- highlighted_sentence += token.replace("##", "")
77
- prev_labels = labels
78
-
79
- # Close any open spans
80
- if prev_labels and prev_labels != ["O"]:
81
- highlighted_sentence += "</span>"
82
-
83
- return highlighted_sentence
84
-
85
- # Gradio Interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  iface = gr.Interface(
87
- fn=predict_ner_tags,
88
- inputs=gr.Textbox(label="Input Sentence"),
89
- outputs=gr.HTML(label="Highlighted Sentence"),
90
  title="Social Bias Named Entity Recognition (with BERT) 🕵",
91
  description=("Enter a sentence to predict biased parts of speech tags. This model uses multi-label BertForTokenClassification, to label the entities: (GEN)eralizations, (UNFAIR)ness, and (STEREO)types. Labels follow BIO format. Try it out :)."
92
  "<br><br>Read more about how this model was trained in this <a href='https://huggingface.co/blog/maximuspowers/bias-entity-recognition' target='_blank'>blog post</a>."
 
3
  from transformers import BertTokenizerFast, BertForTokenClassification
4
  import gradio as gr
5
 
 
6
  tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
7
  model = BertForTokenClassification.from_pretrained('maximuspowers/bias-detection-ner')
8
  model.eval()
9
  model.to('cuda' if torch.cuda.is_available() else 'cpu')
10
 
 
11
  id2label = {
12
  0: 'O',
13
  1: 'B-STEREO',
 
18
  6: 'I-UNFAIR'
19
  }
20
 
21
+ label2id = {v: k for k, v in id2label.items()}
22
+
23
  label_colors = {
24
+ "STEREO": "rgba(255, 0, 0, 0.2)",
25
+ "GEN": "rgba(0, 0, 255, 0.2)",
26
+ "UNFAIR": "rgba(0, 255, 0, 0.2)"
27
  }
28
 
29
+ def post_process_entities(result):
30
+ prev_entity_type = None
31
+
32
+ for i, token_data in enumerate(result):
33
+ labels = token_data["labels"]
34
+
35
+ labels = list(set(labels))
36
+ for entity_type in ["GEN", "UNFAIR", "STEREO"]:
37
+ if f"B-{entity_type}" in labels and f"I-{entity_type}" in labels:
38
+ labels.remove(f"I-{entity_type}")
39
+
40
+ current_entity_type = None
41
+ current_label = None
42
+ for label in labels:
43
+ if label.startswith("B-") or label.startswith("I-"):
44
+ current_label = label
45
+ current_entity_type = label[2:]
46
+
47
+ if current_entity_type:
48
+ if current_label.startswith("B-") and prev_entity_type == current_entity_type:
49
+ labels.remove(current_label)
50
+ labels.append(f"I-{current_entity_type}")
51
+
52
+ if current_label.startswith("I-") and prev_entity_type != current_entity_type:
53
+ labels.remove(current_label)
54
+ labels.append(f"B-{current_entity_type}")
55
+
56
+ prev_entity_type = current_entity_type
57
+ else:
58
+ prev_entity_type = None
59
+
60
+ token_data["labels"] = labels
61
+ return result
62
+
63
+
64
+ def generate_json(sentence):
65
  inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128)
66
  input_ids = inputs['input_ids'].to(model.device)
67
  attention_mask = inputs['attention_mask'].to(model.device)
 
70
  outputs = model(input_ids=input_ids, attention_mask=attention_mask)
71
  logits = outputs.logits
72
  probabilities = torch.sigmoid(logits)
73
+ predicted_labels = (probabilities > 0.5).int()
74
 
75
  tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
76
+ result = []
 
 
77
  for i, token in enumerate(tokens):
78
  if token not in tokenizer.all_special_tokens:
 
79
  label_indices = (predicted_labels[0][i] == 1).nonzero(as_tuple=False).squeeze(-1)
80
+ labels = [id2label[idx.item()] for idx in label_indices] if label_indices.numel() > 0 else ['O']
81
+ result.append({"token": token.replace("##", ""), "labels": labels})
82
+
83
+ result = post_process_entities(result)
84
+
85
+ return json.dumps(result, indent=4)
86
+
87
+ def predict_ner_tags_with_json(sentence):
88
+ json_result = generate_json(sentence)
89
+
90
+ result = json.loads(json_result)
91
+
92
+ word_row = []
93
+ stereo_row = []
94
+ gen_row = []
95
+ unfair_row = []
96
+
97
+ for token_data in result:
98
+ token = token_data["token"]
99
+ labels = token_data["labels"]
100
+
101
+ word_row.append(f"<span style='font-weight:bold;'>{token}</span>")
102
+
103
+ stereo_labels = [label[2:] for label in labels if "STEREO" in label]
104
+ stereo_row.append(
105
+ f"<span style='background:{label_colors['STEREO']}; border-radius:6px; padding:2px 5px;'>{', '.join(stereo_labels)}</span>"
106
+ if stereo_labels else "&nbsp;"
107
+ )
108
+
109
+ gen_labels = [label[2:] for label in labels if "GEN" in label]
110
+ gen_row.append(
111
+ f"<span style='background:{label_colors['GEN']}; border-radius:6px; padding:2px 5px;'>{', '.join(gen_labels)}</span>"
112
+ if gen_labels else "&nbsp;"
113
+ )
114
+
115
+ unfair_labels = [label[2:] for label in labels if "UNFAIR" in label]
116
+ unfair_row.append(
117
+ f"<span style='background:{label_colors['UNFAIR']}; border-radius:6px; padding:2px 5px;'>{', '.join(unfair_labels)}</span>"
118
+ if unfair_labels else "&nbsp;"
119
+ )
120
+
121
+ matrix_html = f"""
122
+ <table style='border-collapse:collapse; width:100%; font-family:monospace; text-align:left;'>
123
+ <tr>
124
+ <td><strong>Text Sequence</strong></td>
125
+ {''.join(f"<td>{word}</td>" for word in word_row)}
126
+ </tr>
127
+ <tr>
128
+ <td><strong>Generalizations</strong></td>
129
+ {''.join(f"<td>{cell}</td>" for cell in gen_row)}
130
+ </tr>
131
+ <tr>
132
+ <td><strong>Unfairness</strong></td>
133
+ {''.join(f"<td>{cell}</td>" for cell in unfair_row)}
134
+ </tr>
135
+ <tr>
136
+ <td><strong>Stereotypes</strong></td>
137
+ {''.join(f"<td>{cell}</td>" for cell in stereo_row)}
138
+ </tr>
139
+ </table>
140
+ """
141
+
142
+ return f"{matrix_html}<br><pre>{json_result}</pre>"
143
+
144
  iface = gr.Interface(
145
+ fn=predict_ner_tags_with_json,
146
+ inputs=[gr.Textbox(label="Input Sentence")],
147
+ outputs=[gr.HTML(label="Entity Matrix and JSON Output")],
148
  title="Social Bias Named Entity Recognition (with BERT) 🕵",
149
  description=("Enter a sentence to predict biased parts of speech tags. This model uses multi-label BertForTokenClassification, to label the entities: (GEN)eralizations, (UNFAIR)ness, and (STEREO)types. Labels follow BIO format. Try it out :)."
150
  "<br><br>Read more about how this model was trained in this <a href='https://huggingface.co/blog/maximuspowers/bias-entity-recognition' target='_blank'>blog post</a>."