Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -4,8 +4,8 @@ import re
|
|
4 |
import gradio as gr
|
5 |
import torch
|
6 |
import pandas as pd
|
7 |
-
import faiss
|
8 |
import time
|
|
|
9 |
from scipy.sparse import csr_matrix
|
10 |
from transformers import AutoModel, AutoProcessor
|
11 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
@@ -38,18 +38,16 @@ for attempt in range(MAX_RETRIES):
|
|
38 |
print("⛔ Échec final du chargement du modèle.")
|
39 |
model, processor = None, None
|
40 |
|
41 |
-
# Fonction pour charger
|
42 |
def load_data():
|
43 |
products_data, qa_data = [], []
|
44 |
|
45 |
-
# Charger les produits
|
46 |
if os.path.exists(PRODUCTS_FILE):
|
47 |
with open(PRODUCTS_FILE, "r", encoding="utf-8") as f:
|
48 |
products_data = json.load(f).get("products", [])
|
49 |
else:
|
50 |
print(f"⛔ Fichier introuvable : {PRODUCTS_FILE}")
|
51 |
|
52 |
-
# Charger les questions-réponses
|
53 |
if os.path.exists(QA_FILE):
|
54 |
with open(QA_FILE, "r", encoding="utf-8") as f:
|
55 |
qa_data = json.load(f)
|
@@ -61,10 +59,49 @@ def load_data():
|
|
61 |
products_data, qa_data = load_data()
|
62 |
|
63 |
# Associer les questions-réponses aux produits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
for product in products_data:
|
65 |
-
|
66 |
-
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
# Prétraitement du texte
|
70 |
def preprocess(text: str) -> str:
|
@@ -73,74 +110,61 @@ def preprocess(text: str) -> str:
|
|
73 |
text = re.sub(r'[^\w\s]', '', text)
|
74 |
return text.strip()
|
75 |
|
76 |
-
# Génération des embeddings des produits optimisée
|
77 |
-
def get_text_embeddings(texts, batch_size=32):
|
78 |
-
if not texts: # Vérifier que la liste de textes n'est pas vide
|
79 |
-
print("⚠️ Avertissement : Aucun texte à encoder. Retour d'une matrice vide.")
|
80 |
-
return torch.zeros((0, model.config.hidden_size)).numpy()
|
81 |
-
|
82 |
-
with torch.no_grad():
|
83 |
-
processed_texts = processor(text=texts, return_tensors="pt", truncation=True, max_length=64, padding=True).to(device)
|
84 |
-
embeddings = model.get_text_features(processed_texts["input_ids"], normalize=True)
|
85 |
-
return embeddings.cpu().numpy()
|
86 |
-
|
87 |
-
print("🛠️ Génération des embeddings des produits...")
|
88 |
-
product_embeddings = get_text_embeddings([prod["title"] + " " + prod["description"] for prod in products_data])
|
89 |
-
print("✅ Embeddings générés et sauvegardés !")
|
90 |
-
|
91 |
-
# Optimisation FAISS avec index dynamique
|
92 |
-
d = product_embeddings.shape[1]
|
93 |
-
nlist = max(10, len(product_embeddings) // 10) # Ajuste dynamiquement le nombre de centroids
|
94 |
-
|
95 |
-
if len(product_embeddings) < 4000:
|
96 |
-
index = faiss.IndexFlatL2(d) # Index simple pour petits datasets
|
97 |
-
else:
|
98 |
-
index = faiss.IndexIVFFlat(faiss.IndexFlatL2(d), d, nlist)
|
99 |
-
index.train(product_embeddings)
|
100 |
-
|
101 |
-
index.add(product_embeddings)
|
102 |
-
|
103 |
# TF-IDF Vectorizer
|
104 |
vectorizer = TfidfVectorizer(stop_words="english")
|
105 |
tfidf_matrix = vectorizer.fit_transform([prod["title"] + " " + prod["description"] for prod in products_data])
|
106 |
|
107 |
-
# Recherche hybride
|
108 |
-
def search_products(query, category, min_price, max_price, weight_tfidf=0.5,
|
109 |
query = preprocess(query)
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
# TF-IDF Similarité
|
112 |
-
query_vector_sparse = csr_matrix(vectorizer.transform([query]))
|
113 |
tfidf_scores = (tfidf_matrix * query_vector_sparse.T).toarray().flatten()
|
114 |
-
|
115 |
-
#
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
#
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
|
|
|
|
|
|
133 |
]
|
134 |
-
|
135 |
if category and category != "Toutes":
|
136 |
-
|
137 |
|
138 |
-
|
139 |
-
filtered_results["qa_info"] = filtered_results["id"].apply(
|
140 |
-
lambda prod_id: [prod["qa_info"] for prod in products_data if prod["id"] == prod_id][0]
|
141 |
-
)
|
142 |
-
|
143 |
-
return filtered_results.sort_values(by="score", ascending=False).head(20)
|
144 |
|
145 |
# Interface Gradio
|
146 |
app = gr.Interface(
|
|
|
4 |
import gradio as gr
|
5 |
import torch
|
6 |
import pandas as pd
|
|
|
7 |
import time
|
8 |
+
import marqo
|
9 |
from scipy.sparse import csr_matrix
|
10 |
from transformers import AutoModel, AutoProcessor
|
11 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
|
|
38 |
print("⛔ Échec final du chargement du modèle.")
|
39 |
model, processor = None, None
|
40 |
|
41 |
+
# Fonction pour charger les fichiers JSON
|
42 |
def load_data():
|
43 |
products_data, qa_data = [], []
|
44 |
|
|
|
45 |
if os.path.exists(PRODUCTS_FILE):
|
46 |
with open(PRODUCTS_FILE, "r", encoding="utf-8") as f:
|
47 |
products_data = json.load(f).get("products", [])
|
48 |
else:
|
49 |
print(f"⛔ Fichier introuvable : {PRODUCTS_FILE}")
|
50 |
|
|
|
51 |
if os.path.exists(QA_FILE):
|
52 |
with open(QA_FILE, "r", encoding="utf-8") as f:
|
53 |
qa_data = json.load(f)
|
|
|
59 |
products_data, qa_data = load_data()
|
60 |
|
61 |
# Associer les questions-réponses aux produits
|
62 |
+
def associate_qa_with_products(products, qa_data):
|
63 |
+
for product in products:
|
64 |
+
product["qa_info"] = []
|
65 |
+
product_name = product.get("title", "").lower()
|
66 |
+
product_desc = product.get("description", "").lower()
|
67 |
+
|
68 |
+
for qa in qa_data:
|
69 |
+
question = qa.get("question", "").lower()
|
70 |
+
if product_name in question or product_desc in question:
|
71 |
+
product["qa_info"].append(qa)
|
72 |
+
|
73 |
+
return products
|
74 |
+
|
75 |
+
products_data = associate_qa_with_products(products_data, qa_data)
|
76 |
+
|
77 |
+
# Connexion au serveur Marqo
|
78 |
+
mq = marqo.Client(url="http://localhost:8882") # Port par défaut de Marqo
|
79 |
+
INDEX_NAME = "ecommerce_products"
|
80 |
+
|
81 |
+
# Supprimer et recréer l'index
|
82 |
+
try:
|
83 |
+
mq.delete_index(INDEX_NAME)
|
84 |
+
except:
|
85 |
+
pass
|
86 |
+
|
87 |
+
mq.create_index(INDEX_NAME, settings={"index_defaults": {"model": "open_clip/ViT-B-32/laion2B-s34B-b79K"}})
|
88 |
+
|
89 |
+
# Ajouter les produits à Marqo
|
90 |
+
documents = []
|
91 |
for product in products_data:
|
92 |
+
doc = {
|
93 |
+
"id": product["id"],
|
94 |
+
"title": product["title"],
|
95 |
+
"description": product["description"],
|
96 |
+
"price": product["price"],
|
97 |
+
"availability": product["availability"],
|
98 |
+
"category": product["category"],
|
99 |
+
"qa_info": product.get("qa_info", []),
|
100 |
+
}
|
101 |
+
documents.append(doc)
|
102 |
+
|
103 |
+
mq.index(INDEX_NAME).add_documents(documents, tensor_fields=["title", "description"])
|
104 |
+
print("✅ Produits indexés dans Marqo !")
|
105 |
|
106 |
# Prétraitement du texte
|
107 |
def preprocess(text: str) -> str:
|
|
|
110 |
text = re.sub(r'[^\w\s]', '', text)
|
111 |
return text.strip()
|
112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
# TF-IDF Vectorizer
|
114 |
vectorizer = TfidfVectorizer(stop_words="english")
|
115 |
tfidf_matrix = vectorizer.fit_transform([prod["title"] + " " + prod["description"] for prod in products_data])
|
116 |
|
117 |
+
# Recherche hybride avec Marqo + TF-IDF
|
118 |
+
def search_products(query, category, min_price, max_price, weight_tfidf=0.5, weight_marqo=0.5):
|
119 |
query = preprocess(query)
|
120 |
+
|
121 |
+
# Recherche Marqo (top 50 résultats)
|
122 |
+
marqo_results = mq.index(INDEX_NAME).search(query, searchable_attributes=["title", "description"], limit=50)
|
123 |
+
|
124 |
+
# Récupérer les résultats Marqo
|
125 |
+
marqo_products = []
|
126 |
+
marqo_scores = []
|
127 |
+
for hit in marqo_results["hits"]:
|
128 |
+
marqo_products.append(hit)
|
129 |
+
marqo_scores.append(hit["_score"])
|
130 |
+
|
131 |
+
# Normaliser les scores Marqo
|
132 |
+
if len(marqo_scores) > 0:
|
133 |
+
marqo_scores = (pd.Series(marqo_scores) - min(marqo_scores)) / (max(marqo_scores) - min(marqo_scores) + 1e-6)
|
134 |
+
else:
|
135 |
+
marqo_scores = [0] * len(marqo_products)
|
136 |
+
|
137 |
# TF-IDF Similarité
|
138 |
+
query_vector_sparse = csr_matrix(vectorizer.transform([query]))
|
139 |
tfidf_scores = (tfidf_matrix * query_vector_sparse.T).toarray().flatten()
|
140 |
+
|
141 |
+
# Normaliser les scores TF-IDF
|
142 |
+
if len(tfidf_scores) > 0:
|
143 |
+
tfidf_scores_norm = (tfidf_scores - min(tfidf_scores)) / (max(tfidf_scores) - min(tfidf_scores) + 1e-6)
|
144 |
+
else:
|
145 |
+
tfidf_scores_norm = [0] * len(marqo_products)
|
146 |
+
|
147 |
+
# Fusionner les scores TF-IDF et Marqo
|
148 |
+
final_scores = weight_tfidf * tfidf_scores_norm[:len(marqo_products)] + weight_marqo * marqo_scores
|
149 |
+
|
150 |
+
# Ajouter le score final aux produits
|
151 |
+
for i, product in enumerate(marqo_products):
|
152 |
+
product["score"] = final_scores[i]
|
153 |
+
|
154 |
+
# Convertir en DataFrame
|
155 |
+
results_df = pd.DataFrame(marqo_products)
|
156 |
+
|
157 |
+
# Filtrer les résultats par prix et disponibilité
|
158 |
+
results_df = results_df[
|
159 |
+
(results_df["price"] >= min_price) &
|
160 |
+
(results_df["price"] <= max_price) &
|
161 |
+
(results_df["availability"] == "in stock")
|
162 |
]
|
163 |
+
|
164 |
if category and category != "Toutes":
|
165 |
+
results_df = results_df[results_df["category"].str.contains(category, case=False, na=False)]
|
166 |
|
167 |
+
return results_df.sort_values(by="score", ascending=False).head(20)
|
|
|
|
|
|
|
|
|
|
|
168 |
|
169 |
# Interface Gradio
|
170 |
app = gr.Interface(
|