import torch from torch import nn import torch.optim as optim device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') embedding = nn.Embedding(180, 128).to(device) gt = torch.randint(0, 2, (180, 2048)).to(device) head = nn.Linear(128, 2048).to(device) optimizer = optim.Adam([embedding.weight, head.weight]) while True: pred = head(embedding.weight).sigmoid() loss = nn.MSELoss()(pred, gt.float()) optimizer.zero_grad() loss.backward() optimizer.step()