Spaces:
Running
Running
| from dataset import ParallelTextReader | |
| from torch.utils.data import DataLoader | |
| from accelerate import find_executable_batch_size | |
| from evaluate import load | |
| from tqdm import tqdm | |
| import torch | |
| import json | |
| import argparse | |
| import numpy as np | |
| import os | |
| def get_dataloader(pred_path: str, gold_path: str, batch_size: int): | |
| """ | |
| Returns a dataloader for the given files. | |
| """ | |
| def collate_fn(batch): | |
| return list(map(list, zip(*batch))) | |
| reader = ParallelTextReader(pred_path=pred_path, gold_path=gold_path) | |
| dataloader = DataLoader( | |
| reader, batch_size=batch_size, collate_fn=collate_fn, num_workers=0 | |
| ) | |
| return dataloader | |
| def eval_files( | |
| pred_path: str, | |
| gold_path: str, | |
| bert_score_model: str, | |
| starting_batch_size: int = 128, | |
| output_path: str = None, | |
| ): | |
| """ | |
| Evaluates the given files. | |
| """ | |
| if torch.cuda.is_available(): | |
| device = "cuda:0" | |
| print("We will use a GPU to calculate BertScore.") | |
| else: | |
| device = "cpu" | |
| print( | |
| f"We will use the CPU to calculate BertScore, this can be slow for large datasets." | |
| ) | |
| dataloader = get_dataloader(pred_path, gold_path, starting_batch_size) | |
| print("Loading sacrebleu...") | |
| sacrebleu = load("sacrebleu") | |
| print("Loading rouge...") | |
| rouge = load("rouge") | |
| print("Loading bleu...") | |
| bleu = load("bleu") | |
| print("Loading meteor...") | |
| meteor = load("meteor") | |
| print("Loading ter...") | |
| ter = load("ter") | |
| print("Loading BertScore...") | |
| bert_score = load("bertscore") | |
| with tqdm(total=len(dataloader.dataset), desc="Loading data...") as pbar: | |
| for predictions, references in dataloader: | |
| sacrebleu.add_batch(predictions=predictions, references=references) | |
| rouge.add_batch(predictions=predictions, references=references) | |
| bleu.add_batch(predictions=predictions, references=references) | |
| meteor.add_batch(predictions=predictions, references=references) | |
| ter.add_batch(predictions=predictions, references=references) | |
| bert_score.add_batch(predictions=predictions, references=references) | |
| pbar.update(len(predictions)) | |
| result_dictionary = {"path": pred_path} | |
| print("Computing sacrebleu") | |
| result_dictionary["sacrebleu"] = sacrebleu.compute() | |
| print("Computing rouge score") | |
| result_dictionary["rouge"] = rouge.compute( | |
| use_aggregator=True, rouge_types=["rouge1", "rouge2", "rougeL", "rougeLsum"] | |
| ) | |
| print("Computing bleu score") | |
| result_dictionary["bleu"] = bleu.compute() | |
| print("Computing meteor score") | |
| result_dictionary["meteor"] = meteor.compute() | |
| print("Computing ter score") | |
| result_dictionary["ter"] = ter.compute() | |
| def inference(batch_size): | |
| nonlocal bert_score, bert_score_model | |
| print(f"Computing bert score with batch size {batch_size} on {device}") | |
| results = bert_score.compute( | |
| model_type=bert_score_model, | |
| batch_size=batch_size, | |
| device=device, | |
| use_fast_tokenizer=True, | |
| ) | |
| results["precision"] = np.average(results["precision"]) | |
| results["recall"] = np.average(results["recall"]) | |
| results["f1"] = np.average(results["f1"]) | |
| return results | |
| result_dictionary["bert_score"] = inference() | |
| if output_path is not None: | |
| if not os.path.exists(os.path.abspath(os.path.dirname(output_path))): | |
| os.makedirs(os.path.abspath(os.path.dirname(output_path))) | |
| with open(output_path, "w") as f: | |
| json.dump(result_dictionary, f, indent=4) | |
| print(f"Results: {json.dumps(result_dictionary,indent=4)}") | |
| return result_dictionary | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser( | |
| description="Run the translation evaluation experiments" | |
| ) | |
| parser.add_argument( | |
| "--pred_path", | |
| type=str, | |
| required=True, | |
| help="Path to a txt file containing the predicted sentences.", | |
| ) | |
| parser.add_argument( | |
| "--gold_path", | |
| type=str, | |
| required=True, | |
| help="Path to a txt file containing the gold sentences.", | |
| ) | |
| parser.add_argument( | |
| "--starting_batch_size", | |
| type=int, | |
| default=64, | |
| help="Starting batch size for BertScore, we will automatically reduce it if we find an OOM error.", | |
| ) | |
| parser.add_argument( | |
| "--output_path", | |
| type=str, | |
| default=None, | |
| help="Path to a json file to save the results. If not given, the results will be printed to the console.", | |
| ) | |
| parser.add_argument( | |
| "--bert_score_model", | |
| type=str, | |
| default="microsoft/deberta-xlarge-mnli", | |
| help="Model to use for BertScore. See: https://github.com/huggingface/datasets/tree/master/metrics/bertscore" | |
| "and https://github.com/Tiiiger/bert_score for more details.", | |
| ) | |
| args = parser.parse_args() | |
| eval_files( | |
| pred_path=args.pred_path, | |
| gold_path=args.gold_path, | |
| starting_batch_size=args.starting_batch_size, | |
| output_path=args.output_path, | |
| bert_score_model=args.bert_score_model, | |
| ) | |