Yanisadel commited on
Commit
d33edb8
·
verified ·
1 Parent(s): de4c019

Update README.md with How to use

Browse files
Files changed (1) hide show
  1. README.md +44 -1
README.md CHANGED
@@ -22,9 +22,52 @@ tissue-specific promoters and enhancers, and CTCF-bound sites) elements.
22
 
23
  ### How to use
24
 
25
- To Be Done
26
 
 
 
 
 
 
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  ## Training data
30
 
 
22
 
23
  ### How to use
24
 
25
+ Until its next release, the transformers library needs to be installed from source with the following command in order to use the models. PyTorch should also be installed in order to one-hot encode the input sequences.
26
 
27
+ ```
28
+ pip install --upgrade git+https://github.com/huggingface/transformers.git
29
+ pip install torch
30
+ ```
31
+
32
+ A small snippet of code is given here in order to retrieve both logits from dummy DNA sequences.
33
 
34
+ ```
35
+ import torch
36
+ from transformers import AutoModel
37
+
38
+ model = AutoModel.from_pretrained("InstaDeepAI/segment_enformer", trust_remote_code=True)
39
+
40
+ def encode_sequences(sequences):
41
+ one_hot_map = {
42
+ 'a': torch.tensor([1., 0., 0., 0.]),
43
+ 'c': torch.tensor([0., 1., 0., 0.]),
44
+ 'g': torch.tensor([0., 0., 1., 0.]),
45
+ 't': torch.tensor([0., 0., 0., 1.]),
46
+ 'n': torch.tensor([0., 0., 0., 0.]),
47
+ 'A': torch.tensor([1., 0., 0., 0.]),
48
+ 'C': torch.tensor([0., 1., 0., 0.]),
49
+ 'G': torch.tensor([0., 0., 1., 0.]),
50
+ 'T': torch.tensor([0., 0., 0., 1.]),
51
+ 'N': torch.tensor([0., 0., 0., 0.])
52
+ }
53
+
54
+ def encode_sequence(seq_str):
55
+ one_hot_list = []
56
+ for char in seq_str:
57
+ one_hot_vector = one_hot_map.get(char, torch.tensor([0.25, 0.25, 0.25, 0.25]))
58
+ one_hot_list.append(one_hot_vector)
59
+ return torch.stack(one_hot_list)
60
+
61
+ if isinstance(sequences, list):
62
+ return torch.stack([encode_sequence(seq) for seq in sequences])
63
+ else:
64
+ return encode_sequence(sequences)
65
+
66
+ sequences = ["A"*196608, "G"*196608]
67
+ one_hot_encoding = encode_sequences(sequences)
68
+ preds = model(one_hot_encoding)
69
+ print(preds['logits'])
70
+ ```
71
 
72
  ## Training data
73