Mahesh2841 commited on
Commit
59afed2
·
verified ·
1 Parent(s): aa809cf

Update custom_modeling.py

Browse files
Files changed (1) hide show
  1. custom_modeling.py +14 -6
custom_modeling.py CHANGED
@@ -45,7 +45,9 @@ class _SafeGenerationMixin:
45
  repo_id=self.config.name_or_path,
46
  filename="toxic.keras",
47
  )
48
- self._toxicity_model = tf.keras.models.load_model(path, compile=False)
 
 
49
  return self._toxicity_model
50
 
51
  @property
@@ -55,7 +57,9 @@ class _SafeGenerationMixin:
55
  repo_id=self.config.name_or_path,
56
  filename="PI.keras",
57
  )
58
- self._pi_model = tf.keras.models.load_model(path, compile=False)
 
 
59
  return self._pi_model
60
 
61
  def _ensure_tokenizer(self):
@@ -70,15 +74,19 @@ class _SafeGenerationMixin:
70
  def _is_toxic(self, text: str) -> bool:
71
  if not text.strip():
72
  return False
73
- inputs = tf.constant([text], dtype=tf.string)
74
- prob = float(self._tox_model.predict(inputs)[0, 0])
 
 
75
  return prob >= self._tox_threshold
76
 
77
  def _has_prompt_injection(self, text: str) -> bool:
78
  if not text.strip():
79
  return False
80
- inputs = tf.constant([text], dtype=tf.string)
81
- prob = float(self._prompt_injection_model.predict(inputs)[0, 0])
 
 
82
  return prob >= self._pi_threshold
83
 
84
  def _safe_ids(self, message: str, length: int | None = None):
 
45
  repo_id=self.config.name_or_path,
46
  filename="toxic.keras",
47
  )
48
+ # Force CPU loading for better performance
49
+ with tf.device('/CPU:0'):
50
+ self._toxicity_model = tf.keras.models.load_model(path, compile=False)
51
  return self._toxicity_model
52
 
53
  @property
 
57
  repo_id=self.config.name_or_path,
58
  filename="PI.keras",
59
  )
60
+ # Force CPU loading for better performance
61
+ with tf.device('/CPU:0'):
62
+ self._pi_model = tf.keras.models.load_model(path, compile=False)
63
  return self._pi_model
64
 
65
  def _ensure_tokenizer(self):
 
74
  def _is_toxic(self, text: str) -> bool:
75
  if not text.strip():
76
  return False
77
+ # Ensure CPU execution
78
+ with tf.device('/CPU:0'):
79
+ inputs = tf.constant([text], dtype=tf.string)
80
+ prob = float(self._tox_model.predict(inputs, verbose=0)[0, 0])
81
  return prob >= self._tox_threshold
82
 
83
  def _has_prompt_injection(self, text: str) -> bool:
84
  if not text.strip():
85
  return False
86
+ # Ensure CPU execution
87
+ with tf.device('/CPU:0'):
88
+ inputs = tf.constant([text], dtype=tf.string)
89
+ prob = float(self._prompt_injection_model.predict(inputs, verbose=0)[0, 0])
90
  return prob >= self._pi_threshold
91
 
92
  def _safe_ids(self, message: str, length: int | None = None):