code-reviewer / app.py
Erpg12's picture
feat: upload train sft file
47f1e3a
raw
history blame
3.19 kB
import json
import gradio as gr
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
pipeline
)
# ── 1) Pick your model ────────────────────────────────────────────────
# ▸ For a causal code model (no T5 errors):
MODEL_ID = "Salesforce/codegen-350M-multi"
# ▸ Or, for a seq‑to‑seq model:
# MODEL_ID = "google/flan-t5-base"
# MODEL_ID = "google/flan-t5-small"
# MODEL_ID = "Salesforce/codegen-350M-multi"
# ── 2) Load tokenizer + model ────────────────────────────────────────
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# Detect T5 vs causal:
if "t5" in MODEL_ID.lower() or MODEL_ID.startswith("google/"):
# seq‑to‑seq
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
task = "text2text-generation"
else:
# causal
model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
task = "text-generation"
# ── 3) Build the pipeline ────────────────────────────────────────────
pipe = pipeline(
task,
model=model,
tokenizer=tokenizer,
device=-1, # CPU
max_new_tokens=256,
temperature=0.2
)
# ── 4) Your review function ──────────────────────────────────────────
def review_code(diff: str, guidelines: str):
prompt = (
"You are an expert code reviewer. Return *only* a valid JSON array of objects\n"
"with two fields each: `line` (number) and `comment` (string).\n\n"
f"DIFF:\n{diff}\n\n"
f"GUIDELINES:\n{guidelines}\n\n"
"OUTPUT FORMAT EXAMPLE:\n"
'[{"line":12,"comment":"…"}]\n\n'
"<<<END_JSON>>>\n"
)
# Run the model
out = pipe(prompt)[0]["generated_text"]
# Truncate at our stop marker, if present
if "<<<END_JSON>>>" in out:
out = out.split("<<<END_JSON>>>")[0]
start = out.find("[")
end = out.rfind("]") + 1
if start < 0 or end < 0:
return {"error": "No JSON array found", "raw": out}
snippet = out[start:end]
try:
return json.loads(snippet)
except json.JSONDecodeError:
return {"error": "JSON parse failed", "raw": snippet}
# ── 5) Gradio interface ──────────────────────────────────────────────
iface = gr.Interface(
fn=review_code,
inputs=[
gr.Textbox(lines=10, label="Git Diff"),
gr.Textbox(lines=5, label="Review Guidelines")
],
outputs=gr.JSON(label="Comments"),
title="🤖 Code Review LLM",
description="Paste your git diff and guidelines; get back JSON comments."
)
if __name__ == "__main__":
local_url = iface.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)
# print so you can curl it immediately
print(f"🔗 endpoint → {local_url}/predict")