Added model usage and included number of trainable parameters
Browse files- README.md +35 -1
- config.json +2 -1
README.md
CHANGED
@@ -11,5 +11,39 @@ This is the repo for a custom-made neural network based on ResNet architecture t
|
|
11 |
- **Architecture:** ResNet
|
12 |
- **Input shape:** 3x32x32
|
13 |
- **Output classes:** 10
|
|
|
14 |
- **Dataset:** CIFAR-10
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
- **Architecture:** ResNet
|
12 |
- **Input shape:** 3x32x32
|
13 |
- **Output classes:** 10
|
14 |
+
- **Parameters:** 272,474
|
15 |
- **Dataset:** CIFAR-10
|
16 |
+
|
17 |
+
More details can be found in the `config.json` file inside this repository and the original [git repository](https://github.com/spolivin/cifar10-website/blob/master/nn_dev/pytorch_models/architectures.py) from which the model originated.
|
18 |
+
|
19 |
+
## Model usage
|
20 |
+
|
21 |
+
Since the model for which the weights loaded in this repository are intended is a part of a custom Python package, one needs to firstly clone the project locally:
|
22 |
+
|
23 |
+
```bash
|
24 |
+
git clone https://github.com/spolivin/cifar10-website.git
|
25 |
+
cd cifar10-website/nn_dev
|
26 |
+
```
|
27 |
+
|
28 |
+
Next, in a Python script we can make imports and load the weights:
|
29 |
+
|
30 |
+
```python
|
31 |
+
import torch
|
32 |
+
|
33 |
+
from pytorch_models import resnet20
|
34 |
+
|
35 |
+
# URL from which to load the weights
|
36 |
+
URL = "https://huggingface.co/spolivin/cnn-cifar10/resolve/main/resnet20_weights.pth"
|
37 |
+
|
38 |
+
# Building the ResNet-20 model
|
39 |
+
resnet20_model = resnet20()
|
40 |
+
|
41 |
+
# Loading the pretrained weights to the model
|
42 |
+
resnet20_model.load_state_dict(
|
43 |
+
torch.hub.load_state_dict_from_url(
|
44 |
+
url=URL,
|
45 |
+
weights_only=True,
|
46 |
+
map_location=torch.device("cpu"),
|
47 |
+
)
|
48 |
+
)
|
49 |
+
```
|
config.json
CHANGED
@@ -7,6 +7,7 @@
|
|
7 |
],
|
8 |
"num_layers": 20,
|
9 |
"num_classes": 10,
|
|
|
10 |
"architecture": "ResNet"
|
11 |
},
|
12 |
"training_parameters": {
|
@@ -53,4 +54,4 @@
|
|
53 |
"version": "v1.0",
|
54 |
"purpose": "Image classification on CIFAR-10"
|
55 |
}
|
56 |
-
}
|
|
|
7 |
],
|
8 |
"num_layers": 20,
|
9 |
"num_classes": 10,
|
10 |
+
"num_trainable_parameters": 272474,
|
11 |
"architecture": "ResNet"
|
12 |
},
|
13 |
"training_parameters": {
|
|
|
54 |
"version": "v1.0",
|
55 |
"purpose": "Image classification on CIFAR-10"
|
56 |
}
|
57 |
+
}
|