Spaces:
Runtime error
Runtime error
| import ast | |
| from pathlib import Path | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| from torch import nn | |
| model_id = "answerdotai/ModernBERT-base" | |
| path = "DanGalt/modernbert-code-comrel-synthetic" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForSequenceClassification.from_pretrained(path) | |
| sep = "[SEP]" | |
| def prepare_input(example): | |
| tokens = tokenizer( | |
| example["function_definition"] + sep + example["code"] + sep + example["comment"], | |
| truncation=True, | |
| max_length=1024, | |
| return_tensors="pt" | |
| ) | |
| return tokens | |
| def parse_text(text): | |
| # NOTE: Doesn't collect comments and function definitions correctly | |
| inputs = [] | |
| defs = [] | |
| tree = ast.parse(text) | |
| for el in tree.body: | |
| if isinstance(el, ast.FunctionDef): | |
| defs.append((el.lineno - 1, el.end_lineno - 1, el.col_offset)) | |
| inputs = [] | |
| lines = text.split('\n') | |
| for lineno, line in enumerate(lines): | |
| if (offset := line.find('#')) != -1: | |
| corresponding_def = None | |
| for (def_l, def_el, def_off) in defs: | |
| if def_l <= lineno and def_off <= offset: | |
| corresponding_def = (def_l, def_el, def_off) | |
| comment = line[offset:] | |
| code = '\n'.join(lines[lineno - 4:lineno + 4]) | |
| fdef = "None" | |
| if corresponding_def is not None: | |
| fdef = [lines[corresponding_def[0]][corresponding_def[2]:]] | |
| cur_lineno = corresponding_def[0] | |
| while cur_lineno <= corresponding_def[1]: | |
| if lines[cur_lineno].find("):") != -1 or lines[cur_lineno].find("->") != -1: | |
| fdef += lines[corresponding_def[0] + 1:cur_lineno + 1] | |
| break | |
| cur_lineno += 1 | |
| fdef = '\n'.join(fdef).strip() | |
| inputs.append({ | |
| "function_definition": fdef, | |
| "code": code, | |
| "comment": comment, | |
| "lineno": lineno | |
| }) | |
| return inputs | |
| def predict(inp, model=model): | |
| with torch.no_grad(): | |
| out = model(**inp) | |
| return nn.functional.softmax(out.logits, dim=-1)[0, 1].item() | |
| def parse_and_predict(text, thrd=0.0): | |
| parsed = parse_text(text) | |
| preds = [predict(prepare_input(p)) for p in parsed] | |
| result = [] | |
| for i, p in enumerate(preds): | |
| if thrd > 0: | |
| p = thrd > p | |
| result.append((parsed[i]["lineno"], p)) | |
| return result | |
| def parse_and_predict_file(path, thrd=0.0): | |
| text = Path(path).open("r").read() | |
| return parse_and_predict(text, thrd) | |
| def parse_and_predict_pretty_out(text, thrd=0.0): | |
| results = parse_and_predict(text, thrd=thrd) | |
| lines = text.split('\n') | |
| output = [] | |
| if thrd > 0: | |
| for lineno, do_warn in results: | |
| if do_warn: | |
| output.append(f"The comment on line {lineno} is incorrect: '{lines[lineno]}'.") | |
| else: | |
| for lineno, p in results: | |
| output.append(f"The comment on line {lineno} is estimated to be correct with probability {p:.2f}: '{lines[lineno]}'.") | |
| return '\n'.join(output) | |
| example_text = """a = 3 | |
| b = 2 | |
| # The code below does some calculations based on a predefined rule that is very important | |
| c = a - b # Calculate and store the sum of a and b in c | |
| d = a + b # Calculate and store the sum of a and b in d | |
| e = c * b # Calculate and store the product of c and d in e | |
| print(f"Wow, maths: {[a, b, c, d, e]}")""" | |
| gradio_app = gr.Interface( | |
| fn=parse_and_predict_pretty_out, | |
| inputs=[ | |
| gr.Textbox(label="Input", lines=7), | |
| gr.Slider(value=0.8, minimum=0.0, maximum=1.0, step=0.05)], | |
| outputs=[gr.Textbox(label="Predictions", lines=7)], | |
| examples=[[example_text, 0.0], [example_text, 0.53]], | |
| title="Comment \"Correctness\" Classifier", | |
| description='Calculates probabilities for each comment in text to be "correct"/"relevant". If the threshold is 0, outputs raw predictions. Otherwise, will report only "incorrect" comments.' | |
| ) | |
| if __name__ == "__main__": | |
| gradio_app.launch() | |