File size: 3,557 Bytes
dc13cd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f3dd72
dc13cd1
1f3dd72
dc13cd1
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import glob
import torch
from gguf.gguf_reader import GGUFReader

class ExLlamaV2ModuleWrapper:
    @classmethod
    def wrap(cls, model, vector_configs):
        vectors = {}
        for file in glob.glob(str(model.config.model_dir) + '-vectors/*.gguf'):
            base = file.rsplit('-', 1)[-1].replace('.gguf', '')
            vector, direction = base.split('__')
            print(f"Loaded control vector: {vector}, Direction: {direction}")
            reader = GGUFReader(file)

            if reader.tensors[0].n_elements != model.config.hidden_size:
                print(' ## Control vector n_elements ({reader.tensors[0].n_elements}) != model.config.hidden_size ({model.config.hidden_size})')
                sys.exit()
            layers = torch.zeros((model.config.num_hidden_layers, model.config.hidden_size), dtype=torch.float32)
            for tensor in reader.tensors:
                idx = int(tensor.name.split('.')[-1])
                layers[idx] = torch.from_numpy(tensor.data.copy())
            vectors.setdefault(vector, {})[direction] = layers

        vector_configs = vector_configs.split(',')
        control_vector = torch.zeros((model.config.num_hidden_layers, model.config.hidden_size), dtype=torch.float32)
        for vector_config in vector_configs:
            (vector, direction, weight) = vector_config.split(':')
            vector_dirs = None
            for k, v in vectors.items():
                if vector in k:
                    vector = k
                    vector_dirs = v
                    break
            if vector_dirs is None:
                print(' !! Error: No vector for "%s" (%s)' % (vector, vector_config))
                continue

            debias_layers = vector_dirs.get('debias', None)
            if debias_layers is None:
                print(' !! Error: No debias for "%s" (%s)' % (vector, vector_config))
                continue

            direction_layers = vector_dirs.get(direction, None)
            if direction_layers is None:
                print(' !! Error: No "%s" for "%s" (%s)' % (direction, vector, vector_config))
                continue

            try:
                weight = float(weight)
            except Exception as e:
                print(' !! Non float weight %s (%s)' % (weight, vector_config))
                weight = 1.0

            print(' -- Applying %s debias and %s * %s' % (vector, direction, weight))
            control_vector += debias_layers
            control_vector += direction_layers * weight


        for idx, module in enumerate(model.modules):
            if idx == 0 or idx >= (len(model.modules) - 2) or module.name != 'MLP':
                continue
            model.modules[idx] = ExLlamaV2ModuleWrapper(module, control_vector)

    def __init__(self, module, control_vector):
        self.module = module
        self.control_vector = control_vector

    def __getattribute__(self, name):
        if name == 'forward':
            return object.__getattribute__(self, 'wrapped_forward')

        try:
            return getattr(object.__getattribute__(self, 'module'), name)
        except AttributeError:
            pass
        return object.__getattribute__(self, name)

    def wrapped_forward(self, *args, **kwargs):
        x = self.module.forward(*args, **kwargs)
        try:
            prev_norm = torch.norm(x, p=2)
            x += self.control_vector[self.module.layer_idx].clone().to(x.device)
            x *= prev_norm / torch.norm(x, p=2)
        except IndexError:
            pass
        return x