Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from peft import PeftModel | |
# Pfade zu deinem Basismodell und dem feingetunten LoRA-Adapter | |
BASE_MODEL = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" | |
ADAPTER = "cheberle/autotrain-llama-milch" | |
print("Lade Tokenizer...") | |
tokenizer = AutoTokenizer.from_pretrained( | |
BASE_MODEL, | |
trust_remote_code=True | |
) | |
print("Lade Basismodell...") | |
base_model = AutoModelForCausalLM.from_pretrained( | |
BASE_MODEL, | |
trust_remote_code=True, | |
device_map="auto", | |
torch_dtype=torch.float16 | |
) | |
print("Lade feingetunten Adapter...") | |
model = PeftModel.from_pretrained( | |
base_model, | |
ADAPTER, | |
torch_dtype=torch.float16 | |
) | |
model.eval() | |
def klassifiziere_lebensmittel_fewshot(produkt_text): | |
""" | |
Verwendet einen Few-Shot-Prompt mit Beispielen auf Deutsch, | |
um das Modell zu einer einzigen, kurzen Lebensmittel-Kategorie | |
ohne zusätzliche Erklärungen zu führen. | |
""" | |
# Beispiele (Few-Shot). | |
# Du kannst die Beispiele anpassen, wenn du andere demonstrieren willst. | |
beispiele = ( | |
"1) Produkt: \"Cailler Branches Milch, 44 x 46 g\"\n Kategorie: Schokolade\n\n" | |
"2) Produkt: \"Aeschbach Trinkschokolade Milch, 1 kg\"\n Kategorie: Trinkschokolade\n\n" | |
"3) Produkt: \"Biedermann Bio Vollmilch 3,8%, pasteurisiert\"\n Kategorie: Milch\n\n" | |
) | |
# Prompt mit Few-Shot und neuer Eingabe | |
prompt = ( | |
"Du bist ein Modell zur Klassifikation von Lebensmitteln in deutsche Kategorien.\n" | |
"Hier sind einige Beispiele:\n\n" | |
f"{beispiele}" | |
f"Neues Produkt: \"{produkt_text}\"\n" | |
"Kategorie (NUR das Wort und keine Erklärung):" | |
) | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
with torch.no_grad(): | |
output = model.generate( | |
**inputs, | |
max_new_tokens=200, # Begrenze die Antwort auf wenige Tokens | |
temperature=0.0, # So wenig "kreatives" Rauschen wie möglich | |
top_p=1.0, | |
do_sample=False | |
) | |
# Modell-Antwort dekodieren | |
decoded = tokenizer.decode(output[0], skip_special_tokens=True).strip() | |
# Oft wiederholt das Modell das Prompt - wir nehmen daher nur die letzte Zeile | |
lines = decoded.split("\n") | |
label = lines[-1].strip() | |
return label | |
# Gradio-Interface | |
with gr.Blocks() as demo: | |
produkt_box = gr.Textbox( | |
lines=2, | |
label="Produktbeschreibung", | |
placeholder="z.B. 'Biedermann Bio Jogurt Schafmilch Himbeer, 5 x 120 g'" | |
) | |
output_box = gr.Textbox( | |
lines=1, | |
label="Predizierte Kategorie", | |
placeholder="Hier erscheint das Ergebnis" | |
) | |
classify_button = gr.Button("Kategorie bestimmen (Few-Shot)") | |
classify_button.click( | |
fn=klassifiziere_lebensmittel_fewshot, | |
inputs=produkt_box, | |
outputs=output_box | |
) | |
demo.launch() |