Spaces:
Running
Running
Upload 2 files
Browse files
server/security/notebook_training_gr.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
server/security/prompt_guard.py
CHANGED
@@ -25,7 +25,6 @@ def get_embedding(documents: list[str]) -> NDArray[np.float32]:
|
|
25 |
return model.encode(documents)
|
26 |
|
27 |
|
28 |
-
|
29 |
class Guardrail:
|
30 |
"""
|
31 |
A class to handle guardrail analysis based on query embeddings.
|
@@ -38,11 +37,11 @@ class Guardrail:
|
|
38 |
"""
|
39 |
Initializes the Guardrail class with a guardrail model instance.
|
40 |
"""
|
41 |
-
file_path = os.path.join("server","security","storage","guardrail_multi.pkl")
|
42 |
with open(file_path, "rb") as f:
|
43 |
self.guardrail = load(f)
|
44 |
|
45 |
-
def analyze_language(self, query:str) -> bool:
|
46 |
"""
|
47 |
Analyzes the given query to determine what language it is written in and whether it is english, french, german or spanish.
|
48 |
|
@@ -53,8 +52,8 @@ class Guardrail:
|
|
53 |
bool: Returns `False` if the query is not a supported language, `True` otherwise.
|
54 |
"""
|
55 |
det = detect(query)
|
56 |
-
return det in ["en","fr","de","es"]
|
57 |
-
|
58 |
def analyze_query(self, query: str) -> bool:
|
59 |
"""
|
60 |
Analyzes the given query to determine if it passes the guardrail check.
|
@@ -68,7 +67,6 @@ class Guardrail:
|
|
68 |
embed_query = get_embedding(documents=[query])
|
69 |
pred = self.guardrail.predict(embed_query.reshape(1, -1))
|
70 |
return pred != 1 # Return True if pred is not 1, otherwise False
|
71 |
-
|
72 |
|
73 |
def incremental_learning(self, X_new, y_new):
|
74 |
"""
|
@@ -80,9 +78,11 @@ class Guardrail:
|
|
80 |
"""
|
81 |
# Extraction des caractéristiques
|
82 |
embedding = model.encode(X_new)
|
83 |
-
|
84 |
# Mise à jour incrémentale du modèle
|
85 |
-
self.guardrail.partial_fit(embedding, y_new, classes=[0, 1])
|
86 |
|
87 |
-
with open(
|
|
|
|
|
88 |
dump(self.guardrail, f)
|
|
|
25 |
return model.encode(documents)
|
26 |
|
27 |
|
|
|
28 |
class Guardrail:
|
29 |
"""
|
30 |
A class to handle guardrail analysis based on query embeddings.
|
|
|
37 |
"""
|
38 |
Initializes the Guardrail class with a guardrail model instance.
|
39 |
"""
|
40 |
+
file_path = os.path.join("server", "security", "storage", "guardrail_multi.pkl")
|
41 |
with open(file_path, "rb") as f:
|
42 |
self.guardrail = load(f)
|
43 |
|
44 |
+
def analyze_language(self, query: str) -> bool:
|
45 |
"""
|
46 |
Analyzes the given query to determine what language it is written in and whether it is english, french, german or spanish.
|
47 |
|
|
|
52 |
bool: Returns `False` if the query is not a supported language, `True` otherwise.
|
53 |
"""
|
54 |
det = detect(query)
|
55 |
+
return det in ["en", "fr", "de", "es"]
|
56 |
+
|
57 |
def analyze_query(self, query: str) -> bool:
|
58 |
"""
|
59 |
Analyzes the given query to determine if it passes the guardrail check.
|
|
|
67 |
embed_query = get_embedding(documents=[query])
|
68 |
pred = self.guardrail.predict(embed_query.reshape(1, -1))
|
69 |
return pred != 1 # Return True if pred is not 1, otherwise False
|
|
|
70 |
|
71 |
def incremental_learning(self, X_new, y_new):
|
72 |
"""
|
|
|
78 |
"""
|
79 |
# Extraction des caractéristiques
|
80 |
embedding = model.encode(X_new)
|
81 |
+
|
82 |
# Mise à jour incrémentale du modèle
|
83 |
+
self.guardrail.partial_fit(embedding.reshape(1, -1), y_new, classes=[0, 1])
|
84 |
|
85 |
+
with open(
|
86 |
+
os.path.join("server", "security", "storage", "guardrail_multi.pkl"), "wb"
|
87 |
+
) as f:
|
88 |
dump(self.guardrail, f)
|