File size: 4,810 Bytes
0f87dc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import logging
import os
from collections import Counter, defaultdict
import multiprocessing
from datetime import datetime
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Dict, List, Tuple
import gc

from fastapi import FastAPI
from fastapi.responses import RedirectResponse

from api.bigcodebench_data import load_solutions
from api.code_execution import untrusted_check

Result = Tuple[str, List[bool]]

def create_app() -> FastAPI:

    level = os.environ.get("LOG_LEVEL", default=logging.INFO)
    logging.basicConfig(level=level)
    logger = logging.getLogger(__name__)

    app = FastAPI()

    @app.get("/")
    def root():
        return RedirectResponse("/docs")

    @app.get("/health", status_code=204)
    def health():
        return

    @app.post("/evaluate/")
    async def evaluate(
        samples: str,
        parallel: int = -1,
        min_time_limit: float = 1,
        max_as_limit: int = 30 * 1024,
        max_data_limit: int = 30 * 1024,
        max_stack_limit: int = 10,
        no_gt: bool = True,
    ) -> dict:
        """
        Evaluate the correctness of the solutions in the given samples file.
        """
        if parallel < 1:
            n_workers = max(1, multiprocessing.cpu_count() // 2)
        else:
            n_workers = parallel

        if not no_gt:
            expected_time = get_groundtruth()
        else:
            expected_time = {}

        results = {
            "date": datetime.now().strftime("%Y-%m-%d %H:%M"),
            "eval": {},
        }

        with ProcessPoolExecutor(max_workers=n_workers) as executor:
            futures = []
            completion_id = Counter()
            n_samples = 0
            eval_results = defaultdict(list)  # task_id ->
            remainings = set()

            for sample in load_solutions(samples):
                task_id = sample["task_id"]
                
                solution = sample["solution"]

                if "sanitized-calibrated" in samples:
                    solution = sample["code_prompt"] + "\n    pass\n" + solution
                remainings.add(sample["_identifier"])
                args = (
                    completion_id[task_id],
                    sample["res_id"],
                    task_id,
                    solution,
                    sample["test"],
                    sample["entry_point"],
                    max_as_limit,
                    max_data_limit,
                    max_stack_limit,
                    sample["_identifier"],
                    min_time_limit,
                    expected_time.get(task_id) if expected_time.get(task_id) else 20
                )
                futures.append(executor.submit(check_correctness, *args))
                completion_id[task_id] += 1
                n_samples += 1

            assert n_samples == len(remainings), "Missing problems in unfinished"
            #assert len(completion_id) == len(problems), "Missing problems in samples"

            for future in as_completed(futures):
                result = future.result()
                remainings.remove(result["_identifier"])
                eval_results[result["task_id"]].append(result)
                del future, result
                gc.collect()
        
        # sort the results for each problem by completion_id
        for task_id, task_results in eval_results.items():
            task_results.sort(key=lambda x: x["completion_id"])
            results["eval"][task_id] = []
            for res in task_results:
                stat, details = res["base"]
                results["eval"][task_id].append(
                    {
                        "res_id": res["res_id"],
                        "task_id": task_id,
                        "solution": res["solution"],
                        "status": stat,
                        "details": details,
                    }
                )
        return results

    return app

def check_correctness(
    completion_id: int,
    res_id: int,
    task_id: str,
    solution: str,
    test: str,
    entry_point: str,
    max_as_limit: float,
    max_data_limit: float,
    max_stack_limit: float,
    identifier=None,
    min_time_limit: float = 0.1,
    gt_time_limit: float = 2.0,
) -> Dict[str, Result]:  
    ret = {
        "completion_id": completion_id,
        "res_id": res_id,
        "task_id": task_id,
        "_identifier": identifier,
        "solution": solution,
    }
    ret["base"] = untrusted_check(
        solution,
        test,
        entry_point,
        max_as_limit,
        max_data_limit,
        max_stack_limit,
        min_time_limit,
        gt_time_limit,
    )
    return ret

def get_groundtruth():
    raise NotImplementedError("Groundtruth execution is not implemented yet.")