adel67460 commited on
Commit
ce6bc00
·
verified ·
1 Parent(s): 7a9155e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -65
app.py CHANGED
@@ -4,8 +4,8 @@ import re
4
  import gradio as gr
5
  import torch
6
  import pandas as pd
7
- import faiss
8
  import time
 
9
  from scipy.sparse import csr_matrix
10
  from transformers import AutoModel, AutoProcessor
11
  from sklearn.feature_extraction.text import TfidfVectorizer
@@ -38,18 +38,16 @@ for attempt in range(MAX_RETRIES):
38
  print("⛔ Échec final du chargement du modèle.")
39
  model, processor = None, None
40
 
41
- # Fonction pour charger et fusionner les données
42
  def load_data():
43
  products_data, qa_data = [], []
44
 
45
- # Charger les produits
46
  if os.path.exists(PRODUCTS_FILE):
47
  with open(PRODUCTS_FILE, "r", encoding="utf-8") as f:
48
  products_data = json.load(f).get("products", [])
49
  else:
50
  print(f"⛔ Fichier introuvable : {PRODUCTS_FILE}")
51
 
52
- # Charger les questions-réponses
53
  if os.path.exists(QA_FILE):
54
  with open(QA_FILE, "r", encoding="utf-8") as f:
55
  qa_data = json.load(f)
@@ -61,10 +59,49 @@ def load_data():
61
  products_data, qa_data = load_data()
62
 
63
  # Associer les questions-réponses aux produits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  for product in products_data:
65
- product_name = product.get("title", "").lower()
66
- related_qa = [qa for qa in qa_data if product_name in qa.get("question", "").lower()]
67
- product["qa_info"] = related_qa # Ajouter les questions/réponses au produit
 
 
 
 
 
 
 
 
 
 
68
 
69
  # Prétraitement du texte
70
  def preprocess(text: str) -> str:
@@ -73,74 +110,61 @@ def preprocess(text: str) -> str:
73
  text = re.sub(r'[^\w\s]', '', text)
74
  return text.strip()
75
 
76
- # Génération des embeddings des produits optimisée
77
- def get_text_embeddings(texts, batch_size=32):
78
- if not texts: # Vérifier que la liste de textes n'est pas vide
79
- print("⚠️ Avertissement : Aucun texte à encoder. Retour d'une matrice vide.")
80
- return torch.zeros((0, model.config.hidden_size)).numpy()
81
-
82
- with torch.no_grad():
83
- processed_texts = processor(text=texts, return_tensors="pt", truncation=True, max_length=64, padding=True).to(device)
84
- embeddings = model.get_text_features(processed_texts["input_ids"], normalize=True)
85
- return embeddings.cpu().numpy()
86
-
87
- print("🛠️ Génération des embeddings des produits...")
88
- product_embeddings = get_text_embeddings([prod["title"] + " " + prod["description"] for prod in products_data])
89
- print("✅ Embeddings générés et sauvegardés !")
90
-
91
- # Optimisation FAISS avec index dynamique
92
- d = product_embeddings.shape[1]
93
- nlist = max(10, len(product_embeddings) // 10) # Ajuste dynamiquement le nombre de centroids
94
-
95
- if len(product_embeddings) < 4000:
96
- index = faiss.IndexFlatL2(d) # Index simple pour petits datasets
97
- else:
98
- index = faiss.IndexIVFFlat(faiss.IndexFlatL2(d), d, nlist)
99
- index.train(product_embeddings)
100
-
101
- index.add(product_embeddings)
102
-
103
  # TF-IDF Vectorizer
104
  vectorizer = TfidfVectorizer(stop_words="english")
105
  tfidf_matrix = vectorizer.fit_transform([prod["title"] + " " + prod["description"] for prod in products_data])
106
 
107
- # Recherche hybride optimisée avec pondération
108
- def search_products(query, category, min_price, max_price, weight_tfidf=0.5, weight_faiss=0.5):
109
  query = preprocess(query)
110
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  # TF-IDF Similarité
112
- query_vector_sparse = csr_matrix(vectorizer.transform([query]))
113
  tfidf_scores = (tfidf_matrix * query_vector_sparse.T).toarray().flatten()
114
-
115
- # FAISS Similarité
116
- query_embedding = get_text_embeddings([query])[0].reshape(1, -1)
117
- _, indices = index.search(query_embedding, 50)
118
-
119
- # Récupérer les produits similaires
120
- similar_products = pd.DataFrame([products_data[i] for i in indices[0]])
121
-
122
- # Normalisation des scores
123
- tfidf_scores_norm = (tfidf_scores - tfidf_scores.min()) / (tfidf_scores.max() - tfidf_scores.min() + 1e-6)
124
-
125
- # Ajout du score pondéré
126
- similar_products["score"] = weight_tfidf * tfidf_scores_norm[indices[0]] + weight_faiss * (1 - tfidf_scores_norm[indices[0]])
127
-
128
- # Filtrage par prix et disponibilité
129
- filtered_results = similar_products[
130
- (similar_products["price"] >= min_price) &
131
- (similar_products["price"] <= max_price) &
132
- (similar_products["availability"] == "in stock")
 
 
 
133
  ]
134
-
135
  if category and category != "Toutes":
136
- filtered_results = filtered_results[filtered_results["category"].str.contains(category, case=False, na=False)]
137
 
138
- # Ajouter les réponses aux résultats
139
- filtered_results["qa_info"] = filtered_results["id"].apply(
140
- lambda prod_id: [prod["qa_info"] for prod in products_data if prod["id"] == prod_id][0]
141
- )
142
-
143
- return filtered_results.sort_values(by="score", ascending=False).head(20)
144
 
145
  # Interface Gradio
146
  app = gr.Interface(
 
4
  import gradio as gr
5
  import torch
6
  import pandas as pd
 
7
  import time
8
+ import marqo
9
  from scipy.sparse import csr_matrix
10
  from transformers import AutoModel, AutoProcessor
11
  from sklearn.feature_extraction.text import TfidfVectorizer
 
38
  print("⛔ Échec final du chargement du modèle.")
39
  model, processor = None, None
40
 
41
+ # Fonction pour charger les fichiers JSON
42
  def load_data():
43
  products_data, qa_data = [], []
44
 
 
45
  if os.path.exists(PRODUCTS_FILE):
46
  with open(PRODUCTS_FILE, "r", encoding="utf-8") as f:
47
  products_data = json.load(f).get("products", [])
48
  else:
49
  print(f"⛔ Fichier introuvable : {PRODUCTS_FILE}")
50
 
 
51
  if os.path.exists(QA_FILE):
52
  with open(QA_FILE, "r", encoding="utf-8") as f:
53
  qa_data = json.load(f)
 
59
  products_data, qa_data = load_data()
60
 
61
  # Associer les questions-réponses aux produits
62
+ def associate_qa_with_products(products, qa_data):
63
+ for product in products:
64
+ product["qa_info"] = []
65
+ product_name = product.get("title", "").lower()
66
+ product_desc = product.get("description", "").lower()
67
+
68
+ for qa in qa_data:
69
+ question = qa.get("question", "").lower()
70
+ if product_name in question or product_desc in question:
71
+ product["qa_info"].append(qa)
72
+
73
+ return products
74
+
75
+ products_data = associate_qa_with_products(products_data, qa_data)
76
+
77
+ # Connexion au serveur Marqo
78
+ mq = marqo.Client(url="http://localhost:8882") # Port par défaut de Marqo
79
+ INDEX_NAME = "ecommerce_products"
80
+
81
+ # Supprimer et recréer l'index
82
+ try:
83
+ mq.delete_index(INDEX_NAME)
84
+ except:
85
+ pass
86
+
87
+ mq.create_index(INDEX_NAME, settings={"index_defaults": {"model": "open_clip/ViT-B-32/laion2B-s34B-b79K"}})
88
+
89
+ # Ajouter les produits à Marqo
90
+ documents = []
91
  for product in products_data:
92
+ doc = {
93
+ "id": product["id"],
94
+ "title": product["title"],
95
+ "description": product["description"],
96
+ "price": product["price"],
97
+ "availability": product["availability"],
98
+ "category": product["category"],
99
+ "qa_info": product.get("qa_info", []),
100
+ }
101
+ documents.append(doc)
102
+
103
+ mq.index(INDEX_NAME).add_documents(documents, tensor_fields=["title", "description"])
104
+ print("✅ Produits indexés dans Marqo !")
105
 
106
  # Prétraitement du texte
107
  def preprocess(text: str) -> str:
 
110
  text = re.sub(r'[^\w\s]', '', text)
111
  return text.strip()
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  # TF-IDF Vectorizer
114
  vectorizer = TfidfVectorizer(stop_words="english")
115
  tfidf_matrix = vectorizer.fit_transform([prod["title"] + " " + prod["description"] for prod in products_data])
116
 
117
+ # Recherche hybride avec Marqo + TF-IDF
118
+ def search_products(query, category, min_price, max_price, weight_tfidf=0.5, weight_marqo=0.5):
119
  query = preprocess(query)
120
+
121
+ # Recherche Marqo (top 50 résultats)
122
+ marqo_results = mq.index(INDEX_NAME).search(query, searchable_attributes=["title", "description"], limit=50)
123
+
124
+ # Récupérer les résultats Marqo
125
+ marqo_products = []
126
+ marqo_scores = []
127
+ for hit in marqo_results["hits"]:
128
+ marqo_products.append(hit)
129
+ marqo_scores.append(hit["_score"])
130
+
131
+ # Normaliser les scores Marqo
132
+ if len(marqo_scores) > 0:
133
+ marqo_scores = (pd.Series(marqo_scores) - min(marqo_scores)) / (max(marqo_scores) - min(marqo_scores) + 1e-6)
134
+ else:
135
+ marqo_scores = [0] * len(marqo_products)
136
+
137
  # TF-IDF Similarité
138
+ query_vector_sparse = csr_matrix(vectorizer.transform([query]))
139
  tfidf_scores = (tfidf_matrix * query_vector_sparse.T).toarray().flatten()
140
+
141
+ # Normaliser les scores TF-IDF
142
+ if len(tfidf_scores) > 0:
143
+ tfidf_scores_norm = (tfidf_scores - min(tfidf_scores)) / (max(tfidf_scores) - min(tfidf_scores) + 1e-6)
144
+ else:
145
+ tfidf_scores_norm = [0] * len(marqo_products)
146
+
147
+ # Fusionner les scores TF-IDF et Marqo
148
+ final_scores = weight_tfidf * tfidf_scores_norm[:len(marqo_products)] + weight_marqo * marqo_scores
149
+
150
+ # Ajouter le score final aux produits
151
+ for i, product in enumerate(marqo_products):
152
+ product["score"] = final_scores[i]
153
+
154
+ # Convertir en DataFrame
155
+ results_df = pd.DataFrame(marqo_products)
156
+
157
+ # Filtrer les résultats par prix et disponibilité
158
+ results_df = results_df[
159
+ (results_df["price"] >= min_price) &
160
+ (results_df["price"] <= max_price) &
161
+ (results_df["availability"] == "in stock")
162
  ]
163
+
164
  if category and category != "Toutes":
165
+ results_df = results_df[results_df["category"].str.contains(category, case=False, na=False)]
166
 
167
+ return results_df.sort_values(by="score", ascending=False).head(20)
 
 
 
 
 
168
 
169
  # Interface Gradio
170
  app = gr.Interface(