Spaces:
Sleeping
Sleeping
Update embed_data.py
Browse files- embed_data.py +47 -24
embed_data.py
CHANGED
|
@@ -2,14 +2,14 @@ from pymongo import MongoClient
|
|
| 2 |
from transformers import BertTokenizer, BertModel
|
| 3 |
import torch
|
| 4 |
from torch.nn import Embedding
|
| 5 |
-
import numpy as np
|
| 6 |
-
from sklearn.metrics.pairwise import cosine_similarity
|
| 7 |
|
| 8 |
# MongoDB Atlas 연결 설정
|
| 9 |
-
client = MongoClient(
|
|
|
|
|
|
|
| 10 |
db = client["two_tower_model"]
|
| 11 |
product_collection = db["product_tower"]
|
| 12 |
-
user_collection = db[
|
| 13 |
product_embedding_collection = db["product_embeddings"] # 상품 임베딩을 저장할 컬렉션
|
| 14 |
user_embedding_collection = db["user_embeddings"] # 사용자 임베딩을 저장할 컬렉션
|
| 15 |
|
|
@@ -23,6 +23,7 @@ max_height = 250
|
|
| 23 |
min_weight = 30
|
| 24 |
max_weight = 200
|
| 25 |
|
|
|
|
| 26 |
# 상품 타워: 데이터 임베딩
|
| 27 |
def embed_product_data(product_data):
|
| 28 |
# 상품명과 상세 정보 임베딩 (BERT)
|
|
@@ -45,47 +46,63 @@ def embed_product_data(product_data):
|
|
| 45 |
|
| 46 |
# 모든 임베딩 벡터 차원 맞추기
|
| 47 |
category_embedding = category_embedding.view(1, -1) # 2D로 변환
|
| 48 |
-
color_embedding = color_embedding.view(1, -1)
|
| 49 |
|
| 50 |
# 최종 임베딩 벡터 결합
|
| 51 |
-
combined_embedding = torch.cat(
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
return product_embedding.detach().numpy()
|
| 55 |
|
|
|
|
| 56 |
# 사용자 타워: 데이터 임베딩
|
| 57 |
def embed_user_data(user_data):
|
| 58 |
# 나이, 성별, 키, 몸무게 임베딩 (임베딩 레이어)
|
| 59 |
-
embedding_layer = Embedding(
|
|
|
|
|
|
|
| 60 |
|
| 61 |
# 예를 들어 성별을 'M'은 0, 'F'는 1로 인코딩
|
| 62 |
-
gender_id = 0 if user_data[
|
| 63 |
|
| 64 |
# 스케일링 적용
|
| 65 |
-
height = user_data[
|
| 66 |
-
weight = user_data[
|
| 67 |
|
| 68 |
if not (min_height <= height <= max_height):
|
| 69 |
-
raise ValueError(
|
|
|
|
|
|
|
| 70 |
if not (min_weight <= weight <= max_weight):
|
| 71 |
-
raise ValueError(
|
|
|
|
|
|
|
| 72 |
|
| 73 |
scaled_height = (height - min_height) * 99 // (max_height - min_height)
|
| 74 |
scaled_weight = (weight - min_weight) * 99 // (max_weight - min_weight)
|
| 75 |
-
|
| 76 |
-
age_embedding = embedding_layer(torch.tensor([user_data[
|
| 77 |
gender_embedding = embedding_layer(torch.tensor([gender_id])).view(1, -1)
|
| 78 |
height_embedding = embedding_layer(torch.tensor([scaled_height])).view(1, -1)
|
| 79 |
weight_embedding = embedding_layer(torch.tensor([scaled_weight])).view(1, -1)
|
| 80 |
|
| 81 |
# 최종 임베딩 벡터 결합
|
| 82 |
-
combined_embedding = torch.cat(
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
return user_embedding.detach().numpy()
|
| 86 |
|
|
|
|
| 87 |
# MongoDB Atlas에서 데이터 가져오기
|
| 88 |
-
all_products = product_collection.find()
|
| 89 |
all_users = user_collection.find() # 모든 사용자 데이터 가져오기
|
| 90 |
|
| 91 |
# 상품 임베딩 수행
|
|
@@ -96,10 +113,14 @@ for product_data in all_products:
|
|
| 96 |
# MongoDB Atlas의 product_embeddings 컬렉션에 임베딩 저장
|
| 97 |
product_embedding_collection.update_one(
|
| 98 |
{"product_id": product_data["product_id"]}, # product_id 기준으로 찾기
|
| 99 |
-
{
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
)
|
| 102 |
-
print(f"Embedding saved to MongoDB Atlas for Product ID {product_data['product_id']}.")
|
| 103 |
|
| 104 |
# 사용자 임베딩 수행
|
| 105 |
for user_data in all_users:
|
|
@@ -110,9 +131,11 @@ for user_data in all_users:
|
|
| 110 |
# MongoDB Atlas의 user_embeddings 컬렉션에 임베딩 저장
|
| 111 |
user_embedding_collection.update_one(
|
| 112 |
{"user_id": user_data["user_id"]}, # user_id 기준으로 찾기
|
| 113 |
-
{
|
| 114 |
-
|
|
|
|
|
|
|
| 115 |
)
|
| 116 |
print(f"Embedding saved to MongoDB Atlas for user_id {user_data['user_id']}.")
|
| 117 |
except ValueError as e:
|
| 118 |
-
print(f"Skipping user_id {user_data['user_id']} due to error: {e}")
|
|
|
|
| 2 |
from transformers import BertTokenizer, BertModel
|
| 3 |
import torch
|
| 4 |
from torch.nn import Embedding
|
|
|
|
|
|
|
| 5 |
|
| 6 |
# MongoDB Atlas 연결 설정
|
| 7 |
+
client = MongoClient(
|
| 8 |
+
"mongodb+srv://waseoke:[email protected]/test?retryWrites=true&w=majority&tls=true&tlsAllowInvalidCertificates=true"
|
| 9 |
+
)
|
| 10 |
db = client["two_tower_model"]
|
| 11 |
product_collection = db["product_tower"]
|
| 12 |
+
user_collection = db["user_tower"]
|
| 13 |
product_embedding_collection = db["product_embeddings"] # 상품 임베딩을 저장할 컬렉션
|
| 14 |
user_embedding_collection = db["user_embeddings"] # 사용자 임베딩을 저장할 컬렉션
|
| 15 |
|
|
|
|
| 23 |
min_weight = 30
|
| 24 |
max_weight = 200
|
| 25 |
|
| 26 |
+
|
| 27 |
# 상품 타워: 데이터 임베딩
|
| 28 |
def embed_product_data(product_data):
|
| 29 |
# 상품명과 상세 정보 임베딩 (BERT)
|
|
|
|
| 46 |
|
| 47 |
# 모든 임베딩 벡터 차원 맞추기
|
| 48 |
category_embedding = category_embedding.view(1, -1) # 2D로 변환
|
| 49 |
+
color_embedding = color_embedding.view(1, -1) # 2D로 변환
|
| 50 |
|
| 51 |
# 최종 임베딩 벡터 결합
|
| 52 |
+
combined_embedding = torch.cat(
|
| 53 |
+
(text_embedding, category_embedding, color_embedding), dim=1
|
| 54 |
+
)
|
| 55 |
+
product_embedding = torch.nn.functional.adaptive_avg_pool1d(
|
| 56 |
+
combined_embedding.unsqueeze(0), 512
|
| 57 |
+
).squeeze(0)
|
| 58 |
|
| 59 |
return product_embedding.detach().numpy()
|
| 60 |
|
| 61 |
+
|
| 62 |
# 사용자 타워: 데이터 임베딩
|
| 63 |
def embed_user_data(user_data):
|
| 64 |
# 나이, 성별, 키, 몸무게 임베딩 (임베딩 레이어)
|
| 65 |
+
embedding_layer = Embedding(
|
| 66 |
+
num_embeddings=100, embedding_dim=128
|
| 67 |
+
) # 임의로 설정된 예시 값
|
| 68 |
|
| 69 |
# 예를 들어 성별을 'M'은 0, 'F'는 1로 인코딩
|
| 70 |
+
gender_id = 0 if user_data["gender"] == "M" else 1
|
| 71 |
|
| 72 |
# 스케일링 적용
|
| 73 |
+
height = user_data["height"]
|
| 74 |
+
weight = user_data["weight"]
|
| 75 |
|
| 76 |
if not (min_height <= height <= max_height):
|
| 77 |
+
raise ValueError(
|
| 78 |
+
f"Invalid height value: {height}. Expected range: {min_height}-{max_height}"
|
| 79 |
+
)
|
| 80 |
if not (min_weight <= weight <= max_weight):
|
| 81 |
+
raise ValueError(
|
| 82 |
+
f"Invalid weight value: {weight}. Expected range: {min_weight}-{max_weight}"
|
| 83 |
+
)
|
| 84 |
|
| 85 |
scaled_height = (height - min_height) * 99 // (max_height - min_height)
|
| 86 |
scaled_weight = (weight - min_weight) * 99 // (max_weight - min_weight)
|
| 87 |
+
|
| 88 |
+
age_embedding = embedding_layer(torch.tensor([user_data["age"]])).view(1, -1)
|
| 89 |
gender_embedding = embedding_layer(torch.tensor([gender_id])).view(1, -1)
|
| 90 |
height_embedding = embedding_layer(torch.tensor([scaled_height])).view(1, -1)
|
| 91 |
weight_embedding = embedding_layer(torch.tensor([scaled_weight])).view(1, -1)
|
| 92 |
|
| 93 |
# 최종 임베딩 벡터 결합
|
| 94 |
+
combined_embedding = torch.cat(
|
| 95 |
+
(age_embedding, gender_embedding, height_embedding, weight_embedding), dim=1
|
| 96 |
+
)
|
| 97 |
+
user_embedding = torch.nn.functional.adaptive_avg_pool1d(
|
| 98 |
+
combined_embedding.unsqueeze(0), 512
|
| 99 |
+
).squeeze(0)
|
| 100 |
|
| 101 |
return user_embedding.detach().numpy()
|
| 102 |
|
| 103 |
+
|
| 104 |
# MongoDB Atlas에서 데이터 가져오기
|
| 105 |
+
all_products = product_collection.find() # 모든 상품 데이터 가져오기
|
| 106 |
all_users = user_collection.find() # 모든 사용자 데이터 가져오기
|
| 107 |
|
| 108 |
# 상품 임베딩 수행
|
|
|
|
| 113 |
# MongoDB Atlas의 product_embeddings 컬렉션에 임베딩 저장
|
| 114 |
product_embedding_collection.update_one(
|
| 115 |
{"product_id": product_data["product_id"]}, # product_id 기준으로 찾기
|
| 116 |
+
{
|
| 117 |
+
"$set": {"embedding": product_embedding.tolist()}
|
| 118 |
+
}, # 벡터를 리스트 형태로 저장
|
| 119 |
+
upsert=True, # 기존 항목이 없으면 새로 삽입
|
| 120 |
+
)
|
| 121 |
+
print(
|
| 122 |
+
f"Embedding saved to MongoDB Atlas for Product ID {product_data['product_id']}."
|
| 123 |
)
|
|
|
|
| 124 |
|
| 125 |
# 사용자 임베딩 수행
|
| 126 |
for user_data in all_users:
|
|
|
|
| 131 |
# MongoDB Atlas의 user_embeddings 컬렉션에 임베딩 저장
|
| 132 |
user_embedding_collection.update_one(
|
| 133 |
{"user_id": user_data["user_id"]}, # user_id 기준으로 찾기
|
| 134 |
+
{
|
| 135 |
+
"$set": {"embedding": user_embedding.tolist()}
|
| 136 |
+
}, # 벡터를 리스트 형태로 저장
|
| 137 |
+
upsert=True, # 기존 항목이 없으면 새로 삽입
|
| 138 |
)
|
| 139 |
print(f"Embedding saved to MongoDB Atlas for user_id {user_data['user_id']}.")
|
| 140 |
except ValueError as e:
|
| 141 |
+
print(f"Skipping user_id {user_data['user_id']} due to error: {e}")
|