File size: 5,875 Bytes
7bf4b88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from typing import List, Dict, Any
import torch
from torchmetrics.functional import (
    retrieval_hit_rate, retrieval_reciprocal_rank, retrieval_recall, 
    retrieval_precision, retrieval_average_precision, retrieval_normalized_dcg, 
    retrieval_r_precision
)

class Evaluator:
    
    def __init__(self, candidate_ids: List[int], device: str = 'cpu'):
        """
        Initializes the evaluator with the given candidate IDs.
        
        Args:
            candidate_ids (List[int]): List of candidate IDs.
        """
        self.candidate_ids = candidate_ids
        self.device = device

    def __call__(self, 
                 pred_dict: Dict[int, float], 
                 answer_ids: torch.LongTensor, 
                 metrics: List[str] = ['mrr', 'hit@3', 'recall@20']) -> Dict[str, float]:
        """
        Evaluates the predictions using the specified metrics.
        
        Args:
            pred_dict (Dict[int, float]): Dictionary of predicted scores.
            answer_ids (torch.LongTensor): Ground truth answer IDs.
            metrics (List[str]): List of metrics to be evaluated, including 'mrr', 'hit@k', 'recall@k', 
                                 'precision@k', 'map@k', 'ndcg@k'.
                             
        Returns:
            Dict[str, float]: Dictionary of evaluation metrics.
        """
        return self.evaluate(pred_dict, answer_ids, metrics)
    
    def evaluate(self, 
                 pred_dict: Dict[int, float], 
                 answer_ids: torch.LongTensor, 
                 metrics: List[str] = ['mrr', 'hit@3', 'recall@20']) -> Dict[str, float]:
        """
        Evaluates the predictions using the specified metrics.
        
        Args:
            pred_dict (Dict[int, float]): Dictionary of predicted scores.
            answer_ids (torch.LongTensor): Ground truth answer IDs.
            metrics (List[str]): A list of metrics to be evaluated, including 'mrr', 'hit@k', 'recall@k', 
                                 'precision@k', 'map@k', 'ndcg@k'.
                             
        Returns:
            Dict[str, float]: Dictionary of evaluation metrics.
        """
        # Convert prediction dictionary to tensor
        pred_ids = torch.LongTensor(list(pred_dict.keys())).view(-1)
        pred = torch.FloatTensor(list(pred_dict.values())).view(-1)
        answer_ids = answer_ids.view(-1)

        # Initialize all predictions to a very low value
        all_pred = torch.ones(max(self.candidate_ids) + 1, dtype=torch.float) * (min(pred) - 1)
        all_pred[pred_ids] = pred
        all_pred = all_pred[self.candidate_ids]

        # Initialize ground truth boolean tensor
        bool_gd = torch.zeros(max(self.candidate_ids) + 1, dtype=torch.bool)
        bool_gd[answer_ids] = True
        bool_gd = bool_gd[self.candidate_ids]

        # Compute evaluation metrics
        eval_metrics = {}
        for metric in metrics:
            k = int(metric.split('@')[-1]) if '@' in metric else None
            if metric == 'mrr':
                result = retrieval_reciprocal_rank(all_pred, bool_gd)
            elif metric == 'rprecision':
                result = retrieval_r_precision(all_pred, bool_gd)
            elif 'hit' in metric:
                result = retrieval_hit_rate(all_pred, bool_gd, top_k=k)
            elif 'recall' in metric:
                result = retrieval_recall(all_pred, bool_gd, top_k=k)
            elif 'precision' in metric:
                result = retrieval_precision(all_pred, bool_gd, top_k=k)
            elif 'map' in metric:
                result = retrieval_average_precision(all_pred, bool_gd, top_k=k)
            elif 'ndcg' in metric:
                result = retrieval_normalized_dcg(all_pred, bool_gd, top_k=k)
            eval_metrics[metric] = float(result)

        return eval_metrics
    
    def evaluate_batch(self, 
                 pred_ids,
                 pred,
                 answer_ids: List[Any], 
                 metrics: List[str] = ['mrr', 'hit@3', 'recall@20']) -> Dict[str, float]:
        # why using "torch.ones((max(self.candidate_ids) + 1, pred.shape[1])" not pred.shape to create tensors? 
        all_pred = torch.ones((max(self.candidate_ids) + 1, pred.shape[1]), dtype=torch.float) * (pred.min() - 1)
        all_pred[pred_ids, :] = pred
        all_pred = all_pred[self.candidate_ids].t().to(self.device)

        bool_gd = torch.zeros((max(self.candidate_ids) + 1, pred.shape[1]), dtype=torch.bool)
        bool_gd[torch.concat(answer_ids), torch.repeat_interleave(torch.arange(len(answer_ids)), torch.tensor(list(map(len, answer_ids))))] = True
        bool_gd = bool_gd[self.candidate_ids].t().to(self.device)

        results = []
        for i in range(len(answer_ids)):
            eval_metrics = {}
            for metric in metrics:
                k = int(metric.split('@')[-1]) if '@' in metric else None
                if metric == 'mrr':
                    result = retrieval_reciprocal_rank(all_pred[i], bool_gd[i])
                elif metric == 'rprecision':
                    result = retrieval_r_precision(all_pred[i], bool_gd[i])
                elif 'hit' in metric:
                    result = retrieval_hit_rate(all_pred[i], bool_gd[i], top_k=k)
                elif 'recall' in metric:
                    result = retrieval_recall(all_pred[i], bool_gd[i], top_k=k)
                elif 'precision' in metric:
                    result = retrieval_precision(all_pred[i], bool_gd[i], top_k=k)
                elif 'map' in metric:
                    result = retrieval_average_precision(all_pred[i], bool_gd[i], top_k=k)
                elif 'ndcg' in metric:
                    result = retrieval_normalized_dcg(all_pred[i], bool_gd[i], top_k=k)
                eval_metrics[metric] = float(result)
            results.append(eval_metrics)
        return results