from flask import Blueprint, render_template, jsonify, request, current_app from app.training.mock_trainer import MockTrainer from app.training.privacy_calculator import PrivacyCalculator from flask_cors import cross_origin import os # Try to import RealTrainer, fallback to MockTrainer if dependencies aren't available try: from app.training.simplified_real_trainer import SimplifiedRealTrainer as RealTrainer REAL_TRAINER_AVAILABLE = True print("Simplified real trainer available - will use MNIST dataset") except ImportError as e: print(f"Real trainer not available ({e}) - trying simplified version") try: from app.training.real_trainer import RealTrainer REAL_TRAINER_AVAILABLE = True print("Full real trainer available - will use MNIST dataset") except ImportError as e2: print(f"No real trainer available ({e2}) - using mock trainer") REAL_TRAINER_AVAILABLE = False main = Blueprint('main', __name__) mock_trainer = MockTrainer() privacy_calculator = PrivacyCalculator() # We'll create trainers dynamically based on dataset selection real_trainers = {} # Cache trainers by dataset to avoid reloading def get_or_create_trainer(dataset, model_architecture='simple-mlp'): """Get or create a trainer for the specified dataset and architecture.""" if not REAL_TRAINER_AVAILABLE: return None # Create a unique key for dataset + architecture combination trainer_key = f"{dataset}_{model_architecture}" if trainer_key not in real_trainers: try: print(f"Creating new trainer for dataset: {dataset}, architecture: {model_architecture}") real_trainers[trainer_key] = RealTrainer(dataset=dataset, model_architecture=model_architecture) print(f"Trainer for {dataset} with {model_architecture} initialized successfully") except Exception as e: print(f"Failed to initialize trainer for {dataset} with {model_architecture}: {e}") return None return real_trainers[trainer_key] @main.route('/') def index(): return render_template('index.html') @main.route('/learning') def learning(): return render_template('learning.html') @main.route('/api/train', methods=['POST', 'OPTIONS']) @cross_origin() def train(): if request.method == 'OPTIONS': return jsonify({'status': 'ok'}) try: data = request.json if not data: return jsonify({'error': 'No data provided'}), 400 params = { 'clipping_norm': float(data.get('clipping_norm', 1.0)), 'noise_multiplier': float(data.get('noise_multiplier', 1.0)), 'batch_size': int(data.get('batch_size', 64)), 'learning_rate': float(data.get('learning_rate', 0.01)), 'epochs': int(data.get('epochs', 5)) } # Get dataset and model architecture selection dataset = data.get('dataset', 'mnist') model_architecture = data.get('model_architecture', 'simple-mlp') # Check if user wants to force mock training use_mock = data.get('use_mock', False) # Use real trainer if available and not forced to use mock if REAL_TRAINER_AVAILABLE and not use_mock: real_trainer = get_or_create_trainer(dataset, model_architecture) if real_trainer: print(f"Using real trainer with {dataset.upper()} dataset and {model_architecture} architecture") results = real_trainer.train(params) results['trainer_type'] = 'real' results['dataset'] = dataset.upper() results['model_architecture'] = model_architecture else: print("Failed to create real trainer, falling back to mock trainer") results = mock_trainer.train(params) results['trainer_type'] = 'mock' results['dataset'] = 'synthetic' results['model_architecture'] = 'mock' else: print("Using mock trainer with synthetic data") results = mock_trainer.train(params) results['trainer_type'] = 'mock' results['dataset'] = 'synthetic' results['model_architecture'] = 'mock' # Add gradient information for visualization (if not already included) if 'gradient_info' not in results: if REAL_TRAINER_AVAILABLE and not use_mock: current_trainer = get_or_create_trainer(dataset, model_architecture) if current_trainer: trainer = current_trainer else: trainer = mock_trainer else: trainer = mock_trainer results['gradient_info'] = { 'before_clipping': trainer.generate_gradient_norms(params['clipping_norm']), 'after_clipping': trainer.generate_clipped_gradients(params['clipping_norm']) } return jsonify(results) except (TypeError, ValueError) as e: return jsonify({'error': f'Invalid parameter values: {str(e)}'}), 400 except Exception as e: print(f"Training error: {str(e)}") # Fallback to mock trainer on any error try: print("Falling back to mock trainer due to error") results = mock_trainer.train(params) results['trainer_type'] = 'mock' results['dataset'] = 'synthetic' results['fallback_reason'] = str(e) return jsonify(results) except Exception as fallback_error: return jsonify({'error': f'Server error: {str(fallback_error)}'}), 500 @main.route('/api/privacy-budget', methods=['POST', 'OPTIONS']) @cross_origin() def calculate_privacy_budget(): if request.method == 'OPTIONS': return jsonify({'status': 'ok'}) try: data = request.json if not data: return jsonify({'error': 'No data provided'}), 400 params = { 'clipping_norm': float(data.get('clipping_norm', 1.0)), 'noise_multiplier': float(data.get('noise_multiplier', 1.0)), 'batch_size': int(data.get('batch_size', 64)), 'epochs': int(data.get('epochs', 5)) } # Use real trainer's privacy calculation if available, otherwise use privacy calculator if REAL_TRAINER_AVAILABLE and real_trainer: epsilon = real_trainer._calculate_privacy_budget(params) else: epsilon = privacy_calculator.calculate_epsilon(params) return jsonify({'epsilon': epsilon}) except (TypeError, ValueError) as e: return jsonify({'error': f'Invalid parameter values: {str(e)}'}), 400 except Exception as e: return jsonify({'error': f'Server error: {str(e)}'}), 500 @main.route('/api/trainer-status', methods=['GET']) @cross_origin() def trainer_status(): """Endpoint to check which trainer is being used.""" return jsonify({ 'real_trainer_available': REAL_TRAINER_AVAILABLE, 'current_trainer': 'real' if REAL_TRAINER_AVAILABLE else 'mock', 'dataset': 'MNIST' if REAL_TRAINER_AVAILABLE else 'synthetic' })