AshenClock commited on
Commit
cf9b229
·
verified ·
1 Parent(s): 964c6d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -38
app.py CHANGED
@@ -10,7 +10,7 @@ import faiss
10
  import json
11
  import numpy as np
12
  from dotenv import load_dotenv
13
- import requests
14
 
15
  # Carica le variabili d'ambiente
16
  load_dotenv()
@@ -24,8 +24,8 @@ logging.basicConfig(
24
  logger = logging.getLogger(__name__)
25
 
26
  # Recupera la chiave API
27
- API_KEY = os.getenv("HF_API_KEY")
28
- if not API_KEY:
29
  logger.error("HF_API_KEY non impostata.")
30
  raise EnvironmentError("HF_API_KEY non impostata.")
31
 
@@ -46,6 +46,9 @@ except Exception as e:
46
  logger.error(f"Errore nel caricamento del modello SentenceTransformer: {e}")
47
  raise e
48
 
 
 
 
49
  def create_data_directory():
50
  """Crea la directory 'data/' se non esiste."""
51
  os.makedirs(os.path.join(BASE_DIR, "data"), exist_ok=True)
@@ -81,17 +84,17 @@ def create_faiss_index(documents_file: str, index_file: str, embedding_model_ins
81
  document = json.load(f)
82
  lines = document['lines']
83
  logger.info(f"{len(lines)} linee caricate da {documents_file}.")
84
-
85
  # Genera embedding
86
  embeddings = embedding_model_instance.encode(lines, convert_to_numpy=True, show_progress_bar=True)
87
  logger.info("Embedding generati con SentenceTransformer.")
88
-
89
  # Crea l'indice FAISS
90
  dimension = embeddings.shape[1]
91
  index = faiss.IndexFlatL2(dimension)
92
  index.add(embeddings)
93
  logger.info(f"Indice FAISS creato con dimensione: {dimension}.")
94
-
95
  # Salva l'indice
96
  faiss.write_index(index, index_file)
97
  logger.info(f"Indice FAISS salvato in {index_file}.")
@@ -142,23 +145,23 @@ def retrieve_relevant_lines(query: str, top_k: int = 5, embedding_model_instance
142
  document = json.load(f)
143
  lines = document['lines']
144
  logger.info(f"{len(lines)} linee caricate da {DOCUMENTS_FILE}.")
145
-
146
  # Carica l'indice FAISS
147
  index = faiss.read_index(FAISS_INDEX_FILE)
148
  logger.info(f"Indice FAISS caricato da {FAISS_INDEX_FILE}.")
149
-
150
  # Genera embedding della query
151
  if embedding_model_instance is None:
152
  embedding_model_instance = SentenceTransformer('all-MiniLM-L6-v2')
153
  logger.info("Modello SentenceTransformer caricato per l'embedding della query.")
154
-
155
  query_embedding = embedding_model_instance.encode([query], convert_to_numpy=True)
156
  logger.info("Embedding della query generati.")
157
-
158
  # Ricerca nell'indice
159
  distances, indices = index.search(query_embedding, top_k)
160
  logger.info(f"Ricerca FAISS completata. Risultati ottenuti: {len(indices[0])}")
161
-
162
  # Recupera le linee rilevanti
163
  relevant_texts = [lines[idx] for idx in indices[0] if idx < len(lines)]
164
  retrieved_docs = "\n".join(relevant_texts)
@@ -205,38 +208,30 @@ Ora fornisci una breve spiegazione museale (massimo ~10 righe), senza inventare
205
  """
206
 
207
  async def call_hf_model(prompt: str, temperature: float = 0.5, max_tokens: int = 150) -> str:
208
- """Chiama il modello Hugging Face tramite l'API REST e gestisce la risposta."""
209
  logger.debug("Chiamo HF con il seguente prompt:")
210
  content_preview = (prompt[:300] + '...') if len(prompt) > 300 else prompt
211
  logger.debug(f"PROMPT => {content_preview}")
212
 
213
- headers = {
214
- "Authorization": f"Bearer {API_KEY}"
215
- }
216
- payload = {
217
- "inputs": prompt,
218
- "parameters": {
219
- "temperature": temperature,
220
- "max_new_tokens": max_tokens,
221
- "top_p": 0.9
222
- }
223
- }
224
-
225
  try:
226
- response = requests.post(
227
- f"https://api-inference.huggingface.co/models/{HF_MODEL}",
228
- headers=headers,
229
- json=payload
 
 
 
 
 
 
230
  )
231
- if response.status_code != 200:
232
- logger.error(f"Errore nella chiamata all'API Hugging Face: {response.status_code} - {response.text}")
233
- raise HTTPException(status_code=500, detail=f"Errore nell'API Hugging Face: {response.text}")
234
- data = response.json()
235
- logger.debug(f"Risposta completa dal modello: {data}")
236
- if isinstance(data, list) and len(data) > 0 and "generated_text" in data[0]:
237
- raw = data[0]["generated_text"]
238
- elif "generated_text" in data:
239
- raw = data["generated_text"]
240
  else:
241
  raise ValueError("Nessun campo 'generated_text' nella risposta.")
242
 
@@ -245,7 +240,7 @@ async def call_hf_model(prompt: str, temperature: float = 0.5, max_tokens: int =
245
  logger.debug(f"Risposta HF single-line: {single_line}")
246
  return single_line.strip()
247
  except Exception as e:
248
- logger.error(f"Errore nella chiamata all'API Hugging Face tramite requests: {e}")
249
  raise HTTPException(status_code=500, detail=str(e))
250
 
251
  # Variabile globale per le etichette delle entità
 
10
  import json
11
  import numpy as np
12
  from dotenv import load_dotenv
13
+ from huggingface_hub import InferenceClient
14
 
15
  # Carica le variabili d'ambiente
16
  load_dotenv()
 
24
  logger = logging.getLogger(__name__)
25
 
26
  # Recupera la chiave API
27
+ HF_API_KEY = os.getenv("HF_API_KEY")
28
+ if not HF_API_KEY:
29
  logger.error("HF_API_KEY non impostata.")
30
  raise EnvironmentError("HF_API_KEY non impostata.")
31
 
 
46
  logger.error(f"Errore nel caricamento del modello SentenceTransformer: {e}")
47
  raise e
48
 
49
+ # Inizializza il client di Hugging Face
50
+ client = InferenceClient(api_key=HF_API_KEY)
51
+
52
  def create_data_directory():
53
  """Crea la directory 'data/' se non esiste."""
54
  os.makedirs(os.path.join(BASE_DIR, "data"), exist_ok=True)
 
84
  document = json.load(f)
85
  lines = document['lines']
86
  logger.info(f"{len(lines)} linee caricate da {documents_file}.")
87
+
88
  # Genera embedding
89
  embeddings = embedding_model_instance.encode(lines, convert_to_numpy=True, show_progress_bar=True)
90
  logger.info("Embedding generati con SentenceTransformer.")
91
+
92
  # Crea l'indice FAISS
93
  dimension = embeddings.shape[1]
94
  index = faiss.IndexFlatL2(dimension)
95
  index.add(embeddings)
96
  logger.info(f"Indice FAISS creato con dimensione: {dimension}.")
97
+
98
  # Salva l'indice
99
  faiss.write_index(index, index_file)
100
  logger.info(f"Indice FAISS salvato in {index_file}.")
 
145
  document = json.load(f)
146
  lines = document['lines']
147
  logger.info(f"{len(lines)} linee caricate da {DOCUMENTS_FILE}.")
148
+
149
  # Carica l'indice FAISS
150
  index = faiss.read_index(FAISS_INDEX_FILE)
151
  logger.info(f"Indice FAISS caricato da {FAISS_INDEX_FILE}.")
152
+
153
  # Genera embedding della query
154
  if embedding_model_instance is None:
155
  embedding_model_instance = SentenceTransformer('all-MiniLM-L6-v2')
156
  logger.info("Modello SentenceTransformer caricato per l'embedding della query.")
157
+
158
  query_embedding = embedding_model_instance.encode([query], convert_to_numpy=True)
159
  logger.info("Embedding della query generati.")
160
+
161
  # Ricerca nell'indice
162
  distances, indices = index.search(query_embedding, top_k)
163
  logger.info(f"Ricerca FAISS completata. Risultati ottenuti: {len(indices[0])}")
164
+
165
  # Recupera le linee rilevanti
166
  relevant_texts = [lines[idx] for idx in indices[0] if idx < len(lines)]
167
  retrieved_docs = "\n".join(relevant_texts)
 
208
  """
209
 
210
  async def call_hf_model(prompt: str, temperature: float = 0.5, max_tokens: int = 150) -> str:
211
+ """Chiama il modello Hugging Face tramite InferenceClient e gestisce la risposta."""
212
  logger.debug("Chiamo HF con il seguente prompt:")
213
  content_preview = (prompt[:300] + '...') if len(prompt) > 300 else prompt
214
  logger.debug(f"PROMPT => {content_preview}")
215
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  try:
217
+ # Utilizza il metodo chat.completions.create per interagire con il modello
218
+ response = client.chat.completions.create(
219
+ model=HF_MODEL,
220
+ messages=[
221
+ {"role": "user", "content": prompt}
222
+ ],
223
+ temperature=temperature,
224
+ max_tokens=max_tokens,
225
+ top_p=0.7,
226
+ stream=False # Imposta su True se desideri gestire lo stream
227
  )
228
+ logger.debug(f"Risposta completa dal modello: {response}")
229
+
230
+ # Estrai il testo generato
231
+ if isinstance(response, list) and len(response) > 0 and "generated_text" in response[0]:
232
+ raw = response[0]["generated_text"]
233
+ elif "generated_text" in response:
234
+ raw = response["generated_text"]
 
 
235
  else:
236
  raise ValueError("Nessun campo 'generated_text' nella risposta.")
237
 
 
240
  logger.debug(f"Risposta HF single-line: {single_line}")
241
  return single_line.strip()
242
  except Exception as e:
243
+ logger.error(f"Errore nella chiamata all'API Hugging Face tramite InferenceClient: {e}")
244
  raise HTTPException(status_code=500, detail=str(e))
245
 
246
  # Variabile globale per le etichette delle entità