import ast
from pathlib import Path

import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from torch import nn


model_id = "answerdotai/ModernBERT-base"
path = "DanGalt/modernbert-code-comrel-synthetic"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSequenceClassification.from_pretrained(path)
sep = "[SEP]"


def prepare_input(example):
    tokens = tokenizer(
        example["function_definition"] + sep + example["code"] + sep + example["comment"],
        truncation=True,
        max_length=1024,
        return_tensors="pt"
    )
    return tokens


def parse_text(text):
    # NOTE: Doesn't collect comments and function definitions correctly
    inputs = []
    defs = []
    tree = ast.parse(text)
    for el in tree.body:
        if isinstance(el, ast.FunctionDef):
            defs.append((el.lineno - 1, el.end_lineno - 1, el.col_offset))

    inputs = []
    lines = text.split('\n')
    for lineno, line in enumerate(lines):
        if (offset := line.find('#')) != -1:
            corresponding_def = None
            for (def_l, def_el, def_off) in defs:
                if def_l <= lineno and def_off <= offset:
                    corresponding_def = (def_l, def_el, def_off)

            comment = line[offset:]
            code = '\n'.join(lines[lineno - 4:lineno + 4])
            fdef = "None"
            if corresponding_def is not None:
                fdef = [lines[corresponding_def[0]][corresponding_def[2]:]]
                cur_lineno = corresponding_def[0]
                while cur_lineno <= corresponding_def[1]:
                    if lines[cur_lineno].find("):") != -1 or lines[cur_lineno].find("->") != -1:
                        fdef += lines[corresponding_def[0] + 1:cur_lineno + 1]
                        break
                    cur_lineno += 1

                fdef = '\n'.join(fdef).strip()

            inputs.append({
                "function_definition": fdef,
                "code": code,
                "comment": comment,
                "lineno": lineno
            })
    return inputs


def predict(inp, model=model):
    with torch.no_grad():
        out = model(**inp)
    return nn.functional.softmax(out.logits, dim=-1)[0, 1].item()


def parse_and_predict(text, thrd=0.0):
    parsed = parse_text(text)
    preds = [predict(prepare_input(p)) for p in parsed]
    result = []
    for i, p in enumerate(preds):
        if thrd > 0:
            p = thrd > p
        result.append((parsed[i]["lineno"], p))

    return result


def parse_and_predict_file(path, thrd=0.0):
    text = Path(path).open("r").read()
    return parse_and_predict(text, thrd)


def parse_and_predict_pretty_out(text, thrd=0.0):
    results = parse_and_predict(text, thrd=thrd)
    lines = text.split('\n')
    output = []
    if thrd > 0:
        for lineno, do_warn in results:
            if do_warn:
                output.append(f"The comment on line {lineno} is incorrect: '{lines[lineno]}'.")
    else:
        for lineno, p in results:
            output.append(f"The comment on line {lineno} is estimated to be correct with probability {p:.2f}: '{lines[lineno]}'.")
    return '\n'.join(output)


example_text = """a = 3
b = 2
# The code below does some calculations based on a predefined rule that is very important
c = a - b  # Calculate and store the sum of a and b in c
d = a + b  # Calculate and store the sum of a and b in d
e = c * b  # Calculate and store the product of c and d in e
print(f"Wow, maths: {[a, b, c, d, e]}")"""

gradio_app = gr.Interface(
    fn=parse_and_predict_pretty_out,
    inputs=[
        gr.Textbox(label="Input", lines=7),
        gr.Slider(value=0.8, minimum=0.0, maximum=1.0, step=0.05)],
    outputs=[gr.Textbox(label="Predictions", lines=7)],
    examples=[[example_text, 0.0], [example_text, 0.53]],
    title="Comment \"Correctness\" Classifier",
    description='Calculates probabilities for each comment in text to be "correct"/"relevant". If the threshold is 0, outputs raw predictions. Otherwise, will report only "incorrect" comments.'
)

if __name__ == "__main__":
    gradio_app.launch()