Shuya Feng commited on
Commit
8ad5d56
·
1 Parent(s): adae711

Update gradients clipping chart

Browse files
Procfile ADDED
@@ -0,0 +1 @@
 
 
1
+ web: gunicorn run:app
app/__init__.py CHANGED
@@ -2,8 +2,26 @@ from flask import Flask
2
  from flask_cors import CORS
3
 
4
  def create_app():
5
- app = Flask(__name__)
6
- CORS(app)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  # Register blueprints
9
  from app.routes import main
 
2
  from flask_cors import CORS
3
 
4
  def create_app():
5
+ app = Flask(__name__,
6
+ static_folder='static',
7
+ template_folder='templates')
8
+
9
+ # Configure CORS
10
+ CORS(app, resources={
11
+ r"/*": {
12
+ "origins": ["http://localhost:5000", "http://127.0.0.1:5000"],
13
+ "methods": ["GET", "POST", "OPTIONS"],
14
+ "allow_headers": ["Content-Type"]
15
+ }
16
+ })
17
+
18
+ # Configure security headers
19
+ @app.after_request
20
+ def add_security_headers(response):
21
+ response.headers['Access-Control-Allow-Origin'] = '*'
22
+ response.headers['Access-Control-Allow-Methods'] = 'GET, POST, OPTIONS'
23
+ response.headers['Access-Control-Allow-Headers'] = 'Content-Type'
24
+ return response
25
 
26
  # Register blueprints
27
  from app.routes import main
app/routes.py CHANGED
@@ -1,6 +1,7 @@
1
- from flask import Blueprint, render_template, jsonify, request
2
  from app.training.mock_trainer import MockTrainer
3
  from app.training.privacy_calculator import PrivacyCalculator
 
4
 
5
  main = Blueprint('main', __name__)
6
  mock_trainer = MockTrainer()
@@ -14,30 +15,61 @@ def index():
14
  def learning():
15
  return render_template('learning.html')
16
 
17
- @main.route('/api/train', methods=['POST'])
 
18
  def train():
19
- data = request.json
20
- params = {
21
- 'clipping_norm': float(data.get('clipping_norm', 1.0)),
22
- 'noise_multiplier': float(data.get('noise_multiplier', 1.0)),
23
- 'batch_size': int(data.get('batch_size', 64)),
24
- 'learning_rate': float(data.get('learning_rate', 0.01)),
25
- 'epochs': int(data.get('epochs', 5))
26
- }
27
-
28
- # Get mock training results
29
- results = mock_trainer.train(params)
30
- return jsonify(results)
31
-
32
- @main.route('/api/privacy-budget', methods=['POST'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def calculate_privacy_budget():
34
- data = request.json
35
- params = {
36
- 'clipping_norm': float(data.get('clipping_norm', 1.0)),
37
- 'noise_multiplier': float(data.get('noise_multiplier', 1.0)),
38
- 'batch_size': int(data.get('batch_size', 64)),
39
- 'epochs': int(data.get('epochs', 5))
40
- }
41
-
42
- epsilon = privacy_calculator.calculate_epsilon(params)
43
- return jsonify({'epsilon': epsilon})
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Blueprint, render_template, jsonify, request, current_app
2
  from app.training.mock_trainer import MockTrainer
3
  from app.training.privacy_calculator import PrivacyCalculator
4
+ from flask_cors import cross_origin
5
 
6
  main = Blueprint('main', __name__)
7
  mock_trainer = MockTrainer()
 
15
  def learning():
16
  return render_template('learning.html')
17
 
18
+ @main.route('/api/train', methods=['POST', 'OPTIONS'])
19
+ @cross_origin()
20
  def train():
21
+ if request.method == 'OPTIONS':
22
+ return jsonify({'status': 'ok'})
23
+
24
+ try:
25
+ data = request.json
26
+ if not data:
27
+ return jsonify({'error': 'No data provided'}), 400
28
+
29
+ params = {
30
+ 'clipping_norm': float(data.get('clipping_norm', 1.0)),
31
+ 'noise_multiplier': float(data.get('noise_multiplier', 1.0)),
32
+ 'batch_size': int(data.get('batch_size', 64)),
33
+ 'learning_rate': float(data.get('learning_rate', 0.01)),
34
+ 'epochs': int(data.get('epochs', 5))
35
+ }
36
+
37
+ # Get mock training results
38
+ results = mock_trainer.train(params)
39
+
40
+ # Add gradient information for visualization
41
+ results['gradient_info'] = {
42
+ 'before_clipping': mock_trainer.generate_gradient_norms(params['clipping_norm']),
43
+ 'after_clipping': mock_trainer.generate_clipped_gradients(params['clipping_norm'])
44
+ }
45
+
46
+ return jsonify(results)
47
+ except (TypeError, ValueError) as e:
48
+ return jsonify({'error': f'Invalid parameter values: {str(e)}'}), 400
49
+ except Exception as e:
50
+ return jsonify({'error': f'Server error: {str(e)}'}), 500
51
+
52
+ @main.route('/api/privacy-budget', methods=['POST', 'OPTIONS'])
53
+ @cross_origin()
54
  def calculate_privacy_budget():
55
+ if request.method == 'OPTIONS':
56
+ return jsonify({'status': 'ok'})
57
+
58
+ try:
59
+ data = request.json
60
+ if not data:
61
+ return jsonify({'error': 'No data provided'}), 400
62
+
63
+ params = {
64
+ 'clipping_norm': float(data.get('clipping_norm', 1.0)),
65
+ 'noise_multiplier': float(data.get('noise_multiplier', 1.0)),
66
+ 'batch_size': int(data.get('batch_size', 64)),
67
+ 'epochs': int(data.get('epochs', 5))
68
+ }
69
+
70
+ epsilon = privacy_calculator.calculate_epsilon(params)
71
+ return jsonify({'epsilon': epsilon})
72
+ except (TypeError, ValueError) as e:
73
+ return jsonify({'error': f'Invalid parameter values: {str(e)}'}), 400
74
+ except Exception as e:
75
+ return jsonify({'error': f'Server error: {str(e)}'}), 500
app/static/css/styles.css CHANGED
@@ -204,16 +204,15 @@ body {
204
 
205
  /* Charts */
206
  .chart-container {
 
207
  height: 300px;
 
208
  margin-bottom: 1rem;
209
- position: relative;
210
  }
211
 
212
- .chart {
213
- width: 100%;
214
- height: 100%;
215
- border: 1px solid var(--border-color);
216
- border-radius: 4px;
217
  }
218
 
219
  /* Metrics */
@@ -295,30 +294,35 @@ body {
295
  .status-badge {
296
  display: flex;
297
  align-items: center;
298
- margin-top: 1rem;
299
- padding: 0.5rem;
300
- background-color: var(--background-off);
301
  border-radius: 4px;
 
302
  }
303
 
304
  .pulse {
305
  display: inline-block;
306
- width: 10px;
307
- height: 10px;
308
  border-radius: 50%;
309
- background: var(--secondary-color);
310
- margin-right: 0.5rem;
311
- animation: pulse 1.5s infinite;
312
  }
313
 
314
  @keyframes pulse {
315
  0% {
 
316
  box-shadow: 0 0 0 0 rgba(76, 175, 80, 0.7);
317
  }
 
318
  70% {
 
319
  box-shadow: 0 0 0 10px rgba(76, 175, 80, 0);
320
  }
 
321
  100% {
 
322
  box-shadow: 0 0 0 0 rgba(76, 175, 80, 0);
323
  }
324
  }
@@ -454,4 +458,26 @@ body {
454
 
455
  .concept-box .box2 {
456
  background-color: #fff8e1;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
  }
 
204
 
205
  /* Charts */
206
  .chart-container {
207
+ position: relative;
208
  height: 300px;
209
+ width: 100%;
210
  margin-bottom: 1rem;
 
211
  }
212
 
213
+ .chart-container canvas {
214
+ width: 100% !important;
215
+ height: 100% !important;
 
 
216
  }
217
 
218
  /* Metrics */
 
294
  .status-badge {
295
  display: flex;
296
  align-items: center;
297
+ gap: 1rem;
298
+ padding: 0.5rem 1rem;
299
+ background-color: #f5f5f5;
300
  border-radius: 4px;
301
+ margin-top: 1rem;
302
  }
303
 
304
  .pulse {
305
  display: inline-block;
306
+ width: 8px;
307
+ height: 8px;
308
  border-radius: 50%;
309
+ background-color: #4caf50;
310
+ animation: pulse 1s infinite;
 
311
  }
312
 
313
  @keyframes pulse {
314
  0% {
315
+ transform: scale(0.95);
316
  box-shadow: 0 0 0 0 rgba(76, 175, 80, 0.7);
317
  }
318
+
319
  70% {
320
+ transform: scale(1);
321
  box-shadow: 0 0 0 10px rgba(76, 175, 80, 0);
322
  }
323
+
324
  100% {
325
+ transform: scale(0.95);
326
  box-shadow: 0 0 0 0 rgba(76, 175, 80, 0);
327
  }
328
  }
 
458
 
459
  .concept-box .box2 {
460
  background-color: #fff8e1;
461
+ }
462
+
463
+ /* Error Message */
464
+ .error-message {
465
+ background-color: #ffebee;
466
+ color: #c62828;
467
+ padding: 1rem;
468
+ margin-bottom: 1rem;
469
+ border-radius: 4px;
470
+ border-left: 4px solid #c62828;
471
+ animation: slideIn 0.3s ease-out;
472
+ }
473
+
474
+ @keyframes slideIn {
475
+ from {
476
+ transform: translateY(-20px);
477
+ opacity: 0;
478
+ }
479
+ to {
480
+ transform: translateY(0);
481
+ opacity: 1;
482
+ }
483
  }
app/static/js/main.js CHANGED
@@ -2,6 +2,7 @@ class DPSGDExplorer {
2
  constructor() {
3
  this.trainingChart = null;
4
  this.privacyChart = null;
 
5
  this.isTraining = false;
6
  this.initializeUI();
7
  }
@@ -31,11 +32,29 @@ class DPSGDExplorer {
31
  for (const [id, slider] of Object.entries(sliders)) {
32
  if (slider) {
33
  slider.addEventListener('input', (e) => {
34
- document.getElementById(`${id}-value`).textContent = e.target.value;
 
 
 
35
  this.updatePrivacyBudget();
 
 
 
 
 
36
  });
37
  }
38
  }
 
 
 
 
 
 
 
 
 
 
39
  }
40
 
41
  initializePresets() {
@@ -92,6 +111,7 @@ class DPSGDExplorer {
92
  initializeCharts() {
93
  const trainingCtx = document.getElementById('training-chart')?.getContext('2d');
94
  const privacyCtx = document.getElementById('privacy-chart')?.getContext('2d');
 
95
 
96
  if (trainingCtx) {
97
  this.trainingChart = new Chart(trainingCtx, {
@@ -115,6 +135,7 @@ class DPSGDExplorer {
115
  },
116
  options: {
117
  responsive: true,
 
118
  interaction: {
119
  mode: 'index',
120
  intersect: false,
@@ -127,7 +148,9 @@ class DPSGDExplorer {
127
  title: {
128
  display: true,
129
  text: 'Accuracy (%)'
130
- }
 
 
131
  },
132
  y1: {
133
  type: 'linear',
@@ -137,6 +160,8 @@ class DPSGDExplorer {
137
  display: true,
138
  text: 'Loss'
139
  },
 
 
140
  grid: {
141
  drawOnChartArea: false,
142
  },
@@ -159,17 +184,80 @@ class DPSGDExplorer {
159
  },
160
  options: {
161
  responsive: true,
 
162
  scales: {
163
  y: {
 
164
  title: {
165
  display: true,
166
  text: 'Privacy Budget (ε)'
167
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  x: {
 
 
 
 
 
 
 
 
 
 
 
170
  title: {
171
  display: true,
172
- text: 'Epoch'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  }
174
  }
175
  }
@@ -236,6 +324,8 @@ class DPSGDExplorer {
236
  this.resetCharts();
237
 
238
  try {
 
 
239
  const response = await fetch('/api/train', {
240
  method: 'POST',
241
  headers: {
@@ -245,10 +335,28 @@ class DPSGDExplorer {
245
  });
246
 
247
  const data = await response.json();
 
 
 
 
 
 
 
 
248
  this.updateCharts(data.epochs_data);
249
  this.updateResults(data);
250
  } catch (error) {
251
  console.error('Training error:', error);
 
 
 
 
 
 
 
 
 
 
252
  } finally {
253
  this.stopTraining();
254
  }
@@ -277,12 +385,21 @@ class DPSGDExplorer {
277
  this.privacyChart.data.datasets[0].data = [];
278
  this.privacyChart.update();
279
  }
 
 
 
 
 
 
280
  }
281
 
282
  updateCharts(epochsData) {
283
  if (!this.trainingChart || !epochsData) return;
284
 
285
- const labels = epochsData.map(d => d.epoch);
 
 
 
286
  const accuracies = epochsData.map(d => d.accuracy);
287
  const losses = epochsData.map(d => d.loss);
288
 
@@ -291,14 +408,85 @@ class DPSGDExplorer {
291
  this.trainingChart.data.datasets[1].data = losses;
292
  this.trainingChart.update();
293
 
294
- // Update privacy chart
 
 
 
 
 
 
 
 
295
  if (this.privacyChart) {
296
- this.privacyChart.data.labels = labels;
297
- this.privacyChart.data.datasets[0].data = epochsData.map((_, i) =>
298
  this.calculateEpochPrivacy(i + 1)
299
  );
 
 
300
  this.privacyChart.update();
301
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
  }
303
 
304
  updateResults(data) {
@@ -363,6 +551,95 @@ class DPSGDExplorer {
363
  const c = Math.sqrt(2 * Math.log(1.25 / delta));
364
  return Math.min((c * samplingRate * Math.sqrt(steps)) / params.noise_multiplier, 10);
365
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  }
367
 
368
  // Initialize the application when the DOM is loaded
 
2
  constructor() {
3
  this.trainingChart = null;
4
  this.privacyChart = null;
5
+ this.gradientChart = null;
6
  this.isTraining = false;
7
  this.initializeUI();
8
  }
 
32
  for (const [id, slider] of Object.entries(sliders)) {
33
  if (slider) {
34
  slider.addEventListener('input', (e) => {
35
+ const value = parseFloat(e.target.value);
36
+ document.getElementById(`${id}-value`).textContent = value.toFixed(1);
37
+
38
+ // Update privacy budget
39
  this.updatePrivacyBudget();
40
+
41
+ // Update gradient visualization when clipping norm changes
42
+ if (id === 'clipping-norm') {
43
+ this.updateGradientVisualization(value);
44
+ }
45
  });
46
  }
47
  }
48
+
49
+ // Add event listener for the visual clipping norm slider
50
+ const visualSlider = document.getElementById('clipping-norm-visual');
51
+ if (visualSlider) {
52
+ visualSlider.addEventListener('input', (e) => {
53
+ const value = parseFloat(e.target.value);
54
+ document.getElementById('clipping-norm-visual-value').textContent = value.toFixed(1);
55
+ this.updateGradientVisualization(value);
56
+ });
57
+ }
58
  }
59
 
60
  initializePresets() {
 
111
  initializeCharts() {
112
  const trainingCtx = document.getElementById('training-chart')?.getContext('2d');
113
  const privacyCtx = document.getElementById('privacy-chart')?.getContext('2d');
114
+ const gradientCtx = document.getElementById('gradient-chart')?.getContext('2d');
115
 
116
  if (trainingCtx) {
117
  this.trainingChart = new Chart(trainingCtx, {
 
135
  },
136
  options: {
137
  responsive: true,
138
+ maintainAspectRatio: false,
139
  interaction: {
140
  mode: 'index',
141
  intersect: false,
 
148
  title: {
149
  display: true,
150
  text: 'Accuracy (%)'
151
+ },
152
+ min: 0,
153
+ max: 100
154
  },
155
  y1: {
156
  type: 'linear',
 
160
  display: true,
161
  text: 'Loss'
162
  },
163
+ min: 0,
164
+ max: 2,
165
  grid: {
166
  drawOnChartArea: false,
167
  },
 
184
  },
185
  options: {
186
  responsive: true,
187
+ maintainAspectRatio: false,
188
  scales: {
189
  y: {
190
+ beginAtZero: true,
191
  title: {
192
  display: true,
193
  text: 'Privacy Budget (ε)'
194
  }
195
+ }
196
+ }
197
+ }
198
+ });
199
+ }
200
+
201
+ if (gradientCtx) {
202
+ this.gradientChart = new Chart(gradientCtx, {
203
+ type: 'scatter',
204
+ data: {
205
+ datasets: [
206
+ {
207
+ label: 'Before Clipping',
208
+ borderColor: '#2196f3',
209
+ backgroundColor: 'rgba(33, 150, 243, 0.1)',
210
+ data: [],
211
+ showLine: true
212
  },
213
+ {
214
+ label: 'After Clipping',
215
+ borderColor: '#f44336',
216
+ backgroundColor: 'rgba(244, 67, 54, 0.1)',
217
+ data: [],
218
+ showLine: true
219
+ }
220
+ ]
221
+ },
222
+ options: {
223
+ responsive: true,
224
+ maintainAspectRatio: false,
225
+ scales: {
226
  x: {
227
+ type: 'linear',
228
+ position: 'bottom',
229
+ title: {
230
+ display: true,
231
+ text: 'Gradient Norm'
232
+ },
233
+ min: 0
234
+ },
235
+ y: {
236
+ type: 'linear',
237
+ position: 'left',
238
  title: {
239
  display: true,
240
+ text: 'Density'
241
+ },
242
+ min: 0
243
+ }
244
+ },
245
+ plugins: {
246
+ annotation: {
247
+ annotations: {
248
+ line1: {
249
+ type: 'line',
250
+ xMin: 1,
251
+ xMax: 1,
252
+ borderColor: '#f44336',
253
+ borderWidth: 2,
254
+ borderDash: [5, 5],
255
+ label: {
256
+ content: 'Clipping Threshold',
257
+ display: true,
258
+ position: 'top'
259
+ }
260
+ }
261
  }
262
  }
263
  }
 
324
  this.resetCharts();
325
 
326
  try {
327
+ console.log('Starting training with parameters:', this.getParameters()); // Debug log
328
+
329
  const response = await fetch('/api/train', {
330
  method: 'POST',
331
  headers: {
 
335
  });
336
 
337
  const data = await response.json();
338
+
339
+ if (!response.ok) {
340
+ throw new Error(data.error || 'Unknown error occurred');
341
+ }
342
+
343
+ console.log('Received training data:', data); // Debug log
344
+
345
+ // Update charts and results
346
  this.updateCharts(data.epochs_data);
347
  this.updateResults(data);
348
  } catch (error) {
349
  console.error('Training error:', error);
350
+ // Show error message to user
351
+ const errorMessage = document.createElement('div');
352
+ errorMessage.className = 'error-message';
353
+ errorMessage.textContent = error.message || 'An error occurred during training';
354
+ document.querySelector('.lab-main').insertBefore(errorMessage, document.querySelector('.lab-main').firstChild);
355
+
356
+ // Remove error message after 5 seconds
357
+ setTimeout(() => {
358
+ errorMessage.remove();
359
+ }, 5000);
360
  } finally {
361
  this.stopTraining();
362
  }
 
385
  this.privacyChart.data.datasets[0].data = [];
386
  this.privacyChart.update();
387
  }
388
+
389
+ if (this.gradientChart) {
390
+ this.gradientChart.data.datasets[0].data = [];
391
+ this.gradientChart.data.datasets[1].data = [];
392
+ this.gradientChart.update();
393
+ }
394
  }
395
 
396
  updateCharts(epochsData) {
397
  if (!this.trainingChart || !epochsData) return;
398
 
399
+ console.log('Updating charts with data:', epochsData); // Debug log
400
+
401
+ // Update training metrics chart
402
+ const labels = epochsData.map(d => `Epoch ${d.epoch}`);
403
  const accuracies = epochsData.map(d => d.accuracy);
404
  const losses = epochsData.map(d => d.loss);
405
 
 
408
  this.trainingChart.data.datasets[1].data = losses;
409
  this.trainingChart.update();
410
 
411
+ // Update current epoch display
412
+ const currentEpoch = document.getElementById('current-epoch');
413
+ const totalEpochs = document.getElementById('total-epochs');
414
+ if (currentEpoch && totalEpochs) {
415
+ currentEpoch.textContent = epochsData.length;
416
+ totalEpochs.textContent = this.getParameters().epochs;
417
+ }
418
+
419
+ // Update privacy budget chart
420
  if (this.privacyChart) {
421
+ const privacyBudgets = epochsData.map((_, i) =>
 
422
  this.calculateEpochPrivacy(i + 1)
423
  );
424
+ this.privacyChart.data.labels = labels;
425
+ this.privacyChart.data.datasets[0].data = privacyBudgets;
426
  this.privacyChart.update();
427
  }
428
+
429
+ // Update gradient visualization
430
+ if (this.gradientChart) {
431
+ const clippingNorm = this.getParameters().clipping_norm;
432
+
433
+ // Generate gradient data if not provided in epochsData
434
+ let gradientData;
435
+ if (epochsData[epochsData.length - 1]?.gradient_info) {
436
+ gradientData = epochsData[epochsData.length - 1].gradient_info;
437
+ } else {
438
+ // Generate synthetic gradient data
439
+ const beforeClipping = [];
440
+ const afterClipping = [];
441
+
442
+ // Generate log-normal distributed gradients
443
+ const mu = Math.log(clippingNorm) - 0.5;
444
+ const sigma = 0.8;
445
+
446
+ for (let i = 0; i < 100; i++) {
447
+ const u1 = Math.random();
448
+ const u2 = Math.random();
449
+ const z = Math.sqrt(-2.0 * Math.log(u1)) * Math.cos(2.0 * Math.PI * u2);
450
+ const norm = Math.exp(mu + sigma * z);
451
+
452
+ const density = Math.exp(-(Math.pow(Math.log(norm) - mu, 2) / (2 * sigma * sigma))) /
453
+ (norm * sigma * Math.sqrt(2 * Math.PI));
454
+ const y = 0.2 + 0.8 * (density / 0.8) + 0.1 * (Math.random() - 0.5);
455
+
456
+ beforeClipping.push({ x: norm, y: y });
457
+ afterClipping.push({ x: Math.min(norm, clippingNorm), y: y });
458
+ }
459
+
460
+ gradientData = {
461
+ before_clipping: beforeClipping.sort((a, b) => a.x - b.x),
462
+ after_clipping: afterClipping.sort((a, b) => a.x - b.x)
463
+ };
464
+ }
465
+
466
+ // Update gradient chart
467
+ this.gradientChart.data.datasets[0].data = gradientData.before_clipping;
468
+ this.gradientChart.data.datasets[1].data = gradientData.after_clipping;
469
+
470
+ // Update clipping threshold line
471
+ this.gradientChart.options.plugins.annotation.annotations.line1 = {
472
+ type: 'line',
473
+ xMin: clippingNorm,
474
+ xMax: clippingNorm,
475
+ borderColor: '#f44336',
476
+ borderWidth: 2,
477
+ borderDash: [5, 5],
478
+ label: {
479
+ content: `Clipping Threshold (C=${clippingNorm.toFixed(1)})`,
480
+ display: true,
481
+ position: 'top'
482
+ }
483
+ };
484
+
485
+ // Update x-axis scale based on clipping norm
486
+ this.gradientChart.options.scales.x.max = Math.max(clippingNorm * 2.5, 5);
487
+
488
+ this.gradientChart.update('active');
489
+ }
490
  }
491
 
492
  updateResults(data) {
 
551
  const c = Math.sqrt(2 * Math.log(1.25 / delta));
552
  return Math.min((c * samplingRate * Math.sqrt(steps)) / params.noise_multiplier, 10);
553
  }
554
+
555
+ updateGradientVisualization(clippingNorm) {
556
+ if (!this.gradientChart) return;
557
+
558
+ // Generate random gradient norms following a log-normal distribution
559
+ const numPoints = 100;
560
+ const beforeClipping = [];
561
+ const afterClipping = [];
562
+
563
+ // Parameters for log-normal distribution
564
+ const mu = Math.log(clippingNorm) - 0.5;
565
+ const sigma = 0.8;
566
+
567
+ // Generate gradient norms
568
+ for (let i = 0; i < numPoints; i++) {
569
+ // Generate log-normal distributed gradient norms
570
+ const u1 = Math.random();
571
+ const u2 = Math.random();
572
+ const z = Math.sqrt(-2.0 * Math.log(u1)) * Math.cos(2.0 * Math.PI * u2);
573
+ const norm = Math.exp(mu + sigma * z);
574
+
575
+ // Calculate density using kernel density estimation
576
+ const density = Math.exp(-(Math.pow(Math.log(norm) - mu, 2) / (2 * sigma * sigma))) / (norm * sigma * Math.sqrt(2 * Math.PI));
577
+
578
+ // Normalize density and add some randomness
579
+ const y = 0.2 + 0.8 * (density / 0.8) + 0.1 * (Math.random() - 0.5);
580
+
581
+ beforeClipping.push({ x: norm, y: y });
582
+ afterClipping.push({ x: Math.min(norm, clippingNorm), y: y });
583
+ }
584
+
585
+ // Sort points by x-value for smoother lines
586
+ beforeClipping.sort((a, b) => a.x - b.x);
587
+ afterClipping.sort((a, b) => a.x - b.x);
588
+
589
+ // Update chart data
590
+ this.gradientChart.data.datasets[0].data = beforeClipping;
591
+ this.gradientChart.data.datasets[1].data = afterClipping;
592
+
593
+ // Update clipping threshold line
594
+ this.gradientChart.options.plugins.annotation.annotations.line1 = {
595
+ type: 'line',
596
+ xMin: clippingNorm,
597
+ xMax: clippingNorm,
598
+ borderColor: '#f44336',
599
+ borderWidth: 2,
600
+ borderDash: [5, 5],
601
+ label: {
602
+ content: `Clipping Threshold (C=${clippingNorm.toFixed(1)})`,
603
+ display: true,
604
+ position: 'top'
605
+ }
606
+ };
607
+
608
+ // Update x-axis scale based on clipping norm
609
+ this.gradientChart.options.scales.x.max = Math.max(clippingNorm * 2.5, 5);
610
+
611
+ // Update the chart with animation
612
+ this.gradientChart.update('active');
613
+ }
614
+
615
+ updateGradientVisualizationWithData(beforeClipping, afterClipping, clippingNorm) {
616
+ if (!this.gradientChart) return;
617
+
618
+ // Update chart data with real training data
619
+ this.gradientChart.data.datasets[0].data = beforeClipping;
620
+ this.gradientChart.data.datasets[1].data = afterClipping;
621
+
622
+ // Update clipping threshold line
623
+ this.gradientChart.options.plugins.annotation.annotations.line1 = {
624
+ type: 'line',
625
+ xMin: clippingNorm,
626
+ xMax: clippingNorm,
627
+ borderColor: '#f44336',
628
+ borderWidth: 2,
629
+ borderDash: [5, 5],
630
+ label: {
631
+ content: `Clipping Threshold (C=${clippingNorm.toFixed(1)})`,
632
+ display: true,
633
+ position: 'top'
634
+ }
635
+ };
636
+
637
+ // Update x-axis scale based on clipping norm
638
+ this.gradientChart.options.scales.x.max = Math.max(clippingNorm * 2.5, 5);
639
+
640
+ // Update the chart with animation
641
+ this.gradientChart.update('active');
642
+ }
643
  }
644
 
645
  // Initialize the application when the DOM is loaded
app/templates/base.html CHANGED
@@ -7,6 +7,11 @@
7
  <link rel="stylesheet" href="{{ url_for('static', filename='css/styles.css') }}">
8
  <link rel="stylesheet" href="{{ url_for('static', filename='css/learning.css') }}">
9
  <script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
 
 
 
 
 
10
  {% block extra_head %}{% endblock %}
11
  </head>
12
  <body>
 
7
  <link rel="stylesheet" href="{{ url_for('static', filename='css/styles.css') }}">
8
  <link rel="stylesheet" href="{{ url_for('static', filename='css/learning.css') }}">
9
  <script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
10
+ <script src="https://cdn.jsdelivr.net/npm/chartjs-plugin-annotation"></script>
11
+ <script>
12
+ // Register the annotation plugin
13
+ Chart.register(ChartAnnotation);
14
+ </script>
15
  {% block extra_head %}{% endblock %}
16
  </head>
17
  <body>
app/templates/index.html CHANGED
@@ -190,8 +190,8 @@
190
  </div>
191
 
192
  <div id="training-tab" class="tab-content active">
193
- <div class="chart-container">
194
- <canvas id="training-chart" class="chart"></canvas>
195
  </div>
196
 
197
  <div id="training-status" class="status-badge" style="display: none;">
@@ -214,8 +214,8 @@
214
  </p>
215
  </div>
216
 
217
- <div class="canvas-container">
218
- <canvas id="gradient-canvas" width="600" height="300"></canvas>
219
  </div>
220
  </div>
221
 
@@ -231,8 +231,8 @@
231
  </p>
232
  </div>
233
 
234
- <div class="chart-container">
235
- <canvas id="privacy-chart" class="chart"></canvas>
236
  </div>
237
  </div>
238
  </div>
 
190
  </div>
191
 
192
  <div id="training-tab" class="tab-content active">
193
+ <div class="chart-container" style="position: relative; height: 300px; width: 100%;">
194
+ <canvas id="training-chart"></canvas>
195
  </div>
196
 
197
  <div id="training-status" class="status-badge" style="display: none;">
 
214
  </p>
215
  </div>
216
 
217
+ <div class="chart-container">
218
+ <canvas id="gradient-chart" class="chart"></canvas>
219
  </div>
220
  </div>
221
 
 
231
  </p>
232
  </div>
233
 
234
+ <div class="chart-container" style="position: relative; height: 300px; width: 100%;">
235
+ <canvas id="privacy-chart"></canvas>
236
  </div>
237
  </div>
238
  </div>
app/training/mock_trainer.py CHANGED
@@ -41,10 +41,17 @@ class MockTrainer:
41
  # Generate recommendations
42
  recommendations = self._generate_recommendations(params, final_metrics)
43
 
 
 
 
 
 
 
44
  return {
45
  'epochs_data': epochs_data,
46
  'final_metrics': final_metrics,
47
- 'recommendations': recommendations
 
48
  }
49
 
50
  def _calculate_privacy_factor(self, clipping_norm: float, noise_multiplier: float) -> float:
@@ -149,4 +156,32 @@ class MockTrainer:
149
  'text': 'Model accuracy is low. Consider adjusting privacy parameters.'
150
  })
151
 
152
- return recommendations
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  # Generate recommendations
42
  recommendations = self._generate_recommendations(params, final_metrics)
43
 
44
+ # Generate gradient information
45
+ gradient_info = {
46
+ 'before_clipping': self.generate_gradient_norms(clipping_norm),
47
+ 'after_clipping': self.generate_clipped_gradients(clipping_norm)
48
+ }
49
+
50
  return {
51
  'epochs_data': epochs_data,
52
  'final_metrics': final_metrics,
53
+ 'recommendations': recommendations,
54
+ 'gradient_info': gradient_info
55
  }
56
 
57
  def _calculate_privacy_factor(self, clipping_norm: float, noise_multiplier: float) -> float:
 
156
  'text': 'Model accuracy is low. Consider adjusting privacy parameters.'
157
  })
158
 
159
+ return recommendations
160
+
161
+ def generate_gradient_norms(self, clipping_norm: float) -> List[Dict[str, float]]:
162
+ """Generate realistic gradient norms following a log-normal distribution."""
163
+ num_points = 100
164
+ gradients = []
165
+
166
+ # Parameters for log-normal distribution
167
+ mu = np.log(clipping_norm) - 0.5
168
+ sigma = 0.8
169
+
170
+ for _ in range(num_points):
171
+ # Generate log-normal distributed gradient norms
172
+ u1, u2 = np.random.random(2)
173
+ z = np.sqrt(-2.0 * np.log(u1)) * np.cos(2.0 * np.pi * u2)
174
+ norm = np.exp(mu + sigma * z)
175
+
176
+ # Calculate density using kernel density estimation
177
+ density = np.exp(-(np.power(np.log(norm) - mu, 2) / (2 * sigma * sigma))) / (norm * sigma * np.sqrt(2 * np.pi))
178
+ density = 0.2 + 0.8 * (density / 0.8) + 0.1 * (np.random.random() - 0.5)
179
+
180
+ gradients.append({'x': float(norm), 'y': float(density)})
181
+
182
+ return sorted(gradients, key=lambda x: x['x'])
183
+
184
+ def generate_clipped_gradients(self, clipping_norm: float) -> List[Dict[str, float]]:
185
+ """Generate clipped versions of the gradient norms."""
186
+ original_gradients = self.generate_gradient_norms(clipping_norm)
187
+ return [{'x': min(g['x'], clipping_norm), 'y': g['y']} for g in original_gradients]
run.py CHANGED
@@ -1,6 +1,12 @@
1
  from app import create_app
 
2
 
3
  app = create_app()
4
 
5
  if __name__ == '__main__':
6
- app.run(debug=True)
 
 
 
 
 
 
1
  from app import create_app
2
+ import os
3
 
4
  app = create_app()
5
 
6
  if __name__ == '__main__':
7
+ # Enable debug mode for development
8
+ app.config['DEBUG'] = True
9
+ # Disable CORS in development
10
+ app.config['CORS_HEADERS'] = 'Content-Type'
11
+ # Run the application
12
+ app.run(host='127.0.0.1', port=5000, debug=True)
runtime.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python-3.8.12