|
import torch |
|
import transformers |
|
import tensorflow as tf |
|
from transformers import LlamaForCausalLM |
|
|
|
class SafeGenerationModel(LlamaForCausalLM): |
|
""" |
|
A wrapper around LlamaForCausalLM (or the appropriate base model class) |
|
that filters toxic inputs and outputs using a pre-trained toxicity classifier. |
|
""" |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.toxicity_model = tf.keras.models.load_model("toxic.keras") |
|
self.toxicity_threshold = 0.6 |
|
try: |
|
|
|
|
|
self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.name_or_path) |
|
except Exception as e: |
|
self.tokenizer = None |
|
print(f"Warning: Tokenizer could not be loaded in SafeGenerationModel: {e}") |
|
|
|
def is_toxic(self, text: str) -> bool: |
|
"""Utility: Return True if the given text is predicted toxic by the classifier.""" |
|
if text is None or text.strip() == "": |
|
return False |
|
|
|
inputs = tf.constant([text]) |
|
|
|
prob = float(self.toxicity_model.predict(inputs)[0, 0]) |
|
return prob >= self.toxicity_threshold |
|
|
|
def generate(self, *args, **kwargs): |
|
|
|
|
|
|
|
prompt_text = None |
|
if 'input_ids' in kwargs and self.tokenizer: |
|
try: |
|
|
|
input_ids = kwargs['input_ids'] |
|
if isinstance(input_ids, torch.Tensor): |
|
input_ids = input_ids[0].tolist() |
|
prompt_text = self.tokenizer.decode(input_ids, skip_special_tokens=True) |
|
except Exception: |
|
prompt_text = None |
|
elif args and self.tokenizer: |
|
|
|
try: |
|
input_ids = args[0] |
|
if isinstance(input_ids, torch.Tensor): |
|
input_ids = input_ids[0].tolist() |
|
prompt_text = self.tokenizer.decode(input_ids, skip_special_tokens=True) |
|
except Exception: |
|
prompt_text = None |
|
|
|
if prompt_text and self.is_toxic(prompt_text): |
|
|
|
safe_msg = "Input is toxic, please be kind to yourself and others." |
|
if self.tokenizer: |
|
|
|
safe_ids = self.tokenizer(safe_msg, return_tensors="pt")["input_ids"] |
|
|
|
device = next(self.parameters()).device |
|
safe_ids = safe_ids.to(device) |
|
return safe_ids |
|
else: |
|
|
|
raise ValueError("Toxic input detected and no tokenizer available to generate safe response.") |
|
|
|
|
|
generated_outputs = super().generate(*args, **kwargs) |
|
|
|
|
|
if isinstance(generated_outputs, torch.Tensor): |
|
output_sequences = generated_outputs |
|
else: |
|
|
|
output_sequences = generated_outputs.sequences if hasattr(generated_outputs, "sequences") else generated_outputs |
|
|
|
|
|
output_sequences = output_sequences.detach().cpu() |
|
if output_sequences.ndim == 1: |
|
output_sequences = output_sequences.unsqueeze(0) |
|
|
|
|
|
safe_sequences = [] |
|
for seq in output_sequences: |
|
output_text = self.tokenizer.decode(seq.tolist(), skip_special_tokens=True) if self.tokenizer else None |
|
if output_text and self.is_toxic(output_text): |
|
|
|
safe_msg = "Response is toxic, please be kind to yourself and others." |
|
if self.tokenizer: |
|
safe_ids = self.tokenizer(safe_msg, return_tensors="pt")["input_ids"][0] |
|
|
|
seq_len = seq.shape[0] |
|
if safe_ids.shape[0] < seq_len: |
|
|
|
pad_id = self.config.eos_token_id if self.config.eos_token_id is not None else (self.config.pad_token_id or 0) |
|
|
|
pad_length = seq_len - safe_ids.shape[0] |
|
safe_ids = torch.cat([safe_ids, torch.full((pad_length,), pad_id, dtype=torch.long)], dim=0) |
|
else: |
|
|
|
safe_ids = safe_ids[:seq_len] |
|
safe_sequences.append(safe_ids) |
|
else: |
|
|
|
safe_sequences.append(torch.zeros_like(seq)) |
|
else: |
|
|
|
safe_sequences.append(seq) |
|
|
|
safe_sequences = torch.stack(safe_sequences, dim=0) |
|
return safe_sequences |
|
|