In [None]:
from typing import List, Union

import torch
from transformers import AutoModel

# Load model

In [None]:
model = AutoModel.from_pretrained("InstaDeepAI/segment_enformer", trust_remote_code=True)

# Define useful functions

In [None]:
def encode_sequences(sequences: Union[str, List[str]]) -> torch.Tensor:
 """
 One-hot encode a DNA sequence or a batch of DNA sequences.

 Args:
 sequences (Union[str, List[str]]): Either a DNA sequence or a list of DNA sequences

 Returns:
 torch.Tensor: One-hot encoded
 - If `sequences` is just one sequence (str), output shape is (seq_len, 4), seq_len being the length of a sequence
 - If `sequences` is a list of sequences, output shape is (num_sequences, seq_len, 4)
 
 Example:
 >>> sequences = ["AC", "GT"]
 >>> encode_sequences(sequences)
 tensor([[[1., 0., 0., 0.],
 [0., 1., 0., 0.]],

 [[0., 0., 1., 0.],
 [0., 0., 0., 1.]]])
 """
 one_hot_map = {
 'a': torch.tensor([1., 0., 0., 0.]),
 'c': torch.tensor([0., 1., 0., 0.]),
 'g': torch.tensor([0., 0., 1., 0.]),
 't': torch.tensor([0., 0., 0., 1.]),
 'n': torch.tensor([0., 0., 0., 0.]),
 'A': torch.tensor([1., 0., 0., 0.]),
 'C': torch.tensor([0., 1., 0., 0.]),
 'G': torch.tensor([0., 0., 1., 0.]),
 'T': torch.tensor([0., 0., 0., 1.]),
 'N': torch.tensor([0., 0., 0., 0.])
 }

 def encode_sequence(seq_str):
 one_hot_list = []
 for char in seq_str:
 one_hot_vector = one_hot_map.get(char, torch.tensor([0.25, 0.25, 0.25, 0.25]))
 one_hot_list.append(one_hot_vector)
 return torch.stack(one_hot_list)

 if isinstance(sequences, list):
 return torch.stack([encode_sequence(seq) for seq in sequences])
 else:
 return encode_sequence(sequences)

# Inference example

In [None]:
sequences = ["A"*196608, "G"*196608]
one_hot_encoding = encode_sequences(sequences)

In [None]:
preds = model(one_hot_encoding)

In [None]:
print(preds['logits'])