File size: 5,058 Bytes
b037067
58ebef2
 
ba6ee1a
d12b30b
fcc0fae
ac94672
d12b30b
c6f84e0
71f2bcc
e1e8719
b44a4e9
 
 
e1e8719
 
 
 
 
 
 
 
 
 
7a9155e
 
 
e1e8719
7da6546
 
 
 
ac94672
7da6546
 
 
 
 
 
ce6bc00
7da6546
 
 
e1e8719
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7da6546
 
 
 
 
e1e8719
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7da6546
e1e8719
71f2bcc
 
e1e8719
 
 
 
 
 
 
 
 
 
71f2bcc
ac94672
3d1180c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import json
import re
import gradio as gr
import torch
import pandas as pd
from scipy.sparse import csr_matrix
from transformers import AutoModel, AutoProcessor
from sklearn.feature_extraction.text import TfidfVectorizer

# 📌 Vérifier si CUDA est disponible
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🔹 Utilisation du périphérique : {device}")

# 📌 Chargement du modèle Marqo Embeddings
model_name = "Marqo/marqo-ecommerce-embeddings-L"
print(f"🔄 Chargement du modèle {model_name}...")

model = AutoModel.from_pretrained(model_name, trust_remote_code=True).to(device)
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)

print("✅ Modèle chargé avec succès !")

# 📌 Définition des fichiers JSON
PRODUCTS_FILE = "products.json"
QA_FILE = "qa_sequences_output.json"

# 📌 Fonction pour charger les fichiers JSON
def safe_load_json(file_path):
    if not os.path.exists(file_path):
        print(f"⛔ Fichier introuvable : {file_path}")
        return []
    try:
        with open(file_path, "r", encoding="utf-8") as f:
            data = json.load(f)
            return data.get("products", []) if "products" in data else data
    except json.JSONDecodeError:
        print(f"⚠️ Erreur de décodage JSON dans {file_path}")
        return []

products_data = safe_load_json(PRODUCTS_FILE)
qa_data = safe_load_json(QA_FILE)

# 📌 Générer des embeddings pour les produits
def get_text_embeddings(texts):
    """Génère des embeddings à partir du modèle Marqo"""
    with torch.no_grad():
        processed_texts = processor(text=texts, return_tensors="pt", truncation=True, max_length=64, padding=True).to(device)
        embeddings = model(**processed_texts).last_hidden_state.mean(dim=1)
    return embeddings.cpu().numpy()

# Création des embeddings pour tous les produits
print("🛠️ Génération des embeddings des produits...")
product_embeddings = get_text_embeddings([prod["title"] + " " + prod["description"] for prod in products_data])
print("✅ Embeddings générés et sauvegardés !")

# 📌 TF-IDF Vectorizer pour une recherche hybride
vectorizer = TfidfVectorizer(stop_words="english")
tfidf_matrix = vectorizer.fit_transform([prod["title"] + " " + prod["description"] for prod in products_data])

# 📌 Recherche hybride avec Marqo embeddings + TF-IDF
def search_products(query, category, min_price, max_price, weight_tfidf=0.5, weight_marqo=0.5):
    if not query.strip():
        return "❌ Veuillez entrer un terme de recherche valide."

    min_price = float(min_price) if isinstance(min_price, (int, float)) else 0
    max_price = float(max_price) if isinstance(max_price, (int, float)) else float("inf")

    # 📌 Génération de l'embedding de la requête utilisateur
    query_embedding = get_text_embeddings([query])[0]

    # 📌 Calcul de similarité cosinus entre la requête et les produits
    marqo_scores = (product_embeddings @ query_embedding).tolist()

    # 📌 TF-IDF Similarité
    query_vector_sparse = csr_matrix(vectorizer.transform([query]))
    tfidf_scores = (tfidf_matrix * query_vector_sparse.T).toarray().flatten()

    # 📌 Normalisation des scores Marqo
    if len(marqo_scores) > 0 and max(marqo_scores) != min(marqo_scores):
        marqo_scores = (pd.Series(marqo_scores) - min(marqo_scores)) / (max(marqo_scores) - min(marqo_scores) + 1e-6)
    else:
        marqo_scores = [1] * len(products_data)

    # 📌 Normalisation des scores TF-IDF
    if len(tfidf_scores) > 0 and max(tfidf_scores) != min(tfidf_scores):
        tfidf_scores_norm = (tfidf_scores - min(tfidf_scores)) / (max(tfidf_scores) - min(tfidf_scores) + 1e-6)
    else:
        tfidf_scores_norm = [1] * len(products_data)

    # 📌 Fusionner les scores TF-IDF et Marqo embeddings
    final_scores = weight_tfidf * tfidf_scores_norm[:len(products_data)] + weight_marqo * marqo_scores

    # 📌 Création d'un DataFrame avec les scores finaux
    results_df = pd.DataFrame(products_data)
    results_df["score"] = final_scores

    # 📌 Filtrage des résultats par prix et disponibilité
    results_df = results_df[
        (results_df["price"].fillna(0).astype(float) >= min_price) &
        (results_df["price"].fillna(0).astype(float) <= max_price) &
        (results_df["availability"].fillna("").str.lower() == "in stock")
    ]

    if category and category != "Toutes":
        results_df = results_df[results_df["category"].str.contains(category, case=False, na=False)]
    
    return results_df.sort_values(by="score", ascending=False).head(20)

# 📌 Interface Gradio
app = gr.Interface(
    fn=search_products,
    inputs=[
        gr.Textbox(label="Rechercher un produit"), 
        gr.Textbox(label="Catégorie"), 
        gr.Number(label="Prix min"), 
        gr.Number(label="Prix max")
    ],
    outputs=[
        gr.Dataframe(headers=["ID", "Titre", "Description", "Prix", "Disponibilité", "Score"], 
                     datatype=["str", "str", "str", "number", "str", "number"])
    ]
)

app.launch()