DanGalt's picture
Create app.py
ea87393 verified
raw
history blame contribute delete
4.16 kB
import ast
from pathlib import Path
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from torch import nn
model_id = "answerdotai/ModernBERT-base"
path = "DanGalt/modernbert-code-comrel-synthetic"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSequenceClassification.from_pretrained(path)
sep = "[SEP]"
def prepare_input(example):
tokens = tokenizer(
example["function_definition"] + sep + example["code"] + sep + example["comment"],
truncation=True,
max_length=1024,
return_tensors="pt"
)
return tokens
def parse_text(text):
# NOTE: Doesn't collect comments and function definitions correctly
inputs = []
defs = []
tree = ast.parse(text)
for el in tree.body:
if isinstance(el, ast.FunctionDef):
defs.append((el.lineno - 1, el.end_lineno - 1, el.col_offset))
inputs = []
lines = text.split('\n')
for lineno, line in enumerate(lines):
if (offset := line.find('#')) != -1:
corresponding_def = None
for (def_l, def_el, def_off) in defs:
if def_l <= lineno and def_off <= offset:
corresponding_def = (def_l, def_el, def_off)
comment = line[offset:]
code = '\n'.join(lines[lineno - 4:lineno + 4])
fdef = "None"
if corresponding_def is not None:
fdef = [lines[corresponding_def[0]][corresponding_def[2]:]]
cur_lineno = corresponding_def[0]
while cur_lineno <= corresponding_def[1]:
if lines[cur_lineno].find("):") != -1 or lines[cur_lineno].find("->") != -1:
fdef += lines[corresponding_def[0] + 1:cur_lineno + 1]
break
cur_lineno += 1
fdef = '\n'.join(fdef).strip()
inputs.append({
"function_definition": fdef,
"code": code,
"comment": comment,
"lineno": lineno
})
return inputs
def predict(inp, model=model):
with torch.no_grad():
out = model(**inp)
return nn.functional.softmax(out.logits, dim=-1)[0, 1].item()
def parse_and_predict(text, thrd=0.0):
parsed = parse_text(text)
preds = [predict(prepare_input(p)) for p in parsed]
result = []
for i, p in enumerate(preds):
if thrd > 0:
p = thrd > p
result.append((parsed[i]["lineno"], p))
return result
def parse_and_predict_file(path, thrd=0.0):
text = Path(path).open("r").read()
return parse_and_predict(text, thrd)
def parse_and_predict_pretty_out(text, thrd=0.0):
results = parse_and_predict(text, thrd=thrd)
lines = text.split('\n')
output = []
if thrd > 0:
for lineno, do_warn in results:
if do_warn:
output.append(f"The comment on line {lineno} is incorrect: '{lines[lineno]}'.")
else:
for lineno, p in results:
output.append(f"The comment on line {lineno} is estimated to be correct with probability {p:.2f}: '{lines[lineno]}'.")
return '\n'.join(output)
example_text = """a = 3
b = 2
# The code below does some calculations based on a predefined rule that is very important
c = a - b # Calculate and store the sum of a and b in c
d = a + b # Calculate and store the sum of a and b in d
e = c * b # Calculate and store the product of c and d in e
print(f"Wow, maths: {[a, b, c, d, e]}")"""
gradio_app = gr.Interface(
fn=parse_and_predict_pretty_out,
inputs=[
gr.Textbox(label="Input", lines=7),
gr.Slider(value=0.8, minimum=0.0, maximum=1.0, step=0.05)],
outputs=[gr.Textbox(label="Predictions", lines=7)],
examples=[[example_text, 0.0], [example_text, 0.53]],
title="Comment \"Correctness\" Classifier",
description='Calculates probabilities for each comment in text to be "correct"/"relevant". If the threshold is 0, outputs raw predictions. Otherwise, will report only "incorrect" comments.'
)
if __name__ == "__main__":
gradio_app.launch()