adel67460 commited on
Commit
b44a4e9
·
verified ·
1 Parent(s): e3505af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -20
app.py CHANGED
@@ -4,10 +4,14 @@ import pandas as pd
4
  import json
5
  from transformers import AutoModel, AutoProcessor
6
 
 
 
 
 
7
  # Charger le modèle Marqo avec gestion d'erreurs
8
  model_name = "Marqo/marqo-ecommerce-embeddings-L"
9
  try:
10
- model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
11
  processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
12
  except Exception as e:
13
  print(f"❌ Erreur lors du chargement du modèle : {e}")
@@ -19,7 +23,7 @@ def load_products_from_json():
19
  with open("products.json", "r", encoding="utf-8") as f:
20
  data = json.load(f)
21
 
22
- products = data.get("products", [])
23
  structured_products = []
24
 
25
  for product in products:
@@ -28,7 +32,7 @@ def load_products_from_json():
28
  structured_products.append({
29
  "id": product.get("id", "N/A"),
30
  "title": product.get("title", "Produit inconnu"),
31
- "description": product.get("description", "")[:200],
32
  "category": product.get("product_type", "Inconnu"),
33
  "brand": product.get("brand", "Sans marque"),
34
  "price": price,
@@ -50,7 +54,7 @@ def load_products_from_json():
50
  # Charger les produits
51
  products_data = load_products_from_json()
52
 
53
- def get_text_embeddings(texts, batch_size=8):
54
  if model is None or processor is None:
55
  print("❌ Modèle non chargé, impossible de générer les embeddings.")
56
  return torch.empty((0, 1024))
@@ -58,10 +62,10 @@ def get_text_embeddings(texts, batch_size=8):
58
  embeddings = []
59
  for i in range(0, len(texts), batch_size):
60
  batch = texts[i:i+batch_size]
61
- processed = processor(text=batch, return_tensors="pt", truncation=True, max_length=64, padding=True)
62
  with torch.no_grad():
63
  batch_embeddings = model.get_text_features(processed["input_ids"], normalize=True)
64
- embeddings.extend(batch_embeddings)
65
  return torch.stack(embeddings)
66
 
67
  if not products_data:
@@ -71,8 +75,6 @@ if not products_data:
71
  else:
72
  products_df = pd.DataFrame(products_data)
73
  title_embeddings = get_text_embeddings([prod["title"] for prod in products_data])
74
- description_embeddings = get_text_embeddings([prod["description"] for prod in products_data])
75
- category_embeddings = get_text_embeddings([prod["category"] for prod in products_data])
76
 
77
  def search_products(query, category, min_price, max_price):
78
  print(f"🔎 Recherche déclenchée avec: {query}, Catégorie: {category}, Prix: {min_price}-{max_price}")
@@ -81,20 +83,12 @@ def search_products(query, category, min_price, max_price):
81
  return pd.DataFrame()
82
 
83
  query_embedding = get_text_embeddings([query])[0]
84
-
85
- # Calcul des similarités pondérées
86
- title_sim = torch.nn.functional.cosine_similarity(query_embedding, title_embeddings, dim=1) * 0.5
87
- desc_sim = torch.nn.functional.cosine_similarity(query_embedding, description_embeddings, dim=1) * 0.3
88
- cat_sim = torch.nn.functional.cosine_similarity(query_embedding, category_embeddings, dim=1) * 0.2
89
- total_sim = title_sim + desc_sim + cat_sim
90
-
91
- # Normalisation des scores
92
- normalized_similarities = (total_sim - total_sim.min()) / (total_sim.max() - total_sim.min())
93
 
94
  results = products_df.copy()
95
  results["score"] = normalized_similarities.cpu().numpy()
96
 
97
- # Filtrage strict des résultats
98
  filtered_results = results[
99
  (results["price"] >= min_price) &
100
  (results["price"] <= max_price) &
@@ -130,5 +124,3 @@ if model is not None and not products_df.empty:
130
  app = create_ui()
131
  app.queue()
132
  app.launch(server_name="0.0.0.0", server_port=7860, share=True)
133
- else:
134
- print("❌ L'application n'a pas pu être initialisée, vérifiez les erreurs ci-dessus.")
 
4
  import json
5
  from transformers import AutoModel, AutoProcessor
6
 
7
+ # Vérifier si CUDA est disponible (GPU)
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ print(f"🔹 Utilisation du périphérique : {device}")
10
+
11
  # Charger le modèle Marqo avec gestion d'erreurs
12
  model_name = "Marqo/marqo-ecommerce-embeddings-L"
13
  try:
14
+ model = AutoModel.from_pretrained(model_name, trust_remote_code=True).to(device)
15
  processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
16
  except Exception as e:
17
  print(f"❌ Erreur lors du chargement du modèle : {e}")
 
23
  with open("products.json", "r", encoding="utf-8") as f:
24
  data = json.load(f)
25
 
26
+ products = data.get("products", [])[:1000] # Limiter à 1000 produits max pour éviter crash mémoire
27
  structured_products = []
28
 
29
  for product in products:
 
32
  structured_products.append({
33
  "id": product.get("id", "N/A"),
34
  "title": product.get("title", "Produit inconnu"),
35
+ "description": product.get("description", "")[:100],
36
  "category": product.get("product_type", "Inconnu"),
37
  "brand": product.get("brand", "Sans marque"),
38
  "price": price,
 
54
  # Charger les produits
55
  products_data = load_products_from_json()
56
 
57
+ def get_text_embeddings(texts, batch_size=16):
58
  if model is None or processor is None:
59
  print("❌ Modèle non chargé, impossible de générer les embeddings.")
60
  return torch.empty((0, 1024))
 
62
  embeddings = []
63
  for i in range(0, len(texts), batch_size):
64
  batch = texts[i:i+batch_size]
65
+ processed = processor(text=batch, return_tensors="pt", truncation=True, max_length=64, padding=True).to(device)
66
  with torch.no_grad():
67
  batch_embeddings = model.get_text_features(processed["input_ids"], normalize=True)
68
+ embeddings.extend(batch_embeddings.cpu()) # Stocker en CPU pour libérer la mémoire GPU
69
  return torch.stack(embeddings)
70
 
71
  if not products_data:
 
75
  else:
76
  products_df = pd.DataFrame(products_data)
77
  title_embeddings = get_text_embeddings([prod["title"] for prod in products_data])
 
 
78
 
79
  def search_products(query, category, min_price, max_price):
80
  print(f"🔎 Recherche déclenchée avec: {query}, Catégorie: {category}, Prix: {min_price}-{max_price}")
 
83
  return pd.DataFrame()
84
 
85
  query_embedding = get_text_embeddings([query])[0]
86
+ title_sim = torch.nn.functional.cosine_similarity(query_embedding, title_embeddings, dim=1)
87
+ normalized_similarities = (title_sim - title_sim.min()) / (title_sim.max() - title_sim.min())
 
 
 
 
 
 
 
88
 
89
  results = products_df.copy()
90
  results["score"] = normalized_similarities.cpu().numpy()
91
 
 
92
  filtered_results = results[
93
  (results["price"] >= min_price) &
94
  (results["price"] <= max_price) &
 
124
  app = create_ui()
125
  app.queue()
126
  app.launch(server_name="0.0.0.0", server_port=7860, share=True)