--- license: apache-2.0 --- # InterProt ESM2 SAE Models A set of SAE models trained on [ESM2-650](https://huggingface.co/facebook/esm2_t33_650M_UR50D) activations using 1M protein sequences from [UniProt](https://www.uniprot.org/). The SAE implementation mostly followed [Gao et al.](https://arxiv.org/abs/2406.04093) with Top-K activation function. For more information, check out our [preprint](https://www.biorxiv.org/content/10.1101/2025.02.06.636901v1). Our SAEs can be viewed and interacted with on [interprot.com](https://interprot.com). ## Installation ```bash pip install git+https://github.com/etowahadams/interprot.git ``` ## Usage Install InterProt, load ESM and SAE ```python import torch from transformers import AutoTokenizer, EsmModel from safetensors.torch import load_file from interprot.sae_model import SparseAutoencoder from huggingface_hub import hf_hub_download ESM_DIM = 1280 SAE_DIM = 4096 LAYER = 24 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load ESM model tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D") esm_model.to(device) esm_model.eval() # Load SAE model checkpoint_path = hf_hub_download( repo_id="liambai/InterProt-ESM2-SAEs", filename="esm2_plm1280_l24_sae4096.safetensors" ) sae_model = SparseAutoencoder(ESM_DIM, SAE_DIM) sae_model.load_state_dict(load_file(checkpoint_path)) sae_model.to(device) sae_model.eval() ``` ESM -> SAE inference on an amino acid sequence of length `L` ``` seq = "TTCCPSIVARSNFNVCRLPGTPEALCATYTGCIIIPGATCPGDYAN" # Tokenize sequence and run ESM inference inputs = tokenizer(seq, padding=True, return_tensors="pt").to(device) with torch.no_grad(): outputs = esm_model(**inputs, output_hidden_states=True) # esm_layer_acts has shape (L+2, ESM_DIM), +2 for BoS and EoS tokens esm_layer_acts = outputs.hidden_states[LAYER][0] # Using ESM embeddings from LAYER, run SAE inference sae_acts = sae_model.get_acts(esm_layer_acts) # (L+2, SAE_DIM) sae_acts ``` ## Note on the default checkpoint on [interprot.com](https://interprot.com) In Novermber 2024, we shared an earlier version of our layer 24 SAE on [X](https://x.com/liambai21/status/1852765669080879108?s=46) and got a lot of amazing community support in identifying SAE features; therefore, we have kept it as the default on [interprot.com](interprot.com). Since then, we retrained the layer 24 SAE with slightly different hyperparameters and on more sequences (1M vs. the original 100K). The new SAE is named `esm2_plm1280_l24_sae4096.safetensors` whereas the original is named `esm2_plm1280_l24_sae4096_100k.safetensors`. We recommend using `esm2_plm1280_l24_sae4096.safetensors`, but if you'd like to reproduce the default SAE on [interprot.com](https://interprot.com), you can use `esm2_plm1280_l24_sae4096_100k.safetensors`. All other layer SAEs are trained with the same configrations as `esm2_plm1280_l24_sae4096.safetensors`.