File size: 5,583 Bytes
8ad5d56
6640531
 
8ad5d56
b0b2c21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6640531
 
 
 
 
b0b2c21
 
 
 
 
 
 
 
 
 
 
 
6640531
 
 
 
 
 
 
 
8ad5d56
 
6640531
8ad5d56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0b2c21
 
8ad5d56
b0b2c21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ad5d56
 
 
 
 
b0b2c21
 
 
 
 
 
 
 
 
 
 
8ad5d56
 
 
6640531
8ad5d56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0b2c21
 
 
 
 
 
8ad5d56
 
 
 
b0b2c21
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from flask import Blueprint, render_template, jsonify, request, current_app
from app.training.mock_trainer import MockTrainer
from app.training.privacy_calculator import PrivacyCalculator
from flask_cors import cross_origin
import os

# Try to import RealTrainer, fallback to MockTrainer if dependencies aren't available
try:
    from app.training.simplified_real_trainer import SimplifiedRealTrainer as RealTrainer
    REAL_TRAINER_AVAILABLE = True
    print("Simplified real trainer available - will use MNIST dataset")
except ImportError as e:
    print(f"Real trainer not available ({e}) - trying simplified version")
    try:
        from app.training.real_trainer import RealTrainer
        REAL_TRAINER_AVAILABLE = True
        print("Full real trainer available - will use MNIST dataset")
    except ImportError as e2:
        print(f"No real trainer available ({e2}) - using mock trainer")
        REAL_TRAINER_AVAILABLE = False

main = Blueprint('main', __name__)
mock_trainer = MockTrainer()
privacy_calculator = PrivacyCalculator()

# Initialize real trainer if available
if REAL_TRAINER_AVAILABLE:
    try:
        real_trainer = RealTrainer()
        print("Real trainer initialized successfully")
    except Exception as e:
        print(f"Failed to initialize real trainer: {e}")
        REAL_TRAINER_AVAILABLE = False
        real_trainer = None
else:
    real_trainer = None

@main.route('/')
def index():
    return render_template('index.html')

@main.route('/learning')
def learning():
    return render_template('learning.html')

@main.route('/api/train', methods=['POST', 'OPTIONS'])
@cross_origin()
def train():
    if request.method == 'OPTIONS':
        return jsonify({'status': 'ok'})

    try:
        data = request.json
        if not data:
            return jsonify({'error': 'No data provided'}), 400

        params = {
            'clipping_norm': float(data.get('clipping_norm', 1.0)),
            'noise_multiplier': float(data.get('noise_multiplier', 1.0)),
            'batch_size': int(data.get('batch_size', 64)),
            'learning_rate': float(data.get('learning_rate', 0.01)),
            'epochs': int(data.get('epochs', 5))
        }
        
        # Check if user wants to force mock training
        use_mock = data.get('use_mock', False)
        
        # Use real trainer if available and not forced to use mock
        if REAL_TRAINER_AVAILABLE and real_trainer and not use_mock:
            print("Using real trainer with MNIST dataset")
            results = real_trainer.train(params)
            results['trainer_type'] = 'real'
            results['dataset'] = 'MNIST'
        else:
            print("Using mock trainer with synthetic data")
            results = mock_trainer.train(params)
            results['trainer_type'] = 'mock'
            results['dataset'] = 'synthetic'
        
        # Add gradient information for visualization (if not already included)
        if 'gradient_info' not in results:
            trainer = real_trainer if (REAL_TRAINER_AVAILABLE and real_trainer and not use_mock) else mock_trainer
            results['gradient_info'] = {
                'before_clipping': trainer.generate_gradient_norms(params['clipping_norm']),
                'after_clipping': trainer.generate_clipped_gradients(params['clipping_norm'])
            }
        
        return jsonify(results)
    except (TypeError, ValueError) as e:
        return jsonify({'error': f'Invalid parameter values: {str(e)}'}), 400
    except Exception as e:
        print(f"Training error: {str(e)}")
        # Fallback to mock trainer on any error
        try:
            print("Falling back to mock trainer due to error")
            results = mock_trainer.train(params)
            results['trainer_type'] = 'mock'
            results['dataset'] = 'synthetic'
            results['fallback_reason'] = str(e)
            return jsonify(results)
        except Exception as fallback_error:
            return jsonify({'error': f'Server error: {str(fallback_error)}'}), 500

@main.route('/api/privacy-budget', methods=['POST', 'OPTIONS'])
@cross_origin()
def calculate_privacy_budget():
    if request.method == 'OPTIONS':
        return jsonify({'status': 'ok'})

    try:
        data = request.json
        if not data:
            return jsonify({'error': 'No data provided'}), 400

        params = {
            'clipping_norm': float(data.get('clipping_norm', 1.0)),
            'noise_multiplier': float(data.get('noise_multiplier', 1.0)),
            'batch_size': int(data.get('batch_size', 64)),
            'epochs': int(data.get('epochs', 5))
        }
        
        # Use real trainer's privacy calculation if available, otherwise use privacy calculator
        if REAL_TRAINER_AVAILABLE and real_trainer:
            epsilon = real_trainer._calculate_privacy_budget(params)
        else:
            epsilon = privacy_calculator.calculate_epsilon(params)
            
        return jsonify({'epsilon': epsilon})
    except (TypeError, ValueError) as e:
        return jsonify({'error': f'Invalid parameter values: {str(e)}'}), 400
    except Exception as e:
        return jsonify({'error': f'Server error: {str(e)}'}), 500

@main.route('/api/trainer-status', methods=['GET'])
@cross_origin()
def trainer_status():
    """Endpoint to check which trainer is being used."""
    return jsonify({
        'real_trainer_available': REAL_TRAINER_AVAILABLE,
        'current_trainer': 'real' if REAL_TRAINER_AVAILABLE else 'mock',
        'dataset': 'MNIST' if REAL_TRAINER_AVAILABLE else 'synthetic'
    })