hu-lab commited on
Commit
d84c46c
·
verified ·
1 Parent(s): 8b1cf03

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +25 -11
README.md CHANGED
@@ -21,27 +21,41 @@ pip install transformers
21
  To calculate the embedding of a dna sequence:
22
  ```python
23
 
 
24
  import torch
25
  from transformers import PreTrainedTokenizerFast
26
- from plantgfm.modeling_plantgfm import PlantGFMForCausalLM
27
  from plantgfm.configuration_plantgfm import PlantGFMConfig
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- config = PlantGFMConfig.from_pretrained("hu-lab/PlantGFM")
30
- tokenizer = PreTrainedTokenizerFast.from_pretrained("hu-lab/PlantGFM")
31
- model = PlantGFMForCausalLM.from_pretrained("hu-lab/PlantGFM", config=config)
32
 
 
 
33
 
34
- sequences = ["CCCTAAACCCTAAACCCTAAA", "ATGGCGTGGCTG"]
35
 
36
- # get single-nucleotide sequences with space between each base
37
- single_nucleotide_sequences = list(map(lambda seq: " ".join(list(seq)), sequences))
 
 
 
 
 
 
 
 
38
 
39
 
40
- tokenized_sequences = tokenizer(single_nucleotide_sequences, padding="longest")["input_ids"]
41
- input_ids = torch.LongTensor(tokenized_sequences)
42
 
43
- embd = model(input_ids=input_ids, output_hidden_states=True)["hidden_states"][0]
44
- print(embd)
45
  ```
46
 
47
 
 
21
  To calculate the embedding of a dna sequence:
22
  ```python
23
 
24
+
25
  import torch
26
  from transformers import PreTrainedTokenizerFast
27
+ from torch.cuda.amp import autocast
28
  from plantgfm.configuration_plantgfm import PlantGFMConfig
29
+ from plantgfm.modeling_plantgfm import PlantGFMForCausalLM
30
+
31
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ config = PlantGFMConfig.from_pretrained("hu-lab/PlantGFM-Gene-generation ")
33
+ tokenizer = PreTrainedTokenizerFast.from_pretrained("hu-lab/PlantGFM-Gene-generation")
34
+ model = PlantGFMForCausalLM.from_pretrained("hu-lab/PlantGFM-Gene-generation", config=config).to(device)
35
+ model = model.to(dtype=torch.bfloat16)
36
+
37
+ num_texts = 1
38
+ batch_size = 1
39
+ generated_texts = []
40
 
 
 
 
41
 
42
+ input_ids = tokenizer.encode("", return_tensors="pt").to(device, dtype=torch.long)
43
+ input_ids = input_ids.expand(batch_size, -1)
44
 
 
45
 
46
+ for i in range(0, num_texts, batch_size):
47
+ with autocast(dtype=torch.bfloat16):
48
+ generated_text = model.generate(
49
+ input_ids=input_ids,
50
+ max_length=4000,
51
+ do_sample=True,
52
+ )
53
+ for output_sequence in output:
54
+ generated_text = tokenizer.decode(output_sequence, skip_special_tokens=True)
55
+ print(generated_text)
56
 
57
 
 
 
58
 
 
 
59
  ```
60
 
61