|
import torch |
|
import torch.nn as nn |
|
from torch_geometric.nn import radius_graph |
|
import scanpy as sc |
|
from huggingface_hub import PyTorchModelHubMixin |
|
from models_cifm.mlp_and_gnn import MLPBiasFree |
|
from models_cifm.egnn_void_invariant import VIEGNNModel |
|
|
|
|
|
class CIFM( |
|
nn.Module, |
|
PyTorchModelHubMixin, |
|
|
|
repo_url='ynyou/CIFM', |
|
pipeline_tag='mask-generation', |
|
license='mit', |
|
): |
|
def __init__(self, args): |
|
super().__init__() |
|
self.gene_encoder = MLPBiasFree(in_dim=args.in_dim, out_dim=args.hidden_dim, hidden_dim=args.hidden_dim, num_layer=args.num_mlp_layers_in_module) |
|
self.model = VIEGNNModel(num_layers=args.num_layer, num_mlp_layers_in_module=args.num_mlp_layers_in_module, |
|
emb_dim=args.hidden_dim, in_dim=args.hidden_dim, out_dim=args.hidden_dim, residual=True) |
|
self.mask_cell_decoder = VIEGNNModel(num_layers=args.num_layer, num_mlp_layers_in_module=args.num_mlp_layers_in_module, |
|
emb_dim=args.hidden_dim, in_dim=args.hidden_dim, out_dim=args.hidden_dim, residual=False) |
|
self.mask_cell_expression = MLPBiasFree(in_dim=args.hidden_dim, out_dim=args.in_dim, hidden_dim=args.hidden_dim, num_layer=args.num_mlp_layers_in_module) |
|
self.mask_cell_dropout = MLPBiasFree(in_dim=args.hidden_dim, out_dim=args.in_dim, hidden_dim=args.hidden_dim, num_layer=args.num_mlp_layers_in_module) |
|
self.mask_embedding = nn.Embedding(1, args.hidden_dim) |
|
|
|
self.relu = nn.ReLU() |
|
self.sigmoid = nn.Sigmoid() |
|
|
|
self.hidden_dim = args.hidden_dim |
|
self.radius_spatial_graph = args.radius_spatial_graph |
|
|
|
def channel_matching(self, channel2ensembl_ids_target, channel2ensembl_ids_source, zero_init_for_unmatched_genes=True): |
|
|
|
|
|
linear_in = nn.Linear(len(channel2ensembl_ids_target), self.hidden_dim, bias=False) |
|
linear_out1 = nn.Linear(self.hidden_dim, len(channel2ensembl_ids_target), bias=False) |
|
linear_out2 = nn.Linear(self.hidden_dim, len(channel2ensembl_ids_target), bias=False) |
|
if zero_init_for_unmatched_genes: |
|
linear_in.weight.data.zero_() |
|
linear_out1.weight.data.zero_() |
|
linear_out2.weight.data.zero_() |
|
|
|
num_matching = 0 |
|
unmatched_channels = [] |
|
for idx_target, ensembls in enumerate(channel2ensembl_ids_target): |
|
if len(ensembls) == 0: |
|
continue |
|
|
|
embs_in = [] |
|
embs_out1 = [] |
|
embs_out2 = [] |
|
for ensembl in ensembls: |
|
for idx_source, ensembles2 in enumerate(channel2ensembl_ids_source): |
|
if ensembl in ensembles2: |
|
embs_in.append(self.gene_encoder.layers[0].weight.data[:, idx_source]) |
|
embs_out1.append(self.mask_cell_expression.layers[-1].weight.data[idx_source]) |
|
embs_out2.append(self.mask_cell_dropout.layers[-1].weight.data[idx_source]) |
|
|
|
if len(embs_in) == 0: |
|
unmatched_channels += ensembls |
|
continue |
|
|
|
embs_in = torch.stack(embs_in).mean(dim=0) |
|
embs_out1 = torch.stack(embs_out1).mean(dim=0) |
|
embs_out2 = torch.stack(embs_out2).mean(dim=0) |
|
linear_in.weight.data[:, idx_target] = embs_in |
|
linear_out1.weight.data[idx_target] = embs_out1 |
|
linear_out2.weight.data[idx_target] = embs_out2 |
|
|
|
num_matching += 1 |
|
|
|
self.gene_encoder.layers[0] = linear_in |
|
self.mask_cell_expression.layers[-1] = linear_out1 |
|
self.mask_cell_dropout.layers[-1] = linear_out2 |
|
|
|
unmatched_channels = list(set(unmatched_channels)) |
|
print('matching', num_matching, 'gene channels out of', len(channel2ensembl_ids_target), '; unmatched channels:', unmatched_channels) |
|
|
|
def forward(self): |
|
pass |
|
|
|
def encode(self, expressions, coordinates, edge_index): |
|
embeddings = self.gene_encoder(expressions) |
|
embeddings, _ = self.model(embeddings, coordinates, edge_index) |
|
return embeddings |
|
|
|
def encode_decode(self, expressions, coordinates, edge_index, mapping): |
|
device = expressions.device |
|
|
|
embeddings = self.encode(expressions, coordinates, edge_index) |
|
embeddings[mapping] = self.mask_embedding(torch.zeros(1, dtype=torch.int64).to(device)) |
|
embeddings_dec = self.mask_cell_decoder(embeddings, coordinates, edge_index)[0][mapping] |
|
|
|
expressions_dec = self.relu(self.mask_cell_expression(embeddings_dec)) |
|
dropouts_dec = self.sigmoid(self.mask_cell_dropout(embeddings_dec)) |
|
|
|
expressions_dec[dropouts_dec<=0.5] = 0 |
|
return expressions_dec |
|
|
|
def embed(self, adata): |
|
device = next(self.parameters()).device |
|
|
|
expressions = torch.tensor(adata.X.toarray(), dtype=torch.float32).to(device) |
|
coordinates = torch.tensor(adata.obsm['spatial'], dtype=torch.float32) |
|
coordinates = torch.cat([coordinates, torch.zeros(coordinates.shape[0], 1)], dim=1).to(device) |
|
edge_index = radius_graph(coordinates, r=self.radius_spatial_graph, max_num_neighbors=10000, loop=True) |
|
|
|
embeddings = self.encode(expressions, coordinates, edge_index) |
|
return embeddings |
|
|
|
def predict_cells_at_locations(self, adata, locations): |
|
device = next(self.parameters()).device |
|
locations = torch.tensor(locations, dtype=torch.float32).to(device) |
|
|
|
expressions = torch.tensor(adata.X.toarray(), dtype=torch.float32).to(device) |
|
expressions = torch.cat([expressions, torch.zeros(locations.shape[0], expressions.shape[1])], dim=0) |
|
coordinates = torch.tensor(adata.obsm['spatial'], dtype=torch.float32) |
|
coordinates = torch.cat([coordinates, locations], dim=0) |
|
coordinates = torch.cat([coordinates, torch.zeros(coordinates.shape[0], 1)], dim=1).to(device) |
|
edge_index = radius_graph(coordinates, r=self.radius_spatial_graph, max_num_neighbors=10000, loop=True) |
|
idx_cells_to_predict = torch.arange(expressions.shape[0]-locations.shape[0], expressions.shape[0]).to(device) |
|
|
|
expressions_pred = self.encode_decode(expressions, coordinates, edge_index, idx_cells_to_predict) |
|
return expressions_pred |
|
|