Normalize hidden state after adding control vectors to preserve L2 norm
Browse files- exl2_wrapper.py +2 -0
exl2_wrapper.py
CHANGED
|
@@ -78,7 +78,9 @@ class ExLlamaV2ModuleWrapper:
|
|
| 78 |
def wrapped_forward(self, *args, **kwargs):
|
| 79 |
x = self.module.forward(*args, **kwargs)
|
| 80 |
try:
|
|
|
|
| 81 |
x += self.control_vector[self.module.layer_idx].clone().to(x.device)
|
|
|
|
| 82 |
except IndexError:
|
| 83 |
pass
|
| 84 |
return x
|
|
|
|
| 78 |
def wrapped_forward(self, *args, **kwargs):
|
| 79 |
x = self.module.forward(*args, **kwargs)
|
| 80 |
try:
|
| 81 |
+
prev_norm = torch.norm(x, p=2)
|
| 82 |
x += self.control_vector[self.module.layer_idx].clone().to(x.device)
|
| 83 |
+
x *= prev_norm / torch.norm(x, p=2)
|
| 84 |
except IndexError:
|
| 85 |
pass
|
| 86 |
return x
|