import torch | |
from torch import nn | |
import torch.optim as optim | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
embedding = nn.Embedding(180, 128).to(device) | |
gt = torch.randint(0, 2, (180, 2048)).to(device) | |
head = nn.Linear(128, 2048).to(device) | |
optimizer = optim.Adam([embedding.weight, head.weight]) | |
while True: | |
pred = head(embedding.weight).sigmoid() | |
loss = nn.MSELoss()(pred, gt.float()) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |