NathanPap's picture
Update utils.py
7a70056 verified
raw
history blame
3.92 kB
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
class CSVAnalyzer:
def __init__(self):
self.model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
try:
# Initialize tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
trust_remote_code=True
)
# Initialize model with lower precision for efficiency
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
except Exception as e:
print(f"Initialisierungsfehler: {str(e)}")
raise
def prepare_context(self, df: pd.DataFrame) -> str:
"""Bereitet den Kontext mit DataFrame-Daten vor."""
try:
context = "Dateninhalt:\n\n"
# Zeilen begrenzen, um Kontextüberlauf zu vermeiden
max_rows = min(len(df), 50)
# Sichere Konvertierung der Indexwerte zu Strings
for i in range(max_rows):
row = df.iloc[i]
row_text = ""
for col in df.columns:
if pd.notna(row[col]):
row_text += f"{col}: {str(row[col]).strip()}\n"
context += f"Eintrag {str(i + 1)}:\n{row_text}\n---\n"
return context.strip()
except Exception as e:
raise Exception(f"Fehler bei der Kontextvorbereitung: {str(e)}")
def generate_response(self, context: str, query: str) -> str:
"""Generiert eine Antwort auf die Frage unter Verwendung des Kontexts."""
# Spezifisches Format für TinyLlama Chat
prompt = f"""<|system|>Du bist ein Assistent, der auf Datenanalyse in einem Facility Management Unternehmen spezialisiert ist.
Antworte präzise und knapp, basierend ausschließlich auf den bereitgestellten Informationen.
Gib das betreffende E-Mail inklusive Datum und Absender an.
Erstelle bei Bedarf Analyse-Tabellen, um die Informationen strukturiert darzustellen.
<|user|>Kontext:
{context}
Frage: {query}
<|assistant|>"""
try:
# Tokenisierung
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=2048,
padding=True,
return_attention_mask=True
).to(self.model.device)
# Antwortgenerierung
with torch.no_grad():
outputs = self.model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=512,
temperature=0.7,
top_p=0.95,
repetition_penalty=1.15,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
# Antwort dekodieren und bereinigen
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response.split("<|assistant|>")[-1].strip()
return response
except Exception as e:
return f"Generierungsfehler: {str(e)}"
def analyze_csv(df: pd.DataFrame, query: str) -> str:
"""Hauptfunktion für CSV-Analyse und Fragenbeantwortung."""
try:
analyzer = CSVAnalyzer()
context = analyzer.prepare_context(df)
response = analyzer.generate_response(context, query)
return response
except Exception as e:
return f"Analysefehler: {str(e)}"