|
|
|
""" |
|
Test script to verify MNIST training with DP-SGD works correctly. |
|
Run this script to test the real trainer implementation. |
|
""" |
|
|
|
import sys |
|
import os |
|
sys.path.append('.') |
|
|
|
def test_real_trainer(): |
|
"""Test the real trainer with MNIST dataset.""" |
|
print("Testing Real Trainer with MNIST Dataset") |
|
print("=" * 50) |
|
|
|
try: |
|
try: |
|
from app.training.simplified_real_trainer import SimplifiedRealTrainer as RealTrainer |
|
print("β
Successfully imported SimplifiedRealTrainer") |
|
except ImportError: |
|
from app.training.real_trainer import RealTrainer |
|
print("β
Successfully imported RealTrainer") |
|
|
|
|
|
trainer = RealTrainer() |
|
print("β
Successfully initialized RealTrainer") |
|
print(f"β
Training data shape: {trainer.x_train.shape}") |
|
print(f"β
Test data shape: {trainer.x_test.shape}") |
|
|
|
|
|
test_params = { |
|
'clipping_norm': 1.0, |
|
'noise_multiplier': 1.1, |
|
'batch_size': 128, |
|
'learning_rate': 0.01, |
|
'epochs': 2 |
|
} |
|
|
|
print(f"\nTraining with parameters: {test_params}") |
|
results = trainer.train(test_params) |
|
|
|
print(f"\nβ
Training completed successfully!") |
|
print(f"Final accuracy: {results['final_metrics']['accuracy']:.2f}%") |
|
print(f"Final loss: {results['final_metrics']['loss']:.4f}") |
|
print(f"Training time: {results['final_metrics']['training_time']:.2f} seconds") |
|
|
|
if 'privacy_budget' in results: |
|
print(f"Privacy budget (Ξ΅): {results['privacy_budget']:.2f}") |
|
|
|
print(f"Number of epochs recorded: {len(results['epochs_data'])}") |
|
print(f"Number of recommendations: {len(results['recommendations'])}") |
|
|
|
return True |
|
|
|
except ImportError as e: |
|
print(f"β Import Error: {e}") |
|
print("Make sure TensorFlow and TensorFlow Privacy are installed:") |
|
print("pip install tensorflow==2.15.0 tensorflow-privacy==0.9.0") |
|
return False |
|
|
|
except Exception as e: |
|
print(f"β Training Error: {e}") |
|
return False |
|
|
|
def test_mock_trainer(): |
|
"""Test the mock trainer as fallback.""" |
|
print("\nTesting Mock Trainer (Fallback)") |
|
print("=" * 50) |
|
|
|
try: |
|
from app.training.mock_trainer import MockTrainer |
|
|
|
trainer = MockTrainer() |
|
test_params = { |
|
'clipping_norm': 1.0, |
|
'noise_multiplier': 1.1, |
|
'batch_size': 128, |
|
'learning_rate': 0.01, |
|
'epochs': 2 |
|
} |
|
|
|
results = trainer.train(test_params) |
|
|
|
print(f"β
Mock training completed!") |
|
print(f"Final accuracy: {results['final_metrics']['accuracy']:.2f}%") |
|
print(f"Final loss: {results['final_metrics']['loss']:.4f}") |
|
print(f"Training time: {results['final_metrics']['training_time']:.2f} seconds") |
|
|
|
return True |
|
|
|
except Exception as e: |
|
print(f"β Mock trainer error: {e}") |
|
return False |
|
|
|
def test_web_app(): |
|
"""Test that the web app routes work.""" |
|
print("\nTesting Web App Routes") |
|
print("=" * 50) |
|
|
|
try: |
|
from app.routes import main |
|
print("β
Successfully imported routes") |
|
|
|
|
|
from app.routes import REAL_TRAINER_AVAILABLE, real_trainer |
|
print(f"Real trainer available: {REAL_TRAINER_AVAILABLE}") |
|
if REAL_TRAINER_AVAILABLE and real_trainer: |
|
print("β
Real trainer is ready for use") |
|
else: |
|
print("β οΈ Will use mock trainer") |
|
|
|
return True |
|
|
|
except Exception as e: |
|
print(f"β Web app test error: {e}") |
|
return False |
|
|
|
if __name__ == "__main__": |
|
print("DPSGD Training System Test") |
|
print("=" * 60) |
|
|
|
|
|
mock_success = test_mock_trainer() |
|
real_success = test_real_trainer() |
|
web_success = test_web_app() |
|
|
|
print("\n" + "=" * 60) |
|
print("TEST SUMMARY") |
|
print("=" * 60) |
|
print(f"Mock Trainer: {'β
PASS' if mock_success else 'β FAIL'}") |
|
print(f"Real Trainer: {'β
PASS' if real_success else 'β FAIL'}") |
|
print(f"Web App: {'β
PASS' if web_success else 'β FAIL'}") |
|
|
|
if real_success: |
|
print("\nπ All tests passed! The system will use real MNIST data.") |
|
elif mock_success: |
|
print("\nβ οΈ Real trainer failed, but mock trainer works. System will use synthetic data.") |
|
else: |
|
print("\nβ Critical errors found. Please check your setup.") |
|
|
|
print("\nTo install missing dependencies, run:") |
|
print("pip install -r requirements.txt") |