Yuning You
commited on
Commit
·
c23f8b5
1
Parent(s):
c18ea1c
update
Browse files- README.md +2 -2
- models_cifm/cifm.py +7 -6
README.md
CHANGED
@@ -13,14 +13,14 @@ tags:
|
|
13 |
# CI-FM: Cellular Interaction Foundation Model
|
14 |
|
15 |
## Overview
|
16 |
-
This is the PyTorch implementation of the CI-FM model -- an AI model that can simulate the
|
17 |
- **Embedding** of celllular microenvironments via ```embeddings = model.embed(adata)``` (Figure below panel D top);
|
18 |
- **Inference** of cellular gene expressions at a certain microenvironment via ```expressions = model.predict_cells_at_locations(adata, rand_loc)``` (Figure below panel D bottom).
|
19 |
|
20 |
The detailed usage of the model can be found in the [tutorial](https://huggingface.co/ynyou/CIFM/blob/main/test.ipynb).
|
21 |
Before running the tutorial, please set up an environment following the [environment instruction](https://huggingface.co/ynyou/CIFM#environment).
|
22 |
|
23 |
-
More information about the model can be found in the [preprint]().
|
24 |
|
25 |

|
26 |
|
|
|
13 |
# CI-FM: Cellular Interaction Foundation Model
|
14 |
|
15 |
## Overview
|
16 |
+
This is the PyTorch implementation of the CI-FM model -- an AI model that can simulate the biological activities within a living tissue (AI virtual tissue). The signature functions of CI-FM are:
|
17 |
- **Embedding** of celllular microenvironments via ```embeddings = model.embed(adata)``` (Figure below panel D top);
|
18 |
- **Inference** of cellular gene expressions at a certain microenvironment via ```expressions = model.predict_cells_at_locations(adata, rand_loc)``` (Figure below panel D bottom).
|
19 |
|
20 |
The detailed usage of the model can be found in the [tutorial](https://huggingface.co/ynyou/CIFM/blob/main/test.ipynb).
|
21 |
Before running the tutorial, please set up an environment following the [environment instruction](https://huggingface.co/ynyou/CIFM#environment).
|
22 |
|
23 |
+
<!-- More information about the model can be found in the [preprint](). -->
|
24 |
|
25 |

|
26 |
|
models_cifm/cifm.py
CHANGED
@@ -33,7 +33,7 @@ class CIFM(
|
|
33 |
self.radius_spatial_graph = args.radius_spatial_graph
|
34 |
|
35 |
def channel_matching(self, channel2ensembl_ids_target, channel2ensembl_ids_source, zero_init_for_unmatched_genes=True):
|
36 |
-
|
37 |
|
38 |
linear_in = nn.Linear(len(channel2ensembl_ids_target), self.hidden_dim, bias=False)
|
39 |
linear_out1 = nn.Linear(self.hidden_dim, len(channel2ensembl_ids_target), bias=False)
|
@@ -72,9 +72,9 @@ class CIFM(
|
|
72 |
|
73 |
num_matching += 1
|
74 |
|
75 |
-
self.gene_encoder.layers[0] = linear_in
|
76 |
-
self.mask_cell_expression.layers[-1] = linear_out1
|
77 |
-
self.mask_cell_dropout.layers[-1] = linear_out2
|
78 |
|
79 |
unmatched_channels = list(set(unmatched_channels))
|
80 |
print('matching', num_matching, 'gene channels out of', len(channel2ensembl_ids_target), '; unmatched channels:', unmatched_channels)
|
@@ -113,10 +113,11 @@ class CIFM(
|
|
113 |
|
114 |
def predict_cells_at_locations(self, adata, locations):
|
115 |
device = next(self.parameters()).device
|
116 |
-
|
|
|
117 |
|
118 |
expressions = torch.tensor(adata.X.toarray(), dtype=torch.float32).to(device)
|
119 |
-
expressions = torch.cat([expressions, torch.zeros(locations.shape[0], expressions.shape[1])], dim=0)
|
120 |
coordinates = torch.tensor(adata.obsm['spatial'], dtype=torch.float32)
|
121 |
coordinates = torch.cat([coordinates, locations], dim=0)
|
122 |
coordinates = torch.cat([coordinates, torch.zeros(coordinates.shape[0], 1)], dim=1).to(device)
|
|
|
33 |
self.radius_spatial_graph = args.radius_spatial_graph
|
34 |
|
35 |
def channel_matching(self, channel2ensembl_ids_target, channel2ensembl_ids_source, zero_init_for_unmatched_genes=True):
|
36 |
+
device = next(self.parameters()).device
|
37 |
|
38 |
linear_in = nn.Linear(len(channel2ensembl_ids_target), self.hidden_dim, bias=False)
|
39 |
linear_out1 = nn.Linear(self.hidden_dim, len(channel2ensembl_ids_target), bias=False)
|
|
|
72 |
|
73 |
num_matching += 1
|
74 |
|
75 |
+
self.gene_encoder.layers[0] = linear_in.to(device)
|
76 |
+
self.mask_cell_expression.layers[-1] = linear_out1.to(device)
|
77 |
+
self.mask_cell_dropout.layers[-1] = linear_out2.to(device)
|
78 |
|
79 |
unmatched_channels = list(set(unmatched_channels))
|
80 |
print('matching', num_matching, 'gene channels out of', len(channel2ensembl_ids_target), '; unmatched channels:', unmatched_channels)
|
|
|
113 |
|
114 |
def predict_cells_at_locations(self, adata, locations):
|
115 |
device = next(self.parameters()).device
|
116 |
+
|
117 |
+
locations = torch.tensor(locations, dtype=torch.float32)
|
118 |
|
119 |
expressions = torch.tensor(adata.X.toarray(), dtype=torch.float32).to(device)
|
120 |
+
expressions = torch.cat([expressions, torch.zeros(locations.shape[0], expressions.shape[1]).to(device)], dim=0)
|
121 |
coordinates = torch.tensor(adata.obsm['spatial'], dtype=torch.float32)
|
122 |
coordinates = torch.cat([coordinates, locations], dim=0)
|
123 |
coordinates = torch.cat([coordinates, torch.zeros(coordinates.shape[0], 1)], dim=1).to(device)
|