In [None]:
import torch
import numpy as np
from models_cifm.cifm import CIFM
import scanpy as sc

### 1. load model

In [None]:
def load_model():
 args_model = torch.load('./models_cifm/args.pt')
 device = 'cpu' # or 'cuda' if you have a GPU
 model = CIFM.from_pretrained('ynyou/CIFM', args=args_model).to(device)
 model.channel2ensembl_ids_source = torch.load('./models_cifm/channel2ensembl.pt')
 model.eval()
 return model
model = load_model()

CIFM(
 (gene_encoder): MLPBiasFree(
 (layers): ModuleList(
 (0): Linear(in_features=18289, out_features=1024, bias=False)
 (1-3): 3 x Linear(in_features=1024, out_features=1024, bias=False)
 )
 (layernorms): ModuleList(
 (0-2): 3 x LayerNorm((1024,), eps=1e-05, elementwise_affine=False)
 )
 (activation): ReLU()
 )
 (model): VIEGNNModel(
 (emb_in): Linear(in_features=1024, out_features=1024, bias=False)
 (convs): ModuleList(
 (0-1): 2 x EGNNLayer(emb_dim=1024, aggr=sum)
 )
 (pred): MLPBiasFree(
 (layers): ModuleList(
 (0-3): 4 x Linear(in_features=1024, out_features=1024, bias=False)
 )
 (layernorms): ModuleList(
 (0-2): 3 x LayerNorm((1024,), eps=1e-05, elementwise_affine=False)
 )
 (activation): ReLU()
 )
 )
 (mask_cell_decoder): VIEGNNModel(
 (emb_in): Linear(in_features=1024, out_features=1024, bias=False)
 (convs): ModuleList(
 (0-1): 2 x EGNNLayer(emb_dim=1024, aggr=sum)
 )
 (pred): MLPBiasFree(
 (layers): ModuleList(
 (0-3): 4 x Linear(in_features=1024, out_features=1024, bias=Fa

### 2. load and preprocess sample adata
- some requirements for adata:
- ```adata.X```: need to the raw count
- ```adata.obsm['spatial']```: the coordinates of cells in the unit of micrometer
- if in a different unit, it might result in a weird geometric graph: we use a radius 20 (micrometer) to construct the geometric graph in the model, so a different unit might result in a overly sparse or dense graph

In [None]:
adata = sc.read_h5ad('./adata.h5ad')
adata.layers['counts'] = adata.X.copy()
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)
adata

AnnData object with n_obs × n_vars = 24844 × 18289
 obs: 'in_tissue'
 var: 'feature_types', 'genome', 'gene_names'
 uns: 'log1p'
 obsm: 'spatial'
 layers: 'counts'

### 3. match feature channels
- we need a list which maps feature channels to ensemble ids: ```channel2ensembl_ids_target```
- format: ```channel2ensembl_ids_target = [[ensemblid1_for_channel1, ensemblid2_for_channel1, ...], [ensemblid1_for_channel2, ensemblid2_for_channel2, ...], ...]```
- one channel could correspond to multiple ensemble ids, e.g., when in your original data the channels are annotated with gene names
- you can use BioMart to map your gene name into one or multiple ensemble ids

In [None]:
channel2ensembl_ids_target = [[i] for i in adata.var.index.tolist()]
model.channel_matching(channel2ensembl_ids_target, model.channel2ensembl_ids_source)

matching 18289 gene channels out of 18289 unmatched channels: []


### 4. embed the microenvironments centered at each cell

In [5]:
with torch.no_grad():
 embeddings = model.embed(adata)
embeddings, embeddings.shape

(tensor([[-0.4132, -0.9847, 0.1647, ..., -0.8351, -0.8177, -1.3235],
 [ 0.8701, 0.0967, -0.3676, ..., 0.2687, -1.4821, 0.1605],
 [-0.5178, -0.4442, -0.0862, ..., -0.7446, -0.5761, -0.5571],
 ...,
 [ 1.2264, 1.2326, 0.2791, ..., 0.8018, -1.4069, 1.4567],
 [ 0.6699, -0.6107, 0.2450, ..., -0.1975, -0.6034, -0.6608],
 [-1.9240, -1.8125, -0.0766, ..., -0.2799, -0.0217, -2.2051]]),
 torch.Size([13898, 1024]))

### 5. infer the potential gene expressions at certain locations

In [None]:
# we here randomly generate the locations for the cells just for demonstration
target_locs = np.random.rand(10, 2)
x_min, x_max = adata.obsm['spatial'][:, 0].min(), adata.obsm['spatial'][:, 0].max()
y_min, y_max = adata.obsm['spatial'][:, 1].min(), adata.obsm['spatial'][:, 1].max()
target_locs[:, 0] = target_locs[:, 0] * (x_max - x_min) + x_min
target_locs[:, 1] = target_locs[:, 1] * (y_max - y_min) + y_min

with torch.no_grad():
 expressions = model.predict_cells_at_locations(adata, target_locs)
expressions, expressions.shape

(tensor([[0.0000, 0.0000, 0.8603, ..., 0.0000, 0.0000, 0.0000],
 [0.0000, 0.0000, 0.6644, ..., 0.0000, 0.0000, 0.0000],
 [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
 ...,
 [0.0000, 0.0000, 0.9809, ..., 0.0000, 0.0000, 0.0000],
 [0.6641, 0.0000, 0.6858, ..., 0.0000, 0.0000, 0.0000],
 [0.4999, 0.0000, 0.5311, ..., 0.0000, 0.0000, 0.0000]]),
 torch.Size([10, 18289]))