|
--- |
|
tags: |
|
- biology |
|
--- |
|
# Model Card for Model ID |
|
PlantGFM-Gene-generation is a gene generation model re-trained from PlantGFM using DNA sequences of 355,190 natural plant genes with lengths less than or equal to 4,000 base pairs. The model was re-trained with prompt-based training for two epochs, using the prompt "gene" to guide the learning process and help the model generate novel plant gene sequences that align with the patterns and structures of natural genes. |
|
|
|
### Model Sources |
|
|
|
- **Repository:** [PlantGFM](https://github.com/hu-lab-PlantGLM/PlantGLM) |
|
- **Manuscript:** [A Genetic Foundation Model for Discovery and Creation of Plant Genes]() |
|
|
|
**Developed by:** hu-lab |
|
|
|
# How to use the model |
|
|
|
Install the runtime library first: |
|
```bash |
|
pip install transformers |
|
``` |
|
To generate a new gene sequence using the model: |
|
```python |
|
import torch |
|
from transformers import PreTrainedTokenizerFast |
|
from torch.cuda.amp import autocast |
|
from plantgfm.configuration_plantgfm import PlantGFMConfig |
|
from plantgfm.modeling_plantgfm import PlantGFMForCausalLM |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
config = PlantGFMConfig.from_pretrained("hu-lab/PlantGFM-Gene-generation ") |
|
tokenizer = PreTrainedTokenizerFast.from_pretrained("hu-lab/PlantGFM-Gene-generation") |
|
model = PlantGFMForCausalLM.from_pretrained("hu-lab/PlantGFM-Gene-generation", config=config).to(device) |
|
model = model.to(dtype=torch.bfloat16) |
|
|
|
num_texts = 1 |
|
batch_size = 1 |
|
generated_texts = [] |
|
|
|
|
|
input_ids = tokenizer.encode("", return_tensors="pt").to(device, dtype=torch.long) |
|
input_ids = input_ids.expand(batch_size, -1) |
|
|
|
|
|
for i in range(0, num_texts, batch_size): |
|
with autocast(dtype=torch.bfloat16): |
|
generated_text = model.generate( |
|
input_ids=input_ids, |
|
max_length=4000, |
|
do_sample=True, |
|
) |
|
for output_sequence in output: |
|
generated_text = tokenizer.decode(output_sequence, skip_special_tokens=True) |
|
print(generated_text) |
|
``` |
|
|
|
|
|
|
|
#### Hardware |
|
Model was trained for 15 hours on 2 Nvidia A100-40G GPUs. |