|
--- |
|
tags: |
|
- biology |
|
--- |
|
# Model Card for Model ID |
|
PlantGFM is a genetic foundation model pre-trained on the complete genome sequences of 12 model plants, encompassing 108 billion nucleotides. Using the Hyena framework with 220 million parameters and a context length of 64K bp, PlantGFM models sequences at single-nucleotide resolution. The model employed a length warm-up strategy, starting with 1K bp fragments and gradually increasing to 64K bp, enhancing training stability and accelerating convergence. |
|
|
|
### Model Sources |
|
|
|
- **Repository:** [PlantGFM](https://github.com/hu-lab-PlantGLM/PlantGLM) |
|
- **Manuscript:** [A Genetic Foundation Model for Discovery and Creation of Plant Genes]() |
|
|
|
**Developed by:** hu-lab |
|
|
|
# How to use the model |
|
|
|
Install the runtime library first: |
|
```bash |
|
pip install transformers |
|
``` |
|
To generate a new gene sequence using the model: |
|
```python |
|
import torch |
|
from transformers import PreTrainedTokenizerFast |
|
from torch.cuda.amp import autocast |
|
from plantgfm.configuration_plantgfm import PlantGFMConfig |
|
from plantgfm.modeling_plantgfm import PlantGFMForCausalLM |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
config = PlantGFMConfig.from_pretrained("hu-lab/PlantGFM-Gene-generation ") |
|
tokenizer = PreTrainedTokenizerFast.from_pretrained("hu-lab/PlantGFM-Gene-generation") |
|
model = PlantGFMForCausalLM.from_pretrained("hu-lab/PlantGFM-Gene-generation", config=config).to(device) |
|
model = model.to(dtype=torch.bfloat16) |
|
|
|
num_texts = 1 |
|
batch_size = 1 |
|
generated_texts = [] |
|
|
|
|
|
input_ids = tokenizer.encode("", return_tensors="pt").to(device, dtype=torch.long) |
|
input_ids = input_ids.expand(batch_size, -1) |
|
|
|
|
|
for i in range(0, num_texts, batch_size): |
|
with autocast(dtype=torch.bfloat16): |
|
generated_text = model.generate( |
|
input_ids=input_ids, |
|
max_length=4000, |
|
do_sample=True, |
|
) |
|
for output_sequence in output: |
|
generated_text = tokenizer.decode(output_sequence, skip_special_tokens=True) |
|
print(generated_text) |
|
``` |
|
|
|
|
|
|
|
#### Hardware |
|
Model was trained for 15 hours on 2 Nvidia A100-40G GPUs. |