Yuning You commited on
Commit
2f63b5b
·
1 Parent(s): 9565da7
models/cifm.py CHANGED
@@ -32,8 +32,8 @@ class CIFM(
32
  self.hidden_dim = args.hidden_dim
33
  self.radius_spatial_graph = args.radius_spatial_graph
34
 
35
- def channel_matching(self, adata, channel2ensembl_ids_source, zero_init_for_unmatched_genes=True):
36
- channel2ensembl_ids_target = [[i] for i in adata.var.index.tolist()]
37
 
38
  linear_in = nn.Linear(len(channel2ensembl_ids_target), self.hidden_dim, bias=False)
39
  linear_out1 = nn.Linear(self.hidden_dim, len(channel2ensembl_ids_target), bias=False)
@@ -97,8 +97,6 @@ class CIFM(
97
  expressions_dec = self.relu(self.mask_cell_expression(embeddings_dec))
98
  dropouts_dec = self.sigmoid(self.mask_cell_dropout(embeddings_dec))
99
 
100
- # import pdb ; pdb.set_trace()
101
-
102
  expressions_dec[dropouts_dec<=0.5] = 0
103
  return expressions_dec
104
 
 
32
  self.hidden_dim = args.hidden_dim
33
  self.radius_spatial_graph = args.radius_spatial_graph
34
 
35
+ def channel_matching(self, channel2ensembl_ids_target, channel2ensembl_ids_source, zero_init_for_unmatched_genes=True):
36
+ # channel2ensembl_ids_target = [[i] for i in adata.var.index.tolist()]
37
 
38
  linear_in = nn.Linear(len(channel2ensembl_ids_target), self.hidden_dim, bias=False)
39
  linear_out1 = nn.Linear(self.hidden_dim, len(channel2ensembl_ids_target), bias=False)
 
97
  expressions_dec = self.relu(self.mask_cell_expression(embeddings_dec))
98
  dropouts_dec = self.sigmoid(self.mask_cell_dropout(embeddings_dec))
99
 
 
 
100
  expressions_dec[dropouts_dec<=0.5] = 0
101
  return expressions_dec
102
 
models/egnn_void_invariant.py CHANGED
@@ -48,15 +48,6 @@ class VIEGNNModel(torch.nn.Module):
48
  self.convs = torch.nn.ModuleList()
49
  for _ in range(num_layers):
50
  self.convs.append(EGNNLayer(emb_dim, num_mlp_layers_in_module, aggr))
51
-
52
- # MLP predictor for invariant tasks using only scalar features
53
- # self.pred = torch.nn.Sequential(
54
- # torch.nn.Linear(emb_dim, emb_dim, bias=False),
55
- # torch.nn.ReLU(),
56
- # torch.nn.Linear(emb_dim, out_dim, bias=False)
57
- # )
58
- # layers = [torch.nn.Linear(emb_dim, emb_dim, bias=False), torch.nn.ReLU()] * (num_mlp_layers_in_module-1) + [torch.nn.Linear(emb_dim, out_dim, bias=False)]
59
- # self.pred = torch.nn.Sequential(*layers)
60
  self.pred = MLPBiasFree(in_dim=emb_dim, out_dim=out_dim, hidden_dim=emb_dim, num_layer=num_mlp_layers_in_module)
61
 
62
  # unroll the batch argments and comment out the pooling operation
 
48
  self.convs = torch.nn.ModuleList()
49
  for _ in range(num_layers):
50
  self.convs.append(EGNNLayer(emb_dim, num_mlp_layers_in_module, aggr))
 
 
 
 
 
 
 
 
 
51
  self.pred = MLPBiasFree(in_dim=emb_dim, out_dim=out_dim, hidden_dim=emb_dim, num_layer=num_mlp_layers_in_module)
52
 
53
  # unroll the batch argments and comment out the pooling operation
models/layers/__init__.py DELETED
File without changes
models/layers/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (170 Bytes)
 
models/layers/__pycache__/egnn_layer_void_invariant.cpython-311.pyc DELETED
Binary file (4.8 kB)
 
models/layers/egnn_layer_void_invariant.py CHANGED
@@ -23,49 +23,11 @@ class EGNNLayer(MessagePassing):
23
  super().__init__(aggr=aggr)
24
 
25
  self.emb_dim = emb_dim
26
- # self.activation = ReLU()
27
 
28
  self.dist_embedding = Linear(1, emb_dim, bias=False)
29
  self.innerprod_embedding = MLPBiasFree(in_dim=1, out_dim=1, hidden_dim=emb_dim, num_layer=num_mlp_layers)
30
-
31
- # MLP `\psi_h` for computing messages `m_ij`
32
- # self.mlp_msg = Sequential(
33
- # Linear(2 * emb_dim + 1, emb_dim, bias=False),
34
- # torch.nn.LayerNorm(emb_dim, bias=False),
35
- # self.activation,
36
- # Linear(emb_dim, emb_dim, bias=False),
37
- # torch.nn.LayerNorm(emb_dim, bias=False),
38
- # self.activation,
39
- # )
40
- # layers = [Linear(2 * emb_dim + 1, emb_dim, bias=False), torch.nn.LayerNorm(emb_dim, bias=False), self.activation] \
41
- # + [Linear(emb_dim, emb_dim, bias=False), torch.nn.LayerNorm(emb_dim, bias=False), self.activation] * (num_mlp_layers-1)
42
- # layers = [Linear(3 * emb_dim, emb_dim, bias=False)] \
43
- # + [self.activation, Linear(emb_dim, emb_dim, bias=False)] * (num_mlp_layers-1) \
44
- # + [torch.nn.LayerNorm(emb_dim, bias=False)]
45
- # self.mlp_msg = Sequential(*layers)
46
  self.mlp_msg = MLPBiasFree(in_dim=3*emb_dim, out_dim=emb_dim, hidden_dim=emb_dim, num_layer=num_mlp_layers)
47
-
48
- # MLP `\psi_x` for computing messages `\overrightarrow{m}_ij`
49
- # self.mlp_pos = Sequential(
50
- # Linear(emb_dim, emb_dim), torch.nn.LayerNorm(emb_dim), self.activation, Linear(emb_dim, 1)
51
- # )
52
- # layers = [Linear(emb_dim, emb_dim, bias=False), torch.nn.LayerNorm(emb_dim, bias=False), self.activation] * (num_mlp_layers-1) + [Linear(emb_dim, 1, bias=False)]
53
- # layers = [Linear(emb_dim, emb_dim, bias=False), self.activation] * (num_mlp_layers-1) + [Linear(emb_dim, 1, bias=False)]
54
- # self.mlp_pos = Sequential(*layers)
55
  self.mlp_pos = MLPBiasFree(in_dim=emb_dim, out_dim=1, hidden_dim=emb_dim, num_layer=num_mlp_layers)
56
-
57
- # MLP `\phi` for computing updated node features `h_i^{l+1}`
58
- # self.mlp_upd = Sequential(
59
- # Linear(2 * emb_dim, emb_dim, bias=False),
60
- # torch.nn.LayerNorm(emb_dim, bias=False),
61
- # self.activation,
62
- # Linear(emb_dim, emb_dim, bias=False),
63
- # torch.nn.LayerNorm(emb_dim, bias=False),
64
- # self.activation,
65
- # )
66
- # layers = [Linear(emb_dim, emb_dim, bias=False), torch.nn.LayerNorm(emb_dim, bias=False), self.activation] * num_mlp_layers
67
- # layers = [Linear(emb_dim, emb_dim, bias=False)] + [self.activation, Linear(emb_dim, emb_dim, bias=False)] * (num_mlp_layers-1)
68
- # self.mlp_upd = Sequential(*layers)
69
  self.mlp_upd = MLPBiasFree(in_dim=emb_dim, out_dim=emb_dim, hidden_dim=emb_dim, num_layer=num_mlp_layers)
70
 
71
  def forward(self, h, pos, edge_index):
@@ -83,7 +45,6 @@ class EGNNLayer(MessagePassing):
83
  def message(self, h_i, h_j, pos_i, pos_j):
84
  # Compute messages
85
  pos_diff = pos_i - pos_j
86
- # dists = torch.norm(pos_diff, dim=-1).unsqueeze(1)
87
  dists = torch.exp(- torch.norm(pos_diff, dim=-1).unsqueeze(1) / 30 ) # reference distances: 30um
88
  inner_prod = torch.mean(h_i * h_j, dim=-1).unsqueeze(1)
89
  msg = torch.cat([h_i, h_j, self.dist_embedding(dists)], dim=-1) * self.innerprod_embedding(inner_prod)
@@ -91,10 +52,6 @@ class EGNNLayer(MessagePassing):
91
  # Scale magnitude of displacement vector
92
  pos_diff = pos_diff * self.mlp_pos(msg)
93
  # NOTE: some papers divide pos_diff by (dists + 1) to stabilise model.
94
- # NOTE: lucidrains clamps pos_diff between some [-n, +n], also for stability.
95
- # print(torch.cat([h_i, h_j, self.dist_embedding(dists)], dim=-1))
96
- # print(msg)
97
- # import pdb; pdb.set_trace()
98
  return msg, pos_diff, inner_prod
99
 
100
  def aggregate(self, inputs, index):
@@ -109,17 +66,12 @@ class EGNNLayer(MessagePassing):
109
  counts = scatter(counts, index, dim=0, reduce="add")
110
  counts[counts==0] = 1
111
  pos_aggr = pos_aggr / counts
112
- # print(msgs)
113
- # print(msg_aggr)
114
- # import pdb; pdb.set_trace()
115
  return msg_aggr, pos_aggr
116
 
117
  def update(self, aggr_out, h, pos):
118
  msg_aggr, pos_aggr = aggr_out
119
- # upd_out = self.mlp_upd(torch.cat([h, msg_aggr], dim=-1))
120
  upd_out = self.mlp_upd(msg_aggr)
121
  upd_pos = pos + pos_aggr
122
- # import pdb; pdb.set_trace()
123
  return upd_out, upd_pos
124
 
125
  def __repr__(self) -> str:
 
23
  super().__init__(aggr=aggr)
24
 
25
  self.emb_dim = emb_dim
 
26
 
27
  self.dist_embedding = Linear(1, emb_dim, bias=False)
28
  self.innerprod_embedding = MLPBiasFree(in_dim=1, out_dim=1, hidden_dim=emb_dim, num_layer=num_mlp_layers)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  self.mlp_msg = MLPBiasFree(in_dim=3*emb_dim, out_dim=emb_dim, hidden_dim=emb_dim, num_layer=num_mlp_layers)
 
 
 
 
 
 
 
 
30
  self.mlp_pos = MLPBiasFree(in_dim=emb_dim, out_dim=1, hidden_dim=emb_dim, num_layer=num_mlp_layers)
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  self.mlp_upd = MLPBiasFree(in_dim=emb_dim, out_dim=emb_dim, hidden_dim=emb_dim, num_layer=num_mlp_layers)
32
 
33
  def forward(self, h, pos, edge_index):
 
45
  def message(self, h_i, h_j, pos_i, pos_j):
46
  # Compute messages
47
  pos_diff = pos_i - pos_j
 
48
  dists = torch.exp(- torch.norm(pos_diff, dim=-1).unsqueeze(1) / 30 ) # reference distances: 30um
49
  inner_prod = torch.mean(h_i * h_j, dim=-1).unsqueeze(1)
50
  msg = torch.cat([h_i, h_j, self.dist_embedding(dists)], dim=-1) * self.innerprod_embedding(inner_prod)
 
52
  # Scale magnitude of displacement vector
53
  pos_diff = pos_diff * self.mlp_pos(msg)
54
  # NOTE: some papers divide pos_diff by (dists + 1) to stabilise model.
 
 
 
 
55
  return msg, pos_diff, inner_prod
56
 
57
  def aggregate(self, inputs, index):
 
66
  counts = scatter(counts, index, dim=0, reduce="add")
67
  counts[counts==0] = 1
68
  pos_aggr = pos_aggr / counts
 
 
 
69
  return msg_aggr, pos_aggr
70
 
71
  def update(self, aggr_out, h, pos):
72
  msg_aggr, pos_aggr = aggr_out
 
73
  upd_out = self.mlp_upd(msg_aggr)
74
  upd_pos = pos + pos_aggr
 
75
  return upd_out, upd_pos
76
 
77
  def __repr__(self) -> str:
test.ipynb CHANGED
@@ -12,9 +12,16 @@
12
  "import scanpy as sc"
13
  ]
14
  },
 
 
 
 
 
 
 
15
  {
16
  "cell_type": "code",
17
- "execution_count": 2,
18
  "metadata": {},
19
  "outputs": [
20
  {
@@ -94,13 +101,25 @@
94
  ],
95
  "source": [
96
  "args_model = torch.load('./model_files/args.pt')\n",
97
- "model = CIFM.from_pretrained('ynyou/CIFM', args=args_model)\n",
 
98
  "model.eval()"
99
  ]
100
  },
 
 
 
 
 
 
 
 
 
 
 
101
  {
102
  "cell_type": "code",
103
- "execution_count": 3,
104
  "metadata": {},
105
  "outputs": [
106
  {
@@ -120,7 +139,6 @@
120
  }
121
  ],
122
  "source": [
123
- "channel2ensembl = torch.load('./model_files/channel2ensembl.pt')\n",
124
  "adata = sc.read_h5ad('./adata.h5ad')\n",
125
  "adata.layers['counts'] = adata.X.copy()\n",
126
  "sc.pp.normalize_total(adata)\n",
@@ -128,9 +146,20 @@
128
  "adata"
129
  ]
130
  },
 
 
 
 
 
 
 
 
 
 
 
131
  {
132
  "cell_type": "code",
133
- "execution_count": 4,
134
  "metadata": {},
135
  "outputs": [
136
  {
@@ -142,7 +171,16 @@
142
  }
143
  ],
144
  "source": [
145
- "model.channel_matching(adata, channel2ensembl)"
 
 
 
 
 
 
 
 
 
146
  ]
147
  },
148
  {
@@ -174,6 +212,13 @@
174
  "embeddings, embeddings.shape"
175
  ]
176
  },
 
 
 
 
 
 
 
177
  {
178
  "cell_type": "code",
179
  "execution_count": 5,
 
12
  "import scanpy as sc"
13
  ]
14
  },
15
+ {
16
+ "cell_type": "markdown",
17
+ "metadata": {},
18
+ "source": [
19
+ "### 1. load model"
20
+ ]
21
+ },
22
  {
23
  "cell_type": "code",
24
+ "execution_count": null,
25
  "metadata": {},
26
  "outputs": [
27
  {
 
101
  ],
102
  "source": [
103
  "args_model = torch.load('./model_files/args.pt')\n",
104
+ "device = 'cpu' # or 'cuda\n",
105
+ "model = CIFM.from_pretrained('ynyou/CIFM', args=args_model).to(device)\n",
106
  "model.eval()"
107
  ]
108
  },
109
+ {
110
+ "cell_type": "markdown",
111
+ "metadata": {},
112
+ "source": [
113
+ "### 2. load and preprocess sample adata\n",
114
+ "- some requirements for adata:\n",
115
+ "- ```adata.X```: need to the raw count\n",
116
+ "- ```adata.obsm['spatial']```: the coordinates of cells in the unit of micrometer\n",
117
+ "- if in a different unit, it might result in a weird geometric graph: we use a radius 20 (micrometer) to construct the geometric graph in the model, so a different unit might result in a overly sparse or dense graph"
118
+ ]
119
+ },
120
  {
121
  "cell_type": "code",
122
+ "execution_count": null,
123
  "metadata": {},
124
  "outputs": [
125
  {
 
139
  }
140
  ],
141
  "source": [
 
142
  "adata = sc.read_h5ad('./adata.h5ad')\n",
143
  "adata.layers['counts'] = adata.X.copy()\n",
144
  "sc.pp.normalize_total(adata)\n",
 
146
  "adata"
147
  ]
148
  },
149
+ {
150
+ "cell_type": "markdown",
151
+ "metadata": {},
152
+ "source": [
153
+ "### 3. match feature channels\n",
154
+ "- we need a list which maps feature channels to ensemble ids: ```channel2ensembl_ids_target```\n",
155
+ "- format: ```channel2ensembl_ids_target = [[ensemblid1_for_channel1, ensemblid1_for_channel1, ...], [ensemblid11_for_channel2, ensemblid12_for_channel2, ...], ...]```\n",
156
+ "- one channel could correspond to multiple ensemble ids, e.g., when your original data the channels are annotated with gene names\n",
157
+ "- you can use to BioMart map you each gene name to one or multiple ensemble ids"
158
+ ]
159
+ },
160
  {
161
  "cell_type": "code",
162
+ "execution_count": null,
163
  "metadata": {},
164
  "outputs": [
165
  {
 
171
  }
172
  ],
173
  "source": [
174
+ "channel2ensembl_ids_target = [[i] for i in adata.var.index.tolist()]\n",
175
+ "channel2ensembl_ids_source = torch.load('./model_files/channel2ensembl.pt')\n",
176
+ "model.channel_matching(channel2ensembl_ids_target, channel2ensembl_ids_source)"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "markdown",
181
+ "metadata": {},
182
+ "source": [
183
+ "### 4. embed the microenvironments centered at each cell"
184
  ]
185
  },
186
  {
 
212
  "embeddings, embeddings.shape"
213
  ]
214
  },
215
+ {
216
+ "cell_type": "markdown",
217
+ "metadata": {},
218
+ "source": [
219
+ "### 5. infer the potential gene expressions at certain locations"
220
+ ]
221
+ },
222
  {
223
  "cell_type": "code",
224
  "execution_count": 5,