adel67460 commited on
Commit
7a9155e
·
verified ·
1 Parent(s): 58ebef2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -25
app.py CHANGED
@@ -10,14 +10,15 @@ from scipy.sparse import csr_matrix
10
  from transformers import AutoModel, AutoProcessor
11
  from sklearn.feature_extraction.text import TfidfVectorizer
12
 
13
- # Installation automatique des dépendances
14
- os.system("pip install faiss-cpu")
15
-
16
  # Vérifier si CUDA est disponible (GPU)
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
  print(f"🔹 Utilisation du périphérique : {device}")
19
 
20
- # Chargement sécurisé du modèle Marqo avec plusieurs tentatives
 
 
 
 
21
  MAX_RETRIES = 3
22
  model_name = "Marqo/marqo-ecommerce-embeddings-L"
23
 
@@ -39,33 +40,31 @@ for attempt in range(MAX_RETRIES):
39
 
40
  # Fonction pour charger et fusionner les données
41
  def load_data():
42
- # Charger les produits de products.json
43
- try:
44
- with open("/mnt/data/products.json", "r", encoding="utf-8") as f:
 
 
45
  products_data = json.load(f).get("products", [])
46
- except Exception as e:
47
- print(f"⚠️ Erreur lors du chargement de products.json: {e}")
48
- products_data = []
49
 
50
- # Charger les questions-réponses de qa_sequences_output.json
51
- try:
52
- with open("/mnt/data/qa_sequences_output.json", "r", encoding="utf-8") as f:
53
  qa_data = json.load(f)
54
- except Exception as e:
55
- print(f"⚠️ Erreur lors du chargement de qa_sequences_output.json: {e}")
56
- qa_data = []
57
 
58
- # Associer les informations QA aux produits
59
- enriched_products = []
60
- for product in products_data:
61
- product_name = product.get("title", "").lower()
62
- related_qa = [qa for qa in qa_data if product_name in qa.get("question", "").lower()]
63
- product["qa_info"] = related_qa # Ajouter les questions/réponses au produit
64
- enriched_products.append(product)
65
 
66
- return enriched_products
67
 
68
- products_data = load_data()
 
 
 
 
69
 
70
  # Prétraitement du texte
71
  def preprocess(text: str) -> str:
@@ -76,6 +75,10 @@ def preprocess(text: str) -> str:
76
 
77
  # Génération des embeddings des produits optimisée
78
  def get_text_embeddings(texts, batch_size=32):
 
 
 
 
79
  with torch.no_grad():
80
  processed_texts = processor(text=texts, return_tensors="pt", truncation=True, max_length=64, padding=True).to(device)
81
  embeddings = model.get_text_features(processed_texts["input_ids"], normalize=True)
 
10
  from transformers import AutoModel, AutoProcessor
11
  from sklearn.feature_extraction.text import TfidfVectorizer
12
 
 
 
 
13
  # Vérifier si CUDA est disponible (GPU)
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
  print(f"🔹 Utilisation du périphérique : {device}")
16
 
17
+ # Définition des chemins des fichiers JSON
18
+ PRODUCTS_FILE = "products.json"
19
+ QA_FILE = "qa_sequences_output.json"
20
+
21
+ # Chargement sécurisé du modèle Marqo
22
  MAX_RETRIES = 3
23
  model_name = "Marqo/marqo-ecommerce-embeddings-L"
24
 
 
40
 
41
  # Fonction pour charger et fusionner les données
42
  def load_data():
43
+ products_data, qa_data = [], []
44
+
45
+ # Charger les produits
46
+ if os.path.exists(PRODUCTS_FILE):
47
+ with open(PRODUCTS_FILE, "r", encoding="utf-8") as f:
48
  products_data = json.load(f).get("products", [])
49
+ else:
50
+ print(f" Fichier introuvable : {PRODUCTS_FILE}")
 
51
 
52
+ # Charger les questions-réponses
53
+ if os.path.exists(QA_FILE):
54
+ with open(QA_FILE, "r", encoding="utf-8") as f:
55
  qa_data = json.load(f)
56
+ else:
57
+ print(f" Fichier introuvable : {QA_FILE}")
 
58
 
59
+ return products_data, qa_data
 
 
 
 
 
 
60
 
61
+ products_data, qa_data = load_data()
62
 
63
+ # Associer les questions-réponses aux produits
64
+ for product in products_data:
65
+ product_name = product.get("title", "").lower()
66
+ related_qa = [qa for qa in qa_data if product_name in qa.get("question", "").lower()]
67
+ product["qa_info"] = related_qa # Ajouter les questions/réponses au produit
68
 
69
  # Prétraitement du texte
70
  def preprocess(text: str) -> str:
 
75
 
76
  # Génération des embeddings des produits optimisée
77
  def get_text_embeddings(texts, batch_size=32):
78
+ if not texts: # Vérifier que la liste de textes n'est pas vide
79
+ print("⚠️ Avertissement : Aucun texte à encoder. Retour d'une matrice vide.")
80
+ return torch.zeros((0, model.config.hidden_size)).numpy()
81
+
82
  with torch.no_grad():
83
  processed_texts = processor(text=texts, return_tensors="pt", truncation=True, max_length=64, padding=True).to(device)
84
  embeddings = model.get_text_features(processed_texts["input_ids"], normalize=True)