from typing import List, Dict, Any
from loguru import logger
import ast
import re
import json
from tqdm import tqdm


def get_parameter_names(prompt: str, entry_point: str) -> List[str]:
    """
    Extract parameter names from the function signature in the prompt.
    """
    # logger.debug(f"Prompt: {prompt}")
    # logger.debug(f"Entry point: {entry_point}")
    tree = ast.parse(prompt)
    for node in ast.walk(tree):
        # logger.debug(f"Node name: {node.name if hasattr(node, 'name') else None}")
        if isinstance(node, ast.FunctionDef) and node.name == entry_point:
            # Return the parameter names from the function definition that matches the entry point
            return [param.arg for param in node.args.args]
    return []


def parse_tests(test: str, parameter_names: List[str], entry_point: str) -> Dict[str, List[Dict[str, Any]]]:
    """
    Parse the test string into a structured format using AST.
    """
    # Remove the METADATA section
    test = re.sub(r'METADATA = \{[^}]*\}', '', test)

    # Parse the entire test string
    tree = ast.parse(test)

    test_cases = []
    for node in ast.walk(tree):
        if isinstance(node, ast.Assert):
            # Process each assert statement
            test_case = process_assert(node, entry_point, parameter_names)
            if test_case:
                test_cases.append(test_case)

    return {"test_cases": test_cases}


def process_assert(node: ast.Assert, entry_point: str, parameter_names: List[str]) -> Dict[str, Any]:
    """
    Process a single assert statement and extract input and expected output.
    """
    if isinstance(node.test, ast.Compare) and isinstance(node.test.ops[0], ast.Eq):
        left = node.test.left
        right = node.test.comparators[0]

        if isinstance(left, ast.Call) and isinstance(left.func, ast.Name) and left.func.id == "candidate":
            input_dict = process_input(left.args, parameter_names)
            # logger.debug(f"Input: {input_dict}")
            # logger.debug(f"right: {right}")
            # logger.debug(f"right type: {type(right)}")
            # logger.debug(f"right value: {right.name if isinstance(right, ast.Name) else right.s if isinstance(right, ast.Str) else None}")

            try:
                # Attempt to evaluate using literal_eval
                expected_output = ast.literal_eval(right)
            except ValueError:
                # Fallback to eval if literal_eval fails
                # logger.warning("Falling back to eval due to failure in literal_eval")
                expected_output = eval(compile(ast.Expression(right), filename="<ast>", mode="eval"))

            return {"input": input_dict, "expected_output": expected_output}

    return None


def process_input(args: List[ast.expr], parameter_names: List[str]) -> Dict[str, Any]:
    """
    Process the input arguments and match them with parameter names.
    """
    input_dict = {}

    for i, arg in enumerate(args):
        try:
            # Attempt to evaluate using literal_eval for simpler cases
            evaluated_arg = ast.literal_eval(arg)
        except ValueError:
            # Fallback to eval if literal_eval fails
            # logger.warning("Falling back to eval due to failure in literal_eval")
            evaluated_arg = eval(compile(ast.Expression(arg), filename="<ast>", mode="eval"))

        if i < len(parameter_names):
            input_dict[parameter_names[i]] = evaluated_arg
        else:
            # Handle extra arguments if any
            input_dict[f"arg_{i}"] = evaluated_arg

    return input_dict


def parse_all_problems(problems):
    success_count = 0
    unhandled_failures = 0
    for problem in problems:
        try:
            problem = json.loads(problem)

            # logger.info(f"Problem: {problem}")
            # logger.debug(f"Test: {problem['test']}")

            entry_point = problem["entry_point"]
            parameter_names = get_parameter_names(problem["prompt"], entry_point)
            # logger.info(f"Parameter names: {parameter_names}")

            given_tests_raw = "\n".join(problem["given_tests"]).replace(entry_point, "candidate")
            given_tests = parse_tests(given_tests_raw, parameter_names, entry_point)

            # Parse the test cases using the parameter names
            parsed_tests = parse_tests(problem["test"], parameter_names, entry_point)
            # logger.info(f"Parsed tests: {parsed_tests}")
            success_count += 1
        except:
            logger.exception(f"Error processing problem {problem['task_id']}")
            if problem['is_solved'] == False:
                unhandled_failures += 1
            continue

    logger.info(f"Success count: {success_count}")
    logger.info(f"Total problems: {len(problems)}")
    logger.info(f"Unhandled failures: {unhandled_failures}")


def parse_specific_problem(problem):
    try:
        if isinstance(problem, str):
            problem = json.loads(problem)

        logger.info(f"Problem: {problem}")
        logger.debug(f"Test: {problem['test']}")
        logger.debug(f"Given Test: {problem['given_tests']}")

        entry_point = problem["entry_point"]
        parameter_names = get_parameter_names(problem["prompt"], entry_point)
        logger.debug(f"Parameter names: {parameter_names}")

        given_tests_raw = "\n".join(problem["given_tests"]).replace(entry_point, "candidate")
        given_tests = parse_tests(given_tests_raw, parameter_names, entry_point)
        logger.debug(f"Given tests: {given_tests}")

        # Parse the test cases using the parameter names
        all_tests = parse_tests(problem["test"], parameter_names, entry_point)
        logger.debug(f"Parsed tests: {all_tests}")
        return all_tests
    except:
        logger.exception(f"Error processing problem {problem['task_id']}")
        return None

#assert next_smallest([]) is None
#assert decode_cyclic(encode_cyclic("abc")) == "abc"
#assert round(find_zero([-6, 11, -6, 1]), 2) == 1.0
#assert abs(candidate(1.33) - 0.33) < 1e-6

def check_all_problems(problems):
    problems_q = []
    success_count = 0
    fail_count = 0
    for problem in tqdm(problems):
        try:
            problem = json.loads(problem)

            logger.info(f"Problem: {problem}")
            logger.debug(f"Test: {problem['test']}")
            logger.debug(f"All Test: {problem['given_tests']}")

            entry_point = problem["entry_point"]
            parameter_names = get_parameter_names(problem["prompt"], entry_point)
            logger.info(f"Parameter names: {parameter_names}")

            # given_tests_len = len(problem["given_tests"])
            # given_tests_raw = "\n".join(problem["given_tests"]).replace(entry_point, "candidate")
            # given_tests = parse_tests(given_tests_raw, parameter_names, entry_point)
            # parsed_given_tests_len = len(given_tests['test_cases'])
            # assert given_tests_len == parsed_given_tests_len
            # success_count += 1

            #Parse the test cases using the parameter names
            tests_len_candidate =  problem["test"].count('candidate')
            parsed_tests = parse_tests(problem["test"], parameter_names, entry_point)
            parsed_test_len = len(parsed_tests['test_cases'])
            #assert parsed_test_len != 0 
            assert tests_len_candidate - 1 == parsed_test_len
            logger.info(f"Parsed tests: {parsed_tests}")
            success_count += 1
        except:
            logger.exception(f"Error processing problem {problem['task_id']}")
            if problem['is_solved'] == False:
                fail_count += 1
                problems_q.append(problem['task_id'])
            continue
    
    with open('output_data/humaneval/seed/deepseek-coder-v2-lite-instruct/20240828-174550/dscoder_debugged_seeds_deepseek-coder-v2-lite-instruct_1_1_10.jsonl', "r") as f:
        fixed = f.readlines()
    for fix_problem in fixed:
        fix_problem = json.loads(fix_problem)
        if fix_problem['task_id'] in problems_q:
            print(1)

    logger.info(f"Success count: {success_count}")
    logger.info(f"Total problems: {len(problems)}")
    logger.info(f"Unhandled failures: {fail_count}")

if __name__ == "__main__":
    input_seeds = "input_data/humaneval/seed/deepseek-coder-v2-lite-instruct/seed.jsonl"

    with open(input_seeds, "r") as f:
        problems = f.readlines()

    check_all_problems(problems)
    #parse_all_problems(problems)

    # parse the one with 'task_id': 'HumanEval/32'
    # for problem in problems:
    #     problem = json.loads(problem)
    #     if problem['task_id'] == 'HumanEval/33':
    #         parsed_tests = parse_specific_problem(problem)
    #         break