louiecerv commited on
Commit
1cc1116
·
1 Parent(s): 37c7e43

save changes

Browse files
Files changed (6) hide show
  1. __pycache__/utils.cpython-313.pyc +0 -0
  2. app.py +80 -128
  3. backup.py +204 -0
  4. model_repo +1 -0
  5. requirements.txt +3 -2
  6. utils.py +53 -0
__pycache__/utils.cpython-313.pyc ADDED
Binary file (3.24 kB). View file
 
app.py CHANGED
@@ -7,113 +7,72 @@ from torch.utils.data import DataLoader
7
  from datasets import load_dataset
8
  from huggingface_hub import HfApi, Repository
9
  import os
 
 
 
10
 
11
  # Hugging Face Hub credentials
12
  HF_TOKEN = os.getenv("HF_TOKEN")
13
- MODEL_REPO_ID = "louiecerv/amer_sign_lang_data_augmentation"
14
- DATASET_REPO_ID = "louiecerv/american_sign_language"
15
 
16
  # Device configuration
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
-
19
- # Define the CNN model
20
- class CNN(nn.Module):
21
- def __init__(self):
22
- super(CNN, self).__init__()
23
- self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
24
- self.relu1 = nn.ReLU()
25
- self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
26
- self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
27
- self.relu2 = nn.ReLU()
28
- self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
29
- self.flatten = nn.Flatten()
30
- self.fc = nn.Linear(64 * 7 * 7, 128) # Adjusted for 28x28 images
31
- self.relu3 = nn.ReLU()
32
- self.fc2 = nn.Linear(128, 25) # 25 classes (A-Y)
 
 
 
 
33
 
34
  def forward(self, x):
35
- x = self.pool1(self.relu1(self.conv1(x)))
36
- x = self.pool2(self.relu2(self.conv2(x)))
37
- x = self.flatten(x)
38
- x = self.relu3(self.fc(x))
39
- x = self.fc2(x)
40
- return x
41
-
42
- # Create a model card
43
- def create_model_card():
44
- model_card = """
45
- ---
46
- language: en
47
- tags:
48
- - image-classification
49
- - deep-learning
50
- - cnn
51
- license: apache-2.0
52
- datasets:
53
- Network (CNN) designed to recognize American Sign Language (ASL) letters from images. It was trained on the `louiecerv/american_sign_language` dataset.
54
-
55
- ## Model Description
56
-
57
- The model consists of two convolutional layers followed by max-pooling layers, a flattening layer, and two fully connected layers. It is designed to classify images of ASL letters into 25 classes (A-Y).
58
-
59
- ## Intended Uses & Limitations
60
-
61
- This model is intended for educational purposes and as a demonstration of image classification using CNNs. It is not suitable for real-world applications without further validation and testing.
62
-
63
- ## How to Use
64
-
65
- ```python
66
- import torch
67
- from torchvision import transforms
68
- from PIL import Image
69
-
70
- # Load the model
71
- model = CNN()
72
- model.load_state_dict(torch.load("path_to_model/pytorch_model.bin"))
73
- model.eval()
74
-
75
- # Preprocess the image
76
- transform = transforms.Compose([
77
- transforms.Grayscale(num_output_channels=1),
78
- transforms.Resize((28, 28)),
79
- transforms.ToTensor(),
80
- transforms.Normalize(mean=[0.5], std=[0.5])
81
- ])
82
- image = Image.open("path_to_image").convert("RGB")
83
- image = transform(image).unsqueeze(0)
84
-
85
- # Make a prediction
86
- with torch.no_grad():
87
- output = model(image)
88
- _, predicted = torch.max(output.data, 1)
89
- print(f"Predicted ASL letter: {predicted.item()}")
90
- ```
91
-
92
- ## Training Data
93
-
94
- The model was trained on the `louiecerv/american_sign_language` dataset, which contains images of ASL letters.
95
-
96
- ## Training Procedure
97
-
98
- The model was trained using the Adam optimizer with a learning rate of 0.001 and a batch size of 64. The training process included 5 epochs.
99
-
100
- ## Evaluation Results
101
-
102
- The model achieved an accuracy of 92% on the validation set.
103
- """
104
- with open("model_repo/README.md", "w") as f:
105
- f.write(model_card)
106
 
107
  # Streamlit app
108
  def main():
109
  st.title("American Sign Language Recognition")
 
 
 
 
110
 
111
  # Load the dataset from Hugging Face Hub
112
  dataset = load_dataset(DATASET_REPO_ID)
113
 
114
- # Data loaders with preprocessing:
115
- transform = transforms.Compose([
116
- transforms.Normalize(mean=[0.5], std=[0.5]) # Adjust mean and std if needed
 
 
 
 
117
  ])
118
 
119
  def collate_fn(batch):
@@ -121,18 +80,18 @@ def main():
121
  labels = []
122
  for item in batch:
123
  if 'pixel_values' in item and 'label' in item:
124
- image = torch.tensor(item['pixel_values']) # Convert to tensor
125
  label = item['label']
126
  try:
127
- image = transform(image)
128
  images.append(image)
129
  labels.append(label)
130
  except Exception as e:
131
  print(f"Error processing image: {e}")
132
- continue # Skip to the next image
133
 
134
- if not images: # Check if the list is empty!
135
- return torch.tensor([]), torch.tensor([]) # Return empty tensors if no images loaded
136
 
137
  images = torch.stack(images).to(device)
138
  labels = torch.tensor(labels).long().to(device)
@@ -142,59 +101,52 @@ def main():
142
  val_loader = DataLoader(dataset["validation"], batch_size=64, collate_fn=collate_fn)
143
 
144
  # Model, loss, and optimizer
145
- model = CNN().to(device)
146
  criterion = nn.CrossEntropyLoss()
147
  optimizer = optim.Adam(model.parameters(), lr=0.001)
148
 
149
- # Training loop
150
- num_epochs = st.slider("Number of Epochs", 1, 20, 5) # Streamlit slider
151
- if st.button("Train Model"):
 
152
  for epoch in range(num_epochs):
 
 
 
153
  for i, (images, labels) in enumerate(train_loader):
154
- if images.nelement() == 0: # Check if images tensor is empty
155
  continue
156
 
157
  # Forward pass
158
  outputs = model(images)
159
  loss = criterion(outputs, labels)
 
160
 
161
  # Backward and optimize
162
  optimizer.zero_grad()
163
  loss.backward()
164
  optimizer.step()
165
 
166
- if (i + 1) % 100 == 0:
167
- st.write(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {loss.item():.4f}')
168
-
169
- # Validation
170
- correct = 0
171
- total = 0
172
- with torch.no_grad():
173
- for images, labels in val_loader:
174
- if images.nelement() == 0: # Check if images tensor is empty
175
- continue
176
- outputs = model(images)
177
  _, predicted = torch.max(outputs.data, 1)
178
  total += labels.size(0)
179
  correct += (predicted == labels).sum().item()
180
 
181
- if total > 0:
182
- accuracy = 100 * correct / total
183
- st.write(f'Accuracy of the model on the validation images: {accuracy:.2f}%')
184
- else:
185
- st.write("No validation images were processed.")
186
 
187
- # Save model to Hugging Face Hub
188
- if HF_TOKEN:
189
- repo = Repository(local_dir="model_repo", clone_from=MODEL_REPO_ID, use_auth_token=HF_TOKEN)
190
- model_path = os.path.join(repo.local_dir, "pytorch_model.bin")
191
- torch.save(model.state_dict(), model_path)
192
 
193
- create_model_card()
194
- repo.push_to_hub(commit_message="Trained model and model card", blocking=True)
195
- st.write(f"Model and model card saved to {MODEL_REPO_ID}")
196
- else:
197
- st.warning("HF_TOKEN environment variable not set. Model not saved.")
 
 
 
 
 
198
 
199
  if __name__ == "__main__":
200
- main()
 
7
  from datasets import load_dataset
8
  from huggingface_hub import HfApi, Repository
9
  import os
10
+ import matplotlib.pyplot as plt
11
+
12
+ import utils
13
 
14
  # Hugging Face Hub credentials
15
  HF_TOKEN = os.getenv("HF_TOKEN")
16
+ MODEL_REPO_ID = "louiecerv/amer_sign_lang_data_augmentation"
17
+ DATASET_REPO_ID = "louiecerv/american_sign_language"
18
 
19
  # Device configuration
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ st.write(f"Device: {device}")
22
+
23
+ # Define the new CNN model
24
+ IMG_HEIGHT = 28
25
+ IMG_WIDTH = 28
26
+ IMG_CHS = 1
27
+ N_CLASSES = 24
28
+
29
+ class MyConvBlock(nn.Module):
30
+ def __init__(self, in_ch, out_ch, dropout_p):
31
+ kernel_size = 3
32
+ super().__init__()
33
+ self.model = nn.Sequential(
34
+ nn.Conv2d(in_ch, out_ch, kernel_size, stride=1, padding=1),
35
+ nn.BatchNorm2d(out_ch),
36
+ nn.ReLU(),
37
+ nn.Dropout(dropout_p),
38
+ nn.MaxPool2d(2, stride=2)
39
+ )
40
 
41
  def forward(self, x):
42
+ return self.model(x)
43
+
44
+ flattened_img_size = 75 * 3 * 3
45
+
46
+ # Input 1 x 28 x 28
47
+ base_model = nn.Sequential(
48
+ MyConvBlock(IMG_CHS, 25, 0), # 25 x 14 x 14
49
+ MyConvBlock(25, 50, 0.2), # 50 x 7 x 7
50
+ MyConvBlock(50, 75, 0), # 75 x 3 x 3
51
+ nn.Flatten(),
52
+ nn.Linear(flattened_img_size, 512),
53
+ nn.Dropout(.3),
54
+ nn.ReLU(),
55
+ nn.Linear(512, N_CLASSES)
56
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  # Streamlit app
59
  def main():
60
  st.title("American Sign Language Recognition")
61
+
62
+ # Move slider and button to sidebar
63
+ num_epochs = st.sidebar.slider("Number of Epochs", 1, 20, 5)
64
+ train_button = st.sidebar.button("Train Model")
65
 
66
  # Load the dataset from Hugging Face Hub
67
  dataset = load_dataset(DATASET_REPO_ID)
68
 
69
+ # Data loaders with preprocessing and data augmentation:
70
+ random_transforms = transforms.Compose([
71
+ transforms.RandomRotation(5),
72
+ transforms.RandomResizedCrop((IMG_WIDTH, IMG_HEIGHT), scale=(.9, 1), ratio=(1, 1)),
73
+ transforms.RandomHorizontalFlip(),
74
+ transforms.ColorJitter(brightness=.2, contrast=.5),
75
+ transforms.Normalize(mean=[0.5], std=[0.5])
76
  ])
77
 
78
  def collate_fn(batch):
 
80
  labels = []
81
  for item in batch:
82
  if 'pixel_values' in item and 'label' in item:
83
+ image = torch.tensor(item['pixel_values'])
84
  label = item['label']
85
  try:
86
+ image = random_transforms(image)
87
  images.append(image)
88
  labels.append(label)
89
  except Exception as e:
90
  print(f"Error processing image: {e}")
91
+ continue
92
 
93
+ if not images:
94
+ return torch.tensor([]), torch.tensor([])
95
 
96
  images = torch.stack(images).to(device)
97
  labels = torch.tensor(labels).long().to(device)
 
101
  val_loader = DataLoader(dataset["validation"], batch_size=64, collate_fn=collate_fn)
102
 
103
  # Model, loss, and optimizer
104
+ model = base_model.to(device)
105
  criterion = nn.CrossEntropyLoss()
106
  optimizer = optim.Adam(model.parameters(), lr=0.001)
107
 
108
+ loss_history = []
109
+ accuracy_history = []
110
+
111
+ if train_button:
112
  for epoch in range(num_epochs):
113
+ total = 0
114
+ correct = 0
115
+ epoch_loss = 0
116
  for i, (images, labels) in enumerate(train_loader):
117
+ if images.nelement() == 0:
118
  continue
119
 
120
  # Forward pass
121
  outputs = model(images)
122
  loss = criterion(outputs, labels)
123
+ epoch_loss += loss.item()
124
 
125
  # Backward and optimize
126
  optimizer.zero_grad()
127
  loss.backward()
128
  optimizer.step()
129
 
 
 
 
 
 
 
 
 
 
 
 
130
  _, predicted = torch.max(outputs.data, 1)
131
  total += labels.size(0)
132
  correct += (predicted == labels).sum().item()
133
 
134
+ epoch_accuracy = 100 * correct / total
135
+ loss_history.append(epoch_loss / len(train_loader))
136
+ accuracy_history.append(epoch_accuracy)
 
 
137
 
138
+ st.write(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss / len(train_loader):.4f}, Accuracy: {epoch_accuracy:.2f}%')
 
 
 
 
139
 
140
+ # Plot loss and accuracy
141
+ fig, ax1 = plt.subplots()
142
+ ax2 = ax1.twinx()
143
+ ax1.plot(loss_history, 'g-', label='Loss')
144
+ ax2.plot(accuracy_history, 'b-', label='Accuracy')
145
+ ax1.set_xlabel('Epoch')
146
+ ax1.set_ylabel('Loss', color='g')
147
+ ax2.set_ylabel('Accuracy (%)', color='b')
148
+ plt.title('Training Loss and Accuracy')
149
+ st.pyplot(fig)
150
 
151
  if __name__ == "__main__":
152
+ main()
backup.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torchvision import transforms
6
+ from torch.utils.data import DataLoader
7
+ from datasets import load_dataset
8
+ from huggingface_hub import HfApi, Repository
9
+ import os
10
+ import matplotlib.pyplot as plt
11
+
12
+ import utils
13
+
14
+ # Hugging Face Hub credentials
15
+ HF_TOKEN = os.getenv("HF_TOKEN")
16
+ MODEL_REPO_ID = "louiecerv/amer_sign_lang_data_augmentation"
17
+ DATASET_REPO_ID = "louiecerv/american_sign_language"
18
+
19
+ # Device configuration
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ st.write(f"Device: {device}")
22
+
23
+ # Define the CNN model
24
+ class CNN(nn.Module):
25
+ def __init__(self):
26
+ super(CNN, self).__init__()
27
+ self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
28
+ self.relu1 = nn.ReLU()
29
+ self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
30
+ self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
31
+ self.relu2 = nn.ReLU()
32
+ self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
33
+ self.flatten = nn.Flatten()
34
+ self.fc = nn.Linear(64 * 7 * 7, 128) # Adjusted for 28x28 images
35
+ self.relu3 = nn.ReLU()
36
+ self.fc2 = nn.Linear(128, 25) # 25 classes (A-Y)
37
+
38
+ def forward(self, x):
39
+ x = self.pool1(self.relu1(self.conv1(x)))
40
+ x = self.pool2(self.relu2(self.conv2(x)))
41
+ x = self.flatten(x)
42
+ x = self.relu3(self.fc(x))
43
+ x = self.fc2(x)
44
+ return x
45
+
46
+ # Create a model card
47
+ def create_model_card():
48
+ model_card = """
49
+ ---
50
+ language: en
51
+ tags:
52
+ - image-classification
53
+ - deep-learning
54
+ - cnn
55
+ license: apache-2.0
56
+ datasets:
57
+ Network (CNN) designed to recognize American Sign Language (ASL) letters from images. It was trained on the `louiecerv/american_sign_language` dataset.
58
+
59
+ ## Model Description
60
+
61
+ The model consists of two convolutional layers followed by max-pooling layers, a flattening layer, and two fully connected layers. It is designed to classify images of ASL letters into 25 classes (A-Y).
62
+
63
+ ## Intended Uses & Limitations
64
+
65
+ This model is intended for educational purposes and as a demonstration of image classification using CNNs. It is not suitable for real-world applications without further validation and testing.
66
+
67
+ ## How to Use
68
+
69
+ ```python
70
+ import torch
71
+ from torchvision import transforms
72
+ from PIL import Image
73
+
74
+ # Load the model
75
+ model = CNN()
76
+ model.load_state_dict(torch.load("path_to_model/pytorch_model.bin"))
77
+ model.eval()
78
+
79
+ # Preprocess the image
80
+ transform = transforms.Compose([
81
+ transforms.Grayscale(num_output_channels=1),
82
+ transforms.Resize((28, 28)),
83
+ transforms.ToTensor(),
84
+ transforms.Normalize(mean=[0.5], std=[0.5])
85
+ ])
86
+ image = Image.open("path_to_image").convert("RGB")
87
+ image = transform(image).unsqueeze(0)
88
+
89
+ # Make a prediction
90
+ with torch.no_grad():
91
+ output = model(image)
92
+ _, predicted = torch.max(output.data, 1)
93
+ print(f"Predicted ASL letter: {predicted.item()}")
94
+ ```
95
+
96
+ ## Training Data
97
+
98
+ The model was trained on the `louiecerv/american_sign_language` dataset, which contains images of ASL letters.
99
+
100
+ ## Training Procedure
101
+
102
+ The model was trained using the Adam optimizer with a learning rate of 0.001 and a batch size of 64. The training process included 5 epochs.
103
+
104
+ ## Evaluation Results
105
+
106
+ The model achieved an accuracy of 92% on the validation set.
107
+ """
108
+ with open("model_repo/README.md", "w") as f:
109
+ f.write(model_card)
110
+
111
+ # Streamlit app
112
+ def main():
113
+ st.title("American Sign Language Recognition")
114
+
115
+ # Load the dataset from Hugging Face Hub
116
+ dataset = load_dataset(DATASET_REPO_ID)
117
+
118
+ # Data loaders with preprocessing:
119
+ transform = transforms.Compose([
120
+ transforms.Normalize(mean=[0.5], std=[0.5]) # Adjust mean and std if needed
121
+ ])
122
+
123
+ def collate_fn(batch):
124
+ images = []
125
+ labels = []
126
+ for item in batch:
127
+ if 'pixel_values' in item and 'label' in item:
128
+ image = torch.tensor(item['pixel_values']) # Convert to tensor
129
+ label = item['label']
130
+ try:
131
+ image = transform(image)
132
+ images.append(image)
133
+ labels.append(label)
134
+ except Exception as e:
135
+ print(f"Error processing image: {e}")
136
+ continue # Skip to the next image
137
+
138
+ if not images: # Check if the list is empty!
139
+ return torch.tensor([]), torch.tensor([]) # Return empty tensors if no images loaded
140
+
141
+ images = torch.stack(images).to(device)
142
+ labels = torch.tensor(labels).long().to(device)
143
+ return images, labels
144
+
145
+ train_loader = DataLoader(dataset["train"], batch_size=64, shuffle=True, collate_fn=collate_fn)
146
+ val_loader = DataLoader(dataset["validation"], batch_size=64, collate_fn=collate_fn)
147
+
148
+ # Model, loss, and optimizer
149
+ model = CNN().to(device)
150
+ criterion = nn.CrossEntropyLoss()
151
+ optimizer = optim.Adam(model.parameters(), lr=0.001)
152
+
153
+ # Training loop
154
+ num_epochs = st.slider("Number of Epochs", 1, 20, 5) # Streamlit slider
155
+ if st.button("Train Model"):
156
+ for epoch in range(num_epochs):
157
+ for i, (images, labels) in enumerate(train_loader):
158
+ if images.nelement() == 0: # Check if images tensor is empty
159
+ continue
160
+
161
+ # Forward pass
162
+ outputs = model(images)
163
+ loss = criterion(outputs, labels)
164
+
165
+ # Backward and optimize
166
+ optimizer.zero_grad()
167
+ loss.backward()
168
+ optimizer.step()
169
+
170
+ if (i + 1) % 100 == 0:
171
+ st.write(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {loss.item():.4f}')
172
+
173
+ # Validation
174
+ correct = 0
175
+ total = 0
176
+ with torch.no_grad():
177
+ for images, labels in val_loader:
178
+ if images.nelement() == 0: # Check if images tensor is empty
179
+ continue
180
+ outputs = model(images)
181
+ _, predicted = torch.max(outputs.data, 1)
182
+ total += labels.size(0)
183
+ correct += (predicted == labels).sum().item()
184
+
185
+ if total > 0:
186
+ accuracy = 100 * correct / total
187
+ st.write(f'Accuracy of the model on the validation images: {accuracy:.2f}%')
188
+ else:
189
+ st.write("No validation images were processed.")
190
+
191
+ # Save model to Hugging Face Hub
192
+ if HF_TOKEN:
193
+ repo = Repository(local_dir="model_repo", clone_from=MODEL_REPO_ID, use_auth_token=HF_TOKEN)
194
+ model_path = os.path.join(repo.local_dir, "pytorch_model.bin")
195
+ torch.save(model.state_dict(), model_path)
196
+
197
+ create_model_card()
198
+ repo.push_to_hub(commit_message="Trained model and model card", blocking=True)
199
+ st.write(f"Model and model card saved to {MODEL_REPO_ID}")
200
+ else:
201
+ st.warning("HF_TOKEN environment variable not set. Model not saved.")
202
+
203
+ if __name__ == "__main__":
204
+ main()
model_repo ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 3b1e03d8d415269d86b88c9df83295a9ef454bb5
requirements.txt CHANGED
@@ -5,5 +5,6 @@ huggingface_hub
5
  torch
6
  torchvision
7
  pandas
8
- Pillow # or PIL (Pillow is the actively maintained fork)
9
- scikit-learn
 
 
5
  torch
6
  torchvision
7
  pandas
8
+ Pillow
9
+ scikit-learn
10
+ matplotlib
utils.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class MyConvBlock(nn.Module):
5
+ def __init__(self, in_ch, out_ch, dropout_p):
6
+ kernel_size = 3
7
+ super().__init__()
8
+
9
+ self.model = nn.Sequential(
10
+ nn.Conv2d(in_ch, out_ch, kernel_size, stride=1, padding=1),
11
+ nn.BatchNorm2d(out_ch),
12
+ nn.ReLU(),
13
+ nn.Dropout(dropout_p),
14
+ nn.MaxPool2d(2, stride=2)
15
+ )
16
+
17
+ def forward(self, x):
18
+ return self.model(x)
19
+
20
+ def get_batch_accuracy(output, y, N):
21
+ pred = output.argmax(dim=1, keepdim=True)
22
+ correct = pred.eq(y.view_as(pred)).sum().item()
23
+ return correct / N
24
+
25
+
26
+ def train(model, train_loader, train_N, random_trans, optimizer, loss_function):
27
+ loss = 0
28
+ accuracy = 0
29
+
30
+ model.train()
31
+ for x, y in train_loader:
32
+ output = model(random_trans(x))
33
+ optimizer.zero_grad()
34
+ batch_loss = loss_function(output, y)
35
+ batch_loss.backward()
36
+ optimizer.step()
37
+
38
+ loss += batch_loss.item()
39
+ accuracy += get_batch_accuracy(output, y, train_N)
40
+ print('Train - Loss: {:.4f} Accuracy: {:.4f}'.format(loss, accuracy))
41
+
42
+ def validate(model, valid_loader, valid_N, loss_function):
43
+ loss = 0
44
+ accuracy = 0
45
+
46
+ model.eval()
47
+ with torch.no_grad():
48
+ for x, y in valid_loader:
49
+ output = model(x)
50
+
51
+ loss += loss_function(output, y).item()
52
+ accuracy += get_batch_accuracy(output, y, valid_N)
53
+ print('Valid - Loss: {:.4f} Accuracy: {:.4f}'.format(loss, accuracy))