CIFM / models_cifm /cifm.py
Yuning You
update
552cf9a
raw
history blame
6.41 kB
import torch
import torch.nn as nn
from torch_geometric.nn import radius_graph
import scanpy as sc
from huggingface_hub import PyTorchModelHubMixin
from models_cifm.mlp_and_gnn import MLPBiasFree
from models_cifm.egnn_void_invariant import VIEGNNModel
class CIFM(
nn.Module,
PyTorchModelHubMixin,
# optionally, you can add metadata which gets pushed to the model card
repo_url='ynyou/CIFM',
pipeline_tag='mask-generation',
license='mit',
):
def __init__(self, args):
super().__init__()
self.gene_encoder = MLPBiasFree(in_dim=args.in_dim, out_dim=args.hidden_dim, hidden_dim=args.hidden_dim, num_layer=args.num_mlp_layers_in_module)
self.model = VIEGNNModel(num_layers=args.num_layer, num_mlp_layers_in_module=args.num_mlp_layers_in_module,
emb_dim=args.hidden_dim, in_dim=args.hidden_dim, out_dim=args.hidden_dim, residual=True)
self.mask_cell_decoder = VIEGNNModel(num_layers=args.num_layer, num_mlp_layers_in_module=args.num_mlp_layers_in_module,
emb_dim=args.hidden_dim, in_dim=args.hidden_dim, out_dim=args.hidden_dim, residual=False)
self.mask_cell_expression = MLPBiasFree(in_dim=args.hidden_dim, out_dim=args.in_dim, hidden_dim=args.hidden_dim, num_layer=args.num_mlp_layers_in_module)
self.mask_cell_dropout = MLPBiasFree(in_dim=args.hidden_dim, out_dim=args.in_dim, hidden_dim=args.hidden_dim, num_layer=args.num_mlp_layers_in_module)
self.mask_embedding = nn.Embedding(1, args.hidden_dim)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
self.hidden_dim = args.hidden_dim
self.radius_spatial_graph = args.radius_spatial_graph
def channel_matching(self, channel2ensembl_ids_target, channel2ensembl_ids_source, zero_init_for_unmatched_genes=True):
# channel2ensembl_ids_target = [[i] for i in adata.var.index.tolist()]
linear_in = nn.Linear(len(channel2ensembl_ids_target), self.hidden_dim, bias=False)
linear_out1 = nn.Linear(self.hidden_dim, len(channel2ensembl_ids_target), bias=False)
linear_out2 = nn.Linear(self.hidden_dim, len(channel2ensembl_ids_target), bias=False)
if zero_init_for_unmatched_genes:
linear_in.weight.data.zero_()
linear_out1.weight.data.zero_()
linear_out2.weight.data.zero_()
num_matching = 0
unmatched_channels = []
for idx_target, ensembls in enumerate(channel2ensembl_ids_target):
if len(ensembls) == 0:
continue
embs_in = []
embs_out1 = []
embs_out2 = []
for ensembl in ensembls:
for idx_source, ensembles2 in enumerate(channel2ensembl_ids_source):
if ensembl in ensembles2:
embs_in.append(self.gene_encoder.layers[0].weight.data[:, idx_source])
embs_out1.append(self.mask_cell_expression.layers[-1].weight.data[idx_source])
embs_out2.append(self.mask_cell_dropout.layers[-1].weight.data[idx_source])
if len(embs_in) == 0:
unmatched_channels += ensembls
continue
embs_in = torch.stack(embs_in).mean(dim=0)
embs_out1 = torch.stack(embs_out1).mean(dim=0)
embs_out2 = torch.stack(embs_out2).mean(dim=0)
linear_in.weight.data[:, idx_target] = embs_in
linear_out1.weight.data[idx_target] = embs_out1
linear_out2.weight.data[idx_target] = embs_out2
num_matching += 1
self.gene_encoder.layers[0] = linear_in
self.mask_cell_expression.layers[-1] = linear_out1
self.mask_cell_dropout.layers[-1] = linear_out2
unmatched_channels = list(set(unmatched_channels))
print('matching', num_matching, 'gene channels out of', len(channel2ensembl_ids_target), '; unmatched channels:', unmatched_channels)
def forward(self):
pass
def encode(self, expressions, coordinates, edge_index):
embeddings = self.gene_encoder(expressions)
embeddings, _ = self.model(embeddings, coordinates, edge_index)
return embeddings
def encode_decode(self, expressions, coordinates, edge_index, mapping):
device = expressions.device
embeddings = self.encode(expressions, coordinates, edge_index)
embeddings[mapping] = self.mask_embedding(torch.zeros(1, dtype=torch.int64).to(device))
embeddings_dec = self.mask_cell_decoder(embeddings, coordinates, edge_index)[0][mapping]
expressions_dec = self.relu(self.mask_cell_expression(embeddings_dec))
dropouts_dec = self.sigmoid(self.mask_cell_dropout(embeddings_dec))
expressions_dec[dropouts_dec<=0.5] = 0
return expressions_dec
def embed(self, adata):
device = next(self.parameters()).device
expressions = torch.tensor(adata.X.toarray(), dtype=torch.float32).to(device)
coordinates = torch.tensor(adata.obsm['spatial'], dtype=torch.float32)
coordinates = torch.cat([coordinates, torch.zeros(coordinates.shape[0], 1)], dim=1).to(device)
edge_index = radius_graph(coordinates, r=self.radius_spatial_graph, max_num_neighbors=10000, loop=True)
embeddings = self.encode(expressions, coordinates, edge_index)
return embeddings
def predict_cells_at_locations(self, adata, locations):
device = next(self.parameters()).device
locations = torch.tensor(locations, dtype=torch.float32).to(device)
expressions = torch.tensor(adata.X.toarray(), dtype=torch.float32).to(device)
expressions = torch.cat([expressions, torch.zeros(locations.shape[0], expressions.shape[1])], dim=0)
coordinates = torch.tensor(adata.obsm['spatial'], dtype=torch.float32)
coordinates = torch.cat([coordinates, locations], dim=0)
coordinates = torch.cat([coordinates, torch.zeros(coordinates.shape[0], 1)], dim=1).to(device)
edge_index = radius_graph(coordinates, r=self.radius_spatial_graph, max_num_neighbors=10000, loop=True)
idx_cells_to_predict = torch.arange(expressions.shape[0]-locations.shape[0], expressions.shape[0]).to(device)
expressions_pred = self.encode_decode(expressions, coordinates, edge_index, idx_cells_to_predict)
return expressions_pred