import torch import torch.nn as nn from torch_geometric.nn import radius_graph import scanpy as sc from huggingface_hub import PyTorchModelHubMixin from models_cifm.mlp_and_gnn import MLPBiasFree from models_cifm.egnn_void_invariant import VIEGNNModel class CIFM( nn.Module, PyTorchModelHubMixin, # optionally, you can add metadata which gets pushed to the model card repo_url='ynyou/CIFM', pipeline_tag='mask-generation', license='mit', ): def __init__(self, args): super().__init__() self.gene_encoder = MLPBiasFree(in_dim=args.in_dim, out_dim=args.hidden_dim, hidden_dim=args.hidden_dim, num_layer=args.num_mlp_layers_in_module) self.model = VIEGNNModel(num_layers=args.num_layer, num_mlp_layers_in_module=args.num_mlp_layers_in_module, emb_dim=args.hidden_dim, in_dim=args.hidden_dim, out_dim=args.hidden_dim, residual=True) self.mask_cell_decoder = VIEGNNModel(num_layers=args.num_layer, num_mlp_layers_in_module=args.num_mlp_layers_in_module, emb_dim=args.hidden_dim, in_dim=args.hidden_dim, out_dim=args.hidden_dim, residual=False) self.mask_cell_expression = MLPBiasFree(in_dim=args.hidden_dim, out_dim=args.in_dim, hidden_dim=args.hidden_dim, num_layer=args.num_mlp_layers_in_module) self.mask_cell_dropout = MLPBiasFree(in_dim=args.hidden_dim, out_dim=args.in_dim, hidden_dim=args.hidden_dim, num_layer=args.num_mlp_layers_in_module) self.mask_embedding = nn.Embedding(1, args.hidden_dim) self.relu = nn.ReLU() self.sigmoid = nn.Sigmoid() self.hidden_dim = args.hidden_dim self.radius_spatial_graph = args.radius_spatial_graph def channel_matching(self, channel2ensembl_ids_target, channel2ensembl_ids_source, zero_init_for_unmatched_genes=True): device = next(self.parameters()).device linear_in = nn.Linear(len(channel2ensembl_ids_target), self.hidden_dim, bias=False) linear_out1 = nn.Linear(self.hidden_dim, len(channel2ensembl_ids_target), bias=False) linear_out2 = nn.Linear(self.hidden_dim, len(channel2ensembl_ids_target), bias=False) if zero_init_for_unmatched_genes: linear_in.weight.data.zero_() linear_out1.weight.data.zero_() linear_out2.weight.data.zero_() num_matching = 0 unmatched_channels = [] for idx_target, ensembls in enumerate(channel2ensembl_ids_target): if len(ensembls) == 0: continue embs_in = [] embs_out1 = [] embs_out2 = [] for ensembl in ensembls: for idx_source, ensembles2 in enumerate(channel2ensembl_ids_source): if ensembl in ensembles2: embs_in.append(self.gene_encoder.layers[0].weight.data[:, idx_source]) embs_out1.append(self.mask_cell_expression.layers[-1].weight.data[idx_source]) embs_out2.append(self.mask_cell_dropout.layers[-1].weight.data[idx_source]) if len(embs_in) == 0: unmatched_channels += ensembls continue embs_in = torch.stack(embs_in).mean(dim=0) embs_out1 = torch.stack(embs_out1).mean(dim=0) embs_out2 = torch.stack(embs_out2).mean(dim=0) linear_in.weight.data[:, idx_target] = embs_in linear_out1.weight.data[idx_target] = embs_out1 linear_out2.weight.data[idx_target] = embs_out2 num_matching += 1 self.gene_encoder.layers[0] = linear_in.to(device) self.mask_cell_expression.layers[-1] = linear_out1.to(device) self.mask_cell_dropout.layers[-1] = linear_out2.to(device) unmatched_channels = list(set(unmatched_channels)) print('matching', num_matching, 'gene channels out of', len(channel2ensembl_ids_target), '; unmatched channels:', unmatched_channels) def forward(self): pass def encode(self, expressions, coordinates, edge_index): embeddings = self.gene_encoder(expressions) embeddings, _ = self.model(embeddings, coordinates, edge_index) return embeddings def encode_decode(self, expressions, coordinates, edge_index, mapping): device = expressions.device embeddings = self.encode(expressions, coordinates, edge_index) embeddings[mapping] = self.mask_embedding(torch.zeros(1, dtype=torch.int64).to(device)) embeddings_dec = self.mask_cell_decoder(embeddings, coordinates, edge_index)[0][mapping] expressions_dec = self.relu(self.mask_cell_expression(embeddings_dec)) dropouts_dec = self.sigmoid(self.mask_cell_dropout(embeddings_dec)) expressions_dec[dropouts_dec<=0.5] = 0 return expressions_dec def embed(self, adata): device = next(self.parameters()).device expressions = torch.tensor(adata.X.toarray(), dtype=torch.float32).to(device) coordinates = torch.tensor(adata.obsm['spatial'], dtype=torch.float32) coordinates = torch.cat([coordinates, torch.zeros(coordinates.shape[0], 1)], dim=1).to(device) edge_index = radius_graph(coordinates, r=self.radius_spatial_graph, max_num_neighbors=10000, loop=True) embeddings = self.encode(expressions, coordinates, edge_index) return embeddings def predict_cells_at_locations(self, adata, locations): device = next(self.parameters()).device locations = torch.tensor(locations, dtype=torch.float32) expressions = torch.tensor(adata.X.toarray(), dtype=torch.float32).to(device) expressions = torch.cat([expressions, torch.zeros(locations.shape[0], expressions.shape[1]).to(device)], dim=0) coordinates = torch.tensor(adata.obsm['spatial'], dtype=torch.float32) coordinates = torch.cat([coordinates, locations], dim=0) coordinates = torch.cat([coordinates, torch.zeros(coordinates.shape[0], 1)], dim=1).to(device) edge_index = radius_graph(coordinates, r=self.radius_spatial_graph, max_num_neighbors=10000, loop=True) idx_cells_to_predict = torch.arange(expressions.shape[0]-locations.shape[0], expressions.shape[0]).to(device) expressions_pred = self.encode_decode(expressions, coordinates, edge_index, idx_cells_to_predict) return expressions_pred