Emily commited on
Commit
f826c3b
·
1 Parent(s): 0e859f8

Add multi-dataset and ResNet-18 architecture support

Browse files

- Fix dataset selection: now properly switches between MNIST, CIFAR-10, and Fashion-MNIST
- Add ResNet-18 architecture option alongside existing MLP and CNN models
- Implement dynamic data preprocessing based on model architecture (flatten for MLPs, keep 2D/3D for CNNs)
- Add model architecture parameter to frontend and backend
- Cache trainers by dataset+architecture combination for efficiency
- Update privacy budget calculations to use correct dataset sizes
- Support for all architecture combinations across datasets

app/routes.py CHANGED
@@ -23,17 +23,27 @@ main = Blueprint('main', __name__)
23
  mock_trainer = MockTrainer()
24
  privacy_calculator = PrivacyCalculator()
25
 
26
- # Initialize real trainer if available
27
- if REAL_TRAINER_AVAILABLE:
28
- try:
29
- real_trainer = RealTrainer()
30
- print("Real trainer initialized successfully")
31
- except Exception as e:
32
- print(f"Failed to initialize real trainer: {e}")
33
- REAL_TRAINER_AVAILABLE = False
34
- real_trainer = None
35
- else:
36
- real_trainer = None
 
 
 
 
 
 
 
 
 
 
37
 
38
  @main.route('/')
39
  def index():
@@ -62,24 +72,46 @@ def train():
62
  'epochs': int(data.get('epochs', 5))
63
  }
64
 
 
 
 
 
65
  # Check if user wants to force mock training
66
  use_mock = data.get('use_mock', False)
67
 
68
  # Use real trainer if available and not forced to use mock
69
- if REAL_TRAINER_AVAILABLE and real_trainer and not use_mock:
70
- print("Using real trainer with MNIST dataset")
71
- results = real_trainer.train(params)
72
- results['trainer_type'] = 'real'
73
- results['dataset'] = 'MNIST'
 
 
 
 
 
 
 
 
 
74
  else:
75
  print("Using mock trainer with synthetic data")
76
  results = mock_trainer.train(params)
77
  results['trainer_type'] = 'mock'
78
  results['dataset'] = 'synthetic'
 
79
 
80
  # Add gradient information for visualization (if not already included)
81
  if 'gradient_info' not in results:
82
- trainer = real_trainer if (REAL_TRAINER_AVAILABLE and real_trainer and not use_mock) else mock_trainer
 
 
 
 
 
 
 
 
83
  results['gradient_info'] = {
84
  'before_clipping': trainer.generate_gradient_norms(params['clipping_norm']),
85
  'after_clipping': trainer.generate_clipped_gradients(params['clipping_norm'])
 
23
  mock_trainer = MockTrainer()
24
  privacy_calculator = PrivacyCalculator()
25
 
26
+ # We'll create trainers dynamically based on dataset selection
27
+ real_trainers = {} # Cache trainers by dataset to avoid reloading
28
+
29
+ def get_or_create_trainer(dataset, model_architecture='simple-mlp'):
30
+ """Get or create a trainer for the specified dataset and architecture."""
31
+ if not REAL_TRAINER_AVAILABLE:
32
+ return None
33
+
34
+ # Create a unique key for dataset + architecture combination
35
+ trainer_key = f"{dataset}_{model_architecture}"
36
+
37
+ if trainer_key not in real_trainers:
38
+ try:
39
+ print(f"Creating new trainer for dataset: {dataset}, architecture: {model_architecture}")
40
+ real_trainers[trainer_key] = RealTrainer(dataset=dataset, model_architecture=model_architecture)
41
+ print(f"Trainer for {dataset} with {model_architecture} initialized successfully")
42
+ except Exception as e:
43
+ print(f"Failed to initialize trainer for {dataset} with {model_architecture}: {e}")
44
+ return None
45
+
46
+ return real_trainers[trainer_key]
47
 
48
  @main.route('/')
49
  def index():
 
72
  'epochs': int(data.get('epochs', 5))
73
  }
74
 
75
+ # Get dataset and model architecture selection
76
+ dataset = data.get('dataset', 'mnist')
77
+ model_architecture = data.get('model_architecture', 'simple-mlp')
78
+
79
  # Check if user wants to force mock training
80
  use_mock = data.get('use_mock', False)
81
 
82
  # Use real trainer if available and not forced to use mock
83
+ if REAL_TRAINER_AVAILABLE and not use_mock:
84
+ real_trainer = get_or_create_trainer(dataset, model_architecture)
85
+ if real_trainer:
86
+ print(f"Using real trainer with {dataset.upper()} dataset and {model_architecture} architecture")
87
+ results = real_trainer.train(params)
88
+ results['trainer_type'] = 'real'
89
+ results['dataset'] = dataset.upper()
90
+ results['model_architecture'] = model_architecture
91
+ else:
92
+ print("Failed to create real trainer, falling back to mock trainer")
93
+ results = mock_trainer.train(params)
94
+ results['trainer_type'] = 'mock'
95
+ results['dataset'] = 'synthetic'
96
+ results['model_architecture'] = 'mock'
97
  else:
98
  print("Using mock trainer with synthetic data")
99
  results = mock_trainer.train(params)
100
  results['trainer_type'] = 'mock'
101
  results['dataset'] = 'synthetic'
102
+ results['model_architecture'] = 'mock'
103
 
104
  # Add gradient information for visualization (if not already included)
105
  if 'gradient_info' not in results:
106
+ if REAL_TRAINER_AVAILABLE and not use_mock:
107
+ current_trainer = get_or_create_trainer(dataset, model_architecture)
108
+ if current_trainer:
109
+ trainer = current_trainer
110
+ else:
111
+ trainer = mock_trainer
112
+ else:
113
+ trainer = mock_trainer
114
+
115
  results['gradient_info'] = {
116
  'before_clipping': trainer.generate_gradient_norms(params['clipping_norm']),
117
  'after_clipping': trainer.generate_clipped_gradients(params['clipping_norm'])
app/static/js/main.js CHANGED
@@ -697,7 +697,9 @@ class DPSGDExplorer {
697
  noise_multiplier: parseFloat(document.getElementById('noise-multiplier').value),
698
  batch_size: parseInt(document.getElementById('batch-size').value),
699
  learning_rate: parseFloat(document.getElementById('learning-rate').value),
700
- epochs: parseInt(document.getElementById('epochs').value)
 
 
701
  };
702
  }
703
 
@@ -720,7 +722,23 @@ class DPSGDExplorer {
720
 
721
  calculateEpochPrivacy(epoch) {
722
  const params = this.getParameters();
723
- const samplingRate = params.batch_size / 60000; // Assuming MNIST size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
724
  const steps = epoch * (1 / samplingRate);
725
  const delta = 1e-5;
726
  const c = Math.sqrt(2 * Math.log(1.25 / delta));
 
697
  noise_multiplier: parseFloat(document.getElementById('noise-multiplier').value),
698
  batch_size: parseInt(document.getElementById('batch-size').value),
699
  learning_rate: parseFloat(document.getElementById('learning-rate').value),
700
+ epochs: parseInt(document.getElementById('epochs').value),
701
+ dataset: document.getElementById('dataset-select').value,
702
+ model_architecture: document.getElementById('model-select').value
703
  };
704
  }
705
 
 
722
 
723
  calculateEpochPrivacy(epoch) {
724
  const params = this.getParameters();
725
+
726
+ // Get dataset size based on selection
727
+ let datasetSize;
728
+ switch(params.dataset) {
729
+ case 'cifar10':
730
+ datasetSize = 50000; // CIFAR-10 training set size
731
+ break;
732
+ case 'fashion-mnist':
733
+ datasetSize = 60000; // Fashion-MNIST training set size
734
+ break;
735
+ case 'mnist':
736
+ default:
737
+ datasetSize = 60000; // MNIST training set size
738
+ break;
739
+ }
740
+
741
+ const samplingRate = params.batch_size / datasetSize;
742
  const steps = epoch * (1 / samplingRate);
743
  const delta = 1e-5;
744
  const c = Math.sqrt(2 * Math.log(1.25 / delta));
app/templates/index.html CHANGED
@@ -39,6 +39,7 @@
39
  <option value="simple-mlp">Simple MLP</option>
40
  <option value="simple-cnn">Simple CNN</option>
41
  <option value="advanced-cnn">Advanced CNN</option>
 
42
  </select>
43
  </div>
44
 
 
39
  <option value="simple-mlp">Simple MLP</option>
40
  <option value="simple-cnn">Simple CNN</option>
41
  <option value="advanced-cnn">Advanced CNN</option>
42
+ <option value="resnet18">ResNet-18</option>
43
  </select>
44
  </div>
45
 
app/training/simplified_real_trainer.py CHANGED
@@ -8,15 +8,32 @@ import logging
8
  logging.getLogger('tensorflow').setLevel(logging.ERROR)
9
 
10
  class SimplifiedRealTrainer:
11
- def __init__(self):
12
  # Set random seeds for reproducibility
13
  tf.random.set_seed(42)
14
  np.random.seed(42)
15
 
16
- # Load and preprocess MNIST dataset
17
- self.x_train, self.y_train, self.x_test, self.y_test = self._load_mnist()
 
 
 
 
 
 
18
  self.model = None
19
 
 
 
 
 
 
 
 
 
 
 
 
20
  def _load_mnist(self):
21
  """Load and preprocess MNIST dataset."""
22
  print("Loading MNIST dataset...")
@@ -28,9 +45,90 @@ class SimplifiedRealTrainer:
28
  x_train = x_train.astype('float32') / 255.0
29
  x_test = x_test.astype('float32') / 255.0
30
 
31
- # Reshape to flatten images
32
- x_train = x_train.reshape(-1, 28 * 28)
33
- x_test = x_test.reshape(-1, 28 * 28)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  # Convert labels to categorical
36
  y_train = keras.utils.to_categorical(y_train, 10)
@@ -42,15 +140,113 @@ class SimplifiedRealTrainer:
42
  return x_train, y_train, x_test, y_test
43
 
44
  def _create_model(self):
45
- """Create a simple MLP model for MNIST classification optimized for DP-SGD."""
46
- # Use a simpler, more robust architecture for DP-SGD
 
 
 
 
 
 
 
 
 
 
 
 
47
  model = keras.Sequential([
48
- keras.layers.Dense(256, activation='tanh', input_shape=(784,)), # tanh works better with DP-SGD
49
  keras.layers.Dense(128, activation='tanh'),
50
- keras.layers.Dense(10, activation='softmax')
51
  ])
52
  return model
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  def _clip_gradients(self, gradients, clipping_norm):
55
  """Clip gradients to a maximum L2 norm globally across all parameters."""
56
  # Calculate global L2 norm across all gradients
 
8
  logging.getLogger('tensorflow').setLevel(logging.ERROR)
9
 
10
  class SimplifiedRealTrainer:
11
+ def __init__(self, dataset='mnist', model_architecture='simple-mlp'):
12
  # Set random seeds for reproducibility
13
  tf.random.set_seed(42)
14
  np.random.seed(42)
15
 
16
+ self.dataset = dataset
17
+ self.model_architecture = model_architecture
18
+ self.input_shape = None
19
+ self.original_shape = None # For CNNs that need 2D/3D inputs
20
+ self.num_classes = 10
21
+
22
+ # Load and preprocess the specified dataset
23
+ self.x_train, self.y_train, self.x_test, self.y_test = self._load_dataset(dataset)
24
  self.model = None
25
 
26
+ def _load_dataset(self, dataset):
27
+ """Load and preprocess the specified dataset."""
28
+ if dataset == 'mnist':
29
+ return self._load_mnist()
30
+ elif dataset == 'cifar10':
31
+ return self._load_cifar10()
32
+ elif dataset == 'fashion-mnist':
33
+ return self._load_fashion_mnist()
34
+ else:
35
+ raise ValueError(f"Unsupported dataset: {dataset}")
36
+
37
  def _load_mnist(self):
38
  """Load and preprocess MNIST dataset."""
39
  print("Loading MNIST dataset...")
 
45
  x_train = x_train.astype('float32') / 255.0
46
  x_test = x_test.astype('float32') / 255.0
47
 
48
+ # Store original shape for CNNs (add channel dimension)
49
+ self.original_shape = (28, 28, 1)
50
+
51
+ # For MLPs, flatten the images; for CNNs, keep 2D shape
52
+ if self.model_architecture in ['simple-cnn', 'advanced-cnn', 'resnet18']:
53
+ x_train = x_train.reshape(-1, 28, 28, 1)
54
+ x_test = x_test.reshape(-1, 28, 28, 1)
55
+ self.input_shape = (28, 28, 1)
56
+ else:
57
+ x_train = x_train.reshape(-1, 28 * 28)
58
+ x_test = x_test.reshape(-1, 28 * 28)
59
+ self.input_shape = (784,)
60
+
61
+ self.num_classes = 10
62
+
63
+ # Convert labels to categorical
64
+ y_train = keras.utils.to_categorical(y_train, 10)
65
+ y_test = keras.utils.to_categorical(y_test, 10)
66
+
67
+ print(f"Training data shape: {x_train.shape}")
68
+ print(f"Test data shape: {x_test.shape}")
69
+
70
+ return x_train, y_train, x_test, y_test
71
+
72
+ def _load_cifar10(self):
73
+ """Load and preprocess CIFAR-10 dataset."""
74
+ print("Loading CIFAR-10 dataset...")
75
+
76
+ # Load CIFAR-10 data
77
+ (x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
78
+
79
+ # Normalize pixel values to [0, 1]
80
+ x_train = x_train.astype('float32') / 255.0
81
+ x_test = x_test.astype('float32') / 255.0
82
+
83
+ # Store original shape for CNNs
84
+ self.original_shape = (32, 32, 3)
85
+
86
+ # For MLPs, flatten the images; for CNNs, keep 3D shape
87
+ if self.model_architecture in ['simple-cnn', 'advanced-cnn', 'resnet18']:
88
+ # Keep original shape for CNNs
89
+ self.input_shape = (32, 32, 3)
90
+ else:
91
+ # Flatten for MLPs
92
+ x_train = x_train.reshape(-1, 32 * 32 * 3)
93
+ x_test = x_test.reshape(-1, 32 * 32 * 3)
94
+ self.input_shape = (3072,)
95
+
96
+ self.num_classes = 10
97
+
98
+ # Convert labels to categorical
99
+ y_train = keras.utils.to_categorical(y_train, 10)
100
+ y_test = keras.utils.to_categorical(y_test, 10)
101
+
102
+ print(f"Training data shape: {x_train.shape}")
103
+ print(f"Test data shape: {x_test.shape}")
104
+
105
+ return x_train, y_train, x_test, y_test
106
+
107
+ def _load_fashion_mnist(self):
108
+ """Load and preprocess Fashion-MNIST dataset."""
109
+ print("Loading Fashion-MNIST dataset...")
110
+
111
+ # Load Fashion-MNIST data
112
+ (x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
113
+
114
+ # Normalize pixel values to [0, 1]
115
+ x_train = x_train.astype('float32') / 255.0
116
+ x_test = x_test.astype('float32') / 255.0
117
+
118
+ # Store original shape for CNNs (add channel dimension)
119
+ self.original_shape = (28, 28, 1)
120
+
121
+ # For MLPs, flatten the images; for CNNs, keep 2D shape
122
+ if self.model_architecture in ['simple-cnn', 'advanced-cnn', 'resnet18']:
123
+ x_train = x_train.reshape(-1, 28, 28, 1)
124
+ x_test = x_test.reshape(-1, 28, 28, 1)
125
+ self.input_shape = (28, 28, 1)
126
+ else:
127
+ x_train = x_train.reshape(-1, 28 * 28)
128
+ x_test = x_test.reshape(-1, 28 * 28)
129
+ self.input_shape = (784,)
130
+
131
+ self.num_classes = 10
132
 
133
  # Convert labels to categorical
134
  y_train = keras.utils.to_categorical(y_train, 10)
 
140
  return x_train, y_train, x_test, y_test
141
 
142
  def _create_model(self):
143
+ """Create a model based on the specified architecture."""
144
+ if self.model_architecture == 'simple-mlp':
145
+ return self._create_simple_mlp()
146
+ elif self.model_architecture == 'simple-cnn':
147
+ return self._create_simple_cnn()
148
+ elif self.model_architecture == 'advanced-cnn':
149
+ return self._create_advanced_cnn()
150
+ elif self.model_architecture == 'resnet18':
151
+ return self._create_resnet18()
152
+ else:
153
+ raise ValueError(f"Unsupported model architecture: {self.model_architecture}")
154
+
155
+ def _create_simple_mlp(self):
156
+ """Create a simple MLP model optimized for DP-SGD."""
157
  model = keras.Sequential([
158
+ keras.layers.Dense(256, activation='tanh', input_shape=self.input_shape), # tanh works better with DP-SGD
159
  keras.layers.Dense(128, activation='tanh'),
160
+ keras.layers.Dense(self.num_classes, activation='softmax')
161
  ])
162
  return model
163
 
164
+ def _create_simple_cnn(self):
165
+ """Create a simple CNN model optimized for DP-SGD."""
166
+ model = keras.Sequential([
167
+ keras.layers.Conv2D(32, (3, 3), activation='tanh', input_shape=self.input_shape),
168
+ keras.layers.MaxPooling2D((2, 2)),
169
+ keras.layers.Conv2D(64, (3, 3), activation='tanh'),
170
+ keras.layers.MaxPooling2D((2, 2)),
171
+ keras.layers.Flatten(),
172
+ keras.layers.Dense(128, activation='tanh'),
173
+ keras.layers.Dense(self.num_classes, activation='softmax')
174
+ ])
175
+ return model
176
+
177
+ def _create_advanced_cnn(self):
178
+ """Create an advanced CNN model optimized for DP-SGD."""
179
+ model = keras.Sequential([
180
+ keras.layers.Conv2D(32, (3, 3), activation='tanh', input_shape=self.input_shape),
181
+ keras.layers.BatchNormalization(),
182
+ keras.layers.Conv2D(32, (3, 3), activation='tanh'),
183
+ keras.layers.MaxPooling2D((2, 2)),
184
+ keras.layers.Dropout(0.25),
185
+
186
+ keras.layers.Conv2D(64, (3, 3), activation='tanh'),
187
+ keras.layers.BatchNormalization(),
188
+ keras.layers.Conv2D(64, (3, 3), activation='tanh'),
189
+ keras.layers.MaxPooling2D((2, 2)),
190
+ keras.layers.Dropout(0.25),
191
+
192
+ keras.layers.Flatten(),
193
+ keras.layers.Dense(256, activation='tanh'),
194
+ keras.layers.Dropout(0.5),
195
+ keras.layers.Dense(128, activation='tanh'),
196
+ keras.layers.Dense(self.num_classes, activation='softmax')
197
+ ])
198
+ return model
199
+
200
+ def _create_resnet18(self):
201
+ """Create a ResNet-18 model optimized for DP-SGD."""
202
+ def residual_block(x, filters, kernel_size=3, stride=1, conv_shortcut=False):
203
+ """A residual block for ResNet."""
204
+ if conv_shortcut:
205
+ shortcut = keras.layers.Conv2D(filters, 1, strides=stride, padding='same')(x)
206
+ shortcut = keras.layers.BatchNormalization()(shortcut)
207
+ else:
208
+ shortcut = x
209
+
210
+ x = keras.layers.Conv2D(filters, kernel_size, strides=stride, padding='same')(x)
211
+ x = keras.layers.BatchNormalization()(x)
212
+ x = keras.layers.Activation('tanh')(x) # Use tanh for DP-SGD
213
+
214
+ x = keras.layers.Conv2D(filters, kernel_size, padding='same')(x)
215
+ x = keras.layers.BatchNormalization()(x)
216
+
217
+ x = keras.layers.Add()([shortcut, x])
218
+ x = keras.layers.Activation('tanh')(x)
219
+ return x
220
+
221
+ def resnet_block(x, filters, num_blocks, stride=1):
222
+ """A stack of residual blocks."""
223
+ x = residual_block(x, filters, stride=stride, conv_shortcut=True)
224
+ for _ in range(num_blocks - 1):
225
+ x = residual_block(x, filters)
226
+ return x
227
+
228
+ # Input layer
229
+ inputs = keras.layers.Input(shape=self.input_shape)
230
+
231
+ # Initial convolution
232
+ x = keras.layers.Conv2D(64, 7, strides=2, padding='same')(inputs)
233
+ x = keras.layers.BatchNormalization()(x)
234
+ x = keras.layers.Activation('tanh')(x)
235
+ x = keras.layers.MaxPooling2D(3, strides=2, padding='same')(x)
236
+
237
+ # ResNet blocks
238
+ x = resnet_block(x, 64, 2)
239
+ x = resnet_block(x, 128, 2, stride=2)
240
+ x = resnet_block(x, 256, 2, stride=2)
241
+ x = resnet_block(x, 512, 2, stride=2)
242
+
243
+ # Global average pooling and output
244
+ x = keras.layers.GlobalAveragePooling2D()(x)
245
+ x = keras.layers.Dense(self.num_classes, activation='softmax')(x)
246
+
247
+ model = keras.Model(inputs, x)
248
+ return model
249
+
250
  def _clip_gradients(self, gradients, clipping_norm):
251
  """Clip gradients to a maximum L2 norm globally across all parameters."""
252
  # Calculate global L2 norm across all gradients