maximuspowers commited on
Commit
7cd8165
·
verified ·
1 Parent(s): 2dab34c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -9
app.py CHANGED
@@ -3,13 +3,13 @@ import torch
3
  from transformers import BertTokenizerFast, BertForTokenClassification
4
  import gradio as gr
5
 
6
- # init important things
7
  tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
8
  model = BertForTokenClassification.from_pretrained('maximuspowers/bias-detection-ner')
9
  model.eval()
10
  model.to('cuda' if torch.cuda.is_available() else 'cpu')
11
 
12
- # ids to labels we want to display
13
  id2label = {
14
  0: 'O',
15
  1: 'B-STEREO',
@@ -20,7 +20,18 @@ id2label = {
20
  6: 'I-UNFAIR'
21
  }
22
 
23
- # predict function you'll want to use if using in your own code
 
 
 
 
 
 
 
 
 
 
 
24
  def predict_ner_tags(sentence):
25
  inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128)
26
  input_ids = inputs['input_ids'].to(model.device)
@@ -30,23 +41,27 @@ def predict_ner_tags(sentence):
30
  outputs = model(input_ids=input_ids, attention_mask=attention_mask)
31
  logits = outputs.logits
32
  probabilities = torch.sigmoid(logits)
33
- predicted_labels = (probabilities > 0.5).int() # remember to try your own threshold
34
 
35
  result = []
36
  tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
 
37
  for i, token in enumerate(tokens):
38
  if token not in tokenizer.all_special_tokens:
39
  label_indices = (predicted_labels[0][i] == 1).nonzero(as_tuple=False).squeeze(-1)
40
  labels = [id2label[idx.item()] for idx in label_indices] if label_indices.numel() > 0 else ['O']
41
- result.append({"token": token, "labels": labels})
 
 
 
42
 
43
- return json.dumps(result, indent=4)
44
 
45
- # startup gradio
46
  iface = gr.Interface(
47
  fn=predict_ner_tags,
48
  inputs="text",
49
- outputs="text",
50
  title="Social Bias Named Entity Recognition (with BERT) 🕵",
51
  description=("Enter a sentence to predict biased parts of speech tags. This model uses multi-label BertForTokenClassification, to label the entities: (GEN)eralizations, (UNFAIR)ness, and (STEREO)types. Labels follow BIO format. Try it out :)."
52
  "<br><br>Read more about how this model was trained in this <a href='https://huggingface.co/blog/maximuspowers/bias-entity-recognition' target='_blank'>blog post</a>."
@@ -55,4 +70,4 @@ iface = gr.Interface(
55
  )
56
 
57
  if __name__ == "__main__":
58
- iface.launch()
 
3
  from transformers import BertTokenizerFast, BertForTokenClassification
4
  import gradio as gr
5
 
6
+ # Initialize important things
7
  tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
8
  model = BertForTokenClassification.from_pretrained('maximuspowers/bias-detection-ner')
9
  model.eval()
10
  model.to('cuda' if torch.cuda.is_available() else 'cpu')
11
 
12
+ # IDs to labels we want to display
13
  id2label = {
14
  0: 'O',
15
  1: 'B-STEREO',
 
20
  6: 'I-UNFAIR'
21
  }
22
 
23
+ # Color map for entities
24
+ label_colors = {
25
+ "B-STEREO": "#FFCDD2",
26
+ "I-STEREO": "#E57373",
27
+ "B-GEN": "#C8E6C9",
28
+ "I-GEN": "#81C784",
29
+ "B-UNFAIR": "#BBDEFB",
30
+ "I-UNFAIR": "#64B5F6",
31
+ "O": "#FFFFFF" # Default for no label
32
+ }
33
+
34
+ # Predict function
35
  def predict_ner_tags(sentence):
36
  inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128)
37
  input_ids = inputs['input_ids'].to(model.device)
 
41
  outputs = model(input_ids=input_ids, attention_mask=attention_mask)
42
  logits = outputs.logits
43
  probabilities = torch.sigmoid(logits)
44
+ predicted_labels = (probabilities > 0.5).int() # Threshold
45
 
46
  result = []
47
  tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
48
+ highlighted_sentence = ""
49
  for i, token in enumerate(tokens):
50
  if token not in tokenizer.all_special_tokens:
51
  label_indices = (predicted_labels[0][i] == 1).nonzero(as_tuple=False).squeeze(-1)
52
  labels = [id2label[idx.item()] for idx in label_indices] if label_indices.numel() > 0 else ['O']
53
+ # Get the most prominent label for coloring (arbitrary choice for multiple labels)
54
+ primary_label = labels[0] if labels else "O"
55
+ color = label_colors.get(primary_label, "#FFFFFF")
56
+ highlighted_sentence += f"<span style='background-color:{color}'>{token}</span> "
57
 
58
+ return highlighted_sentence.strip()
59
 
60
+ # Gradio Interface
61
  iface = gr.Interface(
62
  fn=predict_ner_tags,
63
  inputs="text",
64
+ outputs=gr.outputs.HTML(label="Highlighted Sentence"),
65
  title="Social Bias Named Entity Recognition (with BERT) 🕵",
66
  description=("Enter a sentence to predict biased parts of speech tags. This model uses multi-label BertForTokenClassification, to label the entities: (GEN)eralizations, (UNFAIR)ness, and (STEREO)types. Labels follow BIO format. Try it out :)."
67
  "<br><br>Read more about how this model was trained in this <a href='https://huggingface.co/blog/maximuspowers/bias-entity-recognition' target='_blank'>blog post</a>."
 
70
  )
71
 
72
  if __name__ == "__main__":
73
+ iface.launch()