spolivin commited on
Commit
e3ae3d1
·
verified ·
1 Parent(s): ee360bb

Added model usage and included number of trainable parameters

Browse files
Files changed (2) hide show
  1. README.md +35 -1
  2. config.json +2 -1
README.md CHANGED
@@ -11,5 +11,39 @@ This is the repo for a custom-made neural network based on ResNet architecture t
11
  - **Architecture:** ResNet
12
  - **Input shape:** 3x32x32
13
  - **Output classes:** 10
 
14
  - **Dataset:** CIFAR-10
15
- > More details can be found in the `config.json` file inside this repository.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  - **Architecture:** ResNet
12
  - **Input shape:** 3x32x32
13
  - **Output classes:** 10
14
+ - **Parameters:** 272,474
15
  - **Dataset:** CIFAR-10
16
+
17
+ More details can be found in the `config.json` file inside this repository and the original [git repository](https://github.com/spolivin/cifar10-website/blob/master/nn_dev/pytorch_models/architectures.py) from which the model originated.
18
+
19
+ ## Model usage
20
+
21
+ Since the model for which the weights loaded in this repository are intended is a part of a custom Python package, one needs to firstly clone the project locally:
22
+
23
+ ```bash
24
+ git clone https://github.com/spolivin/cifar10-website.git
25
+ cd cifar10-website/nn_dev
26
+ ```
27
+
28
+ Next, in a Python script we can make imports and load the weights:
29
+
30
+ ```python
31
+ import torch
32
+
33
+ from pytorch_models import resnet20
34
+
35
+ # URL from which to load the weights
36
+ URL = "https://huggingface.co/spolivin/cnn-cifar10/resolve/main/resnet20_weights.pth"
37
+
38
+ # Building the ResNet-20 model
39
+ resnet20_model = resnet20()
40
+
41
+ # Loading the pretrained weights to the model
42
+ resnet20_model.load_state_dict(
43
+ torch.hub.load_state_dict_from_url(
44
+ url=URL,
45
+ weights_only=True,
46
+ map_location=torch.device("cpu"),
47
+ )
48
+ )
49
+ ```
config.json CHANGED
@@ -7,6 +7,7 @@
7
  ],
8
  "num_layers": 20,
9
  "num_classes": 10,
 
10
  "architecture": "ResNet"
11
  },
12
  "training_parameters": {
@@ -53,4 +54,4 @@
53
  "version": "v1.0",
54
  "purpose": "Image classification on CIFAR-10"
55
  }
56
- }
 
7
  ],
8
  "num_layers": 20,
9
  "num_classes": 10,
10
+ "num_trainable_parameters": 272474,
11
  "architecture": "ResNet"
12
  },
13
  "training_parameters": {
 
54
  "version": "v1.0",
55
  "purpose": "Image classification on CIFAR-10"
56
  }
57
+ }