import os
import re
import subprocess
import tempfile
import multiprocessing
from collections import Counter
from contextlib import contextmanager
from dataclasses import dataclass


class PythonREPL:
    def __init__(self, timeout=5):
        self.timeout = timeout

    @staticmethod
    def _run_code(temp_file_path):
        result = subprocess.run(
            ["python3", temp_file_path],
            capture_output=True,
            check=False,
            text=True
        )
        if result.returncode == 0:
            return True, result.stdout.strip()
        else:
            error_msg = result.stderr.strip()
            msgs = error_msg.split("\n")
            new_msgs = []
            want_next = False
            for m in msgs:
                if "Traceback" in m:
                    new_msgs.append(m)
                elif m == msgs[-1]:
                    new_msgs.append(m)
                elif temp_file_path in m:
                    st = m.index('"/') + 1 if '"/' in m else 0
                    ed = m.index(temp_file_path) + 1 if temp_file_path in m else None
                    clr = m[st:ed] if not ed else m[st:]
                    m = m.replace(clr, "")
                    new_msgs.append(m)
                    want_next = True
                elif want_next:
                    new_msgs.append(m)
                    want_next = False
            return False, "\n".join(new_msgs).strip()

    def __call__(self, query):
        query = "import math\nimport numpy as np\nimport sympy as sp\n" + query
        query = query.strip().split("\n")
        if "print(" not in query[-1]:
            if "#" in query[-1]:
                query[-1] = query[-1].split("#")[0]
            query[-1] = "print(" + query[-1] + ")"
        query = "\n".join(query)
        
        with tempfile.TemporaryDirectory() as temp_dir:
            temp_file_path = os.path.join(temp_dir, "tmp.py")
            with open(temp_file_path, "w", encoding="utf-8") as f:
                f.write(query)

            with multiprocessing.Pool(1) as pool:
                result = pool.apply_async(self._run_code, (temp_file_path,))
                try:
                    success, output = result.get(self.timeout)
                except multiprocessing.TimeoutError:
                    pool.terminate()
                    return False, f"Timed out after {self.timeout} seconds."
        return success, output


def execute_completion(executor, completion, return_status, last_code_block):
    executions = re.findall(r"```python(.*?)```", completion, re.DOTALL)
    if len(executions) == 0:
        return completion, False if return_status else completion
    if last_code_block:
        executions = [executions[-1]]
    outputs = []
    successes = []
    for code in executions:
        success = False
        for lib in ("subprocess", "venv"):
            if lib in code:
                output = f"{lib} is not allowed"
                outputs.append(output)
                successes.append(success)
                continue
        try:
            success, output = executor(code)
        except TimeoutError as e:
            print("Code timed out")
            output = e
        if not success and not return_status:
            output = ""
        outputs.append(output)
        successes.append(success)
    output = str(outputs[-1]).strip()
    success = successes[-1]
    if return_status:
        return output, success
    return output


def postprocess_completion(text, return_status, last_code_block):
    executor = PythonREPL()
    result = execute_completion(executor, text, return_status=return_status, last_code_block=last_code_block)
    del executor
    return result


def get_majority_vote(answers):
    if not len(answers):
        return 0
    c = Counter(answers)
    value, _ = c.most_common()[0]
    return value