KyanChen's picture
Upload 1861 files
3b96cb1
raw
history blame
3.84 kB
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional
from mmengine.evaluator import BaseMetric
from mmpretrain.registry import METRICS
@METRICS.register_module()
class ANLS(BaseMetric):
"""ANLS metric.
Compute the Average Normalized Levenshtein Similarity(ANLS).
Args:
threshold (float): ANLS threshold used for determining if the answer
has been correctly selected but not properly recognized,
or on the contrary, the output is a wrong text selected from the
options and given as an answer.
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Should be modified according to the
`retrieval_type` for unambiguous results. Defaults to TR.
"""
default_prefix = 'ANLS'
def __init__(self,
threshold: float = 0.5,
collect_device: str = 'cpu',
prefix: Optional[str] = None) -> None:
super().__init__(collect_device=collect_device, prefix=prefix)
self.threshold = threshold
def process(self, data_batch, data_samples) -> None:
"""Process one batch of data samples.
The processed results should be stored in ``self.results``, which will
be used to computed the metrics when all batches have been processed.
Args:
data_batch: A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
for sample in data_samples:
gt_answer = sample.get('gt_answer')
result = {
'pred_answer': sample.get('pred_answer'),
'gt_answer': gt_answer
}
self.results.append(result)
def compute_metrics(self, results: List) -> dict:
"""Compute the metrics from processed results.
Args:
results (dict): The processed results of each batch.
Returns:
Dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results.
"""
total_score = 0.
for result in results:
sample_score_list = []
pred = ' '.join(result['pred_answer'].strip().lower().split())
for gt in result['gt_answer']:
gt = ' '.join(gt.strip().lower().split())
dist = levenshtein_distance(gt, pred)
length = max(
len(gt.upper()), len(result['pred_answer'].upper()))
sample_score_list.append(0.0 if length == 0 else float(dist) /
float(length))
per_sample_score = 1. - min(sample_score_list)
if per_sample_score < self.threshold:
per_sample_score = 0.
total_score += per_sample_score
total_score = total_score / len(results)
return {'ANLS': total_score}
def levenshtein_distance(s1, s2):
if len(s1) > len(s2):
s1, s2 = s2, s1
distances = range(len(s1) + 1)
for i2, c2 in enumerate(s2):
distances_ = [i2 + 1]
for i1, c1 in enumerate(s1):
if c1 == c2:
distances_.append(distances[i1])
else:
distances_.append(1 + min((distances[i1], distances[i1 + 1],
distances_[-1])))
distances = distances_
return distances[-1]