777_test / custom_modeling.py
Mahesh2841's picture
Upload folder using huggingface_hub
b3d2c4c verified
raw
history blame
6.41 kB
import torch
import transformers
import tensorflow as tf
from transformers import LlamaForCausalLM # use the base model class matching your model's architecture
class SafeGenerationModel(LlamaForCausalLM):
"""
A wrapper around LlamaForCausalLM (or the appropriate base model class)
that filters toxic inputs and outputs using a pre-trained toxicity classifier.
"""
def __init__(self, config):
super().__init__(config)
# Load the pre-trained toxicity classifier model (Keras model).
self.toxicity_model = tf.keras.models.load_model("toxic.keras")
self.toxicity_threshold = 0.6 # Probability threshold to consider content toxic.
try:
# Load the tokenizer for decoding/encoding text.
# This uses the same repository (name_or_path) as the model.
self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.name_or_path)
except Exception as e:
self.tokenizer = None
print(f"Warning: Tokenizer could not be loaded in SafeGenerationModel: {e}")
def is_toxic(self, text: str) -> bool:
"""Utility: Return True if the given text is predicted toxic by the classifier."""
if text is None or text.strip() == "":
return False
# Prepare input for the classifier (expects a batch dimension)
inputs = tf.constant([text])
# Get toxicity probability (assuming the model outputs a sigmoid probability for "toxic")
prob = float(self.toxicity_model.predict(inputs)[0, 0])
return prob >= self.toxicity_threshold
def generate(self, *args, **kwargs):
# Intercept the generate call to filter toxic prompts and outputs.
# 1. Check prompt toxicity (if input text is available via tokenizer).
# The prompt might be passed as input_ids (tensor) or as text (if using pipelines).
prompt_text = None
if 'input_ids' in kwargs and self.tokenizer:
try:
# Decode input_ids to get the prompt text (assume batch size 1 for simplicity)
input_ids = kwargs['input_ids']
if isinstance(input_ids, torch.Tensor):
input_ids = input_ids[0].tolist() # use first sequence
prompt_text = self.tokenizer.decode(input_ids, skip_special_tokens=True)
except Exception:
prompt_text = None
elif args and self.tokenizer:
# If input_ids were passed positionally in *args (e.g., first arg)
try:
input_ids = args[0]
if isinstance(input_ids, torch.Tensor):
input_ids = input_ids[0].tolist()
prompt_text = self.tokenizer.decode(input_ids, skip_special_tokens=True)
except Exception:
prompt_text = None
if prompt_text and self.is_toxic(prompt_text):
# If the user prompt is toxic, return a safe response without generating.
safe_msg = "Input is toxic, please be kind to yourself and others."
if self.tokenizer:
# Encode the safe message into token IDs
safe_ids = self.tokenizer(safe_msg, return_tensors="pt")["input_ids"]
# Move to the same device as model's first parameter
device = next(self.parameters()).device
safe_ids = safe_ids.to(device)
return safe_ids # Return token IDs for safe message as the generated output
else:
# If tokenizer not available, raise an error or empty output
raise ValueError("Toxic input detected and no tokenizer available to generate safe response.")
# 2. If input is fine, proceed with normal text generation using the base class method.
generated_outputs = super().generate(*args, **kwargs)
# Decode outputs to text for toxicity checking.
# `generated_outputs` is usually a torch.LongTensor of shape (batch_size, seq_length).
if isinstance(generated_outputs, torch.Tensor):
output_sequences = generated_outputs
else:
# In case generate returns a dataclass or list
output_sequences = generated_outputs.sequences if hasattr(generated_outputs, "sequences") else generated_outputs
# Ensure we have a 2D tensor of token ids
output_sequences = output_sequences.detach().cpu() # move to CPU for decoding
if output_sequences.ndim == 1:
output_sequences = output_sequences.unsqueeze(0)
# Check each generated sequence for toxicity
safe_sequences = []
for seq in output_sequences:
output_text = self.tokenizer.decode(seq.tolist(), skip_special_tokens=True) if self.tokenizer else None
if output_text and self.is_toxic(output_text):
# Replace toxic output with a safe message
safe_msg = "Response is toxic, please be kind to yourself and others."
if self.tokenizer:
safe_ids = self.tokenizer(safe_msg, return_tensors="pt")["input_ids"][0]
# Pad or truncate safe_ids to match original sequence length for consistency
seq_len = seq.shape[0]
if safe_ids.shape[0] < seq_len:
# Pad with EOS or PAD token if defined, else use 0
pad_id = self.config.eos_token_id if self.config.eos_token_id is not None else (self.config.pad_token_id or 0)
# Pad safe_ids to length seq_len
pad_length = seq_len - safe_ids.shape[0]
safe_ids = torch.cat([safe_ids, torch.full((pad_length,), pad_id, dtype=torch.long)], dim=0)
else:
# Truncate if safe message is longer than allowed
safe_ids = safe_ids[:seq_len]
safe_sequences.append(safe_ids)
else:
# If no tokenizer, append an empty sequence or raise
safe_sequences.append(torch.zeros_like(seq))
else:
# Non-toxic output; keep original
safe_sequences.append(seq)
# Stack sequences back into a tensor
safe_sequences = torch.stack(safe_sequences, dim=0)
return safe_sequences