cheberle's picture
f
4b0ec5b
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
# Pfade zu deinem Basismodell und dem feingetunten LoRA-Adapter
BASE_MODEL = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
ADAPTER = "cheberle/autotrain-llama-milch"
print("Lade Tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
BASE_MODEL,
trust_remote_code=True
)
print("Lade Basismodell...")
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
trust_remote_code=True,
device_map="auto",
torch_dtype=torch.float16
)
print("Lade feingetunten Adapter...")
model = PeftModel.from_pretrained(
base_model,
ADAPTER,
torch_dtype=torch.float16
)
model.eval()
def klassifiziere_lebensmittel_fewshot(produkt_text):
"""
Verwendet einen Few-Shot-Prompt mit Beispielen auf Deutsch,
um das Modell zu einer einzigen, kurzen Lebensmittel-Kategorie
ohne zusätzliche Erklärungen zu führen.
"""
# Beispiele (Few-Shot).
# Du kannst die Beispiele anpassen, wenn du andere demonstrieren willst.
beispiele = (
"1) Produkt: \"Cailler Branches Milch, 44 x 46 g\"\n Kategorie: Schokolade\n\n"
"2) Produkt: \"Aeschbach Trinkschokolade Milch, 1 kg\"\n Kategorie: Trinkschokolade\n\n"
"3) Produkt: \"Biedermann Bio Vollmilch 3,8%, pasteurisiert\"\n Kategorie: Milch\n\n"
)
# Prompt mit Few-Shot und neuer Eingabe
prompt = (
"Du bist ein Modell zur Klassifikation von Lebensmitteln in deutsche Kategorien.\n"
"Hier sind einige Beispiele:\n\n"
f"{beispiele}"
f"Neues Produkt: \"{produkt_text}\"\n"
"Kategorie (NUR das Wort und keine Erklärung):"
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=200, # Begrenze die Antwort auf wenige Tokens
temperature=0.0, # So wenig "kreatives" Rauschen wie möglich
top_p=1.0,
do_sample=False
)
# Modell-Antwort dekodieren
decoded = tokenizer.decode(output[0], skip_special_tokens=True).strip()
# Oft wiederholt das Modell das Prompt - wir nehmen daher nur die letzte Zeile
lines = decoded.split("\n")
label = lines[-1].strip()
return label
# Gradio-Interface
with gr.Blocks() as demo:
produkt_box = gr.Textbox(
lines=2,
label="Produktbeschreibung",
placeholder="z.B. 'Biedermann Bio Jogurt Schafmilch Himbeer, 5 x 120 g'"
)
output_box = gr.Textbox(
lines=1,
label="Predizierte Kategorie",
placeholder="Hier erscheint das Ergebnis"
)
classify_button = gr.Button("Kategorie bestimmen (Few-Shot)")
classify_button.click(
fn=klassifiziere_lebensmittel_fewshot,
inputs=produkt_box,
outputs=output_box
)
demo.launch()