Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import List, Optional | |
from mmengine.evaluator import BaseMetric | |
from mmpretrain.registry import METRICS | |
class ANLS(BaseMetric): | |
"""ANLS metric. | |
Compute the Average Normalized Levenshtein Similarity(ANLS). | |
Args: | |
threshold (float): ANLS threshold used for determining if the answer | |
has been correctly selected but not properly recognized, | |
or on the contrary, the output is a wrong text selected from the | |
options and given as an answer. | |
collect_device (str): Device name used for collecting results from | |
different ranks during distributed training. Must be 'cpu' or | |
'gpu'. Defaults to 'cpu'. | |
prefix (str, optional): The prefix that will be added in the metric | |
names to disambiguate homonymous metrics of different evaluators. | |
If prefix is not provided in the argument, self.default_prefix | |
will be used instead. Should be modified according to the | |
`retrieval_type` for unambiguous results. Defaults to TR. | |
""" | |
default_prefix = 'ANLS' | |
def __init__(self, | |
threshold: float = 0.5, | |
collect_device: str = 'cpu', | |
prefix: Optional[str] = None) -> None: | |
super().__init__(collect_device=collect_device, prefix=prefix) | |
self.threshold = threshold | |
def process(self, data_batch, data_samples) -> None: | |
"""Process one batch of data samples. | |
The processed results should be stored in ``self.results``, which will | |
be used to computed the metrics when all batches have been processed. | |
Args: | |
data_batch: A batch of data from the dataloader. | |
data_samples (Sequence[dict]): A batch of outputs from the model. | |
""" | |
for sample in data_samples: | |
gt_answer = sample.get('gt_answer') | |
result = { | |
'pred_answer': sample.get('pred_answer'), | |
'gt_answer': gt_answer | |
} | |
self.results.append(result) | |
def compute_metrics(self, results: List) -> dict: | |
"""Compute the metrics from processed results. | |
Args: | |
results (dict): The processed results of each batch. | |
Returns: | |
Dict: The computed metrics. The keys are the names of the metrics, | |
and the values are corresponding results. | |
""" | |
total_score = 0. | |
for result in results: | |
sample_score_list = [] | |
pred = ' '.join(result['pred_answer'].strip().lower().split()) | |
for gt in result['gt_answer']: | |
gt = ' '.join(gt.strip().lower().split()) | |
dist = levenshtein_distance(gt, pred) | |
length = max( | |
len(gt.upper()), len(result['pred_answer'].upper())) | |
sample_score_list.append(0.0 if length == 0 else float(dist) / | |
float(length)) | |
per_sample_score = 1. - min(sample_score_list) | |
if per_sample_score < self.threshold: | |
per_sample_score = 0. | |
total_score += per_sample_score | |
total_score = total_score / len(results) | |
return {'ANLS': total_score} | |
def levenshtein_distance(s1, s2): | |
if len(s1) > len(s2): | |
s1, s2 = s2, s1 | |
distances = range(len(s1) + 1) | |
for i2, c2 in enumerate(s2): | |
distances_ = [i2 + 1] | |
for i1, c1 in enumerate(s1): | |
if c1 == c2: | |
distances_.append(distances[i1]) | |
else: | |
distances_.append(1 + min((distances[i1], distances[i1 + 1], | |
distances_[-1]))) | |
distances = distances_ | |
return distances[-1] | |