Spaces:
Sleeping
Sleeping
File size: 5,562 Bytes
2bffc21 6c15b41 09ec91f 6c15b41 45b2ac8 6c15b41 45b2ac8 6c15b41 45b2ac8 6c15b41 45b2ac8 6c15b41 45b2ac8 6c15b41 45b2ac8 6c15b41 45b2ac8 6c15b41 45b2ac8 6c15b41 45b2ac8 6c15b41 45b2ac8 6c15b41 45b2ac8 6c15b41 45b2ac8 6c15b41 45b2ac8 2bffc21 09ec91f 2bffc21 6c15b41 2bffc21 6c15b41 2bffc21 6c15b41 2bffc21 6c15b41 45b2ac8 2bffc21 6c15b41 45b2ac8 2bffc21 6c15b41 45b2ac8 6c15b41 45b2ac8 6c15b41 45b2ac8 6c15b41 45b2ac8 6c15b41 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import torch
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
from pymongo import MongoClient
from transformers import BertTokenizer, BertModel
import numpy as np
# MongoDB Atlas 연결 설정
client = MongoClient(
"mongodb+srv://waseoke:rookies3@cluster0.ps7gq.mongodb.net/test?retryWrites=true&w=majority"
)
db = client["two_tower_model"]
train_dataset = db["train_dataset"]
# KoBERT 모델 및 토크나이저 로드
tokenizer = BertTokenizer.from_pretrained('monologg/kobert')
model = BertModel.from_pretrained('monologg/kobert')
# 상품 임베딩 함수
def embed_product_data(product):
"""
상품 데이터를 KoBERT로 임베딩하는 함수.
"""
text = (
product.get("product_name", "") + " " + product.get("product_description", "")
)
inputs = tokenizer(
text, return_tensors="pt", truncation=True, padding=True, max_length=128
)
outputs = model(**inputs)
embedding = (
outputs.last_hidden_state.mean(dim=1).detach().numpy().flatten()
) # 평균 풀링
return embedding
# PyTorch Dataset 정의
class TripletDataset(Dataset):
def __init__(self, dataset):
self.dataset = dataset
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
data = self.dataset[idx]
anchor = torch.tensor(data["anchor_embedding"], dtype=torch.float32)
positive = torch.tensor(data["positive_embedding"], dtype=torch.float32)
negative = torch.tensor(data["negative_embedding"], dtype=torch.float32)
return anchor, positive, negative
# MongoDB에서 데이터셋 로드 및 임베딩 변환
def prepare_training_data(verbose=False):
dataset = list(train_dataset.find())
if not dataset:
raise ValueError("No training data found in MongoDB.")
# Anchor, Positive, Negative 임베딩 생성
embedded_dataset = []
for idx, entry in enumerate(dataset):
try:
# Anchor, Positive, Negative 데이터 임베딩
anchor_embedding = embed_product_data(entry["anchor"]["product"])
positive_embedding = embed_product_data(entry["positive"]["product"])
negative_embedding = embed_product_data(entry["negative"]["product"])
# 임베딩 확인 (옵션으로 출력)
if verbose:
print(f"Sample {idx + 1}:")
print(
f"Anchor Embedding: {anchor_embedding[:5]}... (shape: {anchor_embedding.shape})"
)
print(
f"Positive Embedding: {positive_embedding[:5]}... (shape: {positive_embedding.shape})"
)
print(
f"Negative Embedding: {negative_embedding[:5]}... (shape: {negative_embedding.shape})"
)
# 임베딩 결과 저장
embedded_dataset.append(
{
"anchor_embedding": anchor_embedding,
"positive_embedding": positive_embedding,
"negative_embedding": negative_embedding,
}
)
except Exception as e:
print(f"Error embedding data at sample {idx + 1}: {e}")
return TripletDataset(embedded_dataset)
# 데이터셋 검증용 함수
def validate_embeddings():
"""
데이터셋 임베딩을 생성하고 각 임베딩의 일부를 출력하여 확인.
"""
print("Validating embeddings...")
triplet_dataset = prepare_training_data(verbose=True)
print(f"Total samples: {len(triplet_dataset)}")
return triplet_dataset
# Triplet Loss를 학습시키는 함수
def train_triplet_model(
product_model, train_loader, num_epochs=10, learning_rate=0.001, margin=0.05
):
optimizer = Adam(product_model.parameters(), lr=learning_rate)
for epoch in range(num_epochs):
product_model.train()
total_loss = 0
for anchor, positive, negative in train_loader:
optimizer.zero_grad()
# Forward pass
anchor_vec = product_model(anchor)
positive_vec = product_model(positive)
negative_vec = product_model(negative)
# Triplet loss 계산
positive_distance = F.pairwise_distance(anchor_vec, positive_vec)
negative_distance = F.pairwise_distance(anchor_vec, negative_vec)
triplet_loss = torch.clamp(
positive_distance - negative_distance + margin, min=0
).mean()
# 역전파와 최적화
triplet_loss.backward()
optimizer.step()
total_loss += triplet_loss.item()
print(
f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss / len(train_loader):.4f}"
)
return product_model
# 모델 학습 파이프라인
def main():
# 모델 초기화 (예시 모델)
product_model = torch.nn.Sequential(
torch.nn.Linear(768, 256), # 768: KoBERT 임베딩 차원
torch.nn.ReLU(),
torch.nn.Linear(256, 128),
)
# 데이터 준비
triplet_dataset = prepare_training_data()
train_loader = DataLoader(triplet_dataset, batch_size=16, shuffle=True)
# 모델 학습
trained_model = train_triplet_model(product_model, train_loader)
# 학습된 모델 저장
torch.save(trained_model.state_dict(), "product_model.pth")
print("Model training completed and saved.")
print(validate_embeddings())
if __name__ == "__main__":
main()
|