{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from typing import List, Union\n", "\n", "import torch\n", "from transformers import AutoModel" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Load model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = AutoModel.from_pretrained(\"InstaDeepAI/segment_enformer\", trust_remote_code=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Define useful functions" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def encode_sequences(sequences: Union[str, List[str]]) -> torch.Tensor:\n", " \"\"\"\n", " One-hot encode a DNA sequence or a batch of DNA sequences.\n", "\n", " Args:\n", " sequences (Union[str, List[str]]): Either a DNA sequence or a list of DNA sequences\n", "\n", " Returns:\n", " torch.Tensor: One-hot encoded\n", " - If `sequences` is just one sequence (str), output shape is (seq_len, 4), seq_len being the length of a sequence\n", " - If `sequences` is a list of sequences, output shape is (num_sequences, seq_len, 4)\n", " \n", " Example:\n", " >>> sequences = [\"AC\", \"GT\"]\n", " >>> encode_sequences(sequences)\n", " tensor([[[1., 0., 0., 0.],\n", " [0., 1., 0., 0.]],\n", "\n", " [[0., 0., 1., 0.],\n", " [0., 0., 0., 1.]]])\n", " \"\"\"\n", " one_hot_map = {\n", " 'a': torch.tensor([1., 0., 0., 0.]),\n", " 'c': torch.tensor([0., 1., 0., 0.]),\n", " 'g': torch.tensor([0., 0., 1., 0.]),\n", " 't': torch.tensor([0., 0., 0., 1.]),\n", " 'n': torch.tensor([0., 0., 0., 0.]),\n", " 'A': torch.tensor([1., 0., 0., 0.]),\n", " 'C': torch.tensor([0., 1., 0., 0.]),\n", " 'G': torch.tensor([0., 0., 1., 0.]),\n", " 'T': torch.tensor([0., 0., 0., 1.]),\n", " 'N': torch.tensor([0., 0., 0., 0.])\n", " }\n", "\n", " def encode_sequence(seq_str):\n", " one_hot_list = []\n", " for char in seq_str:\n", " one_hot_vector = one_hot_map.get(char, torch.tensor([0.25, 0.25, 0.25, 0.25]))\n", " one_hot_list.append(one_hot_vector)\n", " return torch.stack(one_hot_list)\n", "\n", " if isinstance(sequences, list):\n", " return torch.stack([encode_sequence(seq) for seq in sequences])\n", " else:\n", " return encode_sequence(sequences)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Inference example" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sequences = [\"A\"*196608, \"G\"*196608]\n", "one_hot_encoding = encode_sequences(sequences)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "preds = model(one_hot_encoding)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(preds['logits'])" ] } ], "metadata": { "kernelspec": { "display_name": "genomics-research-env", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.10" } }, "nbformat": 4, "nbformat_minor": 2 }