File size: 2,119 Bytes
8b1cf03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e081ac4
8b1cf03
 
 
d84c46c
8b1cf03
d84c46c
 
 
 
 
 
 
 
 
 
 
8b1cf03
 
d84c46c
 
8b1cf03
 
d84c46c
 
 
 
 
 
 
 
 
 
8b1cf03
 
 
 
 
cbf9103
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
---
tags:
- biology
---
# Model Card for Model ID
PlantGFM is a genetic foundation model pre-trained on the complete genome sequences of 12 model plants, encompassing 108 billion nucleotides. Using the Hyena framework with 220 million parameters and a context length of 64K bp, PlantGFM models sequences at single-nucleotide resolution. The model employed a length warm-up strategy, starting with 1K bp fragments and gradually increasing to 64K bp, enhancing training stability and accelerating convergence.

### Model Sources

- **Repository:** [PlantGFM](https://github.com/hu-lab-PlantGLM/PlantGLM)
- **Manuscript:** [A Genetic Foundation Model for Discovery and Creation of Plant Genes]() 

**Developed by:** hu-lab

# How to use the model

Install the runtime library first:
```bash
pip install transformers
```
To generate a new gene sequence using the model:
```python
import torch
from transformers import PreTrainedTokenizerFast
from torch.cuda.amp import autocast
from plantgfm.configuration_plantgfm import PlantGFMConfig
from plantgfm.modeling_plantgfm import PlantGFMForCausalLM

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = PlantGFMConfig.from_pretrained("hu-lab/PlantGFM-Gene-generation ")
tokenizer = PreTrainedTokenizerFast.from_pretrained("hu-lab/PlantGFM-Gene-generation")
model = PlantGFMForCausalLM.from_pretrained("hu-lab/PlantGFM-Gene-generation", config=config).to(device)
model = model.to(dtype=torch.bfloat16)

num_texts = 1
batch_size = 1
generated_texts = []


input_ids = tokenizer.encode("", return_tensors="pt").to(device, dtype=torch.long)
input_ids = input_ids.expand(batch_size, -1)  


for i in range(0, num_texts, batch_size):
    with autocast(dtype=torch.bfloat16):
        generated_text = model.generate(
            input_ids=input_ids,
            max_length=4000,
            do_sample=True,
        )
        for output_sequence in output:
            generated_text = tokenizer.decode(output_sequence, skip_special_tokens=True)
            print(generated_text)
```



#### Hardware
Model was trained for 15 hours on 2 Nvidia A100-40G GPUs.