#!/usr/bin/env python3 """ Test script to verify MNIST training with DP-SGD works correctly. Run this script to test the real trainer implementation. """ import sys import os sys.path.append('.') def test_real_trainer(): """Test the real trainer with MNIST dataset.""" print("Testing Real Trainer with MNIST Dataset") print("=" * 50) try: try: from app.training.simplified_real_trainer import SimplifiedRealTrainer as RealTrainer print("✅ Successfully imported SimplifiedRealTrainer") except ImportError: from app.training.real_trainer import RealTrainer print("✅ Successfully imported RealTrainer") # Initialize trainer trainer = RealTrainer() print("✅ Successfully initialized RealTrainer") print(f"✅ Training data shape: {trainer.x_train.shape}") print(f"✅ Test data shape: {trainer.x_test.shape}") # Test with small parameters for quick execution test_params = { 'clipping_norm': 1.0, 'noise_multiplier': 1.1, 'batch_size': 128, 'learning_rate': 0.01, 'epochs': 2 # Small number for testing } print(f"\nTraining with parameters: {test_params}") results = trainer.train(test_params) print(f"\n✅ Training completed successfully!") print(f"Final accuracy: {results['final_metrics']['accuracy']:.2f}%") print(f"Final loss: {results['final_metrics']['loss']:.4f}") print(f"Training time: {results['final_metrics']['training_time']:.2f} seconds") if 'privacy_budget' in results: print(f"Privacy budget (ε): {results['privacy_budget']:.2f}") print(f"Number of epochs recorded: {len(results['epochs_data'])}") print(f"Number of recommendations: {len(results['recommendations'])}") return True except ImportError as e: print(f"❌ Import Error: {e}") print("Make sure TensorFlow and TensorFlow Privacy are installed:") print("pip install tensorflow==2.15.0 tensorflow-privacy==0.9.0") return False except Exception as e: print(f"❌ Training Error: {e}") return False def test_mock_trainer(): """Test the mock trainer as fallback.""" print("\nTesting Mock Trainer (Fallback)") print("=" * 50) try: from app.training.mock_trainer import MockTrainer trainer = MockTrainer() test_params = { 'clipping_norm': 1.0, 'noise_multiplier': 1.1, 'batch_size': 128, 'learning_rate': 0.01, 'epochs': 2 } results = trainer.train(test_params) print(f"✅ Mock training completed!") print(f"Final accuracy: {results['final_metrics']['accuracy']:.2f}%") print(f"Final loss: {results['final_metrics']['loss']:.4f}") print(f"Training time: {results['final_metrics']['training_time']:.2f} seconds") return True except Exception as e: print(f"❌ Mock trainer error: {e}") return False def test_web_app(): """Test that the web app routes work.""" print("\nTesting Web App Routes") print("=" * 50) try: from app.routes import main print("✅ Successfully imported routes") # Test trainer status from app.routes import REAL_TRAINER_AVAILABLE, real_trainer print(f"Real trainer available: {REAL_TRAINER_AVAILABLE}") if REAL_TRAINER_AVAILABLE and real_trainer: print("✅ Real trainer is ready for use") else: print("⚠️ Will use mock trainer") return True except Exception as e: print(f"❌ Web app test error: {e}") return False if __name__ == "__main__": print("DPSGD Training System Test") print("=" * 60) # Test components mock_success = test_mock_trainer() real_success = test_real_trainer() web_success = test_web_app() print("\n" + "=" * 60) print("TEST SUMMARY") print("=" * 60) print(f"Mock Trainer: {'✅ PASS' if mock_success else '❌ FAIL'}") print(f"Real Trainer: {'✅ PASS' if real_success else '❌ FAIL'}") print(f"Web App: {'✅ PASS' if web_success else '❌ FAIL'}") if real_success: print("\n🎉 All tests passed! The system will use real MNIST data.") elif mock_success: print("\n⚠️ Real trainer failed, but mock trainer works. System will use synthetic data.") else: print("\n❌ Critical errors found. Please check your setup.") print("\nTo install missing dependencies, run:") print("pip install -r requirements.txt")