adel67460 commited on
Commit
7da6546
·
verified ·
1 Parent(s): e3561d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -149
app.py CHANGED
@@ -6,182 +6,112 @@ import torch
6
  import pandas as pd
7
  import time
8
  import marqo
 
9
  from scipy.sparse import csr_matrix
10
  from transformers import AutoModel, AutoProcessor
11
  from sklearn.feature_extraction.text import TfidfVectorizer
12
 
13
- # Vérifier si CUDA est disponible (GPU)
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
  print(f"🔹 Utilisation du périphérique : {device}")
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  # Définition des fichiers JSON
18
  PRODUCTS_FILE = "products.json"
19
  QA_FILE = "qa_sequences_output.json"
20
 
21
- # Chargement sécurisé du modèle Marqo
22
- MAX_RETRIES = 3
23
- model_name = "Marqo/marqo-ecommerce-embeddings-L"
 
 
24
 
25
- for attempt in range(MAX_RETRIES):
26
  try:
27
- print(f"🔄 Chargement du modèle Marqo... (Tentative {attempt + 1}/{MAX_RETRIES})")
28
- model = AutoModel.from_pretrained(model_name, trust_remote_code=True).to(device)
29
- processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
30
- print("✅ Modèle chargé avec succès !")
31
- break
32
- except Exception as e:
33
- print(f"❌ Erreur de chargement : {e}")
34
- if attempt < MAX_RETRIES - 1:
35
- print("🔁 Nouvelle tentative dans 5 secondes...")
36
- time.sleep(5)
37
- else:
38
- print("⛔ Échec final du chargement du modèle.")
39
- model, processor = None, None
40
-
41
- # Fonction pour charger les fichiers JSON
42
- def load_data():
43
- products_data, qa_data = [], []
44
-
45
- if os.path.exists(PRODUCTS_FILE):
46
- with open(PRODUCTS_FILE, "r", encoding="utf-8") as f:
47
- products_data = json.load(f).get("products", [])
48
- else:
49
- print(f"⛔ Fichier introuvable : {PRODUCTS_FILE}")
50
-
51
- if os.path.exists(QA_FILE):
52
- with open(QA_FILE, "r", encoding="utf-8") as f:
53
- qa_data = json.load(f)
54
- else:
55
- print(f"⛔ Fichier introuvable : {QA_FILE}")
56
-
57
- return products_data, qa_data
58
-
59
- products_data, qa_data = load_data()
60
-
61
- # Associer les questions-réponses aux produits
62
- def associate_qa_with_products(products, qa_data):
63
- for product in products:
64
- product["qa_info"] = []
65
- product_name = product.get("title", "").lower()
66
- product_desc = product.get("description", "").lower()
67
-
68
- for qa in qa_data:
69
- question = qa.get("question", "").lower()
70
- if product_name in question or product_desc in question:
71
- product["qa_info"].append(qa)
72
-
73
- return products
74
-
75
- products_data = associate_qa_with_products(products_data, qa_data)
76
-
77
- # Connexion au serveur Marqo
78
- mq = marqo.Client(url="http://localhost:8882") # Port par défaut de Marqo
79
- INDEX_NAME = "ecommerce_products"
80
 
81
- # Supprimer et recréer l'index
82
- try:
 
 
 
 
83
  mq.delete_index(INDEX_NAME)
84
- except:
85
- pass
86
 
87
- # Création de l'index
88
- mq.create_index(INDEX_NAME)
89
  print("✅ Index Marqo créé avec succès !")
90
 
91
  # Ajouter les produits à Marqo
92
- documents = []
93
- for product in products_data:
94
- doc = {
95
- "id": product["id"],
96
- "title": product["title"],
97
- "description": product["description"],
98
- "price": product["price"],
99
- "availability": product["availability"],
100
- "category": product["category"],
101
- "qa_info": product.get("qa_info", []),
102
- "_model": "open_clip/ViT-B-32/laion2B-b79K", # Spécifier le modèle ici
103
  }
104
- documents.append(doc)
 
105
 
106
  mq.index(INDEX_NAME).add_documents(documents, tensor_fields=["title", "description"])
107
  print("✅ Produits indexés dans Marqo avec succès !")
108
 
109
- # Prétraitement du texte
110
- def preprocess(text: str) -> str:
111
- text = text.lower()
112
- text = re.sub(r'\s+', ' ', text)
113
- text = re.sub(r'[^\w\s]', '', text)
114
- return text.strip()
115
-
116
- # TF-IDF Vectorizer
117
- vectorizer = TfidfVectorizer(stop_words="english")
118
- tfidf_matrix = vectorizer.fit_transform([prod["title"] + " " + prod["description"] for prod in products_data])
119
-
120
- # Recherche hybride avec Marqo + TF-IDF
121
- def search_products(query, category, min_price, max_price, weight_tfidf=0.5, weight_marqo=0.5):
122
- query = preprocess(query)
123
-
124
- # Recherche Marqo (top 50 résultats)
125
- marqo_results = mq.index(INDEX_NAME).search(query, searchable_attributes=["title", "description"], limit=50)
126
-
127
- # Récupérer les résultats Marqo
128
- marqo_products = []
129
- marqo_scores = []
130
- for hit in marqo_results["hits"]:
131
- marqo_products.append(hit)
132
- marqo_scores.append(hit["_score"])
133
-
134
- # Normaliser les scores Marqo
135
- if len(marqo_scores) > 0:
136
- marqo_scores = (pd.Series(marqo_scores) - min(marqo_scores)) / (max(marqo_scores) - min(marqo_scores) + 1e-6)
137
- else:
138
- marqo_scores = [0] * len(marqo_products)
139
-
140
- # TF-IDF Similarité
141
- query_vector_sparse = csr_matrix(vectorizer.transform([query]))
142
- tfidf_scores = (tfidf_matrix * query_vector_sparse.T).toarray().flatten()
143
-
144
- # Normaliser les scores TF-IDF
145
- if len(tfidf_scores) > 0:
146
- tfidf_scores_norm = (tfidf_scores - min(tfidf_scores)) / (max(tfidf_scores) - min(tfidf_scores) + 1e-6)
147
- else:
148
- tfidf_scores_norm = [0] * len(marqo_products)
149
-
150
- # Fusionner les scores TF-IDF et Marqo
151
- final_scores = weight_tfidf * tfidf_scores_norm[:len(marqo_products)] + weight_marqo * marqo_scores
152
-
153
- # Ajouter le score final aux produits
154
- for i, product in enumerate(marqo_products):
155
- product["score"] = final_scores[i]
156
-
157
- # Convertir en DataFrame
158
- results_df = pd.DataFrame(marqo_products)
159
-
160
- # Filtrer les résultats par prix et disponibilité
161
- results_df = results_df[
162
- (results_df["price"] >= min_price) &
163
- (results_df["price"] <= max_price) &
164
- (results_df["availability"] == "in stock")
165
- ]
166
-
167
- if category and category != "Toutes":
168
- results_df = results_df[results_df["category"].str.contains(category, case=False, na=False)]
169
-
170
- return results_df.sort_values(by="score", ascending=False).head(20)
171
-
172
  # Interface Gradio
 
 
 
 
 
 
 
 
 
 
 
 
173
  app = gr.Interface(
174
  fn=search_products,
175
- inputs=[
176
- gr.Textbox(label="Rechercher un produit"),
177
- gr.Textbox(label="Catégorie"),
178
- gr.Number(label="Prix min"),
179
- gr.Number(label="Prix max")
180
- ],
181
- outputs=[
182
- gr.Dataframe(headers=["ID", "Titre", "Description", "Prix", "Disponibilité", "Questions/Réponses"],
183
- datatype=["str", "str", "str", "number", "str", "json"])
184
- ]
185
  )
186
 
187
  app.launch()
 
6
  import pandas as pd
7
  import time
8
  import marqo
9
+ import requests
10
  from scipy.sparse import csr_matrix
11
  from transformers import AutoModel, AutoProcessor
12
  from sklearn.feature_extraction.text import TfidfVectorizer
13
 
14
+ # Vérifier si CUDA est disponible
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
  print(f"🔹 Utilisation du périphérique : {device}")
17
 
18
+ # Lancer Marqo si nécessaire
19
+ os.system("docker run -d -p 8882:8882 marqoai/marqo")
20
+
21
+ # Vérifier que Marqo est bien lancé
22
+ def wait_for_marqo(timeout=30):
23
+ start_time = time.time()
24
+ while time.time() - start_time < timeout:
25
+ try:
26
+ response = requests.get("http://localhost:8882")
27
+ if response.status_code == 200:
28
+ print("✅ Marqo est prêt !")
29
+ return True
30
+ except requests.exceptions.ConnectionError:
31
+ print("⏳ En attente du démarrage de Marqo...")
32
+ time.sleep(3)
33
+ print("⛔ Marqo ne répond pas après 30 secondes. Vérifiez le démarrage.")
34
+ return False
35
+
36
+ if not wait_for_marqo():
37
+ exit(1)
38
+
39
+ # Connexion à Marqo avec gestion des erreurs
40
+ MAX_RETRIES = 5
41
+ for attempt in range(MAX_RETRIES):
42
+ try:
43
+ mq = marqo.Client(url="http://localhost:8882")
44
+ print("✅ Connexion à Marqo réussie !")
45
+ break
46
+ except marqo.errors.BackendCommunicationError:
47
+ print(f"⚠️ Erreur de connexion à Marqo (tentative {attempt + 1}/{MAX_RETRIES})")
48
+ time.sleep(5)
49
+ else:
50
+ print("⛔ Impossible de se connecter à Marqo après plusieurs tentatives.")
51
+ exit(1)
52
+
53
  # Définition des fichiers JSON
54
  PRODUCTS_FILE = "products.json"
55
  QA_FILE = "qa_sequences_output.json"
56
 
57
+ # Fonction pour charger les fichiers JSON de manière sécurisée
58
+ def safe_load_json(file_path):
59
+ if not os.path.exists(file_path):
60
+ print(f"⛔ Fichier introuvable : {file_path}")
61
+ return []
62
 
 
63
  try:
64
+ with open(file_path, "r", encoding="utf-8") as f:
65
+ data = json.load(f)
66
+ return data.get("products", []) if "products" in data else data
67
+ except json.JSONDecodeError:
68
+ print(f"⚠️ Erreur de décodage JSON dans {file_path}")
69
+ return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
+ products_data = safe_load_json(PRODUCTS_FILE)
72
+ qa_data = safe_load_json(QA_FILE)
73
+
74
+ # Création de l'index Marqo avec la bonne configuration
75
+ INDEX_NAME = "ecommerce_products"
76
+ if INDEX_NAME in [index["index_name"] for index in mq.get_indexes()["results"]]:
77
  mq.delete_index(INDEX_NAME)
 
 
78
 
79
+ mq.create_index(INDEX_NAME, model="open_clip/ViT-B-32/laion2B-b79K", normalize_embeddings=True)
 
80
  print("✅ Index Marqo créé avec succès !")
81
 
82
  # Ajouter les produits à Marqo
83
+ documents = [
84
+ {
85
+ "id": product.get("id", ""),
86
+ "title": product.get("title", ""),
87
+ "description": product.get("description", ""),
88
+ "price": product.get("price", 0),
89
+ "availability": product.get("availability", ""),
90
+ "category": product.get("category", ""),
 
 
 
91
  }
92
+ for product in products_data
93
+ ]
94
 
95
  mq.index(INDEX_NAME).add_documents(documents, tensor_fields=["title", "description"])
96
  print("✅ Produits indexés dans Marqo avec succès !")
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  # Interface Gradio
99
+ def search_products(query, category, min_price, max_price):
100
+ query = query.strip()
101
+ if not query:
102
+ return "❌ Veuillez entrer un terme de recherche valide."
103
+
104
+ min_price = float(min_price) if isinstance(min_price, (int, float)) else 0
105
+ max_price = float(max_price) if isinstance(max_price, (int, float)) else float("inf")
106
+
107
+ marqo_results = mq.index(INDEX_NAME).search(query, searchable_attributes=["title", "description"], limit=20)
108
+ results_df = pd.DataFrame(marqo_results["hits"])
109
+ return results_df
110
+
111
  app = gr.Interface(
112
  fn=search_products,
113
+ inputs=[gr.Textbox(label="Rechercher un produit"), gr.Textbox(label="Catégorie"), gr.Number(label="Prix min"), gr.Number(label="Prix max")],
114
+ outputs=gr.Dataframe(),
 
 
 
 
 
 
 
 
115
  )
116
 
117
  app.launch()