chore: add encryption
Browse files- app.py +35 -11
- utils_demo.py +5 -0
app.py
CHANGED
|
@@ -3,7 +3,7 @@
|
|
| 3 |
import os
|
| 4 |
import re
|
| 5 |
from typing import Dict, List
|
| 6 |
-
|
| 7 |
import gradio as gr
|
| 8 |
import pandas as pd
|
| 9 |
from fhe_anonymizer import FHEAnonymizer
|
|
@@ -11,16 +11,23 @@ from openai import OpenAI
|
|
| 11 |
from utils_demo import *
|
| 12 |
from concrete.ml.deployment import FHEModelClient
|
| 13 |
|
|
|
|
| 14 |
ORIGINAL_DOCUMENT = read_txt(ORIGINAL_FILE_PATH).split("\n\n")
|
| 15 |
ANONYMIZED_DOCUMENT = read_txt(ANONYMIZED_FILE_PATH)
|
| 16 |
MAPPING_SENTENCES = read_pickle(MAPPING_SENTENCES_PATH)
|
| 17 |
|
|
|
|
|
|
|
|
|
|
| 18 |
clean_directory()
|
| 19 |
|
| 20 |
anonymizer = FHEAnonymizer()
|
| 21 |
|
| 22 |
client = OpenAI(api_key=os.environ.get("openaikey"))
|
| 23 |
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
def select_static_sentences_fn(selected_sentences: List):
|
| 26 |
|
|
@@ -39,11 +46,7 @@ def key_gen_fn() -> Dict:
|
|
| 39 |
Returns:
|
| 40 |
dict: A dictionary containing the generated keys and related information.
|
| 41 |
"""
|
| 42 |
-
print("Key
|
| 43 |
-
|
| 44 |
-
# Generate a random user ID
|
| 45 |
-
user_id = np.random.randint(0, 2**32)
|
| 46 |
-
print(f"Your user ID is: {user_id}....")
|
| 47 |
|
| 48 |
client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{user_id}")
|
| 49 |
client.load()
|
|
@@ -74,16 +77,16 @@ def key_gen_fn() -> Dict:
|
|
| 74 |
|
| 75 |
|
| 76 |
def encrypt_query_fn(query):
|
| 77 |
-
print(f"Query: {query}")
|
| 78 |
|
| 79 |
-
|
|
|
|
|
|
|
| 80 |
|
| 81 |
if not evaluation_key_path.is_file():
|
| 82 |
error_message = "Error β: Please generate the key first!"
|
| 83 |
return {output_encrypted_box: gr.update(value=error_message)}
|
| 84 |
|
| 85 |
if is_user_query_valid(query):
|
| 86 |
-
# TODO: check if the query is related to our context
|
| 87 |
error_msg = (
|
| 88 |
"Unable to process β: The request exceeds the length limit or falls "
|
| 89 |
"outside the scope of this document. Please refine your query."
|
|
@@ -91,9 +94,30 @@ def encrypt_query_fn(query):
|
|
| 91 |
print(error_msg)
|
| 92 |
return {query_box: gr.update(value=error_msg)}
|
| 93 |
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
-
|
| 97 |
|
| 98 |
encrypted_quant_tokens_hex = [token.hex()[500:510] for token in encrypted_tokens]
|
| 99 |
|
|
|
|
| 3 |
import os
|
| 4 |
import re
|
| 5 |
from typing import Dict, List
|
| 6 |
+
import numpy
|
| 7 |
import gradio as gr
|
| 8 |
import pandas as pd
|
| 9 |
from fhe_anonymizer import FHEAnonymizer
|
|
|
|
| 11 |
from utils_demo import *
|
| 12 |
from concrete.ml.deployment import FHEModelClient
|
| 13 |
|
| 14 |
+
|
| 15 |
ORIGINAL_DOCUMENT = read_txt(ORIGINAL_FILE_PATH).split("\n\n")
|
| 16 |
ANONYMIZED_DOCUMENT = read_txt(ANONYMIZED_FILE_PATH)
|
| 17 |
MAPPING_SENTENCES = read_pickle(MAPPING_SENTENCES_PATH)
|
| 18 |
|
| 19 |
+
subprocess.Popen(["uvicorn", "server:app"], cwd=CURRENT_DIR)
|
| 20 |
+
time.sleep(3)
|
| 21 |
+
|
| 22 |
clean_directory()
|
| 23 |
|
| 24 |
anonymizer = FHEAnonymizer()
|
| 25 |
|
| 26 |
client = OpenAI(api_key=os.environ.get("openaikey"))
|
| 27 |
|
| 28 |
+
# Generate a random user ID
|
| 29 |
+
user_id = numpy.random.randint(0, 2**32)
|
| 30 |
+
print(f"Your user ID is: {user_id}....")
|
| 31 |
|
| 32 |
def select_static_sentences_fn(selected_sentences: List):
|
| 33 |
|
|
|
|
| 46 |
Returns:
|
| 47 |
dict: A dictionary containing the generated keys and related information.
|
| 48 |
"""
|
| 49 |
+
print("Step 1: Key Generation:")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{user_id}")
|
| 52 |
client.load()
|
|
|
|
| 77 |
|
| 78 |
|
| 79 |
def encrypt_query_fn(query):
|
|
|
|
| 80 |
|
| 81 |
+
print(f"Step 2 Query encryption: {query=}")
|
| 82 |
+
|
| 83 |
+
evaluation_key_path = KEYS_DIR / f"{user_id}/evaluation_key"
|
| 84 |
|
| 85 |
if not evaluation_key_path.is_file():
|
| 86 |
error_message = "Error β: Please generate the key first!"
|
| 87 |
return {output_encrypted_box: gr.update(value=error_message)}
|
| 88 |
|
| 89 |
if is_user_query_valid(query):
|
|
|
|
| 90 |
error_msg = (
|
| 91 |
"Unable to process β: The request exceeds the length limit or falls "
|
| 92 |
"outside the scope of this document. Please refine your query."
|
|
|
|
| 94 |
print(error_msg)
|
| 95 |
return {query_box: gr.update(value=error_msg)}
|
| 96 |
|
| 97 |
+
# Retrieve the client API
|
| 98 |
+
client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{user_id}")
|
| 99 |
+
client.load()
|
| 100 |
+
|
| 101 |
+
# Pattern to identify words and non-words (including punctuation, spaces, etc.)
|
| 102 |
+
tokens = re.findall(r"(\b[\w\.\/\-@]+\b|[\s,.!?;:'\"-]+)", query)
|
| 103 |
+
encrypted_tokens = []
|
| 104 |
+
|
| 105 |
+
for token in tokens:
|
| 106 |
+
if bool(re.match(r"^\s+$", token)):
|
| 107 |
+
continue
|
| 108 |
+
# Directly append non-word tokens or whitespace to processed_tokens
|
| 109 |
+
|
| 110 |
+
# Prediction for each word
|
| 111 |
+
emb_x = get_batch_text_representation([token], EMBEDDINGS_MODEL, TOKENIZER)
|
| 112 |
+
encrypted_x = client.quantize_encrypt_serialize(emb_x)
|
| 113 |
+
assert isinstance(encrypted_x, bytes)
|
| 114 |
+
|
| 115 |
+
encrypted_tokens.append(encrypted_x)
|
| 116 |
+
|
| 117 |
+
write_pickle(KEYS_DIR / f"{user_id}/encrypted_input", encrypted_tokens)
|
| 118 |
+
|
| 119 |
|
| 120 |
+
#anonymizer.encrypt_query(query)
|
| 121 |
|
| 122 |
encrypted_quant_tokens_hex = [token.hex()[500:510] for token in encrypted_tokens]
|
| 123 |
|
utils_demo.py
CHANGED
|
@@ -6,6 +6,7 @@ import shutil
|
|
| 6 |
import string
|
| 7 |
from collections import Counter
|
| 8 |
from pathlib import Path
|
|
|
|
| 9 |
|
| 10 |
import numpy as np
|
| 11 |
import torch
|
|
@@ -35,6 +36,10 @@ PROMPT_PATH = DATA_PATH / "chatgpt_prompt.txt"
|
|
| 35 |
|
| 36 |
ALL_DIRS = [KEYS_DIR]
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
PUNCTUATION_LIST = list(string.punctuation)
|
| 39 |
PUNCTUATION_LIST.remove("%")
|
| 40 |
PUNCTUATION_LIST.remove("$")
|
|
|
|
| 6 |
import string
|
| 7 |
from collections import Counter
|
| 8 |
from pathlib import Path
|
| 9 |
+
from transformers import AutoModel, AutoTokenizer
|
| 10 |
|
| 11 |
import numpy as np
|
| 12 |
import torch
|
|
|
|
| 36 |
|
| 37 |
ALL_DIRS = [KEYS_DIR]
|
| 38 |
|
| 39 |
+
# Load tokenizer and model
|
| 40 |
+
TOKENIZER = AutoTokenizer.from_pretrained("obi/deid_roberta_i2b2")
|
| 41 |
+
EMBEDDINGS_MODEL = AutoModel.from_pretrained("obi/deid_roberta_i2b2")
|
| 42 |
+
|
| 43 |
PUNCTUATION_LIST = list(string.punctuation)
|
| 44 |
PUNCTUATION_LIST.remove("%")
|
| 45 |
PUNCTUATION_LIST.remove("$")
|