File size: 4,876 Bytes
b0b2c21 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
#!/usr/bin/env python3
"""
Test script to verify MNIST training with DP-SGD works correctly.
Run this script to test the real trainer implementation.
"""
import sys
import os
sys.path.append('.')
def test_real_trainer():
"""Test the real trainer with MNIST dataset."""
print("Testing Real Trainer with MNIST Dataset")
print("=" * 50)
try:
try:
from app.training.simplified_real_trainer import SimplifiedRealTrainer as RealTrainer
print("β
Successfully imported SimplifiedRealTrainer")
except ImportError:
from app.training.real_trainer import RealTrainer
print("β
Successfully imported RealTrainer")
# Initialize trainer
trainer = RealTrainer()
print("β
Successfully initialized RealTrainer")
print(f"β
Training data shape: {trainer.x_train.shape}")
print(f"β
Test data shape: {trainer.x_test.shape}")
# Test with small parameters for quick execution
test_params = {
'clipping_norm': 1.0,
'noise_multiplier': 1.1,
'batch_size': 128,
'learning_rate': 0.01,
'epochs': 2 # Small number for testing
}
print(f"\nTraining with parameters: {test_params}")
results = trainer.train(test_params)
print(f"\nβ
Training completed successfully!")
print(f"Final accuracy: {results['final_metrics']['accuracy']:.2f}%")
print(f"Final loss: {results['final_metrics']['loss']:.4f}")
print(f"Training time: {results['final_metrics']['training_time']:.2f} seconds")
if 'privacy_budget' in results:
print(f"Privacy budget (Ξ΅): {results['privacy_budget']:.2f}")
print(f"Number of epochs recorded: {len(results['epochs_data'])}")
print(f"Number of recommendations: {len(results['recommendations'])}")
return True
except ImportError as e:
print(f"β Import Error: {e}")
print("Make sure TensorFlow and TensorFlow Privacy are installed:")
print("pip install tensorflow==2.15.0 tensorflow-privacy==0.9.0")
return False
except Exception as e:
print(f"β Training Error: {e}")
return False
def test_mock_trainer():
"""Test the mock trainer as fallback."""
print("\nTesting Mock Trainer (Fallback)")
print("=" * 50)
try:
from app.training.mock_trainer import MockTrainer
trainer = MockTrainer()
test_params = {
'clipping_norm': 1.0,
'noise_multiplier': 1.1,
'batch_size': 128,
'learning_rate': 0.01,
'epochs': 2
}
results = trainer.train(test_params)
print(f"β
Mock training completed!")
print(f"Final accuracy: {results['final_metrics']['accuracy']:.2f}%")
print(f"Final loss: {results['final_metrics']['loss']:.4f}")
print(f"Training time: {results['final_metrics']['training_time']:.2f} seconds")
return True
except Exception as e:
print(f"β Mock trainer error: {e}")
return False
def test_web_app():
"""Test that the web app routes work."""
print("\nTesting Web App Routes")
print("=" * 50)
try:
from app.routes import main
print("β
Successfully imported routes")
# Test trainer status
from app.routes import REAL_TRAINER_AVAILABLE, real_trainer
print(f"Real trainer available: {REAL_TRAINER_AVAILABLE}")
if REAL_TRAINER_AVAILABLE and real_trainer:
print("β
Real trainer is ready for use")
else:
print("β οΈ Will use mock trainer")
return True
except Exception as e:
print(f"β Web app test error: {e}")
return False
if __name__ == "__main__":
print("DPSGD Training System Test")
print("=" * 60)
# Test components
mock_success = test_mock_trainer()
real_success = test_real_trainer()
web_success = test_web_app()
print("\n" + "=" * 60)
print("TEST SUMMARY")
print("=" * 60)
print(f"Mock Trainer: {'β
PASS' if mock_success else 'β FAIL'}")
print(f"Real Trainer: {'β
PASS' if real_success else 'β FAIL'}")
print(f"Web App: {'β
PASS' if web_success else 'β FAIL'}")
if real_success:
print("\nπ All tests passed! The system will use real MNIST data.")
elif mock_success:
print("\nβ οΈ Real trainer failed, but mock trainer works. System will use synthetic data.")
else:
print("\nβ Critical errors found. Please check your setup.")
print("\nTo install missing dependencies, run:")
print("pip install -r requirements.txt") |