Yaning1001 commited on
Commit
c91d7b1
·
verified ·
1 Parent(s): 92b9080

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. deeprobust/graph/defense_pyg/__init__.py +15 -0
  2. deeprobust/graph/defense_pyg/appnp.py +79 -0
  3. deeprobust/graph/defense_pyg/base_model.py +206 -0
  4. deeprobust/graph/defense_pyg/gpr.py +135 -0
  5. deeprobust/graph/defense_pyg/mygat_conv.py +198 -0
  6. deeprobust/graph/rl/__init__.py +0 -0
  7. deeprobust/graph/rl/env.py +258 -0
  8. deeprobust/graph/rl/q_net_node.py +228 -0
  9. deeprobust/image/adversary_examples/advexample.png +0 -0
  10. deeprobust/image/adversary_examples/cifar_advexample_orig.png +0 -0
  11. deeprobust/image/adversary_examples/cifar_advexample_pgd.png +0 -0
  12. deeprobust/image/adversary_examples/deepfool_diff.png +0 -0
  13. deeprobust/image/adversary_examples/imageexample.png +0 -0
  14. deeprobust/image/adversary_examples/test.jpg +0 -0
  15. deeprobust/image/adversary_examples/test1.jpg +0 -0
  16. deeprobust/image/evaluation_attack.py +226 -0
  17. deeprobust/image/netmodels/CNN.py +125 -0
  18. deeprobust/image/netmodels/CNN_multilayer.py +122 -0
  19. deeprobust/image/netmodels/YOPOCNN.py +70 -0
  20. deeprobust/image/netmodels/resnet.py +168 -0
  21. deeprobust/image/netmodels/train_model.py +146 -0
  22. deeprobust/image/netmodels/train_resnet.py +39 -0
  23. deeprobust/image/netmodels/vgg.py +116 -0
  24. deeprobust/image/synset_words.txt +1000 -0
  25. docs/Makefile +20 -0
  26. docs/conf.py +71 -0
  27. docs/index.rst +65 -0
  28. examples/graph/cgscore_datasets.py +255 -0
  29. examples/graph/cgscore_datasets_multigpus.py +299 -0
  30. examples/graph/cgscore_datasets_multigpus2.py +208 -0
  31. examples/graph/cgscore_env.yaml +193 -0
  32. examples/graph/cgscore_experiments/attack_method/attack_minmax.py +106 -0
  33. examples/graph/cgscore_experiments/attack_method/attack_nettack.py +212 -0
  34. examples/graph/cgscore_experiments/defense_method/GAT.py +61 -0
  35. examples/graph/cgscore_experiments/defense_method/GCN.py +73 -0
  36. examples/graph/cgscore_experiments/defense_method/GCNJaccard.py +68 -0
  37. examples/graph/cgscore_experiments/defense_method/GCNSVD.py +63 -0
  38. examples/graph/cgscore_experiments/defense_method/GNNGuard.py +64 -0
  39. examples/graph/cgscore_experiments/defense_method/ProGNN.py +80 -0
  40. examples/graph/cgscore_experiments/defense_method/RGCN.py +66 -0
  41. examples/graph/cgscore_experiments/defense_method/cgscore.py +0 -0
  42. examples/graph/cgscore_experiments/grb/grb_data.py +32 -0
  43. examples/graph/cgscore_save.py +402 -0
  44. examples/graph/test_adv_train_evasion.py +112 -0
  45. examples/graph/test_adv_train_poisoning.py +78 -0
  46. examples/graph/test_all.py +13 -0
  47. examples/graph/test_chebnet.py +48 -0
  48. examples/graph/test_deepwalk.py +39 -0
  49. examples/graph/test_gat.py +55 -0
  50. examples/graph/test_gcn.py +69 -0
deeprobust/graph/defense_pyg/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ from .gcn import GCN
3
+ from .gat import GAT
4
+ from .appnp import APPNP
5
+ from .sage import SAGE
6
+ from .gpr import GPRGNN
7
+ from .airgnn import AirGNN
8
+ except ImportError as e:
9
+ print(e)
10
+ warnings.warn("Please install pytorch geometric if you " +
11
+ "would like to use the datasets from pytorch " +
12
+ "geometric. See details in https://pytorch-geom" +
13
+ "etric.readthedocs.io/en/latest/notes/installation.html")
14
+
15
+ __all__ = ["GCN", "GAT", "APPNP", "SAGE", "GPRGNN", "AirGNN"]
deeprobust/graph/defense_pyg/appnp.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+ import math
4
+ import torch
5
+ from torch.nn.parameter import Parameter
6
+ from torch.nn.modules.module import Module
7
+ from torch_geometric.nn import APPNP as APPNPConv
8
+ from torch.nn import Linear
9
+ from .base_model import BaseModel
10
+
11
+
12
+ class APPNP(BaseModel):
13
+
14
+ def __init__(self, nfeat, nhid, nclass, K=10, alpha=0.1, dropout=0.5, lr=0.01,
15
+ with_bn=False, weight_decay=5e-4, with_bias=True, device=None):
16
+
17
+ super(APPNP, self).__init__()
18
+
19
+ assert device is not None, "Please specify 'device'!"
20
+ self.device = device
21
+
22
+
23
+ self.lin1 = Linear(nfeat, nhid)
24
+ if with_bn:
25
+ self.bn1 = nn.BatchNorm1d(nhid)
26
+ self.bn2 = nn.BatchNorm1d(nclass)
27
+
28
+ self.lin2 = Linear(nhid, nclass)
29
+ self.prop1 = APPNPConv(K, alpha)
30
+
31
+ self.dropout = dropout
32
+ self.weight_decay = weight_decay
33
+ self.lr = lr
34
+ self.output = None
35
+ self.best_model = None
36
+ self.best_output = None
37
+ self.name = 'APPNP'
38
+ self.with_bn = with_bn
39
+
40
+ def forward(self, x, edge_index, edge_weight=None):
41
+ x = F.dropout(x, p=self.dropout, training=self.training)
42
+ x = self.lin1(x)
43
+ if self.with_bn:
44
+ x = self.bn1(x)
45
+ x = F.relu(x)
46
+ x = F.dropout(x, p=self.dropout, training=self.training)
47
+ x = self.lin2(x)
48
+ if self.with_bn:
49
+ x = self.bn2(x)
50
+ x = self.prop1(x, edge_index, edge_weight)
51
+ return F.log_softmax(x, dim=1)
52
+
53
+
54
+ def initialize(self):
55
+ self.lin1.reset_parameters()
56
+ self.lin2.reset_parameters()
57
+ if self.with_bn:
58
+ self.bn1.reset_parameters()
59
+ self.bn2.reset_parameters()
60
+
61
+
62
+ if __name__ == "__main__":
63
+ from deeprobust.graph.data import Dataset, Dpr2Pyg
64
+ data = Dataset(root='/tmp/', name='cora', setting='gcn')
65
+ adj, features, labels = data.adj, data.features, data.labels
66
+ idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
67
+ model = GCN(nfeat=features.shape[1],
68
+ nhid=16,
69
+ nclass=labels.max().item() + 1,
70
+ dropout=0.5, device='cuda')
71
+ model = model.to('cuda')
72
+ pyg_data = Dpr2Pyg(data)[0]
73
+
74
+ import ipdb
75
+ ipdb.set_trace()
76
+
77
+ model.fit(pyg_data, verbose=True) # train with earlystopping
78
+ model.test()
79
+ print(model.predict())
deeprobust/graph/defense_pyg/base_model.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.optim as optim
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from copy import deepcopy
5
+ from deeprobust.graph import utils
6
+ import torch
7
+
8
+
9
+ class BaseModel(nn.Module):
10
+ def __init__(self):
11
+ super(BaseModel, self).__init__()
12
+ pass
13
+
14
+ def fit(self, pyg_data, train_iters=1000, initialize=True, verbose=False, patience=100, **kwargs):
15
+ if initialize:
16
+ self.initialize()
17
+
18
+ # self.data = pyg_data[0].to(self.device)
19
+ self.data = pyg_data.to(self.device)
20
+ # By default, it is trained with early stopping on validation
21
+ self.train_with_early_stopping(train_iters, patience, verbose)
22
+
23
+ def finetune(self, edge_index, edge_weight, feat=None, train_iters=10, verbose=True):
24
+ if verbose:
25
+ print(f'=== finetuning {self.name} model ===')
26
+ optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
27
+ labels = self.data.y
28
+ if feat is None:
29
+ x = self.data.x
30
+ else:
31
+ x = feat
32
+ train_mask, val_mask = self.data.train_mask, self.data.val_mask
33
+ best_loss_val = 100
34
+ best_acc_val = 0
35
+ for i in range(train_iters):
36
+ self.train()
37
+ optimizer.zero_grad()
38
+ output = self.forward(x, edge_index, edge_weight)
39
+ loss_train = F.nll_loss(output[train_mask], labels[train_mask])
40
+ loss_train.backward()
41
+ optimizer.step()
42
+
43
+ if verbose and i % 50 == 0:
44
+ print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
45
+
46
+ self.eval()
47
+ with torch.no_grad():
48
+ output = self.forward(x, edge_index)
49
+ loss_val = F.nll_loss(output[val_mask], labels[val_mask])
50
+ acc_val = utils.accuracy(output[val_mask], labels[val_mask])
51
+
52
+ # if best_loss_val > loss_val:
53
+ # best_loss_val = loss_val
54
+ # best_output = output
55
+ # weights = deepcopy(self.state_dict())
56
+
57
+ if best_acc_val < acc_val:
58
+ best_acc_val = acc_val
59
+ best_output = output
60
+ weights = deepcopy(self.state_dict())
61
+
62
+ print('best_acc_val:', best_acc_val.item())
63
+ self.load_state_dict(weights)
64
+ return best_output
65
+
66
+
67
+ def _fit_with_val(self, pyg_data, train_iters=1000, initialize=True, verbose=False, **kwargs):
68
+ if initialize:
69
+ self.initialize()
70
+
71
+ # self.data = pyg_data[0].to(self.device)
72
+ self.data = pyg_data.to(self.device)
73
+ if verbose:
74
+ print(f'=== training {self.name} model ===')
75
+ optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
76
+
77
+ labels = self.data.y
78
+ train_mask, val_mask = self.data.train_mask, self.data.val_mask
79
+
80
+ x, edge_index = self.data.x, self.data.edge_index
81
+ for i in range(train_iters):
82
+ self.train()
83
+ optimizer.zero_grad()
84
+ output = self.forward(x, edge_index)
85
+ loss_train = F.nll_loss(output[train_mask+val_mask], labels[train_mask+val_mask])
86
+ loss_train.backward()
87
+ optimizer.step()
88
+
89
+ if verbose and i % 50 == 0:
90
+ print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
91
+
92
+ def fit_with_val(self, pyg_data, train_iters=1000, initialize=True, patience=100, verbose=False, **kwargs):
93
+ if initialize:
94
+ self.initialize()
95
+
96
+ self.data = pyg_data.to(self.device)
97
+ self.data.train_mask = self.data.train_mask + self.data.val1_mask
98
+ self.data.val_mask = self.data.val2_mask
99
+ self.train_with_early_stopping(train_iters, patience, verbose)
100
+
101
+ def train_with_early_stopping(self, train_iters, patience, verbose):
102
+ """early stopping based on the validation loss
103
+ """
104
+ if verbose:
105
+ print(f'=== training {self.name} model ===')
106
+ optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
107
+
108
+ labels = self.data.y
109
+ train_mask, val_mask = self.data.train_mask, self.data.val_mask
110
+
111
+ early_stopping = patience
112
+ best_loss_val = 100
113
+ best_acc_val = 0
114
+ best_epoch = 0
115
+
116
+ x, edge_index = self.data.x, self.data.edge_index
117
+ for i in range(train_iters):
118
+ self.train()
119
+ optimizer.zero_grad()
120
+
121
+ output = self.forward(x, edge_index)
122
+
123
+ loss_train = F.nll_loss(output[train_mask], labels[train_mask])
124
+ loss_train.backward()
125
+ optimizer.step()
126
+
127
+ if verbose and i % 50 == 0:
128
+ print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
129
+
130
+ self.eval()
131
+ output = self.forward(x, edge_index)
132
+ loss_val = F.nll_loss(output[val_mask], labels[val_mask])
133
+ acc_val = utils.accuracy(output[val_mask], labels[val_mask])
134
+ # print(acc)
135
+
136
+ # if best_loss_val > loss_val:
137
+ # best_loss_val = loss_val
138
+ # self.output = output
139
+ # weights = deepcopy(self.state_dict())
140
+ # patience = early_stopping
141
+ # best_epoch = i
142
+ # else:
143
+ # patience -= 1
144
+
145
+ if best_acc_val < acc_val:
146
+ best_acc_val = acc_val
147
+ self.output = output
148
+ weights = deepcopy(self.state_dict())
149
+ patience = early_stopping
150
+ best_epoch = i
151
+ else:
152
+ patience -= 1
153
+
154
+ if i > early_stopping and patience <= 0:
155
+ break
156
+
157
+ if verbose:
158
+ # print('=== early stopping at {0}, loss_val = {1} ==='.format(best_epoch, best_loss_val) )
159
+ print('=== early stopping at {0}, acc_val = {1} ==='.format(best_epoch, best_acc_val) )
160
+ self.load_state_dict(weights)
161
+
162
+ def test(self):
163
+ """Evaluate model performance on test set.
164
+ Parameters
165
+ ----------
166
+ idx_test :
167
+ node testing indices
168
+ """
169
+ self.eval()
170
+ test_mask = self.data.test_mask
171
+ labels = self.data.y
172
+ output = self.forward(self.data.x, self.data.edge_index)
173
+ # output = self.output
174
+ loss_test = F.nll_loss(output[test_mask], labels[test_mask])
175
+ acc_test = utils.accuracy(output[test_mask], labels[test_mask])
176
+ print("Test set results:",
177
+ "loss= {:.4f}".format(loss_test.item()),
178
+ "accuracy= {:.4f}".format(acc_test.item()))
179
+ return acc_test.item()
180
+
181
+ def predict(self, x=None, edge_index=None, edge_weight=None):
182
+ """
183
+ Returns
184
+ -------
185
+ torch.FloatTensor
186
+ output (log probabilities)
187
+ """
188
+ self.eval()
189
+ if x is None or edge_index is None:
190
+ x, edge_index = self.data.x, self.data.edge_index
191
+ return self.forward(x, edge_index, edge_weight)
192
+
193
+ def _ensure_contiguousness(self,
194
+ x,
195
+ edge_idx,
196
+ edge_weight):
197
+ if not x.is_sparse:
198
+ x = x.contiguous()
199
+ if hasattr(edge_idx, 'contiguous'):
200
+ edge_idx = edge_idx.contiguous()
201
+ if edge_weight is not None:
202
+ edge_weight = edge_weight.contiguous()
203
+ return x, edge_idx, edge_weight
204
+
205
+
206
+
deeprobust/graph/defense_pyg/gpr.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch_sparse import SparseTensor, matmul
5
+ from torch_geometric.nn import GCNConv, SAGEConv, GATConv, APPNP, MessagePassing
6
+ from torch_geometric.nn.conv.gcn_conv import gcn_norm
7
+ import scipy.sparse
8
+ import numpy as np
9
+ from .base_model import BaseModel
10
+
11
+
12
+ class GPRGNN(BaseModel):
13
+ """GPRGNN, from original repo https://github.com/jianhao2016/GPRGNN"""
14
+
15
+ def __init__(self, nfeat, nhid, nclass, Init='PPR', dprate=.5, dropout=.5,
16
+ lr=0.01, weight_decay=0, device='cpu',
17
+ K=10, alpha=.1, Gamma=None, ppnp='GPR_prop'):
18
+ super(GPRGNN, self).__init__()
19
+ self.lin1 = nn.Linear(nfeat, nhid)
20
+ self.lin2 = nn.Linear(nhid, nclass)
21
+
22
+ if ppnp == 'PPNP':
23
+ self.prop1 = APPNP(K, alpha)
24
+ elif ppnp == 'GPR_prop':
25
+ self.prop1 = GPR_prop(K, alpha, Init, Gamma)
26
+
27
+ self.Init = Init
28
+ self.dprate = dprate
29
+ self.dropout = dropout
30
+ self.name = "GPR"
31
+ self.weight_decay = weight_decay
32
+ self.lr = lr
33
+ self.device=device
34
+
35
+ def initialize(self):
36
+ self.reset_parameters()
37
+
38
+ def reset_parameters(self):
39
+ self.lin1.reset_parameters()
40
+ self.lin2.reset_parameters()
41
+ self.prop1.reset_parameters()
42
+
43
+ def forward(self, x, edge_index, edge_weight=None):
44
+
45
+ x = F.dropout(x, p=self.dropout, training=self.training)
46
+ x = F.relu(self.lin1(x))
47
+ x = F.dropout(x, p=self.dropout, training=self.training)
48
+ x = self.lin2(x)
49
+
50
+ if edge_weight is not None:
51
+ adj = SparseTensor.from_edge_index(edge_index, edge_weight, sparse_sizes=2 * x.shape[:1])
52
+ if self.dprate == 0.0:
53
+ x = self.prop1(x, adj)
54
+ else:
55
+ x = F.dropout(x, p=self.dprate, training=self.training)
56
+ x = self.prop1(x, adj)
57
+ else:
58
+ if self.dprate == 0.0:
59
+ x = self.prop1(x, edge_index, edge_weight)
60
+ else:
61
+ x = F.dropout(x, p=self.dprate, training=self.training)
62
+ x = self.prop1(x, edge_index, edge_weight)
63
+
64
+ return F.log_softmax(x, dim=1)
65
+
66
+
67
+ class GPR_prop(MessagePassing):
68
+ '''
69
+ GPRGNN, from original repo https://github.com/jianhao2016/GPRGNN
70
+ propagation class for GPR_GNN
71
+ '''
72
+
73
+ def __init__(self, K, alpha, Init, Gamma=None, bias=True, **kwargs):
74
+ super(GPR_prop, self).__init__(aggr='add', **kwargs)
75
+ self.K = K
76
+ self.Init = Init
77
+ self.alpha = alpha
78
+
79
+ assert Init in ['SGC', 'PPR', 'NPPR', 'Random', 'WS']
80
+ if Init == 'SGC':
81
+ # SGC-like
82
+ TEMP = 0.0*np.ones(K+1)
83
+ TEMP[alpha] = 1.0
84
+ elif Init == 'PPR':
85
+ # PPR-like
86
+ TEMP = alpha*(1-alpha)**np.arange(K+1)
87
+ TEMP[-1] = (1-alpha)**K
88
+ elif Init == 'NPPR':
89
+ # Negative PPR
90
+ TEMP = (alpha)**np.arange(K+1)
91
+ TEMP = TEMP/np.sum(np.abs(TEMP))
92
+ elif Init == 'Random':
93
+ # Random
94
+ bound = np.sqrt(3/(K+1))
95
+ TEMP = np.random.uniform(-bound, bound, K+1)
96
+ TEMP = TEMP/np.sum(np.abs(TEMP))
97
+ elif Init == 'WS':
98
+ # Specify Gamma
99
+ TEMP = Gamma
100
+
101
+ self.temp = nn.Parameter(torch.tensor(TEMP))
102
+
103
+ def reset_parameters(self):
104
+ nn.init.zeros_(self.temp)
105
+ for k in range(self.K+1):
106
+ self.temp.data[k] = self.alpha*(1-self.alpha)**k
107
+ self.temp.data[-1] = (1-self.alpha)**self.K
108
+
109
+ def forward(self, x, edge_index, edge_weight=None):
110
+ if isinstance(edge_index, torch.Tensor):
111
+ edge_index, norm = gcn_norm(
112
+ edge_index, edge_weight, num_nodes=x.size(0), dtype=x.dtype)
113
+ elif isinstance(edge_index, SparseTensor):
114
+ edge_index = gcn_norm(
115
+ edge_index, edge_weight, num_nodes=x.size(0), dtype=x.dtype)
116
+ norm = None
117
+
118
+ hidden = x*(self.temp[0])
119
+ for k in range(self.K):
120
+ x = self.propagate(edge_index, x=x, norm=norm)
121
+ gamma = self.temp[k+1]
122
+ hidden = hidden + gamma*x
123
+ return hidden
124
+
125
+ def message(self, x_j, norm):
126
+ return norm.view(-1, 1) * x_j
127
+
128
+ def message_and_aggregate(self, adj_t, x):
129
+ return matmul(adj_t, x, reduce=self.aggr)
130
+
131
+ def __repr__(self):
132
+ return '{}(K={}, temp={})'.format(self.__class__.__name__, self.K,
133
+ self.temp)
134
+
135
+
deeprobust/graph/defense_pyg/mygat_conv.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Slightly modified torch_geometric.nn.GATConv to fit the structure learning module"""
2
+ from typing import Union, Tuple, Optional
3
+ from torch_geometric.typing import (OptPairTensor, Adj, Size, NoneType,
4
+ OptTensor)
5
+
6
+ import torch
7
+ from torch import Tensor
8
+ import torch.nn.functional as F
9
+ from torch.nn import Parameter, Linear
10
+ from torch_sparse import SparseTensor, set_diag
11
+ from torch_geometric.nn.conv import MessagePassing
12
+ from torch_geometric.utils import remove_self_loops, add_self_loops, softmax
13
+
14
+ from torch_geometric.nn.inits import glorot, zeros
15
+
16
+
17
+ class GATConv(MessagePassing):
18
+ r"""The graph attentional operator from the `"Graph Attention Networks"
19
+ <https://arxiv.org/abs/1710.10903>`_ paper
20
+
21
+ .. math::
22
+ \mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} +
23
+ \sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j},
24
+
25
+ where the attention coefficients :math:`\alpha_{i,j}` are computed as
26
+
27
+ .. math::
28
+ \alpha_{i,j} =
29
+ \frac{
30
+ \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top}
31
+ [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j]
32
+ \right)\right)}
33
+ {\sum_{k \in \mathcal{N}(i) \cup \{ i \}}
34
+ \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top}
35
+ [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_k]
36
+ \right)\right)}.
37
+
38
+ Args:
39
+ in_channels (int or tuple): Size of each input sample. A tuple
40
+ corresponds to the sizes of source and target dimensionalities.
41
+ out_channels (int): Size of each output sample.
42
+ heads (int, optional): Number of multi-head-attentions.
43
+ (default: :obj:`1`)
44
+ concat (bool, optional): If set to :obj:`False`, the multi-head
45
+ attentions are averaged instead of concatenated.
46
+ (default: :obj:`True`)
47
+ negative_slope (float, optional): LeakyReLU angle of the negative
48
+ slope. (default: :obj:`0.2`)
49
+ dropout (float, optional): Dropout probability of the normalized
50
+ attention coefficients which exposes each node to a stochastically
51
+ sampled neighborhood during training. (default: :obj:`0`)
52
+ add_self_loops (bool, optional): If set to :obj:`False`, will not add
53
+ self-loops to the input graph. (default: :obj:`True`)
54
+ bias (bool, optional): If set to :obj:`False`, the layer will not learn
55
+ an additive bias. (default: :obj:`True`)
56
+ **kwargs (optional): Additional arguments of
57
+ :class:`torch_geometric.nn.conv.MessagePassing`.
58
+ """
59
+ _alpha: OptTensor
60
+
61
+ def __init__(self, in_channels: Union[int, Tuple[int, int]],
62
+ out_channels: int, heads: int = 1, concat: bool = True,
63
+ negative_slope: float = 0.2, dropout: float = 0.,
64
+ add_self_loops: bool = True, bias: bool = True, **kwargs):
65
+ kwargs.setdefault('aggr', 'add')
66
+ super(GATConv, self).__init__(node_dim=0, **kwargs)
67
+
68
+ self.in_channels = in_channels
69
+ self.out_channels = out_channels
70
+ self.heads = heads
71
+ self.concat = concat
72
+ self.negative_slope = negative_slope
73
+ self.dropout = dropout
74
+ self.add_self_loops = add_self_loops
75
+
76
+ if isinstance(in_channels, int):
77
+ self.lin_l = Linear(in_channels, heads * out_channels, bias=False)
78
+ self.lin_r = self.lin_l
79
+ else:
80
+ self.lin_l = Linear(in_channels[0], heads * out_channels, False)
81
+ self.lin_r = Linear(in_channels[1], heads * out_channels, False)
82
+
83
+ self.att_l = Parameter(torch.Tensor(1, heads, out_channels))
84
+ self.att_r = Parameter(torch.Tensor(1, heads, out_channels))
85
+
86
+ if bias and concat:
87
+ self.bias = Parameter(torch.Tensor(heads * out_channels))
88
+ elif bias and not concat:
89
+ self.bias = Parameter(torch.Tensor(out_channels))
90
+ else:
91
+ self.register_parameter('bias', None)
92
+
93
+ self._alpha = None
94
+
95
+ self.reset_parameters()
96
+ self.edge_weight = None
97
+
98
+ def reset_parameters(self):
99
+ glorot(self.lin_l.weight)
100
+ glorot(self.lin_r.weight)
101
+ glorot(self.att_l)
102
+ glorot(self.att_r)
103
+ zeros(self.bias)
104
+
105
+ def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_weight=None,
106
+ size: Size = None, return_attention_weights=None):
107
+ # type: (Union[Tensor, OptPairTensor], Tensor, Size, NoneType) -> Tensor # noqa
108
+ # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, NoneType) -> Tensor # noqa
109
+ # type: (Union[Tensor, OptPairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa
110
+ # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa
111
+ r"""
112
+
113
+ Args:
114
+ return_attention_weights (bool, optional): If set to :obj:`True`,
115
+ will additionally return the tuple
116
+ :obj:`(edge_index, attention_weights)`, holding the computed
117
+ attention weights for each edge. (default: :obj:`None`)
118
+ """
119
+ H, C = self.heads, self.out_channels
120
+
121
+ x_l: OptTensor = None
122
+ x_r: OptTensor = None
123
+ alpha_l: OptTensor = None
124
+ alpha_r: OptTensor = None
125
+ if isinstance(x, Tensor):
126
+ assert x.dim() == 2, 'Static graphs not supported in `GATConv`.'
127
+ x_l = x_r = self.lin_l(x).view(-1, H, C)
128
+ alpha_l = (x_l * self.att_l).sum(dim=-1)
129
+ alpha_r = (x_r * self.att_r).sum(dim=-1)
130
+ else:
131
+ x_l, x_r = x[0], x[1]
132
+ assert x[0].dim() == 2, 'Static graphs not supported in `GATConv`.'
133
+ x_l = self.lin_l(x_l).view(-1, H, C)
134
+ alpha_l = (x_l * self.att_l).sum(dim=-1)
135
+ if x_r is not None:
136
+ x_r = self.lin_r(x_r).view(-1, H, C)
137
+ alpha_r = (x_r * self.att_r).sum(dim=-1)
138
+
139
+ assert x_l is not None
140
+ assert alpha_l is not None
141
+
142
+ if self.add_self_loops:
143
+ if isinstance(edge_index, Tensor):
144
+ num_nodes = x_l.size(0)
145
+ if x_r is not None:
146
+ num_nodes = min(num_nodes, x_r.size(0))
147
+ if size is not None:
148
+ num_nodes = min(size[0], size[1])
149
+ # edge_index, _ = remove_self_loops(edge_index)
150
+ # edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)
151
+ edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
152
+ edge_index, edge_weight = add_self_loops(edge_index, edge_weight, num_nodes=num_nodes)
153
+ self.edge_weight = edge_weight
154
+
155
+ elif isinstance(edge_index, SparseTensor):
156
+ edge_index = set_diag(edge_index)
157
+
158
+ # propagate_type: (x: OptPairTensor, alpha: OptPairTensor)
159
+ out = self.propagate(edge_index, x=(x_l, x_r),
160
+ alpha=(alpha_l, alpha_r), size=size)
161
+
162
+ alpha = self._alpha
163
+ self._alpha = None
164
+
165
+ if self.concat:
166
+ out = out.view(-1, self.heads * self.out_channels)
167
+ else:
168
+ out = out.mean(dim=1)
169
+
170
+ if self.bias is not None:
171
+ out += self.bias
172
+
173
+ if isinstance(return_attention_weights, bool):
174
+ assert alpha is not None
175
+ if isinstance(edge_index, Tensor):
176
+ return out, (edge_index, alpha)
177
+ elif isinstance(edge_index, SparseTensor):
178
+ return out, edge_index.set_value(alpha, layout='coo')
179
+ else:
180
+ return out
181
+
182
+ def message(self, x_j: Tensor, alpha_j: Tensor, alpha_i: OptTensor,
183
+ index: Tensor, ptr: OptTensor,
184
+ size_i: Optional[int]) -> Tensor:
185
+ alpha = alpha_j if alpha_i is None else alpha_j + alpha_i
186
+ alpha = F.leaky_relu(alpha, self.negative_slope)
187
+ alpha = softmax(alpha, index, ptr, size_i)
188
+ self._alpha = alpha
189
+ alpha = F.dropout(alpha, p=self.dropout, training=self.training)
190
+
191
+ if self.edge_weight is not None:
192
+ x_j = self.edge_weight.view(-1, 1, 1) * x_j
193
+ return x_j * alpha.unsqueeze(-1)
194
+
195
+ def __repr__(self):
196
+ return '{}({}, {}, heads={})'.format(self.__class__.__name__,
197
+ self.in_channels,
198
+ self.out_channels, self.heads)
deeprobust/graph/rl/__init__.py ADDED
File without changes
deeprobust/graph/rl/env.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adversarial Attacks on Neural Networks for Graph Data. ICML 2018.
3
+ https://arxiv.org/abs/1806.02371
4
+ Author's Implementation
5
+ https://github.com/Hanjun-Dai/graph_adversarial_attack
6
+ This part of code is adopted from the author's implementation (Copyright (c) 2018 Dai, Hanjun and Li, Hui and Tian, Tian and Huang, Xin and Wang, Lin and Zhu, Jun and Song, Le) but modified
7
+ to be integrated into the repository.
8
+ """
9
+
10
+ import os
11
+ import sys
12
+ import numpy as np
13
+ import torch
14
+ import networkx as nx
15
+ import random
16
+ from torch.nn.parameter import Parameter
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ import torch.optim as optim
20
+ from tqdm import tqdm
21
+ from copy import deepcopy
22
+ import pickle as cp
23
+ from deeprobust.graph.utils import *
24
+ import scipy.sparse as sp
25
+ from scipy.sparse.linalg import eigsh
26
+ from deeprobust.graph import utils
27
+
28
+ class StaticGraph(object):
29
+ graph = None
30
+
31
+ @staticmethod
32
+ def get_gsize():
33
+ return torch.Size( (len(StaticGraph.graph), len(StaticGraph.graph)) )
34
+
35
+ class GraphNormTool(object):
36
+
37
+ def __init__(self, normalize, gm, device):
38
+ self.adj_norm = normalize
39
+ self.gm = gm
40
+ g = StaticGraph.graph
41
+ edges = np.array(g.edges(), dtype=np.int64)
42
+ rev_edges = np.array([edges[:, 1], edges[:, 0]], dtype=np.int64)
43
+
44
+ # self_edges = np.array([range(len(g)), range(len(g))], dtype=np.int64)
45
+ # edges = np.hstack((edges.T, rev_edges, self_edges))
46
+ edges = np.hstack((edges.T, rev_edges))
47
+ idxes = torch.LongTensor(edges)
48
+ values = torch.ones(idxes.size()[1])
49
+
50
+ self.raw_adj = torch.sparse.FloatTensor(idxes, values, StaticGraph.get_gsize())
51
+ self.raw_adj = self.raw_adj.to(device)
52
+
53
+ self.normed_adj = self.raw_adj.clone()
54
+ if self.adj_norm:
55
+ if self.gm == 'gcn':
56
+ self.normed_adj = utils.normalize_adj_tensor(self.normed_adj, sparse=True)
57
+ # GraphLaplacianNorm(self.normed_adj)
58
+ else:
59
+
60
+ self.normed_adj = utils.degree_normalize_adj_tensor(self.normed_adj, sparse=True)
61
+ # GraphDegreeNorm(self.normed_adj)
62
+
63
+ def norm_extra(self, added_adj = None):
64
+ if added_adj is None:
65
+ return self.normed_adj
66
+
67
+ new_adj = self.raw_adj + added_adj
68
+ if self.adj_norm:
69
+ if self.gm == 'gcn':
70
+ new_adj = utils.normalize_adj_tensor(new_adj, sparse=True)
71
+ else:
72
+ new_adj = utils.degree_normalize_adj_tensor(new_adj, sparse=True)
73
+
74
+ return new_adj
75
+
76
+
77
+ class ModifiedGraph(object):
78
+
79
+ def __init__(self, directed_edges = None, weights = None):
80
+ self.edge_set = set() #(first, second)
81
+ self.node_set = set(range(StaticGraph.get_gsize()[0]))
82
+ self.node_set = np.arange(StaticGraph.get_gsize()[0])
83
+ if directed_edges is not None:
84
+ self.directed_edges = deepcopy(directed_edges)
85
+ self.weights = deepcopy(weights)
86
+ else:
87
+ self.directed_edges = []
88
+ self.weights = []
89
+
90
+ def add_edge(self, x, y, z):
91
+ assert x is not None and y is not None
92
+ if x == y:
93
+ return
94
+ for e in self.directed_edges:
95
+ if e[0] == x and e[1] == y:
96
+ return
97
+ if e[1] == x and e[0] == y:
98
+ return
99
+ self.edge_set.add((x, y)) # (first, second)
100
+ self.edge_set.add((y, x)) # (second, first)
101
+ self.directed_edges.append((x, y))
102
+ # assert z < 0
103
+ self.weights.append(z)
104
+
105
+ def get_extra_adj(self, device):
106
+ if len(self.directed_edges):
107
+ edges = np.array(self.directed_edges, dtype=np.int64)
108
+ rev_edges = np.array([edges[:, 1], edges[:, 0]], dtype=np.int64)
109
+ edges = np.hstack((edges.T, rev_edges))
110
+
111
+ idxes = torch.LongTensor(edges)
112
+ values = torch.Tensor(self.weights + self.weights)
113
+
114
+ added_adj = torch.sparse.FloatTensor(idxes, values, StaticGraph.get_gsize())
115
+
116
+ added_adj = added_adj.to(device)
117
+ return added_adj
118
+ else:
119
+ return None
120
+
121
+ def get_possible_nodes(self, target_node):
122
+ # connected = set()
123
+ connected = [target_node]
124
+ for n1, n2 in self.edge_set:
125
+ if n1 == target_node:
126
+ # connected.add(target_node)
127
+ connected.append(n2)
128
+ return np.setdiff1d(self.node_set, np.array(connected))
129
+
130
+ # return self.node_set - connected
131
+
132
+ class NodeAttackEnv(object):
133
+ """Node attack environment. It executes an action and then change the
134
+ environment status (modify the graph).
135
+ """
136
+
137
+ def __init__(self, features, labels, all_targets, list_action_space, classifier, num_mod=1, reward_type='binary'):
138
+
139
+ self.classifier = classifier
140
+ self.list_action_space = list_action_space
141
+ self.features = features
142
+ self.labels = labels
143
+ self.all_targets = all_targets
144
+ self.num_mod = num_mod
145
+ self.reward_type = reward_type
146
+
147
+ def setup(self, target_nodes):
148
+ self.target_nodes = target_nodes
149
+ self.n_steps = 0
150
+ self.first_nodes = None
151
+ self.rewards = None
152
+ self.binary_rewards = None
153
+ self.modified_list = []
154
+ for i in range(len(self.target_nodes)):
155
+ self.modified_list.append(ModifiedGraph())
156
+
157
+ self.list_acc_of_all = []
158
+
159
+ def step(self, actions):
160
+ """run actions and get rewards
161
+ """
162
+ if self.first_nodes is None: # pick the first node of edge
163
+ assert self.n_steps % 2 == 0
164
+ self.first_nodes = actions[:]
165
+ else:
166
+ for i in range(len(self.target_nodes)):
167
+ # assert self.first_nodes[i] != actions[i]
168
+ # deleta an edge from the graph
169
+ self.modified_list[i].add_edge(self.first_nodes[i], actions[i], -1.0)
170
+ self.first_nodes = None
171
+ self.banned_list = None
172
+ self.n_steps += 1
173
+
174
+ if self.isTerminal():
175
+ # only calc reward when its terminal
176
+ acc_list = []
177
+ loss_list = []
178
+ # for i in tqdm(range(len(self.target_nodes))):
179
+ for i in (range(len(self.target_nodes))):
180
+ device = self.labels.device
181
+ extra_adj = self.modified_list[i].get_extra_adj(device=device)
182
+ adj = self.classifier.norm_tool.norm_extra(extra_adj)
183
+
184
+ output = self.classifier(self.features, adj)
185
+
186
+ loss, acc = loss_acc(output, self.labels, self.all_targets, avg_loss=False)
187
+ # _, loss, acc = self.classifier(self.features, Variable(adj), self.all_targets, self.labels, avg_loss=False)
188
+
189
+ cur_idx = self.all_targets.index(self.target_nodes[i])
190
+ acc = np.copy(acc.double().cpu().view(-1).numpy())
191
+ loss = loss.data.cpu().view(-1).numpy()
192
+ self.list_acc_of_all.append(acc)
193
+ acc_list.append(acc[cur_idx])
194
+ loss_list.append(loss[cur_idx])
195
+
196
+ self.binary_rewards = (np.array(acc_list) * -2.0 + 1.0).astype(np.float32)
197
+ if self.reward_type == 'binary':
198
+ self.rewards = (np.array(acc_list) * -2.0 + 1.0).astype(np.float32)
199
+ else:
200
+ assert self.reward_type == 'nll'
201
+ self.rewards = np.array(loss_list).astype(np.float32)
202
+
203
+ def sample_pos_rewards(self, num_samples):
204
+ assert self.list_acc_of_all is not None
205
+ cands = []
206
+
207
+ for i in range(len(self.list_acc_of_all)):
208
+ succ = np.where( self.list_acc_of_all[i] < 0.9 )[0]
209
+
210
+ for j in range(len(succ)):
211
+
212
+ cands.append((i, self.all_targets[succ[j]]))
213
+
214
+ if num_samples > len(cands):
215
+ return cands
216
+ random.shuffle(cands)
217
+ return cands[0:num_samples]
218
+
219
+ def uniformRandActions(self):
220
+ # TODO: here only support deleting edges
221
+ # seems they sample first node from 2-hop neighbours
222
+ act_list = []
223
+ offset = 0
224
+ for i in range(len(self.target_nodes)):
225
+ cur_node = self.target_nodes[i]
226
+ region = self.list_action_space[cur_node]
227
+
228
+ if self.first_nodes is not None and self.first_nodes[i] is not None:
229
+ region = self.list_action_space[self.first_nodes[i]]
230
+
231
+ if region is None: # singleton node
232
+ cur_action = np.random.randint(len(self.list_action_space))
233
+ else: # select from neighbours or 2-hop neighbours
234
+ cur_action = region[np.random.randint(len(region))]
235
+
236
+ act_list.append(cur_action)
237
+ return act_list
238
+
239
+ def isTerminal(self):
240
+ if self.n_steps == 2 * self.num_mod:
241
+ return True
242
+ return False
243
+
244
+ def getStateRef(self):
245
+ cp_first = [None] * len(self.target_nodes)
246
+ if self.first_nodes is not None:
247
+ cp_first = self.first_nodes
248
+
249
+ return zip(self.target_nodes, self.modified_list, cp_first)
250
+
251
+ def cloneState(self):
252
+ cp_first = [None] * len(self.target_nodes)
253
+ if self.first_nodes is not None:
254
+ cp_first = self.first_nodes[:]
255
+
256
+ return list(zip(self.target_nodes[:], deepcopy(self.modified_list), cp_first))
257
+
258
+
deeprobust/graph/rl/q_net_node.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Adversarial Attacks on Neural Networks for Graph Data. ICML 2018.
3
+ https://arxiv.org/abs/1806.02371
4
+ Author's Implementation
5
+ https://github.com/Hanjun-Dai/graph_adversarial_attack
6
+ This part of code is adopted from the author's implementation (Copyright (c) 2018 Dai, Hanjun and Li, Hui and Tian, Tian and Huang, Xin and Wang, Lin and Zhu, Jun and Song, Le) but modified
7
+ to be integrated into the repository.
8
+ '''
9
+ import os
10
+ import sys
11
+ import numpy as np
12
+ import torch
13
+ import networkx as nx
14
+ import random
15
+ from torch.nn.parameter import Parameter
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import torch.optim as optim
19
+ from tqdm import tqdm
20
+ from deeprobust.graph.rl.env import GraphNormTool
21
+
22
+ class QNetNode(nn.Module):
23
+
24
+ def __init__(self, node_features, node_labels, list_action_space, bilin_q=1, embed_dim=64, mlp_hidden=64, max_lv=1, gm='mean_field', device='cpu'):
25
+ '''
26
+ bilin_q: bilinear q or not
27
+ mlp_hidden: mlp hidden layer size
28
+ mav_lv: max rounds of message passing
29
+ '''
30
+ super(QNetNode, self).__init__()
31
+ self.node_features = node_features
32
+ self.node_labels = node_labels
33
+ self.list_action_space = list_action_space
34
+ self.total_nodes = len(list_action_space)
35
+
36
+ self.bilin_q = bilin_q
37
+ self.embed_dim = embed_dim
38
+ self.mlp_hidden = mlp_hidden
39
+ self.max_lv = max_lv
40
+ self.gm = gm
41
+
42
+ if bilin_q:
43
+ last_wout = embed_dim
44
+ else:
45
+ last_wout = 1
46
+ self.bias_target = Parameter(torch.Tensor(1, embed_dim))
47
+
48
+ if mlp_hidden:
49
+ self.linear_1 = nn.Linear(embed_dim * 2, mlp_hidden)
50
+ self.linear_out = nn.Linear(mlp_hidden, last_wout)
51
+ else:
52
+ self.linear_out = nn.Linear(embed_dim * 2, last_wout)
53
+
54
+ self.w_n2l = Parameter(torch.Tensor(node_features.size()[1], embed_dim))
55
+ self.bias_n2l = Parameter(torch.Tensor(embed_dim))
56
+ self.bias_picked = Parameter(torch.Tensor(1, embed_dim))
57
+ self.conv_params = nn.Linear(embed_dim, embed_dim)
58
+ self.norm_tool = GraphNormTool(normalize=True, gm=self.gm, device=device)
59
+ weights_init(self)
60
+
61
+ def make_spmat(self, n_rows, n_cols, row_idx, col_idx):
62
+ idxes = torch.LongTensor([[row_idx], [col_idx]])
63
+ values = torch.ones(1)
64
+
65
+ sp = torch.sparse.FloatTensor(idxes, values, torch.Size([n_rows, n_cols]))
66
+ if next(self.parameters()).is_cuda:
67
+ sp = sp.cuda()
68
+ return sp
69
+
70
+ def forward(self, time_t, states, actions, greedy_acts=False, is_inference=False):
71
+
72
+ if self.node_features.data.is_sparse:
73
+ input_node_linear = torch.spmm(self.node_features, self.w_n2l)
74
+ else:
75
+ input_node_linear = torch.mm(self.node_features, self.w_n2l)
76
+
77
+ input_node_linear += self.bias_n2l
78
+
79
+ # TODO the number of target nodes is batch_size, it actually parallizes
80
+ target_nodes, batch_graph, picked_nodes = zip(*states)
81
+
82
+ list_pred = []
83
+ prefix_sum = []
84
+ for i in range(len(batch_graph)):
85
+ region = self.list_action_space[target_nodes[i]]
86
+
87
+ node_embed = input_node_linear.clone()
88
+ if picked_nodes is not None and picked_nodes[i] is not None:
89
+ with torch.set_grad_enabled(mode=not is_inference):
90
+ picked_sp = self.make_spmat(self.total_nodes, 1, picked_nodes[i], 0)
91
+ node_embed += torch.spmm(picked_sp, self.bias_picked)
92
+ region = self.list_action_space[picked_nodes[i]]
93
+
94
+ if not self.bilin_q:
95
+ with torch.set_grad_enabled(mode=not is_inference):
96
+ # with torch.no_grad():
97
+ target_sp = self.make_spmat(self.total_nodes, 1, target_nodes[i], 0)
98
+ node_embed += torch.spmm(target_sp, self.bias_target)
99
+
100
+ with torch.set_grad_enabled(mode=not is_inference):
101
+ device = self.node_features.device
102
+ adj = self.norm_tool.norm_extra( batch_graph[i].get_extra_adj(device))
103
+
104
+ lv = 0
105
+ input_message = node_embed
106
+
107
+ node_embed = F.relu(input_message)
108
+ while lv < self.max_lv:
109
+ n2npool = torch.spmm(adj, node_embed)
110
+ node_linear = self.conv_params( n2npool )
111
+ merged_linear = node_linear + input_message
112
+ node_embed = F.relu(merged_linear)
113
+ lv += 1
114
+
115
+ target_embed = node_embed[target_nodes[i], :].view(-1, 1)
116
+ if region is not None:
117
+ node_embed = node_embed[region]
118
+
119
+ graph_embed = torch.mean(node_embed, dim=0, keepdim=True)
120
+
121
+ if actions is None:
122
+ graph_embed = graph_embed.repeat(node_embed.size()[0], 1)
123
+ else:
124
+ if region is not None:
125
+ act_idx = region.index(actions[i])
126
+ else:
127
+ act_idx = actions[i]
128
+ node_embed = node_embed[act_idx, :].view(1, -1)
129
+
130
+ embed_s_a = torch.cat((node_embed, graph_embed), dim=1)
131
+ if self.mlp_hidden:
132
+ embed_s_a = F.relu( self.linear_1(embed_s_a) )
133
+ raw_pred = self.linear_out(embed_s_a)
134
+
135
+ if self.bilin_q:
136
+ raw_pred = torch.mm(raw_pred, target_embed)
137
+ list_pred.append(raw_pred)
138
+
139
+ if greedy_acts:
140
+ actions, _ = node_greedy_actions(target_nodes, picked_nodes, list_pred, self)
141
+
142
+ return actions, list_pred
143
+
144
+ class NStepQNetNode(nn.Module):
145
+
146
+ def __init__(self, num_steps, node_features, node_labels, list_action_space, bilin_q=1, embed_dim=64, mlp_hidden=64, max_lv=1, gm='mean_field', device='cpu'):
147
+
148
+ super(NStepQNetNode, self).__init__()
149
+ self.node_features = node_features
150
+ self.node_labels = node_labels
151
+ self.list_action_space = list_action_space
152
+ self.total_nodes = len(list_action_space)
153
+
154
+ list_mod = []
155
+ for i in range(0, num_steps):
156
+ # list_mod.append(QNetNode(node_features, node_labels, list_action_space))
157
+ list_mod.append(QNetNode(node_features, node_labels, list_action_space, bilin_q, embed_dim, mlp_hidden, max_lv, gm=gm, device=device))
158
+
159
+ self.list_mod = nn.ModuleList(list_mod)
160
+ self.num_steps = num_steps
161
+
162
+ def forward(self, time_t, states, actions, greedy_acts = False, is_inference=False):
163
+ assert time_t >= 0 and time_t < self.num_steps
164
+
165
+ return self.list_mod[time_t](time_t, states, actions, greedy_acts, is_inference)
166
+
167
+
168
+ def glorot_uniform(t):
169
+ if len(t.size()) == 2:
170
+ fan_in, fan_out = t.size()
171
+ elif len(t.size()) == 3:
172
+ # out_ch, in_ch, kernel for Conv 1
173
+ fan_in = t.size()[1] * t.size()[2]
174
+ fan_out = t.size()[0] * t.size()[2]
175
+ else:
176
+ fan_in = np.prod(t.size())
177
+ fan_out = np.prod(t.size())
178
+
179
+ limit = np.sqrt(6.0 / (fan_in + fan_out))
180
+ t.uniform_(-limit, limit)
181
+
182
+
183
+ def _param_init(m):
184
+ if isinstance(m, Parameter):
185
+ glorot_uniform(m.data)
186
+ elif isinstance(m, nn.Linear):
187
+ m.bias.data.zero_()
188
+ glorot_uniform(m.weight.data)
189
+
190
+ def weights_init(m):
191
+ for p in m.modules():
192
+ if isinstance(p, nn.ParameterList):
193
+ for pp in p:
194
+ _param_init(pp)
195
+ else:
196
+ _param_init(p)
197
+
198
+ for name, p in m.named_parameters():
199
+ if not '.' in name: # top-level parameters
200
+ _param_init(p)
201
+
202
+ def node_greedy_actions(target_nodes, picked_nodes, list_q, net):
203
+ assert len(target_nodes) == len(list_q)
204
+
205
+ actions = []
206
+ values = []
207
+ for i in range(len(target_nodes)):
208
+ region = net.list_action_space[target_nodes[i]]
209
+ if picked_nodes is not None and picked_nodes[i] is not None:
210
+ region = net.list_action_space[picked_nodes[i]]
211
+ if region is None:
212
+ assert list_q[i].size()[0] == net.total_nodes
213
+ else:
214
+ assert len(region) == list_q[i].size()[0]
215
+
216
+ val, act = torch.max(list_q[i], dim=0)
217
+ values.append(val)
218
+ if region is not None:
219
+ act = region[act.data.cpu().numpy()[0]]
220
+ # act = Variable(torch.LongTensor([act]))
221
+ act = torch.LongTensor([act])
222
+ actions.append(act)
223
+ else:
224
+ actions.append(act)
225
+
226
+ return torch.cat(actions, dim=0).data, torch.cat(values, dim=0).data
227
+
228
+
deeprobust/image/adversary_examples/advexample.png ADDED
deeprobust/image/adversary_examples/cifar_advexample_orig.png ADDED
deeprobust/image/adversary_examples/cifar_advexample_pgd.png ADDED
deeprobust/image/adversary_examples/deepfool_diff.png ADDED
deeprobust/image/adversary_examples/imageexample.png ADDED
deeprobust/image/adversary_examples/test.jpg ADDED
deeprobust/image/adversary_examples/test1.jpg ADDED
deeprobust/image/evaluation_attack.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import torch
3
+ from torchvision import datasets,models,transforms
4
+ import torch.nn.functional as F
5
+ import os
6
+
7
+ import numpy as np
8
+ import argparse
9
+ import matplotlib.pyplot as plt
10
+ import random
11
+
12
+ from deeprobust.image import utils
13
+
14
+ def run_attack(attackmethod, batch_size, batch_num, device, test_loader, random_targeted = False, target_label = -1, **kwargs):
15
+ test_loss = 0
16
+ correct = 0
17
+ samplenum = 1000
18
+ count = 0
19
+ classnum = 10
20
+ for count, (data, target) in enumerate(test_loader):
21
+ if count == batch_num:
22
+ break
23
+ print('batch:{}'.format(count))
24
+
25
+ data, target = data.to(device), target.to(device)
26
+ if(random_targeted == True):
27
+ r = list(range(0, target)) + list(range(target+1, classnum))
28
+ target_label = random.choice(r)
29
+ adv_example = attackmethod.generate(data, target, target_label = target_label, **kwargs)
30
+
31
+ elif(target_label >= 0):
32
+ adv_example = attackmethod.generate(data, target, target_label = target_label, **kwargs)
33
+
34
+ else:
35
+ adv_example = attackmethod.generate(data, target, **kwargs)
36
+
37
+ output = model(adv_example)
38
+ test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
39
+
40
+ pred = output.argmax(dim = 1, keepdim = True) # get the index of the max log-probability.
41
+
42
+ correct += pred.eq(target.view_as(pred)).sum().item()
43
+
44
+ batch_num = count+1
45
+ test_loss /= len(test_loader.dataset)
46
+ print("===== ACCURACY =====")
47
+ print('Attack Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
48
+ test_loss, correct, batch_num * batch_size,
49
+ 100. * correct / (batch_num * batch_size)))
50
+
51
+ def load_net(attack_model, filename, path):
52
+ if(attack_model == "CNN"):
53
+ from deeprobust.image.netmodels.CNN import Net
54
+
55
+ model = Net()
56
+ if(attack_model == "ResNet18"):
57
+ import deeprobust.image.netmodels.resnet as Net
58
+ model = Net.ResNet18()
59
+
60
+ model.load_state_dict(torch.load(path + filename))
61
+ model.eval()
62
+ return model
63
+
64
+ def generate_dataloader(dataset, batch_size):
65
+ if(dataset == "MNIST"):
66
+ test_loader = torch.utils.data.DataLoader(
67
+ datasets.MNIST('deeprobust/image/data', train = False,
68
+ download = True,
69
+ transform = transforms.Compose([transforms.ToTensor()])),
70
+ batch_size = args.batch_size,
71
+ shuffle = True)
72
+ print("Loading MNIST dataset.")
73
+
74
+ elif(dataset == "CIFAR" or args.dataset == 'CIFAR10'):
75
+ test_loader = torch.utils.data.DataLoader(
76
+ datasets.CIFAR10('deeprobust/image/data', train = False,
77
+ download = True,
78
+ transform = transforms.Compose([transforms.ToTensor()])),
79
+ batch_size = args.batch_size,
80
+ shuffle = True)
81
+ classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
82
+ print("Loading CIFAR10 dataset.")
83
+
84
+ elif(dataset == "ImageNet"):
85
+ test_loader = torch.utils.data.DataLoader(
86
+ datasets.CIFAR10('deeprobust/image/data', train=False,
87
+ download = True,
88
+ transform = transforms.Compose([transforms.ToTensor()])),
89
+ batch_size = args.batch_size,
90
+ shuffle = True)
91
+ print("Loading ImageNet dataset.")
92
+ return test_loader
93
+
94
+ def parameter_parser():
95
+ parser = argparse.ArgumentParser(description = "Run attack algorithms.", usage ='Use -h for more information.')
96
+
97
+ parser.add_argument("--attack_method",
98
+ default = 'PGD',
99
+ help = "Choose a attack algorithm from: PGD(default), FGSM, LBFGS, CW, deepfool, onepixel, Nattack")
100
+ parser.add_argument("--attack_model",
101
+ default = "CNN",
102
+ help = "Choose network structure from: CNN, ResNet")
103
+ parser.add_argument("--path",
104
+ default = "./trained_models/",
105
+ help = "Type the path where the model is saved.")
106
+ parser.add_argument("--file_name",
107
+ default = 'MNIST_CNN_epoch_20.pt',
108
+ help = "Type the file_name of the model that is to be attack. The model structure should be matched with the ATTACK_MODEL parameter.")
109
+ parser.add_argument("--dataset",
110
+ default = 'MNIST',
111
+ help = "Choose a dataset from: MNIST(default), CIFAR(or CIFAR10), ImageNet")
112
+ parser.add_argument("--epsilon", type = float, default = 0.3)
113
+ parser.add_argument("--batch_num", type = int, default = 1000)
114
+ parser.add_argument("--batch_size", type = int, default = 1000)
115
+ parser.add_argument("--num_steps", type = int, default = 40)
116
+ parser.add_argument("--step_size", type = float, default = 0.01)
117
+ parser.add_argument("--random_targeted", type = bool, default = False,
118
+ help = "default: False. By setting this parameter be True, the program would random generate target labels for the input samples.")
119
+ parser.add_argument("--target_label", type = int, default = -1,
120
+ help = "default: -1. Generate all attack Fixed target label.")
121
+ parser.add_argument("--device", default = 'cuda',
122
+ help = "Choose the device.")
123
+
124
+ return parser.parse_args()
125
+
126
+
127
+ if __name__ == "__main__":
128
+ # read arguments
129
+ args = parameter_parser() # read argument and creat an argparse object
130
+
131
+ # download example model
132
+ example_model_path = './trained_models/MNIST_CNN_epoch_20.pt'
133
+ if not (os.path.exists('./trained_models')):
134
+ os.mkdir('./trained_models')
135
+ print('create path: ./trained_models')
136
+ model_url = "https://github.com/I-am-Bot/deeprobust_trained_model/blob/master/MNIST_CNN_epoch_20.pt?raw=true"
137
+ r = requests.get(model_url)
138
+ print('Downloading example model...')
139
+ with open(example_model_path,'wb') as f:
140
+ f.write(r.content)
141
+ print('Downloaded.')
142
+ # load model
143
+ model = load_net(args.attack_model, args.file_name, args.path)
144
+
145
+ print("===== START ATTACK =====")
146
+ if(args.attack_method == "PGD"):
147
+ from deeprobust.image.attack.pgd import PGD
148
+ test_loader = generate_dataloader(args.dataset, args.batch_size)
149
+ attack_method = PGD(model, args.device)
150
+ utils.tab_printer(args)
151
+ run_attack(attack_method, args.batch_size, args.batch_num, args.device, test_loader, epsilon = args.epsilon)
152
+
153
+ elif(args.attack_method == "FGSM"):
154
+ from deeprobust.image.attack.fgsm import FGSM
155
+ test_loader = generate_dataloader(args.dataset, args.batch_size)
156
+ attack_method = FGSM(model, args.device)
157
+ utils.tab_printer(args)
158
+ run_attack(attack_method, args.batch_size, args.batch_num, args.device, test_loader, epsilon = args.epsilon)
159
+
160
+ elif(args.attack_method == "LBFGS"):
161
+ from deeprobust.image.attack.lbfgs import LBFGS
162
+ try:
163
+ if (args.batch_size >1):
164
+ raise ValueError("batch_size shouldn't be larger than 1.")
165
+ except ValueError:
166
+ args.batch_size = 1
167
+
168
+ try:
169
+ if (args.random_targeted == 0 and args.target_label == -1):
170
+ raise ValueError("No target label assigned. Random generate target for each input.")
171
+ except ValueError:
172
+ args.random_targeted = True
173
+
174
+ utils.tab_printer(args)
175
+ test_loader = generate_dataloader(args.dataset, args.batch_size)
176
+ attack_method = LBFGS(model, args.device)
177
+ run_attack(attack_method, 1, args.batch_num, args.device, test_loader, random_targeted = args.random_targeted, target_label = args.target_label)
178
+
179
+ elif(args.attack_method == "CW"):
180
+ from deeprobust.image.attack.cw import CarliniWagner
181
+ attack_method = CarliniWagner(model, args.device)
182
+ try:
183
+ if (args.batch_size > 1):
184
+ raise ValueError("batch_size shouldn't be larger than 1.")
185
+ except ValueError:
186
+ args.batch_size = 1
187
+
188
+ try:
189
+ if (args.random_targeted == 0 and args.target_label == -1):
190
+ raise ValueError("No target label assigned. Random generate target for each input.")
191
+ except ValueError:
192
+ args.random_targeted = True
193
+
194
+ utils.tab_printer(args)
195
+ test_loader = generate_dataloader(args.dataset, args.batch_size)
196
+ run_attack(attack_method, 1, args.batch_num, args.device, test_loader, random_targeted = args.random_targeted, target_label = args.target_label)
197
+
198
+ elif(args.attack_method == "deepfool"):
199
+ from deeprobust.image.attack.deepfool import DeepFool
200
+ attack_method = DeepFool(model, args.device)
201
+ try:
202
+ if (args.batch_size > 1):
203
+ raise ValueError("batch_size shouldn't be larger than 1.")
204
+ except ValueError:
205
+ args.batch_size = 1
206
+
207
+ utils.tab_printer(args)
208
+ test_loader = generate_dataloader(args.dataset, args.batch_size)
209
+ run_attack(attack_method, args.batch_size, args.batch_num, args.device, test_loader)
210
+
211
+ elif(args.attack_method == "onepixel"):
212
+ from deeprobust.image.attack.onepixel import Onepixel
213
+ attack_method = Onepixel(model, args.device)
214
+ try:
215
+ if (args.batch_size > 1):
216
+ raise ValueError("batch_size shouldn't be larger than 1.")
217
+ except ValueError:
218
+ args.batch_size = 1
219
+
220
+ utils.tab_printer(args)
221
+ test_loader = generate_dataloader(args.dataset, args.batch_size)
222
+ run_attack(attack_method, args.batch_size, args.batch_num, args.device, test_loader)
223
+
224
+ elif(args.attack_method == "Nattack"):
225
+ pass
226
+
deeprobust/image/netmodels/CNN.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is an implementatio of a Convolution Neural Network with 2 Convolutional layer.
3
+ """
4
+
5
+ from __future__ import print_function
6
+ import argparse
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F #233
10
+ import torch.optim as optim
11
+ from torchvision import datasets, transforms
12
+ import numpy as np
13
+ from PIL import Image
14
+
15
+ class Net(nn.Module):
16
+ """Model counterparts.
17
+ """
18
+
19
+ def __init__(self, in_channel1 = 1, out_channel1 = 32, out_channel2 = 64, H = 28, W = 28):
20
+ super(Net, self).__init__()
21
+ self.H = H
22
+ self.W = W
23
+ self.out_channel2 = out_channel2
24
+
25
+ ## define two convolutional layers
26
+ self.conv1 = nn.Conv2d(in_channels = in_channel1,
27
+ out_channels = out_channel1,
28
+ kernel_size = 5,
29
+ stride= 1,
30
+ padding = (2,2))
31
+ self.conv2 = nn.Conv2d(in_channels = out_channel1,
32
+ out_channels = out_channel2,
33
+ kernel_size = 5,
34
+ stride = 1,
35
+ padding = (2,2))
36
+
37
+ ## define two linear layers
38
+ self.fc1 = nn.Linear(int(self.H/4) * int(self.W/4) * out_channel2, 1024)
39
+ self.fc2 = nn.Linear(1024, 10)
40
+
41
+ def forward(self, x):
42
+
43
+ x = F.relu(self.conv1(x))
44
+ x = F.max_pool2d(x, 2, 2)
45
+ x = F.relu(self.conv2(x))
46
+ x = F.max_pool2d(x, 2, 2)
47
+ x = x.view(-1, int(self.H/4) * int(self.W/4) * self.out_channel2)
48
+ x = F.relu(self.fc1(x))
49
+ x = self.fc2(x)
50
+ return x
51
+
52
+ def get_logits(self, x):
53
+ x = F.relu(self.conv1(x))
54
+ x = F.max_pool2d(x, 2, 2)
55
+ x = F.relu(self.conv2(x))
56
+ x = F.max_pool2d(x, 2, 2)
57
+ x = x.view(-1, int(self.H/4) * int(self.W/4) * self.out_channel2)
58
+ x = F.relu(self.fc1(x))
59
+ x = self.fc2(x)
60
+ return x
61
+
62
+ def train(model, device, train_loader, optimizer, epoch):
63
+ """train network.
64
+
65
+ Parameters
66
+ ----------
67
+ model :
68
+ model
69
+ device :
70
+ device(option:'cpu','cuda')
71
+ train_loader :
72
+ training data loader
73
+ optimizer :
74
+ optimizer
75
+ epoch :
76
+ epoch
77
+ """
78
+ model.train()
79
+ for batch_idx, (data, target) in enumerate(train_loader):
80
+ data, target = data.to(device), target.to(device)
81
+ optimizer.zero_grad()
82
+ output = model(data)
83
+ loss = F.cross_entropy(output, target)
84
+ loss.backward()
85
+ optimizer.step()
86
+
87
+ #print every 10
88
+ if batch_idx % 10 == 0:
89
+ print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
90
+ epoch, batch_idx * len(data), len(train_loader.dataset),
91
+ 100. * batch_idx / len(train_loader), loss.item()))
92
+
93
+
94
+ def test(model, device, test_loader):
95
+ """test network.
96
+
97
+ Parameters
98
+ ----------
99
+ model :
100
+ model
101
+ device :
102
+ device(option:'cpu', 'cuda')
103
+ test_loader :
104
+ testing data loader
105
+ """
106
+ model.eval()
107
+
108
+ test_loss = 0
109
+ correct = 0
110
+ with torch.no_grad():
111
+ for data, target in test_loader:
112
+ data, target = data.to(device), target.to(device)
113
+ output = model(data)
114
+ test_loss += F.cross_entropy(output, target, reduction='sum').item() # sum up batch loss
115
+ pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
116
+ correct += pred.eq(target.view_as(pred)).sum().item()
117
+
118
+ test_loss /= len(test_loader.dataset)
119
+
120
+ print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
121
+ test_loss, correct, len(test_loader.dataset),
122
+ 100. * correct / len(test_loader.dataset)))
123
+
124
+
125
+
deeprobust/image/netmodels/CNN_multilayer.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is an implementation of Convolution Neural Network with multi conv layer.
3
+ """
4
+
5
+ from __future__ import print_function
6
+ import argparse
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F #233
10
+ import torch.optim as optim
11
+ from torchvision import datasets, transforms
12
+ import numpy as np
13
+ from PIL import Image
14
+
15
+ class Net(nn.Module):
16
+ def __init__(self, in_channel1 = 1, out_channel1 = 32, out_channel2 = 64, H = 28, W = 28):
17
+ super(Net, self).__init__()
18
+ self.H = H
19
+ self.W = W
20
+ self.out_channel2 = out_channel2
21
+
22
+ ## define two convolutional layers
23
+ self.conv1 = nn.Conv2d(in_channels = in_channel1,
24
+ out_channels = out_channel1,
25
+ kernel_size = 5,
26
+ stride= 1,
27
+ padding = (2,2))
28
+ self.conv2 = nn.Conv2d(in_channels = out_channel1,
29
+ out_channels = out_channel2,
30
+ kernel_size = 5,
31
+ stride = 1,
32
+ padding = (2,2))
33
+
34
+ ## define two linear layers
35
+ self.fc1 = nn.Linear(int(H/4)*int(W/4)* out_channel2, 1024)
36
+ self.fc2 = nn.Linear(1024, 10)
37
+
38
+ def forward(self, x):
39
+ self.layers[0] = F.relu(self.conv1(x))
40
+ self.layers[1] = F.max_pool2d(x, 2, 2)
41
+ self.layers[2] = F.relu(self.conv2(x))
42
+ self.layers[3] = F.max_pool2d(x, 2, 2)
43
+ self.layers[4] = x.view(-1, int(self.H/4) * int(self.W/4) * self.out_channel2)
44
+ self.layers[5] = F.relu(self.fc1(x))
45
+ self.layers[6] = self.fc2(x)
46
+ return F.log_softmax(layers[6], dim=1)
47
+
48
+ #def get_logits(self, x):
49
+ #x = F.relu(self.conv1(x))
50
+ #x = F.max_pool2d(x, 2, 2)
51
+ #x = F.relu(self.conv2(x))
52
+ #x = F.max_pool2d(x, 2, 2)
53
+ #x = x.view(-1, 4* 4 * 50)
54
+ #x = F.relu(self.fc1(x))
55
+ #x = self.fc2(x)
56
+ #return x
57
+
58
+
59
+ def train(model, device, train_loader, optimizer, epoch):
60
+ """train.
61
+
62
+ Parameters
63
+ ----------
64
+ model :
65
+ model
66
+ device :
67
+ device
68
+ train_loader :
69
+ train_loader
70
+ optimizer :
71
+ optimizer
72
+ epoch :
73
+ epoch
74
+ """
75
+ model.train()
76
+ for batch_idx, (data, target) in enumerate(train_loader):
77
+ data, target = data.to(device), target.to(device)
78
+ optimizer.zero_grad()
79
+ output = model(data)
80
+ loss = F.nll_loss(output, target)
81
+ loss.backward()
82
+ optimizer.step()
83
+
84
+ #print every 10
85
+ if batch_idx % 10 == 0:
86
+ print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
87
+ epoch, batch_idx * len(data), len(train_loader.dataset),
88
+ 100. * batch_idx / len(train_loader), loss.item()))
89
+
90
+
91
+ def test(model, device, test_loader):
92
+ """test.
93
+
94
+ Parameters
95
+ ----------
96
+ model :
97
+ model
98
+ device :
99
+ device
100
+ test_loader :
101
+ test_loader
102
+ """
103
+ model.eval()
104
+
105
+ test_loss = 0
106
+ correct = 0
107
+ with torch.no_grad():
108
+ for data, target in test_loader:
109
+ data, target = data.to(device), target.to(device)
110
+ output = model(data)
111
+ test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
112
+ pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
113
+ correct += pred.eq(target.view_as(pred)).sum().item()
114
+
115
+ test_loss /= len(test_loader.dataset)
116
+
117
+ print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
118
+ test_loss, correct, len(test_loader.dataset),
119
+ 100. * correct / len(test_loader.dataset)))
120
+
121
+
122
+
deeprobust/image/netmodels/YOPOCNN.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model for YOPO.
3
+
4
+ Reference
5
+ ---------
6
+ ..[1]https://github.com/a1600012888/YOPO-You-Only-Propagate-Once
7
+
8
+ """
9
+
10
+ from collections import OrderedDict
11
+ import torch.nn as nn
12
+
13
+
14
+ class Net(nn.Module):
15
+ def __init__(self, drop=0.5):
16
+ super(Net, self).__init__()
17
+
18
+ self.num_channels = 1
19
+ self.num_labels = 10
20
+
21
+ activ = nn.ReLU(True)
22
+ self.conv1 = nn.Conv2d(self.num_channels, 32, 3)
23
+ self.layer_one = nn.Sequential(OrderedDict([
24
+ ('conv1', self.conv1),
25
+ ('relu1', activ),]))
26
+
27
+ self.feature_extractor = nn.Sequential(OrderedDict([
28
+ ('conv2', nn.Conv2d(32, 32, 3)),
29
+ ('relu2', activ),
30
+ ('maxpool1', nn.MaxPool2d(2, 2)),
31
+ ('conv3', nn.Conv2d(32, 64, 3)),
32
+ ('relu3', activ),
33
+ ('conv4', nn.Conv2d(64, 64, 3)),
34
+ ('relu4', activ),
35
+ ('maxpool2', nn.MaxPool2d(2, 2)),
36
+ ]))
37
+
38
+ self.classifier = nn.Sequential(OrderedDict([
39
+ ('fc1', nn.Linear(64 * 4 * 4, 200)),
40
+ ('relu1', activ),
41
+ ('drop', nn.Dropout(drop)),
42
+ ('fc2', nn.Linear(200, 200)),
43
+ ('relu2', activ),
44
+ ('fc3', nn.Linear(200, self.num_labels)),
45
+ ]))
46
+ self.other_layers = nn.ModuleList()
47
+ self.other_layers.append(self.feature_extractor)
48
+ self.other_layers.append(self.classifier)
49
+
50
+ for m in self.modules():
51
+ if isinstance(m, (nn.Conv2d)):
52
+ nn.init.kaiming_normal_(m.weight)
53
+ if m.bias is not None:
54
+ nn.init.constant_(m.bias, 0)
55
+ elif isinstance(m, nn.BatchNorm2d):
56
+ nn.init.constant_(m.weight, 1)
57
+ nn.init.constant_(m.bias, 0)
58
+ nn.init.constant_(self.classifier.fc3.weight, 0)
59
+ nn.init.constant_(self.classifier.fc3.bias, 0)
60
+
61
+ def forward(self, input):
62
+ y = self.layer_one(input)
63
+ self.layer_one_out = y
64
+ self.layer_one_out.requires_grad_()
65
+ self.layer_one_out.retain_grad()
66
+ features = self.feature_extractor(y)
67
+ logits = self.classifier(features.view(-1, 64 * 4 * 4))
68
+ return logits
69
+
70
+
deeprobust/image/netmodels/resnet.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Properly implemented ResNet-s for CIFAR10 as described in paper [1].
3
+
4
+ This implementation is from Yerlan Idelbayev.
5
+
6
+ Reference
7
+ ---------
8
+ ..[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
9
+ Deep Residual Learning for Image Recognition. arXiv:1512.03385
10
+ ..[2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
11
+
12
+ '''
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+
19
+ class BasicBlock(nn.Module):
20
+ expansion = 1
21
+
22
+ def __init__(self, in_planes, planes, stride=1):
23
+ super(BasicBlock, self).__init__()
24
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
25
+ self.bn1 = nn.BatchNorm2d(planes)
26
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
27
+ self.bn2 = nn.BatchNorm2d(planes)
28
+
29
+ self.shortcut = nn.Sequential()
30
+ if stride != 1 or in_planes != self.expansion*planes:
31
+ self.shortcut = nn.Sequential(
32
+ nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
33
+ nn.BatchNorm2d(self.expansion*planes)
34
+ )
35
+
36
+ def forward(self, x):
37
+ out = F.relu(self.bn1(self.conv1(x)))
38
+ out = self.bn2(self.conv2(out))
39
+ out += self.shortcut(x)
40
+ out = F.relu(out)
41
+ return out
42
+
43
+
44
+ class Bottleneck(nn.Module):
45
+ expansion = 4
46
+
47
+ def __init__(self, in_planes, planes, stride=1):
48
+ super(Bottleneck, self).__init__()
49
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
50
+ self.bn1 = nn.BatchNorm2d(planes)
51
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
52
+ self.bn2 = nn.BatchNorm2d(planes)
53
+ self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
54
+ self.bn3 = nn.BatchNorm2d(self.expansion*planes)
55
+
56
+ self.shortcut = nn.Sequential()
57
+ if stride != 1 or in_planes != self.expansion*planes:
58
+ self.shortcut = nn.Sequential(
59
+ nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
60
+ nn.BatchNorm2d(self.expansion*planes)
61
+ )
62
+
63
+ def forward(self, x):
64
+ out = F.relu(self.bn1(self.conv1(x)))
65
+ out = F.relu(self.bn2(self.conv2(out)))
66
+ out = self.bn3(self.conv3(out))
67
+ out += self.shortcut(x)
68
+ out = F.relu(out)
69
+ return out
70
+
71
+
72
+ class Net(nn.Module):
73
+ def __init__(self, block, num_blocks, num_classes=10):
74
+ """__init__.
75
+
76
+ Parameters
77
+ ----------
78
+ block :
79
+ block
80
+ num_blocks :
81
+ num_blocks
82
+ num_classes :
83
+ num_classes
84
+ """
85
+ super(Net, self).__init__()
86
+ self.in_planes = 64
87
+
88
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
89
+ self.bn1 = nn.BatchNorm2d(64)
90
+ self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
91
+ self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
92
+ self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
93
+ self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
94
+ self.linear = nn.Linear(512*block.expansion, num_classes)
95
+
96
+ def _make_layer(self, block, planes, num_blocks, stride):
97
+ strides = [stride] + [1]*(num_blocks-1)
98
+ layers = []
99
+ for stride in strides:
100
+ layers.append(block(self.in_planes, planes, stride))
101
+ self.in_planes = planes * block.expansion
102
+ return nn.Sequential(*layers)
103
+
104
+ def forward(self, x):
105
+ out = F.relu(self.bn1(self.conv1(x)))
106
+ out = self.layer1(out)
107
+ out = self.layer2(out)
108
+ out = self.layer3(out)
109
+ out = self.layer4(out)
110
+ out = F.avg_pool2d(out, 4)
111
+ out = out.view(out.size(0), -1)
112
+ out = self.linear(out)
113
+ return out
114
+
115
+
116
+ def ResNet18():
117
+ return Net(BasicBlock, [2,2,2,2])
118
+
119
+ def ResNet34():
120
+ return Net(BasicBlock, [3,4,6,3])
121
+
122
+ def ResNet50():
123
+ return Net(Bottleneck, [3,4,6,3])
124
+
125
+ def ResNet101():
126
+ return Net(Bottleneck, [3,4,23,3])
127
+
128
+ def ResNet152():
129
+ return Net(Bottleneck, [3,8,36,3])
130
+
131
+
132
+ def test(model, device, test_loader):
133
+ model.eval()
134
+
135
+ test_loss = 0
136
+ correct = 0
137
+ with torch.no_grad():
138
+ for data, target in test_loader:
139
+ data, target = data.to(device), target.to(device)
140
+ output = model(data)
141
+ #test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
142
+ loss = F.cross_entropy(output, target)
143
+ pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
144
+ correct += pred.eq(target.view_as(pred)).sum().item()
145
+
146
+ test_loss /= len(test_loader.dataset)
147
+
148
+ print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
149
+ test_loss, correct, len(test_loader.dataset),
150
+ 100. * correct / len(test_loader.dataset)))
151
+
152
+ def train(model, device, train_loader, optimizer, epoch):
153
+ model.train()
154
+
155
+ # lr = util.adjust_learning_rate(optimizer, epoch, args) # don't need it if we use Adam
156
+
157
+ for batch_idx, (data, target) in enumerate(train_loader):
158
+ data, target = torch.tensor(data).to(device), torch.tensor(target).to(device)
159
+ optimizer.zero_grad()
160
+ output = model(data)
161
+ # loss = F.nll_loss(output, target)
162
+ loss = F.cross_entropy(output, target)
163
+ loss.backward()
164
+ optimizer.step()
165
+ if batch_idx % 10 == 0:
166
+ print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
167
+ epoch, batch_idx * len(data), len(train_loader.dataset),
168
+ 100. * batch_idx / len(train_loader), loss.item()))
deeprobust/image/netmodels/train_model.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This function help to train model of different archtecture easily. Select model archtecture and training data, then output corresponding model.
3
+
4
+ """
5
+ from __future__ import print_function
6
+ import os
7
+ import argparse
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F #233
11
+ import torch.optim as optim
12
+ from torchvision import datasets, transforms
13
+ import numpy as np
14
+ from PIL import Image
15
+
16
+ def train(model, data, device, maxepoch, data_path = './', save_per_epoch = 10, seed = 100):
17
+ """train.
18
+
19
+ Parameters
20
+ ----------
21
+ model :
22
+ model(option:'CNN', 'ResNet18', 'ResNet34', 'ResNet50', 'densenet', 'vgg11', 'vgg13', 'vgg16', 'vgg19')
23
+ data :
24
+ data(option:'MNIST','CIFAR10')
25
+ device :
26
+ device(option:'cpu', 'cuda')
27
+ maxepoch :
28
+ training epoch
29
+ data_path :
30
+ data path(default = './')
31
+ save_per_epoch :
32
+ save_per_epoch(default = 10)
33
+ seed :
34
+ seed
35
+
36
+ Examples
37
+ --------
38
+ >>>import deeprobust.image.netmodels.train_model as trainmodel
39
+ >>>trainmodel.train('CNN', 'MNIST', 'cuda', 20)
40
+ """
41
+
42
+ torch.manual_seed(seed)
43
+
44
+ train_loader, test_loader = feed_dataset(data, data_path)
45
+
46
+ if (model == 'CNN'):
47
+ import deeprobust.image.netmodels.CNN as MODEL
48
+ #from deeprobust.image.netmodels.CNN import Net
49
+ train_net = MODEL.Net().to(device)
50
+
51
+ elif (model == 'ResNet18'):
52
+ import deeprobust.image.netmodels.resnet as MODEL
53
+ train_net = MODEL.ResNet18().to(device)
54
+
55
+ elif (model == 'ResNet34'):
56
+ import deeprobust.image.netmodels.resnet as MODEL
57
+ train_net = MODEL.ResNet34().to(device)
58
+
59
+ elif (model == 'ResNet50'):
60
+ import deeprobust.image.netmodels.resnet as MODEL
61
+ train_net = MODEL.ResNet50().to(device)
62
+
63
+ elif (model == 'densenet'):
64
+ import deeprobust.image.netmodels.densenet as MODEL
65
+ train_net = MODEL.densenet_cifar().to(device)
66
+
67
+ elif (model == 'vgg11'):
68
+ import deeprobust.image.netmodels.vgg as MODEL
69
+ train_net = MODEL.VGG('VGG11').to(device)
70
+ elif (model == 'vgg13'):
71
+ import deeprobust.image.netmodels.vgg as MODEL
72
+ train_net = MODEL.VGG('VGG13').to(device)
73
+ elif (model == 'vgg16'):
74
+ import deeprobust.image.netmodels.vgg as MODEL
75
+ train_net = MODEL.VGG('VGG16').to(device)
76
+ elif (model == 'vgg19'):
77
+ import deeprobust.image.netmodels.vgg as MODEL
78
+ train_net = MODEL.VGG('VGG19').to(device)
79
+
80
+
81
+
82
+ optimizer = optim.SGD(train_net.parameters(), lr= 0.1, momentum=0.5)
83
+ scheduler = optim.lr_scheduler.StepLR(optimizer, step_size = 100, gamma = 0.1)
84
+ save_model = True
85
+ for epoch in range(1, maxepoch + 1): ## 5 batches
86
+
87
+ print(epoch)
88
+ MODEL.train(train_net, device, train_loader, optimizer, epoch)
89
+ MODEL.test(train_net, device, test_loader)
90
+
91
+ if (save_model and (epoch % (save_per_epoch) == 0 or epoch == maxepoch)):
92
+ if os.path.isdir('./trained_models/'):
93
+ print('Save model.')
94
+ torch.save(train_net.state_dict(), os.path.join('trained_models', data + "_" + model + "_epoch_" + str(epoch) + ".pt"))
95
+ else:
96
+ os.mkdir('./trained_models/')
97
+ print('Make directory and save model.')
98
+ torch.save(train_net.state_dict(), os.path.join('trained_models', data + "_" + model + "_epoch_" + str(epoch) + ".pt"))
99
+ scheduler.step()
100
+
101
+ def feed_dataset(data, data_dict):
102
+ if(data == 'CIFAR10'):
103
+ transform_train = transforms.Compose([
104
+ transforms.RandomCrop(32, padding=5),
105
+ transforms.RandomHorizontalFlip(),
106
+ transforms.ToTensor(),
107
+ #transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
108
+ ])
109
+
110
+ transform_val = transforms.Compose([
111
+ transforms.ToTensor(),
112
+ #transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
113
+ ])
114
+
115
+ train_loader = torch.utils.data.DataLoader(
116
+ datasets.CIFAR10(data_dict, train=True, download = True,
117
+ transform=transform_train),
118
+ batch_size= 128, shuffle=True) #, **kwargs)
119
+
120
+ test_loader = torch.utils.data.DataLoader(
121
+ datasets.CIFAR10(data_dict, train=False, download = True,
122
+ transform=transform_val),
123
+ batch_size= 1000, shuffle=True) #, **kwargs)
124
+
125
+ elif(data == 'MNIST'):
126
+ train_loader = torch.utils.data.DataLoader(
127
+ datasets.MNIST(data_dict, train=True, download = True,
128
+ transform=transforms.Compose([transforms.ToTensor(),
129
+ transforms.Normalize((0.1307,), (0.3081,))])),
130
+ batch_size=128,
131
+ shuffle=True)
132
+
133
+ test_loader = torch.utils.data.DataLoader(
134
+ datasets.MNIST(data_dict, train=False, download = True,
135
+ transform=transforms.Compose([transforms.ToTensor(),
136
+ transforms.Normalize((0.1307,), (0.3081,))])),
137
+ batch_size=1000,
138
+ shuffle=True)
139
+
140
+ elif(data == 'ImageNet'):
141
+ pass
142
+
143
+ return train_loader, test_loader
144
+
145
+
146
+
deeprobust/image/netmodels/train_resnet.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import argparse
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F #233
6
+ import torch.optim as optim
7
+ from torchvision import datasets, transforms
8
+ import numpy as np
9
+ from PIL import Image
10
+ import CNNmodel
11
+
12
+ torch.manual_seed(100)
13
+ device = torch.device("cuda")
14
+
15
+ train_loader = torch.utils.data.DataLoader(
16
+ datasets.MNIST('../data', train=True, download=True,
17
+ transform=transforms.Compose([transforms.ToTensor(),
18
+ transforms.Normalize((0.1307,), (0.3081,))])),
19
+ batch_size=64,
20
+ shuffle=True)
21
+
22
+ test_loader = torch.utils.data.DataLoader(
23
+ datasets.MNIST('../data', train=False,
24
+ transform=transforms.Compose([transforms.ToTensor(),
25
+ transforms.Normalize((0.1307,), (0.3081,))])),
26
+ batch_size=1000,
27
+ shuffle=True)
28
+
29
+ model = CNNmodel.Net().to(device)
30
+ optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
31
+
32
+ save_model = True
33
+ for epoch in range(1, 5 + 1): ## 5 batches
34
+ print(epoch)
35
+ CNNmodel.train(model, device, train_loader, optimizer, epoch)
36
+ CNNmodel.test(model, device, test_loader)
37
+
38
+ if (save_model):
39
+ torch.save(model.state_dict(), "mnist_cnn.pt")
deeprobust/image/netmodels/vgg.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is an implementation of VGG net.
3
+
4
+ Reference
5
+ ---------
6
+ ..[1]Simonyan, Karen, and Andrew Zisserman. "Very deep convolutional networks for large-scale image recognition." arXiv preprint arXiv:1409.1556 (2014).
7
+ ..[2]Original implementation: https://github.com/kuangliu/pytorch-cifar
8
+ """
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ cfg = {
15
+ 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
16
+ 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
17
+ 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
18
+ 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
19
+ }
20
+
21
+
22
+ class VGG(nn.Module):
23
+ """VGG.
24
+ """
25
+
26
+ def __init__(self, vgg_name):
27
+ super(VGG, self).__init__()
28
+ self.features = self._make_layers(cfg[vgg_name])
29
+ self.classifier = nn.Linear(512, 10)
30
+
31
+ def forward(self, x):
32
+ out = self.features(x)
33
+ out = out.view(out.size(0), -1)
34
+ out = self.classifier(out)
35
+ return out
36
+
37
+ def _make_layers(self, cfg):
38
+ layers = []
39
+ in_channels = 3
40
+ for x in cfg:
41
+ if x == 'M':
42
+ layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
43
+ else:
44
+ layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
45
+ nn.BatchNorm2d(x),
46
+ nn.ReLU(inplace=True)]
47
+ in_channels = x
48
+ layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
49
+ return nn.Sequential(*layers)
50
+
51
+
52
+
53
+ def test(model, device, test_loader):
54
+ """test.
55
+
56
+ Parameters
57
+ ----------
58
+ model :
59
+ model
60
+ device :
61
+ device
62
+ test_loader :
63
+ test_loader
64
+ """
65
+ model.eval()
66
+
67
+ test_loss = 0
68
+ correct = 0
69
+ with torch.no_grad():
70
+ for data, target in test_loader:
71
+ data, target = data.to(device), target.to(device)
72
+ output = model(data)
73
+ #test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
74
+
75
+ test_loss += F.cross_entropy(output, target)
76
+ pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
77
+ correct += pred.eq(target.view_as(pred)).sum().item()
78
+
79
+ test_loss /= len(test_loader.dataset)
80
+
81
+ print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
82
+ test_loss, correct, len(test_loader.dataset),
83
+ 100. * correct / len(test_loader.dataset)))
84
+
85
+ def train(model, device, train_loader, optimizer, epoch):
86
+ """train.
87
+
88
+ Parameters
89
+ ----------
90
+ model :
91
+ model
92
+ device :
93
+ device
94
+ train_loader :
95
+ train_loader
96
+ optimizer :
97
+ optimizer
98
+ epoch :
99
+ epoch
100
+ """
101
+ model.train()
102
+
103
+ # lr = util.adjust_learning_rate(optimizer, epoch, args) # don't need it if we use Adam
104
+
105
+ for batch_idx, (data, target) in enumerate(train_loader):
106
+ data, target = torch.tensor(data).to(device), torch.tensor(target).to(device)
107
+ optimizer.zero_grad()
108
+ output = model(data)
109
+ # loss = F.nll_loss(output, target)
110
+ loss = F.cross_entropy(output, target)
111
+ loss.backward()
112
+ optimizer.step()
113
+ if batch_idx % 10 == 0:
114
+ print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
115
+ epoch, batch_idx * len(data), len(train_loader.dataset),
116
+ 100. * batch_idx / len(train_loader), loss.item()/data.shape[0]))
deeprobust/image/synset_words.txt ADDED
@@ -0,0 +1,1000 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ n01440764 tench, Tinca tinca
2
+ n01443537 goldfish, Carassius auratus
3
+ n01484850 great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias
4
+ n01491361 tiger shark, Galeocerdo cuvieri
5
+ n01494475 hammerhead, hammerhead shark
6
+ n01496331 electric ray, crampfish, numbfish, torpedo
7
+ n01498041 stingray
8
+ n01514668 cock
9
+ n01514859 hen
10
+ n01518878 ostrich, Struthio camelus
11
+ n01530575 brambling, Fringilla montifringilla
12
+ n01531178 goldfinch, Carduelis carduelis
13
+ n01532829 house finch, linnet, Carpodacus mexicanus
14
+ n01534433 junco, snowbird
15
+ n01537544 indigo bunting, indigo finch, indigo bird, Passerina cyanea
16
+ n01558993 robin, American robin, Turdus migratorius
17
+ n01560419 bulbul
18
+ n01580077 jay
19
+ n01582220 magpie
20
+ n01592084 chickadee
21
+ n01601694 water ouzel, dipper
22
+ n01608432 kite
23
+ n01614925 bald eagle, American eagle, Haliaeetus leucocephalus
24
+ n01616318 vulture
25
+ n01622779 great grey owl, great gray owl, Strix nebulosa
26
+ n01629819 European fire salamander, Salamandra salamandra
27
+ n01630670 common newt, Triturus vulgaris
28
+ n01631663 eft
29
+ n01632458 spotted salamander, Ambystoma maculatum
30
+ n01632777 axolotl, mud puppy, Ambystoma mexicanum
31
+ n01641577 bullfrog, Rana catesbeiana
32
+ n01644373 tree frog, tree-frog
33
+ n01644900 tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui
34
+ n01664065 loggerhead, loggerhead turtle, Caretta caretta
35
+ n01665541 leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea
36
+ n01667114 mud turtle
37
+ n01667778 terrapin
38
+ n01669191 box turtle, box tortoise
39
+ n01675722 banded gecko
40
+ n01677366 common iguana, iguana, Iguana iguana
41
+ n01682714 American chameleon, anole, Anolis carolinensis
42
+ n01685808 whiptail, whiptail lizard
43
+ n01687978 agama
44
+ n01688243 frilled lizard, Chlamydosaurus kingi
45
+ n01689811 alligator lizard
46
+ n01692333 Gila monster, Heloderma suspectum
47
+ n01693334 green lizard, Lacerta viridis
48
+ n01694178 African chameleon, Chamaeleo chamaeleon
49
+ n01695060 Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis
50
+ n01697457 African crocodile, Nile crocodile, Crocodylus niloticus
51
+ n01698640 American alligator, Alligator mississipiensis
52
+ n01704323 triceratops
53
+ n01728572 thunder snake, worm snake, Carphophis amoenus
54
+ n01728920 ringneck snake, ring-necked snake, ring snake
55
+ n01729322 hognose snake, puff adder, sand viper
56
+ n01729977 green snake, grass snake
57
+ n01734418 king snake, kingsnake
58
+ n01735189 garter snake, grass snake
59
+ n01737021 water snake
60
+ n01739381 vine snake
61
+ n01740131 night snake, Hypsiglena torquata
62
+ n01742172 boa constrictor, Constrictor constrictor
63
+ n01744401 rock python, rock snake, Python sebae
64
+ n01748264 Indian cobra, Naja naja
65
+ n01749939 green mamba
66
+ n01751748 sea snake
67
+ n01753488 horned viper, cerastes, sand viper, horned asp, Cerastes cornutus
68
+ n01755581 diamondback, diamondback rattlesnake, Crotalus adamanteus
69
+ n01756291 sidewinder, horned rattlesnake, Crotalus cerastes
70
+ n01768244 trilobite
71
+ n01770081 harvestman, daddy longlegs, Phalangium opilio
72
+ n01770393 scorpion
73
+ n01773157 black and gold garden spider, Argiope aurantia
74
+ n01773549 barn spider, Araneus cavaticus
75
+ n01773797 garden spider, Aranea diademata
76
+ n01774384 black widow, Latrodectus mactans
77
+ n01774750 tarantula
78
+ n01775062 wolf spider, hunting spider
79
+ n01776313 tick
80
+ n01784675 centipede
81
+ n01795545 black grouse
82
+ n01796340 ptarmigan
83
+ n01797886 ruffed grouse, partridge, Bonasa umbellus
84
+ n01798484 prairie chicken, prairie grouse, prairie fowl
85
+ n01806143 peacock
86
+ n01806567 quail
87
+ n01807496 partridge
88
+ n01817953 African grey, African gray, Psittacus erithacus
89
+ n01818515 macaw
90
+ n01819313 sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita
91
+ n01820546 lorikeet
92
+ n01824575 coucal
93
+ n01828970 bee eater
94
+ n01829413 hornbill
95
+ n01833805 hummingbird
96
+ n01843065 jacamar
97
+ n01843383 toucan
98
+ n01847000 drake
99
+ n01855032 red-breasted merganser, Mergus serrator
100
+ n01855672 goose
101
+ n01860187 black swan, Cygnus atratus
102
+ n01871265 tusker
103
+ n01872401 echidna, spiny anteater, anteater
104
+ n01873310 platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus
105
+ n01877812 wallaby, brush kangaroo
106
+ n01882714 koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus
107
+ n01883070 wombat
108
+ n01910747 jellyfish
109
+ n01914609 sea anemone, anemone
110
+ n01917289 brain coral
111
+ n01924916 flatworm, platyhelminth
112
+ n01930112 nematode, nematode worm, roundworm
113
+ n01943899 conch
114
+ n01944390 snail
115
+ n01945685 slug
116
+ n01950731 sea slug, nudibranch
117
+ n01955084 chiton, coat-of-mail shell, sea cradle, polyplacophore
118
+ n01968897 chambered nautilus, pearly nautilus, nautilus
119
+ n01978287 Dungeness crab, Cancer magister
120
+ n01978455 rock crab, Cancer irroratus
121
+ n01980166 fiddler crab
122
+ n01981276 king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica
123
+ n01983481 American lobster, Northern lobster, Maine lobster, Homarus americanus
124
+ n01984695 spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish
125
+ n01985128 crayfish, crawfish, crawdad, crawdaddy
126
+ n01986214 hermit crab
127
+ n01990800 isopod
128
+ n02002556 white stork, Ciconia ciconia
129
+ n02002724 black stork, Ciconia nigra
130
+ n02006656 spoonbill
131
+ n02007558 flamingo
132
+ n02009229 little blue heron, Egretta caerulea
133
+ n02009912 American egret, great white heron, Egretta albus
134
+ n02011460 bittern
135
+ n02012849 crane
136
+ n02013706 limpkin, Aramus pictus
137
+ n02017213 European gallinule, Porphyrio porphyrio
138
+ n02018207 American coot, marsh hen, mud hen, water hen, Fulica americana
139
+ n02018795 bustard
140
+ n02025239 ruddy turnstone, Arenaria interpres
141
+ n02027492 red-backed sandpiper, dunlin, Erolia alpina
142
+ n02028035 redshank, Tringa totanus
143
+ n02033041 dowitcher
144
+ n02037110 oystercatcher, oyster catcher
145
+ n02051845 pelican
146
+ n02056570 king penguin, Aptenodytes patagonica
147
+ n02058221 albatross, mollymawk
148
+ n02066245 grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus
149
+ n02071294 killer whale, killer, orca, grampus, sea wolf, Orcinus orca
150
+ n02074367 dugong, Dugong dugon
151
+ n02077923 sea lion
152
+ n02085620 Chihuahua
153
+ n02085782 Japanese spaniel
154
+ n02085936 Maltese dog, Maltese terrier, Maltese
155
+ n02086079 Pekinese, Pekingese, Peke
156
+ n02086240 Shih-Tzu
157
+ n02086646 Blenheim spaniel
158
+ n02086910 papillon
159
+ n02087046 toy terrier
160
+ n02087394 Rhodesian ridgeback
161
+ n02088094 Afghan hound, Afghan
162
+ n02088238 basset, basset hound
163
+ n02088364 beagle
164
+ n02088466 bloodhound, sleuthhound
165
+ n02088632 bluetick
166
+ n02089078 black-and-tan coonhound
167
+ n02089867 Walker hound, Walker foxhound
168
+ n02089973 English foxhound
169
+ n02090379 redbone
170
+ n02090622 borzoi, Russian wolfhound
171
+ n02090721 Irish wolfhound
172
+ n02091032 Italian greyhound
173
+ n02091134 whippet
174
+ n02091244 Ibizan hound, Ibizan Podenco
175
+ n02091467 Norwegian elkhound, elkhound
176
+ n02091635 otterhound, otter hound
177
+ n02091831 Saluki, gazelle hound
178
+ n02092002 Scottish deerhound, deerhound
179
+ n02092339 Weimaraner
180
+ n02093256 Staffordshire bullterrier, Staffordshire bull terrier
181
+ n02093428 American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier
182
+ n02093647 Bedlington terrier
183
+ n02093754 Border terrier
184
+ n02093859 Kerry blue terrier
185
+ n02093991 Irish terrier
186
+ n02094114 Norfolk terrier
187
+ n02094258 Norwich terrier
188
+ n02094433 Yorkshire terrier
189
+ n02095314 wire-haired fox terrier
190
+ n02095570 Lakeland terrier
191
+ n02095889 Sealyham terrier, Sealyham
192
+ n02096051 Airedale, Airedale terrier
193
+ n02096177 cairn, cairn terrier
194
+ n02096294 Australian terrier
195
+ n02096437 Dandie Dinmont, Dandie Dinmont terrier
196
+ n02096585 Boston bull, Boston terrier
197
+ n02097047 miniature schnauzer
198
+ n02097130 giant schnauzer
199
+ n02097209 standard schnauzer
200
+ n02097298 Scotch terrier, Scottish terrier, Scottie
201
+ n02097474 Tibetan terrier, chrysanthemum dog
202
+ n02097658 silky terrier, Sydney silky
203
+ n02098105 soft-coated wheaten terrier
204
+ n02098286 West Highland white terrier
205
+ n02098413 Lhasa, Lhasa apso
206
+ n02099267 flat-coated retriever
207
+ n02099429 curly-coated retriever
208
+ n02099601 golden retriever
209
+ n02099712 Labrador retriever
210
+ n02099849 Chesapeake Bay retriever
211
+ n02100236 German short-haired pointer
212
+ n02100583 vizsla, Hungarian pointer
213
+ n02100735 English setter
214
+ n02100877 Irish setter, red setter
215
+ n02101006 Gordon setter
216
+ n02101388 Brittany spaniel
217
+ n02101556 clumber, clumber spaniel
218
+ n02102040 English springer, English springer spaniel
219
+ n02102177 Welsh springer spaniel
220
+ n02102318 cocker spaniel, English cocker spaniel, cocker
221
+ n02102480 Sussex spaniel
222
+ n02102973 Irish water spaniel
223
+ n02104029 kuvasz
224
+ n02104365 schipperke
225
+ n02105056 groenendael
226
+ n02105162 malinois
227
+ n02105251 briard
228
+ n02105412 kelpie
229
+ n02105505 komondor
230
+ n02105641 Old English sheepdog, bobtail
231
+ n02105855 Shetland sheepdog, Shetland sheep dog, Shetland
232
+ n02106030 collie
233
+ n02106166 Border collie
234
+ n02106382 Bouvier des Flandres, Bouviers des Flandres
235
+ n02106550 Rottweiler
236
+ n02106662 German shepherd, German shepherd dog, German police dog, alsatian
237
+ n02107142 Doberman, Doberman pinscher
238
+ n02107312 miniature pinscher
239
+ n02107574 Greater Swiss Mountain dog
240
+ n02107683 Bernese mountain dog
241
+ n02107908 Appenzeller
242
+ n02108000 EntleBucher``^
243
+ n02108089 boxer`
244
+ n02108422 bull mastif
245
+ n02108551 Tibetan mastiff
246
+ n02108915 French bulldog
247
+ n02109047 Great Dane
248
+ n02109525 Saint Bernard, St Bernard
249
+ n02109961 Eskimo dog, husky
250
+ n02110063 malamute, malemute, Alaskan malamute
251
+ n02110185 Siberian husky
252
+ n02110341 dalmatian, coach dog, carriage dog
253
+ n02110627 affenpinscher, monkey pinscher, monkey dog
254
+ n02110806 basenji
255
+ n02110958 pug, pug-dog
256
+ n02111129 Leonberg
257
+ n02111277 Newfoundland, Newfoundland dog
258
+ n02111500 Great Pyrenees
259
+ n02111889 Samoyed, Samoyede
260
+ n02112018 Pomeranian
261
+ n02112137 chow, chow chow
262
+ n02112350 keeshond
263
+ n02112706 Brabancon griffon
264
+ n02113023 Pembroke, Pembroke Welsh corgi
265
+ n02113186 Cardigan, Cardigan Welsh corgi
266
+ n02113624 toy poodle
267
+ n02113712 miniature poodle
268
+ n02113799 standard poodle
269
+ n02113978 Mexican hairless
270
+ n02114367 timber wolf, grey wolf, gray wolf, Canis lupus
271
+ n02114548 white wolf, Arctic wolf, Canis lupus tundrarum
272
+ n02114712 red wolf, maned wolf, Canis rufus, Canis niger
273
+ n02114855 coyote, prairie wolf, brush wolf, Canis latrans
274
+ n02115641 dingo, warrigal, warragal, Canis dingo
275
+ n02115913 dhole, Cuon alpinus
276
+ n02116738 African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus
277
+ n02117135 hyena, hyaena
278
+ n02119022 red fox, Vulpes vulpes
279
+ n02119789 kit fox, Vulpes macrotis
280
+ n02120079 Arctic fox, white fox, Alopex lagopus
281
+ n02120505 grey fox, gray fox, Urocyon cinereoargenteus
282
+ n02123045 tabby, tabby cat
283
+ n02123159 tiger cat
284
+ n02123394 Persian cat
285
+ n02123597 Siamese cat, Siamese
286
+ n02124075 Egyptian cat
287
+ n02125311 cougar, puma, catamount, mountain lion, painter, panther, Felis concolor
288
+ n02127052 lynx, catamount
289
+ n02128385 leopard, Panthera pardus
290
+ n02128757 snow leopard, ounce, Panthera uncia
291
+ n02128925 jaguar, panther, Panthera onca, Felis onca
292
+ n02129165 lion, king of beasts, Panthera leo
293
+ n02129604 tiger, Panthera tigris
294
+ n02130308 cheetah, chetah, Acinonyx jubatus
295
+ n02132136 brown bear, bruin, Ursus arctos
296
+ n02133161 American black bear, black bear, Ursus americanus, Euarctos americanus
297
+ n02134084 ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus
298
+ n02134418 sloth bear, Melursus ursinus, Ursus ursinus
299
+ n02137549 mongoose
300
+ n02138441 meerkat, mierkat
301
+ n02165105 tiger beetle
302
+ n02165456 ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle
303
+ n02167151 ground beetle, carabid beetle
304
+ n02168699 long-horned beetle, longicorn, longicorn beetle
305
+ n02169497 leaf beetle, chrysomelid
306
+ n02172182 dung beetle
307
+ n02174001 rhinoceros beetle
308
+ n02177972 weevil
309
+ n02190166 fly
310
+ n02206856 bee
311
+ n02219486 ant, emmet, pismire
312
+ n02226429 grasshopper, hopper
313
+ n02229544 cricket
314
+ n02231487 walking stick, walkingstick, stick insect
315
+ n02233338 cockroach, roach
316
+ n02236044 mantis, mantid
317
+ n02256656 cicada, cicala
318
+ n02259212 leafhopper
319
+ n02264363 lacewing, lacewing fly
320
+ n02268443 dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk
321
+ n02268853 damselfly
322
+ n02276258 admiral
323
+ n02277742 ringlet, ringlet butterfly
324
+ n02279972 monarch, monarch butterfly, milkweed butterfly, Danaus plexippus
325
+ n02280649 cabbage butterfly
326
+ n02281406 sulphur butterfly, sulfur butterfly
327
+ n02281787 lycaenid, lycaenid butterfly
328
+ n02317335 starfish, sea star
329
+ n02319095 sea urchin
330
+ n02321529 sea cucumber, holothurian
331
+ n02325366 wood rabbit, cottontail, cottontail rabbit
332
+ n02326432 hare
333
+ n02328150 Angora, Angora rabbit
334
+ n02342885 hamster
335
+ n02346627 porcupine, hedgehog
336
+ n02356798 fox squirrel, eastern fox squirrel, Sciurus niger
337
+ n02361337 marmot
338
+ n02363005 beaver
339
+ n02364673 guinea pig, Cavia cobaya
340
+ n02389026 sorrel
341
+ n02391049 zebra
342
+ n02395406 hog, pig, grunter, squealer, Sus scrofa
343
+ n02396427 wild boar, boar, Sus scrofa
344
+ n02397096 warthog
345
+ n02398521 hippopotamus, hippo, river horse, Hippopotamus amphibius
346
+ n02403003 ox
347
+ n02408429 water buffalo, water ox, Asiatic buffalo, Bubalus bubalis
348
+ n02410509 bison
349
+ n02412080 ram, tup
350
+ n02415577 bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis
351
+ n02417914 ibex, Capra ibex
352
+ n02422106 hartebeest
353
+ n02422699 impala, Aepyceros melampus
354
+ n02423022 gazelle
355
+ n02437312 Arabian camel, dromedary, Camelus dromedarius
356
+ n02437616 llama
357
+ n02441942 weasel
358
+ n02442845 mink
359
+ n02443114 polecat, fitch, foulmart, foumart, Mustela putorius
360
+ n02443484 black-footed ferret, ferret, Mustela nigripes
361
+ n02444819 otter
362
+ n02445715 skunk, polecat, wood pussy
363
+ n02447366 badger
364
+ n02454379 armadillo
365
+ n02457408 three-toed sloth, ai, Bradypus tridactylus
366
+ n02480495 orangutan, orang, orangutang, Pongo pygmaeus
367
+ n02480855 gorilla, Gorilla gorilla
368
+ n02481823 chimpanzee, chimp, Pan troglodytes
369
+ n02483362 gibbon, Hylobates lar
370
+ n02483708 siamang, Hylobates syndactylus, Symphalangus syndactylus
371
+ n02484975 guenon, guenon monkey
372
+ n02486261 patas, hussar monkey, Erythrocebus patas
373
+ n02486410 baboon
374
+ n02487347 macaque
375
+ n02488291 langur
376
+ n02488702 colobus, colobus monkey
377
+ n02489166 proboscis monkey, Nasalis larvatus
378
+ n02490219 marmoset
379
+ n02492035 capuchin, ringtail, Cebus capucinus
380
+ n02492660 howler monkey, howler
381
+ n02493509 titi, titi monkey
382
+ n02493793 spider monkey, Ateles geoffroyi
383
+ n02494079 squirrel monkey, Saimiri sciureus
384
+ n02497673 Madagascar cat, ring-tailed lemur, Lemur catta
385
+ n02500267 indri, indris, Indri indri, Indri brevicaudatus
386
+ n02504013 Indian elephant, Elephas maximus
387
+ n02504458 African elephant, Loxodonta africana
388
+ n02509815 lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens
389
+ n02510455 giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca
390
+ n02514041 barracouta, snoek
391
+ n02526121 eel
392
+ n02536864 coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch
393
+ n02606052 rock beauty, Holocanthus tricolor
394
+ n02607072 anemone fish
395
+ n02640242 sturgeon
396
+ n02641379 gar, garfish, garpike, billfish, Lepisosteus osseus
397
+ n02643566 lionfish
398
+ n02655020 puffer, pufferfish, blowfish, globefish
399
+ n02666196 abacus
400
+ n02667093 abaya
401
+ n02669723 academic gown, academic robe, judge's robe
402
+ n02672831 accordion, piano accordion, squeeze box
403
+ n02676566 acoustic guitar
404
+ n02687172 aircraft carrier, carrier, flattop, attack aircraft carrier
405
+ n02690373 airliner
406
+ n02692877 airship, dirigible
407
+ n02699494 altar
408
+ n02701002 ambulance
409
+ n02704792 amphibian, amphibious vehicle
410
+ n02708093 analog clock
411
+ n02727426 apiary, bee house
412
+ n02730930 apron
413
+ n02747177 ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin
414
+ n02749479 assault rifle, assault gun
415
+ n02769748 backpack, back pack, knapsack, packsack, rucksack, haversack
416
+ n02776631 bakery, bakeshop, bakehouse
417
+ n02777292 balance beam, beam
418
+ n02782093 balloon
419
+ n02783161 ballpoint, ballpoint pen, ballpen, Biro
420
+ n02786058 Band Aid
421
+ n02787622 banjo
422
+ n02788148 bannister, banister, balustrade, balusters, handrail
423
+ n02790996 barbell
424
+ n02791124 barber chair
425
+ n02791270 barbershop
426
+ n02793495 barn
427
+ n02794156 barometer
428
+ n02795169 barrel, cask
429
+ n02797295 barrow, garden cart, lawn cart, wheelbarrow
430
+ n02799071 baseball
431
+ n02802426 basketball
432
+ n02804414 bassinet
433
+ n02804610 bassoon
434
+ n02807133 bathing cap, swimming cap
435
+ n02808304 bath towel
436
+ n02808440 bathtub, bathing tub, bath, tub
437
+ n02814533 beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon
438
+ n02814860 beacon, lighthouse, beacon light, pharos
439
+ n02815834 beaker
440
+ n02817516 bearskin, busby, shako
441
+ n02823428 beer bottle
442
+ n02823750 beer glass
443
+ n02825657 bell cote, bell cot
444
+ n02834397 bib
445
+ n02835271 bicycle-built-for-two, tandem bicycle, tandem
446
+ n02837789 bikini, two-piece
447
+ n02840245 binder, ring-binder
448
+ n02841315 binoculars, field glasses, opera glasses
449
+ n02843684 birdhouse
450
+ n02859443 boathouse
451
+ n02860847 bobsled, bobsleigh, bob
452
+ n02865351 bolo tie, bolo, bola tie, bola
453
+ n02869837 bonnet, poke bonnet
454
+ n02870880 bookcase
455
+ n02871525 bookshop, bookstore, bookstall
456
+ n02877765 bottlecap
457
+ n02879718 bow
458
+ n02883205 bow tie, bow-tie, bowtie
459
+ n02892201 brass, memorial tablet, plaque
460
+ n02892767 brassiere, bra, bandeau
461
+ n02894605 breakwater, groin, groyne, mole, bulwark, seawall, jetty
462
+ n02895154 breastplate, aegis, egis
463
+ n02906734 broom
464
+ n02909870 bucket, pail
465
+ n02910353 buckle
466
+ n02916936 bulletproof vest
467
+ n02917067 bullet train, bullet
468
+ n02927161 butcher shop, meat market
469
+ n02930766 cab, hack, taxi, taxicab
470
+ n02939185 caldron, cauldron
471
+ n02948072 candle, taper, wax light
472
+ n02950826 cannon
473
+ n02951358 canoe
474
+ n02951585 can opener, tin opener
475
+ n02963159 cardigan
476
+ n02965783 car mirror
477
+ n02966193 carousel, carrousel, merry-go-round, roundabout, whirligig
478
+ n02966687 carpenter's kit, tool kit
479
+ n02971356 carton
480
+ n02974003 car wheel
481
+ n02977058 cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM
482
+ n02978881 cassette
483
+ n02979186 cassette player
484
+ n02980441 castle
485
+ n02981792 catamaran
486
+ n02988304 CD player
487
+ n02992211 cello, violoncello
488
+ n02992529 cellular telephone, cellular phone, cellphone, cell, mobile phone
489
+ n02999410 chain
490
+ n03000134 chainlink fence
491
+ n03000247 chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour
492
+ n03000684 chain saw, chainsaw
493
+ n03014705 chest
494
+ n03016953 chiffonier, commode
495
+ n03017168 chime, bell, gong
496
+ n03018349 china cabinet, china closet
497
+ n03026506 Christmas stocking
498
+ n03028079 church, church building
499
+ n03032252 cinema, movie theater, movie theatre, movie house, picture palace
500
+ n03041632 cleaver, meat cleaver, chopper
501
+ n03042490 cliff dwelling
502
+ n03045698 cloak
503
+ n03047690 clog, geta, patten, sabot
504
+ n03062245 cocktail shaker
505
+ n03063599 coffee mug
506
+ n03063689 coffeepot
507
+ n03065424 coil, spiral, volute, whorl, helix
508
+ n03075370 combination lock
509
+ n03085013 computer keyboard, keypad
510
+ n03089624 confectionery, confectionary, candy store
511
+ n03095699 container ship, containership, container vessel
512
+ n03100240 convertible
513
+ n03109150 corkscrew, bottle screw
514
+ n03110669 cornet, horn, trumpet, trump
515
+ n03124043 cowboy boot
516
+ n03124170 cowboy hat, ten-gallon hat
517
+ n03125729 cradle
518
+ n03126707 crane
519
+ n03127747 crash helmet
520
+ n03127925 crate
521
+ n03131574 crib, cot
522
+ n03133878 Crock Pot
523
+ n03134739 croquet ball
524
+ n03141823 crutch
525
+ n03146219 cuirass
526
+ n03160309 dam, dike, dyke
527
+ n03179701 desk
528
+ n03180011 desktop computer
529
+ n03187595 dial telephone, dial phone
530
+ n03188531 diaper, nappy, napkin
531
+ n03196217 digital clock
532
+ n03197337 digital watch
533
+ n03201208 dining table, board
534
+ n03207743 dishrag, dishcloth
535
+ n03207941 dishwasher, dish washer, dishwashing machine
536
+ n03208938 disk brake, disc brake
537
+ n03216828 dock, dockage, docking facility
538
+ n03218198 dogsled, dog sled, dog sleigh
539
+ n03220513 dome
540
+ n03223299 doormat, welcome mat
541
+ n03240683 drilling platform, offshore rig
542
+ n03249569 drum, membranophone, tympan
543
+ n03250847 drumstick
544
+ n03255030 dumbbell
545
+ n03259280 Dutch oven
546
+ n03271574 electric fan, blower
547
+ n03272010 electric guitar
548
+ n03272562 electric locomotive
549
+ n03290653 entertainment center
550
+ n03291819 envelope
551
+ n03297495 espresso maker
552
+ n03314780 face powder
553
+ n03325584 feather boa, boa
554
+ n03337140 file, file cabinet, filing cabinet
555
+ n03344393 fireboat
556
+ n03345487 fire engine, fire truck
557
+ n03347037 fire screen, fireguard
558
+ n03355925 flagpole, flagstaff
559
+ n03372029 flute, transverse flute
560
+ n03376595 folding chair
561
+ n03379051 football helmet
562
+ n03384352 forklift
563
+ n03388043 fountain
564
+ n03388183 fountain pen
565
+ n03388549 four-poster
566
+ n03393912 freight car
567
+ n03394916 French horn, horn
568
+ n03400231 frying pan, frypan, skillet
569
+ n03404251 fur coat
570
+ n03417042 garbage truck, dustcart
571
+ n03424325 gasmask, respirator, gas helmet
572
+ n03425413 gas pump, gasoline pump, petrol pump, island dispenser
573
+ n03443371 goblet
574
+ n03444034 go-kart
575
+ n03445777 golf ball
576
+ n03445924 golfcart, golf cart
577
+ n03447447 gondola
578
+ n03447721 gong, tam-tam
579
+ n03450230 gown
580
+ n03452741 grand piano, grand
581
+ n03457902 greenhouse, nursery, glasshouse
582
+ n03459775 grille, radiator grille
583
+ n03461385 grocery store, grocery, food market, market
584
+ n03467068 guillotine
585
+ n03476684 hair slide
586
+ n03476991 hair spray
587
+ n03478589 half track
588
+ n03481172 hammer
589
+ n03482405 hamper
590
+ n03483316 hand blower, blow dryer, blow drier, hair dryer, hair drier
591
+ n03485407 hand-held computer, hand-held microcomputer
592
+ n03485794 handkerchief, hankie, hanky, hankey
593
+ n03492542 hard disc, hard disk, fixed disk
594
+ n03494278 harmonica, mouth organ, harp, mouth harp
595
+ n03495258 harp
596
+ n03496892 harvester, reaper
597
+ n03498962 hatchet
598
+ n03527444 holster
599
+ n03529860 home theater, home theatre
600
+ n03530642 honeycomb
601
+ n03532672 hook, claw
602
+ n03534580 hoopskirt, crinoline
603
+ n03535780 horizontal bar, high bar
604
+ n03538406 horse cart, horse-cart
605
+ n03544143 hourglass
606
+ n03584254 iPod
607
+ n03584829 iron, smoothing iron
608
+ n03590841 jack-o'-lantern
609
+ n03594734 jean, blue jean, denim
610
+ n03594945 jeep, landrover
611
+ n03595614 jersey, T-shirt, tee shirt
612
+ n03598930 jigsaw puzzle
613
+ n03599486 jinrikisha, ricksha, rickshaw
614
+ n03602883 joystick
615
+ n03617480 kimono
616
+ n03623198 knee pad
617
+ n03627232 knot
618
+ n03630383 lab coat, laboratory coat
619
+ n03633091 ladle
620
+ n03637318 lampshade, lamp shade
621
+ n03642806 laptop, laptop computer
622
+ n03649909 lawn mower, mower
623
+ n03657121 lens cap, lens cover
624
+ n03658185 letter opener, paper knife, paperknife
625
+ n03661043 library
626
+ n03662601 lifeboat
627
+ n03666591 lighter, light, igniter, ignitor
628
+ n03670208 limousine, limo
629
+ n03673027 liner, ocean liner
630
+ n03676483 lipstick, lip rouge
631
+ n03680355 Loafer
632
+ n03690938 lotion
633
+ n03691459 loudspeaker, speaker, speaker unit, loudspeaker system, speaker system
634
+ n03692522 loupe, jeweler's loupe
635
+ n03697007 lumbermill, sawmill
636
+ n03706229 magnetic compass
637
+ n03709823 mailbag, postbag
638
+ n03710193 mailbox, letter box
639
+ n03710637 maillot
640
+ n03710721 maillot, tank suit
641
+ n03717622 manhole cover
642
+ n03720891 maraca
643
+ n03721384 marimba, xylophone
644
+ n03724870 mask
645
+ n03729826 matchstick
646
+ n03733131 maypole
647
+ n03733281 maze, labyrinth
648
+ n03733805 measuring cup
649
+ n03742115 medicine chest, medicine cabinet
650
+ n03743016 megalith, megalithic structure
651
+ n03759954 microphone, mike
652
+ n03761084 microwave, microwave oven
653
+ n03763968 military uniform
654
+ n03764736 milk can
655
+ n03769881 minibus
656
+ n03770439 miniskirt, mini
657
+ n03770679 minivan
658
+ n03773504 missile
659
+ n03775071 mitten
660
+ n03775546 mixing bowl
661
+ n03776460 mobile home, manufactured home
662
+ n03777568 Model T
663
+ n03777754 modem
664
+ n03781244 monastery
665
+ n03782006 monitor
666
+ n03785016 moped
667
+ n03786901 mortar
668
+ n03787032 mortarboard
669
+ n03788195 mosque
670
+ n03788365 mosquito net
671
+ n03791053 motor scooter, scooter
672
+ n03792782 mountain bike, all-terrain bike, off-roader
673
+ n03792972 mountain tent
674
+ n03793489 mouse, computer mouse
675
+ n03794056 mousetrap
676
+ n03796401 moving van
677
+ n03803284 muzzle
678
+ n03804744 nail
679
+ n03814639 neck brace
680
+ n03814906 necklace
681
+ n03825788 nipple
682
+ n03832673 notebook, notebook computer
683
+ n03837869 obelisk
684
+ n03838899 oboe, hautboy, hautbois
685
+ n03840681 ocarina, sweet potato
686
+ n03841143 odometer, hodometer, mileometer, milometer
687
+ n03843555 oil filter
688
+ n03854065 organ, pipe organ
689
+ n03857828 oscilloscope, scope, cathode-ray oscilloscope, CRO
690
+ n03866082 overskirt
691
+ n03868242 oxcart
692
+ n03868863 oxygen mask
693
+ n03871628 packet
694
+ n03873416 paddle, boat paddle
695
+ n03874293 paddlewheel, paddle wheel
696
+ n03874599 padlock
697
+ n03876231 paintbrush
698
+ n03877472 pajama, pyjama, pj's, jammies
699
+ n03877845 palace
700
+ n03884397 panpipe, pandean pipe, syrinx
701
+ n03887697 paper towel
702
+ n03888257 parachute, chute
703
+ n03888605 parallel bars, bars
704
+ n03891251 park bench
705
+ n03891332 parking meter
706
+ n03895866 passenger car, coach, carriage
707
+ n03899768 patio, terrace
708
+ n03902125 pay-phone, pay-station
709
+ n03903868 pedestal, plinth, footstall
710
+ n03908618 pencil box, pencil case
711
+ n03908714 pencil sharpener
712
+ n03916031 perfume, essence
713
+ n03920288 Petri dish
714
+ n03924679 photocopier
715
+ n03929660 pick, plectrum, plectron
716
+ n03929855 pickelhaube
717
+ n03930313 picket fence, paling
718
+ n03930630 pickup, pickup truck
719
+ n03933933 pier
720
+ n03935335 piggy bank, penny bank
721
+ n03937543 pill bottle
722
+ n03938244 pillow
723
+ n03942813 ping-pong ball
724
+ n03944341 pinwheel
725
+ n03947888 pirate, pirate ship
726
+ n03950228 pitcher, ewer
727
+ n03954731 plane, carpenter's plane, woodworking plane
728
+ n03956157 planetarium
729
+ n03958227 plastic bag
730
+ n03961711 plate rack
731
+ n03967562 plow, plough
732
+ n03970156 plunger, plumber's helper
733
+ n03976467 Polaroid camera, Polaroid Land camera
734
+ n03976657 pole
735
+ n03977966 police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria
736
+ n03980874 poncho
737
+ n03982430 pool table, billiard table, snooker table
738
+ n03983396 pop bottle, soda bottle
739
+ n03991062 pot, flowerpot
740
+ n03992509 potter's wheel
741
+ n03995372 power drill
742
+ n03998194 prayer rug, prayer mat
743
+ n04004767 printer
744
+ n04005630 prison, prison house
745
+ n04008634 projectile, missile
746
+ n04009552 projector
747
+ n04019541 puck, hockey puck
748
+ n04023962 punching bag, punch bag, punching ball, punchball
749
+ n04026417 purse
750
+ n04033901 quill, quill pen
751
+ n04033995 quilt, comforter, comfort, puff
752
+ n04037443 racer, race car, racing car
753
+ n04039381 racket, racquet
754
+ n04040759 radiator
755
+ n04041544 radio, wireless
756
+ n04044716 radio telescope, radio reflector
757
+ n04049303 rain barrel
758
+ n04065272 recreational vehicle, RV, R.V.
759
+ n04067472 reel
760
+ n04069434 reflex camera
761
+ n04070727 refrigerator, icebox
762
+ n04074963 remote control, remote
763
+ n04081281 restaurant, eating house, eating place, eatery
764
+ n04086273 revolver, six-gun, six-shooter
765
+ n04090263 rifle
766
+ n04099969 rocking chair, rocker
767
+ n04111531 rotisserie
768
+ n04116512 rubber eraser, rubber, pencil eraser
769
+ n04118538 rugby ball
770
+ n04118776 rule, ruler
771
+ n04120489 running shoe
772
+ n04125021 safe
773
+ n04127249 safety pin
774
+ n04131690 saltshaker, salt shaker
775
+ n04133789 sandal
776
+ n04136333 sarong
777
+ n04141076 sax, saxophone
778
+ n04141327 scabbard
779
+ n04141975 scale, weighing machine
780
+ n04146614 school bus
781
+ n04147183 schooner
782
+ n04149813 scoreboard
783
+ n04152593 screen, CRT screen
784
+ n04153751 screw
785
+ n04154565 screwdriver
786
+ n04162706 seat belt, seatbelt
787
+ n04179913 sewing machine
788
+ n04192698 shield, buckler
789
+ n04200800 shoe shop, shoe-shop, shoe store
790
+ n04201297 shoji
791
+ n04204238 shopping basket
792
+ n04204347 shopping cart
793
+ n04208210 shovel
794
+ n04209133 shower cap
795
+ n04209239 shower curtain
796
+ n04228054 ski
797
+ n04229816 ski mask
798
+ n04235860 sleeping bag
799
+ n04238763 slide rule, slipstick
800
+ n04239074 sliding door
801
+ n04243546 slot, one-armed bandit
802
+ n04251144 snorkel
803
+ n04252077 snowmobile
804
+ n04252225 snowplow, snowplough
805
+ n04254120 soap dispenser
806
+ n04254680 soccer ball
807
+ n04254777 sock
808
+ n04258138 solar dish, solar collector, solar furnace
809
+ n04259630 sombrero
810
+ n04263257 soup bowl
811
+ n04264628 space bar
812
+ n04265275 space heater
813
+ n04266014 space shuttle
814
+ n04270147 spatula
815
+ n04273569 speedboat
816
+ n04275548 spider web, spider's web
817
+ n04277352 spindle
818
+ n04285008 sports car, sport car
819
+ n04286575 spotlight, spot
820
+ n04296562 stage
821
+ n04310018 steam locomotive
822
+ n04311004 steel arch bridge
823
+ n04311174 steel drum
824
+ n04317175 stethoscope
825
+ n04325704 stole
826
+ n04326547 stone wall
827
+ n04328186 stopwatch, stop watch
828
+ n04330267 stove
829
+ n04332243 strainer
830
+ n04335435 streetcar, tram, tramcar, trolley, trolley car
831
+ n04336792 stretcher
832
+ n04344873 studio couch, day bed
833
+ n04346328 stupa, tope
834
+ n04347754 submarine, pigboat, sub, U-boat
835
+ n04350905 suit, suit of clothes
836
+ n04355338 sundial
837
+ n04355933 sunglass
838
+ n04356056 sunglasses, dark glasses, shades
839
+ n04357314 sunscreen, sunblock, sun blocker
840
+ n04366367 suspension bridge
841
+ n04367480 swab, swob, mop
842
+ n04370456 sweatshirt
843
+ n04371430 swimming trunks, bathing trunks
844
+ n04371774 swing
845
+ n04372370 switch, electric switch, electrical switch
846
+ n04376876 syringe
847
+ n04380533 table lamp
848
+ n04389033 tank, army tank, armored combat vehicle, armoured combat vehicle
849
+ n04392985 tape player
850
+ n04398044 teapot
851
+ n04399382 teddy, teddy bear
852
+ n04404412 television, television system
853
+ n04409515 tennis ball
854
+ n04417672 thatch, thatched roof
855
+ n04418357 theater curtain, theatre curtain
856
+ n04423845 thimble
857
+ n04428191 thresher, thrasher, threshing machine
858
+ n04429376 throne
859
+ n04435653 tile roof
860
+ n04442312 toaster
861
+ n04443257 tobacco shop, tobacconist shop, tobacconist
862
+ n04447861 toilet seat
863
+ n04456115 torch
864
+ n04458633 totem pole
865
+ n04461696 tow truck, tow car, wrecker
866
+ n04462240 toyshop
867
+ n04465501 tractor
868
+ n04467665 trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi
869
+ n04476259 tray
870
+ n04479046 trench coat
871
+ n04482393 tricycle, trike, velocipede
872
+ n04483307 trimaran
873
+ n04485082 tripod
874
+ n04486054 triumphal arch
875
+ n04487081 trolleybus, trolley coach, trackless trolley
876
+ n04487394 trombone
877
+ n04493381 tub, vat
878
+ n04501370 turnstile
879
+ n04505470 typewriter keyboard
880
+ n04507155 umbrella
881
+ n04509417 unicycle, monocycle
882
+ n04515003 upright, upright piano
883
+ n04517823 vacuum, vacuum cleaner
884
+ n04522168 vase
885
+ n04523525 vault
886
+ n04525038 velvet
887
+ n04525305 vending machine
888
+ n04532106 vestment
889
+ n04532670 viaduct
890
+ n04536866 violin, fiddle
891
+ n04540053 volleyball
892
+ n04542943 waffle iron
893
+ n04548280 wall clock
894
+ n04548362 wallet, billfold, notecase, pocketbook
895
+ n04550184 wardrobe, closet, press
896
+ n04552348 warplane, military plane
897
+ n04553703 washbasin, handbasin, washbowl, lavabo, wash-hand basin
898
+ n04554684 washer, automatic washer, washing machine
899
+ n04557648 water bottle
900
+ n04560804 water jug
901
+ n04562935 water tower
902
+ n04579145 whiskey jug
903
+ n04579432 whistle
904
+ n04584207 wig
905
+ n04589890 window screen
906
+ n04590129 window shade
907
+ n04591157 Windsor tie
908
+ n04591713 wine bottle
909
+ n04592741 wing
910
+ n04596742 wok
911
+ n04597913 wooden spoon
912
+ n04599235 wool, woolen, woollen
913
+ n04604644 worm fence, snake fence, snake-rail fence, Virginia fence
914
+ n04606251 wreck
915
+ n04612504 yawl
916
+ n04613696 yurt
917
+ n06359193 web site, website, internet site, site
918
+ n06596364 comic book
919
+ n06785654 crossword puzzle, crossword
920
+ n06794110 street sign
921
+ n06874185 traffic light, traffic signal, stoplight
922
+ n07248320 book jacket, dust cover, dust jacket, dust wrapper
923
+ n07565083 menu
924
+ n07579787 plate
925
+ n07583066 guacamole
926
+ n07584110 consomme
927
+ n07590611 hot pot, hotpot
928
+ n07613480 trifle
929
+ n07614500 ice cream, icecream
930
+ n07615774 ice lolly, lolly, lollipop, popsicle
931
+ n07684084 French loaf
932
+ n07693725 bagel, beigel
933
+ n07695742 pretzel
934
+ n07697313 cheeseburger
935
+ n07697537 hotdog, hot dog, red hot
936
+ n07711569 mashed potato
937
+ n07714571 head cabbage
938
+ n07714990 broccoli
939
+ n07715103 cauliflower
940
+ n07716358 zucchini, courgette
941
+ n07716906 spaghetti squash
942
+ n07717410 acorn squash
943
+ n07717556 butternut squash
944
+ n07718472 cucumber, cuke
945
+ n07718747 artichoke, globe artichoke
946
+ n07720875 bell pepper
947
+ n07730033 cardoon
948
+ n07734744 mushroom
949
+ n07742313 Granny Smith
950
+ n07745940 strawberry
951
+ n07747607 orange
952
+ n07749582 lemon
953
+ n07753113 fig
954
+ n07753275 pineapple, ananas
955
+ n07753592 banana
956
+ n07754684 jackfruit, jak, jack
957
+ n07760859 custard apple
958
+ n07768694 pomegranate
959
+ n07802026 hay
960
+ n07831146 carbonara
961
+ n07836838 chocolate sauce, chocolate syrup
962
+ n07860988 dough
963
+ n07871810 meat loaf, meatloaf
964
+ n07873807 pizza, pizza pie
965
+ n07875152 potpie
966
+ n07880968 burrito
967
+ n07892512 red wine
968
+ n07920052 espresso
969
+ n07930864 cup
970
+ n07932039 eggnog
971
+ n09193705 alp
972
+ n09229709 bubble
973
+ n09246464 cliff, drop, drop-off
974
+ n09256479 coral reef
975
+ n09288635 geyser
976
+ n09332890 lakeside, lakeshore
977
+ n09399592 promontory, headland, head, foreland
978
+ n09421951 sandbar, sand bar
979
+ n09428293 seashore, coast, seacoast, sea-coast
980
+ n09468604 valley, vale
981
+ n09472597 volcano
982
+ n09835506 ballplayer, baseball player
983
+ n10148035 groom, bridegroom
984
+ n10565667 scuba diver
985
+ n11879895 rapeseed
986
+ n11939491 daisy
987
+ n12057211 yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum
988
+ n12144580 corn
989
+ n12267677 acorn
990
+ n12620546 hip, rose hip, rosehip
991
+ n12768682 buckeye, horse chestnut, conker
992
+ n12985857 coral fungus
993
+ n12998815 agaric
994
+ n13037406 gyromitra
995
+ n13040303 stinkhorn, carrion fungus
996
+ n13044778 earthstar
997
+ n13052670 hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa
998
+ n13054560 bolete
999
+ n13133613 ear, spike, capitulum
1000
+ n15075141 toilet tissue, toilet paper, bathroom tissue
docs/Makefile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Minimal makefile for Sphinx documentation
2
+ #
3
+
4
+ # You can set these variables from the command line, and also
5
+ # from the environment for the first two.
6
+ SPHINXOPTS ?=
7
+ SPHINXBUILD ?= sphinx-build
8
+ SOURCEDIR = .
9
+ BUILDDIR = _build
10
+
11
+ # Put it first so that "make" without argument is like "make help".
12
+ help:
13
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14
+
15
+ .PHONY: help Makefile
16
+
17
+ # Catch-all target: route all unknown targets to Sphinx using the new
18
+ # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19
+ %: Makefile
20
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
docs/conf.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration file for the Sphinx documentation builder.
2
+ #
3
+ # This file only contains a selection of the most common options. For a full
4
+ # list see the documentation:
5
+ # https://www.sphinx-doc.org/en/master/usage/configuration.html
6
+
7
+ # -- Path setup --------------------------------------------------------------
8
+
9
+ # If extensions (or modules to document with autodoc) are in another directory,
10
+ # add these directories to sys.path here. If the directory is relative to the
11
+ # documentation root, use os.path.abspath to make it absolute, like shown here.
12
+ #
13
+ import os
14
+ import sys
15
+ # sys.path.insert(0, os.path.abspath('.'))
16
+ sys.path.insert(0, os.path.abspath('../'))
17
+ # sys.path.append('/home/jinwei/Laboratory/api')
18
+ sys.path.append('../..')
19
+
20
+
21
+ # -- Project information -----------------------------------------------------
22
+
23
+ project = 'DeepRobust'
24
+ copyright = ''
25
+ author = 'Yaxin Li, Wei Jin, Han Xu, Jiliang Tang'
26
+
27
+ # The full version, including alpha/beta/rc tags
28
+ release = '0.1.1'
29
+
30
+
31
+ # -- General configuration ---------------------------------------------------
32
+
33
+ # Add any Sphinx extension module names here, as strings. They can be
34
+ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
35
+ # ones.
36
+ extensions = ['sphinx.ext.todo', 'sphinx.ext.viewcode', 'sphinx.ext.autodoc', 'sphinx.ext.napoleon',
37
+ 'sphinx.ext.autosummary', 'sphinx.ext.mathjax', 'sphinx.ext.viewcode', 'sphinx.ext.githubpages' ]
38
+
39
+ # extensions = ['sphinx.ext.napoleon']
40
+ autodoc_mock_imports = ['torch', 'torchvision', 'texttable', 'tensorboardX',
41
+ 'torch_geometric', 'gensim', 'node2vec']
42
+
43
+ # remove undoc members
44
+ #autodoc_default_flags = ['members']
45
+
46
+ # Add any paths that contain templates here, relative to this directory.
47
+ templates_path = ['_templates']
48
+
49
+ # List of patterns, relative to source directory, that match files and
50
+ # directories to ignore when looking for source files.
51
+ # This pattern also affects html_static_path and html_extra_path.
52
+ exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
53
+
54
+
55
+ # -- Options for HTML output -------------------------------------------------
56
+
57
+ # The theme to use for HTML and HTML Help pages. See the documentation for
58
+ # a list of builtin themes.
59
+ #
60
+ # html_theme = 'alabaster'
61
+ html_theme = 'sphinx_rtd_theme'
62
+
63
+ # Add any paths that contain custom static files (such as style sheets) here,
64
+ # relative to this directory. They are copied after the builtin static files,
65
+ # so a file named "default.css" will overwrite the builtin "default.css".
66
+ html_static_path = ['_static']
67
+
68
+ add_module_names = False
69
+
70
+ master_doc = 'index'
71
+
docs/index.rst ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .. Deep documentation master file, created by
2
+ sphinx-quickstart on Fri Jul 3 12:19:59 2020.
3
+ You can adapt this file completely to your liking, but it should at least
4
+ contain the root `toctree` directive.
5
+
6
+ Start building your robust models with DeepRobust!
7
+ ================================
8
+ .. comments original size: 626*238
9
+
10
+ .. image:: ./DeepRobust.png
11
+ :width: 313px
12
+ :height: 119px
13
+
14
+ DeepRobust is a pytorch adversarial learning library, which contains most popular attack and defense algorithms in image domain and graph domain.
15
+
16
+ .. toctree::
17
+ :glob:
18
+ :maxdepth: 1
19
+ :caption: Installation
20
+
21
+ notes/installation
22
+
23
+ .. toctree::
24
+ :glob:
25
+ :maxdepth: 1
26
+ :caption: Graph Package
27
+
28
+ graph/data
29
+ graph/attack
30
+ graph/defense
31
+ graph/pyg
32
+ graph/node_embedding
33
+
34
+ .. toctree::
35
+ :glob:
36
+ :maxdepth: 1
37
+ :caption: Image Package
38
+
39
+ image/example
40
+
41
+ Package API
42
+ ===========
43
+ .. toctree::
44
+ :maxdepth: 1
45
+ :caption: Image Package
46
+
47
+ source/deeprobust.image.attack
48
+ source/deeprobust.image.defense
49
+ source/deeprobust.image.netmodels
50
+
51
+
52
+ .. toctree::
53
+ :maxdepth: 1
54
+ :caption: Graph Package
55
+
56
+ source/deeprobust.graph.global_attack
57
+ source/deeprobust.graph.targeted_attack
58
+ source/deeprobust.graph.defense
59
+ source/deeprobust.graph.data
60
+
61
+ Indices and tables
62
+ ==================
63
+
64
+ * :ref:`modindex`
65
+ * :ref:`search`
examples/graph/cgscore_datasets.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # compute cgscore for gcn
2
+ # author: Yaning
3
+ import torch
4
+ import numpy as np
5
+ import torch.nn.functional as Fd
6
+ from deeprobust.graph.defense import GCNJaccard, GCN
7
+ from deeprobust.graph.defense import GCNScore
8
+ from deeprobust.graph.utils import *
9
+ from deeprobust.graph.data import Dataset, PrePtbDataset
10
+ from scipy.sparse import csr_matrix
11
+ import argparse
12
+ import pickle
13
+ from deeprobust.graph import utils
14
+ from collections import defaultdict
15
+
16
+ parser = argparse.ArgumentParser()
17
+ parser.add_argument('--seed', type=int, default=15, help='Random seed.')
18
+ parser.add_argument('--dataset', type=str, default='polblogs', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'], help='dataset')
19
+ parser.add_argument('--ptb_rate', type=float, default=0.10, help='pertubation rate')
20
+
21
+ args = parser.parse_args()
22
+ args.cuda = torch.cuda.is_available()
23
+ print('cuda: %s' % args.cuda)
24
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
25
+
26
+ # make sure you use the same data splits as you generated attacks
27
+ np.random.seed(args.seed)
28
+ if args.cuda:
29
+ torch.cuda.manual_seed(args.seed)
30
+
31
+ # Here the random seed is to split the train/val/test data,
32
+ # we need to set the random seed to be the same as that when you generate the perturbed graph
33
+ # data = Dataset(root='/tmp/', name=args.dataset, setting='nettack', seed=15)
34
+ # Or we can just use setting='prognn' to get the splits
35
+ data = Dataset(root='/tmp/', name=args.dataset, setting='prognn')
36
+ adj, features, labels = data.adj, data.features, data.labels
37
+ idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
38
+
39
+
40
+ perturbed_data = PrePtbDataset(root='/tmp/',
41
+ name=args.dataset,
42
+ attack_method='meta',
43
+ ptb_rate=args.ptb_rate)
44
+
45
+ perturbed_adj = perturbed_data.adj
46
+ # perturbed_adj = adj
47
+
48
+ def save_cg_scores(cg_scores, filename="cg_scores.npy"):
49
+ np.save(filename, cg_scores)
50
+ print(f"CG-scores saved to {filename}")
51
+
52
+ def load_cg_scores_numpy(filename="cg_scores.npy"):
53
+ cg_scores = np.load(filename, allow_pickle=True)
54
+ print(f"CG-scores loaded from {filename}")
55
+ return cg_scores
56
+
57
+ def calc_cg_score_gnn_with_sampling(
58
+ A, X, labels, device, rep_num=1, unbalance_ratio=1, sub_term=False
59
+ ):
60
+ """
61
+ Calculate CG-score for each edge in a graph with node labels and random sampling.
62
+
63
+ Args:
64
+ A: torch.Tensor
65
+ Adjacency matrix of the graph (size: N x N).
66
+ X: torch.Tensor
67
+ Node features matrix (size: N x F).
68
+ labels: torch.Tensor
69
+ Node labels (size: N).
70
+ device: torch.device
71
+ Device to perform calculations.
72
+ rep_num: int
73
+ Number of repetitions for Monte Carlo sampling.
74
+ unbalance_ratio: float
75
+ Ratio of unbalanced data (1:unbalance_ratio).
76
+ sub_term: bool
77
+ If True, calculate and return sub-terms.
78
+
79
+ Returns:
80
+ cg_scores: dict
81
+ Dictionary containing CG-scores for edges and optionally sub-terms.
82
+ """
83
+ N = A.shape[0]
84
+ cg_scores = {
85
+ "vi": np.zeros((N, N)),
86
+ "ab": np.zeros((N, N)),
87
+ "a2": np.zeros((N, N)),
88
+ "b2": np.zeros((N, N)),
89
+ "times": np.zeros((N, N)),
90
+ }
91
+
92
+ with torch.no_grad():
93
+ for _ in range(rep_num):
94
+ # Compute AX (node representations)
95
+ AX = torch.matmul(A, X).to(device)
96
+ norm_AX = AX / torch.norm(AX, dim=1, keepdim=True)
97
+
98
+ # Group nodes by their labels
99
+ dataset = defaultdict(list)
100
+ data_idx = defaultdict(list)
101
+ for i, label in enumerate(labels):
102
+ dataset[label.item()].append(norm_AX[i].unsqueeze(0)) # Store normalized data
103
+ data_idx[label.item()].append(i) # Store indices
104
+
105
+ # Convert to tensors
106
+ for label, data_list in dataset.items():
107
+ dataset[label] = torch.cat(data_list, dim=0)
108
+ data_idx[label] = torch.tensor(data_idx[label], dtype=torch.long, device=device)
109
+
110
+ # Calculate CG-scores for each label group
111
+ for curr_label, curr_samples in dataset.items():
112
+ curr_indices = data_idx[curr_label]
113
+ curr_num = len(curr_samples)
114
+
115
+ # Randomly sample a subset of current label examples
116
+ chosen_curr_idx = np.random.choice(range(curr_num), curr_num, replace=False)
117
+ chosen_curr_samples = curr_samples[chosen_curr_idx]
118
+ chosen_curr_indices = curr_indices[chosen_curr_idx]
119
+
120
+ # Sample negative examples from other classes
121
+ neg_samples = torch.cat(
122
+ [dataset[l] for l in dataset if l != curr_label], dim=0
123
+ )
124
+ neg_indices = torch.cat(
125
+ [data_idx[l] for l in data_idx if l != curr_label], dim=0
126
+ )
127
+ neg_num = min(int(curr_num * unbalance_ratio), len(neg_samples))
128
+ chosen_neg_samples = neg_samples[
129
+ torch.randperm(len(neg_samples))[:neg_num]
130
+ ]
131
+
132
+ # Combine positive and negative samples
133
+ combined_samples = torch.cat([chosen_curr_samples, chosen_neg_samples], dim=0)
134
+ y = torch.cat(
135
+ [torch.ones(len(chosen_curr_samples)), -torch.ones(neg_num)], dim=0
136
+ ).to(device)
137
+
138
+ # Compute the Gram matrix H^\infty
139
+ H_inner = torch.matmul(combined_samples, combined_samples.T)
140
+ del combined_samples
141
+ ###
142
+ H_inner = torch.clamp(H_inner, min=-1.0, max=1.0)
143
+ ###
144
+ H = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi)
145
+ del H_inner
146
+
147
+ H.fill_diagonal_(0.5)
148
+ ##
149
+ epsilon = 1e-6
150
+ H = H + epsilon * torch.eye(H.size(0), device=H.device)
151
+ ##
152
+ invH = torch.inverse(H)
153
+ del H
154
+ original_error = y @ (invH @ y)
155
+
156
+ # Compute CG-scores for each edge
157
+ for i in chosen_curr_indices:
158
+ print("the node index:", i)
159
+ for j in range(i + 1, N): # Upper triangular traversal
160
+ # print(j)
161
+ if A[i, j] == 0: # Skip if no edge exists
162
+ continue
163
+
164
+ # Remove edge (i, j) to create A1
165
+ A1 = A.clone()
166
+ A1[i, j] = A1[j, i] = 0
167
+
168
+ # Recompute AX with A1
169
+ AX1 = torch.matmul(A1, X).to(device)
170
+ norm_AX1 = AX1 / torch.norm(AX1, dim=1, keepdim=True)
171
+
172
+ # Repeat error calculation with A1
173
+ curr_samples_A1 = norm_AX1[chosen_curr_indices]
174
+ neg_samples_A1 = norm_AX1[neg_indices]
175
+ chosen_neg_samples_A1 = neg_samples_A1[
176
+ torch.randperm(len(neg_samples_A1))[:neg_num]
177
+ ]
178
+ combined_samples_A1 = torch.cat(
179
+ [curr_samples_A1, chosen_neg_samples_A1], dim=0
180
+ )
181
+ H_inner_A1 = torch.matmul(combined_samples_A1, combined_samples_A1.T)
182
+
183
+ del combined_samples_A1
184
+
185
+ ### trick1
186
+ H_inner_A1 = torch.clamp(H_inner_A1, min=-1.0, max=1.0)
187
+ ###
188
+
189
+ H_A1 = H_inner_A1 * (np.pi - torch.acos(H_inner_A1)) / (2 * np.pi)
190
+ del H_inner_A1
191
+ H_A1.fill_diagonal_(0.5)
192
+
193
+ ### trick2
194
+ epsilon = 1e-6
195
+ H_A1= H_A1 + epsilon * torch.eye(H_A1.size(0), device=H_A1.device)
196
+ ###
197
+ invH_A1 = torch.inverse(H_A1)
198
+ del H_A1
199
+
200
+ error_A1 = y @ (invH_A1 @ y)
201
+
202
+ print("i:", i)
203
+ print("j:", j)
204
+ print("current score:", (original_error - error_A1).item())
205
+ # Compute the difference in error (CG-score)
206
+ cg_scores["vi"][i, j] += (original_error - error_A1).item()
207
+ cg_scores["vi"][j, i] = cg_scores["vi"][i, j] # Symmetric
208
+ cg_scores["times"][i, j] += 1
209
+ cg_scores["times"][j, i] += 1
210
+
211
+ # Normalize CG-scores by repetition count
212
+ for key, values in cg_scores.items():
213
+ if key == "times":
214
+ continue
215
+ cg_scores[key] = values / np.where(cg_scores["times"] > 0, cg_scores["times"], 1)
216
+
217
+ return cg_scores if sub_term else cg_scores["vi"]
218
+
219
+ def is_symmetric_sparse(adj):
220
+ """
221
+ Check if a sparse matrix is symmetric.
222
+ """
223
+ # Check symmetry
224
+ return (adj != adj.transpose()).nnz == 0 # .nnz is the number of non-zero elements
225
+
226
+ def make_symmetric_sparse(adj):
227
+ """
228
+ Ensure the sparse adjacency matrix is symmetrical.
229
+ """
230
+ # Make the matrix symmetric
231
+ sym_adj = (adj + adj.transpose()) / 2
232
+ return sym_adj
233
+
234
+ perturbed_adj = make_symmetric_sparse(perturbed_adj)
235
+
236
+ if type(perturbed_adj) is not torch.Tensor:
237
+ features, perturbed_adj, labels = utils.to_tensor(features, perturbed_adj, labels)
238
+ else:
239
+ features = features.to(device)
240
+ perturbed_adj = perturbed_adj.to(device)
241
+ labels = labels.to(device)
242
+
243
+ if utils.is_sparse_tensor(perturbed_adj):
244
+
245
+ adj_norm = utils.normalize_adj_tensor(perturbed_adj, sparse=True)
246
+ else:
247
+ adj_norm = utils.normalize_adj_tensor(perturbed_adj)
248
+
249
+ features = features.to_dense()
250
+ perturbed_adj = adj_norm.to_dense()
251
+
252
+
253
+ calc_cg_score = calc_cg_score_gnn_with_sampling(perturbed_adj, features, labels, device, rep_num=1, unbalance_ratio=1, sub_term=False)
254
+ save_cg_scores(calc_cg_score, filename="cg_scores_polblogs_0.10.npy")
255
+ # print("completed")
examples/graph/cgscore_datasets_multigpus.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # compute cgscore for gcn
2
+ # author: Yaning
3
+ import torch
4
+ import numpy as np
5
+ import torch.nn.functional as Fd
6
+ from deeprobust.graph.defense import GCNJaccard, GCN
7
+ from deeprobust.graph.defense import GCNScore
8
+ from deeprobust.graph.utils import *
9
+ from deeprobust.graph.data import Dataset, PrePtbDataset
10
+ from scipy.sparse import csr_matrix
11
+ import argparse
12
+ import pickle
13
+ from deeprobust.graph import utils
14
+ import torch.multiprocessing as mp
15
+ from collections import defaultdict
16
+ from tqdm import tqdm
17
+
18
+ parser = argparse.ArgumentParser()
19
+ parser.add_argument('--seed', type=int, default=15, help='Random seed.')
20
+ parser.add_argument('--dataset', type=str, default='pubmed', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'], help='dataset')
21
+ parser.add_argument('--ptb_rate', type=float, default=0.05, help='pertubation rate')
22
+
23
+ args = parser.parse_args()
24
+ args.cuda = torch.cuda.is_available()
25
+ print('cuda: %s' % args.cuda)
26
+ device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
27
+
28
+ # make sure you use the same data splits as you generated attacks
29
+ np.random.seed(args.seed)
30
+ if args.cuda:
31
+ torch.cuda.manual_seed(args.seed)
32
+
33
+ # Here the random seed is to split the train/val/test data,
34
+ # we need to set the random seed to be the same as that when you generate the perturbed graph
35
+ # data = Dataset(root='/tmp/', name=args.dataset, setting='nettack', seed=15)
36
+ # Or we can just use setting='prognn' to get the splits
37
+ data = Dataset(root='/tmp/', name=args.dataset, setting='prognn')
38
+ adj, features, labels = data.adj, data.features, data.labels
39
+ idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
40
+
41
+
42
+ perturbed_data = PrePtbDataset(root='/tmp/',
43
+ name=args.dataset,
44
+ attack_method='meta',
45
+ ptb_rate=args.ptb_rate)
46
+
47
+ perturbed_adj = perturbed_data.adj
48
+ # perturbed_adj = adj
49
+
50
+ def save_cg_scores(cg_scores, filename="cg_scores.npy"):
51
+ np.save(filename, cg_scores)
52
+ print(f"CG-scores saved to {filename}")
53
+
54
+ def load_cg_scores_numpy(filename="cg_scores.npy"):
55
+ cg_scores = np.load(filename, allow_pickle=True)
56
+ print(f"CG-scores loaded from {filename}")
57
+ return cg_scores
58
+
59
+ def calc_cg_score_gnn_with_sampling(
60
+ A, X, labels, device, rep_num=1, unbalance_ratio=1, sub_term=False, batch_size=64, label_filter=None
61
+ ):
62
+ N = A.shape[0]
63
+ cg_scores = {
64
+ "vi": np.zeros((N, N)),
65
+ "ab": np.zeros((N, N)),
66
+ "a2": np.zeros((N, N)),
67
+ "b2": np.zeros((N, N)),
68
+ "times": np.zeros((N, N)),
69
+ }
70
+
71
+ A = A.to(device)
72
+ X = X.to(device)
73
+ labels = labels.to(device)
74
+
75
+ @torch.no_grad()
76
+ def normalize(tensor):
77
+ return tensor / (torch.norm(tensor, dim=1, keepdim=True) + 1e-8)
78
+
79
+ for _ in range(rep_num):
80
+ AX = torch.matmul(A, X)
81
+ norm_AX = normalize(AX)
82
+
83
+ unique_labels = torch.unique(labels)
84
+ if label_filter is not None:
85
+ unique_labels = [label for label in unique_labels if label.item() in label_filter]
86
+
87
+ label_to_indices = {
88
+ label.item(): (labels == label).nonzero(as_tuple=True)[0] for label in unique_labels
89
+ }
90
+ dataset = {label: norm_AX[indices] for label, indices in label_to_indices.items()}
91
+
92
+ neg_samples_dict = {}
93
+ neg_indices_dict = {}
94
+ for label in unique_labels:
95
+ print("label:", label)
96
+ label = label.item()
97
+ mask = labels != label
98
+ neg_samples = norm_AX[mask]
99
+ neg_indices = mask.nonzero(as_tuple=True)[0]
100
+ neg_samples_dict[label] = neg_samples
101
+ neg_indices_dict[label] = neg_indices
102
+
103
+ for curr_label in tqdm(unique_labels, desc="Label groups", position=device.index):
104
+ print("curr_label:", curr_label)
105
+ label_id = int(curr_label)
106
+ curr_samples = dataset[label_id]
107
+ curr_indices = label_to_indices[label_id]
108
+ curr_num = len(curr_samples)
109
+
110
+ chosen_curr_idx = torch.randperm(curr_num, device=device)
111
+ chosen_curr_samples = curr_samples[chosen_curr_idx]
112
+ chosen_curr_indices = curr_indices[chosen_curr_idx]
113
+
114
+ neg_samples = neg_samples_dict[label_id]
115
+ neg_indices = neg_indices_dict[label_id]
116
+ neg_num = min(int(curr_num * unbalance_ratio), len(neg_samples))
117
+ rand_idx = torch.randperm(len(neg_samples), device=device)[:neg_num]
118
+ chosen_neg_samples = neg_samples[rand_idx]
119
+ chosen_neg_indices = neg_indices[rand_idx]
120
+
121
+ combined_samples = torch.cat([chosen_curr_samples, chosen_neg_samples], dim=0)
122
+ y = torch.cat([torch.ones(len(chosen_curr_samples)), -torch.ones(neg_num)], dim=0).to(device)
123
+
124
+ H_inner = torch.matmul(combined_samples, combined_samples.T)
125
+ H_inner = torch.clamp(H_inner, min=-1.0, max=1.0)
126
+ H = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi)
127
+ H.fill_diagonal_(0.5)
128
+ H += 1e-6 * torch.eye(H.size(0), device=device)
129
+ invH = torch.inverse(H)
130
+ original_error = y @ (invH @ y)
131
+
132
+ edge_batch = []
133
+ for idx_i in chosen_curr_indices.tolist():
134
+ for j in range(idx_i + 1, N):
135
+ if A[idx_i, j] != 0:
136
+ edge_batch.append((idx_i, j))
137
+
138
+ for k in tqdm(range(0, len(edge_batch), batch_size), desc="Edge batches", leave=False, position=device.index):
139
+ batch = edge_batch[k: k + batch_size]
140
+ B = len(batch)
141
+
142
+ norm_AX1_batch = norm_AX.repeat(B, 1, 1).clone()
143
+ for b, (i, j) in enumerate(batch):
144
+ AX1_i = AX[i] - A[i, j] * X[j]
145
+ AX1_j = AX[j] - A[j, i] * X[i]
146
+ norm_AX1_batch[b, i] = AX1_i / (torch.norm(AX1_i) + 1e-8)
147
+ norm_AX1_batch[b, j] = AX1_j / (torch.norm(AX1_j) + 1e-8)
148
+
149
+ sample_idx = chosen_curr_indices.tolist() + chosen_neg_indices.tolist()
150
+ sample_batch = norm_AX1_batch[:, sample_idx, :]
151
+
152
+ H_inner = torch.matmul(sample_batch, sample_batch.transpose(1, 2))
153
+ H_inner = torch.clamp(H_inner, min=-1.0, max=1.0)
154
+ H = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi)
155
+ eye = torch.eye(H.size(-1), device=device).unsqueeze(0).expand_as(H)
156
+ H = H + 1e-6 * eye
157
+ H.diagonal(dim1=-2, dim2=-1).copy_(0.5)
158
+
159
+ invH = torch.inverse(H)
160
+ y_expanded = y.unsqueeze(0).expand(B, -1)
161
+ error_A1 = torch.einsum("bi,bij,bj->b", y_expanded, invH, y_expanded)
162
+
163
+ for b, (i, j) in enumerate(batch):
164
+ score = (original_error - error_A1[b]).item()
165
+ cg_scores["vi"][i, j] += score
166
+ cg_scores["vi"][j, i] = score
167
+ cg_scores["times"][i, j] += 1
168
+ cg_scores["times"][j, i] += 1
169
+
170
+ for key in cg_scores:
171
+ if key != "times":
172
+ cg_scores[key] = cg_scores[key] / np.where(cg_scores["times"] > 0, cg_scores["times"], 1)
173
+
174
+ # return cg_scores if sub_term else cg_scores["vi"]
175
+ return cg_scores
176
+
177
+
178
+ def run_worker(gpu_id, world_size, A, X, labels, rep_num, unbalance_ratio, sub_term, batch_size, return_dict):
179
+ device = torch.device(f"cuda:{gpu_id}")
180
+ unique_labels = torch.unique(labels).tolist()
181
+ label_chunks = np.array_split(unique_labels, world_size)
182
+ rank = torch.cuda.current_device()
183
+ label_filter = [int(l) for l in label_chunks[gpu_id % world_size]]
184
+
185
+ result = calc_cg_score_gnn_with_sampling(
186
+ A, X, labels, device,
187
+ rep_num=rep_num,
188
+ unbalance_ratio=unbalance_ratio,
189
+ sub_term=sub_term,
190
+ batch_size=batch_size,
191
+ label_filter=label_filter
192
+ )
193
+ return_dict[gpu_id] = result
194
+
195
+
196
+
197
+ def multi_gpu_wrapper(A, X, labels, rep_num=1, unbalance_ratio=1, sub_term=False, batch_size=64, gpu_ids=None):
198
+ if gpu_ids is None:
199
+ gpu_ids = list(range(torch.cuda.device_count()))
200
+ world_size = len(gpu_ids)
201
+
202
+ manager = mp.Manager()
203
+ return_dict = manager.dict()
204
+ processes = []
205
+
206
+ for local_rank, gpu_id in enumerate(gpu_ids):
207
+ p = mp.Process(
208
+ target=run_worker,
209
+ args=(gpu_id, world_size, A, X, labels, rep_num, unbalance_ratio, sub_term, batch_size, return_dict)
210
+ )
211
+ p.start()
212
+ processes.append(p)
213
+
214
+ for p in processes:
215
+ p.join()
216
+
217
+ # 初始化 final_score
218
+ final_score = None
219
+
220
+ for gpu_id, rank_result in return_dict.items():
221
+ if not isinstance(rank_result, dict):
222
+ print(f"[FATAL] GPU {gpu_id} result is not a dict: {type(rank_result)}")
223
+ continue
224
+
225
+ if final_score is None:
226
+ # 深拷贝防止指针复用
227
+ final_score = {k: np.copy(v) for k, v in rank_result.items()}
228
+ else:
229
+ for key in rank_result:
230
+ if key not in final_score:
231
+ print(f"[WARN] key '{key}' not in final_score. Skipping.")
232
+ continue
233
+ try:
234
+ if isinstance(final_score[key], np.ndarray) and isinstance(rank_result[key], np.ndarray):
235
+ final_score[key] += rank_result[key]
236
+ else:
237
+ print(f"[WARN] Skipped merging key '{key}' due to type mismatch.")
238
+ except Exception as e:
239
+ print(f"[ERROR] Failed merging key '{key}': {e}")
240
+
241
+ return final_score
242
+
243
+
244
+
245
+ def is_symmetric_sparse(adj):
246
+ """
247
+ Check if a sparse matrix is symmetric.
248
+ """
249
+ # Check symmetry
250
+ return (adj != adj.transpose()).nnz == 0 # .nnz is the number of non-zero elements
251
+
252
+ def make_symmetric_sparse(adj):
253
+ """
254
+ Ensure the sparse adjacency matrix is symmetrical.
255
+ """
256
+ # Make the matrix symmetric
257
+ sym_adj = (adj + adj.transpose()) / 2
258
+ return sym_adj
259
+
260
+ if __name__ == "__main__":
261
+ mp.set_start_method("spawn", force=True)
262
+
263
+ print("cuda:", torch.cuda.is_available())
264
+
265
+ # 选择使用的 GPU(根据你的实际情况)
266
+ selected_gpus = [0, 1, 2, 3]
267
+
268
+ # 稀疏矩阵对称处理
269
+ perturbed_adj = make_symmetric_sparse(perturbed_adj)
270
+
271
+ # 转 tensor
272
+ if type(perturbed_adj) is not torch.Tensor:
273
+ features, perturbed_adj, labels = utils.to_tensor(features, perturbed_adj, labels)
274
+ else:
275
+ features = features.to(device)
276
+ perturbed_adj = perturbed_adj.to(device)
277
+ labels = labels.to(device)
278
+
279
+ # 标准化邻接
280
+ if utils.is_sparse_tensor(perturbed_adj):
281
+ adj_norm = utils.normalize_adj_tensor(perturbed_adj, sparse=True)
282
+ else:
283
+ adj_norm = utils.normalize_adj_tensor(perturbed_adj)
284
+
285
+ features = features.to_dense()
286
+ perturbed_adj = adj_norm.to_dense()
287
+
288
+ # 多GPU并行计算 CG-score
289
+ calc_cg_score = multi_gpu_wrapper(
290
+ perturbed_adj, features, labels,
291
+ rep_num=1,
292
+ unbalance_ratio=1,
293
+ sub_term=False,
294
+ batch_size=1024,
295
+ gpu_ids=selected_gpus
296
+ )
297
+ save_cg_scores(calc_cg_score["vi"], filename=f"{args.dataset}_{args.ptb_rate}.npy")
298
+
299
+ print(" CG-score computation completed.")
examples/graph/cgscore_datasets_multigpus2.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # compute cgscore for gcn - Final Optimized Complete Version
3
+ ## 精度有损失,但不多
4
+ import torch
5
+ import numpy as np
6
+ import torch.multiprocessing as mp
7
+ import argparse
8
+ from tqdm import tqdm
9
+ from deeprobust.graph.utils import *
10
+ from deeprobust.graph.data import Dataset, PrePtbDataset
11
+ from torch.cuda.amp import autocast
12
+ from deeprobust.graph import utils
13
+
14
+
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument('--seed', type=int, default=15)
17
+ parser.add_argument('--dataset', type=str, default='pubmed', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'])
18
+ parser.add_argument('--ptb_rate', type=float, default=0.05)
19
+ args = parser.parse_args()
20
+
21
+ args.cuda = torch.cuda.is_available()
22
+ print('cuda: %s' % args.cuda)
23
+ device = torch.device("cuda:0" if args.cuda else "cpu")
24
+
25
+ np.random.seed(args.seed)
26
+ if args.cuda:
27
+ torch.cuda.manual_seed(args.seed)
28
+
29
+ data = Dataset(root='/tmp/', name=args.dataset, setting='prognn')
30
+ adj, features, labels = data.adj, data.features, data.labels
31
+
32
+ perturbed_data = PrePtbDataset(root='/tmp/', name=args.dataset, attack_method='meta', ptb_rate=args.ptb_rate)
33
+ perturbed_adj = (perturbed_data.adj + perturbed_data.adj.T) / 2
34
+
35
+ def save_cg_scores(cg_scores, filename="cg_scores.npy"):
36
+ np.save(filename, cg_scores)
37
+ print(f"CG-scores saved to {filename}")
38
+
39
+ def calc_cg_score_gnn_with_sampling(A, X, labels, device, rep_num=1, unbalance_ratio=1, batch_size=1024, node_filter=None):
40
+ N = A.shape[0]
41
+ cg_scores = {"vi": np.zeros((N, N)), "times": np.zeros((N, N))}
42
+ A, X, labels = A.to(device), X.to(device), labels.to(device)
43
+
44
+ @torch.no_grad()
45
+ def normalize(tensor):
46
+ return tensor / (torch.norm(tensor, dim=1, keepdim=True) + 1e-8)
47
+
48
+ for _ in range(rep_num):
49
+ AX = torch.matmul(A, X)
50
+ norm_AX = normalize(AX)
51
+
52
+ unique_labels = torch.unique(labels)
53
+ label_to_indices = {label.item(): (labels == label).nonzero(as_tuple=True)[0] for label in unique_labels}
54
+ dataset = {label: norm_AX[idx] for label, idx in label_to_indices.items()}
55
+
56
+ neg_samples_dict = {}
57
+ neg_indices_dict = {}
58
+ for label in unique_labels:
59
+ label = label.item()
60
+ mask = labels != label
61
+ neg_samples_dict[label] = norm_AX[mask]
62
+ neg_indices_dict[label] = mask.nonzero(as_tuple=True)[0]
63
+
64
+ if node_filter is not None:
65
+ node_filter = set(node_filter.tolist())
66
+ else:
67
+ node_filter = set(range(labels.size(0)))
68
+
69
+ for curr_label in tqdm(unique_labels, desc="Label groups", position=device.index):
70
+ label_id = int(curr_label)
71
+ curr_samples = dataset[label_id]
72
+ curr_indices = label_to_indices[label_id]
73
+ curr_num = len(curr_samples)
74
+
75
+ chosen_curr_idx = torch.randperm(curr_num, device=device)
76
+ pos_samples = curr_samples[chosen_curr_idx]
77
+ pos_indices = curr_indices[chosen_curr_idx]
78
+
79
+ neg_samples = neg_samples_dict[label_id]
80
+ neg_indices = neg_indices_dict[label_id]
81
+ neg_num = min(int(curr_num * unbalance_ratio), len(neg_samples))
82
+ rand_idx = torch.randperm(len(neg_samples), device=device)[:neg_num]
83
+ neg_samples = neg_samples[rand_idx]
84
+ neg_indices = neg_indices[rand_idx]
85
+
86
+ sample_idx = pos_indices.tolist() + neg_indices.tolist()
87
+ sample_tensor = norm_AX[sample_idx] # [M, F]
88
+ y = torch.cat([
89
+ torch.ones(len(pos_samples)),
90
+ -torch.ones(len(neg_samples))
91
+ ], dim=0).to(device)
92
+
93
+ with autocast():
94
+ H_inner = torch.matmul(sample_tensor, sample_tensor.T).clamp(-1.0, 1.0)
95
+ H_base = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi)
96
+ H_base.fill_diagonal_(0.5)
97
+ H_base += 1e-6 * torch.eye(H_base.size(0), device=device)
98
+ ref_error = torch.dot(y, torch.linalg.solve(H_base, y))
99
+
100
+ edge_batch = [(i.item(), j)
101
+ for i in pos_indices if i.item() in node_filter
102
+ for j in range(i.item() + 1, N) if A[i, j] != 0]
103
+
104
+ for k in tqdm(range(0, len(edge_batch), batch_size), desc="Edge batches", leave=False, position=device.index):
105
+ batch = edge_batch[k:k + batch_size]
106
+ if not batch: continue
107
+ i_idx, j_idx = zip(*batch)
108
+ i_idx = torch.tensor(i_idx, device=device)
109
+ j_idx = torch.tensor(j_idx, device=device)
110
+
111
+ AX1_i = AX[i_idx] - A[i_idx, j_idx].unsqueeze(1) * X[j_idx]
112
+ AX1_j = AX[j_idx] - A[j_idx, i_idx].unsqueeze(1) * X[i_idx]
113
+ norm_AX1_i = normalize(AX1_i)
114
+ norm_AX1_j = normalize(AX1_j)
115
+
116
+ for b, (i, j) in enumerate(batch):
117
+ i_int, j_int = i, j
118
+ sample_tensor_copy = sample_tensor.clone()
119
+ try:
120
+ i_pos = sample_idx.index(i_int)
121
+ sample_tensor_copy[i_pos] = norm_AX1_i[b]
122
+ except ValueError:
123
+ pass
124
+ try:
125
+ j_pos = sample_idx.index(j_int)
126
+ sample_tensor_copy[j_pos] = norm_AX1_j[b]
127
+ except ValueError:
128
+ pass
129
+
130
+ with autocast():
131
+ H_inner = torch.matmul(sample_tensor_copy, sample_tensor_copy.T).clamp(-1.0, 1.0)
132
+ H = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi)
133
+ H.fill_diagonal_(0.5)
134
+ H += 1e-6 * torch.eye(H.size(0), device=device)
135
+ sol = torch.linalg.solve(H, y)
136
+ err_new = torch.dot(y, sol)
137
+
138
+ score = (ref_error - err_new).item()
139
+ cg_scores["vi"][i, j] += score
140
+ cg_scores["vi"][j, i] = score
141
+ cg_scores["times"][i, j] += 1
142
+ cg_scores["times"][j, i] += 1
143
+
144
+ for key in ["vi"]:
145
+ cg_scores[key] = cg_scores[key] / np.where(cg_scores["times"] > 0, cg_scores["times"], 1)
146
+
147
+ return cg_scores
148
+
149
+ def run_worker(gpu_id, world_size, A, X, labels, rep_num, unbalance_ratio, batch_size, return_dict):
150
+ device = torch.device(f"cuda:{gpu_id}")
151
+
152
+ # 用 node ids 划分代替 label 分片
153
+ node_ids = torch.arange(labels.size(0))
154
+ node_chunks = np.array_split(node_ids.numpy(), world_size)
155
+ node_filter = torch.tensor(node_chunks[gpu_id], device=device)
156
+
157
+ result = calc_cg_score_gnn_with_sampling(
158
+ A, X, labels, device,
159
+ rep_num=rep_num,
160
+ unbalance_ratio=unbalance_ratio,
161
+ batch_size=batch_size,
162
+ node_filter=node_filter # 👈 改名字
163
+ )
164
+ return_dict[gpu_id] = result
165
+
166
+ def multi_gpu_wrapper(A, X, labels, rep_num=1, unbalance_ratio=1, batch_size=1024, gpu_ids=None):
167
+ if gpu_ids is None:
168
+ gpu_ids = list(range(torch.cuda.device_count()))
169
+ world_size = len(gpu_ids)
170
+
171
+ mp.set_start_method("spawn", force=True)
172
+ manager = mp.Manager()
173
+ return_dict = manager.dict()
174
+ processes = []
175
+
176
+ for i, gpu_id in enumerate(gpu_ids):
177
+ p = mp.Process(target=run_worker, args=(gpu_id, world_size, A, X, labels, rep_num, unbalance_ratio, batch_size, return_dict))
178
+ p.start()
179
+ processes.append(p)
180
+
181
+ for p in processes:
182
+ p.join()
183
+
184
+ final_score = None
185
+ for res in return_dict.values():
186
+ if final_score is None:
187
+ final_score = {k: np.copy(v) for k, v in res.items()}
188
+ else:
189
+ for k in res:
190
+ final_score[k] += res[k]
191
+ return final_score
192
+
193
+ if __name__ == "__main__":
194
+ features, perturbed_adj, labels = utils.to_tensor(features, perturbed_adj, labels)
195
+ features = features.to_dense()
196
+ if utils.is_sparse_tensor(perturbed_adj):
197
+ perturbed_adj = utils.normalize_adj_tensor(perturbed_adj, sparse=True)
198
+ perturbed_adj = perturbed_adj.to_dense()
199
+
200
+ selected_gpus = [0,1,2,3]
201
+ cg_scores = multi_gpu_wrapper(perturbed_adj, features, labels,
202
+ rep_num=1,
203
+ unbalance_ratio=3,
204
+ batch_size=40280,
205
+ gpu_ids=selected_gpus)
206
+
207
+ save_cg_scores(cg_scores["vi"], filename=f"{args.dataset}_{args.ptb_rate}.npy")
208
+ print("🎉 CG-score computation completed.")
examples/graph/cgscore_env.yaml ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: cgscore
2
+ channels:
3
+ - defaults
4
+ dependencies:
5
+ - _libgcc_mutex=0.1=main
6
+ - _openmp_mutex=5.1=1_gnu
7
+ - ca-certificates=2024.11.26=h06a4308_0
8
+ - ld_impl_linux-64=2.40=h12ee557_0
9
+ - libffi=3.4.4=h6a678d5_1
10
+ - libgcc-ng=11.2.0=h1234567_1
11
+ - libgomp=11.2.0=h1234567_1
12
+ - libstdcxx-ng=11.2.0=h1234567_1
13
+ - ncurses=6.4=h6a678d5_0
14
+ - openssl=3.0.15=h5eee18b_0
15
+ - python=3.9.21=he870216_1
16
+ - readline=8.2=h5eee18b_0
17
+ - sqlite=3.45.3=h5eee18b_0
18
+ - tk=8.6.14=h39e8969_0
19
+ - wheel=0.44.0=py39h06a4308_0
20
+ - xz=5.4.6=h5eee18b_1
21
+ - zlib=1.2.13=h5eee18b_1
22
+ - pip:
23
+ - aiohappyeyeballs==2.4.4
24
+ - aiohttp==3.11.11
25
+ - aiosignal==1.3.2
26
+ - alabaster==0.7.16
27
+ - alembic==1.15.1
28
+ - asttokens==3.0.0
29
+ - async-timeout==5.0.1
30
+ - attrs==24.3.0
31
+ - autopage==0.5.2
32
+ - babel==2.17.0
33
+ - certifi==2025.1.31
34
+ - cfgv==3.4.0
35
+ - charset-normalizer==3.4.1
36
+ - cliff==4.9.1
37
+ - cmaes==0.11.1
38
+ - cmake==3.31.2
39
+ - cmd2==2.5.11
40
+ - cogdl==0.6
41
+ - colorlog==6.9.0
42
+ - commonmark==0.9.1
43
+ - contourpy==1.3.0
44
+ - cycler==0.12.1
45
+ - decorator==5.1.1
46
+ - distlib==0.3.9
47
+ - docutils==0.21.2
48
+ - exceptiongroup==1.2.2
49
+ - executing==2.1.0
50
+ - filelock==3.18.0
51
+ - flake8==7.1.2
52
+ - fonttools==4.56.0
53
+ - frozenlist==1.5.0
54
+ - fsspec==2025.3.0
55
+ - fst-pso==1.8.1
56
+ - fuzzytm==2.0.9
57
+ - gensim==4.3.0
58
+ - grave==0.0.3
59
+ - grb==0.1.0
60
+ - greenlet==3.1.1
61
+ - huggingface-hub==0.29.3
62
+ - identify==2.6.9
63
+ - idna==3.10
64
+ - imageio==2.36.1
65
+ - imagesize==1.4.1
66
+ - importlib-metadata==4.13.0
67
+ - importlib-resources==6.5.2
68
+ - ipdb==0.13.13
69
+ - ipython==8.18.1
70
+ - jedi==0.19.2
71
+ - jinja2==3.1.6
72
+ - joblib==1.4.2
73
+ - kiwisolver==1.4.7
74
+ - lazy-loader==0.4
75
+ - lit==18.1.8
76
+ - littleutils==0.2.4
77
+ - llvmlite==0.40.1
78
+ - mako==1.3.9
79
+ - markupsafe==3.0.2
80
+ - matplotlib==3.6.0
81
+ - matplotlib-inline==0.1.7
82
+ - mccabe==0.7.0
83
+ - miniful==0.0.6
84
+ - mpmath==1.3.0
85
+ - multidict==6.1.0
86
+ - networkx==3.0
87
+ - ninja==1.11.1.4
88
+ - nodeenv==1.9.1
89
+ - numba==0.57.1
90
+ - numpy==1.23.5
91
+ - nvidia-cublas-cu11==11.10.3.66
92
+ - nvidia-cublas-cu12==12.1.3.1
93
+ - nvidia-cuda-cupti-cu11==11.7.101
94
+ - nvidia-cuda-cupti-cu12==12.1.105
95
+ - nvidia-cuda-nvrtc-cu11==11.7.99
96
+ - nvidia-cuda-nvrtc-cu12==12.1.105
97
+ - nvidia-cuda-runtime-cu11==11.7.99
98
+ - nvidia-cuda-runtime-cu12==12.1.105
99
+ - nvidia-cudnn-cu11==8.5.0.96
100
+ - nvidia-cudnn-cu12==8.9.2.26
101
+ - nvidia-cufft-cu11==10.9.0.58
102
+ - nvidia-cufft-cu12==11.0.2.54
103
+ - nvidia-curand-cu11==10.2.10.91
104
+ - nvidia-curand-cu12==10.3.2.106
105
+ - nvidia-cusolver-cu11==11.4.0.1
106
+ - nvidia-cusolver-cu12==11.4.5.107
107
+ - nvidia-cusparse-cu11==11.7.4.91
108
+ - nvidia-cusparse-cu12==12.1.0.106
109
+ - nvidia-cusparselt-cu12==0.6.2
110
+ - nvidia-nccl-cu11==2.14.3
111
+ - nvidia-nccl-cu12==2.20.5
112
+ - nvidia-nvjitlink-cu12==12.4.127
113
+ - nvidia-nvtx-cu11==11.7.91
114
+ - nvidia-nvtx-cu12==12.1.105
115
+ - ogb==1.3.6
116
+ - optuna==2.4.0
117
+ - outdated==0.2.2
118
+ - packaging==24.2
119
+ - pandas==2.2.3
120
+ - parso==0.8.4
121
+ - pbr==6.1.1
122
+ - pexpect==4.9.0
123
+ - pillow==9.4.0
124
+ - pip==25.0.1
125
+ - platformdirs==4.3.7
126
+ - pre-commit==4.2.0
127
+ - prettytable==3.16.0
128
+ - prompt-toolkit==3.0.48
129
+ - propcache==0.2.1
130
+ - protobuf==3.20.3
131
+ - psutil==6.1.1
132
+ - ptyprocess==0.7.0
133
+ - pure-eval==0.2.3
134
+ - pycodestyle==2.12.1
135
+ - pyflakes==3.2.0
136
+ - pyfume==0.3.1
137
+ - pygments==2.19.1
138
+ - pyparsing==3.2.3
139
+ - pyperclip==1.9.0
140
+ - python-dateutil==2.9.0.post0
141
+ - pytz==2025.2
142
+ - pyyaml==6.0.2
143
+ - recommonmark==0.7.1
144
+ - regex==2024.11.6
145
+ - requests==2.32.3
146
+ - safetensors==0.5.3
147
+ - scikit-image==0.24.0
148
+ - scikit-learn==1.2.0
149
+ - scipy==1.11.3
150
+ - sentencepiece==0.2.0
151
+ - setuptools==78.1.0
152
+ - simpful==2.12.0
153
+ - six==1.17.0
154
+ - smart-open==7.1.0
155
+ - snowballstemmer==2.2.0
156
+ - sphinx==7.3.7
157
+ - sphinxcontrib-applehelp==2.0.0
158
+ - sphinxcontrib-devhelp==2.0.0
159
+ - sphinxcontrib-htmlhelp==2.1.0
160
+ - sphinxcontrib-jsmath==1.0.1
161
+ - sphinxcontrib-qthelp==2.0.0
162
+ - sphinxcontrib-serializinghtml==2.0.0
163
+ - sqlalchemy==2.0.40
164
+ - stack-data==0.6.3
165
+ - stevedore==5.4.1
166
+ - sympy==1.13.1
167
+ - tabulate==0.9.0
168
+ - tensorboardx==2.6
169
+ - texttable==1.6.7
170
+ - threadpoolctl==3.6.0
171
+ - tifffile==2024.8.30
172
+ - tokenizers==0.21.1
173
+ - tomli==2.2.1
174
+ - torch==2.3.0+cu121
175
+ - torch-cluster==1.6.3+pt23cu121
176
+ - torch-geometric==2.6.1
177
+ - torch-scatter==2.1.2+pt23cu121
178
+ - torch-sparse==0.6.18+pt23cu121
179
+ - torch-spline-conv==1.2.2+pt23cu121
180
+ - torchvision==0.18.0+cu121
181
+ - tqdm==4.64.1
182
+ - traitlets==5.14.3
183
+ - transformers==4.50.2
184
+ - triton==2.3.0
185
+ - typing-extensions==4.13.0
186
+ - tzdata==2025.2
187
+ - urllib3==2.3.0
188
+ - virtualenv==20.29.3
189
+ - wcwidth==0.2.13
190
+ - wrapt==1.17.2
191
+ - yarl==1.18.3
192
+ - zipp==3.21.0
193
+ prefix: /home/yiren/new_ssd2/chunhui/miniconda/envs/cgscore
examples/graph/cgscore_experiments/attack_method/attack_minmax.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+ import torch.optim as optim
5
+ from deeprobust.graph.defense import GCN
6
+ from deeprobust.graph.global_attack import MinMax
7
+ from deeprobust.graph.utils import *
8
+ from deeprobust.graph.data import Dataset
9
+ import argparse
10
+ import os
11
+
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument('--seed', type=int, default=15, help='Random seed.')
14
+ parser.add_argument('--epochs', type=int, default=200,
15
+ help='Number of epochs to train.')
16
+ parser.add_argument('--lr', type=float, default=0.01,
17
+ help='Initial learning rate.')
18
+ parser.add_argument('--weight_decay', type=float, default=5e-4,
19
+ help='Weight decay (L2 loss on parameters).')
20
+ parser.add_argument('--hidden', type=int, default=16,
21
+ help='Number of hidden units.')
22
+ parser.add_argument('--dropout', type=float, default=0.5,
23
+ help='Dropout rate (1 - keep probability).')
24
+
25
+ parser.add_argument('--dataset', type=str, default='cora', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed', 'Flickr'], help='dataset')
26
+ parser.add_argument('--ptb_rate', type=float, default=0.25, help='pertubation rate')
27
+ parser.add_argument('--ptb_type', type=str, default='minmax', choices=['clean', 'meta', 'dice', 'minmax', 'pgd', 'random'], help='attack type')
28
+ parser.add_argument('--model', type=str, default='min-max', choices=['PGD', 'min-max'], help='model variant')
29
+
30
+ args = parser.parse_args()
31
+
32
+ device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
33
+
34
+ np.random.seed(args.seed)
35
+ torch.manual_seed(args.seed)
36
+ if device != 'cpu':
37
+ torch.cuda.manual_seed(args.seed)
38
+
39
+ # data = Dataset(root='/tmp/', name=args.dataset, setting='prognn')
40
+ data = Dataset(root='/tmp/', name=args.dataset)
41
+ adj, features, labels = data.adj, data.features, data.labels
42
+ # features = normalize_feature(features)
43
+
44
+ idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
45
+ idx_unlabeled = np.union1d(idx_val, idx_test)
46
+
47
+ perturbations = int(args.ptb_rate * (adj.sum()//2))
48
+
49
+ adj, features, labels = preprocess(adj, features, labels, preprocess_adj=False)
50
+
51
+ # Setup Victim Model
52
+ victim_model = GCN(nfeat=features.shape[1], nclass=labels.max().item()+1, nhid=16,
53
+ dropout=0.5, weight_decay=5e-4, device=device)
54
+
55
+ victim_model = victim_model.to(device)
56
+ victim_model.fit(features, adj, labels, idx_train)
57
+
58
+ # Setup Attack Model
59
+
60
+ model = MinMax(model=victim_model, nnodes=adj.shape[0], loss_type='CE', device=device)
61
+
62
+ model = model.to(device)
63
+
64
+ def test(adj):
65
+ ''' test on GCN '''
66
+
67
+ # adj = normalize_adj_tensor(adj)
68
+ gcn = GCN(nfeat=features.shape[1],
69
+ nhid=args.hidden,
70
+ nclass=labels.max().item() + 1,
71
+ dropout=args.dropout, device=device)
72
+ gcn = gcn.to(device)
73
+ gcn.fit(features, adj, labels, idx_train) # train without model picking
74
+ # gcn.fit(features, adj, labels, idx_train, idx_val) # train with validation model picking
75
+ output = gcn.output.cpu()
76
+ loss_test = F.nll_loss(output[idx_test], labels[idx_test])
77
+ acc_test = accuracy(output[idx_test], labels[idx_test])
78
+ print("Test set results:",
79
+ "loss= {:.4f}".format(loss_test.item()),
80
+ "accuracy= {:.4f}".format(acc_test.item()))
81
+
82
+ return acc_test.item()
83
+
84
+
85
+ def main():
86
+ model.attack(features, adj, labels, idx_train, perturbations)
87
+ print('=== testing GCN on original(clean) graph ===')
88
+ test(adj)
89
+ modified_adj = model.modified_adj
90
+ # modified_features = model.modified_features
91
+
92
+ save_dir = f"../attacked_adj/{args.dataset}"
93
+ os.makedirs(save_dir, exist_ok=True)
94
+
95
+ save_path = os.path.join(save_dir, f"{args.ptb_type}_{args.dataset}_{args.ptb_rate}.pt")
96
+ torch.save(modified_adj, save_path)
97
+
98
+ test(modified_adj)
99
+
100
+ # # if you want to save the modified adj/features, uncomment the code below
101
+ # model.save_adj(root='./', name=f'mod_adj')
102
+ # model.save_features(root='./', name='mod_features')
103
+
104
+ if __name__ == '__main__':
105
+ main()
106
+
examples/graph/cgscore_experiments/attack_method/attack_nettack.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+ import torch.optim as optim
5
+ from deeprobust.graph.defense import GCN
6
+ from deeprobust.graph.targeted_attack import Nettack
7
+ from deeprobust.graph.utils import *
8
+ from deeprobust.graph.data import Dataset
9
+ import argparse
10
+ from tqdm import tqdm
11
+
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument('--seed', type=int, default=15, help='Random seed.')
14
+ parser.add_argument('--dataset', type=str, default='citeseer', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'], help='dataset')
15
+
16
+
17
+ args = parser.parse_args()
18
+ args.cuda = torch.cuda.is_available()
19
+ print('cuda: %s' % args.cuda)
20
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
21
+
22
+ np.random.seed(args.seed)
23
+ torch.manual_seed(args.seed)
24
+ if args.cuda:
25
+ torch.cuda.manual_seed(args.seed)
26
+
27
+ data = Dataset(root='/tmp/', name=args.dataset)
28
+ adj, features, labels = data.adj, data.features, data.labels
29
+
30
+ idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
31
+
32
+ idx_unlabeled = np.union1d(idx_val, idx_test)
33
+
34
+ # Setup Surrogate model
35
+ surrogate = GCN(nfeat=features.shape[1], nclass=labels.max().item()+1,
36
+ nhid=16, dropout=0, with_relu=False, with_bias=False, device=device)
37
+
38
+ surrogate = surrogate.to(device)
39
+ surrogate.fit(features, adj, labels, idx_train, idx_val, patience=30)
40
+
41
+ # Setup Attack Model
42
+ target_node = 0
43
+ assert target_node in idx_unlabeled
44
+
45
+ model = Nettack(surrogate, nnodes=adj.shape[0], attack_structure=True, attack_features=True, device=device)
46
+ model = model.to(device)
47
+
48
+ def main():
49
+ degrees = adj.sum(0).A1
50
+ # How many perturbations to perform. Default: Degree of the node
51
+ n_perturbations = int(degrees[target_node])
52
+
53
+ # direct attack
54
+ model.attack(features, adj, labels, target_node, n_perturbations)
55
+ # # indirect attack/ influencer attack
56
+ # model.attack(features, adj, labels, target_node, n_perturbations, direct=False, n_influencers=5)
57
+ modified_adj = model.modified_adj
58
+ modified_features = model.modified_features
59
+ print(model.structure_perturbations)
60
+ print('=== testing GCN on original(clean) graph ===')
61
+ test(adj, features, target_node)
62
+ print('=== testing GCN on perturbed graph ===')
63
+ test(modified_adj, modified_features, target_node)
64
+
65
+ def test(adj, features, target_node):
66
+ ''' test on GCN '''
67
+ gcn = GCN(nfeat=features.shape[1],
68
+ nhid=16,
69
+ nclass=labels.max().item() + 1,
70
+ dropout=0.5, device=device)
71
+
72
+ gcn = gcn.to(device)
73
+
74
+ gcn.fit(features, adj, labels, idx_train, idx_val, patience=30)
75
+
76
+ gcn.eval()
77
+ output = gcn.predict()
78
+ probs = torch.exp(output[[target_node]])[0]
79
+ print('Target node probs: {}'.format(probs.detach().cpu().numpy()))
80
+ acc_test = accuracy(output[idx_test], labels[idx_test])
81
+
82
+ print("Overall test set results:",
83
+ "accuracy= {:.4f}".format(acc_test.item()))
84
+
85
+ return acc_test.item()
86
+
87
+ def select_nodes(target_gcn=None):
88
+ '''
89
+ selecting nodes as reported in nettack paper:
90
+ (i) the 10 nodes with highest margin of classification, i.e. they are clearly correctly classified,
91
+ (ii) the 10 nodes with lowest margin (but still correctly classified) and
92
+ (iii) 20 more nodes randomly
93
+ '''
94
+
95
+ if target_gcn is None:
96
+ target_gcn = GCN(nfeat=features.shape[1],
97
+ nhid=16,
98
+ nclass=labels.max().item() + 1,
99
+ dropout=0.5, device=device)
100
+ target_gcn = target_gcn.to(device)
101
+ target_gcn.fit(features, adj, labels, idx_train, idx_val, patience=30)
102
+ target_gcn.eval()
103
+ output = target_gcn.predict()
104
+
105
+ margin_dict = {}
106
+ for idx in idx_test:
107
+ margin = classification_margin(output[idx], labels[idx])
108
+ if margin < 0: # only keep the nodes correctly classified
109
+ continue
110
+ margin_dict[idx] = margin
111
+ sorted_margins = sorted(margin_dict.items(), key=lambda x:x[1], reverse=True)
112
+ high = [x for x, y in sorted_margins[: 10]]
113
+ low = [x for x, y in sorted_margins[-10: ]]
114
+ other = [x for x, y in sorted_margins[10: -10]]
115
+ other = np.random.choice(other, 20, replace=False).tolist()
116
+
117
+ return high + low + other
118
+
119
+ def multi_test_poison_accumulative():
120
+ degrees = adj.sum(0).A1
121
+ node_list = select_nodes()
122
+ print("=== [Poisoning Accumulative] Attacking {} nodes sequentially ===".format(len(node_list)))
123
+
124
+ # 初始化结构
125
+ adj_attacked = adj.copy()
126
+
127
+ for target_node in tqdm(node_list):
128
+ n_perturbations = int(degrees[target_node])
129
+ model = Nettack(surrogate, nnodes=adj.shape[0],
130
+ attack_structure=True, attack_features=False, device=device)
131
+ model = model.to(device)
132
+ model.attack(features, adj_attacked, labels, target_node, n_perturbations, verbose=False)
133
+ adj_attacked = model.modified_adj
134
+
135
+ gcn = GCN(nfeat=features.shape[1],
136
+ nhid=16,
137
+ nclass=labels.max().item() + 1,
138
+ dropout=0.5, device=device)
139
+ gcn = gcn.to(device)
140
+ gcn.fit(features, adj_attacked, labels, idx_train, idx_val, patience=30)
141
+ gcn.eval()
142
+ output = gcn.predict()
143
+
144
+ # 在 node_list 上评估被误分类的节点数
145
+ preds = output.argmax(1)
146
+ node_preds = preds[node_list]
147
+ node_labels = labels[node_list]
148
+
149
+ # 转为 tensor,确保正确
150
+ correct_mask = torch.tensor(node_preds == node_labels, dtype=torch.bool)
151
+ correct = correct_mask.sum().item()
152
+ print("Accuracy on attacked nodes: {:.2f}%".format(100 * correct / len(node_list)))
153
+
154
+ def single_test(adj, features, target_node, gcn=None):
155
+ if gcn is None:
156
+ # test on GCN (poisoning attack)
157
+ gcn = GCN(nfeat=features.shape[1],
158
+ nhid=16,
159
+ nclass=labels.max().item() + 1,
160
+ dropout=0.5, device=device)
161
+
162
+ gcn = gcn.to(device)
163
+
164
+ gcn.fit(features, adj, labels, idx_train, idx_val, patience=30)
165
+ gcn.eval()
166
+ output = gcn.predict()
167
+ else:
168
+ # test on GCN (evasion attack)
169
+ output = gcn.predict(features, adj)
170
+ probs = torch.exp(output[[target_node]])
171
+
172
+ # acc_test = accuracy(output[[target_node]], labels[target_node])
173
+ acc_test = (output.argmax(1)[target_node] == labels[target_node])
174
+ return acc_test.item()
175
+
176
+ def multi_test_evasion():
177
+ # test on 40 nodes on evasion attack
178
+ target_gcn = GCN(nfeat=features.shape[1],
179
+ nhid=16,
180
+ nclass=labels.max().item() + 1,
181
+ dropout=0.5, device=device)
182
+
183
+ target_gcn = target_gcn.to(device)
184
+
185
+ target_gcn.fit(features, adj, labels, idx_train, idx_val, patience=30)
186
+
187
+ cnt = 0
188
+ degrees = adj.sum(0).A1
189
+ node_list = select_nodes(target_gcn)
190
+ num = len(node_list)
191
+
192
+ print('=== [Evasion] Attacking %s nodes respectively ===' % num)
193
+ for target_node in tqdm(node_list):
194
+ n_perturbations = int(degrees[target_node])
195
+ model = Nettack(surrogate, nnodes=adj.shape[0], attack_structure=True, attack_features=True, device=device)
196
+ model = model.to(device)
197
+ model.attack(features, adj, labels, target_node, n_perturbations, verbose=False)
198
+ modified_adj = model.modified_adj
199
+ modified_features = model.modified_features
200
+
201
+ acc = single_test(modified_adj, modified_features, target_node, gcn=target_gcn)
202
+ if acc == 0:
203
+ cnt += 1
204
+ print('misclassification rate : %s' % (cnt/num))
205
+
206
+ if __name__ == '__main__':
207
+ # main()
208
+ # multi_test_poison()
209
+ # multi_test_poison()
210
+ multi_test_poison_accumulative()
211
+
212
+
examples/graph/cgscore_experiments/defense_method/GAT.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+ import scipy.sparse as sp
5
+ from deeprobust.graph.defense import GCNJaccard, GCN, GAT
6
+ from deeprobust.graph.utils import *
7
+ from deeprobust.graph.data import Dataset, PrePtbDataset, Dpr2Pyg
8
+ import argparse
9
+
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument('--seed', type=int, default=15, help='Random seed.')
12
+ parser.add_argument('--dataset', type=str, default='cora', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'], help='dataset')
13
+ parser.add_argument('--ptb_rate', type=float, default=0.0, help='pertubation rate')
14
+ parser.add_argument('--ptb_type', type=str, default='clean', choices=['clean', 'meta', 'dice', 'minmax', 'pgd', 'random'], help='attack type')
15
+ parser.add_argument('--hidden', type=int, default=16, help='Number of hidden units.')
16
+ parser.add_argument('--dropout', type=float, default=0.5, help='Dropout rate (1 - keep probability).')
17
+
18
+ args = parser.parse_args()
19
+ args.cuda = torch.cuda.is_available()
20
+ print('cuda: %s' % args.cuda)
21
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
22
+
23
+ # make sure you use the same data splits as you generated attacks
24
+ np.random.seed(args.seed)
25
+ if args.cuda:
26
+ torch.cuda.manual_seed(args.seed)
27
+
28
+ # Here the random seed is to split the train/val/test data,
29
+ # we need to set the random seed to be the same as that when you generate the perturbed graph
30
+ # data = Dataset(root='/tmp/', name=args.dataset, setting='nettack', seed=15)
31
+ # Or we can just use setting='prognn' to get the splits
32
+
33
+ # data = Dataset(root='/tmp/', name=args.dataset, setting='prognn')
34
+ data = Dataset(root='/tmp/', name=args.dataset)
35
+ adj, features, labels = data.adj, data.features, data.labels
36
+ idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
37
+ idx_unlabeled = np.union1d(idx_val, idx_test)
38
+ # print(type((adj)))
39
+ # adj, features, labels = preprocess(adj, features, labels, preprocess_adj=False)
40
+
41
+
42
+ ## load from attacked_adj
43
+ ptb_path = f"../attacked_adj/{args.dataset}/{args.ptb_type}_{args.dataset}_{args.ptb_rate}.pt"
44
+ perturbed_adj = torch.load(ptb_path)
45
+ print("type(perturbed_adj)", type(perturbed_adj))
46
+ perturbed_adj = sp.csr_matrix(perturbed_adj.cpu().numpy())
47
+
48
+ gat = GAT(nfeat=features.shape[1],
49
+ nhid=8, heads=8,
50
+ nclass=labels.max().item() + 1,
51
+ dropout=0.5, device=device)
52
+ gat = gat.to(device)
53
+
54
+ # test on clean graph
55
+ print('==================')
56
+ print('=== train on clean graph ===')
57
+
58
+ pyg_data = Dpr2Pyg(data)
59
+ pyg_data.update_edge_index(perturbed_adj) # inplace operation
60
+ gat.fit(pyg_data, train_iters=200, verbose=True) # train with earlystopping
61
+ gat.test()
examples/graph/cgscore_experiments/defense_method/GCN.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+ from deeprobust.graph.defense import GCNJaccard, GCN
5
+ from deeprobust.graph.utils import *
6
+ from deeprobust.graph.data import Dataset, PrePtbDataset
7
+ import argparse
8
+
9
+ parser = argparse.ArgumentParser()
10
+ parser.add_argument('--seed', type=int, default=15, help='Random seed.')
11
+ parser.add_argument('--dataset', type=str, default='cora', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'], help='dataset')
12
+ parser.add_argument('--ptb_rate', type=float, default=0.0, help='pertubation rate')
13
+ parser.add_argument('--ptb_type', type=str, default='clean', choices=['clean', 'meta', 'dice', 'minmax', 'pgd', 'random'], help='attack type')
14
+ parser.add_argument('--hidden', type=int, default=16, help='Number of hidden units.')
15
+ parser.add_argument('--dropout', type=float, default=0.5, help='Dropout rate (1 - keep probability).')
16
+
17
+ args = parser.parse_args()
18
+ args.cuda = torch.cuda.is_available()
19
+ print('cuda: %s' % args.cuda)
20
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
21
+
22
+ # make sure you use the same data splits as you generated attacks
23
+ np.random.seed(args.seed)
24
+ if args.cuda:
25
+ torch.cuda.manual_seed(args.seed)
26
+
27
+ # Here the random seed is to split the train/val/test data,
28
+ # we need to set the random seed to be the same as that when you generate the perturbed graph
29
+ # data = Dataset(root='/tmp/', name=args.dataset, setting='nettack', seed=15)
30
+ # Or we can just use setting='prognn' to get the splits
31
+
32
+ # data = Dataset(root='/tmp/', name=args.dataset, setting='prognn')
33
+ data = Dataset(root='/tmp/', name=args.dataset)
34
+ adj, features, labels = data.adj, data.features, data.labels
35
+ idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
36
+ idx_unlabeled = np.union1d(idx_val, idx_test)
37
+ adj, features, labels = preprocess(adj, features, labels, preprocess_adj=False)
38
+
39
+
40
+ ## load from attacked_adj
41
+ ptb_path = f"../attacked_adj/{args.dataset}/{args.ptb_type}_{args.dataset}_{args.ptb_rate}.pt"
42
+ perturbed_adj = torch.load(ptb_path)
43
+
44
+ def test(adj):
45
+ ''' test on GCN '''
46
+
47
+ # adj = normalize_adj_tensor(adj)
48
+ gcn = GCN(nfeat=features.shape[1],
49
+ nhid=args.hidden,
50
+ nclass=labels.max().item() + 1,
51
+ dropout=args.dropout, device=device)
52
+ gcn = gcn.to(device)
53
+ gcn.fit(features, adj, labels, idx_train) # train without model picking
54
+ # gcn.fit(features, adj, labels, idx_train, idx_val) # train with validation model picking
55
+ output = gcn.output.cpu()
56
+ loss_test = F.nll_loss(output[idx_test], labels[idx_test])
57
+ acc_test = accuracy(output[idx_test], labels[idx_test])
58
+ print("Test set results:",
59
+ "loss= {:.4f}".format(loss_test.item()),
60
+ "accuracy= {:.4f}".format(acc_test.item()))
61
+
62
+ return acc_test.item()
63
+
64
+ def main():
65
+
66
+ test(perturbed_adj)
67
+
68
+ # # if you want to save the modified adj/features, uncomment the code below
69
+ # model.save_adj(root='./', name=f'mod_adj')
70
+ # model.save_features(root='./', name='mod_features')
71
+
72
+ if __name__ == '__main__':
73
+ main()
examples/graph/cgscore_experiments/defense_method/GCNJaccard.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+ from deeprobust.graph.defense import GCNJaccard, GCN
5
+ from deeprobust.graph.utils import *
6
+ from deeprobust.graph.data import Dataset, PrePtbDataset
7
+ import argparse
8
+
9
+ parser = argparse.ArgumentParser()
10
+ parser.add_argument('--seed', type=int, default=15, help='Random seed.')
11
+ parser.add_argument('--dataset', type=str, default='cora', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'], help='dataset')
12
+ parser.add_argument('--ptb_rate', type=float, default=0.0, help='pertubation rate')
13
+ parser.add_argument('--ptb_type', type=str, default='clean', choices=['clean', 'meta', 'dice', 'minmax', 'pgd', 'random'], help='attack type')
14
+ parser.add_argument('--hidden', type=int, default=16, help='Number of hidden units.')
15
+ parser.add_argument('--dropout', type=float, default=0.5, help='Dropout rate (1 - keep probability).')
16
+ parser.add_argument('--threshold', type=float, default=0.1, help='jaccard coeficient')
17
+
18
+ args = parser.parse_args()
19
+ args.cuda = torch.cuda.is_available()
20
+ print('cuda: %s' % args.cuda)
21
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
22
+
23
+ # make sure you use the same data splits as you generated attacks
24
+ np.random.seed(args.seed)
25
+ if args.cuda:
26
+ torch.cuda.manual_seed(args.seed)
27
+
28
+ # Here the random seed is to split the train/val/test data,
29
+ # we need to set the random seed to be the same as that when you generate the perturbed graph
30
+ # data = Dataset(root='/tmp/', name=args.dataset, setting='nettack', seed=15)
31
+ # Or we can just use setting='prognn' to get the splits
32
+
33
+ # data = Dataset(root='/tmp/', name=args.dataset, setting='prognn')
34
+ data = Dataset(root='/tmp/', name=args.dataset)
35
+ adj, features, labels = data.adj, data.features, data.labels
36
+ idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
37
+ idx_unlabeled = np.union1d(idx_val, idx_test)
38
+ # adj, features, labels = preprocess(adj, features, labels, preprocess_adj=False)
39
+
40
+
41
+ ## load from attacked_adj
42
+ ptb_path = f"../attacked_adj/{args.dataset}/{args.ptb_type}_{args.dataset}_{args.ptb_rate}.pt"
43
+ perturbed_adj = torch.load(ptb_path)
44
+ perturbed_adj = perturbed_adj
45
+
46
+ def test_jaccard(adj):
47
+ ''' test on GCN '''
48
+
49
+ # adj = normalize_adj_tensor(adj)
50
+ gcn = GCNJaccard(nfeat=features.shape[1],
51
+ nhid=args.hidden,
52
+ nclass=labels.max().item() + 1,
53
+ dropout=args.dropout, device=device)
54
+ gcn = gcn.to(device)
55
+ gcn.fit(features, adj, labels, idx_train, idx_val, threshold=args.threshold)
56
+ gcn.eval()
57
+ gcn.test(idx_test)
58
+
59
+
60
+ def main():
61
+
62
+ test_jaccard(perturbed_adj)
63
+ # # if you want to save the modified adj/features, uncomment the code below
64
+ # model.save_adj(root='./', name=f'mod_adj')
65
+ # model.save_features(root='./', name='mod_features')
66
+
67
+ if __name__ == '__main__':
68
+ main()
examples/graph/cgscore_experiments/defense_method/GCNSVD.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+ from deeprobust.graph.defense import GCNJaccard, GCN, GCNSVD
5
+ from deeprobust.graph.utils import *
6
+ from deeprobust.graph.data import Dataset, PrePtbDataset
7
+ import argparse
8
+
9
+ parser = argparse.ArgumentParser()
10
+ parser.add_argument('--seed', type=int, default=15, help='Random seed.')
11
+ parser.add_argument('--dataset', type=str, default='cora', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'], help='dataset')
12
+ parser.add_argument('--ptb_rate', type=float, default=0.0, help='pertubation rate')
13
+ parser.add_argument('--ptb_type', type=str, default='clean', choices=['clean', 'meta', 'dice', 'minmax', 'pgd', 'random'], help='attack type')
14
+ parser.add_argument('--hidden', type=int, default=16, help='Number of hidden units.')
15
+ parser.add_argument('--dropout', type=float, default=0.5, help='Dropout rate (1 - keep probability).')
16
+ parser.add_argument('--k', type=int, default=50, help='Truncated Components.')
17
+
18
+ args = parser.parse_args()
19
+ args.cuda = torch.cuda.is_available()
20
+ print('cuda: %s' % args.cuda)
21
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
22
+
23
+ # make sure you use the same data splits as you generated attacks
24
+ np.random.seed(args.seed)
25
+ if args.cuda:
26
+ torch.cuda.manual_seed(args.seed)
27
+
28
+ # Here the random seed is to split the train/val/test data,
29
+ # we need to set the random seed to be the same as that when you generate the perturbed graph
30
+ # data = Dataset(root='/tmp/', name=args.dataset, setting='nettack', seed=15)
31
+ # Or we can just use setting='prognn' to get the splits
32
+
33
+ # data = Dataset(root='/tmp/', name=args.dataset, setting='prognn')
34
+ data = Dataset(root='/tmp/', name=args.dataset)
35
+ adj, features, labels = data.adj, data.features, data.labels
36
+ idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
37
+ # adj, features, labels = preprocess(adj, features, labels, preprocess_adj=False)
38
+
39
+
40
+ ## load from attacked_adj
41
+ ptb_path = f"../attacked_adj/{args.dataset}/{args.ptb_type}_{args.dataset}_{args.ptb_rate}.pt"
42
+ perturbed_adj = torch.load(ptb_path)
43
+
44
+ def test_svd(adj):
45
+
46
+ ''' test on GCNSVD '''
47
+
48
+ gcn = GCNSVD(nfeat=features.shape[1], nclass=labels.max()+1,nhid=args.hidden, device=device)
49
+ gcn = gcn.to(device)
50
+ gcn.fit(features, perturbed_adj, labels, idx_train, idx_val, k=args.k, verbose=True)
51
+ gcn.eval()
52
+ gcn.test(idx_test)
53
+
54
+
55
+ def main():
56
+
57
+ test_svd(perturbed_adj)
58
+ # # if you want to save the modified adj/features, uncomment the code below
59
+ # model.save_adj(root='./', name=f'mod_adj')
60
+ # model.save_features(root='./', name='mod_features')
61
+
62
+ if __name__ == '__main__':
63
+ main()
examples/graph/cgscore_experiments/defense_method/GNNGuard.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+ from deeprobust.graph.defense.gcn_guard import GCNGuard
5
+ from deeprobust.graph.utils import *
6
+ from deeprobust.graph.data import Dataset
7
+ from deeprobust.graph.data import PtbDataset, PrePtbDataset
8
+
9
+ import argparse
10
+ from scipy import sparse
11
+
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument('--seed', type=int, default=15, help='Random seed')
14
+ parser.add_argument('--GNNGuard', type=bool, default=True, choices=[True, False])
15
+ parser.add_argument('--dataset', type=str, default='cora', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed', 'flickr'], help='dataset')
16
+ parser.add_argument('--ptb_rate', type=float, default=0.0, help='pertubation rate')
17
+ parser.add_argument('--ptb_type', type=str, default='clean', choices=['clean', 'meta', 'dice', 'minmax', 'pgd', 'random'], help='attack type')
18
+
19
+ args = parser.parse_args()
20
+ args.cuda = torch.cuda.is_available()
21
+ print('cuda: %s' % args.cuda)
22
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
23
+
24
+ np.random.seed(args.seed)
25
+ torch.manual_seed(args.seed)
26
+ if args.cuda:
27
+ torch.cuda.manual_seed(args.seed)
28
+
29
+ # data = Dataset(root='/tmp/', name=args.dataset, setting='prognn')
30
+ data = Dataset(root='/tmp/', name=args.dataset)
31
+ adj, features, labels = data.adj, data.features, data.labels
32
+ idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
33
+
34
+ ptb_path = f"../attacked_adj/{args.dataset}/{args.ptb_type}_{args.dataset}_{args.ptb_rate}.pt"
35
+ perturbed_adj = torch.load(ptb_path)
36
+ perturbed_adj = sp.csr_matrix(perturbed_adj.to('cpu').numpy())
37
+
38
+
39
+ def test(adj):
40
+ # """defense models"""
41
+ ''' testing model '''
42
+ gcn = GCNGuard(nfeat=features.shape[1], nclass=labels.max().item() + 1, nhid=16,
43
+ dropout=0.5, with_relu=False, with_bias=True, weight_decay=5e-4, device=device)
44
+ gcn = gcn.to(device)
45
+
46
+ gcn.fit(features, adj, labels, idx_train, train_iters=200, idx_val=idx_val, idx_test=idx_test, verbose=True, attention=args.GNNGuard)
47
+ gcn.eval()
48
+
49
+ # classifier.fit(features, adj, labels, idx_train, idx_val) # train with validation model picking
50
+ acc_test, _ = gcn.test(idx_test)
51
+ # acc_test = classifier.test(idx_test)
52
+ return acc_test
53
+
54
+ def main():
55
+
56
+ # print('=== testing GCN on original(clean) graph ===')
57
+ # test(adj)
58
+ #
59
+ print('=== testing GCN on Mettacked graph ===')
60
+ test(perturbed_adj)
61
+
62
+ if __name__ == '__main__':
63
+ main()
64
+
examples/graph/cgscore_experiments/defense_method/ProGNN.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+ from deeprobust.graph.defense import GCNJaccard, GCN, ProGNN
5
+ from deeprobust.graph.utils import *
6
+ from deeprobust.graph.data import Dataset, PrePtbDataset
7
+ import argparse
8
+
9
+ parser = argparse.ArgumentParser()
10
+ parser.add_argument('--seed', type=int, default=15, help='Random seed.')
11
+ parser.add_argument('--dataset', type=str, default='cora', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'], help='dataset')
12
+ parser.add_argument('--ptb_type', type=str, default='clean', choices=['clean', 'meta', 'dice', 'minmax', 'pgd', 'random'], help='attack type')
13
+ parser.add_argument('--ptb_rate', type=float, default=0.0, help='pertubation rate')
14
+ parser.add_argument('--epochs', type=int, default=200, help='Number of epochs to train.')
15
+
16
+ parser.add_argument('--hidden', type=int, default=16, help='Number of hidden units.')
17
+ parser.add_argument('--dropout', type=float, default=0.5, help='Dropout rate (1 - keep probability).')
18
+ parser.add_argument('--debug', action='store_true',default=False, help='debug mode')
19
+ parser.add_argument('--only_gcn', action='store_true',default=False, help='test the performance of gcn without other components')
20
+ # parser.add_argument('--no-cuda', action='store_true', default=False, help='Disables CUDA training.')
21
+ parser.add_argument('--lr', type=float, default=0.01, help='Initial learning rate.')
22
+ parser.add_argument('--weight_decay', type=float, default=5e-4, help='Weight decay (L2 loss on parameters).')
23
+ parser.add_argument('--alpha', type=float, default=5e-4, help='weight of l1 norm')
24
+ parser.add_argument('--beta', type=float, default=1.5, help='weight of nuclear norm')
25
+ parser.add_argument('--gamma', type=float, default=1, help='weight of l2 norm')
26
+ parser.add_argument('--lambda_', type=float, default=0, help='weight of feature smoothing')
27
+ parser.add_argument('--phi', type=float, default=0, help='weight of symmetric loss')
28
+ parser.add_argument('--inner_steps', type=int, default=2, help='steps for inner optimization')
29
+ parser.add_argument('--outer_steps', type=int, default=1, help='steps for outer optimization')
30
+ parser.add_argument('--lr_adj', type=float, default=0.01, help='lr for training adj')
31
+ parser.add_argument('--symmetric', action='store_true', default=False, help='whether use symmetric matrix')
32
+
33
+ args = parser.parse_args()
34
+ args.cuda = torch.cuda.is_available()
35
+ device = torch.device("cuda:5" if args.cuda else "cpu")
36
+ print('Using device:', device)
37
+
38
+ # make sure you use the same data splits as you generated attacks
39
+ np.random.seed(args.seed)
40
+ if args.cuda:
41
+ torch.cuda.manual_seed(args.seed)
42
+
43
+ # Here the random seed is to split the train/val/test data,
44
+ # we need to set the random seed to be the same as that when you generate the perturbed graph
45
+ # data = Dataset(root='/tmp/', name=args.dataset, setting='nettack', seed=15)
46
+ # Or we can just use setting='prognn' to get the splits
47
+
48
+ # data = Dataset(root='/tmp/', name=args.dataset, setting='prognn')
49
+ data = Dataset(root='/tmp/', name=args.dataset)
50
+ adj, features, labels = data.adj, data.features, data.labels
51
+ idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
52
+ idx_unlabeled = np.union1d(idx_val, idx_test)
53
+ print(type(adj))
54
+ # adj, features, labels = preprocess(adj, features, labels, preprocess_adj=False)
55
+
56
+
57
+ ## load from attacked_adj
58
+ ptb_path = f"../attacked_adj/{args.dataset}/{args.ptb_type}_{args.dataset}_{args.ptb_rate}.pt"
59
+ perturbed_adj = torch.load(ptb_path)
60
+ perturbed_adj = sp.csr_matrix(perturbed_adj.to('cpu').numpy())
61
+
62
+ def test_prognn(features, adj, labels):
63
+
64
+
65
+ adj, features, labels = preprocess(adj, features, labels, preprocess_adj=False, device=device)
66
+
67
+ model = GCN(nfeat=features.shape[1],
68
+ nhid=args.hidden,
69
+ nclass=labels.max().item() + 1,
70
+ dropout=args.dropout, device=device).to(device)
71
+ prognn = ProGNN(model, args, device)
72
+ prognn.fit(features, adj, labels, idx_train, idx_val)
73
+ prognn.test(features, labels, idx_test)
74
+
75
+ def main():
76
+
77
+ test_prognn(features, perturbed_adj, labels)
78
+
79
+ if __name__ == '__main__':
80
+ main()
examples/graph/cgscore_experiments/defense_method/RGCN.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+ from deeprobust.graph.defense import RGCN
5
+ from deeprobust.graph.utils import *
6
+ from deeprobust.graph.data import Dataset, PrePtbDataset
7
+ import argparse
8
+
9
+ parser = argparse.ArgumentParser()
10
+ parser.add_argument('--seed', type=int, default=15, help='Random seed.')
11
+ parser.add_argument('--dataset', type=str, default='cora', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'], help='dataset')
12
+ parser.add_argument('--ptb_rate', type=float, default=0.0, help='pertubation rate')
13
+ parser.add_argument('--ptb_type', type=str, default='clean', choices=['clean', 'meta', 'dice', 'minmax', 'pgd', 'random'], help='attack type')
14
+ parser.add_argument('--hidden', type=int, default=16, help='Number of hidden units.')
15
+ parser.add_argument('--dropout', type=float, default=0.5, help='Dropout rate (1 - keep probability).')
16
+
17
+
18
+ args = parser.parse_args()
19
+ args.cuda = torch.cuda.is_available()
20
+ print('cuda: %s' % args.cuda)
21
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
22
+
23
+ # make sure you use the same data splits as you generated attacks
24
+ np.random.seed(args.seed)
25
+ if args.cuda:
26
+ torch.cuda.manual_seed(args.seed)
27
+
28
+ # Here the random seed is to split the train/val/test data,
29
+ # we need to set the random seed to be the same as that when you generate the perturbed graph
30
+ # data = Dataset(root='/tmp/', name=args.dataset, setting='nettack', seed=15)
31
+ # Or we can just use setting='prognn' to get the splits
32
+
33
+ # data = Dataset(root='/tmp/', name=args.dataset, setting='prognn')
34
+ data = Dataset(root='/tmp/', name=args.dataset)
35
+ adj, features, labels = data.adj, data.features, data.labels
36
+ idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
37
+
38
+ # adj, features, labels = preprocess(adj, features, labels, preprocess_adj=False)
39
+
40
+
41
+ ## load from attacked_adj
42
+ ptb_path = f"../attacked_adj/{args.dataset}/{args.ptb_type}_{args.dataset}_{args.ptb_rate}.pt"
43
+ perturbed_adj = torch.load(ptb_path)
44
+ perturbed_adj = perturbed_adj
45
+
46
+ def test_rgcn(adj):
47
+ ''' test on GCN '''
48
+
49
+ # adj = normalize_adj_tensor(adj)
50
+ gcn = RGCN(nnodes=adj.shape[0], nfeat=features.shape[1], nclass=labels.max()+1,
51
+ nhid=args.hidden, device=device)
52
+ gcn = gcn.to(device)
53
+ gcn.fit(features, adj, labels, idx_train, idx_val, train_iters=200, verbose=True)
54
+ gcn.eval()
55
+ gcn.test(idx_test)
56
+
57
+
58
+ def main():
59
+
60
+ test_rgcn(perturbed_adj)
61
+ # # if you want to save the modified adj/features, uncomment the code below
62
+ # model.save_adj(root='./', name=f'mod_adj')
63
+ # model.save_features(root='./', name='mod_features')
64
+
65
+ if __name__ == '__main__':
66
+ main()
examples/graph/cgscore_experiments/defense_method/cgscore.py ADDED
File without changes
examples/graph/cgscore_experiments/grb/grb_data.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch # pytorch backend
2
+ from grb.dataset import Dataset
3
+ from grb.model.torch import GCN
4
+ from grb.utils.trainer import Trainer
5
+
6
+ # Load data
7
+ # name: ["grb-cora", "grb-citeseer", "grb-aminer", "grb-reddit", "grb-flickr"].
8
+ # mode: [["easy", "medium", "hard", "full"]
9
+ # mode:
10
+
11
+ dataset = Dataset(name='grb-citeseer', mode='easy',feat_norm='None')
12
+
13
+ features = dataset.features # 注意:不是 tensor,而是 scipy.sparse
14
+ adj = dataset.adj # 是 numpy.ndarray 或 scipy.sparse
15
+ labels = dataset.labels
16
+ index = dataset.idx_train
17
+
18
+ print(type(features))
19
+
20
+ print(type(adj))
21
+ print(type(labels))
22
+ print(type(index))
23
+ # model = GCN(in_features=dataset.num_features,
24
+ # out_features=dataset.num_classes,
25
+ # hidden_features=[64, 64])
26
+ # # Training
27
+ # adam = torch.optim.Adam(model.parameters(), lr=0.01)
28
+ # trainer = Trainer(dataset=dataset, optimizer=adam,
29
+ # loss=torch.nn.functional.nll_loss)
30
+ # trainer.train(model=model, n_epoch=200, dropout=0.1,
31
+ # train_mode='inductive')
32
+
examples/graph/cgscore_save.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def calc_cg_score_gnn_with_sampling( # stable training and defense effect
2
+ A, X, labels, device, rep_num=1, unbalance_ratio=1, sub_term=False
3
+ ):
4
+ """
5
+ Optimized CG-score calculation with edge sampling.
6
+ """
7
+
8
+ N = A.shape[0]
9
+ cg_scores = {
10
+ "vi": np.zeros((N, N)),
11
+ "ab": np.zeros((N, N)),
12
+ "a2": np.zeros((N, N)),
13
+ "b2": np.zeros((N, N)),
14
+ "times": np.zeros((N, N)),
15
+ }
16
+
17
+ A = A.to(device)
18
+ X = X.to(device)
19
+ labels = labels.to(device)
20
+
21
+ @torch.no_grad()
22
+ def normalize(tensor):
23
+ return tensor / (torch.norm(tensor, dim=1, keepdim=True) + 1e-8)
24
+
25
+ for _ in range(rep_num):
26
+ AX = torch.matmul(A, X)
27
+ norm_AX = normalize(AX)
28
+
29
+ # Organize data by labels
30
+ dataset = defaultdict(list)
31
+ data_idx = defaultdict(list)
32
+ for i, label in enumerate(labels):
33
+ dataset[label.item()].append(norm_AX[i].unsqueeze(0))
34
+ data_idx[label.item()].append(i)
35
+
36
+ for label in dataset:
37
+ dataset[label] = torch.cat(dataset[label], dim=0)
38
+ data_idx[label] = torch.tensor(data_idx[label], dtype=torch.long, device=device)
39
+
40
+ # Cache negative samples
41
+ neg_samples_dict = {}
42
+ neg_indices_dict = {}
43
+ for label in dataset:
44
+ neg_samples = torch.cat([dataset[l] for l in dataset if l != label])
45
+ neg_indices = torch.cat([data_idx[l] for l in data_idx if l != label])
46
+ neg_samples_dict[label] = neg_samples
47
+ neg_indices_dict[label] = neg_indices
48
+
49
+ # for curr_label, curr_samples in dataset.items():
50
+ for curr_label, curr_samples in tqdm(dataset.items(), desc="Label groups"):
51
+ curr_indices = data_idx[curr_label]
52
+ curr_num = len(curr_samples)
53
+
54
+ chosen_curr_idx = np.random.choice(range(curr_num), curr_num, replace=False)
55
+ chosen_curr_samples = curr_samples[chosen_curr_idx]
56
+ chosen_curr_indices = curr_indices[chosen_curr_idx]
57
+
58
+ # Get negative samples
59
+ neg_samples = neg_samples_dict[curr_label]
60
+ neg_indices = neg_indices_dict[curr_label]
61
+ neg_num = min(int(curr_num * unbalance_ratio), len(neg_samples))
62
+ rand_idx = torch.randperm(len(neg_samples))[:neg_num]
63
+ chosen_neg_samples = neg_samples[rand_idx]
64
+ chosen_neg_indices = neg_indices[rand_idx]
65
+
66
+ combined_samples = torch.cat([chosen_curr_samples, chosen_neg_samples], dim=0)
67
+ y = torch.cat([torch.ones(len(chosen_curr_samples)), -torch.ones(neg_num)], dim=0).to(device)
68
+
69
+ # Gram matrix H
70
+ H_inner = torch.matmul(combined_samples, combined_samples.T)
71
+ H_inner = torch.clamp(H_inner, min=-1.0, max=1.0)
72
+ H = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi)
73
+ H.fill_diagonal_(0.5)
74
+ H += 1e-6 * torch.eye(H.size(0), device=device)
75
+ invH = torch.inverse(H)
76
+ original_error = y @ (invH @ y)
77
+
78
+ # for idx_i in chosen_curr_indices:
79
+ for idx_i in tqdm(chosen_curr_indices.tolist(), desc=f"Nodes in label {curr_label}"):
80
+ for j in range(idx_i + 1, N):
81
+ if A[idx_i, j] == 0:
82
+ continue
83
+
84
+ # Sparse AX1 update
85
+ AX1_i = AX[idx_i] - A[idx_i, j] * X[j]
86
+ AX1_j = AX[j] - A[j, idx_i] * X[idx_i]
87
+
88
+ norm_AX1 = norm_AX.clone()
89
+ norm_AX1[idx_i] = AX1_i / (torch.norm(AX1_i) + 1e-8)
90
+ norm_AX1[j] = AX1_j / (torch.norm(AX1_j) + 1e-8)
91
+
92
+ # Updated samples
93
+ curr_samples_A1 = norm_AX1[chosen_curr_indices]
94
+ neg_samples_A1 = norm_AX1[chosen_neg_indices]
95
+ combined_samples_A1 = torch.cat([curr_samples_A1, neg_samples_A1], dim=0)
96
+
97
+ # Recompute H_A1
98
+ H_inner_A1 = torch.matmul(combined_samples_A1, combined_samples_A1.T)
99
+ H_inner_A1 = torch.clamp(H_inner_A1, min=-1.0, max=1.0)
100
+ H_A1 = H_inner_A1 * (np.pi - torch.acos(H_inner_A1)) / (2 * np.pi)
101
+ H_A1.fill_diagonal_(0.5)
102
+ H_A1 += 1e-6 * torch.eye(H_A1.size(0), device=device)
103
+ invH_A1 = torch.inverse(H_A1)
104
+ error_A1 = y @ (invH_A1 @ y)
105
+
106
+ score = (original_error - error_A1).item()
107
+ cg_scores["vi"][idx_i, j] += score
108
+ cg_scores["vi"][j, idx_i] = cg_scores["vi"][idx_i, j]
109
+ cg_scores["times"][idx_i, j] += 1
110
+ cg_scores["times"][j, idx_i] += 1
111
+
112
+ # Normalize
113
+ for key in cg_scores:
114
+ if key != "times":
115
+ cg_scores[key] = cg_scores[key] / np.where(cg_scores["times"] > 0, cg_scores["times"], 1)
116
+
117
+ return cg_scores if sub_term else cg_scores["vi"]
118
+
119
+
120
+ def calc_cg_score_gnn_with_sampling(
121
+ A, X, labels, device, rep_num=1, unbalance_ratio=1, sub_term=False, batch_size=64
122
+ ):
123
+ """
124
+ Optimized CG-score calculation with edge batching and GPU acceleration.
125
+ """
126
+ # if hasattr(torch, "compile"):
127
+ # calc_cg_score_gnn_with_sampling = torch.compile(calc_cg_score_gnn_with_sampling)
128
+
129
+ N = A.shape[0]
130
+ cg_scores = {
131
+ "vi": np.zeros((N, N)),
132
+ "ab": np.zeros((N, N)),
133
+ "a2": np.zeros((N, N)),
134
+ "b2": np.zeros((N, N)),
135
+ "times": np.zeros((N, N)),
136
+ }
137
+
138
+ A = A.to(device)
139
+ X = X.to(device)
140
+ labels = labels.to(device)
141
+
142
+ @torch.no_grad()
143
+ def normalize(tensor):
144
+ return tensor / (torch.norm(tensor, dim=1, keepdim=True) + 1e-8)
145
+
146
+ for _ in range(rep_num):
147
+ AX = torch.matmul(A, X)
148
+ norm_AX = normalize(AX)
149
+
150
+ # Group nodes by label
151
+ dataset = defaultdict(list)
152
+ data_idx = defaultdict(list)
153
+ for i, label in enumerate(labels):
154
+ dataset[label.item()].append(norm_AX[i].unsqueeze(0))
155
+ data_idx[label.item()].append(i)
156
+
157
+ for label in dataset:
158
+ dataset[label] = torch.cat(dataset[label], dim=0)
159
+ data_idx[label] = torch.tensor(data_idx[label], dtype=torch.long, device=device)
160
+
161
+ # Prepare negative samples
162
+ neg_samples_dict = {}
163
+ neg_indices_dict = {}
164
+ for label in dataset:
165
+ neg_samples = torch.cat([dataset[l] for l in dataset if l != label])
166
+ neg_indices = torch.cat([data_idx[l] for l in data_idx if l != label])
167
+ neg_samples_dict[label] = neg_samples
168
+ neg_indices_dict[label] = neg_indices
169
+
170
+ for curr_label, curr_samples in tqdm(dataset.items(), desc="Label groups"):
171
+ curr_indices = data_idx[curr_label]
172
+ curr_num = len(curr_samples)
173
+
174
+ chosen_curr_idx = np.random.choice(range(curr_num), curr_num, replace=False)
175
+ chosen_curr_samples = curr_samples[chosen_curr_idx]
176
+ chosen_curr_indices = curr_indices[chosen_curr_idx]
177
+
178
+ neg_samples = neg_samples_dict[curr_label]
179
+ neg_indices = neg_indices_dict[curr_label]
180
+ neg_num = min(int(curr_num * unbalance_ratio), len(neg_samples))
181
+ rand_idx = torch.randperm(len(neg_samples))[:neg_num]
182
+ chosen_neg_samples = neg_samples[rand_idx]
183
+ chosen_neg_indices = neg_indices[rand_idx]
184
+
185
+ combined_samples = torch.cat([chosen_curr_samples, chosen_neg_samples], dim=0)
186
+ y = torch.cat([torch.ones(len(chosen_curr_samples)), -torch.ones(neg_num)], dim=0).to(device)
187
+
188
+ # Compute reference error
189
+ H_inner = torch.matmul(combined_samples, combined_samples.T)
190
+ H_inner = torch.clamp(H_inner, min=-1.0, max=1.0)
191
+ H = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi)
192
+ H.fill_diagonal_(0.5)
193
+ H += 1e-6 * torch.eye(H.size(0), device=device)
194
+ invH = torch.inverse(H)
195
+ original_error = y @ (invH @ y)
196
+
197
+ # Gather candidate edges
198
+ edge_batch = []
199
+ for idx_i in chosen_curr_indices.tolist():
200
+ for j in range(idx_i + 1, N):
201
+ if A[idx_i, j] != 0:
202
+ edge_batch.append((idx_i, j))
203
+
204
+ # Process in batches
205
+ for k in tqdm(range(0, len(edge_batch), batch_size), desc="Edge batches", leave=False):
206
+ batch = edge_batch[k : k + batch_size]
207
+ B = len(batch)
208
+
209
+ norm_AX1_batch = norm_AX.repeat(B, 1, 1)
210
+ updates = []
211
+ for b, (i, j) in enumerate(batch):
212
+ AX1_i = AX[i] - A[i, j] * X[j]
213
+ AX1_j = AX[j] - A[j, i] * X[i]
214
+ norm_AX1_batch[b, i] = AX1_i / (torch.norm(AX1_i) + 1e-8)
215
+ norm_AX1_batch[b, j] = AX1_j / (torch.norm(AX1_j) + 1e-8)
216
+
217
+ sample_idx = chosen_curr_indices.tolist() + chosen_neg_indices.tolist()
218
+ sample_batch = norm_AX1_batch[:, sample_idx, :] # [B, M, D]
219
+
220
+ H_inner = torch.matmul(sample_batch, sample_batch.transpose(1, 2))
221
+ H_inner = torch.clamp(H_inner, min=-1.0, max=1.0)
222
+ H = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi)
223
+ eye = torch.eye(H.size(-1), device=device).unsqueeze(0).expand_as(H)
224
+ H = H + 1e-6 * eye
225
+ H.diagonal(dim1=-2, dim2=-1).copy_(0.5)
226
+
227
+ invH = torch.inverse(H)
228
+ y_expanded = y.unsqueeze(0).expand(B, -1)
229
+ error_A1 = torch.einsum('bi,bij,bj->b', y_expanded, invH, y_expanded)
230
+
231
+ for b, (i, j) in enumerate(batch):
232
+ score = (original_error - error_A1[b]).item()
233
+ cg_scores["vi"][i, j] += score
234
+ cg_scores["vi"][j, i] = cg_scores["vi"][i, j]
235
+ cg_scores["times"][i, j] += 1
236
+ cg_scores["times"][j, i] += 1
237
+
238
+ for key in cg_scores:
239
+ if key != "times":
240
+ cg_scores[key] = cg_scores[key] / np.where(cg_scores["times"] > 0, cg_scores["times"], 1)
241
+
242
+ return cg_scores if sub_term else cg_scores["vi"]
243
+
244
+ def calc_cg_score_gnn_with_sampling( # based on the front code, remove more data to GPU, effect is approxiamate
245
+ A, X, labels, device, rep_num=1, unbalance_ratio=1, sub_term=False
246
+ ):
247
+ """
248
+ Calculate CG-score for each edge in a graph with node labels and random sampling.
249
+
250
+ Args:
251
+ A: torch.Tensor
252
+ Adjacency matrix of the graph (size: N x N).
253
+ X: torch.Tensor
254
+ Node features matrix (size: N x F).
255
+ labels: torch.Tensor
256
+ Node labels (size: N).
257
+ device: torch.device
258
+ Device to perform calculations.
259
+ rep_num: int
260
+ Number of repetitions for Monte Carlo sampling.
261
+ unbalance_ratio: float
262
+ Ratio of unbalanced data (1:unbalance_ratio).
263
+ sub_term: bool
264
+ If True, calculate and return sub-terms.
265
+
266
+ Returns:
267
+ cg_scores: dict
268
+ Dictionary containing CG-scores for edges and optionally sub-terms.
269
+ """
270
+ N = A.shape[0]
271
+ cg_scores = {
272
+ "vi": np.zeros((N, N)),
273
+ "ab": np.zeros((N, N)),
274
+ "a2": np.zeros((N, N)),
275
+ "b2": np.zeros((N, N)),
276
+ "times": np.zeros((N, N)),
277
+ }
278
+
279
+ with torch.no_grad():
280
+ for _ in range(rep_num):
281
+ # Compute AX (node representations)
282
+ AX = torch.matmul(A, X).to(device)
283
+ norm_AX = AX / torch.norm(AX, dim=1, keepdim=True)
284
+
285
+ # Group nodes by their labels
286
+ dataset = defaultdict(list)
287
+ data_idx = defaultdict(list)
288
+ for i, label in enumerate(labels):
289
+ dataset[label.item()].append(norm_AX[i].unsqueeze(0)) # Store normalized data
290
+ data_idx[label.item()].append(i) # Store indices
291
+
292
+ # Convert to tensors
293
+ for label, data_list in dataset.items():
294
+ dataset[label] = torch.cat(data_list, dim=0)
295
+ data_idx[label] = torch.tensor(data_idx[label], dtype=torch.long, device=device)
296
+
297
+ # Calculate CG-scores for each label group
298
+ for curr_label, curr_samples in dataset.items():
299
+ curr_indices = data_idx[curr_label]
300
+ curr_num = len(curr_samples)
301
+
302
+ # Randomly sample a subset of current label examples
303
+ chosen_curr_idx = np.random.choice(range(curr_num), curr_num, replace=False)
304
+ chosen_curr_samples = curr_samples[chosen_curr_idx]
305
+ chosen_curr_indices = curr_indices[chosen_curr_idx]
306
+
307
+ # Sample negative examples from other classes
308
+ neg_samples = torch.cat(
309
+ [dataset[l] for l in dataset if l != curr_label], dim=0
310
+ )
311
+ neg_indices = torch.cat(
312
+ [data_idx[l] for l in data_idx if l != curr_label], dim=0
313
+ )
314
+ neg_num = min(int(curr_num * unbalance_ratio), len(neg_samples))
315
+ chosen_neg_samples = neg_samples[
316
+ torch.randperm(len(neg_samples))[:neg_num]
317
+ ]
318
+
319
+ # Combine positive and negative samples
320
+ combined_samples = torch.cat([chosen_curr_samples, chosen_neg_samples], dim=0)
321
+ y = torch.cat(
322
+ [torch.ones(len(chosen_curr_samples)), -torch.ones(neg_num)], dim=0
323
+ ).to(device)
324
+
325
+ # Compute the Gram matrix H^\infty
326
+ H_inner = torch.matmul(combined_samples, combined_samples.T)
327
+ del combined_samples
328
+ ###
329
+ H_inner = torch.clamp(H_inner, min=-1.0, max=1.0)
330
+ ###
331
+ H = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi)
332
+ del H_inner
333
+
334
+ H.fill_diagonal_(0.5)
335
+ ##
336
+ epsilon = 1e-6
337
+ H = H + epsilon * torch.eye(H.size(0), device=H.device)
338
+ ##
339
+ invH = torch.inverse(H)
340
+ del H
341
+ original_error = y @ (invH @ y)
342
+
343
+ # Compute CG-scores for each edge
344
+ for i in chosen_curr_indices:
345
+ print("the node index:", i)
346
+ for j in range(i + 1, N): # Upper triangular traversal
347
+ # print(j)
348
+ if A[i, j] == 0: # Skip if no edge exists
349
+ continue
350
+
351
+ # Remove edge (i, j) to create A1
352
+ A1 = A.clone()
353
+ A1[i, j] = A1[j, i] = 0
354
+
355
+ # Recompute AX with A1
356
+ AX1 = torch.matmul(A1, X).to(device)
357
+ norm_AX1 = AX1 / torch.norm(AX1, dim=1, keepdim=True)
358
+
359
+ # Repeat error calculation with A1
360
+ curr_samples_A1 = norm_AX1[chosen_curr_indices]
361
+ neg_samples_A1 = norm_AX1[neg_indices]
362
+ chosen_neg_samples_A1 = neg_samples_A1[
363
+ torch.randperm(len(neg_samples_A1))[:neg_num]
364
+ ]
365
+ combined_samples_A1 = torch.cat(
366
+ [curr_samples_A1, chosen_neg_samples_A1], dim=0
367
+ )
368
+ H_inner_A1 = torch.matmul(combined_samples_A1, combined_samples_A1.T)
369
+
370
+ del combined_samples_A1
371
+
372
+ ### trick1
373
+ H_inner_A1 = torch.clamp(H_inner_A1, min=-1.0, max=1.0)
374
+ ###
375
+
376
+ H_A1 = H_inner_A1 * (np.pi - torch.acos(H_inner_A1)) / (2 * np.pi)
377
+ del H_inner_A1
378
+ H_A1.fill_diagonal_(0.5)
379
+
380
+ ### trick2
381
+ epsilon = 1e-6
382
+ H_A1= H_A1 + epsilon * torch.eye(H_A1.size(0), device=H_A1.device)
383
+ ###
384
+ invH_A1 = torch.inverse(H_A1)
385
+ del H_A1
386
+
387
+ error_A1 = y @ (invH_A1 @ y)
388
+
389
+ print("i:", i)
390
+ print("j:", j)
391
+ print("current score:", (original_error - error_A1).item())
392
+ # Compute the difference in error (CG-score)
393
+ cg_scores["vi"][i, j] += (original_error - error_A1).item()
394
+ cg_scores["vi"][j, i] = cg_scores["vi"][i, j] # Symmetric
395
+ cg_scores["times"][i, j] += 1
396
+ cg_scores["times"][j, i] += 1
397
+
398
+ # Normalize CG-scores by repetition count
399
+ for key, values in cg_scores.items():
400
+ if key == "times":
401
+ continue
402
+ cg_scores[key] = values / np.where(cg_scores["times"] > 0, cg_scores["times"], 1)
examples/graph/test_adv_train_evasion.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+ import torch.optim as optim
5
+ from deeprobust.graph.defense import GCN
6
+ from deeprobust.graph.global_attack import Random
7
+ from deeprobust.graph.targeted_attack import Nettack
8
+ from deeprobust.graph.utils import *
9
+ from deeprobust.graph.data import Dataset
10
+ from deeprobust.graph.data import PtbDataset
11
+ from tqdm import tqdm
12
+ import argparse
13
+
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument('--seed', type=int, default=15, help='Random seed.')
16
+ parser.add_argument('--dataset', type=str, default='citeseer', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'], help='dataset')
17
+ parser.add_argument('--ptb_rate', type=float, default=0.05, help='pertubation rate')
18
+
19
+ args = parser.parse_args()
20
+ args.cuda = torch.cuda.is_available()
21
+ print('cuda: %s' % args.cuda)
22
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
23
+
24
+ # make sure you use the same data splits as you generated attacks
25
+ np.random.seed(args.seed)
26
+ if args.cuda:
27
+ torch.cuda.manual_seed(args.seed)
28
+
29
+ # load original dataset (to get clean features and labels)
30
+ data = Dataset(root='/tmp/', name=args.dataset)
31
+ adj, features, labels = data.adj, data.features, data.labels
32
+ idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
33
+
34
+ # Setup Target Model
35
+ model = GCN(nfeat=features.shape[1], nclass=labels.max()+1,
36
+ nhid=16, dropout=0, with_relu=False, with_bias=True, device=device)
37
+
38
+ model = model.to(device)
39
+
40
+ # test on original adj
41
+ print('=== test on original adj ===')
42
+ model.fit(features, adj, labels, idx_train)
43
+ output = model.output
44
+ acc_test = accuracy(output[idx_test], labels[idx_test])
45
+ print("Test set results:",
46
+ "accuracy= {:.4f}".format(acc_test.item()))
47
+
48
+ print('=== Adversarial Training for Evasion Attack===')
49
+ adversary = Random()
50
+ adv_train_model = GCN(nfeat=features.shape[1], nclass=labels.max()+1,
51
+ nhid=16, dropout=0, with_relu=False, with_bias=True, device=device)
52
+
53
+ adv_train_model = adv_train_model.to(device)
54
+
55
+ adv_train_model.initialize()
56
+ n_perturbations = int(0.01 * (adj.sum()//2))
57
+ for i in tqdm(range(100)):
58
+ # modified_adj = adversary.attack(features, adj)
59
+ adversary.attack(adj, n_perturbations=n_perturbations, type='add')
60
+ modified_adj = adversary.modified_adj
61
+ adv_train_model.fit(features, modified_adj, labels, idx_train, train_iters=50, initialize=False)
62
+
63
+ adv_train_model.eval()
64
+ # test directly or fine tune
65
+ print('=== test on perturbed adj ===')
66
+ output = adv_train_model.predict()
67
+ acc_test = accuracy(output[idx_test], labels[idx_test])
68
+ print("Test set results:",
69
+ "accuracy= {:.4f}".format(acc_test.item()))
70
+
71
+
72
+ # set up Surrogate & Nettack to attack the graph
73
+ import random
74
+ target_nodes = random.sample(idx_test.tolist(), 20)
75
+ # Setup Surrogate model
76
+ surrogate = GCN(nfeat=features.shape[1], nclass=labels.max().item()+1,
77
+ nhid=16, dropout=0, with_relu=False, with_bias=False, device=device)
78
+ surrogate = surrogate.to(device)
79
+ surrogate.fit(features, adj, labels, idx_train)
80
+
81
+ all_margins = []
82
+ all_adv_margins = []
83
+
84
+ for target_node in target_nodes:
85
+ # set up Nettack
86
+ adversary = Nettack(surrogate, nnodes=adj.shape[0], attack_structure=True, attack_features=True, device=device)
87
+ adversary = adversary.to(device)
88
+ degrees = adj.sum(0).A1
89
+ n_perturbations = int(degrees[target_node]) + 2
90
+ adversary.attack(features, adj, labels, target_node, n_perturbations)
91
+ perturbed_adj = adversary.modified_adj
92
+
93
+ model = GCN(nfeat=features.shape[1], nclass=labels.max()+1,
94
+ nhid=16, dropout=0, with_relu=False, with_bias=True, device=device)
95
+ model = model.to(device)
96
+
97
+ print('=== testing GCN on perturbed graph ===')
98
+ model.fit(features, perturbed_adj, labels, idx_train)
99
+ output = model.output
100
+ margin = classification_margin(output[target_node], labels[target_node])
101
+ all_margins.append(margin)
102
+
103
+ print('=== testing adv-GCN on perturbed graph ===')
104
+ output = adv_train_model.predict(features, perturbed_adj)
105
+ adv_margin = classification_margin(output[target_node], labels[target_node])
106
+ all_adv_margins.append(adv_margin)
107
+
108
+
109
+ print("No adversarial training: classfication margin for {0} nodes: {1}".format(len(target_nodes), np.mean(all_margins)))
110
+
111
+ print("Adversarial training: classfication margin for {0} nodes: {1}".format(len(target_nodes), np.mean(all_adv_margins)))
112
+
examples/graph/test_adv_train_poisoning.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+ import torch.optim as optim
5
+ from deeprobust.graph.defense import GCN
6
+ from deeprobust.graph.global_attack import Random
7
+ from deeprobust.graph.utils import *
8
+ from deeprobust.graph.data import Dataset
9
+ from deeprobust.graph.data import PtbDataset
10
+
11
+ import argparse
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument('--seed', type=int, default=15, help='Random seed.')
14
+ parser.add_argument('--dataset', type=str, default='citeseer', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'], help='dataset')
15
+ parser.add_argument('--ptb_rate', type=float, default=0.05, help='pertubation rate')
16
+
17
+ args = parser.parse_args()
18
+ args.cuda = torch.cuda.is_available()
19
+ print('cuda: %s' % args.cuda)
20
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
21
+
22
+ # make sure you use the same data splits as you generated attacks
23
+ np.random.seed(args.seed)
24
+ if args.cuda:
25
+ torch.cuda.manual_seed(args.seed)
26
+
27
+ # load original dataset (to get clean features and labels)
28
+ data = Dataset(root='/tmp/', name=args.dataset)
29
+ adj, features, labels = data.adj, data.features, data.labels
30
+ idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
31
+
32
+ # load pre-attacked graph
33
+ perturbed_data = PtbDataset(root='/tmp/', name=args.dataset)
34
+ perturbed_adj = perturbed_data.adj
35
+
36
+ # Setup Target Model
37
+ model = GCN(nfeat=features.shape[1], nclass=labels.max()+1,
38
+ nhid=16, dropout=0, with_relu=False, with_bias=True, device=device)
39
+
40
+ model = model.to(device)
41
+
42
+ adversary = Random()
43
+ # test on original adj
44
+ print('=== test on original adj ===')
45
+ model.fit(features, adj, labels, idx_train)
46
+ output = model.output
47
+ acc_test = accuracy(output[idx_test], labels[idx_test])
48
+ print("Test set results:",
49
+ "accuracy= {:.4f}".format(acc_test.item()))
50
+
51
+ print('=== testing GCN on perturbed graph ===')
52
+ model.fit(features, perturbed_adj, labels, idx_train)
53
+ output = model.output
54
+ acc_test = accuracy(output[idx_test], labels[idx_test])
55
+ print("Test set results:",
56
+ "accuracy= {:.4f}".format(acc_test.item()))
57
+
58
+
59
+ # For poisoning attack, the adjacency matrix you have
60
+ # is alreay perturbed
61
+ print('=== Adversarial Training for Poisoning Attack===')
62
+ model.initialize()
63
+ n_perturbations = int(0.01 * (adj.sum()//2))
64
+ for i in range(100):
65
+ # modified_adj = adversary.attack(features, adj)
66
+ adversary.attack(perturbed_adj, n_perturbations=n_perturbations, type='remove')
67
+ modified_adj = adversary.modified_adj
68
+ model.fit(features, modified_adj, labels, idx_train, train_iters=50, initialize=False)
69
+
70
+ model.eval()
71
+
72
+ # test directly or fine tune
73
+ print('=== test on perturbed adj ===')
74
+ output = model.predict(features, perturbed_adj)
75
+ acc_test = accuracy(output[idx_test], labels[idx_test])
76
+ print("Test set results:",
77
+ "accuracy= {:.4f}".format(acc_test.item()))
78
+
examples/graph/test_all.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+
4
+ for file in os.listdir('./'):
5
+ if "py" not in file:
6
+ continue
7
+ if 'rl' in file or 'nipa' in file or 'meta' in file or 'all' in file:
8
+ continue
9
+ if osp.isfile(file):
10
+ print(file)
11
+ os.system('CUDA_VISIBLE_DEVICES=0 python %s' % file)
12
+
13
+
examples/graph/test_chebnet.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+ from deeprobust.graph.data import Dataset, Dpr2Pyg
4
+ from deeprobust.graph.defense import ChebNet
5
+ from deeprobust.graph.data import Dataset
6
+ from deeprobust.graph.data import PrePtbDataset
7
+
8
+ parser = argparse.ArgumentParser()
9
+ parser.add_argument('--seed', type=int, default=15, help='Random seed.')
10
+ parser.add_argument('--dataset', type=str, default='cora', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'], help='dataset')
11
+ parser.add_argument('--ptb_rate', type=float, default=0.05, help='perturbation rate')
12
+
13
+ args = parser.parse_args()
14
+ args.cuda = torch.cuda.is_available()
15
+ print('cuda: %s' % args.cuda)
16
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
17
+
18
+ # use data splist provided by prognn
19
+ data = Dataset(root='/tmp/', name=args.dataset, setting='prognn')
20
+ adj, features, labels = data.adj, data.features, data.labels
21
+ idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
22
+
23
+ cheby = ChebNet(nfeat=features.shape[1],
24
+ nhid=16, num_hops=3,
25
+ nclass=labels.max().item() + 1,
26
+ dropout=0.5, device=device)
27
+ cheby = cheby.to(device)
28
+
29
+ # test on clean graph
30
+ print('==================')
31
+ print('=== train on clean graph ===')
32
+
33
+ pyg_data = Dpr2Pyg(data)
34
+ cheby.fit(pyg_data, verbose=True) # train with earlystopping
35
+ cheby.test()
36
+
37
+ # load pre-attacked graph by Zugner: https://github.com/danielzuegner/gnn-meta-attack
38
+ print('==================')
39
+ print('=== load graph perturbed by Zugner metattack (under prognn splits) ===')
40
+ perturbed_data = PrePtbDataset(root='/tmp/',
41
+ name=args.dataset,
42
+ attack_method='meta',
43
+ ptb_rate=args.ptb_rate)
44
+ perturbed_adj = perturbed_data.adj
45
+ pyg_data.update_edge_index(perturbed_adj) # inplace operation
46
+ cheby.fit(pyg_data, verbose=True) # train with earlystopping
47
+ cheby.test()
48
+
examples/graph/test_deepwalk.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from deeprobust.graph.data import Dataset
2
+ from deeprobust.graph.defense import DeepWalk, Node2Vec
3
+ from deeprobust.graph.global_attack import NodeEmbeddingAttack
4
+ import numpy as np
5
+
6
+ dataset_str = 'cora_ml'
7
+ data = Dataset(root='/tmp/', name=dataset_str, seed=15)
8
+ adj, features, labels = data.adj, data.features, data.labels
9
+ idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
10
+
11
+ attacker = NodeEmbeddingAttack()
12
+ attacker.attack(adj, attack_type="remove", n_perturbations=1000)
13
+ modified_adj = attacker.modified_adj
14
+
15
+ # train defense model
16
+ print("Test DeepWalk on clean graph")
17
+ model = DeepWalk()
18
+ model.fit(adj)
19
+ model.evaluate_node_classification(labels, idx_train, idx_test)
20
+ # model.evaluate_node_classification(labels, idx_train, idx_test, lr_params={"max_iter": 1000})
21
+
22
+ print("Test DeepWalk on attacked graph")
23
+ model.fit(modified_adj)
24
+ model.evaluate_node_classification(labels, idx_train, idx_test)
25
+
26
+ print("Test DeepWalk on link prediciton...")
27
+ model.evaluate_link_prediction(modified_adj, np.array(adj.nonzero()).T)
28
+
29
+ print("Test DeepWalk SVD on attacked graph")
30
+ model = DeepWalk(type="svd")
31
+ model.fit(modified_adj)
32
+ model.evaluate_node_classification(labels, idx_train, idx_test)
33
+
34
+ print("Test Node2vec on attacked graph")
35
+ model = Node2Vec()
36
+ model.fit(modified_adj)
37
+ model.evaluate_node_classification(labels, idx_train, idx_test)
38
+
39
+
examples/graph/test_gat.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+ from deeprobust.graph.data import Dataset, Dpr2Pyg
4
+ from deeprobust.graph.defense import GAT
5
+ from deeprobust.graph.data import Dataset
6
+ from deeprobust.graph.data import PrePtbDataset
7
+ import scipy.sparse as sp
8
+
9
+ parser = argparse.ArgumentParser()
10
+ parser.add_argument('--seed', type=int, default=15, help='Random seed.')
11
+ parser.add_argument('--dataset', type=str, default='cora', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'], help='dataset')
12
+ parser.add_argument('--ptb_rate', type=float, default=0.10, help='perturbation rate')
13
+
14
+ args = parser.parse_args()
15
+ args.cuda = torch.cuda.is_available()
16
+ print('cuda: %s' % args.cuda)
17
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
18
+
19
+ # use data splist provided by prognn
20
+ data = Dataset(root='/tmp/', name=args.dataset, setting='prognn')
21
+ adj, features, labels = data.adj, data.features, data.labels
22
+ idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
23
+
24
+ gat = GAT(nfeat=features.shape[1],
25
+ nhid=8, heads=8,
26
+ nclass=labels.max().item() + 1,
27
+ dropout=0.5, device=device)
28
+ gat = gat.to(device)
29
+
30
+
31
+ # test on clean graph
32
+ print('==================')
33
+ print('=== train on clean graph ===')
34
+
35
+ print(type(features))
36
+ print(type(adj))
37
+ pyg_data = Dpr2Pyg(data)
38
+ gat.fit(pyg_data, verbose=True) # train with earlystopping
39
+ gat.test()
40
+
41
+ # load pre-attacked graph by Zugner: https://github.com/danielzuegner/gnn-meta-attack
42
+ print('==================')
43
+ print('=== load graph perturbed by Zugner metattack (under prognn splits) ===')
44
+ perturbed_data = PrePtbDataset(root='/tmp/',
45
+ name=args.dataset,
46
+ attack_method='meta',
47
+ ptb_rate=args.ptb_rate)
48
+ perturbed_adj = perturbed_data.adj
49
+ pyg_data.update_edge_index(perturbed_adj) # inplace operation
50
+ gat.fit(pyg_data, verbose=True) # train with earlystopping
51
+ gat.test()
52
+
53
+
54
+
55
+
examples/graph/test_gcn.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+ from deeprobust.graph.defense import GCN
5
+ from deeprobust.graph.utils import *
6
+ from deeprobust.graph.data import Dataset
7
+ from deeprobust.graph.data import PtbDataset, PrePtbDataset
8
+ import argparse
9
+
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument('--seed', type=int, default=15, help='Random seed.')
12
+ parser.add_argument('--dataset', type=str, default='citeseer', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'], help='dataset')
13
+ parser.add_argument('--ptb_rate', type=float, default=0.05, help='pertubation rate')
14
+
15
+ args = parser.parse_args()
16
+ args.cuda = torch.cuda.is_available()
17
+ print('cuda: %s' % args.cuda)
18
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
19
+
20
+ # Here the random seed is to split the train/val/test data,
21
+ # we need to set the random seed to be the same as that when you generate the perturbed graph
22
+ # data = Dataset(root='/tmp/', name=args.dataset, setting='nettack', seed=15)
23
+ # Or we can just use setting='prognn' to get the splits
24
+ data = Dataset(root='/tmp/', name=args.dataset, setting='prognn')
25
+ adj, features, labels = data.adj, data.features, data.labels
26
+ idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
27
+
28
+
29
+ # load pre-attacked graph by Zugner: https://github.com/danielzuegner/gnn-meta-attack
30
+ print('==================')
31
+ print('=== load graph perturbed by Zugner metattack (under prognn splits) ===')
32
+ perturbed_data = PrePtbDataset(root='/tmp/',
33
+ name=args.dataset,
34
+ attack_method='meta',
35
+ ptb_rate=args.ptb_rate)
36
+ # perturbed_adj = perturbed_data.adj
37
+ perturbed_adj = adj
38
+
39
+ np.random.seed(args.seed)
40
+ torch.manual_seed(args.seed)
41
+ if args.cuda:
42
+ torch.cuda.manual_seed(args.seed)
43
+
44
+ # Setup GCN Model
45
+ model = GCN(nfeat=features.shape[1], nhid=16, nclass=labels.max()+1, device=device)
46
+ model = model.to(device)
47
+
48
+ # model.fit(features, perturbed_adj, labels, idx_train, train_iters=200, verbose=True)
49
+ # # using validation to pick model
50
+ model.fit(features, perturbed_adj, labels, idx_train, idx_val, train_iters=200, verbose=True)
51
+ model.eval()
52
+ # You can use the inner function of model to test
53
+ model.test(idx_test)
54
+
55
+ # print('==================')
56
+ # print('=== load graph perturbed by DeepRobust 5% metattack (under prognn splits) ===')
57
+ # perturbed_data = PtbDataset(root='/tmp/',
58
+ # name=args.dataset,
59
+ # attack_method='meta')
60
+ # perturbed_adj = perturbed_data.adj
61
+
62
+ # print("dataset:", args.dataset)
63
+ # # model.fit(features, perturbed_adj, labels, idx_train, train_iters=200, verbose=True)
64
+ # # # using validation to pick model
65
+ # model.fit(features, perturbed_adj, labels, idx_train, idx_val, train_iters=200, verbose=True)
66
+ # model.eval()
67
+ # # You can use the inner function of model to test
68
+ # model.test(idx_test)
69
+