Spaces:
Runtime error
Runtime error
import os | |
import json | |
import re | |
import gradio as gr | |
import torch | |
import pandas as pd | |
from scipy.sparse import csr_matrix | |
from transformers import AutoModel, AutoProcessor | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
# 📌 Vérifier si CUDA est disponible | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"🔹 Utilisation du périphérique : {device}") | |
# 📌 Chargement du modèle Marqo Embeddings | |
model_name = "Marqo/marqo-ecommerce-embeddings-L" | |
print(f"🔄 Chargement du modèle {model_name}...") | |
model = AutoModel.from_pretrained(model_name, trust_remote_code=True).to(device) | |
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) | |
print("✅ Modèle chargé avec succès !") | |
# 📌 Définition des fichiers JSON | |
PRODUCTS_FILE = "products.json" | |
QA_FILE = "qa_sequences_output.json" | |
# 📌 Fonction pour charger les fichiers JSON | |
def safe_load_json(file_path): | |
if not os.path.exists(file_path): | |
print(f"⛔ Fichier introuvable : {file_path}") | |
return [] | |
try: | |
with open(file_path, "r", encoding="utf-8") as f: | |
data = json.load(f) | |
return data.get("products", []) if "products" in data else data | |
except json.JSONDecodeError: | |
print(f"⚠️ Erreur de décodage JSON dans {file_path}") | |
return [] | |
products_data = safe_load_json(PRODUCTS_FILE) | |
qa_data = safe_load_json(QA_FILE) | |
# 📌 Générer des embeddings pour les produits | |
def get_text_embeddings(texts): | |
"""Génère des embeddings à partir du modèle Marqo""" | |
with torch.no_grad(): | |
processed_texts = processor(text=texts, return_tensors="pt", truncation=True, max_length=64, padding=True).to(device) | |
embeddings = model(**processed_texts).last_hidden_state.mean(dim=1) | |
return embeddings.cpu().numpy() | |
# Création des embeddings pour tous les produits | |
print("🛠️ Génération des embeddings des produits...") | |
product_embeddings = get_text_embeddings([prod["title"] + " " + prod["description"] for prod in products_data]) | |
print("✅ Embeddings générés et sauvegardés !") | |
# 📌 TF-IDF Vectorizer pour une recherche hybride | |
vectorizer = TfidfVectorizer(stop_words="english") | |
tfidf_matrix = vectorizer.fit_transform([prod["title"] + " " + prod["description"] for prod in products_data]) | |
# 📌 Recherche hybride avec Marqo embeddings + TF-IDF | |
def search_products(query, category, min_price, max_price, weight_tfidf=0.5, weight_marqo=0.5): | |
if not query.strip(): | |
return "❌ Veuillez entrer un terme de recherche valide." | |
min_price = float(min_price) if isinstance(min_price, (int, float)) else 0 | |
max_price = float(max_price) if isinstance(max_price, (int, float)) else float("inf") | |
# 📌 Génération de l'embedding de la requête utilisateur | |
query_embedding = get_text_embeddings([query])[0] | |
# 📌 Calcul de similarité cosinus entre la requête et les produits | |
marqo_scores = (product_embeddings @ query_embedding).tolist() | |
# 📌 TF-IDF Similarité | |
query_vector_sparse = csr_matrix(vectorizer.transform([query])) | |
tfidf_scores = (tfidf_matrix * query_vector_sparse.T).toarray().flatten() | |
# 📌 Normalisation des scores Marqo | |
if len(marqo_scores) > 0 and max(marqo_scores) != min(marqo_scores): | |
marqo_scores = (pd.Series(marqo_scores) - min(marqo_scores)) / (max(marqo_scores) - min(marqo_scores) + 1e-6) | |
else: | |
marqo_scores = [1] * len(products_data) | |
# 📌 Normalisation des scores TF-IDF | |
if len(tfidf_scores) > 0 and max(tfidf_scores) != min(tfidf_scores): | |
tfidf_scores_norm = (tfidf_scores - min(tfidf_scores)) / (max(tfidf_scores) - min(tfidf_scores) + 1e-6) | |
else: | |
tfidf_scores_norm = [1] * len(products_data) | |
# 📌 Fusionner les scores TF-IDF et Marqo embeddings | |
final_scores = weight_tfidf * tfidf_scores_norm[:len(products_data)] + weight_marqo * marqo_scores | |
# 📌 Création d'un DataFrame avec les scores finaux | |
results_df = pd.DataFrame(products_data) | |
results_df["score"] = final_scores | |
# 📌 Filtrage des résultats par prix et disponibilité | |
results_df = results_df[ | |
(results_df["price"].fillna(0).astype(float) >= min_price) & | |
(results_df["price"].fillna(0).astype(float) <= max_price) & | |
(results_df["availability"].fillna("").str.lower() == "in stock") | |
] | |
if category and category != "Toutes": | |
results_df = results_df[results_df["category"].str.contains(category, case=False, na=False)] | |
return results_df.sort_values(by="score", ascending=False).head(20) | |
# 📌 Interface Gradio | |
app = gr.Interface( | |
fn=search_products, | |
inputs=[ | |
gr.Textbox(label="Rechercher un produit"), | |
gr.Textbox(label="Catégorie"), | |
gr.Number(label="Prix min"), | |
gr.Number(label="Prix max") | |
], | |
outputs=[ | |
gr.Dataframe(headers=["ID", "Titre", "Description", "Prix", "Disponibilité", "Score"], | |
datatype=["str", "str", "str", "number", "str", "number"]) | |
] | |
) | |
app.launch() |