File size: 5,583 Bytes
8ad5d56 6640531 8ad5d56 b0b2c21 6640531 b0b2c21 6640531 8ad5d56 6640531 8ad5d56 b0b2c21 8ad5d56 b0b2c21 8ad5d56 b0b2c21 8ad5d56 6640531 8ad5d56 b0b2c21 8ad5d56 b0b2c21 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
from flask import Blueprint, render_template, jsonify, request, current_app
from app.training.mock_trainer import MockTrainer
from app.training.privacy_calculator import PrivacyCalculator
from flask_cors import cross_origin
import os
# Try to import RealTrainer, fallback to MockTrainer if dependencies aren't available
try:
from app.training.simplified_real_trainer import SimplifiedRealTrainer as RealTrainer
REAL_TRAINER_AVAILABLE = True
print("Simplified real trainer available - will use MNIST dataset")
except ImportError as e:
print(f"Real trainer not available ({e}) - trying simplified version")
try:
from app.training.real_trainer import RealTrainer
REAL_TRAINER_AVAILABLE = True
print("Full real trainer available - will use MNIST dataset")
except ImportError as e2:
print(f"No real trainer available ({e2}) - using mock trainer")
REAL_TRAINER_AVAILABLE = False
main = Blueprint('main', __name__)
mock_trainer = MockTrainer()
privacy_calculator = PrivacyCalculator()
# Initialize real trainer if available
if REAL_TRAINER_AVAILABLE:
try:
real_trainer = RealTrainer()
print("Real trainer initialized successfully")
except Exception as e:
print(f"Failed to initialize real trainer: {e}")
REAL_TRAINER_AVAILABLE = False
real_trainer = None
else:
real_trainer = None
@main.route('/')
def index():
return render_template('index.html')
@main.route('/learning')
def learning():
return render_template('learning.html')
@main.route('/api/train', methods=['POST', 'OPTIONS'])
@cross_origin()
def train():
if request.method == 'OPTIONS':
return jsonify({'status': 'ok'})
try:
data = request.json
if not data:
return jsonify({'error': 'No data provided'}), 400
params = {
'clipping_norm': float(data.get('clipping_norm', 1.0)),
'noise_multiplier': float(data.get('noise_multiplier', 1.0)),
'batch_size': int(data.get('batch_size', 64)),
'learning_rate': float(data.get('learning_rate', 0.01)),
'epochs': int(data.get('epochs', 5))
}
# Check if user wants to force mock training
use_mock = data.get('use_mock', False)
# Use real trainer if available and not forced to use mock
if REAL_TRAINER_AVAILABLE and real_trainer and not use_mock:
print("Using real trainer with MNIST dataset")
results = real_trainer.train(params)
results['trainer_type'] = 'real'
results['dataset'] = 'MNIST'
else:
print("Using mock trainer with synthetic data")
results = mock_trainer.train(params)
results['trainer_type'] = 'mock'
results['dataset'] = 'synthetic'
# Add gradient information for visualization (if not already included)
if 'gradient_info' not in results:
trainer = real_trainer if (REAL_TRAINER_AVAILABLE and real_trainer and not use_mock) else mock_trainer
results['gradient_info'] = {
'before_clipping': trainer.generate_gradient_norms(params['clipping_norm']),
'after_clipping': trainer.generate_clipped_gradients(params['clipping_norm'])
}
return jsonify(results)
except (TypeError, ValueError) as e:
return jsonify({'error': f'Invalid parameter values: {str(e)}'}), 400
except Exception as e:
print(f"Training error: {str(e)}")
# Fallback to mock trainer on any error
try:
print("Falling back to mock trainer due to error")
results = mock_trainer.train(params)
results['trainer_type'] = 'mock'
results['dataset'] = 'synthetic'
results['fallback_reason'] = str(e)
return jsonify(results)
except Exception as fallback_error:
return jsonify({'error': f'Server error: {str(fallback_error)}'}), 500
@main.route('/api/privacy-budget', methods=['POST', 'OPTIONS'])
@cross_origin()
def calculate_privacy_budget():
if request.method == 'OPTIONS':
return jsonify({'status': 'ok'})
try:
data = request.json
if not data:
return jsonify({'error': 'No data provided'}), 400
params = {
'clipping_norm': float(data.get('clipping_norm', 1.0)),
'noise_multiplier': float(data.get('noise_multiplier', 1.0)),
'batch_size': int(data.get('batch_size', 64)),
'epochs': int(data.get('epochs', 5))
}
# Use real trainer's privacy calculation if available, otherwise use privacy calculator
if REAL_TRAINER_AVAILABLE and real_trainer:
epsilon = real_trainer._calculate_privacy_budget(params)
else:
epsilon = privacy_calculator.calculate_epsilon(params)
return jsonify({'epsilon': epsilon})
except (TypeError, ValueError) as e:
return jsonify({'error': f'Invalid parameter values: {str(e)}'}), 400
except Exception as e:
return jsonify({'error': f'Server error: {str(e)}'}), 500
@main.route('/api/trainer-status', methods=['GET'])
@cross_origin()
def trainer_status():
"""Endpoint to check which trainer is being used."""
return jsonify({
'real_trainer_available': REAL_TRAINER_AVAILABLE,
'current_trainer': 'real' if REAL_TRAINER_AVAILABLE else 'mock',
'dataset': 'MNIST' if REAL_TRAINER_AVAILABLE else 'synthetic'
}) |