Spaces:
Running
on
Zero
Running
on
Zero
import os, warnings | |
from operator import attrgetter | |
from typing import List, Dict, Callable, Tuple | |
import torch | |
import torch.nn.functional as F | |
from torchtyping import TensorType | |
from transformers import TextIteratorStreamer | |
from transformers import AutoTokenizer, BatchEncoding | |
import nnsight | |
from nnsight import LanguageModel | |
from nnsight.intervention import Envoy | |
warnings.filterwarnings("ignore") | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
# nnsight with multi-threading: https://github.com/ndif-team/nnsight/issues/280 | |
nnsight.CONFIG.APP.GLOBAL_TRACING = False | |
config = { | |
"model_name": "meta-llama/Meta-Llama-3.1-8B-Instruct", | |
"steering_vec": "activations/llama3-8b-steering-vec.pt", | |
"offset": "activations/llama3-8b-offset.pt", | |
"layer": 20, | |
"k": (8.5, 6), | |
} | |
def detect_module_attrs(model: LanguageModel) -> str: | |
if "model" in model._modules and "layers" in model.model._modules: | |
return "model.layers" | |
elif "transformers" in model._modules and "h" in model.transformers._modules: | |
return "transformers.h" | |
else: | |
raise Exception("Failed to detect module attributes.") | |
class ModelBase: | |
def __init__( | |
self, model_name: str, | |
steering_vec: TensorType, offset: TensorType, | |
k: Tuple[float, float], steering_layer: int, | |
tokenizer: AutoTokenizer = None, block_module_attr=None | |
): | |
if tokenizer is None: | |
self.tokenizer = self._load_tokenizer(model_name) | |
else: | |
self.tokenizer = tokenizer | |
self.model = self._load_model(model_name, self.tokenizer) | |
self.device = self.model.device | |
self.hidden_size = self.model.config.hidden_size | |
if block_module_attr is None: | |
self.block_modules = self.get_module(detect_module_attrs(self.model)) | |
else: | |
self.block_modules = self.get_module(block_module_attr) | |
self.steering_layer = steering_layer | |
self.k = k | |
self.unit_vec = F.normalize(steering_vec, dim=-1) | |
self.unit_vec, self.offset = self.set_dtype(self.unit_vec, offset) | |
def _load_model(self, model_name: str, tokenizer: AutoTokenizer) -> LanguageModel: | |
return LanguageModel(model_name, tokenizer=tokenizer, dispatch=True, trust_remote_code=True, device_map="auto", torch_dtype=torch.bfloat16) | |
def _load_tokenizer(self, model_name) -> AutoTokenizer: | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
tokenizer.padding_side = "left" | |
if not tokenizer.pad_token: | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
tokenizer.pad_token = tokenizer.eos_token | |
return tokenizer | |
def tokenize(self, prompt: str) -> BatchEncoding: | |
return self.tokenizer(prompt, padding=True, truncation=False, return_tensors="pt") | |
def get_module(self, attr: str) -> Envoy: | |
return attrgetter(attr)(self.model) | |
def set_dtype(self, *vars): | |
if len(vars) == 1: | |
return vars[0].to(self.model.dtype) | |
else: | |
return (var.to(self.model.dtype) for var in vars) | |
def apply_chat_template(self, instruction: str) -> List[str]: | |
messages = [{"role": "user", "content": instruction}] | |
return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
def generate(self, prompt: str, streamer: TextIteratorStreamer, steering: bool, coeff: float, generation_config: Dict): | |
formatted_prompt = self.apply_chat_template(prompt) | |
inputs = self.tokenize(formatted_prompt) | |
if steering: | |
if coeff < 0: | |
k = self.k[0] | |
else: | |
k = self.k[1] | |
with self.model.generate(inputs, do_sample=True, streamer=streamer, **generation_config): | |
self.block_modules.all() | |
acts = self.block_modules[self.steering_layer].output[0].clone() | |
proj = (acts - self.offset) @ self.unit_vec.unsqueeze(-1) * self.unit_vec # Orthogonal Projection | |
self.block_modules[self.steering_layer].output[0][:] = acts - proj + coeff * k * self.unit_vec | |
else: | |
inputs = inputs.to(self.device) | |
_ = self.model._model.generate(**inputs, do_sample=True, streamer=streamer, **generation_config) | |
def load_model() -> ModelBase: | |
steering_vec = torch.load(config['steering_vec'], weights_only=True) | |
offset = torch.load(config['offset'], weights_only=True) | |
model = ModelBase(config['model_name'], steering_vec=steering_vec, offset=offset, k=config['k'], steering_layer=config['layer']) | |
return model | |