import ast from pathlib import Path import gradio as gr from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch from torch import nn model_id = "answerdotai/ModernBERT-base" path = "DanGalt/modernbert-code-comrel-synthetic" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForSequenceClassification.from_pretrained(path) sep = "[SEP]" def prepare_input(example): tokens = tokenizer( example["function_definition"] + sep + example["code"] + sep + example["comment"], truncation=True, max_length=1024, return_tensors="pt" ) return tokens def parse_text(text): # NOTE: Doesn't collect comments and function definitions correctly inputs = [] defs = [] tree = ast.parse(text) for el in tree.body: if isinstance(el, ast.FunctionDef): defs.append((el.lineno - 1, el.end_lineno - 1, el.col_offset)) inputs = [] lines = text.split('\n') for lineno, line in enumerate(lines): if (offset := line.find('#')) != -1: corresponding_def = None for (def_l, def_el, def_off) in defs: if def_l <= lineno and def_off <= offset: corresponding_def = (def_l, def_el, def_off) comment = line[offset:] code = '\n'.join(lines[lineno - 4:lineno + 4]) fdef = "None" if corresponding_def is not None: fdef = [lines[corresponding_def[0]][corresponding_def[2]:]] cur_lineno = corresponding_def[0] while cur_lineno <= corresponding_def[1]: if lines[cur_lineno].find("):") != -1 or lines[cur_lineno].find("->") != -1: fdef += lines[corresponding_def[0] + 1:cur_lineno + 1] break cur_lineno += 1 fdef = '\n'.join(fdef).strip() inputs.append({ "function_definition": fdef, "code": code, "comment": comment, "lineno": lineno }) return inputs def predict(inp, model=model): with torch.no_grad(): out = model(**inp) return nn.functional.softmax(out.logits, dim=-1)[0, 1].item() def parse_and_predict(text, thrd=0.0): parsed = parse_text(text) preds = [predict(prepare_input(p)) for p in parsed] result = [] for i, p in enumerate(preds): if thrd > 0: p = thrd > p result.append((parsed[i]["lineno"], p)) return result def parse_and_predict_file(path, thrd=0.0): text = Path(path).open("r").read() return parse_and_predict(text, thrd) def parse_and_predict_pretty_out(text, thrd=0.0): results = parse_and_predict(text, thrd=thrd) lines = text.split('\n') output = [] if thrd > 0: for lineno, do_warn in results: if do_warn: output.append(f"The comment on line {lineno} is incorrect: '{lines[lineno]}'.") else: for lineno, p in results: output.append(f"The comment on line {lineno} is estimated to be correct with probability {p:.2f}: '{lines[lineno]}'.") return '\n'.join(output) example_text = """a = 3 b = 2 # The code below does some calculations based on a predefined rule that is very important c = a - b # Calculate and store the sum of a and b in c d = a + b # Calculate and store the sum of a and b in d e = c * b # Calculate and store the product of c and d in e print(f"Wow, maths: {[a, b, c, d, e]}")""" gradio_app = gr.Interface( fn=parse_and_predict_pretty_out, inputs=[ gr.Textbox(label="Input", lines=7), gr.Slider(value=0.8, minimum=0.0, maximum=1.0, step=0.05)], outputs=[gr.Textbox(label="Predictions", lines=7)], examples=[[example_text, 0.0], [example_text, 0.53]], title="Comment \"Correctness\" Classifier", description='Calculates probabilities for each comment in text to be "correct"/"relevant". If the threshold is 0, outputs raw predictions. Otherwise, will report only "incorrect" comments.' ) if __name__ == "__main__": gradio_app.launch()