| import torch | |
| class Predictor: | |
| def __init__(self, model): | |
| self.model = model | |
| def predict(self, test_loader): | |
| self.model.eval() | |
| predictions = [] | |
| with torch.no_grad(): | |
| for images, _ in test_loader: | |
| outputs = self.model(images.view(-1, 28 * 28)) | |
| _, predicted = torch.max(outputs, 1) | |
| predictions.extend(predicted.cpu().numpy()) | |
| return predictions | |