atticusg commited on
Commit
e1060e3
·
verified ·
1 Parent(s): 1874eca

Delete featurizer.py

Browse files
Files changed (1) hide show
  1. featurizer.py +0 -52
featurizer.py DELETED
@@ -1,52 +0,0 @@
1
- """
2
- Copy of the existing SubspaceFeaturizer implementation for submission.
3
- This file provides the same SubspaceFeaturizer functionality in a self-contained format.
4
- """
5
-
6
- import torch
7
- import torch.nn as nn
8
- import pyvene as pv
9
- from CausalAbstraction.neural.featurizers import Featurizer
10
-
11
-
12
- class SubspaceFeaturizerModuleCopy(torch.nn.Module):
13
- def __init__(self, rotate_layer):
14
- super().__init__()
15
- self.rotate = rotate_layer
16
-
17
- def forward(self, x):
18
- r = self.rotate.weight.T
19
- f = x.to(r.dtype) @ r.T
20
- error = x - (f @ r).to(x.dtype)
21
- return f, error
22
-
23
-
24
- class SubspaceInverseFeaturizerModuleCopy(torch.nn.Module):
25
- def __init__(self, rotate_layer):
26
- super().__init__()
27
- self.rotate = rotate_layer
28
-
29
- def forward(self, f, error):
30
- r = self.rotate.weight.T
31
- return (f.to(r.dtype) @ r).to(f.dtype) + error.to(f.dtype)
32
-
33
-
34
- class SubspaceFeaturizerCopy(Featurizer):
35
- def __init__(self, shape=None, rotation_subspace=None, trainable=True, id="subspace"):
36
- assert shape is not None or rotation_subspace is not None, "Either shape or rotation_subspace must be provided."
37
- if shape is not None:
38
- self.rotate = pv.models.layers.LowRankRotateLayer(*shape, init_orth=True)
39
- elif rotation_subspace is not None:
40
- shape = rotation_subspace.shape
41
- self.rotate = pv.models.layers.LowRankRotateLayer(*shape, init_orth=False)
42
- self.rotate.weight.data.copy_(rotation_subspace)
43
- self.rotate = torch.nn.utils.parametrizations.orthogonal(self.rotate)
44
-
45
- if not trainable:
46
- self.rotate.requires_grad_(False)
47
-
48
- # Create module-based featurizer and inverse_featurizer
49
- featurizer = SubspaceFeaturizerModuleCopy(self.rotate)
50
- inverse_featurizer = SubspaceInverseFeaturizerModuleCopy(self.rotate)
51
-
52
- super().__init__(featurizer, inverse_featurizer, n_features=self.rotate.weight.shape[1], id=id)