import torch.nn as nn class AlexNet(nn.Module): def __init__(self, num_classes=3): super(AlexNet, self).__init__() self.features = nn.Sequential( nn.Conv3d(1, 64, kernel_size=11, stride=4, padding=2), nn.ReLU(inplace=True), nn.MaxPool3d(kernel_size=3, stride=2), nn.Conv3d(64, 192, kernel_size=5, padding=2), nn.ReLU(inplace=True), nn.MaxPool3d(kernel_size=3, stride=2), nn.Conv3d(192, 384, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv3d(384, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv3d(256, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool3d(kernel_size=3, stride=2), ) self.classifier = nn.Sequential( nn.Dropout(), nn.Linear(256 * 6 * 6 * 6, 4096), nn.ReLU(inplace=True), nn.Dropout(), nn.Linear(4096, 4096), nn.ReLU(inplace=True), nn.Linear(4096, num_classes), ) self.reset_parameters() def reset_parameters(self): for weight in self.parameters(): weight.data.uniform_(-0.1, 0.1) def forward(self, x): x = self.features(x) x = x.view(x.size(0), 256 * 6 * 6 * 6) x = self.classifier(x) return x