InterProt-ESM2-SAEs / README.md
liambai's picture
Update README.md
3ded2e5 verified
---
license: apache-2.0
---
# InterProt ESM2 SAE Models
A set of SAE models trained on [ESM2-650](https://huggingface.co/facebook/esm2_t33_650M_UR50D) activations using 1M protein sequences from [UniProt](https://www.uniprot.org/). The SAE implementation mostly followed [Gao et al.](https://arxiv.org/abs/2406.04093) with Top-K activation function.
For more information, check out our [preprint](https://www.biorxiv.org/content/10.1101/2025.02.06.636901v1). Our SAEs can be viewed and interacted with on [interprot.com](https://interprot.com).
## Installation
```bash
pip install git+https://github.com/etowahadams/interprot.git
```
## Usage
Install InterProt, load ESM and SAE
```python
import torch
from transformers import AutoTokenizer, EsmModel
from safetensors.torch import load_file
from interprot.sae_model import SparseAutoencoder
from huggingface_hub import hf_hub_download
ESM_DIM = 1280
SAE_DIM = 4096
LAYER = 24
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load ESM model
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
esm_model.to(device)
esm_model.eval()
# Load SAE model
checkpoint_path = hf_hub_download(
repo_id="liambai/InterProt-ESM2-SAEs",
filename="esm2_plm1280_l24_sae4096.safetensors"
)
sae_model = SparseAutoencoder(ESM_DIM, SAE_DIM)
sae_model.load_state_dict(load_file(checkpoint_path))
sae_model.to(device)
sae_model.eval()
```
ESM -> SAE inference on an amino acid sequence of length `L`
```
seq = "TTCCPSIVARSNFNVCRLPGTPEALCATYTGCIIIPGATCPGDYAN"
# Tokenize sequence and run ESM inference
inputs = tokenizer(seq, padding=True, return_tensors="pt").to(device)
with torch.no_grad():
outputs = esm_model(**inputs, output_hidden_states=True)
# esm_layer_acts has shape (L+2, ESM_DIM), +2 for BoS and EoS tokens
esm_layer_acts = outputs.hidden_states[LAYER][0]
# Using ESM embeddings from LAYER, run SAE inference
sae_acts = sae_model.get_acts(esm_layer_acts) # (L+2, SAE_DIM)
sae_acts
```
## Note on the default checkpoint on [interprot.com](https://interprot.com)
In Novermber 2024, we shared an earlier version of our layer 24 SAE on [X](https://x.com/liambai21/status/1852765669080879108?s=46) and got a lot of amazing community support in identifying SAE features; therefore, we have kept it as the default on [interprot.com](interprot.com). Since then, we retrained the layer 24 SAE with slightly different hyperparameters and on more sequences (1M vs. the original 100K). The new SAE is named `esm2_plm1280_l24_sae4096.safetensors` whereas the original is named `esm2_plm1280_l24_sae4096_100k.safetensors`.
We recommend using `esm2_plm1280_l24_sae4096.safetensors`, but if you'd like to reproduce the default SAE on [interprot.com](https://interprot.com), you can use `esm2_plm1280_l24_sae4096_100k.safetensors`. All other layer SAEs are trained with the same configrations as `esm2_plm1280_l24_sae4096.safetensors`.