DPSGDTool / test_training.py
Shuya Feng
udpate
b0b2c21
#!/usr/bin/env python3
"""
Test script to verify MNIST training with DP-SGD works correctly.
Run this script to test the real trainer implementation.
"""
import sys
import os
sys.path.append('.')
def test_real_trainer():
"""Test the real trainer with MNIST dataset."""
print("Testing Real Trainer with MNIST Dataset")
print("=" * 50)
try:
try:
from app.training.simplified_real_trainer import SimplifiedRealTrainer as RealTrainer
print("βœ… Successfully imported SimplifiedRealTrainer")
except ImportError:
from app.training.real_trainer import RealTrainer
print("βœ… Successfully imported RealTrainer")
# Initialize trainer
trainer = RealTrainer()
print("βœ… Successfully initialized RealTrainer")
print(f"βœ… Training data shape: {trainer.x_train.shape}")
print(f"βœ… Test data shape: {trainer.x_test.shape}")
# Test with small parameters for quick execution
test_params = {
'clipping_norm': 1.0,
'noise_multiplier': 1.1,
'batch_size': 128,
'learning_rate': 0.01,
'epochs': 2 # Small number for testing
}
print(f"\nTraining with parameters: {test_params}")
results = trainer.train(test_params)
print(f"\nβœ… Training completed successfully!")
print(f"Final accuracy: {results['final_metrics']['accuracy']:.2f}%")
print(f"Final loss: {results['final_metrics']['loss']:.4f}")
print(f"Training time: {results['final_metrics']['training_time']:.2f} seconds")
if 'privacy_budget' in results:
print(f"Privacy budget (Ξ΅): {results['privacy_budget']:.2f}")
print(f"Number of epochs recorded: {len(results['epochs_data'])}")
print(f"Number of recommendations: {len(results['recommendations'])}")
return True
except ImportError as e:
print(f"❌ Import Error: {e}")
print("Make sure TensorFlow and TensorFlow Privacy are installed:")
print("pip install tensorflow==2.15.0 tensorflow-privacy==0.9.0")
return False
except Exception as e:
print(f"❌ Training Error: {e}")
return False
def test_mock_trainer():
"""Test the mock trainer as fallback."""
print("\nTesting Mock Trainer (Fallback)")
print("=" * 50)
try:
from app.training.mock_trainer import MockTrainer
trainer = MockTrainer()
test_params = {
'clipping_norm': 1.0,
'noise_multiplier': 1.1,
'batch_size': 128,
'learning_rate': 0.01,
'epochs': 2
}
results = trainer.train(test_params)
print(f"βœ… Mock training completed!")
print(f"Final accuracy: {results['final_metrics']['accuracy']:.2f}%")
print(f"Final loss: {results['final_metrics']['loss']:.4f}")
print(f"Training time: {results['final_metrics']['training_time']:.2f} seconds")
return True
except Exception as e:
print(f"❌ Mock trainer error: {e}")
return False
def test_web_app():
"""Test that the web app routes work."""
print("\nTesting Web App Routes")
print("=" * 50)
try:
from app.routes import main
print("βœ… Successfully imported routes")
# Test trainer status
from app.routes import REAL_TRAINER_AVAILABLE, real_trainer
print(f"Real trainer available: {REAL_TRAINER_AVAILABLE}")
if REAL_TRAINER_AVAILABLE and real_trainer:
print("βœ… Real trainer is ready for use")
else:
print("⚠️ Will use mock trainer")
return True
except Exception as e:
print(f"❌ Web app test error: {e}")
return False
if __name__ == "__main__":
print("DPSGD Training System Test")
print("=" * 60)
# Test components
mock_success = test_mock_trainer()
real_success = test_real_trainer()
web_success = test_web_app()
print("\n" + "=" * 60)
print("TEST SUMMARY")
print("=" * 60)
print(f"Mock Trainer: {'βœ… PASS' if mock_success else '❌ FAIL'}")
print(f"Real Trainer: {'βœ… PASS' if real_success else '❌ FAIL'}")
print(f"Web App: {'βœ… PASS' if web_success else '❌ FAIL'}")
if real_success:
print("\nπŸŽ‰ All tests passed! The system will use real MNIST data.")
elif mock_success:
print("\n⚠️ Real trainer failed, but mock trainer works. System will use synthetic data.")
else:
print("\n❌ Critical errors found. Please check your setup.")
print("\nTo install missing dependencies, run:")
print("pip install -r requirements.txt")