|
--- |
|
license: apache-2.0 |
|
--- |
|
# InterProt ESM2 SAE Models |
|
|
|
A set of SAE models trained on [ESM2-650](https://huggingface.co/facebook/esm2_t33_650M_UR50D) activations using 1M protein sequences from [UniProt](https://www.uniprot.org/). The SAE implementation mostly followed [Gao et al.](https://arxiv.org/abs/2406.04093) with Top-K activation function. |
|
|
|
For more information, check out our [preprint](https://www.biorxiv.org/content/10.1101/2025.02.06.636901v1). Our SAEs can be viewed and interacted with on [interprot.com](https://interprot.com). |
|
|
|
## Installation |
|
|
|
```bash |
|
pip install git+https://github.com/etowahadams/interprot.git |
|
``` |
|
|
|
## Usage |
|
|
|
Install InterProt, load ESM and SAE |
|
```python |
|
import torch |
|
from transformers import AutoTokenizer, EsmModel |
|
from safetensors.torch import load_file |
|
from interprot.sae_model import SparseAutoencoder |
|
from huggingface_hub import hf_hub_download |
|
|
|
ESM_DIM = 1280 |
|
SAE_DIM = 4096 |
|
LAYER = 24 |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
# Load ESM model |
|
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") |
|
esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D") |
|
esm_model.to(device) |
|
esm_model.eval() |
|
|
|
# Load SAE model |
|
checkpoint_path = hf_hub_download( |
|
repo_id="liambai/InterProt-ESM2-SAEs", |
|
filename="esm2_plm1280_l24_sae4096.safetensors" |
|
) |
|
sae_model = SparseAutoencoder(ESM_DIM, SAE_DIM) |
|
sae_model.load_state_dict(load_file(checkpoint_path)) |
|
sae_model.to(device) |
|
sae_model.eval() |
|
``` |
|
|
|
ESM -> SAE inference on an amino acid sequence of length `L` |
|
``` |
|
seq = "TTCCPSIVARSNFNVCRLPGTPEALCATYTGCIIIPGATCPGDYAN" |
|
|
|
# Tokenize sequence and run ESM inference |
|
inputs = tokenizer(seq, padding=True, return_tensors="pt").to(device) |
|
with torch.no_grad(): |
|
outputs = esm_model(**inputs, output_hidden_states=True) |
|
|
|
# esm_layer_acts has shape (L+2, ESM_DIM), +2 for BoS and EoS tokens |
|
esm_layer_acts = outputs.hidden_states[LAYER][0] |
|
|
|
# Using ESM embeddings from LAYER, run SAE inference |
|
sae_acts = sae_model.get_acts(esm_layer_acts) # (L+2, SAE_DIM) |
|
sae_acts |
|
``` |
|
|
|
## Note on the default checkpoint on [interprot.com](https://interprot.com) |
|
|
|
In Novermber 2024, we shared an earlier version of our layer 24 SAE on [X](https://x.com/liambai21/status/1852765669080879108?s=46) and got a lot of amazing community support in identifying SAE features; therefore, we have kept it as the default on [interprot.com](interprot.com). Since then, we retrained the layer 24 SAE with slightly different hyperparameters and on more sequences (1M vs. the original 100K). The new SAE is named `esm2_plm1280_l24_sae4096.safetensors` whereas the original is named `esm2_plm1280_l24_sae4096_100k.safetensors`. |
|
|
|
We recommend using `esm2_plm1280_l24_sae4096.safetensors`, but if you'd like to reproduce the default SAE on [interprot.com](https://interprot.com), you can use `esm2_plm1280_l24_sae4096_100k.safetensors`. All other layer SAEs are trained with the same configrations as `esm2_plm1280_l24_sae4096.safetensors`. |