File size: 2,119 Bytes
8b1cf03 e081ac4 8b1cf03 d84c46c 8b1cf03 d84c46c 8b1cf03 d84c46c 8b1cf03 d84c46c 8b1cf03 cbf9103 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
---
tags:
- biology
---
# Model Card for Model ID
PlantGFM is a genetic foundation model pre-trained on the complete genome sequences of 12 model plants, encompassing 108 billion nucleotides. Using the Hyena framework with 220 million parameters and a context length of 64K bp, PlantGFM models sequences at single-nucleotide resolution. The model employed a length warm-up strategy, starting with 1K bp fragments and gradually increasing to 64K bp, enhancing training stability and accelerating convergence.
### Model Sources
- **Repository:** [PlantGFM](https://github.com/hu-lab-PlantGLM/PlantGLM)
- **Manuscript:** [A Genetic Foundation Model for Discovery and Creation of Plant Genes]()
**Developed by:** hu-lab
# How to use the model
Install the runtime library first:
```bash
pip install transformers
```
To generate a new gene sequence using the model:
```python
import torch
from transformers import PreTrainedTokenizerFast
from torch.cuda.amp import autocast
from plantgfm.configuration_plantgfm import PlantGFMConfig
from plantgfm.modeling_plantgfm import PlantGFMForCausalLM
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = PlantGFMConfig.from_pretrained("hu-lab/PlantGFM-Gene-generation ")
tokenizer = PreTrainedTokenizerFast.from_pretrained("hu-lab/PlantGFM-Gene-generation")
model = PlantGFMForCausalLM.from_pretrained("hu-lab/PlantGFM-Gene-generation", config=config).to(device)
model = model.to(dtype=torch.bfloat16)
num_texts = 1
batch_size = 1
generated_texts = []
input_ids = tokenizer.encode("", return_tensors="pt").to(device, dtype=torch.long)
input_ids = input_ids.expand(batch_size, -1)
for i in range(0, num_texts, batch_size):
with autocast(dtype=torch.bfloat16):
generated_text = model.generate(
input_ids=input_ids,
max_length=4000,
do_sample=True,
)
for output_sequence in output:
generated_text = tokenizer.decode(output_sequence, skip_special_tokens=True)
print(generated_text)
```
#### Hardware
Model was trained for 15 hours on 2 Nvidia A100-40G GPUs. |