{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import numpy as np\n", "from models_cifm.cifm import CIFM\n", "import scanpy as sc" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1. load model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "CIFM(\n", " (gene_encoder): MLPBiasFree(\n", " (layers): ModuleList(\n", " (0): Linear(in_features=18289, out_features=1024, bias=False)\n", " (1-3): 3 x Linear(in_features=1024, out_features=1024, bias=False)\n", " )\n", " (layernorms): ModuleList(\n", " (0-2): 3 x LayerNorm((1024,), eps=1e-05, elementwise_affine=False)\n", " )\n", " (activation): ReLU()\n", " )\n", " (model): VIEGNNModel(\n", " (emb_in): Linear(in_features=1024, out_features=1024, bias=False)\n", " (convs): ModuleList(\n", " (0-1): 2 x EGNNLayer(emb_dim=1024, aggr=sum)\n", " )\n", " (pred): MLPBiasFree(\n", " (layers): ModuleList(\n", " (0-3): 4 x Linear(in_features=1024, out_features=1024, bias=False)\n", " )\n", " (layernorms): ModuleList(\n", " (0-2): 3 x LayerNorm((1024,), eps=1e-05, elementwise_affine=False)\n", " )\n", " (activation): ReLU()\n", " )\n", " )\n", " (mask_cell_decoder): VIEGNNModel(\n", " (emb_in): Linear(in_features=1024, out_features=1024, bias=False)\n", " (convs): ModuleList(\n", " (0-1): 2 x EGNNLayer(emb_dim=1024, aggr=sum)\n", " )\n", " (pred): MLPBiasFree(\n", " (layers): ModuleList(\n", " (0-3): 4 x Linear(in_features=1024, out_features=1024, bias=False)\n", " )\n", " (layernorms): ModuleList(\n", " (0-2): 3 x LayerNorm((1024,), eps=1e-05, elementwise_affine=False)\n", " )\n", " (activation): ReLU()\n", " )\n", " )\n", " (mask_cell_expression): MLPBiasFree(\n", " (layers): ModuleList(\n", " (0-2): 3 x Linear(in_features=1024, out_features=1024, bias=False)\n", " (3): Linear(in_features=1024, out_features=18289, bias=False)\n", " )\n", " (layernorms): ModuleList(\n", " (0-2): 3 x LayerNorm((1024,), eps=1e-05, elementwise_affine=False)\n", " )\n", " (activation): ReLU()\n", " )\n", " (mask_cell_dropout): MLPBiasFree(\n", " (layers): ModuleList(\n", " (0-2): 3 x Linear(in_features=1024, out_features=1024, bias=False)\n", " (3): Linear(in_features=1024, out_features=18289, bias=False)\n", " )\n", " (layernorms): ModuleList(\n", " (0-2): 3 x LayerNorm((1024,), eps=1e-05, elementwise_affine=False)\n", " )\n", " (activation): ReLU()\n", " )\n", " (mask_embedding): Embedding(1, 1024)\n", " (relu): ReLU()\n", " (sigmoid): Sigmoid()\n", ")" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def load_model():\n", " args_model = torch.load('./models_cifm/args.pt')\n", " device = 'cpu' # or 'cuda' if you have a GPU\n", " model = CIFM.from_pretrained('ynyou/CIFM', args=args_model).to(device)\n", " model.channel2ensembl_ids_source = torch.load('./models_cifm/channel2ensembl.pt')\n", " model.eval()\n", " return model\n", "model = load_model()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2. load and preprocess sample adata\n", "- some requirements for adata:\n", "- ```adata.X```: need to the raw count\n", "- ```adata.obsm['spatial']```: the coordinates of cells in the unit of micrometer\n", "- if in a different unit, it might result in a weird geometric graph: we use a radius 20 (micrometer) to construct the geometric graph in the model, so a different unit might result in a overly sparse or dense graph" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "AnnData object with n_obs × n_vars = 24844 × 18289\n", " obs: 'in_tissue'\n", " var: 'feature_types', 'genome', 'gene_names'\n", " uns: 'log1p'\n", " obsm: 'spatial'\n", " layers: 'counts'" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "adata = sc.read_h5ad('./adata.h5ad')\n", "adata.layers['counts'] = adata.X.copy()\n", "sc.pp.normalize_total(adata)\n", "sc.pp.log1p(adata)\n", "adata" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3. match feature channels\n", "- we need a list which maps feature channels to ensemble ids: ```channel2ensembl_ids_target```\n", "- format: ```channel2ensembl_ids_target = [[ensemblid1_for_channel1, ensemblid2_for_channel1, ...], [ensemblid1_for_channel2, ensemblid2_for_channel2, ...], ...]```\n", "- one channel could correspond to multiple ensemble ids, e.g., when in your original data the channels are annotated with gene names\n", "- you can use BioMart to map your gene name into one or multiple ensemble ids" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "matching 18289 gene channels out of 18289 unmatched channels: []\n" ] } ], "source": [ "channel2ensembl_ids_target = [[i] for i in adata.var.index.tolist()]\n", "model.channel_matching(channel2ensembl_ids_target, model.channel2ensembl_ids_source)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 4. embed the microenvironments centered at each cell" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([[-0.4132, -0.9847, 0.1647, ..., -0.8351, -0.8177, -1.3235],\n", " [ 0.8701, 0.0967, -0.3676, ..., 0.2687, -1.4821, 0.1605],\n", " [-0.5178, -0.4442, -0.0862, ..., -0.7446, -0.5761, -0.5571],\n", " ...,\n", " [ 1.2264, 1.2326, 0.2791, ..., 0.8018, -1.4069, 1.4567],\n", " [ 0.6699, -0.6107, 0.2450, ..., -0.1975, -0.6034, -0.6608],\n", " [-1.9240, -1.8125, -0.0766, ..., -0.2799, -0.0217, -2.2051]]),\n", " torch.Size([13898, 1024]))" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "with torch.no_grad():\n", " embeddings = model.embed(adata)\n", "embeddings, embeddings.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 5. infer the potential gene expressions at certain locations" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([[0.0000, 0.0000, 0.8603, ..., 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.6644, ..., 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", " ...,\n", " [0.0000, 0.0000, 0.9809, ..., 0.0000, 0.0000, 0.0000],\n", " [0.6641, 0.0000, 0.6858, ..., 0.0000, 0.0000, 0.0000],\n", " [0.4999, 0.0000, 0.5311, ..., 0.0000, 0.0000, 0.0000]]),\n", " torch.Size([10, 18289]))" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# we here randomly generate the locations for the cells just for demonstration\n", "target_locs = np.random.rand(10, 2)\n", "x_min, x_max = adata.obsm['spatial'][:, 0].min(), adata.obsm['spatial'][:, 0].max()\n", "y_min, y_max = adata.obsm['spatial'][:, 1].min(), adata.obsm['spatial'][:, 1].max()\n", "target_locs[:, 0] = target_locs[:, 0] * (x_max - x_min) + x_min\n", "target_locs[:, 1] = target_locs[:, 1] * (y_max - y_min) + y_min\n", "\n", "with torch.no_grad():\n", " expressions = model.predict_cells_at_locations(adata, target_locs)\n", "expressions, expressions.shape" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.10" } }, "nbformat": 4, "nbformat_minor": 2 }