import os, warnings from operator import attrgetter from typing import List, Dict, Callable, Tuple import torch import torch.nn.functional as F from torchtyping import TensorType from transformers import TextIteratorStreamer from transformers import AutoTokenizer, BatchEncoding import nnsight from nnsight import LanguageModel from nnsight.intervention import Envoy warnings.filterwarnings("ignore") os.environ["TOKENIZERS_PARALLELISM"] = "false" # nnsight with multi-threading: https://github.com/ndif-team/nnsight/issues/280 nnsight.CONFIG.APP.GLOBAL_TRACING = False config = { "model_name": "meta-llama/Meta-Llama-3.1-8B-Instruct", "steering_vec": "activations/llama3-8b-steering-vec.pt", "offset": "activations/llama3-8b-offset.pt", "layer": 20, "k": (8.5, 6), } def detect_module_attrs(model: LanguageModel) -> str: if "model" in model._modules and "layers" in model.model._modules: return "model.layers" elif "transformers" in model._modules and "h" in model.transformers._modules: return "transformers.h" else: raise Exception("Failed to detect module attributes.") class ModelBase: def __init__( self, model_name: str, steering_vec: TensorType, offset: TensorType, k: Tuple[float, float], steering_layer: int, tokenizer: AutoTokenizer = None, block_module_attr=None ): if tokenizer is None: self.tokenizer = self._load_tokenizer(model_name) else: self.tokenizer = tokenizer self.model = self._load_model(model_name, self.tokenizer) self.device = self.model.device self.hidden_size = self.model.config.hidden_size if block_module_attr is None: self.block_modules = self.get_module(detect_module_attrs(self.model)) else: self.block_modules = self.get_module(block_module_attr) self.steering_layer = steering_layer self.k = k self.unit_vec = F.normalize(steering_vec, dim=-1) self.unit_vec, self.offset = self.set_dtype(self.unit_vec, offset) def _load_model(self, model_name: str, tokenizer: AutoTokenizer) -> LanguageModel: return LanguageModel(model_name, tokenizer=tokenizer, dispatch=True, trust_remote_code=True, device_map="auto", torch_dtype=torch.bfloat16) def _load_tokenizer(self, model_name) -> AutoTokenizer: tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) tokenizer.padding_side = "left" if not tokenizer.pad_token: tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.pad_token = tokenizer.eos_token return tokenizer def tokenize(self, prompt: str) -> BatchEncoding: return self.tokenizer(prompt, padding=True, truncation=False, return_tensors="pt") def get_module(self, attr: str) -> Envoy: return attrgetter(attr)(self.model) def set_dtype(self, *vars): if len(vars) == 1: return vars[0].to(self.model.dtype) else: return (var.to(self.model.dtype) for var in vars) def apply_chat_template(self, instruction: str) -> List[str]: messages = [{"role": "user", "content": instruction}] return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) def generate(self, prompt: str, streamer: TextIteratorStreamer, steering: bool, coeff: float, generation_config: Dict): formatted_prompt = self.apply_chat_template(prompt) inputs = self.tokenize(formatted_prompt) if steering: if coeff < 0: k = self.k[0] else: k = self.k[1] with self.model.generate(inputs, do_sample=True, streamer=streamer, **generation_config): self.block_modules.all() acts = self.block_modules[self.steering_layer].output[0].clone() proj = (acts - self.offset) @ self.unit_vec.unsqueeze(-1) * self.unit_vec # Orthogonal Projection self.block_modules[self.steering_layer].output[0][:] = acts - proj + coeff * k * self.unit_vec else: inputs = inputs.to(self.device) _ = self.model._model.generate(**inputs, do_sample=True, streamer=streamer, **generation_config) def load_model() -> ModelBase: steering_vec = torch.load(config['steering_vec'], weights_only=True) offset = torch.load(config['offset'], weights_only=True) model = ModelBase(config['model_name'], steering_vec=steering_vec, offset=offset, k=config['k'], steering_layer=config['layer']) return model