File size: 4,654 Bytes
93a19af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a82244
 
93a19af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a82244
93a19af
 
 
 
 
 
0a82244
93a19af
0a82244
93a19af
 
 
0a82244
 
 
93a19af
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os, warnings
from operator import attrgetter
from typing import List, Dict, Callable, Tuple

import torch
import torch.nn.functional as F
from torchtyping import TensorType
from transformers import TextIteratorStreamer
from transformers import AutoTokenizer, BatchEncoding
import nnsight
from nnsight import LanguageModel
from nnsight.intervention import Envoy

warnings.filterwarnings("ignore")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# nnsight with multi-threading: https://github.com/ndif-team/nnsight/issues/280
nnsight.CONFIG.APP.GLOBAL_TRACING = False

config = {
    "model_name": "meta-llama/Meta-Llama-3.1-8B-Instruct",
    "steering_vec": "activations/llama3-8b-steering-vec.pt",
    "offset": "activations/llama3-8b-offset.pt",
    "layer": 20,
    "k": (8.5, 6),
}


def detect_module_attrs(model: LanguageModel) -> str:
    if "model" in model._modules and "layers" in model.model._modules:
        return "model.layers"
    elif "transformers" in model._modules and "h" in model.transformers._modules:
        return "transformers.h"
    else:
        raise Exception("Failed to detect module attributes.")



class ModelBase:
    def __init__(
        self, model_name: str,
        steering_vec: TensorType, offset: TensorType,
        k: Tuple[float, float], steering_layer: int,
        tokenizer: AutoTokenizer = None, block_module_attr=None
    ):
        if tokenizer is None:
            self.tokenizer = self._load_tokenizer(model_name)
        else:
            self.tokenizer = tokenizer
        self.model = self._load_model(model_name, self.tokenizer)

        self.device = self.model.device
        self.hidden_size = self.model.config.hidden_size
        if block_module_attr is None:
            self.block_modules = self.get_module(detect_module_attrs(self.model))
        else:
            self.block_modules = self.get_module(block_module_attr)
        self.steering_layer = steering_layer
        self.k = k
        self.unit_vec = F.normalize(steering_vec, dim=-1)
        self.unit_vec, self.offset = self.set_dtype(self.unit_vec, offset)
    
    def _load_model(self, model_name: str, tokenizer: AutoTokenizer) -> LanguageModel:
        return LanguageModel(model_name, tokenizer=tokenizer, dispatch=True, trust_remote_code=True, device_map="auto", torch_dtype=torch.bfloat16)
    
    def _load_tokenizer(self, model_name) -> AutoTokenizer:
        tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        tokenizer.padding_side = "left"
        if not tokenizer.pad_token:
            tokenizer.pad_token_id = tokenizer.eos_token_id
            tokenizer.pad_token = tokenizer.eos_token
        return tokenizer

    def tokenize(self, prompt: str) -> BatchEncoding:
        return self.tokenizer(prompt, padding=True, truncation=False, return_tensors="pt")
    
    def get_module(self, attr: str) -> Envoy:
        return attrgetter(attr)(self.model)

    def set_dtype(self, *vars):
        if len(vars) == 1:
            return vars[0].to(self.model.dtype)
        else:
            return (var.to(self.model.dtype) for var in vars)
    
    def apply_chat_template(self, instruction: str) -> List[str]:
        messages = [{"role": "user", "content": instruction}]
        return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    def generate(self, prompt: str, streamer: TextIteratorStreamer, steering: bool, coeff: float, generation_config: Dict):
        formatted_prompt = self.apply_chat_template(prompt)
        inputs = self.tokenize(formatted_prompt)

        if steering:
            if coeff < 0:
                k = self.k[0]
            else:
                k = self.k[1]

            with self.model.generate(inputs, do_sample=True, streamer=streamer, **generation_config):
                self.block_modules.all()
                acts = self.block_modules[self.steering_layer].output[0].clone()
                proj = (acts - self.offset) @ self.unit_vec.unsqueeze(-1) * self.unit_vec # Orthogonal Projection
                self.block_modules[self.steering_layer].output[0][:] = acts - proj + coeff * k * self.unit_vec
        else:
            inputs = inputs.to(self.device)
            _ = self.model._model.generate(**inputs, do_sample=True, streamer=streamer, **generation_config)


def load_model() -> ModelBase:
    steering_vec = torch.load(config['steering_vec'], weights_only=True)
    offset = torch.load(config['offset'], weights_only=True)
    model = ModelBase(config['model_name'], steering_vec=steering_vec, offset=offset, k=config['k'], steering_layer=config['layer'])
    return model