Yuning You
commited on
Commit
Β·
552cf9a
1
Parent(s):
3218749
update
Browse files- README.md +1 -0
- {model_files β models_cifm}/args.pt +0 -0
- {model_files β models_cifm}/channel2ensembl.pt +0 -0
- {models β models_cifm}/cifm.py +2 -2
- {models β models_cifm}/egnn_void_invariant.py +2 -2
- {models β models_cifm}/layers/egnn_layer_void_invariant.py +1 -1
- {models β models_cifm}/mlp_and_gnn.py +0 -0
- test.ipynb +3 -3
README.md
CHANGED
@@ -32,6 +32,7 @@ pip install torch-scatter torch-sparse torch-cluster torch-geometric -f https://
|
|
32 |
pip install scanpy e3nn
|
33 |
pip install lightning==2.1.0
|
34 |
pip install numpy==1.26.4
|
|
|
35 |
```
|
36 |
where the compatibility between ```torch``` and ```torch-geometric``` is not always guaranteed (the first two lines), since these two guys are picky on the platform and the version of other packages. You may trial and error to find the right version of ```torch``` and ```torch-geometric``` that works for you, e.g., in some machine I once installed via:
|
37 |
```
|
|
|
32 |
pip install scanpy e3nn
|
33 |
pip install lightning==2.1.0
|
34 |
pip install numpy==1.26.4
|
35 |
+
pip install transformers
|
36 |
```
|
37 |
where the compatibility between ```torch``` and ```torch-geometric``` is not always guaranteed (the first two lines), since these two guys are picky on the platform and the version of other packages. You may trial and error to find the right version of ```torch``` and ```torch-geometric``` that works for you, e.g., in some machine I once installed via:
|
38 |
```
|
{model_files β models_cifm}/args.pt
RENAMED
File without changes
|
{model_files β models_cifm}/channel2ensembl.pt
RENAMED
File without changes
|
{models β models_cifm}/cifm.py
RENAMED
@@ -3,8 +3,8 @@ import torch.nn as nn
|
|
3 |
from torch_geometric.nn import radius_graph
|
4 |
import scanpy as sc
|
5 |
from huggingface_hub import PyTorchModelHubMixin
|
6 |
-
from
|
7 |
-
from
|
8 |
|
9 |
|
10 |
class CIFM(
|
|
|
3 |
from torch_geometric.nn import radius_graph
|
4 |
import scanpy as sc
|
5 |
from huggingface_hub import PyTorchModelHubMixin
|
6 |
+
from models_cifm.mlp_and_gnn import MLPBiasFree
|
7 |
+
from models_cifm.egnn_void_invariant import VIEGNNModel
|
8 |
|
9 |
|
10 |
class CIFM(
|
{models β models_cifm}/egnn_void_invariant.py
RENAMED
@@ -2,8 +2,8 @@ import torch
|
|
2 |
from torch.nn import functional as F
|
3 |
from torch_geometric.nn import global_add_pool, global_mean_pool
|
4 |
|
5 |
-
from
|
6 |
-
from
|
7 |
|
8 |
|
9 |
class VIEGNNModel(torch.nn.Module):
|
|
|
2 |
from torch.nn import functional as F
|
3 |
from torch_geometric.nn import global_add_pool, global_mean_pool
|
4 |
|
5 |
+
from models_cifm.layers.egnn_layer_void_invariant import EGNNLayer
|
6 |
+
from models_cifm.mlp_and_gnn import MLPBiasFree
|
7 |
|
8 |
|
9 |
class VIEGNNModel(torch.nn.Module):
|
{models β models_cifm}/layers/egnn_layer_void_invariant.py
RENAMED
@@ -3,7 +3,7 @@ from torch.nn import Linear, ReLU, SiLU, Sequential
|
|
3 |
from torch_geometric.nn import MessagePassing
|
4 |
from torch_scatter import scatter
|
5 |
|
6 |
-
from
|
7 |
|
8 |
|
9 |
class EGNNLayer(MessagePassing):
|
|
|
3 |
from torch_geometric.nn import MessagePassing
|
4 |
from torch_scatter import scatter
|
5 |
|
6 |
+
from models_cifm.mlp_and_gnn import MLPBiasFree
|
7 |
|
8 |
|
9 |
class EGNNLayer(MessagePassing):
|
{models β models_cifm}/mlp_and_gnn.py
RENAMED
File without changes
|
test.ipynb
CHANGED
@@ -8,7 +8,7 @@
|
|
8 |
"source": [
|
9 |
"import torch\n",
|
10 |
"import numpy as np\n",
|
11 |
-
"from
|
12 |
"import scanpy as sc"
|
13 |
]
|
14 |
},
|
@@ -100,7 +100,7 @@
|
|
100 |
}
|
101 |
],
|
102 |
"source": [
|
103 |
-
"args_model = torch.load('./
|
104 |
"device = 'cpu' # or 'cuda\n",
|
105 |
"model = CIFM.from_pretrained('ynyou/CIFM', args=args_model).to(device)\n",
|
106 |
"model.eval()"
|
@@ -172,7 +172,7 @@
|
|
172 |
],
|
173 |
"source": [
|
174 |
"channel2ensembl_ids_target = [[i] for i in adata.var.index.tolist()]\n",
|
175 |
-
"channel2ensembl_ids_source = torch.load('./
|
176 |
"model.channel_matching(channel2ensembl_ids_target, channel2ensembl_ids_source)"
|
177 |
]
|
178 |
},
|
|
|
8 |
"source": [
|
9 |
"import torch\n",
|
10 |
"import numpy as np\n",
|
11 |
+
"from models_cifm.cifm import CIFM\n",
|
12 |
"import scanpy as sc"
|
13 |
]
|
14 |
},
|
|
|
100 |
}
|
101 |
],
|
102 |
"source": [
|
103 |
+
"args_model = torch.load('./models_cifm/args.pt')\n",
|
104 |
"device = 'cpu' # or 'cuda\n",
|
105 |
"model = CIFM.from_pretrained('ynyou/CIFM', args=args_model).to(device)\n",
|
106 |
"model.eval()"
|
|
|
172 |
],
|
173 |
"source": [
|
174 |
"channel2ensembl_ids_target = [[i] for i in adata.var.index.tolist()]\n",
|
175 |
+
"channel2ensembl_ids_source = torch.load('./models_cifm/channel2ensembl.pt')\n",
|
176 |
"model.channel_matching(channel2ensembl_ids_target, channel2ensembl_ids_source)"
|
177 |
]
|
178 |
},
|