hannahcyberey's picture
Upload model.py
0a82244 verified
import os, warnings
from operator import attrgetter
from typing import List, Dict, Callable, Tuple
import torch
import torch.nn.functional as F
from torchtyping import TensorType
from transformers import TextIteratorStreamer
from transformers import AutoTokenizer, BatchEncoding
import nnsight
from nnsight import LanguageModel
from nnsight.intervention import Envoy
warnings.filterwarnings("ignore")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# nnsight with multi-threading: https://github.com/ndif-team/nnsight/issues/280
nnsight.CONFIG.APP.GLOBAL_TRACING = False
config = {
"model_name": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"steering_vec": "activations/llama3-8b-steering-vec.pt",
"offset": "activations/llama3-8b-offset.pt",
"layer": 20,
"k": (8.5, 6),
}
def detect_module_attrs(model: LanguageModel) -> str:
if "model" in model._modules and "layers" in model.model._modules:
return "model.layers"
elif "transformers" in model._modules and "h" in model.transformers._modules:
return "transformers.h"
else:
raise Exception("Failed to detect module attributes.")
class ModelBase:
def __init__(
self, model_name: str,
steering_vec: TensorType, offset: TensorType,
k: Tuple[float, float], steering_layer: int,
tokenizer: AutoTokenizer = None, block_module_attr=None
):
if tokenizer is None:
self.tokenizer = self._load_tokenizer(model_name)
else:
self.tokenizer = tokenizer
self.model = self._load_model(model_name, self.tokenizer)
self.device = self.model.device
self.hidden_size = self.model.config.hidden_size
if block_module_attr is None:
self.block_modules = self.get_module(detect_module_attrs(self.model))
else:
self.block_modules = self.get_module(block_module_attr)
self.steering_layer = steering_layer
self.k = k
self.unit_vec = F.normalize(steering_vec, dim=-1)
self.unit_vec, self.offset = self.set_dtype(self.unit_vec, offset)
def _load_model(self, model_name: str, tokenizer: AutoTokenizer) -> LanguageModel:
return LanguageModel(model_name, tokenizer=tokenizer, dispatch=True, trust_remote_code=True, device_map="auto", torch_dtype=torch.bfloat16)
def _load_tokenizer(self, model_name) -> AutoTokenizer:
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.padding_side = "left"
if not tokenizer.pad_token:
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
def tokenize(self, prompt: str) -> BatchEncoding:
return self.tokenizer(prompt, padding=True, truncation=False, return_tensors="pt")
def get_module(self, attr: str) -> Envoy:
return attrgetter(attr)(self.model)
def set_dtype(self, *vars):
if len(vars) == 1:
return vars[0].to(self.model.dtype)
else:
return (var.to(self.model.dtype) for var in vars)
def apply_chat_template(self, instruction: str) -> List[str]:
messages = [{"role": "user", "content": instruction}]
return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
def generate(self, prompt: str, streamer: TextIteratorStreamer, steering: bool, coeff: float, generation_config: Dict):
formatted_prompt = self.apply_chat_template(prompt)
inputs = self.tokenize(formatted_prompt)
if steering:
if coeff < 0:
k = self.k[0]
else:
k = self.k[1]
with self.model.generate(inputs, do_sample=True, streamer=streamer, **generation_config):
self.block_modules.all()
acts = self.block_modules[self.steering_layer].output[0].clone()
proj = (acts - self.offset) @ self.unit_vec.unsqueeze(-1) * self.unit_vec # Orthogonal Projection
self.block_modules[self.steering_layer].output[0][:] = acts - proj + coeff * k * self.unit_vec
else:
inputs = inputs.to(self.device)
_ = self.model._model.generate(**inputs, do_sample=True, streamer=streamer, **generation_config)
def load_model() -> ModelBase:
steering_vec = torch.load(config['steering_vec'], weights_only=True)
offset = torch.load(config['offset'], weights_only=True)
model = ModelBase(config['model_name'], steering_vec=steering_vec, offset=offset, k=config['k'], steering_layer=config['layer'])
return model