Emily
commited on
Commit
·
f826c3b
1
Parent(s):
0e859f8
Add multi-dataset and ResNet-18 architecture support
Browse files- Fix dataset selection: now properly switches between MNIST, CIFAR-10, and Fashion-MNIST
- Add ResNet-18 architecture option alongside existing MLP and CNN models
- Implement dynamic data preprocessing based on model architecture (flatten for MLPs, keep 2D/3D for CNNs)
- Add model architecture parameter to frontend and backend
- Cache trainers by dataset+architecture combination for efficiency
- Update privacy budget calculations to use correct dataset sizes
- Support for all architecture combinations across datasets
- app/routes.py +49 -17
- app/static/js/main.js +20 -2
- app/templates/index.html +1 -0
- app/training/simplified_real_trainer.py +206 -10
app/routes.py
CHANGED
|
@@ -23,17 +23,27 @@ main = Blueprint('main', __name__)
|
|
| 23 |
mock_trainer = MockTrainer()
|
| 24 |
privacy_calculator = PrivacyCalculator()
|
| 25 |
|
| 26 |
-
#
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
@main.route('/')
|
| 39 |
def index():
|
|
@@ -62,24 +72,46 @@ def train():
|
|
| 62 |
'epochs': int(data.get('epochs', 5))
|
| 63 |
}
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
# Check if user wants to force mock training
|
| 66 |
use_mock = data.get('use_mock', False)
|
| 67 |
|
| 68 |
# Use real trainer if available and not forced to use mock
|
| 69 |
-
if REAL_TRAINER_AVAILABLE and
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
else:
|
| 75 |
print("Using mock trainer with synthetic data")
|
| 76 |
results = mock_trainer.train(params)
|
| 77 |
results['trainer_type'] = 'mock'
|
| 78 |
results['dataset'] = 'synthetic'
|
|
|
|
| 79 |
|
| 80 |
# Add gradient information for visualization (if not already included)
|
| 81 |
if 'gradient_info' not in results:
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
results['gradient_info'] = {
|
| 84 |
'before_clipping': trainer.generate_gradient_norms(params['clipping_norm']),
|
| 85 |
'after_clipping': trainer.generate_clipped_gradients(params['clipping_norm'])
|
|
|
|
| 23 |
mock_trainer = MockTrainer()
|
| 24 |
privacy_calculator = PrivacyCalculator()
|
| 25 |
|
| 26 |
+
# We'll create trainers dynamically based on dataset selection
|
| 27 |
+
real_trainers = {} # Cache trainers by dataset to avoid reloading
|
| 28 |
+
|
| 29 |
+
def get_or_create_trainer(dataset, model_architecture='simple-mlp'):
|
| 30 |
+
"""Get or create a trainer for the specified dataset and architecture."""
|
| 31 |
+
if not REAL_TRAINER_AVAILABLE:
|
| 32 |
+
return None
|
| 33 |
+
|
| 34 |
+
# Create a unique key for dataset + architecture combination
|
| 35 |
+
trainer_key = f"{dataset}_{model_architecture}"
|
| 36 |
+
|
| 37 |
+
if trainer_key not in real_trainers:
|
| 38 |
+
try:
|
| 39 |
+
print(f"Creating new trainer for dataset: {dataset}, architecture: {model_architecture}")
|
| 40 |
+
real_trainers[trainer_key] = RealTrainer(dataset=dataset, model_architecture=model_architecture)
|
| 41 |
+
print(f"Trainer for {dataset} with {model_architecture} initialized successfully")
|
| 42 |
+
except Exception as e:
|
| 43 |
+
print(f"Failed to initialize trainer for {dataset} with {model_architecture}: {e}")
|
| 44 |
+
return None
|
| 45 |
+
|
| 46 |
+
return real_trainers[trainer_key]
|
| 47 |
|
| 48 |
@main.route('/')
|
| 49 |
def index():
|
|
|
|
| 72 |
'epochs': int(data.get('epochs', 5))
|
| 73 |
}
|
| 74 |
|
| 75 |
+
# Get dataset and model architecture selection
|
| 76 |
+
dataset = data.get('dataset', 'mnist')
|
| 77 |
+
model_architecture = data.get('model_architecture', 'simple-mlp')
|
| 78 |
+
|
| 79 |
# Check if user wants to force mock training
|
| 80 |
use_mock = data.get('use_mock', False)
|
| 81 |
|
| 82 |
# Use real trainer if available and not forced to use mock
|
| 83 |
+
if REAL_TRAINER_AVAILABLE and not use_mock:
|
| 84 |
+
real_trainer = get_or_create_trainer(dataset, model_architecture)
|
| 85 |
+
if real_trainer:
|
| 86 |
+
print(f"Using real trainer with {dataset.upper()} dataset and {model_architecture} architecture")
|
| 87 |
+
results = real_trainer.train(params)
|
| 88 |
+
results['trainer_type'] = 'real'
|
| 89 |
+
results['dataset'] = dataset.upper()
|
| 90 |
+
results['model_architecture'] = model_architecture
|
| 91 |
+
else:
|
| 92 |
+
print("Failed to create real trainer, falling back to mock trainer")
|
| 93 |
+
results = mock_trainer.train(params)
|
| 94 |
+
results['trainer_type'] = 'mock'
|
| 95 |
+
results['dataset'] = 'synthetic'
|
| 96 |
+
results['model_architecture'] = 'mock'
|
| 97 |
else:
|
| 98 |
print("Using mock trainer with synthetic data")
|
| 99 |
results = mock_trainer.train(params)
|
| 100 |
results['trainer_type'] = 'mock'
|
| 101 |
results['dataset'] = 'synthetic'
|
| 102 |
+
results['model_architecture'] = 'mock'
|
| 103 |
|
| 104 |
# Add gradient information for visualization (if not already included)
|
| 105 |
if 'gradient_info' not in results:
|
| 106 |
+
if REAL_TRAINER_AVAILABLE and not use_mock:
|
| 107 |
+
current_trainer = get_or_create_trainer(dataset, model_architecture)
|
| 108 |
+
if current_trainer:
|
| 109 |
+
trainer = current_trainer
|
| 110 |
+
else:
|
| 111 |
+
trainer = mock_trainer
|
| 112 |
+
else:
|
| 113 |
+
trainer = mock_trainer
|
| 114 |
+
|
| 115 |
results['gradient_info'] = {
|
| 116 |
'before_clipping': trainer.generate_gradient_norms(params['clipping_norm']),
|
| 117 |
'after_clipping': trainer.generate_clipped_gradients(params['clipping_norm'])
|
app/static/js/main.js
CHANGED
|
@@ -697,7 +697,9 @@ class DPSGDExplorer {
|
|
| 697 |
noise_multiplier: parseFloat(document.getElementById('noise-multiplier').value),
|
| 698 |
batch_size: parseInt(document.getElementById('batch-size').value),
|
| 699 |
learning_rate: parseFloat(document.getElementById('learning-rate').value),
|
| 700 |
-
epochs: parseInt(document.getElementById('epochs').value)
|
|
|
|
|
|
|
| 701 |
};
|
| 702 |
}
|
| 703 |
|
|
@@ -720,7 +722,23 @@ class DPSGDExplorer {
|
|
| 720 |
|
| 721 |
calculateEpochPrivacy(epoch) {
|
| 722 |
const params = this.getParameters();
|
| 723 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 724 |
const steps = epoch * (1 / samplingRate);
|
| 725 |
const delta = 1e-5;
|
| 726 |
const c = Math.sqrt(2 * Math.log(1.25 / delta));
|
|
|
|
| 697 |
noise_multiplier: parseFloat(document.getElementById('noise-multiplier').value),
|
| 698 |
batch_size: parseInt(document.getElementById('batch-size').value),
|
| 699 |
learning_rate: parseFloat(document.getElementById('learning-rate').value),
|
| 700 |
+
epochs: parseInt(document.getElementById('epochs').value),
|
| 701 |
+
dataset: document.getElementById('dataset-select').value,
|
| 702 |
+
model_architecture: document.getElementById('model-select').value
|
| 703 |
};
|
| 704 |
}
|
| 705 |
|
|
|
|
| 722 |
|
| 723 |
calculateEpochPrivacy(epoch) {
|
| 724 |
const params = this.getParameters();
|
| 725 |
+
|
| 726 |
+
// Get dataset size based on selection
|
| 727 |
+
let datasetSize;
|
| 728 |
+
switch(params.dataset) {
|
| 729 |
+
case 'cifar10':
|
| 730 |
+
datasetSize = 50000; // CIFAR-10 training set size
|
| 731 |
+
break;
|
| 732 |
+
case 'fashion-mnist':
|
| 733 |
+
datasetSize = 60000; // Fashion-MNIST training set size
|
| 734 |
+
break;
|
| 735 |
+
case 'mnist':
|
| 736 |
+
default:
|
| 737 |
+
datasetSize = 60000; // MNIST training set size
|
| 738 |
+
break;
|
| 739 |
+
}
|
| 740 |
+
|
| 741 |
+
const samplingRate = params.batch_size / datasetSize;
|
| 742 |
const steps = epoch * (1 / samplingRate);
|
| 743 |
const delta = 1e-5;
|
| 744 |
const c = Math.sqrt(2 * Math.log(1.25 / delta));
|
app/templates/index.html
CHANGED
|
@@ -39,6 +39,7 @@
|
|
| 39 |
<option value="simple-mlp">Simple MLP</option>
|
| 40 |
<option value="simple-cnn">Simple CNN</option>
|
| 41 |
<option value="advanced-cnn">Advanced CNN</option>
|
|
|
|
| 42 |
</select>
|
| 43 |
</div>
|
| 44 |
|
|
|
|
| 39 |
<option value="simple-mlp">Simple MLP</option>
|
| 40 |
<option value="simple-cnn">Simple CNN</option>
|
| 41 |
<option value="advanced-cnn">Advanced CNN</option>
|
| 42 |
+
<option value="resnet18">ResNet-18</option>
|
| 43 |
</select>
|
| 44 |
</div>
|
| 45 |
|
app/training/simplified_real_trainer.py
CHANGED
|
@@ -8,15 +8,32 @@ import logging
|
|
| 8 |
logging.getLogger('tensorflow').setLevel(logging.ERROR)
|
| 9 |
|
| 10 |
class SimplifiedRealTrainer:
|
| 11 |
-
def __init__(self):
|
| 12 |
# Set random seeds for reproducibility
|
| 13 |
tf.random.set_seed(42)
|
| 14 |
np.random.seed(42)
|
| 15 |
|
| 16 |
-
|
| 17 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
self.model = None
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
def _load_mnist(self):
|
| 21 |
"""Load and preprocess MNIST dataset."""
|
| 22 |
print("Loading MNIST dataset...")
|
|
@@ -28,9 +45,90 @@ class SimplifiedRealTrainer:
|
|
| 28 |
x_train = x_train.astype('float32') / 255.0
|
| 29 |
x_test = x_test.astype('float32') / 255.0
|
| 30 |
|
| 31 |
-
#
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
# Convert labels to categorical
|
| 36 |
y_train = keras.utils.to_categorical(y_train, 10)
|
|
@@ -42,15 +140,113 @@ class SimplifiedRealTrainer:
|
|
| 42 |
return x_train, y_train, x_test, y_test
|
| 43 |
|
| 44 |
def _create_model(self):
|
| 45 |
-
"""Create a
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
model = keras.Sequential([
|
| 48 |
-
keras.layers.Dense(256, activation='tanh', input_shape=
|
| 49 |
keras.layers.Dense(128, activation='tanh'),
|
| 50 |
-
keras.layers.Dense(
|
| 51 |
])
|
| 52 |
return model
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
def _clip_gradients(self, gradients, clipping_norm):
|
| 55 |
"""Clip gradients to a maximum L2 norm globally across all parameters."""
|
| 56 |
# Calculate global L2 norm across all gradients
|
|
|
|
| 8 |
logging.getLogger('tensorflow').setLevel(logging.ERROR)
|
| 9 |
|
| 10 |
class SimplifiedRealTrainer:
|
| 11 |
+
def __init__(self, dataset='mnist', model_architecture='simple-mlp'):
|
| 12 |
# Set random seeds for reproducibility
|
| 13 |
tf.random.set_seed(42)
|
| 14 |
np.random.seed(42)
|
| 15 |
|
| 16 |
+
self.dataset = dataset
|
| 17 |
+
self.model_architecture = model_architecture
|
| 18 |
+
self.input_shape = None
|
| 19 |
+
self.original_shape = None # For CNNs that need 2D/3D inputs
|
| 20 |
+
self.num_classes = 10
|
| 21 |
+
|
| 22 |
+
# Load and preprocess the specified dataset
|
| 23 |
+
self.x_train, self.y_train, self.x_test, self.y_test = self._load_dataset(dataset)
|
| 24 |
self.model = None
|
| 25 |
|
| 26 |
+
def _load_dataset(self, dataset):
|
| 27 |
+
"""Load and preprocess the specified dataset."""
|
| 28 |
+
if dataset == 'mnist':
|
| 29 |
+
return self._load_mnist()
|
| 30 |
+
elif dataset == 'cifar10':
|
| 31 |
+
return self._load_cifar10()
|
| 32 |
+
elif dataset == 'fashion-mnist':
|
| 33 |
+
return self._load_fashion_mnist()
|
| 34 |
+
else:
|
| 35 |
+
raise ValueError(f"Unsupported dataset: {dataset}")
|
| 36 |
+
|
| 37 |
def _load_mnist(self):
|
| 38 |
"""Load and preprocess MNIST dataset."""
|
| 39 |
print("Loading MNIST dataset...")
|
|
|
|
| 45 |
x_train = x_train.astype('float32') / 255.0
|
| 46 |
x_test = x_test.astype('float32') / 255.0
|
| 47 |
|
| 48 |
+
# Store original shape for CNNs (add channel dimension)
|
| 49 |
+
self.original_shape = (28, 28, 1)
|
| 50 |
+
|
| 51 |
+
# For MLPs, flatten the images; for CNNs, keep 2D shape
|
| 52 |
+
if self.model_architecture in ['simple-cnn', 'advanced-cnn', 'resnet18']:
|
| 53 |
+
x_train = x_train.reshape(-1, 28, 28, 1)
|
| 54 |
+
x_test = x_test.reshape(-1, 28, 28, 1)
|
| 55 |
+
self.input_shape = (28, 28, 1)
|
| 56 |
+
else:
|
| 57 |
+
x_train = x_train.reshape(-1, 28 * 28)
|
| 58 |
+
x_test = x_test.reshape(-1, 28 * 28)
|
| 59 |
+
self.input_shape = (784,)
|
| 60 |
+
|
| 61 |
+
self.num_classes = 10
|
| 62 |
+
|
| 63 |
+
# Convert labels to categorical
|
| 64 |
+
y_train = keras.utils.to_categorical(y_train, 10)
|
| 65 |
+
y_test = keras.utils.to_categorical(y_test, 10)
|
| 66 |
+
|
| 67 |
+
print(f"Training data shape: {x_train.shape}")
|
| 68 |
+
print(f"Test data shape: {x_test.shape}")
|
| 69 |
+
|
| 70 |
+
return x_train, y_train, x_test, y_test
|
| 71 |
+
|
| 72 |
+
def _load_cifar10(self):
|
| 73 |
+
"""Load and preprocess CIFAR-10 dataset."""
|
| 74 |
+
print("Loading CIFAR-10 dataset...")
|
| 75 |
+
|
| 76 |
+
# Load CIFAR-10 data
|
| 77 |
+
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
|
| 78 |
+
|
| 79 |
+
# Normalize pixel values to [0, 1]
|
| 80 |
+
x_train = x_train.astype('float32') / 255.0
|
| 81 |
+
x_test = x_test.astype('float32') / 255.0
|
| 82 |
+
|
| 83 |
+
# Store original shape for CNNs
|
| 84 |
+
self.original_shape = (32, 32, 3)
|
| 85 |
+
|
| 86 |
+
# For MLPs, flatten the images; for CNNs, keep 3D shape
|
| 87 |
+
if self.model_architecture in ['simple-cnn', 'advanced-cnn', 'resnet18']:
|
| 88 |
+
# Keep original shape for CNNs
|
| 89 |
+
self.input_shape = (32, 32, 3)
|
| 90 |
+
else:
|
| 91 |
+
# Flatten for MLPs
|
| 92 |
+
x_train = x_train.reshape(-1, 32 * 32 * 3)
|
| 93 |
+
x_test = x_test.reshape(-1, 32 * 32 * 3)
|
| 94 |
+
self.input_shape = (3072,)
|
| 95 |
+
|
| 96 |
+
self.num_classes = 10
|
| 97 |
+
|
| 98 |
+
# Convert labels to categorical
|
| 99 |
+
y_train = keras.utils.to_categorical(y_train, 10)
|
| 100 |
+
y_test = keras.utils.to_categorical(y_test, 10)
|
| 101 |
+
|
| 102 |
+
print(f"Training data shape: {x_train.shape}")
|
| 103 |
+
print(f"Test data shape: {x_test.shape}")
|
| 104 |
+
|
| 105 |
+
return x_train, y_train, x_test, y_test
|
| 106 |
+
|
| 107 |
+
def _load_fashion_mnist(self):
|
| 108 |
+
"""Load and preprocess Fashion-MNIST dataset."""
|
| 109 |
+
print("Loading Fashion-MNIST dataset...")
|
| 110 |
+
|
| 111 |
+
# Load Fashion-MNIST data
|
| 112 |
+
(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
|
| 113 |
+
|
| 114 |
+
# Normalize pixel values to [0, 1]
|
| 115 |
+
x_train = x_train.astype('float32') / 255.0
|
| 116 |
+
x_test = x_test.astype('float32') / 255.0
|
| 117 |
+
|
| 118 |
+
# Store original shape for CNNs (add channel dimension)
|
| 119 |
+
self.original_shape = (28, 28, 1)
|
| 120 |
+
|
| 121 |
+
# For MLPs, flatten the images; for CNNs, keep 2D shape
|
| 122 |
+
if self.model_architecture in ['simple-cnn', 'advanced-cnn', 'resnet18']:
|
| 123 |
+
x_train = x_train.reshape(-1, 28, 28, 1)
|
| 124 |
+
x_test = x_test.reshape(-1, 28, 28, 1)
|
| 125 |
+
self.input_shape = (28, 28, 1)
|
| 126 |
+
else:
|
| 127 |
+
x_train = x_train.reshape(-1, 28 * 28)
|
| 128 |
+
x_test = x_test.reshape(-1, 28 * 28)
|
| 129 |
+
self.input_shape = (784,)
|
| 130 |
+
|
| 131 |
+
self.num_classes = 10
|
| 132 |
|
| 133 |
# Convert labels to categorical
|
| 134 |
y_train = keras.utils.to_categorical(y_train, 10)
|
|
|
|
| 140 |
return x_train, y_train, x_test, y_test
|
| 141 |
|
| 142 |
def _create_model(self):
|
| 143 |
+
"""Create a model based on the specified architecture."""
|
| 144 |
+
if self.model_architecture == 'simple-mlp':
|
| 145 |
+
return self._create_simple_mlp()
|
| 146 |
+
elif self.model_architecture == 'simple-cnn':
|
| 147 |
+
return self._create_simple_cnn()
|
| 148 |
+
elif self.model_architecture == 'advanced-cnn':
|
| 149 |
+
return self._create_advanced_cnn()
|
| 150 |
+
elif self.model_architecture == 'resnet18':
|
| 151 |
+
return self._create_resnet18()
|
| 152 |
+
else:
|
| 153 |
+
raise ValueError(f"Unsupported model architecture: {self.model_architecture}")
|
| 154 |
+
|
| 155 |
+
def _create_simple_mlp(self):
|
| 156 |
+
"""Create a simple MLP model optimized for DP-SGD."""
|
| 157 |
model = keras.Sequential([
|
| 158 |
+
keras.layers.Dense(256, activation='tanh', input_shape=self.input_shape), # tanh works better with DP-SGD
|
| 159 |
keras.layers.Dense(128, activation='tanh'),
|
| 160 |
+
keras.layers.Dense(self.num_classes, activation='softmax')
|
| 161 |
])
|
| 162 |
return model
|
| 163 |
|
| 164 |
+
def _create_simple_cnn(self):
|
| 165 |
+
"""Create a simple CNN model optimized for DP-SGD."""
|
| 166 |
+
model = keras.Sequential([
|
| 167 |
+
keras.layers.Conv2D(32, (3, 3), activation='tanh', input_shape=self.input_shape),
|
| 168 |
+
keras.layers.MaxPooling2D((2, 2)),
|
| 169 |
+
keras.layers.Conv2D(64, (3, 3), activation='tanh'),
|
| 170 |
+
keras.layers.MaxPooling2D((2, 2)),
|
| 171 |
+
keras.layers.Flatten(),
|
| 172 |
+
keras.layers.Dense(128, activation='tanh'),
|
| 173 |
+
keras.layers.Dense(self.num_classes, activation='softmax')
|
| 174 |
+
])
|
| 175 |
+
return model
|
| 176 |
+
|
| 177 |
+
def _create_advanced_cnn(self):
|
| 178 |
+
"""Create an advanced CNN model optimized for DP-SGD."""
|
| 179 |
+
model = keras.Sequential([
|
| 180 |
+
keras.layers.Conv2D(32, (3, 3), activation='tanh', input_shape=self.input_shape),
|
| 181 |
+
keras.layers.BatchNormalization(),
|
| 182 |
+
keras.layers.Conv2D(32, (3, 3), activation='tanh'),
|
| 183 |
+
keras.layers.MaxPooling2D((2, 2)),
|
| 184 |
+
keras.layers.Dropout(0.25),
|
| 185 |
+
|
| 186 |
+
keras.layers.Conv2D(64, (3, 3), activation='tanh'),
|
| 187 |
+
keras.layers.BatchNormalization(),
|
| 188 |
+
keras.layers.Conv2D(64, (3, 3), activation='tanh'),
|
| 189 |
+
keras.layers.MaxPooling2D((2, 2)),
|
| 190 |
+
keras.layers.Dropout(0.25),
|
| 191 |
+
|
| 192 |
+
keras.layers.Flatten(),
|
| 193 |
+
keras.layers.Dense(256, activation='tanh'),
|
| 194 |
+
keras.layers.Dropout(0.5),
|
| 195 |
+
keras.layers.Dense(128, activation='tanh'),
|
| 196 |
+
keras.layers.Dense(self.num_classes, activation='softmax')
|
| 197 |
+
])
|
| 198 |
+
return model
|
| 199 |
+
|
| 200 |
+
def _create_resnet18(self):
|
| 201 |
+
"""Create a ResNet-18 model optimized for DP-SGD."""
|
| 202 |
+
def residual_block(x, filters, kernel_size=3, stride=1, conv_shortcut=False):
|
| 203 |
+
"""A residual block for ResNet."""
|
| 204 |
+
if conv_shortcut:
|
| 205 |
+
shortcut = keras.layers.Conv2D(filters, 1, strides=stride, padding='same')(x)
|
| 206 |
+
shortcut = keras.layers.BatchNormalization()(shortcut)
|
| 207 |
+
else:
|
| 208 |
+
shortcut = x
|
| 209 |
+
|
| 210 |
+
x = keras.layers.Conv2D(filters, kernel_size, strides=stride, padding='same')(x)
|
| 211 |
+
x = keras.layers.BatchNormalization()(x)
|
| 212 |
+
x = keras.layers.Activation('tanh')(x) # Use tanh for DP-SGD
|
| 213 |
+
|
| 214 |
+
x = keras.layers.Conv2D(filters, kernel_size, padding='same')(x)
|
| 215 |
+
x = keras.layers.BatchNormalization()(x)
|
| 216 |
+
|
| 217 |
+
x = keras.layers.Add()([shortcut, x])
|
| 218 |
+
x = keras.layers.Activation('tanh')(x)
|
| 219 |
+
return x
|
| 220 |
+
|
| 221 |
+
def resnet_block(x, filters, num_blocks, stride=1):
|
| 222 |
+
"""A stack of residual blocks."""
|
| 223 |
+
x = residual_block(x, filters, stride=stride, conv_shortcut=True)
|
| 224 |
+
for _ in range(num_blocks - 1):
|
| 225 |
+
x = residual_block(x, filters)
|
| 226 |
+
return x
|
| 227 |
+
|
| 228 |
+
# Input layer
|
| 229 |
+
inputs = keras.layers.Input(shape=self.input_shape)
|
| 230 |
+
|
| 231 |
+
# Initial convolution
|
| 232 |
+
x = keras.layers.Conv2D(64, 7, strides=2, padding='same')(inputs)
|
| 233 |
+
x = keras.layers.BatchNormalization()(x)
|
| 234 |
+
x = keras.layers.Activation('tanh')(x)
|
| 235 |
+
x = keras.layers.MaxPooling2D(3, strides=2, padding='same')(x)
|
| 236 |
+
|
| 237 |
+
# ResNet blocks
|
| 238 |
+
x = resnet_block(x, 64, 2)
|
| 239 |
+
x = resnet_block(x, 128, 2, stride=2)
|
| 240 |
+
x = resnet_block(x, 256, 2, stride=2)
|
| 241 |
+
x = resnet_block(x, 512, 2, stride=2)
|
| 242 |
+
|
| 243 |
+
# Global average pooling and output
|
| 244 |
+
x = keras.layers.GlobalAveragePooling2D()(x)
|
| 245 |
+
x = keras.layers.Dense(self.num_classes, activation='softmax')(x)
|
| 246 |
+
|
| 247 |
+
model = keras.Model(inputs, x)
|
| 248 |
+
return model
|
| 249 |
+
|
| 250 |
def _clip_gradients(self, gradients, clipping_norm):
|
| 251 |
"""Clip gradients to a maximum L2 norm globally across all parameters."""
|
| 252 |
# Calculate global L2 norm across all gradients
|