import ast
import io
import sys
import subprocess
import tempfile
import gradio as gr
import shutil
from contextlib import redirect_stdout

# =============================================================================
# Python To JavaScript 轉換器
# =============================================================================

class PythonToJavaScriptConverter(ast.NodeVisitor):
    def __init__(self):
        self.variables = set()  # 用於追蹤已聲明的變量
        self.current_scope = []
        self.indent_level = 0

    def indent(self):
        return '    ' * self.indent_level

    def visit_FunctionDef(self, node):
        self.current_scope.append(node.name)
        func_name = node.name
        args = [arg.arg for arg in node.args.args]
        defaults = [self.visit(d) for d in node.args.defaults]
        args_list = []
        num_args = len(args)
        num_defaults = len(defaults)
        for i, arg in enumerate(args):
            if i >= num_args - num_defaults:
                default = defaults[i - (num_args - num_defaults)]
                args_list.append(f"{arg} = {default}")
            else:
                args_list.append(arg)
        args_str = ", ".join(args_list)
        self.indent_level += 1
        body = "\n".join(self.visit(n) for n in node.body)
        self.indent_level -= 1
        self.current_scope.pop()
        return f"{self.indent()}function {func_name}({args_str}) {{\n{body}\n{self.indent()}}}"

    def visit_Return(self, node):
        if node.value:
            return f"{self.indent()}return {self.visit(node.value)};"
        else:
            return f"{self.indent()}return;"

    def visit_Assign(self, node):
        target = node.targets[0]
        value = self.visit(node.value)
        if isinstance(target, ast.Tuple):
            targets = [self.visit(t) for t in target.elts]
            if not isinstance(node.value, ast.Tuple):
                value = f"[{value}]"
            assignment = f"{self.indent()}let [{', '.join(targets)}] = {value};"
        else:
            target_str = self.visit(target)
            if target_str not in self.variables:
                declaration = "let "
                self.variables.add(target_str)
            else:
                declaration = ""
            assignment = f"{self.indent()}{declaration}{target_str} = {value};"
        return assignment

    def visit_If(self, node):
        test = self.visit(node.test)
        self.indent_level += 1
        if_body = "\n".join(self.visit(n) for n in node.body)
        self.indent_level -= 1
        if node.orelse:
            self.indent_level += 1
            else_body = "\n".join(self.visit(n) for n in node.orelse)
            self.indent_level -= 1
            return f"{self.indent()}if (({test})) {{\n{if_body}\n{self.indent()}}} else {{\n{else_body}\n{self.indent()}}}"
        else:
            return f"{self.indent()}if (({test})) {{\n{if_body}\n{self.indent()}}}"

    def visit_While(self, node):
        test = self.visit(node.test)
        self.indent_level += 1
        body = "\n".join(self.visit(n) for n in node.body)
        self.indent_level -= 1
        return f"{self.indent()}while ({test}) {{\n{body}\n{self.indent()}}}"

    def visit_For(self, node):
        target = self.visit(node.target)
        if isinstance(node.iter, ast.Call) and isinstance(node.iter.func, ast.Name) and node.iter.func.id == "range":
            args = [self.visit(arg) for arg in node.iter.args]
            if len(args) == 1:
                iterable = f"Array.from({{length: {args[0]}}}, (_, {target}) => {target})"
            elif len(args) == 2:
                iterable = f"Array.from({{length: {args[1]} - {args[0]}}}, (_, {target}) => {target} + {args[0]})"
            else:
                iterable = f"Array.from({{length: 1}}, (_, {target}) => {target})"
        else:
            iterable = self.visit(node.iter)
        self.indent_level += 1
        body = "\n".join(self.visit(n) for n in node.body)
        self.indent_level -= 1
        return f"{self.indent()}for (let {target} of {iterable}) {{\n{body}\n{self.indent()}}}"

    def visit_Try(self, node):
        try_body = "\n".join(self.visit(n) for n in node.body)
        catch_body = ""
        if node.handlers:
            handler_lines = []
            for handler in node.handlers:
                if handler.type:
                    type_str = self.visit(handler.type)
                    handler_body = "\n".join(self.visit(n) for n in handler.body)
                    clause = f"if (e instanceof {type_str}) {{\n{self.indent()}    {handler_body}\n{self.indent()}}}"
                    handler_lines.append(clause)
                else:
                    handler_body = "\n".join(self.visit(n) for n in handler.body)
                    handler_lines.append(f"{{\n{self.indent()}    {handler_body}\n{self.indent()}}}")
            catch_body = ("\n" + self.indent() + "else ").join(handler_lines)
        else:
            catch_body = f"{self.indent()}/* No exception handlers */"
        finally_body = ""
        if node.finalbody:
            self.indent_level += 1
            finally_body = "\n".join(self.visit(n) for n in node.finalbody)
            self.indent_level -= 1
            finally_body = f"\n{self.indent()}finally {{\n{finally_body}\n{self.indent()}}}"
        return f"{self.indent()}try {{\n{try_body}\n{self.indent()}}} catch(e) {{\n{self.indent()}    {catch_body}\n{self.indent()}}}{finally_body}"

    def visit_With(self, node):
        items = []
        for item in node.items:
            context_expr = self.visit(item.context_expr)
            optional_vars = self.visit(item.optional_vars) if item.optional_vars else ""
            items.append(f"{context_expr} as {optional_vars}" if optional_vars else context_expr)
        with_comment = f"// with: {', '.join(items)}"
        self.indent_level += 1
        body = "\n".join(self.visit(n) for n in node.body)
        self.indent_level -= 1
        return f"{self.indent()}{with_comment}\n{body}"

    def visit_Call(self, node):
        func_name = ""
        if isinstance(node.func, ast.Name):
            func_name = node.func.id
        if func_name == "print":
            func = "console.log"
        elif func_name == "range":
            return self.handle_range(node)
        elif func_name == "list":
            arg = self.visit(node.args[0])
            return f"Array.from({arg})"
        else:
            func = self.visit(node.func)
        args = ", ".join(self.visit(arg) for arg in node.args)
        return f"{func}({args})"

    def handle_range(self, node):
        args = [self.visit(arg) for arg in node.args]
        if len(node.args) == 1:
            return f"Array.from({{length: {args[0]}}}, (_, i) => i)"
        elif len(node.args) == 2:
            return f"Array.from({{length: {args[1]} - {args[0]}}}, (_, i) => i + {args[0]})"
        return f"Array.from({{length: 1}})"

    def visit_ListComp(self, node):
        if len(node.generators) > 1:
            return "/* Unsupported list comprehension with multiple generators */"
        gen = node.generators[0]
        target = self.visit(gen.target)
        if_condition = None
        if gen.ifs:
            conditions = " && ".join(self.visit(cond) for cond in gen.ifs)
            if_condition = conditions
        iter_expr = gen.iter
        if isinstance(iter_expr, ast.Call) and isinstance(iter_expr.func, ast.Name) and iter_expr.func.id == "range":
            args = [self.visit(arg) for arg in iter_expr.args]
            if len(iter_expr.args) == 1:
                range_str = f"Array.from({{length: {args[0]}}}, (_, {target}) => {target})"
            elif len(iter_expr.args) == 2:
                range_str = f"Array.from({{length: {args[1]} - {args[0]}}}, (_, {target}) => {target} + {args[0]})"
            else:
                range_str = f"Array.from({{length: 1}}, (_, {target}) => {target})"
        else:
            range_str = self.visit(iter_expr)
        elt = self.visit(node.elt)
        if if_condition:
            return f"{range_str}.filter(({target}) => {if_condition}).map(({target}) => ({elt}))"
        else:
            return f"{range_str}.map(({target}) => ({elt}))"

    def visit_GeneratorExp(self, node):
        if len(node.generators) > 1:
            return "/* Unsupported generator expression with multiple generators */"
        gen = node.generators[0]
        target = self.visit(gen.target)
        if_condition = None
        if gen.ifs:
            conditions = " && ".join(self.visit(cond) for cond in gen.ifs)
            if_condition = conditions
        iter_expr = gen.iter
        if isinstance(iter_expr, ast.Call) and isinstance(iter_expr.func, ast.Name) and iter_expr.func.id == "range":
            args = [self.visit(arg) for arg in iter_expr.args]
            if len(iter_expr.args) == 1:
                range_str = f"Array.from({{length: {args[0]}}}, (_, {target}) => {target})"
            elif len(iter_expr.args) == 2:
                range_str = f"Array.from({{length: {args[1]} - {args[0]}}}, (_, {target}) => {target} + {args[0]})"
            else:
                range_str = f"Array.from({{length: 1}}, (_, {target}) => {target})"
        else:
            range_str = self.visit(iter_expr)
        elt = self.visit(node.elt)
        if if_condition:
            gen_body = (
                f"for (let {target} of {range_str}) {{\n"
                f"    if (!({if_condition})) continue;\n"
                f"    yield ({elt});\n"
                f"}}"
            )
        else:
            gen_body = (
                f"for (let {target} of {range_str}) {{\n"
                f"    yield ({elt});\n"
                f"}}"
            )
        return f"(function* () {{\n{gen_body}\n}})()"

    def visit_BinOp(self, node):
        if isinstance(node.op, ast.FloorDiv):
            left = self.visit(node.left)
            right = self.visit(node.right)
            return f"Math.floor({left} / {right})"
        elif isinstance(node.op, ast.Pow):
            left = self.visit(node.left)
            right = self.visit(node.right)
            return f"Math.pow({left}, {right})"
        else:
            left = self.visit(node.left)
            op = self.visit(node.op)
            right = self.visit(node.right)
            return f"{left} {op} {right}"

    def visit_Compare(self, node):
        if len(node.ops) != 1 or len(node.comparators) != 1:
            return "/* Unsupported comparison */"
        left = self.visit(node.left)
        op = self.visit(node.ops[0])
        comparator = self.visit(node.comparators[0])
        return f"{left} {op} {comparator}"

    def visit_BoolOp(self, node):
        if isinstance(node.op, ast.And):
            op = "&&"
        elif isinstance(node.op, ast.Or):
            op = "||"
        else:
            op = "/* Unsupported BoolOp */"
        values = [self.visit(v) for v in node.values]
        return f"{f' {op} '.join(values)}"

    def visit_UnaryOp(self, node):
        if isinstance(node.op, ast.Not):
            op = "!"
        else:
            op = "/* Unsupported UnaryOp */"
        operand = self.visit(node.operand)
        return f"{op}{operand}"

    def visit_Attribute(self, node):
        obj = self.visit(node.value)
        return f"{obj}.{node.attr}"

    def visit_Subscript(self, node):
        value = self.visit(node.value)
        index = self.visit(node.slice)
        return f"{value}[{index}]"

    def visit_Index(self, node):
        return self.visit(node.value)

    def visit_Name(self, node):
        return node.id

    def visit_Constant(self, node):
        if node.value is None:
            return "null"
        elif isinstance(node.value, bool):
            return "true" if node.value else "false"
        elif isinstance(node.value, str):
            return f'"{node.value}"'
        return str(node.value)

    def visit_List(self, node):
        elements = [self.visit(e) for e in node.elts]
        return f"[{', '.join(elements)}]"

    def visit_Tuple(self, node):
        elements = [self.visit(e) for e in node.elts]
        return f"[{', '.join(elements)}]"

    def visit_Dict(self, node):
        pairs = []
        for k, v in zip(node.keys, node.values):
            key = self.visit(k) if k else ""
            value = self.visit(v)
            if isinstance(k, ast.Constant) and isinstance(k.value, str):
                key = f'"{k.value}"'
            pairs.append(f"{key}: {value}")
        return f"{{{', '.join(pairs)}}}"

    def visit_Expr(self, node):
        expr = self.visit(node.value)
        return f"{self.indent()}{expr};"

    def visit_AugAssign(self, node):
        target = self.visit(node.target)
        op = self.visit(node.op)
        value = self.visit(node.value)
        return f"{self.indent()}{target} {op}= {value};"

    def visit_Add(self, node):
        return "+"

    def visit_Sub(self, node):
        return "-"

    def visit_Mult(self, node):
        return "*"

    def visit_Div(self, node):
        return "/"

    def visit_Mod(self, node):
        return "%"

    def visit_Eq(self, node):
        return "==="

    def visit_NotEq(self, node):
        return "!=="

    def visit_Lt(self, node):
        return "<"

    def visit_LtE(self, node):
        return "<="

    def visit_Gt(self, node):
        return ">"

    def visit_GtE(self, node):
        return ">="

    def visit_ClassDef(self, node):
        class_name = node.name
        methods = []
        for n in node.body:
            if isinstance(n, ast.FunctionDef):
                if n.name == "__init__":
                    method_code = self.visit_FunctionDef(n)
                    method_code = method_code.replace("function __init__", "constructor")
                    methods.append(method_code)
                else:
                    methods.append(self.visit_FunctionDef(n))
            else:
                methods.append(self.visit(n))
        body_str = "\n".join(methods)
        return f"{self.indent()}class {class_name} {{\n{body_str}\n{self.indent()}}}"

    def generic_visit(self, node):
        return f"/* Unhandled node type: {type(node).__name__} */"

def convert_python_to_javascript(python_code):
    tree = ast.parse(python_code)
    converter = PythonToJavaScriptConverter()
    return "\n".join(filter(None, (converter.visit(n) for n in tree.body)))

# =============================================================================
# 執行 Python 程式碼,並捕捉其輸出
# =============================================================================
def run_python_code(code):
    stdout = io.StringIO()
    try:
        with redirect_stdout(stdout):
            exec(code, {})
    except Exception as e:
        stdout.write(f"Error: {str(e)}")
    return stdout.getvalue()

# =============================================================================
# 執行 JavaScript 程式碼(需要系統有 node.js 環境)
# =============================================================================
def run_js_code(js_code):
    if not shutil.which("node"):
        return "Error: Node.js is not installed or not available in the system PATH."
    try:
        with tempfile.NamedTemporaryFile("w", suffix=".js", delete=False) as temp_js:
            temp_js.write(js_code)
            temp_js.flush()
            temp_js_name = temp_js.name
        result = subprocess.run(["node", temp_js_name],
                                capture_output=True, text=True, timeout=5)
        output = result.stdout
        if result.stderr:
            output += "\nError:\n" + result.stderr
    except Exception as e:
        output = f"Error executing JS code: {str(e)}"
    return output

# =============================================================================
# 定義多組範例測試代碼 (覆蓋所有測試情況)
# =============================================================================
sample_dict = {
    "Simple Function": '''\
def add(a, b):
    return a + b
print(add(2, 3))
''',

    "If/Else": '''\
def max_num(a, b):
    if a > b:
        return a
    else:
        return b
print(max_num(10, 20))
''',

    "While Loop": '''\
i = 0
while i < 5:
    print(i)
    i += 1
''',

    "For Loop": '''\
for x in range(3):
    print(x)
''',

    "List Comprehension (without if)": '''\
numbers = [x * 2 for x in range(5)]
print(numbers)
''',

    "List Comprehension (with if)": '''\
numbers = [x for x in range(10) if x % 2 == 0]
print(numbers)
''',

    "Dictionary Access": '''\
person = {"name": "Alice", "age": 30}
print(person["name"])
''',

    "Augmented Assignment": '''\
count = 10
count += 5
print(count)
''',

    "Multiple Assignments": '''\
a, b = 1, 2
print(a)
print(b)
''',

    "Boolean Operations": '''\
a = True
b = False
if a and not b:
    print("Both conditions met")
''',

    "Nested Functions": '''\
def outer(x):
    def inner(y):
        return y * 2
    return inner(x) + 1
print(outer(5))
''',

    "None and Boolean": '''\
result = None
flag = True
if flag:
    result = "Success"
else:
    result = "Failure"
print(result)
''',

    "Floor Div and Pow": '''\
a = 5 // 2
b = 2 ** 3
print(a)
print(b)
''',

    "Class Definition": '''\
class Person:
    def __init__(self, name):
        self.name = name
    def greet(self):
        print("Hello, " + self.name)
p = Person("Alice")
p.greet()
''',

    "Default Parameter": '''\
def greet(name="World"):
    print("Hello, " + name)
greet()
''',

    "Nested If/Else": '''\
if a > b:
    if b > c:
        print("a > b > c")
else:
    print("Not in order")
''',

    "Nested For Loops": '''\
for i in range(3):
    for j in range(2):
        print(i, j)
''',

    "Generator Expression": '''\
gen = (x * 2 for x in range(5))
print(list(gen))
'''
}

sample_choices = list(sample_dict.keys())

def load_sample(sample_name):
    return sample_dict.get(sample_name, "")

# =============================================================================
# Gradio 介面主函數
# =============================================================================
def process_code(python_code):
    try:
        tree = ast.parse(python_code)
        ast_tree = ast.dump(tree, indent=4)
    except Exception as e:
        ast_tree = f"AST Error: {e}"
    python_output = run_python_code(python_code)
    try:
        js_code = convert_python_to_javascript(python_code)
    except Exception as e:
        js_code = f"Conversion Error: {e}"
    js_output = run_js_code(js_code)
    return ast_tree, python_output, js_code, js_output

# =============================================================================
# 建立 Gradio 介面,包含下拉選單、Code 輸入與多個輸出區塊
# =============================================================================
with gr.Blocks(css="#component-7 {font-family: 'Consolas', monospace;}") as demo:
    gr.Markdown("# Python ↔ JavaScript Interactive converters")
    gr.Markdown("Select the sample code from the drop-down menu or paste the Python code directly to see the AST, the execution result, and the converted JavaScript code and execution result.")
    
    with gr.Row():
        with gr.Column(scale=1):
            sample_dropdown = gr.Dropdown(label="Select a Sample", choices=sample_choices, value=sample_choices[0])
        with gr.Column(scale=3):
            python_code_input = gr.Code(
                label="Python Code Input", 
                language="python",
                value=sample_dict[sample_choices[0]],
                interactive=True)
    
    sample_dropdown.change(fn=load_sample, inputs=sample_dropdown, outputs=python_code_input)

    run_button = gr.Button("Run Code")
    
    with gr.Row():
        with gr.Column():
            ast_output = gr.Code(label="AST Tree", language="python", interactive=False)
            python_exec_output = gr.Textbox(label="Python Execution Output", interactive=False)
        with gr.Column():
            js_code_output = gr.Code(label="Converted JavaScript Code", language="javascript", interactive=False)
            js_exec_output = gr.Textbox(label="JavaScript Execution Output", interactive=False)
    
    run_button.click(fn=process_code,
                     inputs=python_code_input,
                     outputs=[ast_output, python_exec_output, js_code_output, js_exec_output])

if __name__ == "__main__":
    demo.launch()