Avoid dict overrides for entity-level
Browse files- FairEval.py +38 -18
- HFFE_use_cases.pdf +0 -0
FairEval.py
CHANGED
|
@@ -204,38 +204,58 @@ class FairEval(evaluate.Metric):
|
|
| 204 |
assert mode in ['traditional', 'fair', 'weighted'], 'mode must be \'traditional\', \'fair\' or \'weighted\''
|
| 205 |
assert error_format in ['count', 'error_ratio', 'entity_ratio'], 'error_format must be \'count\', \'error_ratio\' or \'entity_ratio\''
|
| 206 |
|
| 207 |
-
# append entity-level errors
|
| 208 |
-
for k, v in results['per_label']['fair'].items():
|
| 209 |
-
output[k] = {'TP': v['TP'] / fair_divider if error_format == 'entity_ratio' else v['TP'],
|
| 210 |
-
'FP': v['FP'] / fair_divider, 'FN': v['FN'] / fair_divider,
|
| 211 |
-
'LE': v['LE'] / fair_divider, 'BE': v['BE'] / fair_divider, 'LBE': v['LBE'] / fair_divider,}
|
| 212 |
-
|
| 213 |
-
# append entity-level scores (depending on mode)
|
| 214 |
if mode == 'traditional':
|
| 215 |
for k, v in results['per_label'][mode].items():
|
| 216 |
-
output[k]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
elif mode == 'fair' or mode == 'weighted':
|
| 218 |
for k, v in results['per_label'][mode].items():
|
| 219 |
-
output[k]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
|
| 221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
output['overall_precision'] = results['overall'][mode]['Prec']
|
| 223 |
output['overall_recall'] = results['overall'][mode]['Rec']
|
| 224 |
output['overall_f1'] = results['overall'][mode]['F1']
|
| 225 |
|
| 226 |
-
# append overall error counts (
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
|
| 234 |
return output
|
| 235 |
|
| 236 |
|
| 237 |
def seq_to_fair(seq_sentences):
|
| 238 |
-
"Transforms input
|
| 239 |
out = []
|
| 240 |
for seq_sentence in seq_sentences:
|
| 241 |
sentence = []
|
|
|
|
| 204 |
assert mode in ['traditional', 'fair', 'weighted'], 'mode must be \'traditional\', \'fair\' or \'weighted\''
|
| 205 |
assert error_format in ['count', 'error_ratio', 'entity_ratio'], 'error_format must be \'count\', \'error_ratio\' or \'entity_ratio\''
|
| 206 |
|
| 207 |
+
# append entity-level errors and scores
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
if mode == 'traditional':
|
| 209 |
for k, v in results['per_label'][mode].items():
|
| 210 |
+
output[k] = {# traditional scores
|
| 211 |
+
'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'],
|
| 212 |
+
|
| 213 |
+
# traditional errors
|
| 214 |
+
'TP': v['TP'] / trad_divider if error_format == 'entity_ratio' else v['TP'],
|
| 215 |
+
'FP': v['FP'] / trad_divider, 'FN': v['FN'] / trad_divider}
|
| 216 |
elif mode == 'fair' or mode == 'weighted':
|
| 217 |
for k, v in results['per_label'][mode].items():
|
| 218 |
+
output[k] = {# fair/weighted scores
|
| 219 |
+
'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'],
|
| 220 |
+
|
| 221 |
+
# traditional scores
|
| 222 |
+
'trad_prec': results['per_label']['traditional'][k]['Prec'],
|
| 223 |
+
'trad_rec': results['per_label']['traditional'][k]['Rec'],
|
| 224 |
+
'trad_f1': results['per_label']['traditional'][k]['F1'],
|
| 225 |
|
| 226 |
+
# fair/weighted errors
|
| 227 |
+
'TP': v['TP'] / fair_divider if error_format == 'entity_ratio' else v['TP'],
|
| 228 |
+
'FP': v['FP'] / fair_divider, 'FN': v['FN'] / fair_divider,
|
| 229 |
+
'LE': v['LE'] / fair_divider, 'BE': v['BE'] / fair_divider, 'LBE': v['LBE'] / fair_divider}
|
| 230 |
+
|
| 231 |
+
# append overall scores
|
| 232 |
output['overall_precision'] = results['overall'][mode]['Prec']
|
| 233 |
output['overall_recall'] = results['overall'][mode]['Rec']
|
| 234 |
output['overall_f1'] = results['overall'][mode]['F1']
|
| 235 |
|
| 236 |
+
# append overall error counts (and trad scores if mode is fair)
|
| 237 |
+
if mode == 'traditional':
|
| 238 |
+
output['TP'] = results['overall'][mode]['TP'] / trad_divider if error_format == 'entity_ratio' else \
|
| 239 |
+
results['overall'][mode]['TP']
|
| 240 |
+
output['FP'] = results['overall'][mode]['FP'] / trad_divider
|
| 241 |
+
output['FN'] = results['overall'][mode]['FN'] / trad_divider
|
| 242 |
+
elif mode == 'fair' or 'weighted':
|
| 243 |
+
output['overall_trad_prec'] = results['overall']['traditional']['Prec']
|
| 244 |
+
output['overall_trad_rec'] = results['overall']['traditional']['Rec']
|
| 245 |
+
output['overall_trad_f1'] = results['overall']['traditional']['F1']
|
| 246 |
+
output['TP'] = results['overall'][mode]['TP'] / fair_divider if error_format == 'entity_ratio' else \
|
| 247 |
+
results['overall'][mode]['TP']
|
| 248 |
+
output['FP'] = results['overall'][mode]['FP'] / fair_divider
|
| 249 |
+
output['FN'] = results['overall'][mode]['FN'] / fair_divider
|
| 250 |
+
output['LE'] = results['overall'][mode]['LE'] / fair_divider
|
| 251 |
+
output['BE'] = results['overall'][mode]['BE'] / fair_divider
|
| 252 |
+
output['LBE'] = results['overall'][mode]['LBE'] / fair_divider
|
| 253 |
|
| 254 |
return output
|
| 255 |
|
| 256 |
|
| 257 |
def seq_to_fair(seq_sentences):
|
| 258 |
+
"Transforms input annotated sentences from seqeval span format to FairEval span format"
|
| 259 |
out = []
|
| 260 |
for seq_sentence in seq_sentences:
|
| 261 |
sentence = []
|
HFFE_use_cases.pdf
DELETED
|
Binary file (86.4 kB)
|
|
|