Yuning You commited on
Commit
c23f8b5
·
1 Parent(s): c18ea1c
Files changed (2) hide show
  1. README.md +2 -2
  2. models_cifm/cifm.py +7 -6
README.md CHANGED
@@ -13,14 +13,14 @@ tags:
13
  # CI-FM: Cellular Interaction Foundation Model
14
 
15
  ## Overview
16
- This is the PyTorch implementation of the CI-FM model -- an AI model that can simulate the cellular activities within a living tissue (AI virtual tissue). The signature functions of CI-FM are:
17
  - **Embedding** of celllular microenvironments via ```embeddings = model.embed(adata)``` (Figure below panel D top);
18
  - **Inference** of cellular gene expressions at a certain microenvironment via ```expressions = model.predict_cells_at_locations(adata, rand_loc)``` (Figure below panel D bottom).
19
 
20
  The detailed usage of the model can be found in the [tutorial](https://huggingface.co/ynyou/CIFM/blob/main/test.ipynb).
21
  Before running the tutorial, please set up an environment following the [environment instruction](https://huggingface.co/ynyou/CIFM#environment).
22
 
23
- More information about the model can be found in the [preprint]().
24
 
25
  ![](./figures/cifm.png)
26
 
 
13
  # CI-FM: Cellular Interaction Foundation Model
14
 
15
  ## Overview
16
+ This is the PyTorch implementation of the CI-FM model -- an AI model that can simulate the biological activities within a living tissue (AI virtual tissue). The signature functions of CI-FM are:
17
  - **Embedding** of celllular microenvironments via ```embeddings = model.embed(adata)``` (Figure below panel D top);
18
  - **Inference** of cellular gene expressions at a certain microenvironment via ```expressions = model.predict_cells_at_locations(adata, rand_loc)``` (Figure below panel D bottom).
19
 
20
  The detailed usage of the model can be found in the [tutorial](https://huggingface.co/ynyou/CIFM/blob/main/test.ipynb).
21
  Before running the tutorial, please set up an environment following the [environment instruction](https://huggingface.co/ynyou/CIFM#environment).
22
 
23
+ <!-- More information about the model can be found in the [preprint](). -->
24
 
25
  ![](./figures/cifm.png)
26
 
models_cifm/cifm.py CHANGED
@@ -33,7 +33,7 @@ class CIFM(
33
  self.radius_spatial_graph = args.radius_spatial_graph
34
 
35
  def channel_matching(self, channel2ensembl_ids_target, channel2ensembl_ids_source, zero_init_for_unmatched_genes=True):
36
- # channel2ensembl_ids_target = [[i] for i in adata.var.index.tolist()]
37
 
38
  linear_in = nn.Linear(len(channel2ensembl_ids_target), self.hidden_dim, bias=False)
39
  linear_out1 = nn.Linear(self.hidden_dim, len(channel2ensembl_ids_target), bias=False)
@@ -72,9 +72,9 @@ class CIFM(
72
 
73
  num_matching += 1
74
 
75
- self.gene_encoder.layers[0] = linear_in
76
- self.mask_cell_expression.layers[-1] = linear_out1
77
- self.mask_cell_dropout.layers[-1] = linear_out2
78
 
79
  unmatched_channels = list(set(unmatched_channels))
80
  print('matching', num_matching, 'gene channels out of', len(channel2ensembl_ids_target), '; unmatched channels:', unmatched_channels)
@@ -113,10 +113,11 @@ class CIFM(
113
 
114
  def predict_cells_at_locations(self, adata, locations):
115
  device = next(self.parameters()).device
116
- locations = torch.tensor(locations, dtype=torch.float32).to(device)
 
117
 
118
  expressions = torch.tensor(adata.X.toarray(), dtype=torch.float32).to(device)
119
- expressions = torch.cat([expressions, torch.zeros(locations.shape[0], expressions.shape[1])], dim=0)
120
  coordinates = torch.tensor(adata.obsm['spatial'], dtype=torch.float32)
121
  coordinates = torch.cat([coordinates, locations], dim=0)
122
  coordinates = torch.cat([coordinates, torch.zeros(coordinates.shape[0], 1)], dim=1).to(device)
 
33
  self.radius_spatial_graph = args.radius_spatial_graph
34
 
35
  def channel_matching(self, channel2ensembl_ids_target, channel2ensembl_ids_source, zero_init_for_unmatched_genes=True):
36
+ device = next(self.parameters()).device
37
 
38
  linear_in = nn.Linear(len(channel2ensembl_ids_target), self.hidden_dim, bias=False)
39
  linear_out1 = nn.Linear(self.hidden_dim, len(channel2ensembl_ids_target), bias=False)
 
72
 
73
  num_matching += 1
74
 
75
+ self.gene_encoder.layers[0] = linear_in.to(device)
76
+ self.mask_cell_expression.layers[-1] = linear_out1.to(device)
77
+ self.mask_cell_dropout.layers[-1] = linear_out2.to(device)
78
 
79
  unmatched_channels = list(set(unmatched_channels))
80
  print('matching', num_matching, 'gene channels out of', len(channel2ensembl_ids_target), '; unmatched channels:', unmatched_channels)
 
113
 
114
  def predict_cells_at_locations(self, adata, locations):
115
  device = next(self.parameters()).device
116
+
117
+ locations = torch.tensor(locations, dtype=torch.float32)
118
 
119
  expressions = torch.tensor(adata.X.toarray(), dtype=torch.float32).to(device)
120
+ expressions = torch.cat([expressions, torch.zeros(locations.shape[0], expressions.shape[1]).to(device)], dim=0)
121
  coordinates = torch.tensor(adata.obsm['spatial'], dtype=torch.float32)
122
  coordinates = torch.cat([coordinates, locations], dim=0)
123
  coordinates = torch.cat([coordinates, torch.zeros(coordinates.shape[0], 1)], dim=1).to(device)