|
import torch |
|
import torch.nn as nn |
|
from torch.nn import functional as F |
|
import timm |
|
|
|
|
|
class CNNMedium(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.module = nn.Sequential( |
|
nn.Conv2d(3, 16, 3), |
|
nn.MaxPool2d(2, 2), |
|
nn.LeakyReLU(), |
|
nn.Conv2d(16, 32, 3), |
|
nn.MaxPool2d(2, 2), |
|
nn.LeakyReLU(), |
|
nn.Conv2d(32, 15, 3), |
|
nn.MaxPool2d(2, 2), |
|
nn.LeakyReLU(), |
|
nn.Flatten(start_dim=1), |
|
) |
|
self.head = nn.Sequential( |
|
nn.Linear(60, 20), |
|
nn.LeakyReLU(), |
|
nn.Linear(20, 10), |
|
) |
|
|
|
def forward(self, x): |
|
x = self.module(x) |
|
x = self.head(x) |
|
return x |
|
|
|
|
|
def Model(): |
|
model = CNNMedium() |
|
return model, model.head |
|
|
|
|
|
if __name__ == "__main__": |
|
model, _ = Model() |
|
x = torch.ones([4, 3, 32, 32]) |
|
y = model(x) |
|
print(y.shape) |
|
print(model) |
|
num_param = 0 |
|
for v in model.parameters(): |
|
num_param += v.numel() |
|
print("num_param:", num_param) |
|
|