Update custom_modeling.py
Browse files- custom_modeling.py +14 -6
custom_modeling.py
CHANGED
@@ -45,7 +45,9 @@ class _SafeGenerationMixin:
|
|
45 |
repo_id=self.config.name_or_path,
|
46 |
filename="toxic.keras",
|
47 |
)
|
48 |
-
|
|
|
|
|
49 |
return self._toxicity_model
|
50 |
|
51 |
@property
|
@@ -55,7 +57,9 @@ class _SafeGenerationMixin:
|
|
55 |
repo_id=self.config.name_or_path,
|
56 |
filename="PI.keras",
|
57 |
)
|
58 |
-
|
|
|
|
|
59 |
return self._pi_model
|
60 |
|
61 |
def _ensure_tokenizer(self):
|
@@ -70,15 +74,19 @@ class _SafeGenerationMixin:
|
|
70 |
def _is_toxic(self, text: str) -> bool:
|
71 |
if not text.strip():
|
72 |
return False
|
73 |
-
|
74 |
-
|
|
|
|
|
75 |
return prob >= self._tox_threshold
|
76 |
|
77 |
def _has_prompt_injection(self, text: str) -> bool:
|
78 |
if not text.strip():
|
79 |
return False
|
80 |
-
|
81 |
-
|
|
|
|
|
82 |
return prob >= self._pi_threshold
|
83 |
|
84 |
def _safe_ids(self, message: str, length: int | None = None):
|
|
|
45 |
repo_id=self.config.name_or_path,
|
46 |
filename="toxic.keras",
|
47 |
)
|
48 |
+
# ✅ Force CPU loading for better performance
|
49 |
+
with tf.device('/CPU:0'):
|
50 |
+
self._toxicity_model = tf.keras.models.load_model(path, compile=False)
|
51 |
return self._toxicity_model
|
52 |
|
53 |
@property
|
|
|
57 |
repo_id=self.config.name_or_path,
|
58 |
filename="PI.keras",
|
59 |
)
|
60 |
+
# ✅ Force CPU loading for better performance
|
61 |
+
with tf.device('/CPU:0'):
|
62 |
+
self._pi_model = tf.keras.models.load_model(path, compile=False)
|
63 |
return self._pi_model
|
64 |
|
65 |
def _ensure_tokenizer(self):
|
|
|
74 |
def _is_toxic(self, text: str) -> bool:
|
75 |
if not text.strip():
|
76 |
return False
|
77 |
+
# ✅ Ensure CPU execution
|
78 |
+
with tf.device('/CPU:0'):
|
79 |
+
inputs = tf.constant([text], dtype=tf.string)
|
80 |
+
prob = float(self._tox_model.predict(inputs, verbose=0)[0, 0])
|
81 |
return prob >= self._tox_threshold
|
82 |
|
83 |
def _has_prompt_injection(self, text: str) -> bool:
|
84 |
if not text.strip():
|
85 |
return False
|
86 |
+
# ✅ Ensure CPU execution
|
87 |
+
with tf.device('/CPU:0'):
|
88 |
+
inputs = tf.constant([text], dtype=tf.string)
|
89 |
+
prob = float(self._prompt_injection_model.predict(inputs, verbose=0)[0, 0])
|
90 |
return prob >= self._pi_threshold
|
91 |
|
92 |
def _safe_ids(self, message: str, length: int | None = None):
|