File size: 7,511 Bytes
5451fa1
 
29c4c86
5451fa1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29c4c86
 
 
 
5451fa1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69037ca
 
 
 
 
 
 
 
 
5451fa1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
"""A gradio app. that runs locally (analytics=False and share=False) about sentiment analysis on tweets."""

import random
import numpy as np
import gradio as gr
from concrete.ml.deployment import FHEModelClient
import numpy
import os
from pathlib import Path

import shutil
import torch

from model import Autoencoder
from concrete.ml.torch.compile import compile_torch_model

sequence_length = 50
input_size = 12
latent_size = 8
hidden_size = 64

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

ae_model = Autoencoder(
    input_size=input_size,
    hidden_size=hidden_size,
    latent_size=latent_size,
    sequence_length=sequence_length,
    num_lstm_layers=1,
)

encoder = ae_model.encoder
encoder.load_state_dict(torch.load("deployment/encoder.pth", weights_only=True))

decoder = ae_model.decoder
decoder.load_state_dict(torch.load("deployment/decoder.pth", weights_only=True))

criterion = torch.nn.MSELoss()

dummy_input = torch.randn(1, latent_size)
compiled_decoder = compile_torch_model(
    decoder,
    dummy_input.numpy(),
    n_bits=6,
    rounding_threshold_bits={"n_bits": 6, "method": "approximate"},
)

# Encrypted data limit for the browser to display
# (encrypted data is too large to display in the browser)
ENCRYPTED_DATA_BROWSER_LIMIT = 100
N_USER_KEY_STORED = 20
FHE_MODEL_PATH = "deployment"

def clean_tmp_directory():
    # Allow 20 user keys to be stored.
    # Once that limitation is reached, deleted the oldest.
    path_sub_directories = sorted([f for f in Path(".fhe_keys/").iterdir() if f.is_dir()], key=os.path.getmtime)

    user_ids = []
    if len(path_sub_directories) > N_USER_KEY_STORED:
        n_files_to_delete = len(path_sub_directories) - N_USER_KEY_STORED
        for p in path_sub_directories[:n_files_to_delete]:
            user_ids.append(p.name)
            shutil.rmtree(p)

    list_files_tmp = Path("tmp/").iterdir()
    # Delete all files related to user_id
    for file in list_files_tmp:
        for user_id in user_ids:
            if file.name.endswith(f"{user_id}.npy"):
                file.unlink()


def keygen():
    # Clean tmp directory if needed
    clean_tmp_directory()

    print("Initializing FHEModelClient...")

    # Let's create a user_id
    user_id = numpy.random.randint(0, 2**32)
    fhe_api = FHEModelClient(FHE_MODEL_PATH, f".fhe_keys/{user_id}")
    fhe_api.load()

    # Generate a fresh key
    fhe_api.generate_private_and_evaluation_keys(force=True)
    evaluation_key = fhe_api.get_serialized_evaluation_keys()

    # Save evaluation_key in a file, since too large to pass through regular Gradio
    # buttons, https://github.com/gradio-app/gradio/issues/1877
    numpy.save(f"tmp/tmp_evaluation_key_{user_id}.npy", evaluation_key)

    return [list(evaluation_key)[:ENCRYPTED_DATA_BROWSER_LIMIT], user_id]


def run_fhe(packets_ids, threshold=0.05):
    int_values = np.array([int(h[0], 16) for h in packets_ids.split(" ")])
    binary_rep = np.array([list(bin(x)[2:].zfill(12)) for x in int_values])
    packets_ids = binary_rep.astype(float)
    packets_ids = torch.tensor(packets_ids).unsqueeze(0).float()

    latent = encoder(packets_ids)

    with torch.no_grad():  # Disable gradient computation for validation
        decrypted_output = compiled_decoder.forward(latent.numpy(), fhe="simulate")
    
    decrypted_output = torch.tensor(decrypted_output).view(
        -1, ae_model.sequence_length, packets_ids.size(2)
    )

    loss = criterion(decrypted_output, packets_ids)
    pred = loss.item() > threshold

    return [loss, pred]


demo = gr.Blocks()

with demo:
    gr.Markdown(
        """
        <h1 align="center">CAN Bus Intrusion Detection With FHE</h1>

        <p align="center">
            <img src="https://ars.els-cdn.com/content/image/1-s2.0-S0167404824000786-gr001_lrg.jpg" width="60%" height="60%">
        </p>
        """
    )

    gr.Markdown("## Step 1: Generate the keys")

    b_gen_key_and_install = gr.Button("Get the keys")

    evaluation_key = gr.Textbox(
        label="Evaluation key (truncated):",
        max_lines=1,
        interactive=False,
    )

    user_id = gr.Textbox(
        label="",
        max_lines=1,
        interactive=False,
        visible=False
    )

    gr.Markdown(
        """
        ## Step 2: Provide the packets ids
        Enter a sensitive electronic control units (ECUs) communication packets from in-vehicle network (IVN).
        """
    )
    
    packets_ids = gr.Textbox(
        label="Packets ids",
        info="Enter a sequence of 50 packets ids separated by a space",
        max_lines=1,
    )
    gr.Examples(
        label="Free attacks",
        examples=[
            "316 18F 260 2A0 329 545 002 153 2C0 130 131 140 350 43F 370 440 316 18F 260 2A0 329 4F0 545 430 4B1 1F1 153 002 2C0 350 130 131 140 370 43F 440 5F0 18F 260 2A0 316 329 545 002 153 2C0 130 131 140 350",
            "329 4F0 545 430 4B1 1F1 153 002 2C0 350 130 131 140 370 43F 440 5F0 18F 260 2A0 316 329 545 002 153 2C0 130 131 140 350 43F 370 0A0 0A1 440 316 18F 260 2A0 329 4F0 545 430 4B1 1F1 153 002 2C0 350 130",
            "316 329 545 002 153 2C0 130 131 140 350 43F 370 0A0 0A1 440 316 18F 260 2A0 329 4F0 545 430 4B1 1F1 153 002 2C0 350 130 131 140 370 43F 440 316 18F 260 2A0 329 545 002 153 2C0 130 131 140 350 43F 370",
        ],
        inputs=[packets_ids],
    )
    gr.Examples(
        label="DoS attacks",
        examples=[
            "130 131 140 370 43F 440 316 18F 260 2A0 329 545 002 153 2C0 130 131 140 350 43F 370 440 316 18F 260 2A0 329 545 4F0 430 2C0 4B1 1F1 153 002 350 000 130 000 131 000 140 000 370 000 43F 000 440 000 000",
            "370 440 316 18F 260 2A0 329 545 4F0 430 2C0 4B1 1F1 153 002 350 000 130 000 131 000 140 000 370 000 43F 000 440 000 000 000 18F 000 260 000 2A0 000 316 000 329 000 545 000 000 000 000 000 002 000 153",
            "000 140 000 370 000 43F 000 440 000 000 000 18F 000 260 000 2A0 000 316 000 329 000 545 000 000 000 000 000 002 000 153 000 2C0 000 130 000 131 000 140 000 350 000 370 000 43F 000 440 000 000 000 316",
        ],
        inputs=[packets_ids],
    )
    gr.Examples(
        label="Fuzzy attacks",
        examples=[
            "260 2A0 329 545 2B0 002 153 350 130 131 140 2C0 370 43F 440 4F0 18F 260 2A0 316 329 545 5F0 2B0 430 4B1 1F1 153 002 350 2C0 370 43F 130 131 140 440 2B0 130 2B0 130 440 002 2A0 430 2B0 350 0DF 6EA 2FD",
            "329 545 5F0 2B0 430 4B1 1F1 153 002 350 2C0 370 43F 130 131 140 440 2B0 130 2B0 130 440 002 2A0 430 2B0 350 0DF 6EA 2FD 2B0 370 2B0 430 4B1 1F1 12D 153 002 2C0 33A 350 130 131 39F 131 7E3 130 491 522",
            "130 440 002 2A0 430 2B0 350 0DF 6EA 2FD 2B0 370 2B0 430 4B1 1F1 12D 153 002 2C0 33A 350 130 131 39F 131 7E3 130 491 522 2FD 78D 260 330 329 494 4F0 545 2B0 430 4B1 1F1 153 002 350 2C0 370 43F 130 131",
        ],
        inputs=[packets_ids],
    )

    gr.Markdown(
        """
        ## Step 3: Detect the attack using FHE
        """
    )
    threshold = gr.Slider(0, 1, value=0.05, label="Threshold", info="Choose between 0 and 1")
    b_detect = gr.Button("Detect", variant="primary")
    prediction = gr.Textbox(
        label="Prediction",
        max_lines=1,
        interactive=False,
    )
    loss = gr.Textbox(
        label="Loss",
        max_lines=1,
        interactive=False,
    )

    b_gen_key_and_install.click(keygen, inputs=[], outputs=[evaluation_key, user_id])
    b_detect.click(run_fhe, inputs=[packets_ids, threshold], outputs=[loss, prediction])

demo.launch(share=False)