Spaces:
Sleeping
Sleeping
import json | |
import gradio as gr | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForCausalLM, | |
AutoModelForSeq2SeqLM, | |
pipeline | |
) | |
# ── 1) Pick your model ──────────────────────────────────────────────── | |
# ▸ For a causal code model (no T5 errors): | |
MODEL_ID = "Salesforce/codegen-350M-multi" | |
# ▸ Or, for a seq‑to‑seq model: | |
# MODEL_ID = "google/flan-t5-base" | |
# MODEL_ID = "google/flan-t5-small" | |
# MODEL_ID = "Salesforce/codegen-350M-multi" | |
# ── 2) Load tokenizer + model ──────────────────────────────────────── | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
# Detect T5 vs causal: | |
if "t5" in MODEL_ID.lower() or MODEL_ID.startswith("google/"): | |
# seq‑to‑seq | |
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID) | |
task = "text2text-generation" | |
else: | |
# causal | |
model = AutoModelForCausalLM.from_pretrained(MODEL_ID) | |
task = "text-generation" | |
# ── 3) Build the pipeline ──────────────────────────────────────────── | |
pipe = pipeline( | |
task, | |
model=model, | |
tokenizer=tokenizer, | |
device=-1, # CPU | |
max_new_tokens=256, | |
temperature=0.2 | |
) | |
# ── 4) Your review function ────────────────────────────────────────── | |
def review_code(diff: str, guidelines: str): | |
prompt = ( | |
"You are an expert code reviewer. Return *only* a valid JSON array of objects\n" | |
"with two fields each: `line` (number) and `comment` (string).\n\n" | |
f"DIFF:\n{diff}\n\n" | |
f"GUIDELINES:\n{guidelines}\n\n" | |
"OUTPUT FORMAT EXAMPLE:\n" | |
'[{"line":12,"comment":"…"}]\n\n' | |
"<<<END_JSON>>>\n" | |
) | |
# Run the model | |
out = pipe(prompt)[0]["generated_text"] | |
# Truncate at our stop marker, if present | |
if "<<<END_JSON>>>" in out: | |
out = out.split("<<<END_JSON>>>")[0] | |
start = out.find("[") | |
end = out.rfind("]") + 1 | |
if start < 0 or end < 0: | |
return {"error": "No JSON array found", "raw": out} | |
snippet = out[start:end] | |
try: | |
return json.loads(snippet) | |
except json.JSONDecodeError: | |
return {"error": "JSON parse failed", "raw": snippet} | |
# ── 5) Gradio interface ────────────────────────────────────────────── | |
iface = gr.Interface( | |
fn=review_code, | |
inputs=[ | |
gr.Textbox(lines=10, label="Git Diff"), | |
gr.Textbox(lines=5, label="Review Guidelines") | |
], | |
outputs=gr.JSON(label="Comments"), | |
title="🤖 Code Review LLM", | |
description="Paste your git diff and guidelines; get back JSON comments." | |
) | |
if __name__ == "__main__": | |
local_url = iface.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False | |
) | |
# print so you can curl it immediately | |
print(f"🔗 endpoint → {local_url}/predict") |