Sahm269 commited on
Commit
7432833
·
verified ·
1 Parent(s): 1806489

Upload 2 files

Browse files
server/security/notebook_training_gr.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
server/security/prompt_guard.py CHANGED
@@ -25,7 +25,6 @@ def get_embedding(documents: list[str]) -> NDArray[np.float32]:
25
  return model.encode(documents)
26
 
27
 
28
-
29
  class Guardrail:
30
  """
31
  A class to handle guardrail analysis based on query embeddings.
@@ -38,11 +37,11 @@ class Guardrail:
38
  """
39
  Initializes the Guardrail class with a guardrail model instance.
40
  """
41
- file_path = os.path.join("server","security","storage","guardrail_multi.pkl")
42
  with open(file_path, "rb") as f:
43
  self.guardrail = load(f)
44
 
45
- def analyze_language(self, query:str) -> bool:
46
  """
47
  Analyzes the given query to determine what language it is written in and whether it is english, french, german or spanish.
48
 
@@ -53,8 +52,8 @@ class Guardrail:
53
  bool: Returns `False` if the query is not a supported language, `True` otherwise.
54
  """
55
  det = detect(query)
56
- return det in ["en","fr","de","es"]
57
-
58
  def analyze_query(self, query: str) -> bool:
59
  """
60
  Analyzes the given query to determine if it passes the guardrail check.
@@ -68,7 +67,6 @@ class Guardrail:
68
  embed_query = get_embedding(documents=[query])
69
  pred = self.guardrail.predict(embed_query.reshape(1, -1))
70
  return pred != 1 # Return True if pred is not 1, otherwise False
71
-
72
 
73
  def incremental_learning(self, X_new, y_new):
74
  """
@@ -80,9 +78,11 @@ class Guardrail:
80
  """
81
  # Extraction des caractéristiques
82
  embedding = model.encode(X_new)
83
-
84
  # Mise à jour incrémentale du modèle
85
- self.guardrail.partial_fit(embedding, y_new, classes=[0, 1])
86
 
87
- with open(os.path.join("server","security","storage","guardrail_multi.pkl"), "wb") as f:
 
 
88
  dump(self.guardrail, f)
 
25
  return model.encode(documents)
26
 
27
 
 
28
  class Guardrail:
29
  """
30
  A class to handle guardrail analysis based on query embeddings.
 
37
  """
38
  Initializes the Guardrail class with a guardrail model instance.
39
  """
40
+ file_path = os.path.join("server", "security", "storage", "guardrail_multi.pkl")
41
  with open(file_path, "rb") as f:
42
  self.guardrail = load(f)
43
 
44
+ def analyze_language(self, query: str) -> bool:
45
  """
46
  Analyzes the given query to determine what language it is written in and whether it is english, french, german or spanish.
47
 
 
52
  bool: Returns `False` if the query is not a supported language, `True` otherwise.
53
  """
54
  det = detect(query)
55
+ return det in ["en", "fr", "de", "es"]
56
+
57
  def analyze_query(self, query: str) -> bool:
58
  """
59
  Analyzes the given query to determine if it passes the guardrail check.
 
67
  embed_query = get_embedding(documents=[query])
68
  pred = self.guardrail.predict(embed_query.reshape(1, -1))
69
  return pred != 1 # Return True if pred is not 1, otherwise False
 
70
 
71
  def incremental_learning(self, X_new, y_new):
72
  """
 
78
  """
79
  # Extraction des caractéristiques
80
  embedding = model.encode(X_new)
81
+
82
  # Mise à jour incrémentale du modèle
83
+ self.guardrail.partial_fit(embedding.reshape(1, -1), y_new, classes=[0, 1])
84
 
85
+ with open(
86
+ os.path.join("server", "security", "storage", "guardrail_multi.pkl"), "wb"
87
+ ) as f:
88
  dump(self.guardrail, f)