maximuspowers commited on
Commit
e0dca16
·
verified ·
1 Parent(s): 34ab835

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -46
app.py CHANGED
@@ -20,8 +20,6 @@ id2label = {
20
  6: 'I-UNFAIR'
21
  }
22
 
23
- label2id = {v: k for k, v in id2label.items()}
24
-
25
  # Entity colors for highlights
26
  label_colors = {
27
  "STEREO": "rgba(255, 0, 0, 0.2)", # Light Red
@@ -34,38 +32,23 @@ def post_process_entities(result):
34
  prev_entity_type = None
35
  for token_data in result:
36
  labels = token_data["labels"]
37
- labels = list(set(labels))
38
-
39
- # Handle conflicting B- and I- tags for the same entity
40
- for entity_type in ["GEN", "UNFAIR", "STEREO"]:
41
- if f"B-{entity_type}" in labels and f"I-{entity_type}" in labels:
42
- labels.remove(f"I-{entity_type}")
43
 
44
  # Handle sequence rules
45
- current_entity_type = None
46
- current_label = None
47
- for label in labels:
48
- if label.startswith("B-") or label.startswith("I-"):
49
- current_label = label
50
- current_entity_type = label[2:]
51
-
52
- if current_entity_type:
53
- if current_label.startswith("B-") and prev_entity_type == current_entity_type:
54
- labels.remove(current_label)
55
- labels.append(f"I-{current_entity_type}")
56
- if current_label.startswith("I-") and prev_entity_type != current_entity_type:
57
- labels.remove(current_label)
58
- labels.append(f"B-{current_entity_type}")
59
-
60
- prev_entity_type = current_entity_type
61
- else:
62
- prev_entity_type = None
63
-
64
- token_data["labels"] = labels
65
  return result
66
 
67
- # Generate JSON results
68
- def generate_json(sentence):
69
  inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128)
70
  input_ids = inputs['input_ids'].to(model.device)
71
  attention_mask = inputs['attention_mask'].to(model.device)
@@ -74,26 +57,24 @@ def generate_json(sentence):
74
  outputs = model(input_ids=input_ids, attention_mask=attention_mask)
75
  logits = outputs.logits
76
  probabilities = torch.sigmoid(logits)
77
- predicted_labels = (probabilities > 0.5).int()
78
 
79
  tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
80
  result = []
81
  for i, token in enumerate(tokens):
82
  if token not in tokenizer.all_special_tokens:
83
- label_indices = (predicted_labels[0][i] == 1).nonzero(as_tuple=False).squeeze(-1)
84
- labels = [id2label[idx.item()] for idx in label_indices] if label_indices.numel() > 0 else ['O']
 
 
 
 
 
 
85
  result.append({"token": token.replace("##", ""), "labels": labels})
86
 
87
  result = post_process_entities(result)
88
 
89
- return json.dumps(result, indent=4)
90
-
91
- # Predict function
92
- def predict_ner_tags_with_json(sentence):
93
- json_result = generate_json(sentence)
94
-
95
- result = json.loads(json_result)
96
-
97
  word_row = []
98
  stereo_row = []
99
  gen_row = []
@@ -105,19 +86,28 @@ def predict_ner_tags_with_json(sentence):
105
 
106
  word_row.append(f"<span style='font-weight:bold;'>{token}</span>")
107
 
108
- stereo_labels = [label[2:] for label in labels if "STEREO" in label]
 
 
 
109
  stereo_row.append(
110
  f"<span style='background:{label_colors['STEREO']}; border-radius:6px; padding:2px 5px;'>{', '.join(stereo_labels)}</span>"
111
  if stereo_labels else "&nbsp;"
112
  )
113
 
114
- gen_labels = [label[2:] for label in labels if "GEN" in label]
 
 
 
115
  gen_row.append(
116
  f"<span style='background:{label_colors['GEN']}; border-radius:6px; padding:2px 5px;'>{', '.join(gen_labels)}</span>"
117
  if gen_labels else "&nbsp;"
118
  )
119
 
120
- unfair_labels = [label[2:] for label in labels if "UNFAIR" in label]
 
 
 
121
  unfair_row.append(
122
  f"<span style='background:{label_colors['UNFAIR']}; border-radius:6px; padding:2px 5px;'>{', '.join(unfair_labels)}</span>"
123
  if unfair_labels else "&nbsp;"
@@ -144,7 +134,7 @@ def predict_ner_tags_with_json(sentence):
144
  </table>
145
  """
146
 
147
- return f"{matrix_html}<br><pre>{json_result}</pre>"
148
 
149
  # Gradio Interface
150
  iface = gr.Blocks()
@@ -168,7 +158,7 @@ with iface:
168
  with gr.Row():
169
  input_box = gr.Textbox(label="Input Sentence")
170
  with gr.Row():
171
- output_box = gr.HTML(label="Entity Matrix and JSON Output")
172
 
173
  input_box.change(predict_ner_tags_with_json, inputs=[input_box], outputs=[output_box])
174
 
 
20
  6: 'I-UNFAIR'
21
  }
22
 
 
 
23
  # Entity colors for highlights
24
  label_colors = {
25
  "STEREO": "rgba(255, 0, 0, 0.2)", # Light Red
 
32
  prev_entity_type = None
33
  for token_data in result:
34
  labels = token_data["labels"]
 
 
 
 
 
 
35
 
36
  # Handle sequence rules
37
+ new_labels = []
38
+ for label_data in labels:
39
+ label = label_data['label']
40
+ if label.startswith("B-") and prev_entity_type == label[2:]:
41
+ new_labels.append({"label": f"I-{label[2:]}", "confidence": label_data["confidence"]})
42
+ elif label.startswith("I-") and prev_entity_type != label[2:]:
43
+ new_labels.append({"label": f"B-{label[2:]}", "confidence": label_data["confidence"]})
44
+ else:
45
+ new_labels.append(label_data)
46
+ prev_entity_type = label[2:]
47
+ token_data["labels"] = new_labels
 
 
 
 
 
 
 
 
 
48
  return result
49
 
50
+ # Generate JSON results with probabilities
51
+ def predict_ner_tags_with_json(sentence):
52
  inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128)
53
  input_ids = inputs['input_ids'].to(model.device)
54
  attention_mask = inputs['attention_mask'].to(model.device)
 
57
  outputs = model(input_ids=input_ids, attention_mask=attention_mask)
58
  logits = outputs.logits
59
  probabilities = torch.sigmoid(logits)
 
60
 
61
  tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
62
  result = []
63
  for i, token in enumerate(tokens):
64
  if token not in tokenizer.all_special_tokens:
65
+ label_indices = (probabilities[0][i] > 0.52).nonzero(as_tuple=False).squeeze(-1)
66
+ labels = [
67
+ {
68
+ "label": id2label[idx.item()],
69
+ "confidence": round(probabilities[0][i][idx].item() * 100, 2)
70
+ }
71
+ for idx in label_indices
72
+ ]
73
  result.append({"token": token.replace("##", ""), "labels": labels})
74
 
75
  result = post_process_entities(result)
76
 
77
+ # Create table rows
 
 
 
 
 
 
 
78
  word_row = []
79
  stereo_row = []
80
  gen_row = []
 
86
 
87
  word_row.append(f"<span style='font-weight:bold;'>{token}</span>")
88
 
89
+ # STEREO
90
+ stereo_labels = [
91
+ f"{label_data['label'][2:]} ({label_data['confidence']}%)" for label_data in labels if "STEREO" in label_data["label"]
92
+ ]
93
  stereo_row.append(
94
  f"<span style='background:{label_colors['STEREO']}; border-radius:6px; padding:2px 5px;'>{', '.join(stereo_labels)}</span>"
95
  if stereo_labels else "&nbsp;"
96
  )
97
 
98
+ # GEN
99
+ gen_labels = [
100
+ f"{label_data['label'][2:]} ({label_data['confidence']}%)" for label_data in labels if "GEN" in label_data["label"]
101
+ ]
102
  gen_row.append(
103
  f"<span style='background:{label_colors['GEN']}; border-radius:6px; padding:2px 5px;'>{', '.join(gen_labels)}</span>"
104
  if gen_labels else "&nbsp;"
105
  )
106
 
107
+ # UNFAIR
108
+ unfair_labels = [
109
+ f"{label_data['label'][2:]} ({label_data['confidence']}%)" for label_data in labels if "UNFAIR" in label_data["label"]
110
+ ]
111
  unfair_row.append(
112
  f"<span style='background:{label_colors['UNFAIR']}; border-radius:6px; padding:2px 5px;'>{', '.join(unfair_labels)}</span>"
113
  if unfair_labels else "&nbsp;"
 
134
  </table>
135
  """
136
 
137
+ return matrix_html
138
 
139
  # Gradio Interface
140
  iface = gr.Blocks()
 
158
  with gr.Row():
159
  input_box = gr.Textbox(label="Input Sentence")
160
  with gr.Row():
161
+ output_box = gr.HTML(label="Entity Matrix")
162
 
163
  input_box.change(predict_ner_tags_with_json, inputs=[input_box], outputs=[output_box])
164