Yuning You
commited on
Commit
·
2f63b5b
1
Parent(s):
9565da7
update
Browse files
models/cifm.py
CHANGED
@@ -32,8 +32,8 @@ class CIFM(
|
|
32 |
self.hidden_dim = args.hidden_dim
|
33 |
self.radius_spatial_graph = args.radius_spatial_graph
|
34 |
|
35 |
-
def channel_matching(self,
|
36 |
-
channel2ensembl_ids_target = [[i] for i in adata.var.index.tolist()]
|
37 |
|
38 |
linear_in = nn.Linear(len(channel2ensembl_ids_target), self.hidden_dim, bias=False)
|
39 |
linear_out1 = nn.Linear(self.hidden_dim, len(channel2ensembl_ids_target), bias=False)
|
@@ -97,8 +97,6 @@ class CIFM(
|
|
97 |
expressions_dec = self.relu(self.mask_cell_expression(embeddings_dec))
|
98 |
dropouts_dec = self.sigmoid(self.mask_cell_dropout(embeddings_dec))
|
99 |
|
100 |
-
# import pdb ; pdb.set_trace()
|
101 |
-
|
102 |
expressions_dec[dropouts_dec<=0.5] = 0
|
103 |
return expressions_dec
|
104 |
|
|
|
32 |
self.hidden_dim = args.hidden_dim
|
33 |
self.radius_spatial_graph = args.radius_spatial_graph
|
34 |
|
35 |
+
def channel_matching(self, channel2ensembl_ids_target, channel2ensembl_ids_source, zero_init_for_unmatched_genes=True):
|
36 |
+
# channel2ensembl_ids_target = [[i] for i in adata.var.index.tolist()]
|
37 |
|
38 |
linear_in = nn.Linear(len(channel2ensembl_ids_target), self.hidden_dim, bias=False)
|
39 |
linear_out1 = nn.Linear(self.hidden_dim, len(channel2ensembl_ids_target), bias=False)
|
|
|
97 |
expressions_dec = self.relu(self.mask_cell_expression(embeddings_dec))
|
98 |
dropouts_dec = self.sigmoid(self.mask_cell_dropout(embeddings_dec))
|
99 |
|
|
|
|
|
100 |
expressions_dec[dropouts_dec<=0.5] = 0
|
101 |
return expressions_dec
|
102 |
|
models/egnn_void_invariant.py
CHANGED
@@ -48,15 +48,6 @@ class VIEGNNModel(torch.nn.Module):
|
|
48 |
self.convs = torch.nn.ModuleList()
|
49 |
for _ in range(num_layers):
|
50 |
self.convs.append(EGNNLayer(emb_dim, num_mlp_layers_in_module, aggr))
|
51 |
-
|
52 |
-
# MLP predictor for invariant tasks using only scalar features
|
53 |
-
# self.pred = torch.nn.Sequential(
|
54 |
-
# torch.nn.Linear(emb_dim, emb_dim, bias=False),
|
55 |
-
# torch.nn.ReLU(),
|
56 |
-
# torch.nn.Linear(emb_dim, out_dim, bias=False)
|
57 |
-
# )
|
58 |
-
# layers = [torch.nn.Linear(emb_dim, emb_dim, bias=False), torch.nn.ReLU()] * (num_mlp_layers_in_module-1) + [torch.nn.Linear(emb_dim, out_dim, bias=False)]
|
59 |
-
# self.pred = torch.nn.Sequential(*layers)
|
60 |
self.pred = MLPBiasFree(in_dim=emb_dim, out_dim=out_dim, hidden_dim=emb_dim, num_layer=num_mlp_layers_in_module)
|
61 |
|
62 |
# unroll the batch argments and comment out the pooling operation
|
|
|
48 |
self.convs = torch.nn.ModuleList()
|
49 |
for _ in range(num_layers):
|
50 |
self.convs.append(EGNNLayer(emb_dim, num_mlp_layers_in_module, aggr))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
self.pred = MLPBiasFree(in_dim=emb_dim, out_dim=out_dim, hidden_dim=emb_dim, num_layer=num_mlp_layers_in_module)
|
52 |
|
53 |
# unroll the batch argments and comment out the pooling operation
|
models/layers/__init__.py
DELETED
File without changes
|
models/layers/__pycache__/__init__.cpython-311.pyc
DELETED
Binary file (170 Bytes)
|
|
models/layers/__pycache__/egnn_layer_void_invariant.cpython-311.pyc
DELETED
Binary file (4.8 kB)
|
|
models/layers/egnn_layer_void_invariant.py
CHANGED
@@ -23,49 +23,11 @@ class EGNNLayer(MessagePassing):
|
|
23 |
super().__init__(aggr=aggr)
|
24 |
|
25 |
self.emb_dim = emb_dim
|
26 |
-
# self.activation = ReLU()
|
27 |
|
28 |
self.dist_embedding = Linear(1, emb_dim, bias=False)
|
29 |
self.innerprod_embedding = MLPBiasFree(in_dim=1, out_dim=1, hidden_dim=emb_dim, num_layer=num_mlp_layers)
|
30 |
-
|
31 |
-
# MLP `\psi_h` for computing messages `m_ij`
|
32 |
-
# self.mlp_msg = Sequential(
|
33 |
-
# Linear(2 * emb_dim + 1, emb_dim, bias=False),
|
34 |
-
# torch.nn.LayerNorm(emb_dim, bias=False),
|
35 |
-
# self.activation,
|
36 |
-
# Linear(emb_dim, emb_dim, bias=False),
|
37 |
-
# torch.nn.LayerNorm(emb_dim, bias=False),
|
38 |
-
# self.activation,
|
39 |
-
# )
|
40 |
-
# layers = [Linear(2 * emb_dim + 1, emb_dim, bias=False), torch.nn.LayerNorm(emb_dim, bias=False), self.activation] \
|
41 |
-
# + [Linear(emb_dim, emb_dim, bias=False), torch.nn.LayerNorm(emb_dim, bias=False), self.activation] * (num_mlp_layers-1)
|
42 |
-
# layers = [Linear(3 * emb_dim, emb_dim, bias=False)] \
|
43 |
-
# + [self.activation, Linear(emb_dim, emb_dim, bias=False)] * (num_mlp_layers-1) \
|
44 |
-
# + [torch.nn.LayerNorm(emb_dim, bias=False)]
|
45 |
-
# self.mlp_msg = Sequential(*layers)
|
46 |
self.mlp_msg = MLPBiasFree(in_dim=3*emb_dim, out_dim=emb_dim, hidden_dim=emb_dim, num_layer=num_mlp_layers)
|
47 |
-
|
48 |
-
# MLP `\psi_x` for computing messages `\overrightarrow{m}_ij`
|
49 |
-
# self.mlp_pos = Sequential(
|
50 |
-
# Linear(emb_dim, emb_dim), torch.nn.LayerNorm(emb_dim), self.activation, Linear(emb_dim, 1)
|
51 |
-
# )
|
52 |
-
# layers = [Linear(emb_dim, emb_dim, bias=False), torch.nn.LayerNorm(emb_dim, bias=False), self.activation] * (num_mlp_layers-1) + [Linear(emb_dim, 1, bias=False)]
|
53 |
-
# layers = [Linear(emb_dim, emb_dim, bias=False), self.activation] * (num_mlp_layers-1) + [Linear(emb_dim, 1, bias=False)]
|
54 |
-
# self.mlp_pos = Sequential(*layers)
|
55 |
self.mlp_pos = MLPBiasFree(in_dim=emb_dim, out_dim=1, hidden_dim=emb_dim, num_layer=num_mlp_layers)
|
56 |
-
|
57 |
-
# MLP `\phi` for computing updated node features `h_i^{l+1}`
|
58 |
-
# self.mlp_upd = Sequential(
|
59 |
-
# Linear(2 * emb_dim, emb_dim, bias=False),
|
60 |
-
# torch.nn.LayerNorm(emb_dim, bias=False),
|
61 |
-
# self.activation,
|
62 |
-
# Linear(emb_dim, emb_dim, bias=False),
|
63 |
-
# torch.nn.LayerNorm(emb_dim, bias=False),
|
64 |
-
# self.activation,
|
65 |
-
# )
|
66 |
-
# layers = [Linear(emb_dim, emb_dim, bias=False), torch.nn.LayerNorm(emb_dim, bias=False), self.activation] * num_mlp_layers
|
67 |
-
# layers = [Linear(emb_dim, emb_dim, bias=False)] + [self.activation, Linear(emb_dim, emb_dim, bias=False)] * (num_mlp_layers-1)
|
68 |
-
# self.mlp_upd = Sequential(*layers)
|
69 |
self.mlp_upd = MLPBiasFree(in_dim=emb_dim, out_dim=emb_dim, hidden_dim=emb_dim, num_layer=num_mlp_layers)
|
70 |
|
71 |
def forward(self, h, pos, edge_index):
|
@@ -83,7 +45,6 @@ class EGNNLayer(MessagePassing):
|
|
83 |
def message(self, h_i, h_j, pos_i, pos_j):
|
84 |
# Compute messages
|
85 |
pos_diff = pos_i - pos_j
|
86 |
-
# dists = torch.norm(pos_diff, dim=-1).unsqueeze(1)
|
87 |
dists = torch.exp(- torch.norm(pos_diff, dim=-1).unsqueeze(1) / 30 ) # reference distances: 30um
|
88 |
inner_prod = torch.mean(h_i * h_j, dim=-1).unsqueeze(1)
|
89 |
msg = torch.cat([h_i, h_j, self.dist_embedding(dists)], dim=-1) * self.innerprod_embedding(inner_prod)
|
@@ -91,10 +52,6 @@ class EGNNLayer(MessagePassing):
|
|
91 |
# Scale magnitude of displacement vector
|
92 |
pos_diff = pos_diff * self.mlp_pos(msg)
|
93 |
# NOTE: some papers divide pos_diff by (dists + 1) to stabilise model.
|
94 |
-
# NOTE: lucidrains clamps pos_diff between some [-n, +n], also for stability.
|
95 |
-
# print(torch.cat([h_i, h_j, self.dist_embedding(dists)], dim=-1))
|
96 |
-
# print(msg)
|
97 |
-
# import pdb; pdb.set_trace()
|
98 |
return msg, pos_diff, inner_prod
|
99 |
|
100 |
def aggregate(self, inputs, index):
|
@@ -109,17 +66,12 @@ class EGNNLayer(MessagePassing):
|
|
109 |
counts = scatter(counts, index, dim=0, reduce="add")
|
110 |
counts[counts==0] = 1
|
111 |
pos_aggr = pos_aggr / counts
|
112 |
-
# print(msgs)
|
113 |
-
# print(msg_aggr)
|
114 |
-
# import pdb; pdb.set_trace()
|
115 |
return msg_aggr, pos_aggr
|
116 |
|
117 |
def update(self, aggr_out, h, pos):
|
118 |
msg_aggr, pos_aggr = aggr_out
|
119 |
-
# upd_out = self.mlp_upd(torch.cat([h, msg_aggr], dim=-1))
|
120 |
upd_out = self.mlp_upd(msg_aggr)
|
121 |
upd_pos = pos + pos_aggr
|
122 |
-
# import pdb; pdb.set_trace()
|
123 |
return upd_out, upd_pos
|
124 |
|
125 |
def __repr__(self) -> str:
|
|
|
23 |
super().__init__(aggr=aggr)
|
24 |
|
25 |
self.emb_dim = emb_dim
|
|
|
26 |
|
27 |
self.dist_embedding = Linear(1, emb_dim, bias=False)
|
28 |
self.innerprod_embedding = MLPBiasFree(in_dim=1, out_dim=1, hidden_dim=emb_dim, num_layer=num_mlp_layers)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
self.mlp_msg = MLPBiasFree(in_dim=3*emb_dim, out_dim=emb_dim, hidden_dim=emb_dim, num_layer=num_mlp_layers)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
self.mlp_pos = MLPBiasFree(in_dim=emb_dim, out_dim=1, hidden_dim=emb_dim, num_layer=num_mlp_layers)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
self.mlp_upd = MLPBiasFree(in_dim=emb_dim, out_dim=emb_dim, hidden_dim=emb_dim, num_layer=num_mlp_layers)
|
32 |
|
33 |
def forward(self, h, pos, edge_index):
|
|
|
45 |
def message(self, h_i, h_j, pos_i, pos_j):
|
46 |
# Compute messages
|
47 |
pos_diff = pos_i - pos_j
|
|
|
48 |
dists = torch.exp(- torch.norm(pos_diff, dim=-1).unsqueeze(1) / 30 ) # reference distances: 30um
|
49 |
inner_prod = torch.mean(h_i * h_j, dim=-1).unsqueeze(1)
|
50 |
msg = torch.cat([h_i, h_j, self.dist_embedding(dists)], dim=-1) * self.innerprod_embedding(inner_prod)
|
|
|
52 |
# Scale magnitude of displacement vector
|
53 |
pos_diff = pos_diff * self.mlp_pos(msg)
|
54 |
# NOTE: some papers divide pos_diff by (dists + 1) to stabilise model.
|
|
|
|
|
|
|
|
|
55 |
return msg, pos_diff, inner_prod
|
56 |
|
57 |
def aggregate(self, inputs, index):
|
|
|
66 |
counts = scatter(counts, index, dim=0, reduce="add")
|
67 |
counts[counts==0] = 1
|
68 |
pos_aggr = pos_aggr / counts
|
|
|
|
|
|
|
69 |
return msg_aggr, pos_aggr
|
70 |
|
71 |
def update(self, aggr_out, h, pos):
|
72 |
msg_aggr, pos_aggr = aggr_out
|
|
|
73 |
upd_out = self.mlp_upd(msg_aggr)
|
74 |
upd_pos = pos + pos_aggr
|
|
|
75 |
return upd_out, upd_pos
|
76 |
|
77 |
def __repr__(self) -> str:
|
test.ipynb
CHANGED
@@ -12,9 +12,16 @@
|
|
12 |
"import scanpy as sc"
|
13 |
]
|
14 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
{
|
16 |
"cell_type": "code",
|
17 |
-
"execution_count":
|
18 |
"metadata": {},
|
19 |
"outputs": [
|
20 |
{
|
@@ -94,13 +101,25 @@
|
|
94 |
],
|
95 |
"source": [
|
96 |
"args_model = torch.load('./model_files/args.pt')\n",
|
97 |
-
"
|
|
|
98 |
"model.eval()"
|
99 |
]
|
100 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
{
|
102 |
"cell_type": "code",
|
103 |
-
"execution_count":
|
104 |
"metadata": {},
|
105 |
"outputs": [
|
106 |
{
|
@@ -120,7 +139,6 @@
|
|
120 |
}
|
121 |
],
|
122 |
"source": [
|
123 |
-
"channel2ensembl = torch.load('./model_files/channel2ensembl.pt')\n",
|
124 |
"adata = sc.read_h5ad('./adata.h5ad')\n",
|
125 |
"adata.layers['counts'] = adata.X.copy()\n",
|
126 |
"sc.pp.normalize_total(adata)\n",
|
@@ -128,9 +146,20 @@
|
|
128 |
"adata"
|
129 |
]
|
130 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
{
|
132 |
"cell_type": "code",
|
133 |
-
"execution_count":
|
134 |
"metadata": {},
|
135 |
"outputs": [
|
136 |
{
|
@@ -142,7 +171,16 @@
|
|
142 |
}
|
143 |
],
|
144 |
"source": [
|
145 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
]
|
147 |
},
|
148 |
{
|
@@ -174,6 +212,13 @@
|
|
174 |
"embeddings, embeddings.shape"
|
175 |
]
|
176 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
{
|
178 |
"cell_type": "code",
|
179 |
"execution_count": 5,
|
|
|
12 |
"import scanpy as sc"
|
13 |
]
|
14 |
},
|
15 |
+
{
|
16 |
+
"cell_type": "markdown",
|
17 |
+
"metadata": {},
|
18 |
+
"source": [
|
19 |
+
"### 1. load model"
|
20 |
+
]
|
21 |
+
},
|
22 |
{
|
23 |
"cell_type": "code",
|
24 |
+
"execution_count": null,
|
25 |
"metadata": {},
|
26 |
"outputs": [
|
27 |
{
|
|
|
101 |
],
|
102 |
"source": [
|
103 |
"args_model = torch.load('./model_files/args.pt')\n",
|
104 |
+
"device = 'cpu' # or 'cuda\n",
|
105 |
+
"model = CIFM.from_pretrained('ynyou/CIFM', args=args_model).to(device)\n",
|
106 |
"model.eval()"
|
107 |
]
|
108 |
},
|
109 |
+
{
|
110 |
+
"cell_type": "markdown",
|
111 |
+
"metadata": {},
|
112 |
+
"source": [
|
113 |
+
"### 2. load and preprocess sample adata\n",
|
114 |
+
"- some requirements for adata:\n",
|
115 |
+
"- ```adata.X```: need to the raw count\n",
|
116 |
+
"- ```adata.obsm['spatial']```: the coordinates of cells in the unit of micrometer\n",
|
117 |
+
"- if in a different unit, it might result in a weird geometric graph: we use a radius 20 (micrometer) to construct the geometric graph in the model, so a different unit might result in a overly sparse or dense graph"
|
118 |
+
]
|
119 |
+
},
|
120 |
{
|
121 |
"cell_type": "code",
|
122 |
+
"execution_count": null,
|
123 |
"metadata": {},
|
124 |
"outputs": [
|
125 |
{
|
|
|
139 |
}
|
140 |
],
|
141 |
"source": [
|
|
|
142 |
"adata = sc.read_h5ad('./adata.h5ad')\n",
|
143 |
"adata.layers['counts'] = adata.X.copy()\n",
|
144 |
"sc.pp.normalize_total(adata)\n",
|
|
|
146 |
"adata"
|
147 |
]
|
148 |
},
|
149 |
+
{
|
150 |
+
"cell_type": "markdown",
|
151 |
+
"metadata": {},
|
152 |
+
"source": [
|
153 |
+
"### 3. match feature channels\n",
|
154 |
+
"- we need a list which maps feature channels to ensemble ids: ```channel2ensembl_ids_target```\n",
|
155 |
+
"- format: ```channel2ensembl_ids_target = [[ensemblid1_for_channel1, ensemblid1_for_channel1, ...], [ensemblid11_for_channel2, ensemblid12_for_channel2, ...], ...]```\n",
|
156 |
+
"- one channel could correspond to multiple ensemble ids, e.g., when your original data the channels are annotated with gene names\n",
|
157 |
+
"- you can use to BioMart map you each gene name to one or multiple ensemble ids"
|
158 |
+
]
|
159 |
+
},
|
160 |
{
|
161 |
"cell_type": "code",
|
162 |
+
"execution_count": null,
|
163 |
"metadata": {},
|
164 |
"outputs": [
|
165 |
{
|
|
|
171 |
}
|
172 |
],
|
173 |
"source": [
|
174 |
+
"channel2ensembl_ids_target = [[i] for i in adata.var.index.tolist()]\n",
|
175 |
+
"channel2ensembl_ids_source = torch.load('./model_files/channel2ensembl.pt')\n",
|
176 |
+
"model.channel_matching(channel2ensembl_ids_target, channel2ensembl_ids_source)"
|
177 |
+
]
|
178 |
+
},
|
179 |
+
{
|
180 |
+
"cell_type": "markdown",
|
181 |
+
"metadata": {},
|
182 |
+
"source": [
|
183 |
+
"### 4. embed the microenvironments centered at each cell"
|
184 |
]
|
185 |
},
|
186 |
{
|
|
|
212 |
"embeddings, embeddings.shape"
|
213 |
]
|
214 |
},
|
215 |
+
{
|
216 |
+
"cell_type": "markdown",
|
217 |
+
"metadata": {},
|
218 |
+
"source": [
|
219 |
+
"### 5. infer the potential gene expressions at certain locations"
|
220 |
+
]
|
221 |
+
},
|
222 |
{
|
223 |
"cell_type": "code",
|
224 |
"execution_count": 5,
|