BenCzechMark / model_compare.py
jstetina's picture
Ranks
49d6897
raw
history blame
1.62 kB
from functools import cmp_to_key
from compare_significance import check_significance
class ModelCompare():
TASKS = ["propaganda_demonizace",
"propaganda_vina",
"propaganda_relativizace",
"propaganda_argumentace",
"propaganda_lokace",
"propaganda_nazor",
"propaganda_emoce",
"propaganda_fabulace",
"propaganda_nalepkovani",
"propaganda_zamereni",
"propaganda_zanr",
"propaganda_rusko",
"propaganda_strach",
"benczechmark_sentiment"]
def __init__(self, ranks:dict=None):
self.ranks = ranks
def compare_models(self, modelA_id, modelB_id):
if not self.ranks:
raise Exception("Missing model rankings")
res = self.ranks[modelA_id][modelB_id][self.current_task]
if res == True:
return 1
elif res == False:
return -1
else:
return -1
def get_tasks_ranks(self, ranks:dict) -> dict:
'''Order models based on the significance improvement'''
self.ranks = ranks
tasks_ranks = {}
models = ranks.keys()
for task in self.TASKS:
self.current_task = task
tasks_ranks[task] = sorted(models, key=cmp_to_key(self.compare_models))
return tasks_ranks
# models = {
# model1 : {
# task1 : order_idx
# task2 : order_idx
# task3 : order_idx
# }
# }