File size: 6,458 Bytes
8af54b8
 
 
 
 
 
a6d7b1c
 
4c7982b
 
 
be1543a
 
9827786
3a8c0d0
f2c1a54
8af54b8
 
 
 
 
 
be1543a
8af54b8
 
 
be1543a
8af54b8
 
 
 
 
 
 
 
33a6f85
 
 
8af54b8
 
 
 
 
 
 
 
 
9827786
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8af54b8
e01a5f6
 
a6d7b1c
845a45a
 
3a8c0d0
 
 
 
 
 
 
f2c1a54
 
 
 
 
 
 
 
 
 
 
 
a6d7b1c
be1543a
 
a6d7b1c
 
3a8c0d0
 
f2c1a54
 
 
 
a6d7b1c
3a8c0d0
be1543a
 
 
 
 
 
 
 
 
3a8c0d0
 
18cd4ae
845a45a
 
 
 
 
 
141ccb9
 
 
 
 
 
 
 
 
 
3a8c0d0
845a45a
 
 
 
 
 
 
d13c0d8
 
 
3a8c0d0
d13c0d8
 
 
 
f2c1a54
 
 
 
 
 
 
 
 
 
 
 
 
 
3a8c0d0
 
 
 
 
 
845a45a
3a8c0d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be1543a
 
a6d7b1c
3a8c0d0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
try:
    from ipytorch import logging
except Exception as e:
    import logging

from typing import Any, Optional, Protocol, Iterable, Callable
from tqdm.auto import tqdm
from evaluate.evaluation_suite import EvaluationSuite
import evaluate
import numpy as np
import datasets
import pandas as pd
from .tasks import *
from .utils import *
from itertools import chain
from copy import deepcopy


class ReasoningMetric(evaluate.Metric):
    """TODO: Short description of my evaluation module."""

    def _info(self):
        # if self.config_name in ["cmmlu"]:
        features = datasets.Features(
            {
                "responses": datasets.Value("string"),
                # "responses": datasets.Sequence(datasets.Value("float")),
                "references": datasets.Value("string"),
            }
        )

        # TODO: Specifies the evaluate.EvaluationModuleInfo object
        return evaluate.EvaluationModuleInfo(
            # This is the description that will appear on the modules page.
            # module_type="measurement",
            description="",
            citation="",
            inputs_description="",
            # This defines the format of each prediction and reference
            features=features,
            # Homepage of the module for documentation
            homepage="http://module.homepage",
            # Additional links to the codebase or references
            codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
            reference_urls=["http://path.to.reference.url/new_module"],
        )

    def _compute(self, responses, references):
        return_value = getattr(Metrics, self.config_name)(responses, references)
        match return_value:
            case tuple():
                extract_responses, extract_references = return_value
                results = {
                    self.config_name: np.mean(
                        sync_pipe(lambda x, y: x == y)(
                            zip(extract_responses, extract_references)
                        )
                    )
                }
            case dict():
                results = return_value

            case list():
                results = {self.config_name: np.mean(return_value)}

            case _:
                raise NotImplementedError

        return results


class Suite(EvaluationSuite):
    task_class = Task

    def __getitem__(self, key) -> Task:
        match key:
            case str():
                return self.suite[key]
            # case _:
            #     return list(chain(*self.suite.values()))[key]

    def aggregate(self, suite):
        for cate, tasks in suite.items():
            if isinstance(tasks, dict):
                suite[cate] = self.aggregate(tasks)
            else:
                result = []
                for task in tasks:
                    result.extend(task.result.values())
                suite[cate] = np.mean(result)

        return suite

    def run(
        self,
        model_or_pipeline: Any,
    ) -> dict[str, float]:
        self.assert_suite_nonempty()

        self.suite: dict[str, list[Task]]
        for task in (bar := tqdm(self.tasks)):
            bar.desc = f"complete {task.name}."
            _ = task.run(model_or_pipeline)
        return self.aggregate(deepcopy(self.suite))

    def get_suite(self, name) -> dict[str, Task]:
        chat = False
        match name:
            case _ if "chat" in name:
                chat = True
        match name:
            case _ if name.startswith("mmlu"):
                suite = MMLU.suite(chat=chat)
            case _ if name.startswith("cmmlu"):
                suite = CMMLU.suite(chat=chat)
            case _ if name.startswith("ceval"):
                suite = CEVAL.suite(chat=chat)
            case "gsm8k":
                suite = Task(
                    dataset_name=("gsm8k", "main"),
                    metric_name=("sustech/tlem", "gsm8k"),
                    input_column="question",
                    label_column="answer",
                )
            case "bbh":
                suite = BBH.suite()
            case "arc":
                suite = ARC.suite()
            case "hellaswag":
                suite = HellaSwag.suite()
            case "drop":
                suite = DROP.suite()
            case "winogrande":
                suite = Winogrande.suite()

            case "mt_bench":
                suite = Task(
                    dataset_name="SUSTech/mt_bench_judge",
                    split="train",
                    prompt=mt_bench_prompt
                    # metric_name=("sustech/tlem", "gsm8k"),
                )
            case "MATH" | "competition_math":
                suite = Task(
                    dataset_name="hendrycks/competition_math",
                    prompt="This is a math problem, please think step by step and slove it: {input_column}. Simplify your final answer as much as possible and surround them with '$' in TeX form",
                    metric_name=("sustech/tlem", "MATH"),
                    input_column="problem",
                    label_column="solution",
                )

            case "open-leaderboard":
                suite = {}
                for name in [
                    "arc",
                    "hellaswag",
                    "mmlu-chat",
                    "winogrande",
                    "gsm8k",
                    # "truthful_qa",
                    "drop",
                ]:
                    suite[name] = self.get_suite(name)

        if isinstance(suite, Task):
            suite = [suite]
        if isinstance(suite, list):
            suite = {name: suite}

        return suite

    def singleton(self, task):
        try:
            return self.tasks[self.tasks.index(task)]

        except Exception as e:
            self.tasks.append(task)
            return self.tasks[-1]

    def drop_duplicates(self, suite):
        for category, tasks in suite.items():
            if isinstance(tasks, dict):
                suite[category] = self.drop_duplicates(tasks)
            else:
                suite[category] = [self.singleton(task) for task in tasks]
        return suite

    def load(self, name):
        self.suite.update(self.get_suite(name))
        self.suite = self.drop_duplicates(self.suite)

    def __init__(self, name="tlem"):
        super().__init__(name)
        self.tasks = []
        self.suite = {}