Mahesh2841 commited on
Commit
c70f7b4
·
verified ·
1 Parent(s): bb29408

Update custom_modeling.py

Browse files
Files changed (1) hide show
  1. custom_modeling.py +30 -66
custom_modeling.py CHANGED
@@ -1,39 +1,3 @@
1
- import torch
2
- import tensorflow as tf
3
- from transformers import LlamaForCausalLM
4
- from transformers.utils import cached_file
5
- import os
6
- import logging
7
-
8
- # Set up logging
9
- logger = logging.getLogger(__name__)
10
- logging.basicConfig(level=logging.INFO)
11
-
12
- class ToxicityChecker:
13
- def __init__(self, model_path):
14
- try:
15
- # Check if file exists
16
- if not os.path.exists(model_path):
17
- raise FileNotFoundError(f"Toxicity model not found at {model_path}")
18
-
19
- logger.info(f"Loading toxicity model from: {model_path}")
20
- self.model = tf.keras.models.load_model(model_path)
21
- logger.info("Toxicity model loaded successfully")
22
- except Exception as e:
23
- logger.error(f"Failed to load toxicity model: {str(e)}")
24
- raise
25
-
26
- def is_toxic(self, text, threshold=0.6):
27
- try:
28
- # Convert to TensorFlow constant
29
- text_tensor = tf.constant([text])
30
- prob = self.model.predict(text_tensor, verbose=0)[0][0]
31
- logger.debug(f"Toxicity check: '{text[:30]}...' → prob: {prob:.3f}")
32
- return prob > threshold
33
- except Exception as e:
34
- logger.error(f"Toxicity check failed: {str(e)}")
35
- return False
36
-
37
  import os
38
  import torch
39
  import tensorflow as tf
@@ -43,50 +7,50 @@ import logging
43
 
44
  logger = logging.getLogger(__name__)
45
 
46
- class ToxicityChecker:
47
- def __init__(self, model_path):
48
- if not os.path.exists(model_path):
49
- raise FileNotFoundError(f"Toxicity model not found at {model_path}")
50
- logger.info(f"Loading toxicity model from: {model_path}")
51
- self.model = tf.keras.models.load_model(model_path)
 
 
 
 
 
52
  logger.info("Toxicity model loaded successfully")
53
-
54
  def is_toxic(self, text, threshold=0.6):
55
  try:
56
- text_tensor = tf.constant([text])
57
- prob = self.model.predict(text_tensor, verbose=0)[0][0]
58
- logger.debug(f"Toxicity: '{text[:30]}...' → {prob:.3f}")
59
  return prob > threshold
60
- except Exception:
 
61
  return False
62
 
63
- class SafeGenerationModel(LlamaForCausalLM):
64
- def __init__(self, config):
65
- super().__init__(config)
66
- toxic_path = cached_file(config.name_or_path, "toxic.keras")
67
- self.toxicity_checker = ToxicityChecker(toxic_path)
68
- self.tokenizer = None
69
-
70
  def generate(self, *args, **kwargs):
71
  inputs = kwargs.get("input_ids")
72
- if inputs is not None and self.tokenizer:
 
 
73
  input_text = self.tokenizer.decode(inputs[0], skip_special_tokens=True)
74
- if self.toxicity_checker.is_toxic(input_text):
75
- return self._safe_response_tensor()
76
 
 
77
  outputs = super().generate(*args, **kwargs)
78
 
 
79
  if self.tokenizer:
80
  output_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
81
- if self.toxicity_checker.is_toxic(output_text):
82
- return self._safe_response_tensor()
83
 
84
  return outputs
85
-
86
- def _safe_response_tensor(self):
87
- safe_response = "I'm unable to respond to that request. HAHAHA"
88
- return self.tokenizer.encode(safe_response, return_tensors="pt").to(self.device)
89
-
90
  def set_tokenizer(self, tokenizer):
91
- self.tokenizer = tokenizer
92
- logger.info("Tokenizer injected into model")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import torch
3
  import tensorflow as tf
 
7
 
8
  logger = logging.getLogger(__name__)
9
 
10
+ class SafeGenerationModel(LlamaForCausalLM):
11
+ def __init__(self, config):
12
+ super().__init__(config)
13
+
14
+ # Load toxicity model
15
+ toxic_path = cached_file(config._name_or_path, "toxic.keras")
16
+ if not os.path.exists(toxic_path):
17
+ raise FileNotFoundError(f"Toxicity model not found at {toxic_path}")
18
+
19
+ self.toxicity_model = tf.keras.models.load_model(toxic_path)
20
+ self.tokenizer = None
21
  logger.info("Toxicity model loaded successfully")
22
+
23
  def is_toxic(self, text, threshold=0.6):
24
  try:
25
+ prob = self.toxicity_model.predict([text], verbose=0)[0][0]
 
 
26
  return prob > threshold
27
+ except Exception as e:
28
+ logger.error(f"Toxicity check failed: {str(e)}")
29
  return False
30
 
 
 
 
 
 
 
 
31
  def generate(self, *args, **kwargs):
32
  inputs = kwargs.get("input_ids")
33
+
34
+ # Check input toxicity
35
+ if self.tokenizer and inputs is not None:
36
  input_text = self.tokenizer.decode(inputs[0], skip_special_tokens=True)
37
+ if self.is_toxic(input_text):
38
+ return self._safe_response()
39
 
40
+ # Generate response
41
  outputs = super().generate(*args, **kwargs)
42
 
43
+ # Check output toxicity
44
  if self.tokenizer:
45
  output_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
46
+ if self.is_toxic(output_text):
47
+ return self._safe_response()
48
 
49
  return outputs
50
+
51
+ def _safe_response(self):
52
+ safe_text = "I'm unable to respond to that request. HAHAHA"
53
+ return self.tokenizer.encode(safe_text, return_tensors="pt").to(self.device)
54
+
55
  def set_tokenizer(self, tokenizer):
56
+ self.tokenizer = tokenizer