Update README.md
Browse files
README.md
CHANGED
@@ -21,27 +21,41 @@ pip install transformers
|
|
21 |
To calculate the embedding of a dna sequence:
|
22 |
```python
|
23 |
|
|
|
24 |
import torch
|
25 |
from transformers import PreTrainedTokenizerFast
|
26 |
-
from
|
27 |
from plantgfm.configuration_plantgfm import PlantGFMConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
-
config = PlantGFMConfig.from_pretrained("hu-lab/PlantGFM")
|
30 |
-
tokenizer = PreTrainedTokenizerFast.from_pretrained("hu-lab/PlantGFM")
|
31 |
-
model = PlantGFMForCausalLM.from_pretrained("hu-lab/PlantGFM", config=config)
|
32 |
|
|
|
|
|
33 |
|
34 |
-
sequences = ["CCCTAAACCCTAAACCCTAAA", "ATGGCGTGGCTG"]
|
35 |
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
|
40 |
-
tokenized_sequences = tokenizer(single_nucleotide_sequences, padding="longest")["input_ids"]
|
41 |
-
input_ids = torch.LongTensor(tokenized_sequences)
|
42 |
|
43 |
-
embd = model(input_ids=input_ids, output_hidden_states=True)["hidden_states"][0]
|
44 |
-
print(embd)
|
45 |
```
|
46 |
|
47 |
|
|
|
21 |
To calculate the embedding of a dna sequence:
|
22 |
```python
|
23 |
|
24 |
+
|
25 |
import torch
|
26 |
from transformers import PreTrainedTokenizerFast
|
27 |
+
from torch.cuda.amp import autocast
|
28 |
from plantgfm.configuration_plantgfm import PlantGFMConfig
|
29 |
+
from plantgfm.modeling_plantgfm import PlantGFMForCausalLM
|
30 |
+
|
31 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
32 |
+
config = PlantGFMConfig.from_pretrained("hu-lab/PlantGFM-Gene-generation ")
|
33 |
+
tokenizer = PreTrainedTokenizerFast.from_pretrained("hu-lab/PlantGFM-Gene-generation")
|
34 |
+
model = PlantGFMForCausalLM.from_pretrained("hu-lab/PlantGFM-Gene-generation", config=config).to(device)
|
35 |
+
model = model.to(dtype=torch.bfloat16)
|
36 |
+
|
37 |
+
num_texts = 1
|
38 |
+
batch_size = 1
|
39 |
+
generated_texts = []
|
40 |
|
|
|
|
|
|
|
41 |
|
42 |
+
input_ids = tokenizer.encode("", return_tensors="pt").to(device, dtype=torch.long)
|
43 |
+
input_ids = input_ids.expand(batch_size, -1)
|
44 |
|
|
|
45 |
|
46 |
+
for i in range(0, num_texts, batch_size):
|
47 |
+
with autocast(dtype=torch.bfloat16):
|
48 |
+
generated_text = model.generate(
|
49 |
+
input_ids=input_ids,
|
50 |
+
max_length=4000,
|
51 |
+
do_sample=True,
|
52 |
+
)
|
53 |
+
for output_sequence in output:
|
54 |
+
generated_text = tokenizer.decode(output_sequence, skip_special_tokens=True)
|
55 |
+
print(generated_text)
|
56 |
|
57 |
|
|
|
|
|
58 |
|
|
|
|
|
59 |
```
|
60 |
|
61 |
|