Yuning You commited on
Commit
552cf9a
Β·
1 Parent(s): 3218749
README.md CHANGED
@@ -32,6 +32,7 @@ pip install torch-scatter torch-sparse torch-cluster torch-geometric -f https://
32
  pip install scanpy e3nn
33
  pip install lightning==2.1.0
34
  pip install numpy==1.26.4
 
35
  ```
36
  where the compatibility between ```torch``` and ```torch-geometric``` is not always guaranteed (the first two lines), since these two guys are picky on the platform and the version of other packages. You may trial and error to find the right version of ```torch``` and ```torch-geometric``` that works for you, e.g., in some machine I once installed via:
37
  ```
 
32
  pip install scanpy e3nn
33
  pip install lightning==2.1.0
34
  pip install numpy==1.26.4
35
+ pip install transformers
36
  ```
37
  where the compatibility between ```torch``` and ```torch-geometric``` is not always guaranteed (the first two lines), since these two guys are picky on the platform and the version of other packages. You may trial and error to find the right version of ```torch``` and ```torch-geometric``` that works for you, e.g., in some machine I once installed via:
38
  ```
{model_files β†’ models_cifm}/args.pt RENAMED
File without changes
{model_files β†’ models_cifm}/channel2ensembl.pt RENAMED
File without changes
{models β†’ models_cifm}/cifm.py RENAMED
@@ -3,8 +3,8 @@ import torch.nn as nn
3
  from torch_geometric.nn import radius_graph
4
  import scanpy as sc
5
  from huggingface_hub import PyTorchModelHubMixin
6
- from models.mlp_and_gnn import MLPBiasFree
7
- from models.egnn_void_invariant import VIEGNNModel
8
 
9
 
10
  class CIFM(
 
3
  from torch_geometric.nn import radius_graph
4
  import scanpy as sc
5
  from huggingface_hub import PyTorchModelHubMixin
6
+ from models_cifm.mlp_and_gnn import MLPBiasFree
7
+ from models_cifm.egnn_void_invariant import VIEGNNModel
8
 
9
 
10
  class CIFM(
{models β†’ models_cifm}/egnn_void_invariant.py RENAMED
@@ -2,8 +2,8 @@ import torch
2
  from torch.nn import functional as F
3
  from torch_geometric.nn import global_add_pool, global_mean_pool
4
 
5
- from models.layers.egnn_layer_void_invariant import EGNNLayer
6
- from models.mlp_and_gnn import MLPBiasFree
7
 
8
 
9
  class VIEGNNModel(torch.nn.Module):
 
2
  from torch.nn import functional as F
3
  from torch_geometric.nn import global_add_pool, global_mean_pool
4
 
5
+ from models_cifm.layers.egnn_layer_void_invariant import EGNNLayer
6
+ from models_cifm.mlp_and_gnn import MLPBiasFree
7
 
8
 
9
  class VIEGNNModel(torch.nn.Module):
{models β†’ models_cifm}/layers/egnn_layer_void_invariant.py RENAMED
@@ -3,7 +3,7 @@ from torch.nn import Linear, ReLU, SiLU, Sequential
3
  from torch_geometric.nn import MessagePassing
4
  from torch_scatter import scatter
5
 
6
- from models.mlp_and_gnn import MLPBiasFree
7
 
8
 
9
  class EGNNLayer(MessagePassing):
 
3
  from torch_geometric.nn import MessagePassing
4
  from torch_scatter import scatter
5
 
6
+ from models_cifm.mlp_and_gnn import MLPBiasFree
7
 
8
 
9
  class EGNNLayer(MessagePassing):
{models β†’ models_cifm}/mlp_and_gnn.py RENAMED
File without changes
test.ipynb CHANGED
@@ -8,7 +8,7 @@
8
  "source": [
9
  "import torch\n",
10
  "import numpy as np\n",
11
- "from models.cifm import CIFM\n",
12
  "import scanpy as sc"
13
  ]
14
  },
@@ -100,7 +100,7 @@
100
  }
101
  ],
102
  "source": [
103
- "args_model = torch.load('./model_files/args.pt')\n",
104
  "device = 'cpu' # or 'cuda\n",
105
  "model = CIFM.from_pretrained('ynyou/CIFM', args=args_model).to(device)\n",
106
  "model.eval()"
@@ -172,7 +172,7 @@
172
  ],
173
  "source": [
174
  "channel2ensembl_ids_target = [[i] for i in adata.var.index.tolist()]\n",
175
- "channel2ensembl_ids_source = torch.load('./model_files/channel2ensembl.pt')\n",
176
  "model.channel_matching(channel2ensembl_ids_target, channel2ensembl_ids_source)"
177
  ]
178
  },
 
8
  "source": [
9
  "import torch\n",
10
  "import numpy as np\n",
11
+ "from models_cifm.cifm import CIFM\n",
12
  "import scanpy as sc"
13
  ]
14
  },
 
100
  }
101
  ],
102
  "source": [
103
+ "args_model = torch.load('./models_cifm/args.pt')\n",
104
  "device = 'cpu' # or 'cuda\n",
105
  "model = CIFM.from_pretrained('ynyou/CIFM', args=args_model).to(device)\n",
106
  "model.eval()"
 
172
  ],
173
  "source": [
174
  "channel2ensembl_ids_target = [[i] for i in adata.var.index.tolist()]\n",
175
+ "channel2ensembl_ids_source = torch.load('./models_cifm/channel2ensembl.pt')\n",
176
  "model.channel_matching(channel2ensembl_ids_target, channel2ensembl_ids_source)"
177
  ]
178
  },