Update app.py
Browse files
app.py
CHANGED
@@ -10,7 +10,7 @@ import faiss
|
|
10 |
import json
|
11 |
import numpy as np
|
12 |
from dotenv import load_dotenv
|
13 |
-
import
|
14 |
|
15 |
# Carica le variabili d'ambiente
|
16 |
load_dotenv()
|
@@ -24,8 +24,8 @@ logging.basicConfig(
|
|
24 |
logger = logging.getLogger(__name__)
|
25 |
|
26 |
# Recupera la chiave API
|
27 |
-
|
28 |
-
if not
|
29 |
logger.error("HF_API_KEY non impostata.")
|
30 |
raise EnvironmentError("HF_API_KEY non impostata.")
|
31 |
|
@@ -46,6 +46,9 @@ except Exception as e:
|
|
46 |
logger.error(f"Errore nel caricamento del modello SentenceTransformer: {e}")
|
47 |
raise e
|
48 |
|
|
|
|
|
|
|
49 |
def create_data_directory():
|
50 |
"""Crea la directory 'data/' se non esiste."""
|
51 |
os.makedirs(os.path.join(BASE_DIR, "data"), exist_ok=True)
|
@@ -81,17 +84,17 @@ def create_faiss_index(documents_file: str, index_file: str, embedding_model_ins
|
|
81 |
document = json.load(f)
|
82 |
lines = document['lines']
|
83 |
logger.info(f"{len(lines)} linee caricate da {documents_file}.")
|
84 |
-
|
85 |
# Genera embedding
|
86 |
embeddings = embedding_model_instance.encode(lines, convert_to_numpy=True, show_progress_bar=True)
|
87 |
logger.info("Embedding generati con SentenceTransformer.")
|
88 |
-
|
89 |
# Crea l'indice FAISS
|
90 |
dimension = embeddings.shape[1]
|
91 |
index = faiss.IndexFlatL2(dimension)
|
92 |
index.add(embeddings)
|
93 |
logger.info(f"Indice FAISS creato con dimensione: {dimension}.")
|
94 |
-
|
95 |
# Salva l'indice
|
96 |
faiss.write_index(index, index_file)
|
97 |
logger.info(f"Indice FAISS salvato in {index_file}.")
|
@@ -142,23 +145,23 @@ def retrieve_relevant_lines(query: str, top_k: int = 5, embedding_model_instance
|
|
142 |
document = json.load(f)
|
143 |
lines = document['lines']
|
144 |
logger.info(f"{len(lines)} linee caricate da {DOCUMENTS_FILE}.")
|
145 |
-
|
146 |
# Carica l'indice FAISS
|
147 |
index = faiss.read_index(FAISS_INDEX_FILE)
|
148 |
logger.info(f"Indice FAISS caricato da {FAISS_INDEX_FILE}.")
|
149 |
-
|
150 |
# Genera embedding della query
|
151 |
if embedding_model_instance is None:
|
152 |
embedding_model_instance = SentenceTransformer('all-MiniLM-L6-v2')
|
153 |
logger.info("Modello SentenceTransformer caricato per l'embedding della query.")
|
154 |
-
|
155 |
query_embedding = embedding_model_instance.encode([query], convert_to_numpy=True)
|
156 |
logger.info("Embedding della query generati.")
|
157 |
-
|
158 |
# Ricerca nell'indice
|
159 |
distances, indices = index.search(query_embedding, top_k)
|
160 |
logger.info(f"Ricerca FAISS completata. Risultati ottenuti: {len(indices[0])}")
|
161 |
-
|
162 |
# Recupera le linee rilevanti
|
163 |
relevant_texts = [lines[idx] for idx in indices[0] if idx < len(lines)]
|
164 |
retrieved_docs = "\n".join(relevant_texts)
|
@@ -205,38 +208,30 @@ Ora fornisci una breve spiegazione museale (massimo ~10 righe), senza inventare
|
|
205 |
"""
|
206 |
|
207 |
async def call_hf_model(prompt: str, temperature: float = 0.5, max_tokens: int = 150) -> str:
|
208 |
-
"""Chiama il modello Hugging Face tramite
|
209 |
logger.debug("Chiamo HF con il seguente prompt:")
|
210 |
content_preview = (prompt[:300] + '...') if len(prompt) > 300 else prompt
|
211 |
logger.debug(f"PROMPT => {content_preview}")
|
212 |
|
213 |
-
headers = {
|
214 |
-
"Authorization": f"Bearer {API_KEY}"
|
215 |
-
}
|
216 |
-
payload = {
|
217 |
-
"inputs": prompt,
|
218 |
-
"parameters": {
|
219 |
-
"temperature": temperature,
|
220 |
-
"max_new_tokens": max_tokens,
|
221 |
-
"top_p": 0.9
|
222 |
-
}
|
223 |
-
}
|
224 |
-
|
225 |
try:
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
)
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
raw =
|
238 |
-
elif "generated_text" in data:
|
239 |
-
raw = data["generated_text"]
|
240 |
else:
|
241 |
raise ValueError("Nessun campo 'generated_text' nella risposta.")
|
242 |
|
@@ -245,7 +240,7 @@ async def call_hf_model(prompt: str, temperature: float = 0.5, max_tokens: int =
|
|
245 |
logger.debug(f"Risposta HF single-line: {single_line}")
|
246 |
return single_line.strip()
|
247 |
except Exception as e:
|
248 |
-
logger.error(f"Errore nella chiamata all'API Hugging Face tramite
|
249 |
raise HTTPException(status_code=500, detail=str(e))
|
250 |
|
251 |
# Variabile globale per le etichette delle entità
|
|
|
10 |
import json
|
11 |
import numpy as np
|
12 |
from dotenv import load_dotenv
|
13 |
+
from huggingface_hub import InferenceClient
|
14 |
|
15 |
# Carica le variabili d'ambiente
|
16 |
load_dotenv()
|
|
|
24 |
logger = logging.getLogger(__name__)
|
25 |
|
26 |
# Recupera la chiave API
|
27 |
+
HF_API_KEY = os.getenv("HF_API_KEY")
|
28 |
+
if not HF_API_KEY:
|
29 |
logger.error("HF_API_KEY non impostata.")
|
30 |
raise EnvironmentError("HF_API_KEY non impostata.")
|
31 |
|
|
|
46 |
logger.error(f"Errore nel caricamento del modello SentenceTransformer: {e}")
|
47 |
raise e
|
48 |
|
49 |
+
# Inizializza il client di Hugging Face
|
50 |
+
client = InferenceClient(api_key=HF_API_KEY)
|
51 |
+
|
52 |
def create_data_directory():
|
53 |
"""Crea la directory 'data/' se non esiste."""
|
54 |
os.makedirs(os.path.join(BASE_DIR, "data"), exist_ok=True)
|
|
|
84 |
document = json.load(f)
|
85 |
lines = document['lines']
|
86 |
logger.info(f"{len(lines)} linee caricate da {documents_file}.")
|
87 |
+
|
88 |
# Genera embedding
|
89 |
embeddings = embedding_model_instance.encode(lines, convert_to_numpy=True, show_progress_bar=True)
|
90 |
logger.info("Embedding generati con SentenceTransformer.")
|
91 |
+
|
92 |
# Crea l'indice FAISS
|
93 |
dimension = embeddings.shape[1]
|
94 |
index = faiss.IndexFlatL2(dimension)
|
95 |
index.add(embeddings)
|
96 |
logger.info(f"Indice FAISS creato con dimensione: {dimension}.")
|
97 |
+
|
98 |
# Salva l'indice
|
99 |
faiss.write_index(index, index_file)
|
100 |
logger.info(f"Indice FAISS salvato in {index_file}.")
|
|
|
145 |
document = json.load(f)
|
146 |
lines = document['lines']
|
147 |
logger.info(f"{len(lines)} linee caricate da {DOCUMENTS_FILE}.")
|
148 |
+
|
149 |
# Carica l'indice FAISS
|
150 |
index = faiss.read_index(FAISS_INDEX_FILE)
|
151 |
logger.info(f"Indice FAISS caricato da {FAISS_INDEX_FILE}.")
|
152 |
+
|
153 |
# Genera embedding della query
|
154 |
if embedding_model_instance is None:
|
155 |
embedding_model_instance = SentenceTransformer('all-MiniLM-L6-v2')
|
156 |
logger.info("Modello SentenceTransformer caricato per l'embedding della query.")
|
157 |
+
|
158 |
query_embedding = embedding_model_instance.encode([query], convert_to_numpy=True)
|
159 |
logger.info("Embedding della query generati.")
|
160 |
+
|
161 |
# Ricerca nell'indice
|
162 |
distances, indices = index.search(query_embedding, top_k)
|
163 |
logger.info(f"Ricerca FAISS completata. Risultati ottenuti: {len(indices[0])}")
|
164 |
+
|
165 |
# Recupera le linee rilevanti
|
166 |
relevant_texts = [lines[idx] for idx in indices[0] if idx < len(lines)]
|
167 |
retrieved_docs = "\n".join(relevant_texts)
|
|
|
208 |
"""
|
209 |
|
210 |
async def call_hf_model(prompt: str, temperature: float = 0.5, max_tokens: int = 150) -> str:
|
211 |
+
"""Chiama il modello Hugging Face tramite InferenceClient e gestisce la risposta."""
|
212 |
logger.debug("Chiamo HF con il seguente prompt:")
|
213 |
content_preview = (prompt[:300] + '...') if len(prompt) > 300 else prompt
|
214 |
logger.debug(f"PROMPT => {content_preview}")
|
215 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
try:
|
217 |
+
# Utilizza il metodo chat.completions.create per interagire con il modello
|
218 |
+
response = client.chat.completions.create(
|
219 |
+
model=HF_MODEL,
|
220 |
+
messages=[
|
221 |
+
{"role": "user", "content": prompt}
|
222 |
+
],
|
223 |
+
temperature=temperature,
|
224 |
+
max_tokens=max_tokens,
|
225 |
+
top_p=0.7,
|
226 |
+
stream=False # Imposta su True se desideri gestire lo stream
|
227 |
)
|
228 |
+
logger.debug(f"Risposta completa dal modello: {response}")
|
229 |
+
|
230 |
+
# Estrai il testo generato
|
231 |
+
if isinstance(response, list) and len(response) > 0 and "generated_text" in response[0]:
|
232 |
+
raw = response[0]["generated_text"]
|
233 |
+
elif "generated_text" in response:
|
234 |
+
raw = response["generated_text"]
|
|
|
|
|
235 |
else:
|
236 |
raise ValueError("Nessun campo 'generated_text' nella risposta.")
|
237 |
|
|
|
240 |
logger.debug(f"Risposta HF single-line: {single_line}")
|
241 |
return single_line.strip()
|
242 |
except Exception as e:
|
243 |
+
logger.error(f"Errore nella chiamata all'API Hugging Face tramite InferenceClient: {e}")
|
244 |
raise HTTPException(status_code=500, detail=str(e))
|
245 |
|
246 |
# Variabile globale per le etichette delle entità
|