hu-lab's picture
Update README.md
cbf9103 verified
|
raw
history blame
2.12 kB
metadata
tags:
  - biology

Model Card for Model ID

PlantGFM is a genetic foundation model pre-trained on the complete genome sequences of 12 model plants, encompassing 108 billion nucleotides. Using the Hyena framework with 220 million parameters and a context length of 64K bp, PlantGFM models sequences at single-nucleotide resolution. The model employed a length warm-up strategy, starting with 1K bp fragments and gradually increasing to 64K bp, enhancing training stability and accelerating convergence.

Model Sources

Developed by: hu-lab

How to use the model

Install the runtime library first:

pip install transformers

To generate a new gene sequence using the model:

import torch
from transformers import PreTrainedTokenizerFast
from torch.cuda.amp import autocast
from plantgfm.configuration_plantgfm import PlantGFMConfig
from plantgfm.modeling_plantgfm import PlantGFMForCausalLM

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = PlantGFMConfig.from_pretrained("hu-lab/PlantGFM-Gene-generation ")
tokenizer = PreTrainedTokenizerFast.from_pretrained("hu-lab/PlantGFM-Gene-generation")
model = PlantGFMForCausalLM.from_pretrained("hu-lab/PlantGFM-Gene-generation", config=config).to(device)
model = model.to(dtype=torch.bfloat16)

num_texts = 1
batch_size = 1
generated_texts = []


input_ids = tokenizer.encode("", return_tensors="pt").to(device, dtype=torch.long)
input_ids = input_ids.expand(batch_size, -1)  


for i in range(0, num_texts, batch_size):
    with autocast(dtype=torch.bfloat16):
        generated_text = model.generate(
            input_ids=input_ids,
            max_length=4000,
            do_sample=True,
        )
        for output_sequence in output:
            generated_text = tokenizer.decode(output_sequence, skip_special_tokens=True)
            print(generated_text)

Hardware

Model was trained for 15 hours on 2 Nvidia A100-40G GPUs.