DPSGDTool / app /routes.py
Emily
Add multi-dataset and ResNet-18 architecture support
f826c3b
from flask import Blueprint, render_template, jsonify, request, current_app
from app.training.mock_trainer import MockTrainer
from app.training.privacy_calculator import PrivacyCalculator
from flask_cors import cross_origin
import os
# Try to import RealTrainer, fallback to MockTrainer if dependencies aren't available
try:
from app.training.simplified_real_trainer import SimplifiedRealTrainer as RealTrainer
REAL_TRAINER_AVAILABLE = True
print("Simplified real trainer available - will use MNIST dataset")
except ImportError as e:
print(f"Real trainer not available ({e}) - trying simplified version")
try:
from app.training.real_trainer import RealTrainer
REAL_TRAINER_AVAILABLE = True
print("Full real trainer available - will use MNIST dataset")
except ImportError as e2:
print(f"No real trainer available ({e2}) - using mock trainer")
REAL_TRAINER_AVAILABLE = False
main = Blueprint('main', __name__)
mock_trainer = MockTrainer()
privacy_calculator = PrivacyCalculator()
# We'll create trainers dynamically based on dataset selection
real_trainers = {} # Cache trainers by dataset to avoid reloading
def get_or_create_trainer(dataset, model_architecture='simple-mlp'):
"""Get or create a trainer for the specified dataset and architecture."""
if not REAL_TRAINER_AVAILABLE:
return None
# Create a unique key for dataset + architecture combination
trainer_key = f"{dataset}_{model_architecture}"
if trainer_key not in real_trainers:
try:
print(f"Creating new trainer for dataset: {dataset}, architecture: {model_architecture}")
real_trainers[trainer_key] = RealTrainer(dataset=dataset, model_architecture=model_architecture)
print(f"Trainer for {dataset} with {model_architecture} initialized successfully")
except Exception as e:
print(f"Failed to initialize trainer for {dataset} with {model_architecture}: {e}")
return None
return real_trainers[trainer_key]
@main.route('/')
def index():
return render_template('index.html')
@main.route('/learning')
def learning():
return render_template('learning.html')
@main.route('/api/train', methods=['POST', 'OPTIONS'])
@cross_origin()
def train():
if request.method == 'OPTIONS':
return jsonify({'status': 'ok'})
try:
data = request.json
if not data:
return jsonify({'error': 'No data provided'}), 400
params = {
'clipping_norm': float(data.get('clipping_norm', 1.0)),
'noise_multiplier': float(data.get('noise_multiplier', 1.0)),
'batch_size': int(data.get('batch_size', 64)),
'learning_rate': float(data.get('learning_rate', 0.01)),
'epochs': int(data.get('epochs', 5))
}
# Get dataset and model architecture selection
dataset = data.get('dataset', 'mnist')
model_architecture = data.get('model_architecture', 'simple-mlp')
# Check if user wants to force mock training
use_mock = data.get('use_mock', False)
# Use real trainer if available and not forced to use mock
if REAL_TRAINER_AVAILABLE and not use_mock:
real_trainer = get_or_create_trainer(dataset, model_architecture)
if real_trainer:
print(f"Using real trainer with {dataset.upper()} dataset and {model_architecture} architecture")
results = real_trainer.train(params)
results['trainer_type'] = 'real'
results['dataset'] = dataset.upper()
results['model_architecture'] = model_architecture
else:
print("Failed to create real trainer, falling back to mock trainer")
results = mock_trainer.train(params)
results['trainer_type'] = 'mock'
results['dataset'] = 'synthetic'
results['model_architecture'] = 'mock'
else:
print("Using mock trainer with synthetic data")
results = mock_trainer.train(params)
results['trainer_type'] = 'mock'
results['dataset'] = 'synthetic'
results['model_architecture'] = 'mock'
# Add gradient information for visualization (if not already included)
if 'gradient_info' not in results:
if REAL_TRAINER_AVAILABLE and not use_mock:
current_trainer = get_or_create_trainer(dataset, model_architecture)
if current_trainer:
trainer = current_trainer
else:
trainer = mock_trainer
else:
trainer = mock_trainer
results['gradient_info'] = {
'before_clipping': trainer.generate_gradient_norms(params['clipping_norm']),
'after_clipping': trainer.generate_clipped_gradients(params['clipping_norm'])
}
return jsonify(results)
except (TypeError, ValueError) as e:
return jsonify({'error': f'Invalid parameter values: {str(e)}'}), 400
except Exception as e:
print(f"Training error: {str(e)}")
# Fallback to mock trainer on any error
try:
print("Falling back to mock trainer due to error")
results = mock_trainer.train(params)
results['trainer_type'] = 'mock'
results['dataset'] = 'synthetic'
results['fallback_reason'] = str(e)
return jsonify(results)
except Exception as fallback_error:
return jsonify({'error': f'Server error: {str(fallback_error)}'}), 500
@main.route('/api/privacy-budget', methods=['POST', 'OPTIONS'])
@cross_origin()
def calculate_privacy_budget():
if request.method == 'OPTIONS':
return jsonify({'status': 'ok'})
try:
data = request.json
if not data:
return jsonify({'error': 'No data provided'}), 400
params = {
'clipping_norm': float(data.get('clipping_norm', 1.0)),
'noise_multiplier': float(data.get('noise_multiplier', 1.0)),
'batch_size': int(data.get('batch_size', 64)),
'epochs': int(data.get('epochs', 5))
}
# Use real trainer's privacy calculation if available, otherwise use privacy calculator
if REAL_TRAINER_AVAILABLE and real_trainer:
epsilon = real_trainer._calculate_privacy_budget(params)
else:
epsilon = privacy_calculator.calculate_epsilon(params)
return jsonify({'epsilon': epsilon})
except (TypeError, ValueError) as e:
return jsonify({'error': f'Invalid parameter values: {str(e)}'}), 400
except Exception as e:
return jsonify({'error': f'Server error: {str(e)}'}), 500
@main.route('/api/trainer-status', methods=['GET'])
@cross_origin()
def trainer_status():
"""Endpoint to check which trainer is being used."""
return jsonify({
'real_trainer_available': REAL_TRAINER_AVAILABLE,
'current_trainer': 'real' if REAL_TRAINER_AVAILABLE else 'mock',
'dataset': 'MNIST' if REAL_TRAINER_AVAILABLE else 'synthetic'
})