|
from flask import Blueprint, render_template, jsonify, request, current_app |
|
from app.training.mock_trainer import MockTrainer |
|
from app.training.privacy_calculator import PrivacyCalculator |
|
from flask_cors import cross_origin |
|
import os |
|
|
|
|
|
try: |
|
from app.training.simplified_real_trainer import SimplifiedRealTrainer as RealTrainer |
|
REAL_TRAINER_AVAILABLE = True |
|
print("Simplified real trainer available - will use MNIST dataset") |
|
except ImportError as e: |
|
print(f"Real trainer not available ({e}) - trying simplified version") |
|
try: |
|
from app.training.real_trainer import RealTrainer |
|
REAL_TRAINER_AVAILABLE = True |
|
print("Full real trainer available - will use MNIST dataset") |
|
except ImportError as e2: |
|
print(f"No real trainer available ({e2}) - using mock trainer") |
|
REAL_TRAINER_AVAILABLE = False |
|
|
|
main = Blueprint('main', __name__) |
|
mock_trainer = MockTrainer() |
|
privacy_calculator = PrivacyCalculator() |
|
|
|
|
|
real_trainers = {} |
|
|
|
def get_or_create_trainer(dataset, model_architecture='simple-mlp'): |
|
"""Get or create a trainer for the specified dataset and architecture.""" |
|
if not REAL_TRAINER_AVAILABLE: |
|
return None |
|
|
|
|
|
trainer_key = f"{dataset}_{model_architecture}" |
|
|
|
if trainer_key not in real_trainers: |
|
try: |
|
print(f"Creating new trainer for dataset: {dataset}, architecture: {model_architecture}") |
|
real_trainers[trainer_key] = RealTrainer(dataset=dataset, model_architecture=model_architecture) |
|
print(f"Trainer for {dataset} with {model_architecture} initialized successfully") |
|
except Exception as e: |
|
print(f"Failed to initialize trainer for {dataset} with {model_architecture}: {e}") |
|
return None |
|
|
|
return real_trainers[trainer_key] |
|
|
|
@main.route('/') |
|
def index(): |
|
return render_template('index.html') |
|
|
|
@main.route('/learning') |
|
def learning(): |
|
return render_template('learning.html') |
|
|
|
@main.route('/api/train', methods=['POST', 'OPTIONS']) |
|
@cross_origin() |
|
def train(): |
|
if request.method == 'OPTIONS': |
|
return jsonify({'status': 'ok'}) |
|
|
|
try: |
|
data = request.json |
|
if not data: |
|
return jsonify({'error': 'No data provided'}), 400 |
|
|
|
params = { |
|
'clipping_norm': float(data.get('clipping_norm', 1.0)), |
|
'noise_multiplier': float(data.get('noise_multiplier', 1.0)), |
|
'batch_size': int(data.get('batch_size', 64)), |
|
'learning_rate': float(data.get('learning_rate', 0.01)), |
|
'epochs': int(data.get('epochs', 5)) |
|
} |
|
|
|
|
|
dataset = data.get('dataset', 'mnist') |
|
model_architecture = data.get('model_architecture', 'simple-mlp') |
|
|
|
|
|
use_mock = data.get('use_mock', False) |
|
|
|
|
|
if REAL_TRAINER_AVAILABLE and not use_mock: |
|
real_trainer = get_or_create_trainer(dataset, model_architecture) |
|
if real_trainer: |
|
print(f"Using real trainer with {dataset.upper()} dataset and {model_architecture} architecture") |
|
results = real_trainer.train(params) |
|
results['trainer_type'] = 'real' |
|
results['dataset'] = dataset.upper() |
|
results['model_architecture'] = model_architecture |
|
else: |
|
print("Failed to create real trainer, falling back to mock trainer") |
|
results = mock_trainer.train(params) |
|
results['trainer_type'] = 'mock' |
|
results['dataset'] = 'synthetic' |
|
results['model_architecture'] = 'mock' |
|
else: |
|
print("Using mock trainer with synthetic data") |
|
results = mock_trainer.train(params) |
|
results['trainer_type'] = 'mock' |
|
results['dataset'] = 'synthetic' |
|
results['model_architecture'] = 'mock' |
|
|
|
|
|
if 'gradient_info' not in results: |
|
if REAL_TRAINER_AVAILABLE and not use_mock: |
|
current_trainer = get_or_create_trainer(dataset, model_architecture) |
|
if current_trainer: |
|
trainer = current_trainer |
|
else: |
|
trainer = mock_trainer |
|
else: |
|
trainer = mock_trainer |
|
|
|
results['gradient_info'] = { |
|
'before_clipping': trainer.generate_gradient_norms(params['clipping_norm']), |
|
'after_clipping': trainer.generate_clipped_gradients(params['clipping_norm']) |
|
} |
|
|
|
return jsonify(results) |
|
except (TypeError, ValueError) as e: |
|
return jsonify({'error': f'Invalid parameter values: {str(e)}'}), 400 |
|
except Exception as e: |
|
print(f"Training error: {str(e)}") |
|
|
|
try: |
|
print("Falling back to mock trainer due to error") |
|
results = mock_trainer.train(params) |
|
results['trainer_type'] = 'mock' |
|
results['dataset'] = 'synthetic' |
|
results['fallback_reason'] = str(e) |
|
return jsonify(results) |
|
except Exception as fallback_error: |
|
return jsonify({'error': f'Server error: {str(fallback_error)}'}), 500 |
|
|
|
@main.route('/api/privacy-budget', methods=['POST', 'OPTIONS']) |
|
@cross_origin() |
|
def calculate_privacy_budget(): |
|
if request.method == 'OPTIONS': |
|
return jsonify({'status': 'ok'}) |
|
|
|
try: |
|
data = request.json |
|
if not data: |
|
return jsonify({'error': 'No data provided'}), 400 |
|
|
|
params = { |
|
'clipping_norm': float(data.get('clipping_norm', 1.0)), |
|
'noise_multiplier': float(data.get('noise_multiplier', 1.0)), |
|
'batch_size': int(data.get('batch_size', 64)), |
|
'epochs': int(data.get('epochs', 5)) |
|
} |
|
|
|
|
|
if REAL_TRAINER_AVAILABLE and real_trainer: |
|
epsilon = real_trainer._calculate_privacy_budget(params) |
|
else: |
|
epsilon = privacy_calculator.calculate_epsilon(params) |
|
|
|
return jsonify({'epsilon': epsilon}) |
|
except (TypeError, ValueError) as e: |
|
return jsonify({'error': f'Invalid parameter values: {str(e)}'}), 400 |
|
except Exception as e: |
|
return jsonify({'error': f'Server error: {str(e)}'}), 500 |
|
|
|
@main.route('/api/trainer-status', methods=['GET']) |
|
@cross_origin() |
|
def trainer_status(): |
|
"""Endpoint to check which trainer is being used.""" |
|
return jsonify({ |
|
'real_trainer_available': REAL_TRAINER_AVAILABLE, |
|
'current_trainer': 'real' if REAL_TRAINER_AVAILABLE else 'mock', |
|
'dataset': 'MNIST' if REAL_TRAINER_AVAILABLE else 'synthetic' |
|
}) |