waseoke commited on
Commit
2bffc21
·
verified ·
1 Parent(s): 09ec91f

Update train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +27 -14
train_model.py CHANGED
@@ -1,16 +1,29 @@
1
- for epoch in range(num_epochs):
2
- optimizer.zero_grad()
3
- anchor_vec = product_model(anchor_data)
4
- positive_vec = product_model(positive_data)
5
- negative_vec = product_model(negative_data)
6
 
7
- # 트립렛 손실 계산
8
- positive_distance = F.pairwise_distance(anchor_vec, positive_vec)
9
- negative_distance = F.pairwise_distance(anchor_vec, negative_vec)
10
- triplet_loss = torch.clamp(positive_distance - negative_distance + margin, min=0).mean()
11
-
12
- # 역전파와 최적화
13
- triplet_loss.backward()
14
- optimizer.step()
15
 
16
- print(f"Epoch {epoch + 1}, Loss: {triplet_loss.item()}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch.optim import Adam
4
+ from torch.utils.data import DataLoader
 
5
 
6
+ def train_triplet_model(product_model, anchor_data, positive_data, negative_data, num_epochs=10, learning_rate=0.001, margin=1.0):
7
+ optimizer = Adam(product_model.parameters(), lr=learning_rate)
 
 
 
 
 
 
8
 
9
+ for epoch in range(num_epochs):
10
+ product_model.train()
11
+ optimizer.zero_grad()
12
+
13
+ # Forward pass
14
+ anchor_vec = product_model(anchor_data)
15
+ positive_vec = product_model(positive_data)
16
+ negative_vec = product_model(negative_data)
17
+
18
+ # Triplet loss calculation
19
+ positive_distance = F.pairwise_distance(anchor_vec, positive_vec)
20
+ negative_distance = F.pairwise_distance(anchor_vec, negative_vec)
21
+ triplet_loss = torch.clamp(positive_distance - negative_distance + margin, min=0).mean()
22
+
23
+ # Backward pass and optimization
24
+ triplet_loss.backward()
25
+ optimizer.step()
26
+
27
+ print(f"Epoch {epoch + 1}, Loss: {triplet_loss.item()}")
28
+
29
+ return product_model