|
import ast |
|
from pathlib import Path |
|
|
|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import torch |
|
from torch import nn |
|
|
|
|
|
model_id = "answerdotai/ModernBERT-base" |
|
path = "DanGalt/modernbert-code-comrel-synthetic" |
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
model = AutoModelForSequenceClassification.from_pretrained(path) |
|
sep = "[SEP]" |
|
|
|
|
|
def prepare_input(example): |
|
tokens = tokenizer( |
|
example["function_definition"] + sep + example["code"] + sep + example["comment"], |
|
truncation=True, |
|
max_length=1024, |
|
return_tensors="pt" |
|
) |
|
return tokens |
|
|
|
|
|
def parse_text(text): |
|
|
|
inputs = [] |
|
defs = [] |
|
tree = ast.parse(text) |
|
for el in tree.body: |
|
if isinstance(el, ast.FunctionDef): |
|
defs.append((el.lineno - 1, el.end_lineno - 1, el.col_offset)) |
|
|
|
inputs = [] |
|
lines = text.split('\n') |
|
for lineno, line in enumerate(lines): |
|
if (offset := line.find('#')) != -1: |
|
corresponding_def = None |
|
for (def_l, def_el, def_off) in defs: |
|
if def_l <= lineno and def_off <= offset: |
|
corresponding_def = (def_l, def_el, def_off) |
|
|
|
comment = line[offset:] |
|
code = '\n'.join(lines[lineno - 4:lineno + 4]) |
|
fdef = "None" |
|
if corresponding_def is not None: |
|
fdef = [lines[corresponding_def[0]][corresponding_def[2]:]] |
|
cur_lineno = corresponding_def[0] |
|
while cur_lineno <= corresponding_def[1]: |
|
if lines[cur_lineno].find("):") != -1 or lines[cur_lineno].find("->") != -1: |
|
fdef += lines[corresponding_def[0] + 1:cur_lineno + 1] |
|
break |
|
cur_lineno += 1 |
|
|
|
fdef = '\n'.join(fdef).strip() |
|
|
|
inputs.append({ |
|
"function_definition": fdef, |
|
"code": code, |
|
"comment": comment, |
|
"lineno": lineno |
|
}) |
|
return inputs |
|
|
|
|
|
def predict(inp, model=model): |
|
with torch.no_grad(): |
|
out = model(**inp) |
|
return nn.functional.softmax(out.logits, dim=-1)[0, 1].item() |
|
|
|
|
|
def parse_and_predict(text, thrd=0.0): |
|
parsed = parse_text(text) |
|
preds = [predict(prepare_input(p)) for p in parsed] |
|
result = [] |
|
for i, p in enumerate(preds): |
|
if thrd > 0: |
|
p = thrd > p |
|
result.append((parsed[i]["lineno"], p)) |
|
|
|
return result |
|
|
|
|
|
def parse_and_predict_file(path, thrd=0.0): |
|
text = Path(path).open("r").read() |
|
return parse_and_predict(text, thrd) |
|
|
|
|
|
def parse_and_predict_pretty_out(text, thrd=0.0): |
|
results = parse_and_predict(text, thrd=thrd) |
|
lines = text.split('\n') |
|
output = [] |
|
if thrd > 0: |
|
for lineno, do_warn in results: |
|
if do_warn: |
|
output.append(f"The comment on line {lineno} is incorrect: '{lines[lineno]}'.") |
|
else: |
|
for lineno, p in results: |
|
output.append(f"The comment on line {lineno} is estimated to be correct with probability {p:.2f}: '{lines[lineno]}'.") |
|
return '\n'.join(output) |
|
|
|
|
|
example_text = """a = 3 |
|
b = 2 |
|
# The code below does some calculations based on a predefined rule that is very important |
|
c = a - b # Calculate and store the sum of a and b in c |
|
d = a + b # Calculate and store the sum of a and b in d |
|
e = c * b # Calculate and store the product of c and d in e |
|
print(f"Wow, maths: {[a, b, c, d, e]}")""" |
|
|
|
gradio_app = gr.Interface( |
|
fn=parse_and_predict_pretty_out, |
|
inputs=[ |
|
gr.Textbox(label="Input", lines=7), |
|
gr.Slider(value=0.8, minimum=0.0, maximum=1.0, step=0.05)], |
|
outputs=[gr.Textbox(label="Predictions", lines=7)], |
|
examples=[[example_text, 0.0], [example_text, 0.53]], |
|
title="Comment \"Correctness\" Classifier", |
|
description='Calculates probabilities for each comment in text to be "correct"/"relevant". If the threshold is 0, outputs raw predictions. Otherwise, will report only "incorrect" comments.' |
|
) |
|
|
|
if __name__ == "__main__": |
|
gradio_app.launch() |
|
|