VincentGOURBIN commited on
Commit
b4c82ce
·
verified ·
1 Parent(s): c892afd

Upload step03_chatbot.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. step03_chatbot.py +43 -40
step03_chatbot.py CHANGED
@@ -571,7 +571,18 @@ class GenericRAGChatbot:
571
  try:
572
  from sentence_transformers import SentenceTransformer
573
 
574
- if self.use_flash_attention and self.is_cuda:
 
 
 
 
 
 
 
 
 
 
 
575
  print(" - Configuration avec Flash Attention 2 activée (CUDA)")
576
  try:
577
  self.embedding_model = SentenceTransformer(
@@ -711,6 +722,7 @@ class GenericRAGChatbot:
711
  except:
712
  return 0.0
713
 
 
714
  def search_documents(self, query: str, final_k: int = None, use_reranking: bool = None) -> List[Dict]:
715
  """
716
  Recherche avancée avec reranking en deux étapes
@@ -724,10 +736,13 @@ class GenericRAGChatbot:
724
  # Les modèles d'embedding fonctionnent bien sur CPU sur ZeroGPU
725
 
726
  # Étape 1: Recherche par embedding avec FAISS
 
727
  if hasattr(self.embedding_model, 'prompts') and 'query' in self.embedding_model.prompts:
728
- query_embedding = self.embedding_model.encode([query], prompt_name="query")[0]
729
  else:
730
- query_embedding = self.embedding_model.encode([query])[0]
 
 
731
 
732
  # Recherche dans l'index FAISS
733
  query_vector = query_embedding.reshape(1, -1).astype('float32')
@@ -842,27 +857,19 @@ Instructions importantes:
842
  messages.append({"role": "user", "content": user_message})
843
 
844
  try:
845
- # Formatage manuel plus stable pour ZeroGPU
846
- formatted_messages = []
847
- for msg in messages:
848
- if msg["role"] == "system":
849
- formatted_messages.append(f"<|im_start|>system\n{msg['content']}<|im_end|>")
850
- elif msg["role"] == "user":
851
- formatted_messages.append(f"<|im_start|>user\n{msg['content']}<|im_end|>")
852
- elif msg["role"] == "assistant":
853
- formatted_messages.append(f"<|im_start|>assistant\n{msg['content']}<|im_end|>")
854
-
855
- # Ajouter le prompt de génération
856
- formatted_messages.append("<|im_start|>assistant\n")
857
- formatted_prompt = "\n".join(formatted_messages)
858
 
859
  # Tokenisation
860
  inputs = self.generation_tokenizer(
861
  formatted_prompt,
862
  return_tensors="pt",
863
  truncation=True,
864
- max_length=4096,
865
- padding=True
866
  )
867
 
868
  # Déplacement vers le device
@@ -883,8 +890,10 @@ Instructions importantes:
883
  "input_ids": inputs["input_ids"],
884
  "attention_mask": inputs["attention_mask"],
885
  "streamer": streamer,
886
- "max_new_tokens": 512,
887
- "temperature": 0.7,
 
 
888
  "do_sample": True,
889
  "pad_token_id": self.generation_tokenizer.pad_token_id,
890
  "eos_token_id": self.generation_tokenizer.eos_token_id,
@@ -943,39 +952,33 @@ Réponds à cette question en te basant sur le contexte fourni."""
943
 
944
  # Formatage pour le modèle
945
  try:
946
- # Formatage manuel plus stable pour ZeroGPU
947
- formatted_messages = []
948
- for msg in messages:
949
- if msg["role"] == "system":
950
- formatted_messages.append(f"<|im_start|>system\n{msg['content']}<|im_end|>")
951
- elif msg["role"] == "user":
952
- formatted_messages.append(f"<|im_start|>user\n{msg['content']}<|im_end|>")
953
- elif msg["role"] == "assistant":
954
- formatted_messages.append(f"<|im_start|>assistant\n{msg['content']}<|im_end|>")
955
-
956
- # Ajouter le prompt de génération
957
- formatted_messages.append("<|im_start|>assistant\n")
958
- formatted_prompt = "\n".join(formatted_messages)
959
-
960
- # Tokenisation avec padding et attention mask appropriés
961
  inputs = self.generation_tokenizer(
962
  formatted_prompt,
963
  return_tensors="pt",
964
  truncation=True,
965
- max_length=4096,
966
- padding=True
967
  )
968
 
969
  # Déplacement vers le device
970
  inputs = {k: v.to(self.generation_device) for k, v in inputs.items()}
971
 
972
- # Génération avec paramètres simplifiés
973
  with torch.no_grad():
974
  outputs = self.generation_model.generate(
975
  input_ids=inputs["input_ids"],
976
  attention_mask=inputs["attention_mask"],
977
- max_new_tokens=512,
978
- temperature=0.7,
 
 
979
  do_sample=True,
980
  pad_token_id=self.generation_tokenizer.pad_token_id,
981
  eos_token_id=self.generation_tokenizer.eos_token_id,
 
571
  try:
572
  from sentence_transformers import SentenceTransformer
573
 
574
+ if os.getenv("SPACE_ID"):
575
+ print(" - Configuration ZeroGPU optimisée")
576
+ # Sur ZeroGPU, utiliser float16 et device auto pour les performances
577
+ self.embedding_model = SentenceTransformer(
578
+ self.config.embedding_model,
579
+ model_kwargs={
580
+ "torch_dtype": torch.float16,
581
+ "device_map": "auto"
582
+ },
583
+ tokenizer_kwargs={"padding_side": "left"}
584
+ )
585
+ elif self.use_flash_attention and self.is_cuda:
586
  print(" - Configuration avec Flash Attention 2 activée (CUDA)")
587
  try:
588
  self.embedding_model = SentenceTransformer(
 
722
  except:
723
  return 0.0
724
 
725
+ @spaces.GPU(duration=120) # ZeroGPU: GPU nécessaire pour embedding
726
  def search_documents(self, query: str, final_k: int = None, use_reranking: bool = None) -> List[Dict]:
727
  """
728
  Recherche avancée avec reranking en deux étapes
 
736
  # Les modèles d'embedding fonctionnent bien sur CPU sur ZeroGPU
737
 
738
  # Étape 1: Recherche par embedding avec FAISS
739
+ print(" 🎯 Calcul de l'embedding de la requête...")
740
  if hasattr(self.embedding_model, 'prompts') and 'query' in self.embedding_model.prompts:
741
+ query_embedding = self.embedding_model.encode([query], prompt_name="query", show_progress_bar=False)[0]
742
  else:
743
+ query_embedding = self.embedding_model.encode([query], show_progress_bar=False)[0]
744
+
745
+ print(f" 📐 Embedding calculé: shape={query_embedding.shape}, norm={np.linalg.norm(query_embedding):.3f}")
746
 
747
  # Recherche dans l'index FAISS
748
  query_vector = query_embedding.reshape(1, -1).astype('float32')
 
857
  messages.append({"role": "user", "content": user_message})
858
 
859
  try:
860
+ # Utiliser le template officiel Qwen3 (documentation officielle)
861
+ formatted_prompt = self.generation_tokenizer.apply_chat_template(
862
+ messages,
863
+ tokenize=False,
864
+ add_generation_prompt=True
865
+ )
 
 
 
 
 
 
 
866
 
867
  # Tokenisation
868
  inputs = self.generation_tokenizer(
869
  formatted_prompt,
870
  return_tensors="pt",
871
  truncation=True,
872
+ max_length=4096
 
873
  )
874
 
875
  # Déplacement vers le device
 
890
  "input_ids": inputs["input_ids"],
891
  "attention_mask": inputs["attention_mask"],
892
  "streamer": streamer,
893
+ "max_new_tokens": 1024, # Recommandation officielle
894
+ "temperature": 0.7, # Recommandation officielle
895
+ "top_p": 0.8, # Recommandation officielle
896
+ "top_k": 20, # Recommandation officielle
897
  "do_sample": True,
898
  "pad_token_id": self.generation_tokenizer.pad_token_id,
899
  "eos_token_id": self.generation_tokenizer.eos_token_id,
 
952
 
953
  # Formatage pour le modèle
954
  try:
955
+ # Utiliser le template officiel Qwen3 (documentation officielle)
956
+ formatted_prompt = self.generation_tokenizer.apply_chat_template(
957
+ messages,
958
+ tokenize=False,
959
+ add_generation_prompt=True
960
+ )
961
+
962
+ # Tokenisation avec les bonnes options
 
 
 
 
 
 
 
963
  inputs = self.generation_tokenizer(
964
  formatted_prompt,
965
  return_tensors="pt",
966
  truncation=True,
967
+ max_length=4096
 
968
  )
969
 
970
  # Déplacement vers le device
971
  inputs = {k: v.to(self.generation_device) for k, v in inputs.items()}
972
 
973
+ # Génération avec paramètres officiels Qwen3
974
  with torch.no_grad():
975
  outputs = self.generation_model.generate(
976
  input_ids=inputs["input_ids"],
977
  attention_mask=inputs["attention_mask"],
978
+ max_new_tokens=1024, # Recommandation officielle
979
+ temperature=0.7, # Recommandation officielle
980
+ top_p=0.8, # Recommandation officielle
981
+ top_k=20, # Recommandation officielle
982
  do_sample=True,
983
  pad_token_id=self.generation_tokenizer.pad_token_id,
984
  eos_token_id=self.generation_tokenizer.eos_token_id,