OpenFactCheck-Prerelease
/
src
/openfactcheck
/solvers
/webservice
/factcheckgpt_utils
/eval_utils.py
| # code for general evaluation | |
| import numpy as np | |
| import evaluate | |
| from sklearn.metrics import precision_recall_fscore_support, accuracy_score | |
| def evaluate_classification(preds, gold): | |
| metric = evaluate.load("bstrai/classification_report") | |
| return metric.compute(predictions=preds, references=gold) | |
| def eval_classification(y_true, y_pred, average="macro"): | |
| precision, recall, F1, support = precision_recall_fscore_support(y_true, y_pred, average=average) | |
| accuracy = accuracy_score(y_true, y_pred) | |
| metrics = { | |
| "accuracy": round(accuracy, 3), | |
| "precision": round(precision, 3), | |
| "recall": round(recall, 3), | |
| "F1": round(F1, 3), | |
| } | |
| return metrics | |
| def eval_binary(y_true, y_pred, pos_label=1, average="binary"): | |
| """pos_label: postive label is machine text here, label is 1, human text is 0""" | |
| precision, recall, F1, support = precision_recall_fscore_support( | |
| y_true, y_pred, pos_label = pos_label, average = average) | |
| # accuracy | |
| accuracy = accuracy_score(y_true, y_pred) | |
| # precison | |
| # pre = precision_score(y_true, y_pred, pos_label = pos_label, average = average) | |
| # recall | |
| # rec = recall_score(y_true, y_pred, pos_label = pos_label, average = average) | |
| metrics = { | |
| "accuracy": round(accuracy, 3), | |
| "precision": round(precision, 3), | |
| "recall": round(recall, 3), | |
| "F1": round(F1, 3), | |
| } | |
| return metrics | |