VincentGOURBIN commited on
Commit
8b010cc
·
verified ·
1 Parent(s): 6052529

Upload step03_chatbot.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. step03_chatbot.py +1553 -0
step03_chatbot.py ADDED
@@ -0,0 +1,1553 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Step 03 - Interface de chat RAG générique avec Gradio
4
+ Utilise les embeddings de Step 02 depuis Hugging Face Hub + Qwen3-4B-Instruct-2507 pour génération
5
+ """
6
+
7
+ import os
8
+ import json
9
+ import numpy as np
10
+ import gradio as gr
11
+ from gradio import ChatMessage
12
+ from typing import List, Dict, Optional, Tuple
13
+ import time
14
+ import torch
15
+ import threading
16
+ import http.server
17
+ import socketserver
18
+ from pathlib import Path
19
+ from datetime import datetime
20
+
21
+ # ZeroGPU compatibility
22
+ try:
23
+ import spaces
24
+ ZEROGPU_AVAILABLE = True
25
+ print("🚀 ZeroGPU détecté - activation du support")
26
+ except ImportError:
27
+ ZEROGPU_AVAILABLE = False
28
+ # Fallback decorator for local usage
29
+ class MockSpaces:
30
+ @staticmethod
31
+ def GPU(duration=None):
32
+ def decorator(func):
33
+ return func
34
+ return decorator
35
+ spaces = MockSpaces()
36
+
37
+ def _check_dependencies():
38
+ """Vérifie les dépendances nécessaires."""
39
+ missing = []
40
+ try:
41
+ import torch
42
+ except ImportError:
43
+ missing.append("torch")
44
+
45
+ try:
46
+ import numpy as np
47
+ except ImportError:
48
+ missing.append("numpy")
49
+
50
+ try:
51
+ from safetensors.torch import load_file
52
+ except ImportError:
53
+ missing.append("safetensors")
54
+
55
+ try:
56
+ from huggingface_hub import hf_hub_download
57
+ except ImportError:
58
+ missing.append("huggingface-hub")
59
+
60
+ try:
61
+ import faiss
62
+ except ImportError:
63
+ missing.append("faiss-cpu")
64
+
65
+ try:
66
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
67
+ except ImportError:
68
+ missing.append("transformers")
69
+
70
+ try:
71
+ from sentence_transformers import SentenceTransformer
72
+ except ImportError:
73
+ missing.append("sentence-transformers")
74
+
75
+ if missing:
76
+ print(f"❌ Dépendances manquantes: {', '.join(missing)}")
77
+ print("📦 Installer avec: pip install " + " ".join(missing))
78
+ return False
79
+ return True
80
+
81
+
82
+ class Step03Config:
83
+ """Gestionnaire de configuration Step 03 basé sur la sortie Step 02."""
84
+
85
+ def __init__(self, config_file: str = "step03_config.json"):
86
+ self.config_file = Path(config_file)
87
+ self.config = self.load_config()
88
+
89
+ def load_config(self) -> Dict:
90
+ """Charge la configuration Step 03."""
91
+ if not self.config_file.exists():
92
+ raise FileNotFoundError(
93
+ f"❌ Configuration Step 03 non trouvée: {self.config_file}\n"
94
+ f"💡 Lancez d'abord: python step02_upload_embeddings.py"
95
+ )
96
+
97
+ try:
98
+ with open(self.config_file, 'r', encoding='utf-8') as f:
99
+ config = json.load(f)
100
+
101
+ # Vérification de la structure
102
+ if not config.get("step02_completed"):
103
+ raise ValueError("❌ Step 02 non complété selon la configuration")
104
+
105
+ required_keys = ["huggingface", "embeddings_info"]
106
+ for key in required_keys:
107
+ if key not in config:
108
+ raise ValueError(f"❌ Clé manquante dans configuration: {key}")
109
+
110
+ return config
111
+
112
+ except json.JSONDecodeError as e:
113
+ raise ValueError(f"❌ Configuration Step 03 malformée: {e}")
114
+
115
+ @property
116
+ def repo_id(self) -> str:
117
+ """Repository Hugging Face ID."""
118
+ return self.config["huggingface"]["repo_id"]
119
+
120
+ @property
121
+ def dataset_name(self) -> str:
122
+ """Nom du dataset."""
123
+ return self.config["huggingface"]["dataset_name"]
124
+
125
+ @property
126
+ def embeddings_file(self) -> str:
127
+ """Nom du fichier SafeTensors."""
128
+ return self.config["huggingface"]["files"]["embeddings"]
129
+
130
+ @property
131
+ def metadata_file(self) -> str:
132
+ """Nom du fichier métadonnées."""
133
+ return self.config["huggingface"]["files"]["metadata"]
134
+
135
+ @property
136
+ def total_vectors(self) -> int:
137
+ """Nombre total de vecteurs."""
138
+ return self.config["embeddings_info"]["total_vectors"]
139
+
140
+ @property
141
+ def vector_dimension(self) -> int:
142
+ """Dimension des vecteurs."""
143
+ return self.config["embeddings_info"]["vector_dimension"]
144
+
145
+ @property
146
+ def embedding_model(self) -> str:
147
+ """Modèle d'embedding utilisé."""
148
+ return self.config["embeddings_info"]["embedding_model"]
149
+
150
+
151
+ class Qwen3Reranker:
152
+ """
153
+ Reranker utilisant Qwen3-Reranker-4B pour améliorer la pertinence des résultats de recherche
154
+ """
155
+
156
+ def __init__(self, model_name: str = "Qwen/Qwen3-Reranker-4B", use_flash_attention: bool = True):
157
+ """
158
+ Initialise le reranker Qwen3
159
+
160
+ Args:
161
+ model_name: Nom du modèle HuggingFace à charger
162
+ use_flash_attention: Utiliser Flash Attention 2 si disponible (auto-d��sactivé sur Mac)
163
+ """
164
+ self.model_name = model_name
165
+ self.use_flash_attention = use_flash_attention
166
+
167
+ # Détection de l'environnement
168
+ self.is_mps = torch.backends.mps.is_available()
169
+ self.is_cuda = torch.cuda.is_available()
170
+ self.is_cpu = not self.is_mps and not self.is_cuda
171
+
172
+ print(f"🔄 Chargement du reranker {model_name}...")
173
+ self._detect_platform()
174
+ self._load_model()
175
+
176
+ def _detect_platform(self):
177
+ """Détecte la plateforme et ajuste les paramètres"""
178
+ if self.is_mps:
179
+ print(" - Plateforme: Mac MPS détecté")
180
+ self.use_flash_attention = False # Flash Attention non compatible MPS
181
+ self.batch_size = 1 # Traitement strictement individuel sur Mac
182
+ self.memory_cleanup_freq = 3 # Nettoyage mémoire fréquent
183
+ elif self.is_cuda:
184
+ print(f" - Plateforme: CUDA détecté ({torch.cuda.get_device_name()})")
185
+ self.batch_size = 1 # Garde traitement individuel pour stabilité
186
+ self.memory_cleanup_freq = 10 # Nettoyage moins fréquent
187
+ else:
188
+ print(" - Plateforme: CPU")
189
+ self.use_flash_attention = False
190
+ self.batch_size = 1
191
+ self.memory_cleanup_freq = 5
192
+
193
+ def _load_model(self):
194
+ """Charge le modèle et le tokenizer"""
195
+ try:
196
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
197
+
198
+ # Chargement du tokenizer
199
+ print(" - Chargement du tokenizer...")
200
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
201
+
202
+ # Configuration du modèle selon la plateforme
203
+ model_kwargs = self._get_model_config()
204
+
205
+ # Chargement du modèle
206
+ print(" - Chargement du modèle...")
207
+ self.model = AutoModelForSequenceClassification.from_pretrained(
208
+ self.model_name,
209
+ **model_kwargs
210
+ )
211
+
212
+ # Configuration du device
213
+ self._setup_device()
214
+
215
+ print(f"✅ Reranker chargé sur {self.device}")
216
+ print(f" - Flash Attention: {'✅' if self.use_flash_attention else '❌'}")
217
+ print(f" - Paramètres: {self.get_parameter_count():.1f}B")
218
+
219
+ except Exception as e:
220
+ print(f"❌ Erreur lors du chargement du reranker: {e}")
221
+ print("💡 Le reranking sera désactivé")
222
+ self.model = None
223
+ self.tokenizer = None
224
+ self.device = None
225
+
226
+ def _get_model_config(self) -> Dict:
227
+ """Retourne la configuration du modèle selon la plateforme"""
228
+ config = {}
229
+
230
+ if self.is_mps:
231
+ # Configuration pour Mac MPS
232
+ config["torch_dtype"] = torch.float32 # MPS fonctionne mieux avec float32
233
+ config["device_map"] = None # device_map peut causer des problèmes avec MPS
234
+ elif self.is_cuda:
235
+ # Configuration pour CUDA
236
+ config["torch_dtype"] = torch.float16
237
+ if self.use_flash_attention:
238
+ try:
239
+ config["attn_implementation"] = "flash_attention_2"
240
+ print(" - Flash Attention 2 activée")
241
+ except Exception:
242
+ print(" - Flash Attention 2 non disponible, utilisation standard")
243
+ self.use_flash_attention = False
244
+ else:
245
+ config["device_map"] = "auto"
246
+ else:
247
+ # Configuration pour CPU
248
+ config["torch_dtype"] = torch.float32
249
+ config["device_map"] = "cpu"
250
+
251
+ return config
252
+
253
+ def _setup_device(self):
254
+ """Configure le device pour le modèle"""
255
+ if self.is_mps:
256
+ self.device = torch.device("mps")
257
+ self.model = self.model.to(self.device)
258
+ elif self.is_cuda:
259
+ if hasattr(self.model, 'device'):
260
+ self.device = next(self.model.parameters()).device
261
+ else:
262
+ self.device = torch.device("cuda")
263
+ self.model = self.model.to(self.device)
264
+ else:
265
+ self.device = torch.device("cpu")
266
+ self.model = self.model.to(self.device)
267
+
268
+ def _format_pair(self, query: str, document: str, instruction: str = None) -> str:
269
+ """
270
+ Formate une paire query-document pour le reranker
271
+ """
272
+ if instruction:
273
+ return f"Instruction: {instruction}\nQuery: {query}\nDocument: {document}"
274
+ return f"Query: {query}\nDocument: {document}"
275
+
276
+ def _get_default_instruction(self) -> str:
277
+ """Retourne l'instruction par défaut pour la documentation technique"""
278
+ return (
279
+ "Évaluez la pertinence de ce document technique "
280
+ "par rapport à la requête en considérant : terminologie technique, "
281
+ "spécifications, normes, procédures de mise en œuvre."
282
+ )
283
+
284
+ def _process_single_document(self, query: str, document: str, instruction: str) -> float:
285
+ """
286
+ Traite un seul document et retourne son score de pertinence
287
+ """
288
+ # Formatage de la paire
289
+ pair_text = self._format_pair(query, document, instruction)
290
+
291
+ # Tokenisation (pas de problème de padding avec un seul document)
292
+ inputs = self.tokenizer(
293
+ pair_text,
294
+ truncation=True,
295
+ max_length=512,
296
+ return_tensors="pt",
297
+ padding=False
298
+ )
299
+
300
+ # Déplacement vers le device
301
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
302
+
303
+ # Inférence
304
+ with torch.no_grad():
305
+ outputs = self.model(**inputs)
306
+ logits = outputs.logits
307
+
308
+ # Le modèle Qwen3-Reranker retourne des logits de forme [1, 2]
309
+ # pour classification binaire : [non-pertinent, pertinent]
310
+ probs = torch.nn.functional.softmax(logits, dim=1)
311
+ score = probs[0, 1].cpu().item() # Classe 1 = pertinent
312
+
313
+ return float(score)
314
+
315
+ def _cleanup_memory(self):
316
+ """Nettoie la mémoire selon la plateforme"""
317
+ if self.is_mps:
318
+ if hasattr(torch.mps, 'empty_cache'):
319
+ torch.mps.empty_cache()
320
+ elif self.is_cuda:
321
+ torch.cuda.empty_cache()
322
+
323
+ import gc
324
+ gc.collect()
325
+
326
+ def rerank(self, query: str, documents: List[str], instruction: str = None) -> List[float]:
327
+ """
328
+ Reranke une liste de documents par rapport à une requête
329
+ """
330
+ if not documents:
331
+ return []
332
+
333
+ if self.model is None or self.tokenizer is None:
334
+ print(" - Reranker non disponible, scores neutres retournés")
335
+ return [0.5] * len(documents)
336
+
337
+ if instruction is None:
338
+ instruction = self._get_default_instruction()
339
+
340
+ print(f" - Reranking de {len(documents)} documents (traitement individuel)")
341
+
342
+ scores = []
343
+ successful_count = 0
344
+
345
+ for i, document in enumerate(documents):
346
+ try:
347
+ score = self._process_single_document(query, document, instruction)
348
+ score = max(0.0, min(1.0, score))
349
+ scores.append(score)
350
+ successful_count += 1
351
+
352
+ if (i + 1) % self.memory_cleanup_freq == 0:
353
+ self._cleanup_memory()
354
+
355
+ except Exception as doc_error:
356
+ print(f" ⚠️ Erreur document {i+1}: {doc_error}")
357
+ scores.append(0.5) # Score neutre en cas d'erreur
358
+
359
+ self._cleanup_memory()
360
+
361
+ print(f" ✅ Reranking terminé: {successful_count}/{len(documents)} documents traités")
362
+
363
+ if successful_count > 0:
364
+ valid_scores = [s for s in scores if s != 0.5]
365
+ if valid_scores:
366
+ top_scores = sorted(valid_scores, reverse=True)[:3]
367
+ print(f" 📈 Top 3 scores: {[f'{s:.3f}' for s in top_scores]}")
368
+
369
+ return scores
370
+
371
+ def get_parameter_count(self) -> float:
372
+ """Retourne le nombre de paramètres du modèle en milliards"""
373
+ if self.model is None:
374
+ return 0.0
375
+ try:
376
+ return sum(p.numel() for p in self.model.parameters()) / 1e9
377
+ except:
378
+ return 0.0
379
+
380
+ def is_available(self) -> bool:
381
+ """Vérifie si le reranker est disponible et fonctionnel"""
382
+ return self.model is not None and self.tokenizer is not None
383
+
384
+
385
+ class GenericRAGChatbot:
386
+ """Chatbot RAG générique utilisant les embeddings de Step 02 et Qwen3-4B-Instruct pour la génération"""
387
+
388
+ def __init__(self,
389
+ generation_model: str = "Qwen/Qwen3-4B-Instruct-2507",
390
+ initial_k: int = 20,
391
+ final_k: int = 3,
392
+ use_flash_attention: bool = True,
393
+ use_reranker: bool = True):
394
+ """
395
+ Initialise le système RAG générique
396
+
397
+ Args:
398
+ generation_model: Modèle Qwen3 pour la génération
399
+ initial_k: Nombre de candidats pour la recherche initiale
400
+ final_k: Nombre de documents finaux après reranking
401
+ use_flash_attention: Utiliser Flash Attention (désactivé automatiquement sur Mac)
402
+ use_reranker: Utiliser le reranking Qwen3
403
+ """
404
+ self.generation_model_name = generation_model
405
+ self.initial_k = initial_k
406
+ self.final_k = final_k
407
+ self.use_flash_attention = use_flash_attention
408
+ self.use_reranker = use_reranker
409
+
410
+ # Détection de l'environnement (local + ZeroGPU)
411
+ self.is_zerogpu = ZEROGPU_AVAILABLE and os.getenv("SPACE_ID") is not None
412
+ self.is_mps = torch.backends.mps.is_available() and not self.is_zerogpu
413
+ self.is_cuda = torch.cuda.is_available()
414
+
415
+ # Configuration du device
416
+ if self.is_mps:
417
+ self.device = torch.device("mps")
418
+ elif self.is_cuda:
419
+ self.device = torch.device("cuda")
420
+ else:
421
+ self.device = torch.device("cpu")
422
+
423
+ if self.is_zerogpu:
424
+ print("🚀 Environnement ZeroGPU détecté - optimisations cloud")
425
+ self.use_flash_attention = True # ZeroGPU supporte Flash Attention
426
+ elif self.is_mps and use_flash_attention:
427
+ print("🍎 Mac avec MPS détecté - désactivation automatique de Flash Attention")
428
+ self.use_flash_attention = False
429
+
430
+ # Chargement des composants
431
+ self._load_step03_config()
432
+ self._load_embeddings_from_hf()
433
+ self._load_embedding_model()
434
+ self._load_reranker()
435
+ self._load_generation_model()
436
+
437
+ def _load_step03_config(self):
438
+ """Charge la configuration Step 03"""
439
+ try:
440
+ self.config = Step03Config()
441
+ print(f"✅ Configuration Step 03 chargée")
442
+ print(f" 📦 Repository HF: {self.config.repo_id}")
443
+ print(f" 📊 Embeddings: {self.config.total_vectors:,} vecteurs")
444
+ print(f" 📏 Dimension: {self.config.vector_dimension}")
445
+ except Exception as e:
446
+ print(f"❌ Erreur de chargement de la configuration: {e}")
447
+ raise
448
+
449
+ def _load_embeddings_from_hf(self):
450
+ """Télécharge et charge les embeddings depuis Hugging Face Hub"""
451
+ try:
452
+ from huggingface_hub import hf_hub_download
453
+ from safetensors.torch import load_file
454
+ import numpy as np
455
+ import faiss
456
+
457
+ print(f"🔄 Téléchargement des embeddings depuis {self.config.repo_id}...")
458
+
459
+ # Télécharger les fichiers (sans token pour les repos publics)
460
+ try:
461
+ embeddings_file = hf_hub_download(
462
+ repo_id=self.config.repo_id,
463
+ filename=self.config.embeddings_file,
464
+ repo_type="dataset",
465
+ token=None # Forcer l'accès sans token pour les repos publics
466
+ )
467
+
468
+ metadata_file = hf_hub_download(
469
+ repo_id=self.config.repo_id,
470
+ filename=self.config.metadata_file,
471
+ repo_type="dataset",
472
+ token=None # Forcer l'accès sans token pour les repos publics
473
+ )
474
+ except Exception as auth_error:
475
+ print(f" ⚠️ Erreur d'authentification: {auth_error}")
476
+ print(" 🔑 Essai avec token depuis les variables d'environnement...")
477
+
478
+ # Essayer avec le token d'environnement
479
+ import os
480
+ hf_token = os.getenv('HF_TOKEN') or os.getenv('HUGGINGFACE_HUB_TOKEN')
481
+
482
+ if hf_token:
483
+ print(" 🔑 Token trouvé, nouvel essai...")
484
+ embeddings_file = hf_hub_download(
485
+ repo_id=self.config.repo_id,
486
+ filename=self.config.embeddings_file,
487
+ repo_type="dataset",
488
+ token=hf_token
489
+ )
490
+
491
+ metadata_file = hf_hub_download(
492
+ repo_id=self.config.repo_id,
493
+ filename=self.config.metadata_file,
494
+ repo_type="dataset",
495
+ token=hf_token
496
+ )
497
+ else:
498
+ print(" ❌ Aucun token trouvé dans les variables d'environnement")
499
+ print(" 💡 Solutions possibles:")
500
+ print(" 1. Vérifiez que le repository est bien public")
501
+ print(" 2. Connectez-vous avec: huggingface-cli login")
502
+ print(" 3. Définissez HF_TOKEN dans les variables d'environnement")
503
+ raise auth_error
504
+
505
+ print(" 📥 Chargement des embeddings SafeTensors...")
506
+ tensors = load_file(embeddings_file)
507
+ embeddings_tensor = tensors["embeddings"]
508
+ embeddings_np = embeddings_tensor.numpy().astype(np.float32)
509
+
510
+ print(" 📋 Chargement des métadonnées...")
511
+ with open(metadata_file, 'r', encoding='utf-8') as f:
512
+ self.metadata = json.load(f)
513
+
514
+ # Créer l'index FAISS (optimisé pour Mac)
515
+ print(" 🔧 Création de l'index FAISS...")
516
+ dimension = embeddings_np.shape[1]
517
+
518
+ # Configuration d'index FAISS selon l'environnement
519
+ if self.is_zerogpu:
520
+ print(" 🚀 Index FAISS optimisé pour ZeroGPU (IndexHNSWFlat)")
521
+ # Index sophistiqué pour ZeroGPU avec GPU puissant
522
+ self.faiss_index = faiss.IndexHNSWFlat(dimension, 32)
523
+ self.faiss_index.hnsw.efConstruction = 200
524
+ self.faiss_index.hnsw.efSearch = 50
525
+ elif self.is_mps:
526
+ print(" 🍎 Index FAISS optimisé pour Mac (IndexFlatIP)")
527
+ # Index simple mais efficace sur Mac
528
+ self.faiss_index = faiss.IndexFlatIP(dimension) # Inner Product (plus stable sur Mac)
529
+ else:
530
+ print(" 🐧 Index FAISS HNSW pour Linux/Windows")
531
+ # Index plus sophistiqué pour autres plateformes
532
+ self.faiss_index = faiss.IndexHNSWFlat(dimension, 32)
533
+ self.faiss_index.hnsw.efConstruction = 200
534
+ self.faiss_index.hnsw.efSearch = 50
535
+
536
+ # Normaliser les embeddings pour IndexFlatIP (équivalent à cosine similarity)
537
+ if self.is_mps:
538
+ # Normalisation L2 pour que IndexFlatIP = cosine similarity
539
+ norms = np.linalg.norm(embeddings_np, axis=1, keepdims=True)
540
+ embeddings_np = embeddings_np / (norms + 1e-8) # Éviter division par 0
541
+
542
+ print(f" 📊 Ajout de {embeddings_np.shape[0]:,} vecteurs à l'index...")
543
+ # Ajouter les vecteurs à l'index
544
+ self.faiss_index.add(embeddings_np)
545
+
546
+ # Récupérer les mappings et métadonnées de contenu
547
+ self.ordered_ids = self.metadata.get('ordered_ids', [])
548
+ self.id_to_idx = self.metadata.get('id_to_idx', {})
549
+ self.content_metadata = self.metadata.get('content_metadata', {})
550
+
551
+
552
+ print(f"✅ Embeddings chargés: {embeddings_np.shape[0]:,} vecteurs de dimension {dimension}")
553
+
554
+ except Exception as e:
555
+ print(f"❌ Erreur lors du chargement des embeddings: {e}")
556
+ raise
557
+
558
+ def _load_embedding_model(self):
559
+ """Charge le modèle d'embeddings pour les requêtes"""
560
+ print(f"🔄 Chargement du modèle d'embeddings {self.config.embedding_model}...")
561
+
562
+ try:
563
+ from sentence_transformers import SentenceTransformer
564
+
565
+ if self.use_flash_attention and self.is_cuda:
566
+ print(" - Configuration avec Flash Attention 2 activée (CUDA)")
567
+ try:
568
+ self.embedding_model = SentenceTransformer(
569
+ self.config.embedding_model,
570
+ model_kwargs={
571
+ "attn_implementation": "flash_attention_2",
572
+ "device_map": "auto"
573
+ },
574
+ tokenizer_kwargs={"padding_side": "left"}
575
+ )
576
+ except Exception as flash_error:
577
+ print(f" - Flash Attention échoué: {flash_error}")
578
+ print(" - Fallback vers configuration standard")
579
+ self.embedding_model = SentenceTransformer(self.config.embedding_model)
580
+ self.use_flash_attention = False
581
+ else:
582
+ print(" - Configuration standard (MPS/CPU ou Flash Attention désactivé)")
583
+ model_kwargs = {}
584
+
585
+ if self.is_mps:
586
+ model_kwargs = {"torch_dtype": torch.float32}
587
+
588
+ if model_kwargs:
589
+ self.embedding_model = SentenceTransformer(
590
+ self.config.embedding_model,
591
+ model_kwargs=model_kwargs,
592
+ tokenizer_kwargs={"padding_side": "left"}
593
+ )
594
+ else:
595
+ self.embedding_model = SentenceTransformer(self.config.embedding_model)
596
+
597
+ print(f"✅ Modèle d'embeddings {self.config.embedding_model} chargé avec succès")
598
+
599
+ except Exception as e:
600
+ print(f"❌ Erreur avec {self.config.embedding_model}: {e}")
601
+ print("🔄 Fallback vers le modèle multilingual MiniLM...")
602
+ self.embedding_model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
603
+ self.use_flash_attention = False
604
+
605
+ def _load_reranker(self):
606
+ """Charge le reranker Qwen3-Reranker-4B"""
607
+ if self.use_reranker:
608
+ try:
609
+ effective_flash_attention = self.use_flash_attention and not self.is_mps
610
+ self.reranker = Qwen3Reranker(use_flash_attention=effective_flash_attention)
611
+ except Exception as e:
612
+ print(f"❌ Erreur lors du chargement du reranker: {e}")
613
+ print("🔄 Désactivation du reranking")
614
+ self.use_reranker = False
615
+ self.reranker = None
616
+ else:
617
+ self.reranker = None
618
+ print("⚠️ Reranking désactivé par configuration")
619
+
620
+ def _load_generation_model(self):
621
+ """Charge le modèle de génération Qwen3-4B-Instruct"""
622
+ print(f"🔄 Chargement du modèle de génération {self.generation_model_name}...")
623
+
624
+ try:
625
+ from transformers import AutoTokenizer, AutoModelForCausalLM
626
+
627
+ # Chargement du tokenizer
628
+ print(" - Chargement du tokenizer...")
629
+ self.generation_tokenizer = AutoTokenizer.from_pretrained(self.generation_model_name)
630
+
631
+ # Configuration du modèle selon la plateforme
632
+ model_kwargs = self._get_generation_model_config()
633
+
634
+ # Chargement du modèle
635
+ print(" - Chargement du modèle...")
636
+ self.generation_model = AutoModelForCausalLM.from_pretrained(
637
+ self.generation_model_name,
638
+ **model_kwargs
639
+ )
640
+
641
+ # Configuration du device
642
+ self._setup_generation_device()
643
+
644
+ print(f"✅ Modèle de génération chargé sur {self.generation_device}")
645
+ print(f" - Paramètres: {self._get_generation_parameter_count():.1f}B")
646
+
647
+ except Exception as e:
648
+ print(f"❌ Erreur lors du chargement du modèle de génération: {e}")
649
+ print("💡 La génération sera désactivée")
650
+ self.generation_model = None
651
+ self.generation_tokenizer = None
652
+ self.generation_device = None
653
+
654
+ def _get_generation_model_config(self) -> Dict:
655
+ """Retourne la configuration du modèle de génération selon la plateforme"""
656
+ config = {}
657
+
658
+ if self.is_mps:
659
+ config["torch_dtype"] = torch.float32
660
+ config["device_map"] = None
661
+ elif self.is_cuda:
662
+ config["torch_dtype"] = torch.float16
663
+ if self.use_flash_attention:
664
+ try:
665
+ config["attn_implementation"] = "flash_attention_2"
666
+ print(" - Flash Attention 2 activée pour génération")
667
+ except Exception:
668
+ print(" - Flash Attention 2 non disponible pour génération")
669
+ config["device_map"] = "auto"
670
+ else:
671
+ config["torch_dtype"] = torch.float32
672
+ config["device_map"] = "cpu"
673
+
674
+ return config
675
+
676
+ def _setup_generation_device(self):
677
+ """Configure le device pour le modèle de génération"""
678
+ if self.is_mps:
679
+ self.generation_device = torch.device("mps")
680
+ self.generation_model = self.generation_model.to(self.generation_device)
681
+ elif self.is_cuda:
682
+ if hasattr(self.generation_model, 'device'):
683
+ self.generation_device = next(self.generation_model.parameters()).device
684
+ else:
685
+ self.generation_device = torch.device("cuda")
686
+ self.generation_model = self.generation_model.to(self.generation_device)
687
+ else:
688
+ self.generation_device = torch.device("cpu")
689
+ self.generation_model = self.generation_model.to(self.generation_device)
690
+
691
+ def _get_generation_parameter_count(self) -> float:
692
+ """Retourne le nombre de paramètres du modèle de génération en milliards"""
693
+ if self.generation_model is None:
694
+ return 0.0
695
+ try:
696
+ return sum(p.numel() for p in self.generation_model.parameters()) / 1e9
697
+ except:
698
+ return 0.0
699
+
700
+ def search_documents(self, query: str, final_k: int = None, use_reranking: bool = None) -> List[Dict]:
701
+ """
702
+ Recherche avancée avec reranking en deux étapes
703
+ """
704
+ k = final_k if final_k is not None else self.final_k
705
+ initial_k = max(self.initial_k, k * 3)
706
+ should_rerank = use_reranking if use_reranking is not None else self.use_reranker
707
+
708
+ print(f"🔍 Recherche en deux étapes: {initial_k} candidats → reranking → {k} finaux")
709
+
710
+ # Étape 1: Recherche par embedding avec FAISS
711
+ if hasattr(self.embedding_model, 'prompts') and 'query' in self.embedding_model.prompts:
712
+ query_embedding = self.embedding_model.encode([query], prompt_name="query")[0]
713
+ else:
714
+ query_embedding = self.embedding_model.encode([query])[0]
715
+
716
+ # Recherche dans l'index FAISS
717
+ query_vector = query_embedding.reshape(1, -1).astype('float32')
718
+
719
+ # Normaliser la requête sur Mac pour IndexFlatIP (consistency avec les embeddings)
720
+ if self.is_mps:
721
+ norm = np.linalg.norm(query_vector)
722
+ if norm > 0:
723
+ query_vector = query_vector / norm
724
+
725
+ distances, indices = self.faiss_index.search(query_vector, initial_k)
726
+
727
+ if len(indices[0]) == 0:
728
+ print("❌ Aucun document trouvé")
729
+ return []
730
+
731
+ print(f"📋 {len(indices[0])} candidats récupérés")
732
+
733
+ # Conversion en format intermédiaire
734
+ initial_results = []
735
+ for i, (distance, idx) in enumerate(zip(distances[0], indices[0])):
736
+ if idx < len(self.ordered_ids):
737
+ doc_id = self.ordered_ids[idx]
738
+ doc_metadata = self.content_metadata.get(doc_id, {})
739
+
740
+ # Ajustement des scores selon le type d'index
741
+ if self.is_mps:
742
+ # Sur Mac avec IndexFlatIP : distance = inner product (plus haut = plus similaire)
743
+ embedding_score = float(distance) # Inner product normalisé = cosine similarity
744
+ embedding_distance = 1.0 - embedding_score # Conversion en distance pour compatibilité
745
+ else:
746
+ # Sur autres plateformes avec IndexHNSWFlat : distance euclidienne
747
+ embedding_distance = float(distance)
748
+ embedding_score = 1 - embedding_distance
749
+
750
+ doc = {
751
+ 'content': doc_metadata.get('chunk_content', 'Contenu non disponible'),
752
+ 'metadata': doc_metadata,
753
+ 'embedding_distance': embedding_distance,
754
+ 'embedding_score': embedding_score,
755
+ 'source': doc_metadata.get('source_file', 'Inconnu'),
756
+ 'title': doc_metadata.get('title', 'Sans titre'),
757
+ 'heading': doc_metadata.get('heading', ''),
758
+ 'initial_rank': i + 1
759
+ }
760
+ initial_results.append(doc)
761
+
762
+ # Étape 2: Reranking si disponible
763
+ if should_rerank and self.reranker and self.reranker.model is not None:
764
+ print("🎯 Application du reranking Qwen3...")
765
+
766
+ documents = [doc['content'] for doc in initial_results]
767
+
768
+
769
+ rerank_scores = self.reranker.rerank(query, documents)
770
+
771
+ # Ajout des scores de reranking
772
+ for doc, rerank_score in zip(initial_results, rerank_scores):
773
+ doc['rerank_score'] = float(rerank_score)
774
+
775
+ # Tri par score de reranking
776
+ initial_results.sort(key=lambda x: x['rerank_score'], reverse=True)
777
+
778
+ # Mise à jour des positions finales
779
+ for i, doc in enumerate(initial_results):
780
+ doc['final_rank'] = i + 1
781
+
782
+ top_scores = [f"{doc['rerank_score']:.3f}" for doc in initial_results[:5]]
783
+ print(f"✅ Reranking appliqué, top 5 scores: {top_scores}")
784
+ else:
785
+ print("⚠️ Reranking désactivé, utilisation des scores d'embedding uniquement")
786
+ for doc in initial_results:
787
+ doc['rerank_score'] = doc['embedding_score']
788
+ doc['final_rank'] = doc['initial_rank']
789
+
790
+ # Retour des top-k résultats finaux
791
+ final_results = initial_results[:k]
792
+ print(f"📊 {len(final_results)} documents finaux sélectionnés")
793
+
794
+ return final_results
795
+
796
+ def generate_response_stream(self, query: str, context: str, history: List = None):
797
+ """
798
+ Génère une réponse streamée basée sur le contexte et l'historique
799
+ """
800
+ if self.generation_model is None or self.generation_tokenizer is None:
801
+ yield "❌ Modèle de génération non disponible"
802
+ return
803
+
804
+ # Construction du prompt système
805
+ system_prompt = """Tu es un assistant expert qui répond aux questions en te basant uniquement sur les documents fournis dans le contexte.
806
+
807
+ Instructions importantes:
808
+ - Réponds en français de manière claire et précise
809
+ - Base-toi uniquement sur les informations du contexte fourni
810
+ - Si l'information n'est pas dans le contexte, dis-le clairement
811
+ - Utilise un ton professionnel adapté au domaine
812
+ - Structure ta réponse avec des paragraphes clairs"""
813
+
814
+ # Construire le prompt complet
815
+ messages = [{"role": "system", "content": system_prompt}]
816
+
817
+ # Ajouter l'historique si fourni
818
+ if history:
819
+ for msg in history:
820
+ if hasattr(msg, 'role') and hasattr(msg, 'content'):
821
+ messages.append({"role": msg.role, "content": msg.content})
822
+
823
+ # Ajouter le contexte et la question
824
+ user_message = f"Contexte:\n{context}\n\nQuestion: {query}"
825
+ messages.append({"role": "user", "content": user_message})
826
+
827
+ try:
828
+ # Tokenisation
829
+ inputs = self.generation_tokenizer.apply_chat_template(
830
+ messages,
831
+ tokenize=True,
832
+ add_generation_prompt=True,
833
+ return_tensors="pt"
834
+ ).to(self.device)
835
+
836
+ # Génération streamée
837
+ from transformers import TextIteratorStreamer
838
+ import threading
839
+
840
+ streamer = TextIteratorStreamer(
841
+ self.generation_tokenizer,
842
+ timeout=10.0,
843
+ skip_prompt=True,
844
+ skip_special_tokens=True
845
+ )
846
+
847
+ generation_kwargs = {
848
+ "input_ids": inputs,
849
+ "streamer": streamer,
850
+ "max_new_tokens": 1024,
851
+ "temperature": 0.7,
852
+ "do_sample": True,
853
+ "pad_token_id": self.generation_tokenizer.eos_token_id,
854
+ "eos_token_id": self.generation_tokenizer.eos_token_id,
855
+ }
856
+
857
+ # Lancer la génération dans un thread séparé
858
+ thread = threading.Thread(target=self.generation_model.generate, kwargs=generation_kwargs)
859
+ thread.start()
860
+
861
+ # Streamer les tokens
862
+ for new_token in streamer:
863
+ yield new_token
864
+
865
+ thread.join()
866
+
867
+ except Exception as e:
868
+ yield f"❌ Erreur lors de la génération: {str(e)}"
869
+
870
+ def generate_response(self, query: str, context: str, history: List = None) -> str:
871
+ """
872
+ Génère une réponse basée sur le contexte et l'historique
873
+ """
874
+ if self.generation_model is None or self.generation_tokenizer is None:
875
+ return "❌ Modèle de génération non disponible"
876
+
877
+ # Construction du prompt système
878
+ system_prompt = """Tu es un assistant expert qui répond aux questions en te basant uniquement sur les documents fournis dans le contexte.
879
+
880
+ Instructions importantes:
881
+ - Réponds en français de manière claire et précise
882
+ - Base-toi uniquement sur les informations du contexte fourni
883
+ - Si l'information n'est pas dans le contexte, dis-le clairement
884
+ - Utilise un ton professionnel adapté au domaine
885
+ - Structure ta réponse avec des paragraphes clairs"""
886
+
887
+ # Construire le prompt complet
888
+ messages = [{"role": "system", "content": system_prompt}]
889
+
890
+ # Ajouter l'historique si fourni
891
+ if history:
892
+ for msg in history:
893
+ if hasattr(msg, 'role') and hasattr(msg, 'content'):
894
+ if msg.role in ["user", "assistant"] and not getattr(msg, 'metadata', None):
895
+ messages.append({"role": msg.role, "content": msg.content})
896
+
897
+ # Ajouter la question courante avec le contexte
898
+ user_prompt = f"""Contexte documentaire:
899
+ {context}
900
+
901
+ Question: {query}
902
+
903
+ Réponds à cette question en te basant sur le contexte fourni."""
904
+
905
+ messages.append({"role": "user", "content": user_prompt})
906
+
907
+ # Formatage pour le modèle
908
+ try:
909
+ # Appliquer le template de chat du modèle
910
+ formatted_prompt = self.generation_tokenizer.apply_chat_template(
911
+ messages,
912
+ tokenize=False,
913
+ add_generation_prompt=True
914
+ )
915
+
916
+ # Tokenisation
917
+ inputs = self.generation_tokenizer(
918
+ formatted_prompt,
919
+ return_tensors="pt",
920
+ truncation=True,
921
+ max_length=4096
922
+ )
923
+
924
+ # Déplacement vers le device
925
+ inputs = {k: v.to(self.generation_device) for k, v in inputs.items()}
926
+
927
+ # Génération
928
+ with torch.no_grad():
929
+ outputs = self.generation_model.generate(
930
+ **inputs,
931
+ max_new_tokens=1024,
932
+ temperature=0.7,
933
+ do_sample=True,
934
+ pad_token_id=self.generation_tokenizer.eos_token_id,
935
+ eos_token_id=self.generation_tokenizer.eos_token_id,
936
+ )
937
+
938
+ # Décodage de la réponse
939
+ full_response = self.generation_tokenizer.decode(outputs[0], skip_special_tokens=True)
940
+
941
+ # Extraire seulement la nouvelle génération
942
+ response = full_response[len(formatted_prompt):].strip()
943
+
944
+ return response
945
+
946
+ except Exception as e:
947
+ print(f"❌ Erreur lors de la génération: {e}")
948
+ return f"❌ Erreur lors de la génération de la réponse: {str(e)}"
949
+
950
+ def stream_response_with_tools(self, query: str, history, top_k: int = None, use_reranking: bool = None):
951
+ """
952
+ Génère une réponse streamée avec affichage visuel des tools et reranking Qwen3
953
+ """
954
+ # 1. S'assurer que l'historique est une liste
955
+ if not history:
956
+ history = []
957
+
958
+ # 2. Ajouter le message utilisateur seulement s'il n'est pas déjà présent
959
+ if not history or history[-1].role != "user" or history[-1].content != query:
960
+ history.append(ChatMessage(role="user", content=query))
961
+ yield history
962
+ time.sleep(0.1)
963
+
964
+ # 3. Recherche des documents avec tool visuel
965
+ should_rerank = use_reranking if use_reranking is not None else self.use_reranker
966
+ search_method = "avec reranking Qwen3" if should_rerank else "par embedding seulement"
967
+
968
+ history.append(ChatMessage(
969
+ role="assistant",
970
+ content=f"Je recherche les documents les plus pertinents dans la base de données ({search_method})...",
971
+ metadata={"title": "🔍 Recherche sémantique avancée"}
972
+ ))
973
+ yield history
974
+
975
+ # Recherche des documents pertinents
976
+ relevant_docs = self.search_documents(query, top_k, use_reranking)
977
+
978
+ time.sleep(0.2)
979
+
980
+ if not relevant_docs:
981
+ history.append(ChatMessage(
982
+ role="assistant",
983
+ content="Aucun document pertinent trouvé dans la base de données."
984
+ ))
985
+ yield history
986
+ return
987
+
988
+ # 4. Affichage des documents trouvés avec scores détaillés
989
+ docs_summary = f"Trouvé {len(relevant_docs)} documents pertinents"
990
+ if should_rerank:
991
+ docs_summary += f"\n\n📊 **Reranking Qwen3 appliqué:**"
992
+ for i, doc in enumerate(relevant_docs):
993
+ embedding_score = doc.get('embedding_score', 0)
994
+ rerank_score = doc.get('rerank_score', 0)
995
+ rank_change = doc.get('initial_rank', i+1) - doc.get('final_rank', i+1)
996
+ rank_indicator = f" (#{doc.get('initial_rank', i+1)}→#{doc.get('final_rank', i+1)})" if rank_change != 0 else ""
997
+ docs_summary += f"\n• **{doc['title']}**{rank_indicator}"
998
+ docs_summary += f"\n └ Embedding: {embedding_score:.3f} | Reranking: {rerank_score:.3f}"
999
+ else:
1000
+ for i, doc in enumerate(relevant_docs):
1001
+ embedding_score = doc.get('embedding_score', doc.get('distance', 0))
1002
+ docs_summary += f"\n• **{doc['title']}** - Score: {embedding_score:.3f}"
1003
+
1004
+ history.append(ChatMessage(
1005
+ role="assistant",
1006
+ content=docs_summary,
1007
+ metadata={"title": f"📚 Documents sélectionnés ({len(relevant_docs)} total)"}
1008
+ ))
1009
+ yield history
1010
+
1011
+ time.sleep(0.2)
1012
+
1013
+ # 5. Construction du contexte
1014
+ context_parts = []
1015
+ sources_with_scores = []
1016
+
1017
+ for i, doc in enumerate(relevant_docs):
1018
+ context_parts.append(f"[Document {i+1}] {doc['title']} - {doc['heading']}\n{doc['content']}")
1019
+ sources_with_scores.append({
1020
+ 'title': doc['title'],
1021
+ 'source': doc['source'],
1022
+ 'embedding_score': doc.get('embedding_score', 1 - doc.get('distance', 0)),
1023
+ 'rerank_score': doc.get('rerank_score'),
1024
+ 'final_rank': doc.get('final_rank', i+1)
1025
+ })
1026
+
1027
+ context = "\n\n".join(context_parts)
1028
+
1029
+ # 6. Génération de la réponse avec Qwen3-4B
1030
+ history.append(ChatMessage(
1031
+ role="assistant",
1032
+ content="Génération de la réponse basée sur les documents sélectionnés...",
1033
+ metadata={"title": "🤖 Génération avec Qwen3-4B"}
1034
+ ))
1035
+ yield history
1036
+
1037
+ time.sleep(0.2)
1038
+
1039
+ # Génération streamée de la réponse
1040
+ history.append(ChatMessage(
1041
+ role="assistant",
1042
+ content="", # Commencer avec un contenu vide
1043
+ metadata={"title": "🤖 Réponse générée"}
1044
+ ))
1045
+
1046
+ # Streamer la réponse token par token
1047
+ current_response = ""
1048
+ for token in self.generate_response_stream(query, context, history[:-1]): # Exclure le dernier message vide
1049
+ current_response += token
1050
+ # Mettre à jour le dernier message avec la réponse en cours
1051
+ history[-1] = ChatMessage(
1052
+ role="assistant",
1053
+ content=current_response,
1054
+ metadata={"title": "🤖 Réponse générée"}
1055
+ )
1056
+ yield history
1057
+ time.sleep(0.01) # Petit délai pour un streaming fluide
1058
+
1059
+ time.sleep(0.2)
1060
+
1061
+ # 7. Ajout des sources consultées avec scores détaillés
1062
+ sources_text = []
1063
+ for i, source_info in enumerate(sources_with_scores):
1064
+ embedding_score = source_info['embedding_score']
1065
+ rerank_score = source_info.get('rerank_score')
1066
+ source_file = source_info['source']
1067
+
1068
+ if rerank_score is not None:
1069
+ score_display = f"Embedding: {embedding_score:.3f} | **Reranking: {rerank_score:.3f}**"
1070
+ else:
1071
+ score_display = f"Score: {embedding_score:.3f}"
1072
+
1073
+ sources_text.append(f"• **[{i+1}]** {source_info['title']} ({source_file})\n └ {score_display}")
1074
+
1075
+ sources_display = "\n".join(sources_text)
1076
+
1077
+ # Titre adaptatif selon la méthode utilisée
1078
+ sources_title = f"📚 Sources avec reranking Qwen3 ({len(relevant_docs)} documents)" if should_rerank else f"📚 Sources par embedding ({len(relevant_docs)} documents)"
1079
+
1080
+ history.append(ChatMessage(
1081
+ role="assistant",
1082
+ content=sources_display,
1083
+ metadata={"title": sources_title}
1084
+ ))
1085
+ yield history
1086
+
1087
+
1088
+ def _create_rag_system():
1089
+ """Créé et configure le système RAG avec paramètres optimaux"""
1090
+
1091
+ # Détection automatique d'environnement
1092
+ is_zerogpu = ZEROGPU_AVAILABLE and os.getenv("SPACE_ID") is not None
1093
+ is_mac = torch.backends.mps.is_available() and not is_zerogpu
1094
+ is_cuda = torch.cuda.is_available()
1095
+
1096
+ if is_zerogpu:
1097
+ print("🚀 ZeroGPU détecté - optimisations cloud appliquées")
1098
+ elif is_mac:
1099
+ print("🍎 Mac avec MPS détecté - optimisations automatiques appliquées")
1100
+ elif is_cuda:
1101
+ print("🐧 CUDA détecté - optimisations GPU appliquées")
1102
+ else:
1103
+ print("💻 CPU détecté - optimisations processeur appliquées")
1104
+
1105
+ # Paramètres par défaut optimisés selon l'environnement
1106
+ if is_zerogpu:
1107
+ default_config = {
1108
+ 'use_flash_attention': True, # ZeroGPU supporte Flash Attention
1109
+ 'use_reranker': True, # GPU puissant, reranking activé
1110
+ 'initial_k': 30, # Plus de candidats avec GPU puissant
1111
+ 'final_k': 5 # Plus de documents finaux
1112
+ }
1113
+ elif is_mac:
1114
+ default_config = {
1115
+ 'use_flash_attention': False, # MPS ne supporte pas Flash Attention
1116
+ 'use_reranker': True, # Reranking OK sur Mac
1117
+ 'initial_k': 20, # Valeurs modérées
1118
+ 'final_k': 3
1119
+ }
1120
+ else:
1121
+ default_config = {
1122
+ 'use_flash_attention': is_cuda, # Flash Attention seulement sur CUDA
1123
+ 'use_reranker': True, # Reranking par défaut
1124
+ 'initial_k': 20, # Candidats pour la première étape
1125
+ 'final_k': 3 # Documents finaux par défaut
1126
+ }
1127
+
1128
+ print("🚀 Initialisation du chatbot RAG générique...")
1129
+ return GenericRAGChatbot(**default_config)
1130
+
1131
+
1132
+ def _clear_message():
1133
+ """Fonction utilitaire interne pour effacer le message d'entrée."""
1134
+ return ""
1135
+
1136
+ def _clear_chat():
1137
+ """Fonction utilitaire interne pour effacer l'historique de chat."""
1138
+ return []
1139
+
1140
+ def _ensure_chatmessages(history):
1141
+ """Convertit une liste en objets ChatMessage si besoin."""
1142
+ result = []
1143
+ for m in history or []:
1144
+ if isinstance(m, ChatMessage):
1145
+ result.append(m)
1146
+ elif isinstance(m, dict):
1147
+ result.append(ChatMessage(
1148
+ role=m.get("role", ""),
1149
+ content=m.get("content", ""),
1150
+ metadata=m.get("metadata", None)
1151
+ ))
1152
+ elif isinstance(m, (list, tuple)) and len(m) >= 2:
1153
+ result.append(ChatMessage(role=m[0], content=m[1]))
1154
+ return result
1155
+
1156
+
1157
+ @spaces.GPU(duration=180) # ZeroGPU: alloue GPU pour toute la pipeline
1158
+ def chat_with_generic_rag(message, history, top_k, use_reranking):
1159
+ """
1160
+ Interface entre Gradio et le système RAG générique avec contrôles avancés.
1161
+
1162
+ Cette fonction gère l'interface de chat interactive avec streaming en temps réel
1163
+ et affichage des étapes de traitement (recherche, reranking, génération).
1164
+
1165
+ Args:
1166
+ message (str): Le message ou question de l'utilisateur à traiter
1167
+ history (list): L'historique de la conversation sous forme de liste de messages
1168
+ top_k (int): Nombre de documents finaux à utiliser pour la génération de réponse
1169
+ use_reranking (bool): Activation du reranking Qwen3 pour améliorer la sélection
1170
+
1171
+ Yields:
1172
+ list: Historique mis à jour avec les nouveaux messages et étapes de traitement
1173
+ """
1174
+ history = _ensure_chatmessages(history)
1175
+ response_generator = rag_system.stream_response_with_tools(message, history, top_k, use_reranking)
1176
+ for updated_history in response_generator:
1177
+ yield updated_history
1178
+
1179
+
1180
+ def ask_rag_question(question: str = "Qu'est-ce que Swift MLX?", num_documents: int = 3, use_reranking: bool = True) -> str:
1181
+ """
1182
+ Pose une question au système RAG LocalRAG et retourne la réponse avec les documents sources.
1183
+
1184
+ Cette fonction utilise un système de recherche sémantique avancé avec des modèles Qwen3
1185
+ pour interroger une base de connaissances et générer des réponses contextualisées.
1186
+
1187
+ Args:
1188
+ question (str): La question à poser au système RAG en langage naturel
1189
+ num_documents (int): Nombre de documents à utiliser pour générer la réponse (entre 1 et 10)
1190
+ use_reranking (bool): Utiliser le reranking Qwen3-Reranker-4B pour améliorer la sélection des documents
1191
+
1192
+ Returns:
1193
+ str: Réponse générée incluant la réponse contextuelle et les sources avec leurs scores de pertinence
1194
+ """
1195
+ global rag_system
1196
+
1197
+ try:
1198
+ # Validation des paramètres
1199
+ num_documents = max(1, min(10, int(num_documents)))
1200
+
1201
+ print(f"🔍 Question MCP: {question}")
1202
+ print(f"📊 Paramètres: {num_documents} documents, reranking: {use_reranking}")
1203
+
1204
+ # Recherche des documents pertinents
1205
+ relevant_docs = rag_system.search_documents(question, num_documents, use_reranking)
1206
+
1207
+ if not relevant_docs:
1208
+ return "❌ Aucun document pertinent trouvé dans la base de données pour répondre à cette question."
1209
+
1210
+ # Construction du contexte pour la génération
1211
+ context_parts = []
1212
+ for i, doc in enumerate(relevant_docs):
1213
+ context_parts.append(f"[Document {i+1}] {doc['title']} - {doc['heading']}\n{doc['content']}")
1214
+
1215
+ context = "\n\n".join(context_parts)
1216
+
1217
+ # Génération de la réponse
1218
+ response = rag_system.generate_response(question, context, None)
1219
+
1220
+ # Formatage de la réponse avec les sources
1221
+ sources_info = []
1222
+ search_method = "avec reranking Qwen3" if use_reranking else "par embedding seulement"
1223
+
1224
+ sources_info.append(f"\n\n📚 **Documents sources utilisés ({search_method}):**\n")
1225
+
1226
+ for i, doc in enumerate(relevant_docs):
1227
+ embedding_score = doc.get('embedding_score', 0)
1228
+ rerank_score = doc.get('rerank_score')
1229
+ initial_rank = doc.get('initial_rank', i+1)
1230
+ final_rank = doc.get('final_rank', i+1)
1231
+
1232
+ # Formatage des scores
1233
+ if rerank_score is not None and use_reranking:
1234
+ score_display = f"Embedding: {embedding_score:.3f} | **Reranking: {rerank_score:.3f}**"
1235
+ if initial_rank != final_rank:
1236
+ rank_change = f" (#{initial_rank}→#{final_rank})"
1237
+ else:
1238
+ rank_change = ""
1239
+ else:
1240
+ score_display = f"Score: {embedding_score:.3f}"
1241
+ rank_change = ""
1242
+
1243
+ sources_info.append(f"• **[{i+1}]** {doc['title']}{rank_change}")
1244
+ sources_info.append(f" └ {score_display}")
1245
+ sources_info.append(f" └ Source: {doc['source']}")
1246
+
1247
+ # Assemblage de la réponse finale
1248
+ final_response = response + "\n".join(sources_info)
1249
+
1250
+ print(f"✅ Réponse MCP générée ({len(relevant_docs)} documents utilisés)")
1251
+ return final_response
1252
+
1253
+ except Exception as e:
1254
+ error_msg = f"❌ Erreur lors du traitement de la question: {str(e)}"
1255
+ print(error_msg)
1256
+ return error_msg
1257
+
1258
+
1259
+ def create_gradio_interface():
1260
+ """Créé l'interface Gradio pour utilisation externe (Spaces)"""
1261
+ # Initialisation du système RAG
1262
+ global rag_system
1263
+ try:
1264
+ rag_system = _create_rag_system()
1265
+ except Exception as e:
1266
+ raise RuntimeError(f"Erreur d'initialisation RAG: {e}")
1267
+
1268
+ # Configuration de l'interface Gradio avec thème Glass
1269
+ with gr.Blocks(
1270
+ title="🤖 LocalRAG Chat Générique",
1271
+ theme=gr.themes.Glass(),
1272
+ ) as demo:
1273
+
1274
+ # En-tête simplifié avec composants Gradio natifs
1275
+ with gr.Row():
1276
+ with gr.Column():
1277
+ gr.Markdown("# 🤖 Assistant RAG Générique LocalRAG")
1278
+ gr.Markdown(f"📦 Repository: `{rag_system.config.repo_id}` | 📊 Vecteurs: **{rag_system.config.total_vectors:,}**")
1279
+
1280
+ with gr.Row():
1281
+ with gr.Column(scale=4):
1282
+ chatbot = gr.Chatbot(
1283
+ label="💬 Conversation avec l'assistant",
1284
+ show_label=True,
1285
+ height=600,
1286
+ type="messages"
1287
+ )
1288
+
1289
+ msg = gr.Textbox(
1290
+ label="Votre question",
1291
+ placeholder="Posez votre question ici...",
1292
+ lines=1,
1293
+ max_lines=3
1294
+ )
1295
+
1296
+ with gr.Row():
1297
+ send_btn = gr.Button("Envoyer", variant="primary")
1298
+ clear_btn = gr.Button("Effacer", variant="secondary")
1299
+
1300
+ with gr.Column(scale=1):
1301
+ gr.Markdown("### ⚙️ Paramètres")
1302
+ top_k_slider = gr.Slider(
1303
+ minimum=1,
1304
+ maximum=20,
1305
+ value=5,
1306
+ step=1,
1307
+ label="Nombre de documents (top-k)",
1308
+ info="Plus élevé = plus de contexte"
1309
+ )
1310
+
1311
+ reranking_checkbox = gr.Checkbox(
1312
+ label="Activer reranking Qwen3",
1313
+ value=True,
1314
+ info="Améliore la pertinence"
1315
+ )
1316
+
1317
+ gr.Markdown("### 📊 Statistiques")
1318
+ gr.Markdown(f"""
1319
+ - **Modèle embedding:** Qwen3-Embedding-4B
1320
+ - **Modèle reranking:** Qwen3-Reranker-4B
1321
+ - **Modèle génération:** Qwen3-4B-Instruct-2507
1322
+ - **Index FAISS:** HNSW optimisé
1323
+ - **Vecteurs:** {rag_system.config.total_vectors:,}
1324
+ """)
1325
+
1326
+ # Interactions
1327
+ def _clear_message():
1328
+ return ""
1329
+
1330
+ def _clear_chat():
1331
+ return []
1332
+
1333
+ # Envoi par Entrée
1334
+ msg.submit(
1335
+ chat_with_generic_rag,
1336
+ [msg, chatbot, top_k_slider, reranking_checkbox],
1337
+ chatbot
1338
+ ).then(
1339
+ _clear_message,
1340
+ outputs=msg
1341
+ )
1342
+
1343
+ # Envoi par bouton
1344
+ send_btn.click(
1345
+ chat_with_generic_rag,
1346
+ [msg, chatbot, top_k_slider, reranking_checkbox],
1347
+ chatbot
1348
+ ).then(
1349
+ _clear_message,
1350
+ outputs=msg
1351
+ )
1352
+
1353
+ # Effacement de la conversation
1354
+ clear_btn.click(_clear_chat, outputs=chatbot)
1355
+
1356
+ return demo
1357
+
1358
+
1359
+ def main():
1360
+ """Point d'entrée principal."""
1361
+ print("🚀 LocalRAG Step 03 - Interface de chat générique")
1362
+ print("=" * 50)
1363
+
1364
+ # Vérification des dépendances
1365
+ if not _check_dependencies():
1366
+ return 1
1367
+
1368
+ # Initialisation du système RAG
1369
+ global rag_system
1370
+ try:
1371
+ rag_system = _create_rag_system()
1372
+ except Exception as e:
1373
+ print(f"❌ Erreur d'initialisation: {e}")
1374
+ return 1
1375
+
1376
+ # Configuration de l'interface Gradio avec thème Glass
1377
+ with gr.Blocks(
1378
+ title="🤖 LocalRAG Chat Générique",
1379
+ theme=gr.themes.Glass(),
1380
+ ) as demo:
1381
+
1382
+ # En-tête simplifié avec composants Gradio natifs
1383
+ with gr.Row():
1384
+ with gr.Column():
1385
+ gr.Markdown("# 🤖 Assistant RAG Générique LocalRAG")
1386
+
1387
+ # Affichage de l'environnement d'exécution
1388
+ env_info = ""
1389
+ if ZEROGPU_AVAILABLE and os.getenv("SPACE_ID"):
1390
+ env_info = "🚀 **Powered by ZeroGPU** - GPU gratuit Hugging Face"
1391
+ elif torch.backends.mps.is_available():
1392
+ env_info = "🍎 **Apple Silicon optimisé** - MPS accelerated"
1393
+ elif torch.cuda.is_available():
1394
+ env_info = f"🐧 **CUDA accelerated** - {torch.cuda.get_device_name()}"
1395
+ else:
1396
+ env_info = "💻 **CPU optimisé** - Traitement local"
1397
+
1398
+ gr.Markdown(f"**Système RAG complet avec modèles Qwen3 de dernière génération**")
1399
+ gr.Markdown(env_info)
1400
+ gr.Markdown(f"🧠 {rag_system.config.embedding_model.split('/')[-1]} • 🎯 Qwen3-Reranker-4B • 💬 Qwen3-4B • ⚡ Recherche en 2 étapes")
1401
+ gr.Markdown(f"📦 Repository: `{rag_system.config.repo_id}` | 📊 Vecteurs: **{rag_system.config.total_vectors:,}**")
1402
+
1403
+ # Interface de chat
1404
+ chatbot = gr.Chatbot(
1405
+ height=500,
1406
+ show_label=False,
1407
+ container=True,
1408
+ show_copy_button=True,
1409
+ autoscroll=True,
1410
+ avatar_images=(None, "🤖"),
1411
+ type="messages"
1412
+ )
1413
+
1414
+ # Zone de saisie
1415
+ with gr.Row():
1416
+ msg = gr.Textbox(
1417
+ placeholder="Posez votre question...",
1418
+ show_label=False,
1419
+ container=False,
1420
+ scale=4
1421
+ )
1422
+ send_btn = gr.Button("📤 Envoyer", variant="primary", scale=1)
1423
+
1424
+ # Panneau de contrôle avancé simplifié
1425
+ with gr.Accordion("🎛️ Contrôles avancés", open=True):
1426
+ with gr.Row():
1427
+ top_k_slider = gr.Slider(
1428
+ minimum=1,
1429
+ maximum=10,
1430
+ value=3,
1431
+ step=1,
1432
+ label="📊 Nombre de documents finaux",
1433
+ info="Documents qui seront utilisés pour générer la réponse"
1434
+ )
1435
+
1436
+ reranking_checkbox = gr.Checkbox(
1437
+ value=True,
1438
+ label="🎯 Activer le reranking Qwen3",
1439
+ info="Améliore la pertinence avec un modèle de reranking spécialisé"
1440
+ )
1441
+
1442
+ # Bouton pour effacer
1443
+ clear_btn = gr.Button("🗑️ Effacer la conversation", variant="secondary", size="lg")
1444
+
1445
+ # Informations en pied de page avec Accordion pour économiser l'espace
1446
+ with gr.Accordion("ℹ️ Informations sur l'architecture", open=False):
1447
+ env_docs = ""
1448
+ if ZEROGPU_AVAILABLE and os.getenv("SPACE_ID"):
1449
+ env_docs = """
1450
+ ### 🚀 Optimisations ZeroGPU
1451
+
1452
+ - **Allocation dynamique :** GPU alloué automatiquement pour le reranking et la génération
1453
+ - **NVIDIA H200 :** 70GB VRAM disponible pour les calculs intensifs
1454
+ - **Décorateurs intelligents :** `@spaces.GPU()` pour optimiser l'usage GPU
1455
+ - **Cache optimisé :** Stockage temporaire en `/tmp` pour performances maximales
1456
+ """
1457
+ elif torch.backends.mps.is_available():
1458
+ env_docs = """
1459
+ ### 🍎 Optimisations Apple Silicon
1460
+
1461
+ - **Metal Performance Shaders :** Accélération native Apple
1462
+ - **Index FAISS adapté :** IndexFlatIP pour éviter les segfaults
1463
+ - **Mémoire unifiée :** Partage efficace CPU/GPU
1464
+ - **Float32 :** Précision optimisée pour MPS
1465
+ """
1466
+ else:
1467
+ env_docs = """
1468
+ ### ⚡ Optimisations locales
1469
+
1470
+ - **Multi-plateforme :** Support CPU, CUDA, MPS selon disponibilité
1471
+ - **Flash Attention :** Activé automatiquement sur CUDA
1472
+ - **Gestion mémoire :** Cleanup automatique pour stabilité
1473
+ """
1474
+
1475
+ gr.Markdown(f"""
1476
+ ### 🚀 Architecture LocalRAG Step 03
1477
+
1478
+ - **📥 Step 02 :** Embeddings chargés depuis Hugging Face Hub au format SafeTensors
1479
+ - **🔍 Recherche :** Index FAISS reconstructé pour recherche vectorielle haute performance
1480
+ - **🎯 Reranking :** Qwen3-Reranker-4B pour affiner la sélection des documents
1481
+ - **💬 Génération :** Qwen3-4B-Instruct-2507 pour des réponses contextuelles optimisées
1482
+ {env_docs}
1483
+ ### 📊 Lecture des scores
1484
+
1485
+ - **Score Embedding :** Similarité vectorielle initiale (0.0-1.0, plus haut = plus pertinent)
1486
+ - **Score Reranking :** Score de pertinence final après analyse contextuelle
1487
+ - **Changement de rang :** Evolution de la position du document après reranking
1488
+ """)
1489
+
1490
+ # Gestionnaire de likes
1491
+ def like_response(evt: gr.LikeData):
1492
+ print(f"Réaction utilisateur: {'👍' if evt.liked else '👎'} sur le message #{evt.index}")
1493
+ print(f"Contenu: {evt.value[:100]}...")
1494
+
1495
+ chatbot.like(like_response)
1496
+
1497
+ # Envoi par touche Entrée
1498
+ msg.submit(
1499
+ chat_with_generic_rag,
1500
+ [msg, chatbot, top_k_slider, reranking_checkbox],
1501
+ chatbot
1502
+ ).then(
1503
+ _clear_message,
1504
+ outputs=msg
1505
+ )
1506
+
1507
+ # Envoi par bouton
1508
+ send_btn.click(
1509
+ chat_with_generic_rag,
1510
+ [msg, chatbot, top_k_slider, reranking_checkbox],
1511
+ chatbot
1512
+ ).then(
1513
+ _clear_message,
1514
+ outputs=msg
1515
+ )
1516
+
1517
+ # Effacement de la conversation
1518
+ clear_btn.click(_clear_chat, outputs=chatbot)
1519
+
1520
+ print("🌐 Lancement de l'interface Gradio...")
1521
+
1522
+ # Configuration HTTPS pour Claude Desktop
1523
+ ssl_keyfile = os.getenv("SSL_KEYFILE")
1524
+ ssl_certfile = os.getenv("SSL_CERTFILE")
1525
+
1526
+ if ssl_keyfile and ssl_certfile:
1527
+ print("🔒 Mode HTTPS activé")
1528
+ print("🔗 Serveur MCP : /gradio_api/mcp/sse")
1529
+
1530
+ demo.launch(
1531
+ mcp_server=True, # Toujours activer MCP
1532
+ inbrowser=True,
1533
+ show_error=True,
1534
+ ssl_keyfile=ssl_keyfile,
1535
+ ssl_certfile=ssl_certfile
1536
+ )
1537
+ else:
1538
+ print("🔗 Serveur MCP : /gradio_api/mcp/sse")
1539
+ print("💡 Pour HTTPS : python step03_ssl_generator_optional.py")
1540
+
1541
+ demo.launch(
1542
+ mcp_server=True, # Toujours activer MCP
1543
+ inbrowser=True,
1544
+ show_error=True
1545
+ )
1546
+
1547
+ print("📋 Outil MCP exposé : ask_rag_question")
1548
+
1549
+ return 0
1550
+
1551
+
1552
+ if __name__ == "__main__":
1553
+ exit(main())