|  | """A local gradio app that detects seizures with EEG using FHE.""" | 
					
						
						|  | from PIL import Image | 
					
						
						|  | import os | 
					
						
						|  | import shutil | 
					
						
						|  | import subprocess | 
					
						
						|  | import time | 
					
						
						|  | import gradio as gr | 
					
						
						|  | import numpy | 
					
						
						|  | import requests | 
					
						
						|  | from itertools import chain | 
					
						
						|  |  | 
					
						
						|  | from common import ( | 
					
						
						|  | CLIENT_TMP_PATH, | 
					
						
						|  | SERVER_TMP_PATH, | 
					
						
						|  | EXAMPLES, | 
					
						
						|  | INPUT_SHAPE, | 
					
						
						|  | KEYS_PATH, | 
					
						
						|  | REPO_DIR, | 
					
						
						|  | SERVER_URL, | 
					
						
						|  | ) | 
					
						
						|  | from client_server_interface import FHEClient | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | subprocess.Popen(["uvicorn", "server:app"], cwd=REPO_DIR) | 
					
						
						|  | time.sleep(3) | 
					
						
						|  |  | 
					
						
						|  | def shorten_bytes_object(bytes_object, limit=500): | 
					
						
						|  | """Shorten the input bytes object to a given length. | 
					
						
						|  |  | 
					
						
						|  | Encrypted data is too large for displaying it in the browser using Gradio. This function | 
					
						
						|  | provides a shorten representation of it. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | bytes_object (bytes): The input to shorten | 
					
						
						|  | limit (int): The length to consider. Default to 500. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | str: Hexadecimal string shorten representation of the input byte object. | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | shift = 100 | 
					
						
						|  | return bytes_object[shift : limit + shift].hex() | 
					
						
						|  |  | 
					
						
						|  | def get_client(user_id): | 
					
						
						|  | """Get the client API. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | user_id (int): The current user's ID. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | FHEClient: The client API. | 
					
						
						|  | """ | 
					
						
						|  | return FHEClient( | 
					
						
						|  | key_dir=KEYS_PATH / f"seizure_detection_{user_id}" | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def get_client_file_path(name, user_id): | 
					
						
						|  | """Get the correct temporary file path for the client. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | name (str): The desired file name. | 
					
						
						|  | user_id (int): The current user's ID. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | pathlib.Path: The file path. | 
					
						
						|  | """ | 
					
						
						|  | return CLIENT_TMP_PATH / f"{name}_seizure_detection_{user_id}" | 
					
						
						|  |  | 
					
						
						|  | def clean_temporary_files(n_keys=20): | 
					
						
						|  | """Clean keys and encrypted images. | 
					
						
						|  |  | 
					
						
						|  | A maximum of n_keys keys and associated temporary files are allowed to be stored. Once this | 
					
						
						|  | limit is reached, the oldest files are deleted. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | n_keys (int): The maximum number of keys and associated files to be stored. Default to 20. | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | key_dirs = sorted(KEYS_PATH.iterdir(), key=os.path.getmtime) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | user_ids = [] | 
					
						
						|  | if len(key_dirs) > n_keys: | 
					
						
						|  | n_keys_to_delete = len(key_dirs) - n_keys | 
					
						
						|  | for key_dir in key_dirs[:n_keys_to_delete]: | 
					
						
						|  | user_ids.append(key_dir.name) | 
					
						
						|  | shutil.rmtree(key_dir) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | client_files = CLIENT_TMP_PATH.iterdir() | 
					
						
						|  | server_files = SERVER_TMP_PATH.iterdir() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | for file in chain(client_files, server_files): | 
					
						
						|  | for user_id in user_ids: | 
					
						
						|  | if user_id in file.name: | 
					
						
						|  | file.unlink() | 
					
						
						|  |  | 
					
						
						|  | def keygen(): | 
					
						
						|  | """Generate the private key for seizure detection. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | (user_id, True) (Tuple[int, bool]): The current user's ID and a boolean used for visual display. | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | clean_temporary_files() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | user_id = numpy.random.randint(0, 2**32) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | client = get_client(user_id) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | client.generate_private_and_evaluation_keys(force=True) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | evaluation_key = client.get_serialized_evaluation_keys() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | evaluation_key_path = get_client_file_path("evaluation_key", user_id) | 
					
						
						|  |  | 
					
						
						|  | with evaluation_key_path.open("wb") as evaluation_key_file: | 
					
						
						|  | evaluation_key_file.write(evaluation_key) | 
					
						
						|  |  | 
					
						
						|  | return (user_id, True) | 
					
						
						|  |  | 
					
						
						|  | def encrypt(user_id, input_image): | 
					
						
						|  | """Encrypt the given image for seizure detection. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | user_id (int): The current user's ID. | 
					
						
						|  | input_image (numpy.ndarray): The image to encrypt. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | (input_image, encrypted_image_short) (Tuple[bytes]): The encrypted image and one of its | 
					
						
						|  | representation. | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  | if user_id == "": | 
					
						
						|  | raise gr.Error("Please generate the private key first.") | 
					
						
						|  |  | 
					
						
						|  | if input_image is None: | 
					
						
						|  | raise gr.Error("Please choose an image first.") | 
					
						
						|  |  | 
					
						
						|  | if input_image.shape[-1] != 3: | 
					
						
						|  | raise ValueError(f"Input image must have 3 channels (RGB). Current shape: {input_image.shape}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if input_image.shape != (224, 224, 3): | 
					
						
						|  | input_image_pil = Image.fromarray(input_image) | 
					
						
						|  | input_image_pil = input_image_pil.resize((224, 224)) | 
					
						
						|  | input_image = numpy.array(input_image_pil) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | client = get_client(user_id) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | encrypted_image = client.encrypt_serialize(input_image) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | encrypted_image_path = get_client_file_path("encrypted_image", user_id) | 
					
						
						|  |  | 
					
						
						|  | with encrypted_image_path.open("wb") as encrypted_image_file: | 
					
						
						|  | encrypted_image_file.write(encrypted_image) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | encrypted_image_short = shorten_bytes_object(encrypted_image) | 
					
						
						|  |  | 
					
						
						|  | return (resize_img(input_image), encrypted_image_short) | 
					
						
						|  |  | 
					
						
						|  | def send_input(user_id): | 
					
						
						|  | """Send the encrypted input image as well as the evaluation key to the server. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | user_id (int): The current user's ID. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | evaluation_key_path = get_client_file_path("evaluation_key", user_id) | 
					
						
						|  |  | 
					
						
						|  | if user_id == "" or not evaluation_key_path.is_file(): | 
					
						
						|  | raise gr.Error("Please generate the private key first.") | 
					
						
						|  |  | 
					
						
						|  | encrypted_input_path = get_client_file_path("encrypted_image", user_id) | 
					
						
						|  |  | 
					
						
						|  | if not encrypted_input_path.is_file(): | 
					
						
						|  | raise gr.Error("Please generate the private key and then encrypt an image first.") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | data = { | 
					
						
						|  | "user_id": user_id, | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | files = [ | 
					
						
						|  | ("files", open(encrypted_input_path, "rb")), | 
					
						
						|  | ("files", open(evaluation_key_path, "rb")), | 
					
						
						|  | ] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | url = SERVER_URL + "send_input" | 
					
						
						|  | with requests.post( | 
					
						
						|  | url=url, | 
					
						
						|  | data=data, | 
					
						
						|  | files=files, | 
					
						
						|  | ) as response: | 
					
						
						|  | return response.ok | 
					
						
						|  |  | 
					
						
						|  | def run_fhe(user_id): | 
					
						
						|  | """Apply the seizure detection model on the encrypted image previously sent using FHE. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | user_id (int): The current user's ID. | 
					
						
						|  | """ | 
					
						
						|  | data = { | 
					
						
						|  | "user_id": user_id, | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | url = SERVER_URL + "run_fhe" | 
					
						
						|  | with requests.post( | 
					
						
						|  | url=url, | 
					
						
						|  | data=data, | 
					
						
						|  | ) as response: | 
					
						
						|  | if response.ok: | 
					
						
						|  | return response.json() | 
					
						
						|  | else: | 
					
						
						|  | raise gr.Error("Please wait for the input image to be sent to the server.") | 
					
						
						|  |  | 
					
						
						|  | def get_output(user_id): | 
					
						
						|  | """Retrieve the encrypted output (boolean). | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | user_id (int): The current user's ID. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | encrypted_output_short (bytes): A representation of the encrypted result. | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  | data = { | 
					
						
						|  | "user_id": user_id, | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | url = SERVER_URL + "get_output" | 
					
						
						|  | with requests.post( | 
					
						
						|  | url=url, | 
					
						
						|  | data=data, | 
					
						
						|  | ) as response: | 
					
						
						|  | if response.ok: | 
					
						
						|  | encrypted_output = response.content | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | encrypted_output_path = get_client_file_path("encrypted_output", user_id) | 
					
						
						|  |  | 
					
						
						|  | with encrypted_output_path.open("wb") as encrypted_output_file: | 
					
						
						|  | encrypted_output_file.write(encrypted_output) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | encrypted_output_short = shorten_bytes_object(encrypted_output) | 
					
						
						|  |  | 
					
						
						|  | return encrypted_output_short | 
					
						
						|  | else: | 
					
						
						|  | raise gr.Error("Please wait for the FHE execution to be completed.") | 
					
						
						|  |  | 
					
						
						|  | def decrypt_output(user_id): | 
					
						
						|  | """Decrypt the result. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | user_id (int): The current user's ID. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | bool: The decrypted output (True if seizure detected, False otherwise) | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  | if user_id == "": | 
					
						
						|  | raise gr.Error("Please generate the private key first.") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | encrypted_output_path = get_client_file_path("encrypted_output", user_id) | 
					
						
						|  |  | 
					
						
						|  | if not encrypted_output_path.is_file(): | 
					
						
						|  | raise gr.Error("Please run the FHE execution first.") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with encrypted_output_path.open("rb") as encrypted_output_file: | 
					
						
						|  | encrypted_output = encrypted_output_file.read() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | client = get_client(user_id) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | decrypted_output = client.deserialize_decrypt_post_process(encrypted_output) | 
					
						
						|  |  | 
					
						
						|  | return "Seizure detected" if decrypted_output else "No seizure detected" | 
					
						
						|  |  | 
					
						
						|  | def resize_img(img, width=256, height=256): | 
					
						
						|  | """Resize the image.""" | 
					
						
						|  | if img.dtype != numpy.uint8: | 
					
						
						|  | img = img.astype(numpy.uint8) | 
					
						
						|  | img_pil = Image.fromarray(img) | 
					
						
						|  |  | 
					
						
						|  | resized_img_pil = img_pil.resize((width, height)) | 
					
						
						|  |  | 
					
						
						|  | return numpy.array(resized_img_pil) | 
					
						
						|  |  | 
					
						
						|  | demo = gr.Blocks() | 
					
						
						|  |  | 
					
						
						|  | print("Starting the demo...") | 
					
						
						|  | with demo: | 
					
						
						|  | gr.Markdown( | 
					
						
						|  | """ | 
					
						
						|  | <h1 align="center">Seizure Detection on Encrypted EEG Data Using Fully Homomorphic Encryption</h1> | 
					
						
						|  | """ | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | gr.Markdown("## Client side") | 
					
						
						|  | gr.Markdown("### Step 1: Upload an EEG image. ") | 
					
						
						|  | gr.Markdown( | 
					
						
						|  | f"The image will automatically be resized to shape (224x224). " | 
					
						
						|  | "The image here, however, is displayed in its original resolution." | 
					
						
						|  | ) | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | input_image = gr.Image( | 
					
						
						|  | value=None, label="Upload an EEG image here.", height=256, | 
					
						
						|  | width=256, sources="upload", interactive=True, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | examples = gr.Examples( | 
					
						
						|  | examples=EXAMPLES, inputs=[input_image], examples_per_page=5, label="Examples to use." | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | gr.Markdown("### Step 2: Generate the private key.") | 
					
						
						|  | keygen_button = gr.Button("Generate the private key.") | 
					
						
						|  |  | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | keygen_checkbox = gr.Checkbox(label="Private key generated:", interactive=False) | 
					
						
						|  |  | 
					
						
						|  | user_id = gr.Textbox(label="", max_lines=2, interactive=False, visible=False) | 
					
						
						|  |  | 
					
						
						|  | gr.Markdown("### Step 3: Encrypt the image using FHE.") | 
					
						
						|  | encrypt_button = gr.Button("Encrypt the image using FHE.") | 
					
						
						|  |  | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | encrypted_input = gr.Textbox( | 
					
						
						|  | label="Encrypted input representation:", max_lines=2, interactive=False | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | gr.Markdown("## Server side") | 
					
						
						|  | gr.Markdown( | 
					
						
						|  | "The encrypted value is received by the server. The server can then compute the seizure " | 
					
						
						|  | "detection directly over encrypted values. Once the computation is finished, the server returns " | 
					
						
						|  | "the encrypted results to the client." | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | gr.Markdown("### Step 4: Send the encrypted image to the server.") | 
					
						
						|  | send_input_button = gr.Button("Send the encrypted image to the server.") | 
					
						
						|  | send_input_checkbox = gr.Checkbox(label="Encrypted image sent.", interactive=False) | 
					
						
						|  |  | 
					
						
						|  | gr.Markdown("### Step 5: Run FHE execution.") | 
					
						
						|  | execute_fhe_button = gr.Button("Run FHE execution.") | 
					
						
						|  | fhe_execution_time = gr.Textbox( | 
					
						
						|  | label="Total FHE execution time (in seconds):", max_lines=1, interactive=False | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | gr.Markdown("### Step 6: Receive the encrypted output from the server.") | 
					
						
						|  | get_output_button = gr.Button("Receive the encrypted output from the server.") | 
					
						
						|  |  | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | encrypted_output = gr.Textbox( | 
					
						
						|  | label="Encrypted output representation:", | 
					
						
						|  | max_lines=2, | 
					
						
						|  | interactive=False | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | gr.Markdown("## Client side") | 
					
						
						|  | gr.Markdown( | 
					
						
						|  | "The encrypted output is sent back to the client, who can finally decrypt it with the " | 
					
						
						|  | "private key. Only the client is aware of the original image and the detection result." | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | gr.Markdown("### Step 7: Decrypt the output.") | 
					
						
						|  | decrypt_button = gr.Button("Decrypt the output") | 
					
						
						|  |  | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | decrypted_output = gr.Textbox( | 
					
						
						|  | label="Seizure detection result:", | 
					
						
						|  | interactive=False | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | keygen_button.click( | 
					
						
						|  | keygen, | 
					
						
						|  | outputs=[user_id, keygen_checkbox], | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | encrypt_button.click( | 
					
						
						|  | encrypt, | 
					
						
						|  | inputs=[user_id, input_image], | 
					
						
						|  | outputs=[input_image, encrypted_input], | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | send_input_button.click( | 
					
						
						|  | send_input, inputs=[user_id], outputs=[send_input_checkbox] | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | execute_fhe_button.click(run_fhe, inputs=[user_id], outputs=[fhe_execution_time]) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | get_output_button.click( | 
					
						
						|  | get_output, | 
					
						
						|  | inputs=[user_id], | 
					
						
						|  | outputs=[encrypted_output] | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | decrypt_button.click( | 
					
						
						|  | decrypt_output, | 
					
						
						|  | inputs=[user_id], | 
					
						
						|  | outputs=[decrypted_output], | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | gr.Markdown( | 
					
						
						|  | "The app was built with [Concrete-ML](https://github.com/zama-ai/concrete-ml), a " | 
					
						
						|  | "Privacy-Preserving Machine Learning (PPML) open-source set of tools by [Zama](https://zama.ai/). " | 
					
						
						|  | "Try it yourself and don't forget to star on Github ⭐." | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | demo.launch(share=False) | 
					
						
						|  |  |