Spaces:
Running
Running
| import os.path as osp | |
| import argparse | |
| import json | |
| from data import Tasks, DATASET_TASK_DICT | |
| from utils import preprocess_path | |
| def process_result(entry, name, task): | |
| processed = { | |
| 'name': name, | |
| 'task': str(task), | |
| } | |
| if task == Tasks.EXTRACTIVE_QUESTION_ANSWERING: | |
| key = 'em,none' if name == 'mkqa_tr' else 'exact,none' | |
| scale = 0.01 if name != 'mkqa_tr' else 1 | |
| processed['exact_match'] = scale * entry[key] | |
| processed['f1'] = scale * entry['f1,none'] | |
| elif task == Tasks.SUMMARIZATION: | |
| processed['rouge1'] = entry['rouge1,none'] | |
| processed['rouge2'] = entry['rouge2,none'] | |
| processed['rougeL'] = entry['rougeL,none'] | |
| elif task in ( | |
| Tasks.MULTIPLE_CHOICE, | |
| Tasks.NATURAL_LANGUAGE_INFERENCE, | |
| Tasks.TEXT_CLASSIFICATION, | |
| ): | |
| processed['acc'] = entry['acc,none'] | |
| processed['acc_norm'] = entry.get('acc_norm,none', processed['acc']) | |
| elif task == Tasks.MACHINE_TRANSLATION: | |
| processed['wer'] = entry['wer,none'] | |
| processed['bleu'] = entry['bleu,none'] | |
| elif task == Tasks.GRAMMATICAL_ERROR_CORRECTION: | |
| processed['exact_match'] = entry['exact_match,none'] | |
| return processed | |
| def main(): | |
| parser = argparse.ArgumentParser(description='Results file formatter.') | |
| parser.add_argument('-i', '--input-file', type=str, help='Input JSON file for the results.') | |
| parser.add_argument('-o', '--output-file', type=str, help='Output JSON file for the formatted results.') | |
| args = parser.parse_args() | |
| with open(preprocess_path(args.input_file)) as f: | |
| raw_data = json.load(f) | |
| # first, get model args | |
| model_args = raw_data['config']['model_args'].split(',') | |
| model_args = dict([tuple(pair.split('=')) for pair in model_args]) | |
| processed = dict() | |
| model_args['model'] = model_args.pop('pretrained') | |
| processed['model'] = model_args | |
| processed['model']['api'] = raw_data['config']['model'] | |
| # then, process results | |
| results = raw_data['results'] | |
| processed['results'] = list() | |
| for dataset, entry in results.items(): | |
| if dataset not in DATASET_TASK_DICT.keys(): | |
| continue | |
| task = DATASET_TASK_DICT[dataset] | |
| processed['results'].append(process_result(entry, dataset, task)) | |
| with open(preprocess_path(args.output_file), 'w') as f: | |
| json.dump(processed, f, indent=4) | |
| print('done') | |
| if __name__ == '__main__': | |
| main() |