NathanPap commited on
Commit
1b225dc
·
verified ·
1 Parent(s): 20a3dbe

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +43 -31
utils.py CHANGED
@@ -2,23 +2,34 @@ import pandas as pd
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
- class CSVAnalyzer:
 
 
 
 
 
 
 
 
 
 
6
  def __init__(self):
7
  self.model_name = "mistralai/Mistral-7B-Instruct-v0.2"
8
  try:
9
- # Tokenizer initialization with specific configuration
 
10
  self.tokenizer = AutoTokenizer.from_pretrained(
11
  self.model_name,
12
  trust_remote_code=True,
13
  use_fast=False
14
  )
15
 
16
- # Padding token configuration
17
  if self.tokenizer.pad_token is None:
18
  self.tokenizer.pad_token = self.tokenizer.eos_token
19
  self.tokenizer.padding_side = "right"
20
 
21
- # Model initialization
22
  self.model = AutoModelForCausalLM.from_pretrained(
23
  self.model_name,
24
  torch_dtype=torch.float16,
@@ -26,42 +37,43 @@ class CSVAnalyzer:
26
  trust_remote_code=True
27
  )
28
 
29
- # Ensure model knows the pad_token
30
  self.model.config.pad_token_id = self.tokenizer.pad_token_id
 
31
 
32
  except Exception as e:
33
  print(f"Initialisierungsfehler: {str(e)}")
34
  raise
35
 
36
- def prepare_context(self, df: pd.DataFrame) -> str:
37
  """Bereitet den Kontext mit den DataFrame-Daten vor."""
38
  try:
39
- context = "E-Mail-Informationen:\n\n"
40
 
41
- # Convert DataFrame to string and handle missing values
42
  df_str = df.fillna("Keine Angabe").astype(str)
43
 
44
- # Process each row
45
  for index in range(len(df_str)):
46
- row = df_str.iloc[index]
47
- context += f"E-Mail {index + 1}:\n"
48
- for column in df_str.columns:
49
- context += f"{column}: {row[column]}\n"
50
- context += "---\n"
51
 
52
- return context.strip()
53
 
54
  except Exception as e:
55
  raise Exception(f"Fehler bei der Kontextvorbereitung: {str(e)}")
56
 
57
- def generate_response(self, context: str, query: str) -> str:
58
  """Generiert eine Antwort auf die Frage unter Verwendung des Kontexts."""
59
  prompt = f"""<s>[INST] Sie sind ein deutscher Assistent für Facility Management Datenanalyse.
60
  Analysieren Sie die folgenden E-Mail-Daten:
61
 
62
- {context}
63
 
64
- Frage: {query}
65
 
66
  Wichtige Anweisungen:
67
  1. Antworten Sie AUSSCHLIEßLICH auf Deutsch
@@ -72,7 +84,7 @@ Wichtige Anweisungen:
72
  Ihre deutsche Antwort: [/INST]"""
73
 
74
  try:
75
- inputs = self.tokenizer(
76
  prompt,
77
  return_tensors="pt",
78
  padding=True,
@@ -83,9 +95,9 @@ Ihre deutsche Antwort: [/INST]"""
83
  ).to(self.model.device)
84
 
85
  with torch.no_grad():
86
- outputs = self.model.generate(
87
- input_ids=inputs["input_ids"],
88
- attention_mask=inputs["attention_mask"],
89
  max_new_tokens=512,
90
  temperature=0.7,
91
  top_p=0.95,
@@ -96,21 +108,21 @@ Ihre deutsche Antwort: [/INST]"""
96
  eos_token_id=self.tokenizer.eos_token_id
97
  )
98
 
99
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
100
- response = response.split("[/INST]")[-1].strip()
101
 
102
- return response
103
 
104
  except Exception as e:
105
- return f"Fehler bei der Analyse: {str(e)}"
106
 
107
- def analyze_csv(df: pd.DataFrame, query: str) -> str:
108
  """Hauptfunktion zur CSV-Analyse und Fragenbeantwortung."""
109
  try:
110
- analyzer = CSVAnalyzer()
111
- context = analyzer.prepare_context(df)
112
- response = analyzer.generate_response(context, query)
113
- return response
114
 
115
  except Exception as e:
116
  return f"Fehler bei der Analyse: {str(e)}"
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
+ # Globale Analyseinstanz
6
+ _analyzer = None
7
+
8
+ def get_analyzer():
9
+ """Singleton-Muster zur Vermeidung der Modellneuinitialisierung bei jedem Aufruf"""
10
+ global _analyzer
11
+ if _analyzer is None:
12
+ _analyzer = CSVAnalysierer()
13
+ return _analyzer
14
+
15
+ class CSVAnalysierer:
16
  def __init__(self):
17
  self.model_name = "mistralai/Mistral-7B-Instruct-v0.2"
18
  try:
19
+ print("Modell wird initialisiert...")
20
+ # Tokenizer-Initialisierung mit spezifischer Konfiguration
21
  self.tokenizer = AutoTokenizer.from_pretrained(
22
  self.model_name,
23
  trust_remote_code=True,
24
  use_fast=False
25
  )
26
 
27
+ # Padding-Token-Konfiguration
28
  if self.tokenizer.pad_token is None:
29
  self.tokenizer.pad_token = self.tokenizer.eos_token
30
  self.tokenizer.padding_side = "right"
31
 
32
+ # Modell-Initialisierung
33
  self.model = AutoModelForCausalLM.from_pretrained(
34
  self.model_name,
35
  torch_dtype=torch.float16,
 
37
  trust_remote_code=True
38
  )
39
 
40
+ # Sicherstellen, dass das Modell das Padding-Token kennt
41
  self.model.config.pad_token_id = self.tokenizer.pad_token_id
42
+ print("Modell erfolgreich initialisiert!")
43
 
44
  except Exception as e:
45
  print(f"Initialisierungsfehler: {str(e)}")
46
  raise
47
 
48
+ def kontext_vorbereiten(self, df: pd.DataFrame) -> str:
49
  """Bereitet den Kontext mit den DataFrame-Daten vor."""
50
  try:
51
+ kontext = "E-Mail-Informationen:\n\n"
52
 
53
+ # DataFrame in String umwandeln und fehlende Werte behandeln
54
  df_str = df.fillna("Keine Angabe").astype(str)
55
 
56
+ # Jede Zeile verarbeiten
57
  for index in range(len(df_str)):
58
+ zeile = df_str.iloc[index]
59
+ kontext += f"E-Mail {index + 1}:\n"
60
+ for spalte in df_str.columns:
61
+ kontext += f"{spalte}: {zeile[spalte]}\n"
62
+ kontext += "---\n"
63
 
64
+ return kontext.strip()
65
 
66
  except Exception as e:
67
  raise Exception(f"Fehler bei der Kontextvorbereitung: {str(e)}")
68
 
69
+ def antwort_generieren(self, kontext: str, frage: str) -> str:
70
  """Generiert eine Antwort auf die Frage unter Verwendung des Kontexts."""
71
  prompt = f"""<s>[INST] Sie sind ein deutscher Assistent für Facility Management Datenanalyse.
72
  Analysieren Sie die folgenden E-Mail-Daten:
73
 
74
+ {kontext}
75
 
76
+ Frage: {frage}
77
 
78
  Wichtige Anweisungen:
79
  1. Antworten Sie AUSSCHLIEßLICH auf Deutsch
 
84
  Ihre deutsche Antwort: [/INST]"""
85
 
86
  try:
87
+ eingabe = self.tokenizer(
88
  prompt,
89
  return_tensors="pt",
90
  padding=True,
 
95
  ).to(self.model.device)
96
 
97
  with torch.no_grad():
98
+ ausgabe = self.model.generate(
99
+ input_ids=eingabe["input_ids"],
100
+ attention_mask=eingabe["attention_mask"],
101
  max_new_tokens=512,
102
  temperature=0.7,
103
  top_p=0.95,
 
108
  eos_token_id=self.tokenizer.eos_token_id
109
  )
110
 
111
+ antwort = self.tokenizer.decode(ausgabe[0], skip_special_tokens=True)
112
+ antwort = antwort.split("[/INST]")[-1].strip()
113
 
114
+ return antwort
115
 
116
  except Exception as e:
117
+ return f"Generierungsfehler: {str(e)}"
118
 
119
+ def csv_analysieren(df: pd.DataFrame, frage: str) -> str:
120
  """Hauptfunktion zur CSV-Analyse und Fragenbeantwortung."""
121
  try:
122
+ analysierer = get_analyzer() # Verwendet die einzige Instanz
123
+ kontext = analysierer.kontext_vorbereiten(df)
124
+ antwort = analysierer.antwort_generieren(kontext, frage)
125
+ return antwort
126
 
127
  except Exception as e:
128
  return f"Fehler bei der Analyse: {str(e)}"