Yuning You commited on
Commit
4981657
·
1 Parent(s): 0036d26
.gitattributes CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  adata.h5ad filter=lfs diff=lfs merge=lfs -text
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  adata.h5ad filter=lfs diff=lfs merge=lfs -text
37
+ figures/cifm.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -10,11 +10,11 @@ tags:
10
  - Library: ynyou/CIFM
11
  - Docs: [More Information Needed] -->
12
 
13
- # CI-FM: Cellular Interaction Foundation Model
14
 
15
  ## Overview
16
- This is the PyTorch implementation of the CI-FM model -- an AI model that can simulate the activities within a living tissue (AI virtual tissue).
17
- The current version of CI-FM has 138M parameters and is trained on around 23M cells of spatial genomics. The signature functions of CI-FM are:
18
  - **Embedding** of celllular microenvironments via ```embeddings = model.embed(adata)``` (the 1st Figure below panel D top);
19
  - **Inference/simulation** of cellular gene expressions within a certain microenvironment via ```expressions = model.predict_cells_at_locations(adata, target_locs)``` (the 1st Figure below panel D bottom, and the 2nd Figure below).
20
 
 
10
  - Library: ynyou/CIFM
11
  - Docs: [More Information Needed] -->
12
 
13
+ # CIFM: Cellular Interaction Foundation Model
14
 
15
  ## Overview
16
+ This is the PyTorch implementation of the CIFM model -- an AI model that can simulate the activities within a living tissue (AI virtual tissue).
17
+ The current version of CIFM has 138M parameters and is trained on around 23M cells of spatial genomics. The signature functions of CIFM are:
18
  - **Embedding** of celllular microenvironments via ```embeddings = model.embed(adata)``` (the 1st Figure below panel D top);
19
  - **Inference/simulation** of cellular gene expressions within a certain microenvironment via ```expressions = model.predict_cells_at_locations(adata, target_locs)``` (the 1st Figure below panel D bottom, and the 2nd Figure below).
20
 
figures/autoregressive.gif CHANGED
figures/cifm.png CHANGED

Git LFS Details

  • SHA256: 57927c560bd4a3625814dbead09178ea660637c07f9e1a3ab5e7787678f5998f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.64 MB
models_cifm/cifm.py CHANGED
@@ -19,12 +19,13 @@ class CIFM(
19
  super().__init__()
20
  self.gene_encoder = MLPBiasFree(in_dim=args.in_dim, out_dim=args.hidden_dim, hidden_dim=args.hidden_dim, num_layer=args.num_mlp_layers_in_module)
21
  self.model = VIEGNNModel(num_layers=args.num_layer, num_mlp_layers_in_module=args.num_mlp_layers_in_module,
22
- emb_dim=args.hidden_dim, in_dim=args.hidden_dim, out_dim=args.hidden_dim, residual=True)
23
  self.mask_cell_decoder = VIEGNNModel(num_layers=args.num_layer, num_mlp_layers_in_module=args.num_mlp_layers_in_module,
24
  emb_dim=args.hidden_dim, in_dim=args.hidden_dim, out_dim=args.hidden_dim, residual=False)
25
  self.mask_cell_expression = MLPBiasFree(in_dim=args.hidden_dim, out_dim=args.in_dim, hidden_dim=args.hidden_dim, num_layer=args.num_mlp_layers_in_module)
26
  self.mask_cell_dropout = MLPBiasFree(in_dim=args.hidden_dim, out_dim=args.in_dim, hidden_dim=args.hidden_dim, num_layer=args.num_mlp_layers_in_module)
27
  self.mask_embedding = nn.Embedding(1, args.hidden_dim)
 
28
 
29
  self.relu = nn.ReLU()
30
  self.sigmoid = nn.Sigmoid()
 
19
  super().__init__()
20
  self.gene_encoder = MLPBiasFree(in_dim=args.in_dim, out_dim=args.hidden_dim, hidden_dim=args.hidden_dim, num_layer=args.num_mlp_layers_in_module)
21
  self.model = VIEGNNModel(num_layers=args.num_layer, num_mlp_layers_in_module=args.num_mlp_layers_in_module,
22
+ emb_dim=args.hidden_dim, in_dim=args.hidden_dim, out_dim=args.hidden_dim, residual=False)
23
  self.mask_cell_decoder = VIEGNNModel(num_layers=args.num_layer, num_mlp_layers_in_module=args.num_mlp_layers_in_module,
24
  emb_dim=args.hidden_dim, in_dim=args.hidden_dim, out_dim=args.hidden_dim, residual=False)
25
  self.mask_cell_expression = MLPBiasFree(in_dim=args.hidden_dim, out_dim=args.in_dim, hidden_dim=args.hidden_dim, num_layer=args.num_mlp_layers_in_module)
26
  self.mask_cell_dropout = MLPBiasFree(in_dim=args.hidden_dim, out_dim=args.in_dim, hidden_dim=args.hidden_dim, num_layer=args.num_mlp_layers_in_module)
27
  self.mask_embedding = nn.Embedding(1, args.hidden_dim)
28
+ self.proj = MLPBiasFree(in_dim=args.hidden_dim, out_dim=1, hidden_dim=args.hidden_dim, num_layer=4)
29
 
30
  self.relu = nn.ReLU()
31
  self.sigmoid = nn.Sigmoid()
test.ipynb CHANGED
@@ -2,7 +2,7 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": null,
6
  "metadata": {},
7
  "outputs": [],
8
  "source": [
@@ -21,82 +21,22 @@
21
  },
22
  {
23
  "cell_type": "code",
24
- "execution_count": null,
25
  "metadata": {},
26
  "outputs": [
27
  {
28
  "data": {
 
 
 
 
 
29
  "text/plain": [
30
- "CIFM(\n",
31
- " (gene_encoder): MLPBiasFree(\n",
32
- " (layers): ModuleList(\n",
33
- " (0): Linear(in_features=18289, out_features=1024, bias=False)\n",
34
- " (1-3): 3 x Linear(in_features=1024, out_features=1024, bias=False)\n",
35
- " )\n",
36
- " (layernorms): ModuleList(\n",
37
- " (0-2): 3 x LayerNorm((1024,), eps=1e-05, elementwise_affine=False)\n",
38
- " )\n",
39
- " (activation): ReLU()\n",
40
- " )\n",
41
- " (model): VIEGNNModel(\n",
42
- " (emb_in): Linear(in_features=1024, out_features=1024, bias=False)\n",
43
- " (convs): ModuleList(\n",
44
- " (0-1): 2 x EGNNLayer(emb_dim=1024, aggr=sum)\n",
45
- " )\n",
46
- " (pred): MLPBiasFree(\n",
47
- " (layers): ModuleList(\n",
48
- " (0-3): 4 x Linear(in_features=1024, out_features=1024, bias=False)\n",
49
- " )\n",
50
- " (layernorms): ModuleList(\n",
51
- " (0-2): 3 x LayerNorm((1024,), eps=1e-05, elementwise_affine=False)\n",
52
- " )\n",
53
- " (activation): ReLU()\n",
54
- " )\n",
55
- " )\n",
56
- " (mask_cell_decoder): VIEGNNModel(\n",
57
- " (emb_in): Linear(in_features=1024, out_features=1024, bias=False)\n",
58
- " (convs): ModuleList(\n",
59
- " (0-1): 2 x EGNNLayer(emb_dim=1024, aggr=sum)\n",
60
- " )\n",
61
- " (pred): MLPBiasFree(\n",
62
- " (layers): ModuleList(\n",
63
- " (0-3): 4 x Linear(in_features=1024, out_features=1024, bias=False)\n",
64
- " )\n",
65
- " (layernorms): ModuleList(\n",
66
- " (0-2): 3 x LayerNorm((1024,), eps=1e-05, elementwise_affine=False)\n",
67
- " )\n",
68
- " (activation): ReLU()\n",
69
- " )\n",
70
- " )\n",
71
- " (mask_cell_expression): MLPBiasFree(\n",
72
- " (layers): ModuleList(\n",
73
- " (0-2): 3 x Linear(in_features=1024, out_features=1024, bias=False)\n",
74
- " (3): Linear(in_features=1024, out_features=18289, bias=False)\n",
75
- " )\n",
76
- " (layernorms): ModuleList(\n",
77
- " (0-2): 3 x LayerNorm((1024,), eps=1e-05, elementwise_affine=False)\n",
78
- " )\n",
79
- " (activation): ReLU()\n",
80
- " )\n",
81
- " (mask_cell_dropout): MLPBiasFree(\n",
82
- " (layers): ModuleList(\n",
83
- " (0-2): 3 x Linear(in_features=1024, out_features=1024, bias=False)\n",
84
- " (3): Linear(in_features=1024, out_features=18289, bias=False)\n",
85
- " )\n",
86
- " (layernorms): ModuleList(\n",
87
- " (0-2): 3 x LayerNorm((1024,), eps=1e-05, elementwise_affine=False)\n",
88
- " )\n",
89
- " (activation): ReLU()\n",
90
- " )\n",
91
- " (mask_embedding): Embedding(1, 1024)\n",
92
- " (relu): ReLU()\n",
93
- " (sigmoid): Sigmoid()\n",
94
- ")"
95
  ]
96
  },
97
- "execution_count": 2,
98
  "metadata": {},
99
- "output_type": "execute_result"
100
  }
101
  ],
102
  "source": [
@@ -123,7 +63,7 @@
123
  },
124
  {
125
  "cell_type": "code",
126
- "execution_count": null,
127
  "metadata": {},
128
  "outputs": [
129
  {
@@ -145,7 +85,7 @@
145
  "source": [
146
  "adata = sc.read_h5ad('./adata.h5ad')\n",
147
  "adata.layers['counts'] = adata.X.copy()\n",
148
- "sc.pp.normalize_total(adata)\n",
149
  "sc.pp.log1p(adata)\n",
150
  "adata"
151
  ]
@@ -163,14 +103,14 @@
163
  },
164
  {
165
  "cell_type": "code",
166
- "execution_count": null,
167
  "metadata": {},
168
  "outputs": [
169
  {
170
  "name": "stdout",
171
  "output_type": "stream",
172
  "text": [
173
- "matching 18289 gene channels out of 18289 unmatched channels: []\n"
174
  ]
175
  }
176
  ],
@@ -194,14 +134,14 @@
194
  {
195
  "data": {
196
  "text/plain": [
197
- "(tensor([[-0.4132, -0.9847, 0.1647, ..., -0.8351, -0.8177, -1.3235],\n",
198
- " [ 0.8701, 0.0967, -0.3676, ..., 0.2687, -1.4821, 0.1605],\n",
199
- " [-0.5178, -0.4442, -0.0862, ..., -0.7446, -0.5761, -0.5571],\n",
200
  " ...,\n",
201
- " [ 1.2264, 1.2326, 0.2791, ..., 0.8018, -1.4069, 1.4567],\n",
202
- " [ 0.6699, -0.6107, 0.2450, ..., -0.1975, -0.6034, -0.6608],\n",
203
- " [-1.9240, -1.8125, -0.0766, ..., -0.2799, -0.0217, -2.2051]]),\n",
204
- " torch.Size([13898, 1024]))"
205
  ]
206
  },
207
  "execution_count": 5,
@@ -224,23 +164,23 @@
224
  },
225
  {
226
  "cell_type": "code",
227
- "execution_count": null,
228
  "metadata": {},
229
  "outputs": [
230
  {
231
  "data": {
232
  "text/plain": [
233
- "(tensor([[0.0000, 0.0000, 0.8603, ..., 0.0000, 0.0000, 0.0000],\n",
234
- " [0.0000, 0.0000, 0.6644, ..., 0.0000, 0.0000, 0.0000],\n",
235
  " [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
236
  " ...,\n",
237
- " [0.0000, 0.0000, 0.9809, ..., 0.0000, 0.0000, 0.0000],\n",
238
- " [0.6641, 0.0000, 0.6858, ..., 0.0000, 0.0000, 0.0000],\n",
239
- " [0.4999, 0.0000, 0.5311, ..., 0.0000, 0.0000, 0.0000]]),\n",
240
  " torch.Size([10, 18289]))"
241
  ]
242
  },
243
- "execution_count": 5,
244
  "metadata": {},
245
  "output_type": "execute_result"
246
  }
@@ -260,9 +200,27 @@
260
  },
261
  {
262
  "cell_type": "code",
263
- "execution_count": null,
264
  "metadata": {},
265
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  "source": [
267
  "# you can convert it into normalize counts\n",
268
  "counts_normalized = np.exp(expressions) - 1\n",
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 1,
6
  "metadata": {},
7
  "outputs": [],
8
  "source": [
 
21
  },
22
  {
23
  "cell_type": "code",
24
+ "execution_count": 2,
25
  "metadata": {},
26
  "outputs": [
27
  {
28
  "data": {
29
+ "application/vnd.jupyter.widget-view+json": {
30
+ "model_id": "18d58ba0049e4560b7bd0916fbd6ea33",
31
+ "version_major": 2,
32
+ "version_minor": 0
33
+ },
34
  "text/plain": [
35
+ "model.safetensors: 0%| | 0.00/569M [00:00<?, ?B/s]"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  ]
37
  },
 
38
  "metadata": {},
39
+ "output_type": "display_data"
40
  }
41
  ],
42
  "source": [
 
63
  },
64
  {
65
  "cell_type": "code",
66
+ "execution_count": 3,
67
  "metadata": {},
68
  "outputs": [
69
  {
 
85
  "source": [
86
  "adata = sc.read_h5ad('./adata.h5ad')\n",
87
  "adata.layers['counts'] = adata.X.copy()\n",
88
+ "sc.pp.normalize_total(adata, target_sum=1e4)\n",
89
  "sc.pp.log1p(adata)\n",
90
  "adata"
91
  ]
 
103
  },
104
  {
105
  "cell_type": "code",
106
+ "execution_count": 4,
107
  "metadata": {},
108
  "outputs": [
109
  {
110
  "name": "stdout",
111
  "output_type": "stream",
112
  "text": [
113
+ "matching 18289 gene channels out of 18289 ; unmatched channels: []\n"
114
  ]
115
  }
116
  ],
 
134
  {
135
  "data": {
136
  "text/plain": [
137
+ "(tensor([[-0.4326, -0.8625, 0.1121, ..., 0.4980, 0.3855, -0.1965],\n",
138
+ " [-0.6833, -0.9950, 0.1927, ..., -0.2064, 0.6193, 0.0387],\n",
139
+ " [-0.2099, -0.9877, 0.3462, ..., 0.2102, 0.6807, -0.2155],\n",
140
  " ...,\n",
141
+ " [-0.0187, -0.8444, 0.3058, ..., 0.1030, 0.8362, -0.1859],\n",
142
+ " [-0.5535, -0.8201, 0.7805, ..., -0.1402, 0.5221, -0.3520],\n",
143
+ " [-0.9339, -0.8467, 0.0600, ..., 0.0406, 0.3608, 0.3418]]),\n",
144
+ " torch.Size([24844, 1024]))"
145
  ]
146
  },
147
  "execution_count": 5,
 
164
  },
165
  {
166
  "cell_type": "code",
167
+ "execution_count": 6,
168
  "metadata": {},
169
  "outputs": [
170
  {
171
  "data": {
172
  "text/plain": [
173
+ "(tensor([[0.0000, 0.0000, 2.8781, ..., 0.0000, 0.0000, 0.0000],\n",
174
+ " [0.0000, 0.0000, 2.9699, ..., 0.0000, 0.0000, 0.0000],\n",
175
  " [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
176
  " ...,\n",
177
+ " [0.0000, 0.0000, 3.2570, ..., 0.0000, 0.0000, 0.0000],\n",
178
+ " [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
179
+ " [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]),\n",
180
  " torch.Size([10, 18289]))"
181
  ]
182
  },
183
+ "execution_count": 6,
184
  "metadata": {},
185
  "output_type": "execute_result"
186
  }
 
200
  },
201
  {
202
  "cell_type": "code",
203
+ "execution_count": 7,
204
  "metadata": {},
205
+ "outputs": [
206
+ {
207
+ "data": {
208
+ "text/plain": [
209
+ "(tensor([[0.0000, 0.0000, 0.0002, ..., 0.0000, 0.0000, 0.0000],\n",
210
+ " [0.0000, 0.0000, 0.0002, ..., 0.0000, 0.0000, 0.0000],\n",
211
+ " [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
212
+ " ...,\n",
213
+ " [0.0000, 0.0000, 0.0003, ..., 0.0000, 0.0000, 0.0000],\n",
214
+ " [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
215
+ " [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]),\n",
216
+ " torch.Size([10, 18289]))"
217
+ ]
218
+ },
219
+ "execution_count": 7,
220
+ "metadata": {},
221
+ "output_type": "execute_result"
222
+ }
223
+ ],
224
  "source": [
225
  "# you can convert it into normalize counts\n",
226
  "counts_normalized = np.exp(expressions) - 1\n",