Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- deeprobust/graph/defense_pyg/__init__.py +15 -0
- deeprobust/graph/defense_pyg/appnp.py +79 -0
- deeprobust/graph/defense_pyg/base_model.py +206 -0
- deeprobust/graph/defense_pyg/gpr.py +135 -0
- deeprobust/graph/defense_pyg/mygat_conv.py +198 -0
- deeprobust/graph/rl/__init__.py +0 -0
- deeprobust/graph/rl/env.py +258 -0
- deeprobust/graph/rl/q_net_node.py +228 -0
- deeprobust/image/adversary_examples/advexample.png +0 -0
- deeprobust/image/adversary_examples/cifar_advexample_orig.png +0 -0
- deeprobust/image/adversary_examples/cifar_advexample_pgd.png +0 -0
- deeprobust/image/adversary_examples/deepfool_diff.png +0 -0
- deeprobust/image/adversary_examples/imageexample.png +0 -0
- deeprobust/image/adversary_examples/test.jpg +0 -0
- deeprobust/image/adversary_examples/test1.jpg +0 -0
- deeprobust/image/evaluation_attack.py +226 -0
- deeprobust/image/netmodels/CNN.py +125 -0
- deeprobust/image/netmodels/CNN_multilayer.py +122 -0
- deeprobust/image/netmodels/YOPOCNN.py +70 -0
- deeprobust/image/netmodels/resnet.py +168 -0
- deeprobust/image/netmodels/train_model.py +146 -0
- deeprobust/image/netmodels/train_resnet.py +39 -0
- deeprobust/image/netmodels/vgg.py +116 -0
- deeprobust/image/synset_words.txt +1000 -0
- docs/Makefile +20 -0
- docs/conf.py +71 -0
- docs/index.rst +65 -0
- examples/graph/cgscore_datasets.py +255 -0
- examples/graph/cgscore_datasets_multigpus.py +299 -0
- examples/graph/cgscore_datasets_multigpus2.py +208 -0
- examples/graph/cgscore_env.yaml +193 -0
- examples/graph/cgscore_experiments/attack_method/attack_minmax.py +106 -0
- examples/graph/cgscore_experiments/attack_method/attack_nettack.py +212 -0
- examples/graph/cgscore_experiments/defense_method/GAT.py +61 -0
- examples/graph/cgscore_experiments/defense_method/GCN.py +73 -0
- examples/graph/cgscore_experiments/defense_method/GCNJaccard.py +68 -0
- examples/graph/cgscore_experiments/defense_method/GCNSVD.py +63 -0
- examples/graph/cgscore_experiments/defense_method/GNNGuard.py +64 -0
- examples/graph/cgscore_experiments/defense_method/ProGNN.py +80 -0
- examples/graph/cgscore_experiments/defense_method/RGCN.py +66 -0
- examples/graph/cgscore_experiments/defense_method/cgscore.py +0 -0
- examples/graph/cgscore_experiments/grb/grb_data.py +32 -0
- examples/graph/cgscore_save.py +402 -0
- examples/graph/test_adv_train_evasion.py +112 -0
- examples/graph/test_adv_train_poisoning.py +78 -0
- examples/graph/test_all.py +13 -0
- examples/graph/test_chebnet.py +48 -0
- examples/graph/test_deepwalk.py +39 -0
- examples/graph/test_gat.py +55 -0
- examples/graph/test_gcn.py +69 -0
deeprobust/graph/defense_pyg/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
try:
|
| 2 |
+
from .gcn import GCN
|
| 3 |
+
from .gat import GAT
|
| 4 |
+
from .appnp import APPNP
|
| 5 |
+
from .sage import SAGE
|
| 6 |
+
from .gpr import GPRGNN
|
| 7 |
+
from .airgnn import AirGNN
|
| 8 |
+
except ImportError as e:
|
| 9 |
+
print(e)
|
| 10 |
+
warnings.warn("Please install pytorch geometric if you " +
|
| 11 |
+
"would like to use the datasets from pytorch " +
|
| 12 |
+
"geometric. See details in https://pytorch-geom" +
|
| 13 |
+
"etric.readthedocs.io/en/latest/notes/installation.html")
|
| 14 |
+
|
| 15 |
+
__all__ = ["GCN", "GAT", "APPNP", "SAGE", "GPRGNN", "AirGNN"]
|
deeprobust/graph/defense_pyg/appnp.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import math
|
| 4 |
+
import torch
|
| 5 |
+
from torch.nn.parameter import Parameter
|
| 6 |
+
from torch.nn.modules.module import Module
|
| 7 |
+
from torch_geometric.nn import APPNP as APPNPConv
|
| 8 |
+
from torch.nn import Linear
|
| 9 |
+
from .base_model import BaseModel
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class APPNP(BaseModel):
|
| 13 |
+
|
| 14 |
+
def __init__(self, nfeat, nhid, nclass, K=10, alpha=0.1, dropout=0.5, lr=0.01,
|
| 15 |
+
with_bn=False, weight_decay=5e-4, with_bias=True, device=None):
|
| 16 |
+
|
| 17 |
+
super(APPNP, self).__init__()
|
| 18 |
+
|
| 19 |
+
assert device is not None, "Please specify 'device'!"
|
| 20 |
+
self.device = device
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
self.lin1 = Linear(nfeat, nhid)
|
| 24 |
+
if with_bn:
|
| 25 |
+
self.bn1 = nn.BatchNorm1d(nhid)
|
| 26 |
+
self.bn2 = nn.BatchNorm1d(nclass)
|
| 27 |
+
|
| 28 |
+
self.lin2 = Linear(nhid, nclass)
|
| 29 |
+
self.prop1 = APPNPConv(K, alpha)
|
| 30 |
+
|
| 31 |
+
self.dropout = dropout
|
| 32 |
+
self.weight_decay = weight_decay
|
| 33 |
+
self.lr = lr
|
| 34 |
+
self.output = None
|
| 35 |
+
self.best_model = None
|
| 36 |
+
self.best_output = None
|
| 37 |
+
self.name = 'APPNP'
|
| 38 |
+
self.with_bn = with_bn
|
| 39 |
+
|
| 40 |
+
def forward(self, x, edge_index, edge_weight=None):
|
| 41 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
| 42 |
+
x = self.lin1(x)
|
| 43 |
+
if self.with_bn:
|
| 44 |
+
x = self.bn1(x)
|
| 45 |
+
x = F.relu(x)
|
| 46 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
| 47 |
+
x = self.lin2(x)
|
| 48 |
+
if self.with_bn:
|
| 49 |
+
x = self.bn2(x)
|
| 50 |
+
x = self.prop1(x, edge_index, edge_weight)
|
| 51 |
+
return F.log_softmax(x, dim=1)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def initialize(self):
|
| 55 |
+
self.lin1.reset_parameters()
|
| 56 |
+
self.lin2.reset_parameters()
|
| 57 |
+
if self.with_bn:
|
| 58 |
+
self.bn1.reset_parameters()
|
| 59 |
+
self.bn2.reset_parameters()
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
if __name__ == "__main__":
|
| 63 |
+
from deeprobust.graph.data import Dataset, Dpr2Pyg
|
| 64 |
+
data = Dataset(root='/tmp/', name='cora', setting='gcn')
|
| 65 |
+
adj, features, labels = data.adj, data.features, data.labels
|
| 66 |
+
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
| 67 |
+
model = GCN(nfeat=features.shape[1],
|
| 68 |
+
nhid=16,
|
| 69 |
+
nclass=labels.max().item() + 1,
|
| 70 |
+
dropout=0.5, device='cuda')
|
| 71 |
+
model = model.to('cuda')
|
| 72 |
+
pyg_data = Dpr2Pyg(data)[0]
|
| 73 |
+
|
| 74 |
+
import ipdb
|
| 75 |
+
ipdb.set_trace()
|
| 76 |
+
|
| 77 |
+
model.fit(pyg_data, verbose=True) # train with earlystopping
|
| 78 |
+
model.test()
|
| 79 |
+
print(model.predict())
|
deeprobust/graph/defense_pyg/base_model.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.optim as optim
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from copy import deepcopy
|
| 5 |
+
from deeprobust.graph import utils
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class BaseModel(nn.Module):
|
| 10 |
+
def __init__(self):
|
| 11 |
+
super(BaseModel, self).__init__()
|
| 12 |
+
pass
|
| 13 |
+
|
| 14 |
+
def fit(self, pyg_data, train_iters=1000, initialize=True, verbose=False, patience=100, **kwargs):
|
| 15 |
+
if initialize:
|
| 16 |
+
self.initialize()
|
| 17 |
+
|
| 18 |
+
# self.data = pyg_data[0].to(self.device)
|
| 19 |
+
self.data = pyg_data.to(self.device)
|
| 20 |
+
# By default, it is trained with early stopping on validation
|
| 21 |
+
self.train_with_early_stopping(train_iters, patience, verbose)
|
| 22 |
+
|
| 23 |
+
def finetune(self, edge_index, edge_weight, feat=None, train_iters=10, verbose=True):
|
| 24 |
+
if verbose:
|
| 25 |
+
print(f'=== finetuning {self.name} model ===')
|
| 26 |
+
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
| 27 |
+
labels = self.data.y
|
| 28 |
+
if feat is None:
|
| 29 |
+
x = self.data.x
|
| 30 |
+
else:
|
| 31 |
+
x = feat
|
| 32 |
+
train_mask, val_mask = self.data.train_mask, self.data.val_mask
|
| 33 |
+
best_loss_val = 100
|
| 34 |
+
best_acc_val = 0
|
| 35 |
+
for i in range(train_iters):
|
| 36 |
+
self.train()
|
| 37 |
+
optimizer.zero_grad()
|
| 38 |
+
output = self.forward(x, edge_index, edge_weight)
|
| 39 |
+
loss_train = F.nll_loss(output[train_mask], labels[train_mask])
|
| 40 |
+
loss_train.backward()
|
| 41 |
+
optimizer.step()
|
| 42 |
+
|
| 43 |
+
if verbose and i % 50 == 0:
|
| 44 |
+
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
|
| 45 |
+
|
| 46 |
+
self.eval()
|
| 47 |
+
with torch.no_grad():
|
| 48 |
+
output = self.forward(x, edge_index)
|
| 49 |
+
loss_val = F.nll_loss(output[val_mask], labels[val_mask])
|
| 50 |
+
acc_val = utils.accuracy(output[val_mask], labels[val_mask])
|
| 51 |
+
|
| 52 |
+
# if best_loss_val > loss_val:
|
| 53 |
+
# best_loss_val = loss_val
|
| 54 |
+
# best_output = output
|
| 55 |
+
# weights = deepcopy(self.state_dict())
|
| 56 |
+
|
| 57 |
+
if best_acc_val < acc_val:
|
| 58 |
+
best_acc_val = acc_val
|
| 59 |
+
best_output = output
|
| 60 |
+
weights = deepcopy(self.state_dict())
|
| 61 |
+
|
| 62 |
+
print('best_acc_val:', best_acc_val.item())
|
| 63 |
+
self.load_state_dict(weights)
|
| 64 |
+
return best_output
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _fit_with_val(self, pyg_data, train_iters=1000, initialize=True, verbose=False, **kwargs):
|
| 68 |
+
if initialize:
|
| 69 |
+
self.initialize()
|
| 70 |
+
|
| 71 |
+
# self.data = pyg_data[0].to(self.device)
|
| 72 |
+
self.data = pyg_data.to(self.device)
|
| 73 |
+
if verbose:
|
| 74 |
+
print(f'=== training {self.name} model ===')
|
| 75 |
+
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
| 76 |
+
|
| 77 |
+
labels = self.data.y
|
| 78 |
+
train_mask, val_mask = self.data.train_mask, self.data.val_mask
|
| 79 |
+
|
| 80 |
+
x, edge_index = self.data.x, self.data.edge_index
|
| 81 |
+
for i in range(train_iters):
|
| 82 |
+
self.train()
|
| 83 |
+
optimizer.zero_grad()
|
| 84 |
+
output = self.forward(x, edge_index)
|
| 85 |
+
loss_train = F.nll_loss(output[train_mask+val_mask], labels[train_mask+val_mask])
|
| 86 |
+
loss_train.backward()
|
| 87 |
+
optimizer.step()
|
| 88 |
+
|
| 89 |
+
if verbose and i % 50 == 0:
|
| 90 |
+
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
|
| 91 |
+
|
| 92 |
+
def fit_with_val(self, pyg_data, train_iters=1000, initialize=True, patience=100, verbose=False, **kwargs):
|
| 93 |
+
if initialize:
|
| 94 |
+
self.initialize()
|
| 95 |
+
|
| 96 |
+
self.data = pyg_data.to(self.device)
|
| 97 |
+
self.data.train_mask = self.data.train_mask + self.data.val1_mask
|
| 98 |
+
self.data.val_mask = self.data.val2_mask
|
| 99 |
+
self.train_with_early_stopping(train_iters, patience, verbose)
|
| 100 |
+
|
| 101 |
+
def train_with_early_stopping(self, train_iters, patience, verbose):
|
| 102 |
+
"""early stopping based on the validation loss
|
| 103 |
+
"""
|
| 104 |
+
if verbose:
|
| 105 |
+
print(f'=== training {self.name} model ===')
|
| 106 |
+
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
| 107 |
+
|
| 108 |
+
labels = self.data.y
|
| 109 |
+
train_mask, val_mask = self.data.train_mask, self.data.val_mask
|
| 110 |
+
|
| 111 |
+
early_stopping = patience
|
| 112 |
+
best_loss_val = 100
|
| 113 |
+
best_acc_val = 0
|
| 114 |
+
best_epoch = 0
|
| 115 |
+
|
| 116 |
+
x, edge_index = self.data.x, self.data.edge_index
|
| 117 |
+
for i in range(train_iters):
|
| 118 |
+
self.train()
|
| 119 |
+
optimizer.zero_grad()
|
| 120 |
+
|
| 121 |
+
output = self.forward(x, edge_index)
|
| 122 |
+
|
| 123 |
+
loss_train = F.nll_loss(output[train_mask], labels[train_mask])
|
| 124 |
+
loss_train.backward()
|
| 125 |
+
optimizer.step()
|
| 126 |
+
|
| 127 |
+
if verbose and i % 50 == 0:
|
| 128 |
+
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
|
| 129 |
+
|
| 130 |
+
self.eval()
|
| 131 |
+
output = self.forward(x, edge_index)
|
| 132 |
+
loss_val = F.nll_loss(output[val_mask], labels[val_mask])
|
| 133 |
+
acc_val = utils.accuracy(output[val_mask], labels[val_mask])
|
| 134 |
+
# print(acc)
|
| 135 |
+
|
| 136 |
+
# if best_loss_val > loss_val:
|
| 137 |
+
# best_loss_val = loss_val
|
| 138 |
+
# self.output = output
|
| 139 |
+
# weights = deepcopy(self.state_dict())
|
| 140 |
+
# patience = early_stopping
|
| 141 |
+
# best_epoch = i
|
| 142 |
+
# else:
|
| 143 |
+
# patience -= 1
|
| 144 |
+
|
| 145 |
+
if best_acc_val < acc_val:
|
| 146 |
+
best_acc_val = acc_val
|
| 147 |
+
self.output = output
|
| 148 |
+
weights = deepcopy(self.state_dict())
|
| 149 |
+
patience = early_stopping
|
| 150 |
+
best_epoch = i
|
| 151 |
+
else:
|
| 152 |
+
patience -= 1
|
| 153 |
+
|
| 154 |
+
if i > early_stopping and patience <= 0:
|
| 155 |
+
break
|
| 156 |
+
|
| 157 |
+
if verbose:
|
| 158 |
+
# print('=== early stopping at {0}, loss_val = {1} ==='.format(best_epoch, best_loss_val) )
|
| 159 |
+
print('=== early stopping at {0}, acc_val = {1} ==='.format(best_epoch, best_acc_val) )
|
| 160 |
+
self.load_state_dict(weights)
|
| 161 |
+
|
| 162 |
+
def test(self):
|
| 163 |
+
"""Evaluate model performance on test set.
|
| 164 |
+
Parameters
|
| 165 |
+
----------
|
| 166 |
+
idx_test :
|
| 167 |
+
node testing indices
|
| 168 |
+
"""
|
| 169 |
+
self.eval()
|
| 170 |
+
test_mask = self.data.test_mask
|
| 171 |
+
labels = self.data.y
|
| 172 |
+
output = self.forward(self.data.x, self.data.edge_index)
|
| 173 |
+
# output = self.output
|
| 174 |
+
loss_test = F.nll_loss(output[test_mask], labels[test_mask])
|
| 175 |
+
acc_test = utils.accuracy(output[test_mask], labels[test_mask])
|
| 176 |
+
print("Test set results:",
|
| 177 |
+
"loss= {:.4f}".format(loss_test.item()),
|
| 178 |
+
"accuracy= {:.4f}".format(acc_test.item()))
|
| 179 |
+
return acc_test.item()
|
| 180 |
+
|
| 181 |
+
def predict(self, x=None, edge_index=None, edge_weight=None):
|
| 182 |
+
"""
|
| 183 |
+
Returns
|
| 184 |
+
-------
|
| 185 |
+
torch.FloatTensor
|
| 186 |
+
output (log probabilities)
|
| 187 |
+
"""
|
| 188 |
+
self.eval()
|
| 189 |
+
if x is None or edge_index is None:
|
| 190 |
+
x, edge_index = self.data.x, self.data.edge_index
|
| 191 |
+
return self.forward(x, edge_index, edge_weight)
|
| 192 |
+
|
| 193 |
+
def _ensure_contiguousness(self,
|
| 194 |
+
x,
|
| 195 |
+
edge_idx,
|
| 196 |
+
edge_weight):
|
| 197 |
+
if not x.is_sparse:
|
| 198 |
+
x = x.contiguous()
|
| 199 |
+
if hasattr(edge_idx, 'contiguous'):
|
| 200 |
+
edge_idx = edge_idx.contiguous()
|
| 201 |
+
if edge_weight is not None:
|
| 202 |
+
edge_weight = edge_weight.contiguous()
|
| 203 |
+
return x, edge_idx, edge_weight
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
|
deeprobust/graph/defense_pyg/gpr.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch_sparse import SparseTensor, matmul
|
| 5 |
+
from torch_geometric.nn import GCNConv, SAGEConv, GATConv, APPNP, MessagePassing
|
| 6 |
+
from torch_geometric.nn.conv.gcn_conv import gcn_norm
|
| 7 |
+
import scipy.sparse
|
| 8 |
+
import numpy as np
|
| 9 |
+
from .base_model import BaseModel
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class GPRGNN(BaseModel):
|
| 13 |
+
"""GPRGNN, from original repo https://github.com/jianhao2016/GPRGNN"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, nfeat, nhid, nclass, Init='PPR', dprate=.5, dropout=.5,
|
| 16 |
+
lr=0.01, weight_decay=0, device='cpu',
|
| 17 |
+
K=10, alpha=.1, Gamma=None, ppnp='GPR_prop'):
|
| 18 |
+
super(GPRGNN, self).__init__()
|
| 19 |
+
self.lin1 = nn.Linear(nfeat, nhid)
|
| 20 |
+
self.lin2 = nn.Linear(nhid, nclass)
|
| 21 |
+
|
| 22 |
+
if ppnp == 'PPNP':
|
| 23 |
+
self.prop1 = APPNP(K, alpha)
|
| 24 |
+
elif ppnp == 'GPR_prop':
|
| 25 |
+
self.prop1 = GPR_prop(K, alpha, Init, Gamma)
|
| 26 |
+
|
| 27 |
+
self.Init = Init
|
| 28 |
+
self.dprate = dprate
|
| 29 |
+
self.dropout = dropout
|
| 30 |
+
self.name = "GPR"
|
| 31 |
+
self.weight_decay = weight_decay
|
| 32 |
+
self.lr = lr
|
| 33 |
+
self.device=device
|
| 34 |
+
|
| 35 |
+
def initialize(self):
|
| 36 |
+
self.reset_parameters()
|
| 37 |
+
|
| 38 |
+
def reset_parameters(self):
|
| 39 |
+
self.lin1.reset_parameters()
|
| 40 |
+
self.lin2.reset_parameters()
|
| 41 |
+
self.prop1.reset_parameters()
|
| 42 |
+
|
| 43 |
+
def forward(self, x, edge_index, edge_weight=None):
|
| 44 |
+
|
| 45 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
| 46 |
+
x = F.relu(self.lin1(x))
|
| 47 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
| 48 |
+
x = self.lin2(x)
|
| 49 |
+
|
| 50 |
+
if edge_weight is not None:
|
| 51 |
+
adj = SparseTensor.from_edge_index(edge_index, edge_weight, sparse_sizes=2 * x.shape[:1])
|
| 52 |
+
if self.dprate == 0.0:
|
| 53 |
+
x = self.prop1(x, adj)
|
| 54 |
+
else:
|
| 55 |
+
x = F.dropout(x, p=self.dprate, training=self.training)
|
| 56 |
+
x = self.prop1(x, adj)
|
| 57 |
+
else:
|
| 58 |
+
if self.dprate == 0.0:
|
| 59 |
+
x = self.prop1(x, edge_index, edge_weight)
|
| 60 |
+
else:
|
| 61 |
+
x = F.dropout(x, p=self.dprate, training=self.training)
|
| 62 |
+
x = self.prop1(x, edge_index, edge_weight)
|
| 63 |
+
|
| 64 |
+
return F.log_softmax(x, dim=1)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class GPR_prop(MessagePassing):
|
| 68 |
+
'''
|
| 69 |
+
GPRGNN, from original repo https://github.com/jianhao2016/GPRGNN
|
| 70 |
+
propagation class for GPR_GNN
|
| 71 |
+
'''
|
| 72 |
+
|
| 73 |
+
def __init__(self, K, alpha, Init, Gamma=None, bias=True, **kwargs):
|
| 74 |
+
super(GPR_prop, self).__init__(aggr='add', **kwargs)
|
| 75 |
+
self.K = K
|
| 76 |
+
self.Init = Init
|
| 77 |
+
self.alpha = alpha
|
| 78 |
+
|
| 79 |
+
assert Init in ['SGC', 'PPR', 'NPPR', 'Random', 'WS']
|
| 80 |
+
if Init == 'SGC':
|
| 81 |
+
# SGC-like
|
| 82 |
+
TEMP = 0.0*np.ones(K+1)
|
| 83 |
+
TEMP[alpha] = 1.0
|
| 84 |
+
elif Init == 'PPR':
|
| 85 |
+
# PPR-like
|
| 86 |
+
TEMP = alpha*(1-alpha)**np.arange(K+1)
|
| 87 |
+
TEMP[-1] = (1-alpha)**K
|
| 88 |
+
elif Init == 'NPPR':
|
| 89 |
+
# Negative PPR
|
| 90 |
+
TEMP = (alpha)**np.arange(K+1)
|
| 91 |
+
TEMP = TEMP/np.sum(np.abs(TEMP))
|
| 92 |
+
elif Init == 'Random':
|
| 93 |
+
# Random
|
| 94 |
+
bound = np.sqrt(3/(K+1))
|
| 95 |
+
TEMP = np.random.uniform(-bound, bound, K+1)
|
| 96 |
+
TEMP = TEMP/np.sum(np.abs(TEMP))
|
| 97 |
+
elif Init == 'WS':
|
| 98 |
+
# Specify Gamma
|
| 99 |
+
TEMP = Gamma
|
| 100 |
+
|
| 101 |
+
self.temp = nn.Parameter(torch.tensor(TEMP))
|
| 102 |
+
|
| 103 |
+
def reset_parameters(self):
|
| 104 |
+
nn.init.zeros_(self.temp)
|
| 105 |
+
for k in range(self.K+1):
|
| 106 |
+
self.temp.data[k] = self.alpha*(1-self.alpha)**k
|
| 107 |
+
self.temp.data[-1] = (1-self.alpha)**self.K
|
| 108 |
+
|
| 109 |
+
def forward(self, x, edge_index, edge_weight=None):
|
| 110 |
+
if isinstance(edge_index, torch.Tensor):
|
| 111 |
+
edge_index, norm = gcn_norm(
|
| 112 |
+
edge_index, edge_weight, num_nodes=x.size(0), dtype=x.dtype)
|
| 113 |
+
elif isinstance(edge_index, SparseTensor):
|
| 114 |
+
edge_index = gcn_norm(
|
| 115 |
+
edge_index, edge_weight, num_nodes=x.size(0), dtype=x.dtype)
|
| 116 |
+
norm = None
|
| 117 |
+
|
| 118 |
+
hidden = x*(self.temp[0])
|
| 119 |
+
for k in range(self.K):
|
| 120 |
+
x = self.propagate(edge_index, x=x, norm=norm)
|
| 121 |
+
gamma = self.temp[k+1]
|
| 122 |
+
hidden = hidden + gamma*x
|
| 123 |
+
return hidden
|
| 124 |
+
|
| 125 |
+
def message(self, x_j, norm):
|
| 126 |
+
return norm.view(-1, 1) * x_j
|
| 127 |
+
|
| 128 |
+
def message_and_aggregate(self, adj_t, x):
|
| 129 |
+
return matmul(adj_t, x, reduce=self.aggr)
|
| 130 |
+
|
| 131 |
+
def __repr__(self):
|
| 132 |
+
return '{}(K={}, temp={})'.format(self.__class__.__name__, self.K,
|
| 133 |
+
self.temp)
|
| 134 |
+
|
| 135 |
+
|
deeprobust/graph/defense_pyg/mygat_conv.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Slightly modified torch_geometric.nn.GATConv to fit the structure learning module"""
|
| 2 |
+
from typing import Union, Tuple, Optional
|
| 3 |
+
from torch_geometric.typing import (OptPairTensor, Adj, Size, NoneType,
|
| 4 |
+
OptTensor)
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch.nn import Parameter, Linear
|
| 10 |
+
from torch_sparse import SparseTensor, set_diag
|
| 11 |
+
from torch_geometric.nn.conv import MessagePassing
|
| 12 |
+
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax
|
| 13 |
+
|
| 14 |
+
from torch_geometric.nn.inits import glorot, zeros
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class GATConv(MessagePassing):
|
| 18 |
+
r"""The graph attentional operator from the `"Graph Attention Networks"
|
| 19 |
+
<https://arxiv.org/abs/1710.10903>`_ paper
|
| 20 |
+
|
| 21 |
+
.. math::
|
| 22 |
+
\mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} +
|
| 23 |
+
\sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j},
|
| 24 |
+
|
| 25 |
+
where the attention coefficients :math:`\alpha_{i,j}` are computed as
|
| 26 |
+
|
| 27 |
+
.. math::
|
| 28 |
+
\alpha_{i,j} =
|
| 29 |
+
\frac{
|
| 30 |
+
\exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top}
|
| 31 |
+
[\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j]
|
| 32 |
+
\right)\right)}
|
| 33 |
+
{\sum_{k \in \mathcal{N}(i) \cup \{ i \}}
|
| 34 |
+
\exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top}
|
| 35 |
+
[\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_k]
|
| 36 |
+
\right)\right)}.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
in_channels (int or tuple): Size of each input sample. A tuple
|
| 40 |
+
corresponds to the sizes of source and target dimensionalities.
|
| 41 |
+
out_channels (int): Size of each output sample.
|
| 42 |
+
heads (int, optional): Number of multi-head-attentions.
|
| 43 |
+
(default: :obj:`1`)
|
| 44 |
+
concat (bool, optional): If set to :obj:`False`, the multi-head
|
| 45 |
+
attentions are averaged instead of concatenated.
|
| 46 |
+
(default: :obj:`True`)
|
| 47 |
+
negative_slope (float, optional): LeakyReLU angle of the negative
|
| 48 |
+
slope. (default: :obj:`0.2`)
|
| 49 |
+
dropout (float, optional): Dropout probability of the normalized
|
| 50 |
+
attention coefficients which exposes each node to a stochastically
|
| 51 |
+
sampled neighborhood during training. (default: :obj:`0`)
|
| 52 |
+
add_self_loops (bool, optional): If set to :obj:`False`, will not add
|
| 53 |
+
self-loops to the input graph. (default: :obj:`True`)
|
| 54 |
+
bias (bool, optional): If set to :obj:`False`, the layer will not learn
|
| 55 |
+
an additive bias. (default: :obj:`True`)
|
| 56 |
+
**kwargs (optional): Additional arguments of
|
| 57 |
+
:class:`torch_geometric.nn.conv.MessagePassing`.
|
| 58 |
+
"""
|
| 59 |
+
_alpha: OptTensor
|
| 60 |
+
|
| 61 |
+
def __init__(self, in_channels: Union[int, Tuple[int, int]],
|
| 62 |
+
out_channels: int, heads: int = 1, concat: bool = True,
|
| 63 |
+
negative_slope: float = 0.2, dropout: float = 0.,
|
| 64 |
+
add_self_loops: bool = True, bias: bool = True, **kwargs):
|
| 65 |
+
kwargs.setdefault('aggr', 'add')
|
| 66 |
+
super(GATConv, self).__init__(node_dim=0, **kwargs)
|
| 67 |
+
|
| 68 |
+
self.in_channels = in_channels
|
| 69 |
+
self.out_channels = out_channels
|
| 70 |
+
self.heads = heads
|
| 71 |
+
self.concat = concat
|
| 72 |
+
self.negative_slope = negative_slope
|
| 73 |
+
self.dropout = dropout
|
| 74 |
+
self.add_self_loops = add_self_loops
|
| 75 |
+
|
| 76 |
+
if isinstance(in_channels, int):
|
| 77 |
+
self.lin_l = Linear(in_channels, heads * out_channels, bias=False)
|
| 78 |
+
self.lin_r = self.lin_l
|
| 79 |
+
else:
|
| 80 |
+
self.lin_l = Linear(in_channels[0], heads * out_channels, False)
|
| 81 |
+
self.lin_r = Linear(in_channels[1], heads * out_channels, False)
|
| 82 |
+
|
| 83 |
+
self.att_l = Parameter(torch.Tensor(1, heads, out_channels))
|
| 84 |
+
self.att_r = Parameter(torch.Tensor(1, heads, out_channels))
|
| 85 |
+
|
| 86 |
+
if bias and concat:
|
| 87 |
+
self.bias = Parameter(torch.Tensor(heads * out_channels))
|
| 88 |
+
elif bias and not concat:
|
| 89 |
+
self.bias = Parameter(torch.Tensor(out_channels))
|
| 90 |
+
else:
|
| 91 |
+
self.register_parameter('bias', None)
|
| 92 |
+
|
| 93 |
+
self._alpha = None
|
| 94 |
+
|
| 95 |
+
self.reset_parameters()
|
| 96 |
+
self.edge_weight = None
|
| 97 |
+
|
| 98 |
+
def reset_parameters(self):
|
| 99 |
+
glorot(self.lin_l.weight)
|
| 100 |
+
glorot(self.lin_r.weight)
|
| 101 |
+
glorot(self.att_l)
|
| 102 |
+
glorot(self.att_r)
|
| 103 |
+
zeros(self.bias)
|
| 104 |
+
|
| 105 |
+
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_weight=None,
|
| 106 |
+
size: Size = None, return_attention_weights=None):
|
| 107 |
+
# type: (Union[Tensor, OptPairTensor], Tensor, Size, NoneType) -> Tensor # noqa
|
| 108 |
+
# type: (Union[Tensor, OptPairTensor], SparseTensor, Size, NoneType) -> Tensor # noqa
|
| 109 |
+
# type: (Union[Tensor, OptPairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa
|
| 110 |
+
# type: (Union[Tensor, OptPairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa
|
| 111 |
+
r"""
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
return_attention_weights (bool, optional): If set to :obj:`True`,
|
| 115 |
+
will additionally return the tuple
|
| 116 |
+
:obj:`(edge_index, attention_weights)`, holding the computed
|
| 117 |
+
attention weights for each edge. (default: :obj:`None`)
|
| 118 |
+
"""
|
| 119 |
+
H, C = self.heads, self.out_channels
|
| 120 |
+
|
| 121 |
+
x_l: OptTensor = None
|
| 122 |
+
x_r: OptTensor = None
|
| 123 |
+
alpha_l: OptTensor = None
|
| 124 |
+
alpha_r: OptTensor = None
|
| 125 |
+
if isinstance(x, Tensor):
|
| 126 |
+
assert x.dim() == 2, 'Static graphs not supported in `GATConv`.'
|
| 127 |
+
x_l = x_r = self.lin_l(x).view(-1, H, C)
|
| 128 |
+
alpha_l = (x_l * self.att_l).sum(dim=-1)
|
| 129 |
+
alpha_r = (x_r * self.att_r).sum(dim=-1)
|
| 130 |
+
else:
|
| 131 |
+
x_l, x_r = x[0], x[1]
|
| 132 |
+
assert x[0].dim() == 2, 'Static graphs not supported in `GATConv`.'
|
| 133 |
+
x_l = self.lin_l(x_l).view(-1, H, C)
|
| 134 |
+
alpha_l = (x_l * self.att_l).sum(dim=-1)
|
| 135 |
+
if x_r is not None:
|
| 136 |
+
x_r = self.lin_r(x_r).view(-1, H, C)
|
| 137 |
+
alpha_r = (x_r * self.att_r).sum(dim=-1)
|
| 138 |
+
|
| 139 |
+
assert x_l is not None
|
| 140 |
+
assert alpha_l is not None
|
| 141 |
+
|
| 142 |
+
if self.add_self_loops:
|
| 143 |
+
if isinstance(edge_index, Tensor):
|
| 144 |
+
num_nodes = x_l.size(0)
|
| 145 |
+
if x_r is not None:
|
| 146 |
+
num_nodes = min(num_nodes, x_r.size(0))
|
| 147 |
+
if size is not None:
|
| 148 |
+
num_nodes = min(size[0], size[1])
|
| 149 |
+
# edge_index, _ = remove_self_loops(edge_index)
|
| 150 |
+
# edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)
|
| 151 |
+
edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
|
| 152 |
+
edge_index, edge_weight = add_self_loops(edge_index, edge_weight, num_nodes=num_nodes)
|
| 153 |
+
self.edge_weight = edge_weight
|
| 154 |
+
|
| 155 |
+
elif isinstance(edge_index, SparseTensor):
|
| 156 |
+
edge_index = set_diag(edge_index)
|
| 157 |
+
|
| 158 |
+
# propagate_type: (x: OptPairTensor, alpha: OptPairTensor)
|
| 159 |
+
out = self.propagate(edge_index, x=(x_l, x_r),
|
| 160 |
+
alpha=(alpha_l, alpha_r), size=size)
|
| 161 |
+
|
| 162 |
+
alpha = self._alpha
|
| 163 |
+
self._alpha = None
|
| 164 |
+
|
| 165 |
+
if self.concat:
|
| 166 |
+
out = out.view(-1, self.heads * self.out_channels)
|
| 167 |
+
else:
|
| 168 |
+
out = out.mean(dim=1)
|
| 169 |
+
|
| 170 |
+
if self.bias is not None:
|
| 171 |
+
out += self.bias
|
| 172 |
+
|
| 173 |
+
if isinstance(return_attention_weights, bool):
|
| 174 |
+
assert alpha is not None
|
| 175 |
+
if isinstance(edge_index, Tensor):
|
| 176 |
+
return out, (edge_index, alpha)
|
| 177 |
+
elif isinstance(edge_index, SparseTensor):
|
| 178 |
+
return out, edge_index.set_value(alpha, layout='coo')
|
| 179 |
+
else:
|
| 180 |
+
return out
|
| 181 |
+
|
| 182 |
+
def message(self, x_j: Tensor, alpha_j: Tensor, alpha_i: OptTensor,
|
| 183 |
+
index: Tensor, ptr: OptTensor,
|
| 184 |
+
size_i: Optional[int]) -> Tensor:
|
| 185 |
+
alpha = alpha_j if alpha_i is None else alpha_j + alpha_i
|
| 186 |
+
alpha = F.leaky_relu(alpha, self.negative_slope)
|
| 187 |
+
alpha = softmax(alpha, index, ptr, size_i)
|
| 188 |
+
self._alpha = alpha
|
| 189 |
+
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
|
| 190 |
+
|
| 191 |
+
if self.edge_weight is not None:
|
| 192 |
+
x_j = self.edge_weight.view(-1, 1, 1) * x_j
|
| 193 |
+
return x_j * alpha.unsqueeze(-1)
|
| 194 |
+
|
| 195 |
+
def __repr__(self):
|
| 196 |
+
return '{}({}, {}, heads={})'.format(self.__class__.__name__,
|
| 197 |
+
self.in_channels,
|
| 198 |
+
self.out_channels, self.heads)
|
deeprobust/graph/rl/__init__.py
ADDED
|
File without changes
|
deeprobust/graph/rl/env.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Adversarial Attacks on Neural Networks for Graph Data. ICML 2018.
|
| 3 |
+
https://arxiv.org/abs/1806.02371
|
| 4 |
+
Author's Implementation
|
| 5 |
+
https://github.com/Hanjun-Dai/graph_adversarial_attack
|
| 6 |
+
This part of code is adopted from the author's implementation (Copyright (c) 2018 Dai, Hanjun and Li, Hui and Tian, Tian and Huang, Xin and Wang, Lin and Zhu, Jun and Song, Le) but modified
|
| 7 |
+
to be integrated into the repository.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import sys
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import networkx as nx
|
| 15 |
+
import random
|
| 16 |
+
from torch.nn.parameter import Parameter
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
import torch.optim as optim
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
from copy import deepcopy
|
| 22 |
+
import pickle as cp
|
| 23 |
+
from deeprobust.graph.utils import *
|
| 24 |
+
import scipy.sparse as sp
|
| 25 |
+
from scipy.sparse.linalg import eigsh
|
| 26 |
+
from deeprobust.graph import utils
|
| 27 |
+
|
| 28 |
+
class StaticGraph(object):
|
| 29 |
+
graph = None
|
| 30 |
+
|
| 31 |
+
@staticmethod
|
| 32 |
+
def get_gsize():
|
| 33 |
+
return torch.Size( (len(StaticGraph.graph), len(StaticGraph.graph)) )
|
| 34 |
+
|
| 35 |
+
class GraphNormTool(object):
|
| 36 |
+
|
| 37 |
+
def __init__(self, normalize, gm, device):
|
| 38 |
+
self.adj_norm = normalize
|
| 39 |
+
self.gm = gm
|
| 40 |
+
g = StaticGraph.graph
|
| 41 |
+
edges = np.array(g.edges(), dtype=np.int64)
|
| 42 |
+
rev_edges = np.array([edges[:, 1], edges[:, 0]], dtype=np.int64)
|
| 43 |
+
|
| 44 |
+
# self_edges = np.array([range(len(g)), range(len(g))], dtype=np.int64)
|
| 45 |
+
# edges = np.hstack((edges.T, rev_edges, self_edges))
|
| 46 |
+
edges = np.hstack((edges.T, rev_edges))
|
| 47 |
+
idxes = torch.LongTensor(edges)
|
| 48 |
+
values = torch.ones(idxes.size()[1])
|
| 49 |
+
|
| 50 |
+
self.raw_adj = torch.sparse.FloatTensor(idxes, values, StaticGraph.get_gsize())
|
| 51 |
+
self.raw_adj = self.raw_adj.to(device)
|
| 52 |
+
|
| 53 |
+
self.normed_adj = self.raw_adj.clone()
|
| 54 |
+
if self.adj_norm:
|
| 55 |
+
if self.gm == 'gcn':
|
| 56 |
+
self.normed_adj = utils.normalize_adj_tensor(self.normed_adj, sparse=True)
|
| 57 |
+
# GraphLaplacianNorm(self.normed_adj)
|
| 58 |
+
else:
|
| 59 |
+
|
| 60 |
+
self.normed_adj = utils.degree_normalize_adj_tensor(self.normed_adj, sparse=True)
|
| 61 |
+
# GraphDegreeNorm(self.normed_adj)
|
| 62 |
+
|
| 63 |
+
def norm_extra(self, added_adj = None):
|
| 64 |
+
if added_adj is None:
|
| 65 |
+
return self.normed_adj
|
| 66 |
+
|
| 67 |
+
new_adj = self.raw_adj + added_adj
|
| 68 |
+
if self.adj_norm:
|
| 69 |
+
if self.gm == 'gcn':
|
| 70 |
+
new_adj = utils.normalize_adj_tensor(new_adj, sparse=True)
|
| 71 |
+
else:
|
| 72 |
+
new_adj = utils.degree_normalize_adj_tensor(new_adj, sparse=True)
|
| 73 |
+
|
| 74 |
+
return new_adj
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class ModifiedGraph(object):
|
| 78 |
+
|
| 79 |
+
def __init__(self, directed_edges = None, weights = None):
|
| 80 |
+
self.edge_set = set() #(first, second)
|
| 81 |
+
self.node_set = set(range(StaticGraph.get_gsize()[0]))
|
| 82 |
+
self.node_set = np.arange(StaticGraph.get_gsize()[0])
|
| 83 |
+
if directed_edges is not None:
|
| 84 |
+
self.directed_edges = deepcopy(directed_edges)
|
| 85 |
+
self.weights = deepcopy(weights)
|
| 86 |
+
else:
|
| 87 |
+
self.directed_edges = []
|
| 88 |
+
self.weights = []
|
| 89 |
+
|
| 90 |
+
def add_edge(self, x, y, z):
|
| 91 |
+
assert x is not None and y is not None
|
| 92 |
+
if x == y:
|
| 93 |
+
return
|
| 94 |
+
for e in self.directed_edges:
|
| 95 |
+
if e[0] == x and e[1] == y:
|
| 96 |
+
return
|
| 97 |
+
if e[1] == x and e[0] == y:
|
| 98 |
+
return
|
| 99 |
+
self.edge_set.add((x, y)) # (first, second)
|
| 100 |
+
self.edge_set.add((y, x)) # (second, first)
|
| 101 |
+
self.directed_edges.append((x, y))
|
| 102 |
+
# assert z < 0
|
| 103 |
+
self.weights.append(z)
|
| 104 |
+
|
| 105 |
+
def get_extra_adj(self, device):
|
| 106 |
+
if len(self.directed_edges):
|
| 107 |
+
edges = np.array(self.directed_edges, dtype=np.int64)
|
| 108 |
+
rev_edges = np.array([edges[:, 1], edges[:, 0]], dtype=np.int64)
|
| 109 |
+
edges = np.hstack((edges.T, rev_edges))
|
| 110 |
+
|
| 111 |
+
idxes = torch.LongTensor(edges)
|
| 112 |
+
values = torch.Tensor(self.weights + self.weights)
|
| 113 |
+
|
| 114 |
+
added_adj = torch.sparse.FloatTensor(idxes, values, StaticGraph.get_gsize())
|
| 115 |
+
|
| 116 |
+
added_adj = added_adj.to(device)
|
| 117 |
+
return added_adj
|
| 118 |
+
else:
|
| 119 |
+
return None
|
| 120 |
+
|
| 121 |
+
def get_possible_nodes(self, target_node):
|
| 122 |
+
# connected = set()
|
| 123 |
+
connected = [target_node]
|
| 124 |
+
for n1, n2 in self.edge_set:
|
| 125 |
+
if n1 == target_node:
|
| 126 |
+
# connected.add(target_node)
|
| 127 |
+
connected.append(n2)
|
| 128 |
+
return np.setdiff1d(self.node_set, np.array(connected))
|
| 129 |
+
|
| 130 |
+
# return self.node_set - connected
|
| 131 |
+
|
| 132 |
+
class NodeAttackEnv(object):
|
| 133 |
+
"""Node attack environment. It executes an action and then change the
|
| 134 |
+
environment status (modify the graph).
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
def __init__(self, features, labels, all_targets, list_action_space, classifier, num_mod=1, reward_type='binary'):
|
| 138 |
+
|
| 139 |
+
self.classifier = classifier
|
| 140 |
+
self.list_action_space = list_action_space
|
| 141 |
+
self.features = features
|
| 142 |
+
self.labels = labels
|
| 143 |
+
self.all_targets = all_targets
|
| 144 |
+
self.num_mod = num_mod
|
| 145 |
+
self.reward_type = reward_type
|
| 146 |
+
|
| 147 |
+
def setup(self, target_nodes):
|
| 148 |
+
self.target_nodes = target_nodes
|
| 149 |
+
self.n_steps = 0
|
| 150 |
+
self.first_nodes = None
|
| 151 |
+
self.rewards = None
|
| 152 |
+
self.binary_rewards = None
|
| 153 |
+
self.modified_list = []
|
| 154 |
+
for i in range(len(self.target_nodes)):
|
| 155 |
+
self.modified_list.append(ModifiedGraph())
|
| 156 |
+
|
| 157 |
+
self.list_acc_of_all = []
|
| 158 |
+
|
| 159 |
+
def step(self, actions):
|
| 160 |
+
"""run actions and get rewards
|
| 161 |
+
"""
|
| 162 |
+
if self.first_nodes is None: # pick the first node of edge
|
| 163 |
+
assert self.n_steps % 2 == 0
|
| 164 |
+
self.first_nodes = actions[:]
|
| 165 |
+
else:
|
| 166 |
+
for i in range(len(self.target_nodes)):
|
| 167 |
+
# assert self.first_nodes[i] != actions[i]
|
| 168 |
+
# deleta an edge from the graph
|
| 169 |
+
self.modified_list[i].add_edge(self.first_nodes[i], actions[i], -1.0)
|
| 170 |
+
self.first_nodes = None
|
| 171 |
+
self.banned_list = None
|
| 172 |
+
self.n_steps += 1
|
| 173 |
+
|
| 174 |
+
if self.isTerminal():
|
| 175 |
+
# only calc reward when its terminal
|
| 176 |
+
acc_list = []
|
| 177 |
+
loss_list = []
|
| 178 |
+
# for i in tqdm(range(len(self.target_nodes))):
|
| 179 |
+
for i in (range(len(self.target_nodes))):
|
| 180 |
+
device = self.labels.device
|
| 181 |
+
extra_adj = self.modified_list[i].get_extra_adj(device=device)
|
| 182 |
+
adj = self.classifier.norm_tool.norm_extra(extra_adj)
|
| 183 |
+
|
| 184 |
+
output = self.classifier(self.features, adj)
|
| 185 |
+
|
| 186 |
+
loss, acc = loss_acc(output, self.labels, self.all_targets, avg_loss=False)
|
| 187 |
+
# _, loss, acc = self.classifier(self.features, Variable(adj), self.all_targets, self.labels, avg_loss=False)
|
| 188 |
+
|
| 189 |
+
cur_idx = self.all_targets.index(self.target_nodes[i])
|
| 190 |
+
acc = np.copy(acc.double().cpu().view(-1).numpy())
|
| 191 |
+
loss = loss.data.cpu().view(-1).numpy()
|
| 192 |
+
self.list_acc_of_all.append(acc)
|
| 193 |
+
acc_list.append(acc[cur_idx])
|
| 194 |
+
loss_list.append(loss[cur_idx])
|
| 195 |
+
|
| 196 |
+
self.binary_rewards = (np.array(acc_list) * -2.0 + 1.0).astype(np.float32)
|
| 197 |
+
if self.reward_type == 'binary':
|
| 198 |
+
self.rewards = (np.array(acc_list) * -2.0 + 1.0).astype(np.float32)
|
| 199 |
+
else:
|
| 200 |
+
assert self.reward_type == 'nll'
|
| 201 |
+
self.rewards = np.array(loss_list).astype(np.float32)
|
| 202 |
+
|
| 203 |
+
def sample_pos_rewards(self, num_samples):
|
| 204 |
+
assert self.list_acc_of_all is not None
|
| 205 |
+
cands = []
|
| 206 |
+
|
| 207 |
+
for i in range(len(self.list_acc_of_all)):
|
| 208 |
+
succ = np.where( self.list_acc_of_all[i] < 0.9 )[0]
|
| 209 |
+
|
| 210 |
+
for j in range(len(succ)):
|
| 211 |
+
|
| 212 |
+
cands.append((i, self.all_targets[succ[j]]))
|
| 213 |
+
|
| 214 |
+
if num_samples > len(cands):
|
| 215 |
+
return cands
|
| 216 |
+
random.shuffle(cands)
|
| 217 |
+
return cands[0:num_samples]
|
| 218 |
+
|
| 219 |
+
def uniformRandActions(self):
|
| 220 |
+
# TODO: here only support deleting edges
|
| 221 |
+
# seems they sample first node from 2-hop neighbours
|
| 222 |
+
act_list = []
|
| 223 |
+
offset = 0
|
| 224 |
+
for i in range(len(self.target_nodes)):
|
| 225 |
+
cur_node = self.target_nodes[i]
|
| 226 |
+
region = self.list_action_space[cur_node]
|
| 227 |
+
|
| 228 |
+
if self.first_nodes is not None and self.first_nodes[i] is not None:
|
| 229 |
+
region = self.list_action_space[self.first_nodes[i]]
|
| 230 |
+
|
| 231 |
+
if region is None: # singleton node
|
| 232 |
+
cur_action = np.random.randint(len(self.list_action_space))
|
| 233 |
+
else: # select from neighbours or 2-hop neighbours
|
| 234 |
+
cur_action = region[np.random.randint(len(region))]
|
| 235 |
+
|
| 236 |
+
act_list.append(cur_action)
|
| 237 |
+
return act_list
|
| 238 |
+
|
| 239 |
+
def isTerminal(self):
|
| 240 |
+
if self.n_steps == 2 * self.num_mod:
|
| 241 |
+
return True
|
| 242 |
+
return False
|
| 243 |
+
|
| 244 |
+
def getStateRef(self):
|
| 245 |
+
cp_first = [None] * len(self.target_nodes)
|
| 246 |
+
if self.first_nodes is not None:
|
| 247 |
+
cp_first = self.first_nodes
|
| 248 |
+
|
| 249 |
+
return zip(self.target_nodes, self.modified_list, cp_first)
|
| 250 |
+
|
| 251 |
+
def cloneState(self):
|
| 252 |
+
cp_first = [None] * len(self.target_nodes)
|
| 253 |
+
if self.first_nodes is not None:
|
| 254 |
+
cp_first = self.first_nodes[:]
|
| 255 |
+
|
| 256 |
+
return list(zip(self.target_nodes[:], deepcopy(self.modified_list), cp_first))
|
| 257 |
+
|
| 258 |
+
|
deeprobust/graph/rl/q_net_node.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
Adversarial Attacks on Neural Networks for Graph Data. ICML 2018.
|
| 3 |
+
https://arxiv.org/abs/1806.02371
|
| 4 |
+
Author's Implementation
|
| 5 |
+
https://github.com/Hanjun-Dai/graph_adversarial_attack
|
| 6 |
+
This part of code is adopted from the author's implementation (Copyright (c) 2018 Dai, Hanjun and Li, Hui and Tian, Tian and Huang, Xin and Wang, Lin and Zhu, Jun and Song, Le) but modified
|
| 7 |
+
to be integrated into the repository.
|
| 8 |
+
'''
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import networkx as nx
|
| 14 |
+
import random
|
| 15 |
+
from torch.nn.parameter import Parameter
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
import torch.optim as optim
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
from deeprobust.graph.rl.env import GraphNormTool
|
| 21 |
+
|
| 22 |
+
class QNetNode(nn.Module):
|
| 23 |
+
|
| 24 |
+
def __init__(self, node_features, node_labels, list_action_space, bilin_q=1, embed_dim=64, mlp_hidden=64, max_lv=1, gm='mean_field', device='cpu'):
|
| 25 |
+
'''
|
| 26 |
+
bilin_q: bilinear q or not
|
| 27 |
+
mlp_hidden: mlp hidden layer size
|
| 28 |
+
mav_lv: max rounds of message passing
|
| 29 |
+
'''
|
| 30 |
+
super(QNetNode, self).__init__()
|
| 31 |
+
self.node_features = node_features
|
| 32 |
+
self.node_labels = node_labels
|
| 33 |
+
self.list_action_space = list_action_space
|
| 34 |
+
self.total_nodes = len(list_action_space)
|
| 35 |
+
|
| 36 |
+
self.bilin_q = bilin_q
|
| 37 |
+
self.embed_dim = embed_dim
|
| 38 |
+
self.mlp_hidden = mlp_hidden
|
| 39 |
+
self.max_lv = max_lv
|
| 40 |
+
self.gm = gm
|
| 41 |
+
|
| 42 |
+
if bilin_q:
|
| 43 |
+
last_wout = embed_dim
|
| 44 |
+
else:
|
| 45 |
+
last_wout = 1
|
| 46 |
+
self.bias_target = Parameter(torch.Tensor(1, embed_dim))
|
| 47 |
+
|
| 48 |
+
if mlp_hidden:
|
| 49 |
+
self.linear_1 = nn.Linear(embed_dim * 2, mlp_hidden)
|
| 50 |
+
self.linear_out = nn.Linear(mlp_hidden, last_wout)
|
| 51 |
+
else:
|
| 52 |
+
self.linear_out = nn.Linear(embed_dim * 2, last_wout)
|
| 53 |
+
|
| 54 |
+
self.w_n2l = Parameter(torch.Tensor(node_features.size()[1], embed_dim))
|
| 55 |
+
self.bias_n2l = Parameter(torch.Tensor(embed_dim))
|
| 56 |
+
self.bias_picked = Parameter(torch.Tensor(1, embed_dim))
|
| 57 |
+
self.conv_params = nn.Linear(embed_dim, embed_dim)
|
| 58 |
+
self.norm_tool = GraphNormTool(normalize=True, gm=self.gm, device=device)
|
| 59 |
+
weights_init(self)
|
| 60 |
+
|
| 61 |
+
def make_spmat(self, n_rows, n_cols, row_idx, col_idx):
|
| 62 |
+
idxes = torch.LongTensor([[row_idx], [col_idx]])
|
| 63 |
+
values = torch.ones(1)
|
| 64 |
+
|
| 65 |
+
sp = torch.sparse.FloatTensor(idxes, values, torch.Size([n_rows, n_cols]))
|
| 66 |
+
if next(self.parameters()).is_cuda:
|
| 67 |
+
sp = sp.cuda()
|
| 68 |
+
return sp
|
| 69 |
+
|
| 70 |
+
def forward(self, time_t, states, actions, greedy_acts=False, is_inference=False):
|
| 71 |
+
|
| 72 |
+
if self.node_features.data.is_sparse:
|
| 73 |
+
input_node_linear = torch.spmm(self.node_features, self.w_n2l)
|
| 74 |
+
else:
|
| 75 |
+
input_node_linear = torch.mm(self.node_features, self.w_n2l)
|
| 76 |
+
|
| 77 |
+
input_node_linear += self.bias_n2l
|
| 78 |
+
|
| 79 |
+
# TODO the number of target nodes is batch_size, it actually parallizes
|
| 80 |
+
target_nodes, batch_graph, picked_nodes = zip(*states)
|
| 81 |
+
|
| 82 |
+
list_pred = []
|
| 83 |
+
prefix_sum = []
|
| 84 |
+
for i in range(len(batch_graph)):
|
| 85 |
+
region = self.list_action_space[target_nodes[i]]
|
| 86 |
+
|
| 87 |
+
node_embed = input_node_linear.clone()
|
| 88 |
+
if picked_nodes is not None and picked_nodes[i] is not None:
|
| 89 |
+
with torch.set_grad_enabled(mode=not is_inference):
|
| 90 |
+
picked_sp = self.make_spmat(self.total_nodes, 1, picked_nodes[i], 0)
|
| 91 |
+
node_embed += torch.spmm(picked_sp, self.bias_picked)
|
| 92 |
+
region = self.list_action_space[picked_nodes[i]]
|
| 93 |
+
|
| 94 |
+
if not self.bilin_q:
|
| 95 |
+
with torch.set_grad_enabled(mode=not is_inference):
|
| 96 |
+
# with torch.no_grad():
|
| 97 |
+
target_sp = self.make_spmat(self.total_nodes, 1, target_nodes[i], 0)
|
| 98 |
+
node_embed += torch.spmm(target_sp, self.bias_target)
|
| 99 |
+
|
| 100 |
+
with torch.set_grad_enabled(mode=not is_inference):
|
| 101 |
+
device = self.node_features.device
|
| 102 |
+
adj = self.norm_tool.norm_extra( batch_graph[i].get_extra_adj(device))
|
| 103 |
+
|
| 104 |
+
lv = 0
|
| 105 |
+
input_message = node_embed
|
| 106 |
+
|
| 107 |
+
node_embed = F.relu(input_message)
|
| 108 |
+
while lv < self.max_lv:
|
| 109 |
+
n2npool = torch.spmm(adj, node_embed)
|
| 110 |
+
node_linear = self.conv_params( n2npool )
|
| 111 |
+
merged_linear = node_linear + input_message
|
| 112 |
+
node_embed = F.relu(merged_linear)
|
| 113 |
+
lv += 1
|
| 114 |
+
|
| 115 |
+
target_embed = node_embed[target_nodes[i], :].view(-1, 1)
|
| 116 |
+
if region is not None:
|
| 117 |
+
node_embed = node_embed[region]
|
| 118 |
+
|
| 119 |
+
graph_embed = torch.mean(node_embed, dim=0, keepdim=True)
|
| 120 |
+
|
| 121 |
+
if actions is None:
|
| 122 |
+
graph_embed = graph_embed.repeat(node_embed.size()[0], 1)
|
| 123 |
+
else:
|
| 124 |
+
if region is not None:
|
| 125 |
+
act_idx = region.index(actions[i])
|
| 126 |
+
else:
|
| 127 |
+
act_idx = actions[i]
|
| 128 |
+
node_embed = node_embed[act_idx, :].view(1, -1)
|
| 129 |
+
|
| 130 |
+
embed_s_a = torch.cat((node_embed, graph_embed), dim=1)
|
| 131 |
+
if self.mlp_hidden:
|
| 132 |
+
embed_s_a = F.relu( self.linear_1(embed_s_a) )
|
| 133 |
+
raw_pred = self.linear_out(embed_s_a)
|
| 134 |
+
|
| 135 |
+
if self.bilin_q:
|
| 136 |
+
raw_pred = torch.mm(raw_pred, target_embed)
|
| 137 |
+
list_pred.append(raw_pred)
|
| 138 |
+
|
| 139 |
+
if greedy_acts:
|
| 140 |
+
actions, _ = node_greedy_actions(target_nodes, picked_nodes, list_pred, self)
|
| 141 |
+
|
| 142 |
+
return actions, list_pred
|
| 143 |
+
|
| 144 |
+
class NStepQNetNode(nn.Module):
|
| 145 |
+
|
| 146 |
+
def __init__(self, num_steps, node_features, node_labels, list_action_space, bilin_q=1, embed_dim=64, mlp_hidden=64, max_lv=1, gm='mean_field', device='cpu'):
|
| 147 |
+
|
| 148 |
+
super(NStepQNetNode, self).__init__()
|
| 149 |
+
self.node_features = node_features
|
| 150 |
+
self.node_labels = node_labels
|
| 151 |
+
self.list_action_space = list_action_space
|
| 152 |
+
self.total_nodes = len(list_action_space)
|
| 153 |
+
|
| 154 |
+
list_mod = []
|
| 155 |
+
for i in range(0, num_steps):
|
| 156 |
+
# list_mod.append(QNetNode(node_features, node_labels, list_action_space))
|
| 157 |
+
list_mod.append(QNetNode(node_features, node_labels, list_action_space, bilin_q, embed_dim, mlp_hidden, max_lv, gm=gm, device=device))
|
| 158 |
+
|
| 159 |
+
self.list_mod = nn.ModuleList(list_mod)
|
| 160 |
+
self.num_steps = num_steps
|
| 161 |
+
|
| 162 |
+
def forward(self, time_t, states, actions, greedy_acts = False, is_inference=False):
|
| 163 |
+
assert time_t >= 0 and time_t < self.num_steps
|
| 164 |
+
|
| 165 |
+
return self.list_mod[time_t](time_t, states, actions, greedy_acts, is_inference)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def glorot_uniform(t):
|
| 169 |
+
if len(t.size()) == 2:
|
| 170 |
+
fan_in, fan_out = t.size()
|
| 171 |
+
elif len(t.size()) == 3:
|
| 172 |
+
# out_ch, in_ch, kernel for Conv 1
|
| 173 |
+
fan_in = t.size()[1] * t.size()[2]
|
| 174 |
+
fan_out = t.size()[0] * t.size()[2]
|
| 175 |
+
else:
|
| 176 |
+
fan_in = np.prod(t.size())
|
| 177 |
+
fan_out = np.prod(t.size())
|
| 178 |
+
|
| 179 |
+
limit = np.sqrt(6.0 / (fan_in + fan_out))
|
| 180 |
+
t.uniform_(-limit, limit)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def _param_init(m):
|
| 184 |
+
if isinstance(m, Parameter):
|
| 185 |
+
glorot_uniform(m.data)
|
| 186 |
+
elif isinstance(m, nn.Linear):
|
| 187 |
+
m.bias.data.zero_()
|
| 188 |
+
glorot_uniform(m.weight.data)
|
| 189 |
+
|
| 190 |
+
def weights_init(m):
|
| 191 |
+
for p in m.modules():
|
| 192 |
+
if isinstance(p, nn.ParameterList):
|
| 193 |
+
for pp in p:
|
| 194 |
+
_param_init(pp)
|
| 195 |
+
else:
|
| 196 |
+
_param_init(p)
|
| 197 |
+
|
| 198 |
+
for name, p in m.named_parameters():
|
| 199 |
+
if not '.' in name: # top-level parameters
|
| 200 |
+
_param_init(p)
|
| 201 |
+
|
| 202 |
+
def node_greedy_actions(target_nodes, picked_nodes, list_q, net):
|
| 203 |
+
assert len(target_nodes) == len(list_q)
|
| 204 |
+
|
| 205 |
+
actions = []
|
| 206 |
+
values = []
|
| 207 |
+
for i in range(len(target_nodes)):
|
| 208 |
+
region = net.list_action_space[target_nodes[i]]
|
| 209 |
+
if picked_nodes is not None and picked_nodes[i] is not None:
|
| 210 |
+
region = net.list_action_space[picked_nodes[i]]
|
| 211 |
+
if region is None:
|
| 212 |
+
assert list_q[i].size()[0] == net.total_nodes
|
| 213 |
+
else:
|
| 214 |
+
assert len(region) == list_q[i].size()[0]
|
| 215 |
+
|
| 216 |
+
val, act = torch.max(list_q[i], dim=0)
|
| 217 |
+
values.append(val)
|
| 218 |
+
if region is not None:
|
| 219 |
+
act = region[act.data.cpu().numpy()[0]]
|
| 220 |
+
# act = Variable(torch.LongTensor([act]))
|
| 221 |
+
act = torch.LongTensor([act])
|
| 222 |
+
actions.append(act)
|
| 223 |
+
else:
|
| 224 |
+
actions.append(act)
|
| 225 |
+
|
| 226 |
+
return torch.cat(actions, dim=0).data, torch.cat(values, dim=0).data
|
| 227 |
+
|
| 228 |
+
|
deeprobust/image/adversary_examples/advexample.png
ADDED
|
deeprobust/image/adversary_examples/cifar_advexample_orig.png
ADDED
|
deeprobust/image/adversary_examples/cifar_advexample_pgd.png
ADDED
|
deeprobust/image/adversary_examples/deepfool_diff.png
ADDED
|
deeprobust/image/adversary_examples/imageexample.png
ADDED
|
deeprobust/image/adversary_examples/test.jpg
ADDED
|
deeprobust/image/adversary_examples/test1.jpg
ADDED
|
deeprobust/image/evaluation_attack.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
import torch
|
| 3 |
+
from torchvision import datasets,models,transforms
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import argparse
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
import random
|
| 11 |
+
|
| 12 |
+
from deeprobust.image import utils
|
| 13 |
+
|
| 14 |
+
def run_attack(attackmethod, batch_size, batch_num, device, test_loader, random_targeted = False, target_label = -1, **kwargs):
|
| 15 |
+
test_loss = 0
|
| 16 |
+
correct = 0
|
| 17 |
+
samplenum = 1000
|
| 18 |
+
count = 0
|
| 19 |
+
classnum = 10
|
| 20 |
+
for count, (data, target) in enumerate(test_loader):
|
| 21 |
+
if count == batch_num:
|
| 22 |
+
break
|
| 23 |
+
print('batch:{}'.format(count))
|
| 24 |
+
|
| 25 |
+
data, target = data.to(device), target.to(device)
|
| 26 |
+
if(random_targeted == True):
|
| 27 |
+
r = list(range(0, target)) + list(range(target+1, classnum))
|
| 28 |
+
target_label = random.choice(r)
|
| 29 |
+
adv_example = attackmethod.generate(data, target, target_label = target_label, **kwargs)
|
| 30 |
+
|
| 31 |
+
elif(target_label >= 0):
|
| 32 |
+
adv_example = attackmethod.generate(data, target, target_label = target_label, **kwargs)
|
| 33 |
+
|
| 34 |
+
else:
|
| 35 |
+
adv_example = attackmethod.generate(data, target, **kwargs)
|
| 36 |
+
|
| 37 |
+
output = model(adv_example)
|
| 38 |
+
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
|
| 39 |
+
|
| 40 |
+
pred = output.argmax(dim = 1, keepdim = True) # get the index of the max log-probability.
|
| 41 |
+
|
| 42 |
+
correct += pred.eq(target.view_as(pred)).sum().item()
|
| 43 |
+
|
| 44 |
+
batch_num = count+1
|
| 45 |
+
test_loss /= len(test_loader.dataset)
|
| 46 |
+
print("===== ACCURACY =====")
|
| 47 |
+
print('Attack Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
|
| 48 |
+
test_loss, correct, batch_num * batch_size,
|
| 49 |
+
100. * correct / (batch_num * batch_size)))
|
| 50 |
+
|
| 51 |
+
def load_net(attack_model, filename, path):
|
| 52 |
+
if(attack_model == "CNN"):
|
| 53 |
+
from deeprobust.image.netmodels.CNN import Net
|
| 54 |
+
|
| 55 |
+
model = Net()
|
| 56 |
+
if(attack_model == "ResNet18"):
|
| 57 |
+
import deeprobust.image.netmodels.resnet as Net
|
| 58 |
+
model = Net.ResNet18()
|
| 59 |
+
|
| 60 |
+
model.load_state_dict(torch.load(path + filename))
|
| 61 |
+
model.eval()
|
| 62 |
+
return model
|
| 63 |
+
|
| 64 |
+
def generate_dataloader(dataset, batch_size):
|
| 65 |
+
if(dataset == "MNIST"):
|
| 66 |
+
test_loader = torch.utils.data.DataLoader(
|
| 67 |
+
datasets.MNIST('deeprobust/image/data', train = False,
|
| 68 |
+
download = True,
|
| 69 |
+
transform = transforms.Compose([transforms.ToTensor()])),
|
| 70 |
+
batch_size = args.batch_size,
|
| 71 |
+
shuffle = True)
|
| 72 |
+
print("Loading MNIST dataset.")
|
| 73 |
+
|
| 74 |
+
elif(dataset == "CIFAR" or args.dataset == 'CIFAR10'):
|
| 75 |
+
test_loader = torch.utils.data.DataLoader(
|
| 76 |
+
datasets.CIFAR10('deeprobust/image/data', train = False,
|
| 77 |
+
download = True,
|
| 78 |
+
transform = transforms.Compose([transforms.ToTensor()])),
|
| 79 |
+
batch_size = args.batch_size,
|
| 80 |
+
shuffle = True)
|
| 81 |
+
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
|
| 82 |
+
print("Loading CIFAR10 dataset.")
|
| 83 |
+
|
| 84 |
+
elif(dataset == "ImageNet"):
|
| 85 |
+
test_loader = torch.utils.data.DataLoader(
|
| 86 |
+
datasets.CIFAR10('deeprobust/image/data', train=False,
|
| 87 |
+
download = True,
|
| 88 |
+
transform = transforms.Compose([transforms.ToTensor()])),
|
| 89 |
+
batch_size = args.batch_size,
|
| 90 |
+
shuffle = True)
|
| 91 |
+
print("Loading ImageNet dataset.")
|
| 92 |
+
return test_loader
|
| 93 |
+
|
| 94 |
+
def parameter_parser():
|
| 95 |
+
parser = argparse.ArgumentParser(description = "Run attack algorithms.", usage ='Use -h for more information.')
|
| 96 |
+
|
| 97 |
+
parser.add_argument("--attack_method",
|
| 98 |
+
default = 'PGD',
|
| 99 |
+
help = "Choose a attack algorithm from: PGD(default), FGSM, LBFGS, CW, deepfool, onepixel, Nattack")
|
| 100 |
+
parser.add_argument("--attack_model",
|
| 101 |
+
default = "CNN",
|
| 102 |
+
help = "Choose network structure from: CNN, ResNet")
|
| 103 |
+
parser.add_argument("--path",
|
| 104 |
+
default = "./trained_models/",
|
| 105 |
+
help = "Type the path where the model is saved.")
|
| 106 |
+
parser.add_argument("--file_name",
|
| 107 |
+
default = 'MNIST_CNN_epoch_20.pt',
|
| 108 |
+
help = "Type the file_name of the model that is to be attack. The model structure should be matched with the ATTACK_MODEL parameter.")
|
| 109 |
+
parser.add_argument("--dataset",
|
| 110 |
+
default = 'MNIST',
|
| 111 |
+
help = "Choose a dataset from: MNIST(default), CIFAR(or CIFAR10), ImageNet")
|
| 112 |
+
parser.add_argument("--epsilon", type = float, default = 0.3)
|
| 113 |
+
parser.add_argument("--batch_num", type = int, default = 1000)
|
| 114 |
+
parser.add_argument("--batch_size", type = int, default = 1000)
|
| 115 |
+
parser.add_argument("--num_steps", type = int, default = 40)
|
| 116 |
+
parser.add_argument("--step_size", type = float, default = 0.01)
|
| 117 |
+
parser.add_argument("--random_targeted", type = bool, default = False,
|
| 118 |
+
help = "default: False. By setting this parameter be True, the program would random generate target labels for the input samples.")
|
| 119 |
+
parser.add_argument("--target_label", type = int, default = -1,
|
| 120 |
+
help = "default: -1. Generate all attack Fixed target label.")
|
| 121 |
+
parser.add_argument("--device", default = 'cuda',
|
| 122 |
+
help = "Choose the device.")
|
| 123 |
+
|
| 124 |
+
return parser.parse_args()
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
if __name__ == "__main__":
|
| 128 |
+
# read arguments
|
| 129 |
+
args = parameter_parser() # read argument and creat an argparse object
|
| 130 |
+
|
| 131 |
+
# download example model
|
| 132 |
+
example_model_path = './trained_models/MNIST_CNN_epoch_20.pt'
|
| 133 |
+
if not (os.path.exists('./trained_models')):
|
| 134 |
+
os.mkdir('./trained_models')
|
| 135 |
+
print('create path: ./trained_models')
|
| 136 |
+
model_url = "https://github.com/I-am-Bot/deeprobust_trained_model/blob/master/MNIST_CNN_epoch_20.pt?raw=true"
|
| 137 |
+
r = requests.get(model_url)
|
| 138 |
+
print('Downloading example model...')
|
| 139 |
+
with open(example_model_path,'wb') as f:
|
| 140 |
+
f.write(r.content)
|
| 141 |
+
print('Downloaded.')
|
| 142 |
+
# load model
|
| 143 |
+
model = load_net(args.attack_model, args.file_name, args.path)
|
| 144 |
+
|
| 145 |
+
print("===== START ATTACK =====")
|
| 146 |
+
if(args.attack_method == "PGD"):
|
| 147 |
+
from deeprobust.image.attack.pgd import PGD
|
| 148 |
+
test_loader = generate_dataloader(args.dataset, args.batch_size)
|
| 149 |
+
attack_method = PGD(model, args.device)
|
| 150 |
+
utils.tab_printer(args)
|
| 151 |
+
run_attack(attack_method, args.batch_size, args.batch_num, args.device, test_loader, epsilon = args.epsilon)
|
| 152 |
+
|
| 153 |
+
elif(args.attack_method == "FGSM"):
|
| 154 |
+
from deeprobust.image.attack.fgsm import FGSM
|
| 155 |
+
test_loader = generate_dataloader(args.dataset, args.batch_size)
|
| 156 |
+
attack_method = FGSM(model, args.device)
|
| 157 |
+
utils.tab_printer(args)
|
| 158 |
+
run_attack(attack_method, args.batch_size, args.batch_num, args.device, test_loader, epsilon = args.epsilon)
|
| 159 |
+
|
| 160 |
+
elif(args.attack_method == "LBFGS"):
|
| 161 |
+
from deeprobust.image.attack.lbfgs import LBFGS
|
| 162 |
+
try:
|
| 163 |
+
if (args.batch_size >1):
|
| 164 |
+
raise ValueError("batch_size shouldn't be larger than 1.")
|
| 165 |
+
except ValueError:
|
| 166 |
+
args.batch_size = 1
|
| 167 |
+
|
| 168 |
+
try:
|
| 169 |
+
if (args.random_targeted == 0 and args.target_label == -1):
|
| 170 |
+
raise ValueError("No target label assigned. Random generate target for each input.")
|
| 171 |
+
except ValueError:
|
| 172 |
+
args.random_targeted = True
|
| 173 |
+
|
| 174 |
+
utils.tab_printer(args)
|
| 175 |
+
test_loader = generate_dataloader(args.dataset, args.batch_size)
|
| 176 |
+
attack_method = LBFGS(model, args.device)
|
| 177 |
+
run_attack(attack_method, 1, args.batch_num, args.device, test_loader, random_targeted = args.random_targeted, target_label = args.target_label)
|
| 178 |
+
|
| 179 |
+
elif(args.attack_method == "CW"):
|
| 180 |
+
from deeprobust.image.attack.cw import CarliniWagner
|
| 181 |
+
attack_method = CarliniWagner(model, args.device)
|
| 182 |
+
try:
|
| 183 |
+
if (args.batch_size > 1):
|
| 184 |
+
raise ValueError("batch_size shouldn't be larger than 1.")
|
| 185 |
+
except ValueError:
|
| 186 |
+
args.batch_size = 1
|
| 187 |
+
|
| 188 |
+
try:
|
| 189 |
+
if (args.random_targeted == 0 and args.target_label == -1):
|
| 190 |
+
raise ValueError("No target label assigned. Random generate target for each input.")
|
| 191 |
+
except ValueError:
|
| 192 |
+
args.random_targeted = True
|
| 193 |
+
|
| 194 |
+
utils.tab_printer(args)
|
| 195 |
+
test_loader = generate_dataloader(args.dataset, args.batch_size)
|
| 196 |
+
run_attack(attack_method, 1, args.batch_num, args.device, test_loader, random_targeted = args.random_targeted, target_label = args.target_label)
|
| 197 |
+
|
| 198 |
+
elif(args.attack_method == "deepfool"):
|
| 199 |
+
from deeprobust.image.attack.deepfool import DeepFool
|
| 200 |
+
attack_method = DeepFool(model, args.device)
|
| 201 |
+
try:
|
| 202 |
+
if (args.batch_size > 1):
|
| 203 |
+
raise ValueError("batch_size shouldn't be larger than 1.")
|
| 204 |
+
except ValueError:
|
| 205 |
+
args.batch_size = 1
|
| 206 |
+
|
| 207 |
+
utils.tab_printer(args)
|
| 208 |
+
test_loader = generate_dataloader(args.dataset, args.batch_size)
|
| 209 |
+
run_attack(attack_method, args.batch_size, args.batch_num, args.device, test_loader)
|
| 210 |
+
|
| 211 |
+
elif(args.attack_method == "onepixel"):
|
| 212 |
+
from deeprobust.image.attack.onepixel import Onepixel
|
| 213 |
+
attack_method = Onepixel(model, args.device)
|
| 214 |
+
try:
|
| 215 |
+
if (args.batch_size > 1):
|
| 216 |
+
raise ValueError("batch_size shouldn't be larger than 1.")
|
| 217 |
+
except ValueError:
|
| 218 |
+
args.batch_size = 1
|
| 219 |
+
|
| 220 |
+
utils.tab_printer(args)
|
| 221 |
+
test_loader = generate_dataloader(args.dataset, args.batch_size)
|
| 222 |
+
run_attack(attack_method, args.batch_size, args.batch_num, args.device, test_loader)
|
| 223 |
+
|
| 224 |
+
elif(args.attack_method == "Nattack"):
|
| 225 |
+
pass
|
| 226 |
+
|
deeprobust/image/netmodels/CNN.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This is an implementatio of a Convolution Neural Network with 2 Convolutional layer.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import print_function
|
| 6 |
+
import argparse
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F #233
|
| 10 |
+
import torch.optim as optim
|
| 11 |
+
from torchvision import datasets, transforms
|
| 12 |
+
import numpy as np
|
| 13 |
+
from PIL import Image
|
| 14 |
+
|
| 15 |
+
class Net(nn.Module):
|
| 16 |
+
"""Model counterparts.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, in_channel1 = 1, out_channel1 = 32, out_channel2 = 64, H = 28, W = 28):
|
| 20 |
+
super(Net, self).__init__()
|
| 21 |
+
self.H = H
|
| 22 |
+
self.W = W
|
| 23 |
+
self.out_channel2 = out_channel2
|
| 24 |
+
|
| 25 |
+
## define two convolutional layers
|
| 26 |
+
self.conv1 = nn.Conv2d(in_channels = in_channel1,
|
| 27 |
+
out_channels = out_channel1,
|
| 28 |
+
kernel_size = 5,
|
| 29 |
+
stride= 1,
|
| 30 |
+
padding = (2,2))
|
| 31 |
+
self.conv2 = nn.Conv2d(in_channels = out_channel1,
|
| 32 |
+
out_channels = out_channel2,
|
| 33 |
+
kernel_size = 5,
|
| 34 |
+
stride = 1,
|
| 35 |
+
padding = (2,2))
|
| 36 |
+
|
| 37 |
+
## define two linear layers
|
| 38 |
+
self.fc1 = nn.Linear(int(self.H/4) * int(self.W/4) * out_channel2, 1024)
|
| 39 |
+
self.fc2 = nn.Linear(1024, 10)
|
| 40 |
+
|
| 41 |
+
def forward(self, x):
|
| 42 |
+
|
| 43 |
+
x = F.relu(self.conv1(x))
|
| 44 |
+
x = F.max_pool2d(x, 2, 2)
|
| 45 |
+
x = F.relu(self.conv2(x))
|
| 46 |
+
x = F.max_pool2d(x, 2, 2)
|
| 47 |
+
x = x.view(-1, int(self.H/4) * int(self.W/4) * self.out_channel2)
|
| 48 |
+
x = F.relu(self.fc1(x))
|
| 49 |
+
x = self.fc2(x)
|
| 50 |
+
return x
|
| 51 |
+
|
| 52 |
+
def get_logits(self, x):
|
| 53 |
+
x = F.relu(self.conv1(x))
|
| 54 |
+
x = F.max_pool2d(x, 2, 2)
|
| 55 |
+
x = F.relu(self.conv2(x))
|
| 56 |
+
x = F.max_pool2d(x, 2, 2)
|
| 57 |
+
x = x.view(-1, int(self.H/4) * int(self.W/4) * self.out_channel2)
|
| 58 |
+
x = F.relu(self.fc1(x))
|
| 59 |
+
x = self.fc2(x)
|
| 60 |
+
return x
|
| 61 |
+
|
| 62 |
+
def train(model, device, train_loader, optimizer, epoch):
|
| 63 |
+
"""train network.
|
| 64 |
+
|
| 65 |
+
Parameters
|
| 66 |
+
----------
|
| 67 |
+
model :
|
| 68 |
+
model
|
| 69 |
+
device :
|
| 70 |
+
device(option:'cpu','cuda')
|
| 71 |
+
train_loader :
|
| 72 |
+
training data loader
|
| 73 |
+
optimizer :
|
| 74 |
+
optimizer
|
| 75 |
+
epoch :
|
| 76 |
+
epoch
|
| 77 |
+
"""
|
| 78 |
+
model.train()
|
| 79 |
+
for batch_idx, (data, target) in enumerate(train_loader):
|
| 80 |
+
data, target = data.to(device), target.to(device)
|
| 81 |
+
optimizer.zero_grad()
|
| 82 |
+
output = model(data)
|
| 83 |
+
loss = F.cross_entropy(output, target)
|
| 84 |
+
loss.backward()
|
| 85 |
+
optimizer.step()
|
| 86 |
+
|
| 87 |
+
#print every 10
|
| 88 |
+
if batch_idx % 10 == 0:
|
| 89 |
+
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
|
| 90 |
+
epoch, batch_idx * len(data), len(train_loader.dataset),
|
| 91 |
+
100. * batch_idx / len(train_loader), loss.item()))
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def test(model, device, test_loader):
|
| 95 |
+
"""test network.
|
| 96 |
+
|
| 97 |
+
Parameters
|
| 98 |
+
----------
|
| 99 |
+
model :
|
| 100 |
+
model
|
| 101 |
+
device :
|
| 102 |
+
device(option:'cpu', 'cuda')
|
| 103 |
+
test_loader :
|
| 104 |
+
testing data loader
|
| 105 |
+
"""
|
| 106 |
+
model.eval()
|
| 107 |
+
|
| 108 |
+
test_loss = 0
|
| 109 |
+
correct = 0
|
| 110 |
+
with torch.no_grad():
|
| 111 |
+
for data, target in test_loader:
|
| 112 |
+
data, target = data.to(device), target.to(device)
|
| 113 |
+
output = model(data)
|
| 114 |
+
test_loss += F.cross_entropy(output, target, reduction='sum').item() # sum up batch loss
|
| 115 |
+
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
|
| 116 |
+
correct += pred.eq(target.view_as(pred)).sum().item()
|
| 117 |
+
|
| 118 |
+
test_loss /= len(test_loader.dataset)
|
| 119 |
+
|
| 120 |
+
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
|
| 121 |
+
test_loss, correct, len(test_loader.dataset),
|
| 122 |
+
100. * correct / len(test_loader.dataset)))
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
|
deeprobust/image/netmodels/CNN_multilayer.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This is an implementation of Convolution Neural Network with multi conv layer.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import print_function
|
| 6 |
+
import argparse
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F #233
|
| 10 |
+
import torch.optim as optim
|
| 11 |
+
from torchvision import datasets, transforms
|
| 12 |
+
import numpy as np
|
| 13 |
+
from PIL import Image
|
| 14 |
+
|
| 15 |
+
class Net(nn.Module):
|
| 16 |
+
def __init__(self, in_channel1 = 1, out_channel1 = 32, out_channel2 = 64, H = 28, W = 28):
|
| 17 |
+
super(Net, self).__init__()
|
| 18 |
+
self.H = H
|
| 19 |
+
self.W = W
|
| 20 |
+
self.out_channel2 = out_channel2
|
| 21 |
+
|
| 22 |
+
## define two convolutional layers
|
| 23 |
+
self.conv1 = nn.Conv2d(in_channels = in_channel1,
|
| 24 |
+
out_channels = out_channel1,
|
| 25 |
+
kernel_size = 5,
|
| 26 |
+
stride= 1,
|
| 27 |
+
padding = (2,2))
|
| 28 |
+
self.conv2 = nn.Conv2d(in_channels = out_channel1,
|
| 29 |
+
out_channels = out_channel2,
|
| 30 |
+
kernel_size = 5,
|
| 31 |
+
stride = 1,
|
| 32 |
+
padding = (2,2))
|
| 33 |
+
|
| 34 |
+
## define two linear layers
|
| 35 |
+
self.fc1 = nn.Linear(int(H/4)*int(W/4)* out_channel2, 1024)
|
| 36 |
+
self.fc2 = nn.Linear(1024, 10)
|
| 37 |
+
|
| 38 |
+
def forward(self, x):
|
| 39 |
+
self.layers[0] = F.relu(self.conv1(x))
|
| 40 |
+
self.layers[1] = F.max_pool2d(x, 2, 2)
|
| 41 |
+
self.layers[2] = F.relu(self.conv2(x))
|
| 42 |
+
self.layers[3] = F.max_pool2d(x, 2, 2)
|
| 43 |
+
self.layers[4] = x.view(-1, int(self.H/4) * int(self.W/4) * self.out_channel2)
|
| 44 |
+
self.layers[5] = F.relu(self.fc1(x))
|
| 45 |
+
self.layers[6] = self.fc2(x)
|
| 46 |
+
return F.log_softmax(layers[6], dim=1)
|
| 47 |
+
|
| 48 |
+
#def get_logits(self, x):
|
| 49 |
+
#x = F.relu(self.conv1(x))
|
| 50 |
+
#x = F.max_pool2d(x, 2, 2)
|
| 51 |
+
#x = F.relu(self.conv2(x))
|
| 52 |
+
#x = F.max_pool2d(x, 2, 2)
|
| 53 |
+
#x = x.view(-1, 4* 4 * 50)
|
| 54 |
+
#x = F.relu(self.fc1(x))
|
| 55 |
+
#x = self.fc2(x)
|
| 56 |
+
#return x
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def train(model, device, train_loader, optimizer, epoch):
|
| 60 |
+
"""train.
|
| 61 |
+
|
| 62 |
+
Parameters
|
| 63 |
+
----------
|
| 64 |
+
model :
|
| 65 |
+
model
|
| 66 |
+
device :
|
| 67 |
+
device
|
| 68 |
+
train_loader :
|
| 69 |
+
train_loader
|
| 70 |
+
optimizer :
|
| 71 |
+
optimizer
|
| 72 |
+
epoch :
|
| 73 |
+
epoch
|
| 74 |
+
"""
|
| 75 |
+
model.train()
|
| 76 |
+
for batch_idx, (data, target) in enumerate(train_loader):
|
| 77 |
+
data, target = data.to(device), target.to(device)
|
| 78 |
+
optimizer.zero_grad()
|
| 79 |
+
output = model(data)
|
| 80 |
+
loss = F.nll_loss(output, target)
|
| 81 |
+
loss.backward()
|
| 82 |
+
optimizer.step()
|
| 83 |
+
|
| 84 |
+
#print every 10
|
| 85 |
+
if batch_idx % 10 == 0:
|
| 86 |
+
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
|
| 87 |
+
epoch, batch_idx * len(data), len(train_loader.dataset),
|
| 88 |
+
100. * batch_idx / len(train_loader), loss.item()))
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def test(model, device, test_loader):
|
| 92 |
+
"""test.
|
| 93 |
+
|
| 94 |
+
Parameters
|
| 95 |
+
----------
|
| 96 |
+
model :
|
| 97 |
+
model
|
| 98 |
+
device :
|
| 99 |
+
device
|
| 100 |
+
test_loader :
|
| 101 |
+
test_loader
|
| 102 |
+
"""
|
| 103 |
+
model.eval()
|
| 104 |
+
|
| 105 |
+
test_loss = 0
|
| 106 |
+
correct = 0
|
| 107 |
+
with torch.no_grad():
|
| 108 |
+
for data, target in test_loader:
|
| 109 |
+
data, target = data.to(device), target.to(device)
|
| 110 |
+
output = model(data)
|
| 111 |
+
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
|
| 112 |
+
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
|
| 113 |
+
correct += pred.eq(target.view_as(pred)).sum().item()
|
| 114 |
+
|
| 115 |
+
test_loss /= len(test_loader.dataset)
|
| 116 |
+
|
| 117 |
+
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
|
| 118 |
+
test_loss, correct, len(test_loader.dataset),
|
| 119 |
+
100. * correct / len(test_loader.dataset)))
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
|
deeprobust/image/netmodels/YOPOCNN.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model for YOPO.
|
| 3 |
+
|
| 4 |
+
Reference
|
| 5 |
+
---------
|
| 6 |
+
..[1]https://github.com/a1600012888/YOPO-You-Only-Propagate-Once
|
| 7 |
+
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from collections import OrderedDict
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Net(nn.Module):
|
| 15 |
+
def __init__(self, drop=0.5):
|
| 16 |
+
super(Net, self).__init__()
|
| 17 |
+
|
| 18 |
+
self.num_channels = 1
|
| 19 |
+
self.num_labels = 10
|
| 20 |
+
|
| 21 |
+
activ = nn.ReLU(True)
|
| 22 |
+
self.conv1 = nn.Conv2d(self.num_channels, 32, 3)
|
| 23 |
+
self.layer_one = nn.Sequential(OrderedDict([
|
| 24 |
+
('conv1', self.conv1),
|
| 25 |
+
('relu1', activ),]))
|
| 26 |
+
|
| 27 |
+
self.feature_extractor = nn.Sequential(OrderedDict([
|
| 28 |
+
('conv2', nn.Conv2d(32, 32, 3)),
|
| 29 |
+
('relu2', activ),
|
| 30 |
+
('maxpool1', nn.MaxPool2d(2, 2)),
|
| 31 |
+
('conv3', nn.Conv2d(32, 64, 3)),
|
| 32 |
+
('relu3', activ),
|
| 33 |
+
('conv4', nn.Conv2d(64, 64, 3)),
|
| 34 |
+
('relu4', activ),
|
| 35 |
+
('maxpool2', nn.MaxPool2d(2, 2)),
|
| 36 |
+
]))
|
| 37 |
+
|
| 38 |
+
self.classifier = nn.Sequential(OrderedDict([
|
| 39 |
+
('fc1', nn.Linear(64 * 4 * 4, 200)),
|
| 40 |
+
('relu1', activ),
|
| 41 |
+
('drop', nn.Dropout(drop)),
|
| 42 |
+
('fc2', nn.Linear(200, 200)),
|
| 43 |
+
('relu2', activ),
|
| 44 |
+
('fc3', nn.Linear(200, self.num_labels)),
|
| 45 |
+
]))
|
| 46 |
+
self.other_layers = nn.ModuleList()
|
| 47 |
+
self.other_layers.append(self.feature_extractor)
|
| 48 |
+
self.other_layers.append(self.classifier)
|
| 49 |
+
|
| 50 |
+
for m in self.modules():
|
| 51 |
+
if isinstance(m, (nn.Conv2d)):
|
| 52 |
+
nn.init.kaiming_normal_(m.weight)
|
| 53 |
+
if m.bias is not None:
|
| 54 |
+
nn.init.constant_(m.bias, 0)
|
| 55 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 56 |
+
nn.init.constant_(m.weight, 1)
|
| 57 |
+
nn.init.constant_(m.bias, 0)
|
| 58 |
+
nn.init.constant_(self.classifier.fc3.weight, 0)
|
| 59 |
+
nn.init.constant_(self.classifier.fc3.bias, 0)
|
| 60 |
+
|
| 61 |
+
def forward(self, input):
|
| 62 |
+
y = self.layer_one(input)
|
| 63 |
+
self.layer_one_out = y
|
| 64 |
+
self.layer_one_out.requires_grad_()
|
| 65 |
+
self.layer_one_out.retain_grad()
|
| 66 |
+
features = self.feature_extractor(y)
|
| 67 |
+
logits = self.classifier(features.view(-1, 64 * 4 * 4))
|
| 68 |
+
return logits
|
| 69 |
+
|
| 70 |
+
|
deeprobust/image/netmodels/resnet.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
Properly implemented ResNet-s for CIFAR10 as described in paper [1].
|
| 3 |
+
|
| 4 |
+
This implementation is from Yerlan Idelbayev.
|
| 5 |
+
|
| 6 |
+
Reference
|
| 7 |
+
---------
|
| 8 |
+
..[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
|
| 9 |
+
Deep Residual Learning for Image Recognition. arXiv:1512.03385
|
| 10 |
+
..[2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
|
| 11 |
+
|
| 12 |
+
'''
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class BasicBlock(nn.Module):
|
| 20 |
+
expansion = 1
|
| 21 |
+
|
| 22 |
+
def __init__(self, in_planes, planes, stride=1):
|
| 23 |
+
super(BasicBlock, self).__init__()
|
| 24 |
+
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
| 25 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 26 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
| 27 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 28 |
+
|
| 29 |
+
self.shortcut = nn.Sequential()
|
| 30 |
+
if stride != 1 or in_planes != self.expansion*planes:
|
| 31 |
+
self.shortcut = nn.Sequential(
|
| 32 |
+
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
|
| 33 |
+
nn.BatchNorm2d(self.expansion*planes)
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
| 38 |
+
out = self.bn2(self.conv2(out))
|
| 39 |
+
out += self.shortcut(x)
|
| 40 |
+
out = F.relu(out)
|
| 41 |
+
return out
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class Bottleneck(nn.Module):
|
| 45 |
+
expansion = 4
|
| 46 |
+
|
| 47 |
+
def __init__(self, in_planes, planes, stride=1):
|
| 48 |
+
super(Bottleneck, self).__init__()
|
| 49 |
+
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
|
| 50 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 51 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
| 52 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 53 |
+
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
|
| 54 |
+
self.bn3 = nn.BatchNorm2d(self.expansion*planes)
|
| 55 |
+
|
| 56 |
+
self.shortcut = nn.Sequential()
|
| 57 |
+
if stride != 1 or in_planes != self.expansion*planes:
|
| 58 |
+
self.shortcut = nn.Sequential(
|
| 59 |
+
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
|
| 60 |
+
nn.BatchNorm2d(self.expansion*planes)
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
def forward(self, x):
|
| 64 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
| 65 |
+
out = F.relu(self.bn2(self.conv2(out)))
|
| 66 |
+
out = self.bn3(self.conv3(out))
|
| 67 |
+
out += self.shortcut(x)
|
| 68 |
+
out = F.relu(out)
|
| 69 |
+
return out
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class Net(nn.Module):
|
| 73 |
+
def __init__(self, block, num_blocks, num_classes=10):
|
| 74 |
+
"""__init__.
|
| 75 |
+
|
| 76 |
+
Parameters
|
| 77 |
+
----------
|
| 78 |
+
block :
|
| 79 |
+
block
|
| 80 |
+
num_blocks :
|
| 81 |
+
num_blocks
|
| 82 |
+
num_classes :
|
| 83 |
+
num_classes
|
| 84 |
+
"""
|
| 85 |
+
super(Net, self).__init__()
|
| 86 |
+
self.in_planes = 64
|
| 87 |
+
|
| 88 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
|
| 89 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 90 |
+
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
|
| 91 |
+
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
|
| 92 |
+
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
|
| 93 |
+
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
|
| 94 |
+
self.linear = nn.Linear(512*block.expansion, num_classes)
|
| 95 |
+
|
| 96 |
+
def _make_layer(self, block, planes, num_blocks, stride):
|
| 97 |
+
strides = [stride] + [1]*(num_blocks-1)
|
| 98 |
+
layers = []
|
| 99 |
+
for stride in strides:
|
| 100 |
+
layers.append(block(self.in_planes, planes, stride))
|
| 101 |
+
self.in_planes = planes * block.expansion
|
| 102 |
+
return nn.Sequential(*layers)
|
| 103 |
+
|
| 104 |
+
def forward(self, x):
|
| 105 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
| 106 |
+
out = self.layer1(out)
|
| 107 |
+
out = self.layer2(out)
|
| 108 |
+
out = self.layer3(out)
|
| 109 |
+
out = self.layer4(out)
|
| 110 |
+
out = F.avg_pool2d(out, 4)
|
| 111 |
+
out = out.view(out.size(0), -1)
|
| 112 |
+
out = self.linear(out)
|
| 113 |
+
return out
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def ResNet18():
|
| 117 |
+
return Net(BasicBlock, [2,2,2,2])
|
| 118 |
+
|
| 119 |
+
def ResNet34():
|
| 120 |
+
return Net(BasicBlock, [3,4,6,3])
|
| 121 |
+
|
| 122 |
+
def ResNet50():
|
| 123 |
+
return Net(Bottleneck, [3,4,6,3])
|
| 124 |
+
|
| 125 |
+
def ResNet101():
|
| 126 |
+
return Net(Bottleneck, [3,4,23,3])
|
| 127 |
+
|
| 128 |
+
def ResNet152():
|
| 129 |
+
return Net(Bottleneck, [3,8,36,3])
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def test(model, device, test_loader):
|
| 133 |
+
model.eval()
|
| 134 |
+
|
| 135 |
+
test_loss = 0
|
| 136 |
+
correct = 0
|
| 137 |
+
with torch.no_grad():
|
| 138 |
+
for data, target in test_loader:
|
| 139 |
+
data, target = data.to(device), target.to(device)
|
| 140 |
+
output = model(data)
|
| 141 |
+
#test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
|
| 142 |
+
loss = F.cross_entropy(output, target)
|
| 143 |
+
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
|
| 144 |
+
correct += pred.eq(target.view_as(pred)).sum().item()
|
| 145 |
+
|
| 146 |
+
test_loss /= len(test_loader.dataset)
|
| 147 |
+
|
| 148 |
+
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
|
| 149 |
+
test_loss, correct, len(test_loader.dataset),
|
| 150 |
+
100. * correct / len(test_loader.dataset)))
|
| 151 |
+
|
| 152 |
+
def train(model, device, train_loader, optimizer, epoch):
|
| 153 |
+
model.train()
|
| 154 |
+
|
| 155 |
+
# lr = util.adjust_learning_rate(optimizer, epoch, args) # don't need it if we use Adam
|
| 156 |
+
|
| 157 |
+
for batch_idx, (data, target) in enumerate(train_loader):
|
| 158 |
+
data, target = torch.tensor(data).to(device), torch.tensor(target).to(device)
|
| 159 |
+
optimizer.zero_grad()
|
| 160 |
+
output = model(data)
|
| 161 |
+
# loss = F.nll_loss(output, target)
|
| 162 |
+
loss = F.cross_entropy(output, target)
|
| 163 |
+
loss.backward()
|
| 164 |
+
optimizer.step()
|
| 165 |
+
if batch_idx % 10 == 0:
|
| 166 |
+
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
|
| 167 |
+
epoch, batch_idx * len(data), len(train_loader.dataset),
|
| 168 |
+
100. * batch_idx / len(train_loader), loss.item()))
|
deeprobust/image/netmodels/train_model.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This function help to train model of different archtecture easily. Select model archtecture and training data, then output corresponding model.
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
from __future__ import print_function
|
| 6 |
+
import os
|
| 7 |
+
import argparse
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F #233
|
| 11 |
+
import torch.optim as optim
|
| 12 |
+
from torchvision import datasets, transforms
|
| 13 |
+
import numpy as np
|
| 14 |
+
from PIL import Image
|
| 15 |
+
|
| 16 |
+
def train(model, data, device, maxepoch, data_path = './', save_per_epoch = 10, seed = 100):
|
| 17 |
+
"""train.
|
| 18 |
+
|
| 19 |
+
Parameters
|
| 20 |
+
----------
|
| 21 |
+
model :
|
| 22 |
+
model(option:'CNN', 'ResNet18', 'ResNet34', 'ResNet50', 'densenet', 'vgg11', 'vgg13', 'vgg16', 'vgg19')
|
| 23 |
+
data :
|
| 24 |
+
data(option:'MNIST','CIFAR10')
|
| 25 |
+
device :
|
| 26 |
+
device(option:'cpu', 'cuda')
|
| 27 |
+
maxepoch :
|
| 28 |
+
training epoch
|
| 29 |
+
data_path :
|
| 30 |
+
data path(default = './')
|
| 31 |
+
save_per_epoch :
|
| 32 |
+
save_per_epoch(default = 10)
|
| 33 |
+
seed :
|
| 34 |
+
seed
|
| 35 |
+
|
| 36 |
+
Examples
|
| 37 |
+
--------
|
| 38 |
+
>>>import deeprobust.image.netmodels.train_model as trainmodel
|
| 39 |
+
>>>trainmodel.train('CNN', 'MNIST', 'cuda', 20)
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
torch.manual_seed(seed)
|
| 43 |
+
|
| 44 |
+
train_loader, test_loader = feed_dataset(data, data_path)
|
| 45 |
+
|
| 46 |
+
if (model == 'CNN'):
|
| 47 |
+
import deeprobust.image.netmodels.CNN as MODEL
|
| 48 |
+
#from deeprobust.image.netmodels.CNN import Net
|
| 49 |
+
train_net = MODEL.Net().to(device)
|
| 50 |
+
|
| 51 |
+
elif (model == 'ResNet18'):
|
| 52 |
+
import deeprobust.image.netmodels.resnet as MODEL
|
| 53 |
+
train_net = MODEL.ResNet18().to(device)
|
| 54 |
+
|
| 55 |
+
elif (model == 'ResNet34'):
|
| 56 |
+
import deeprobust.image.netmodels.resnet as MODEL
|
| 57 |
+
train_net = MODEL.ResNet34().to(device)
|
| 58 |
+
|
| 59 |
+
elif (model == 'ResNet50'):
|
| 60 |
+
import deeprobust.image.netmodels.resnet as MODEL
|
| 61 |
+
train_net = MODEL.ResNet50().to(device)
|
| 62 |
+
|
| 63 |
+
elif (model == 'densenet'):
|
| 64 |
+
import deeprobust.image.netmodels.densenet as MODEL
|
| 65 |
+
train_net = MODEL.densenet_cifar().to(device)
|
| 66 |
+
|
| 67 |
+
elif (model == 'vgg11'):
|
| 68 |
+
import deeprobust.image.netmodels.vgg as MODEL
|
| 69 |
+
train_net = MODEL.VGG('VGG11').to(device)
|
| 70 |
+
elif (model == 'vgg13'):
|
| 71 |
+
import deeprobust.image.netmodels.vgg as MODEL
|
| 72 |
+
train_net = MODEL.VGG('VGG13').to(device)
|
| 73 |
+
elif (model == 'vgg16'):
|
| 74 |
+
import deeprobust.image.netmodels.vgg as MODEL
|
| 75 |
+
train_net = MODEL.VGG('VGG16').to(device)
|
| 76 |
+
elif (model == 'vgg19'):
|
| 77 |
+
import deeprobust.image.netmodels.vgg as MODEL
|
| 78 |
+
train_net = MODEL.VGG('VGG19').to(device)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
optimizer = optim.SGD(train_net.parameters(), lr= 0.1, momentum=0.5)
|
| 83 |
+
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size = 100, gamma = 0.1)
|
| 84 |
+
save_model = True
|
| 85 |
+
for epoch in range(1, maxepoch + 1): ## 5 batches
|
| 86 |
+
|
| 87 |
+
print(epoch)
|
| 88 |
+
MODEL.train(train_net, device, train_loader, optimizer, epoch)
|
| 89 |
+
MODEL.test(train_net, device, test_loader)
|
| 90 |
+
|
| 91 |
+
if (save_model and (epoch % (save_per_epoch) == 0 or epoch == maxepoch)):
|
| 92 |
+
if os.path.isdir('./trained_models/'):
|
| 93 |
+
print('Save model.')
|
| 94 |
+
torch.save(train_net.state_dict(), os.path.join('trained_models', data + "_" + model + "_epoch_" + str(epoch) + ".pt"))
|
| 95 |
+
else:
|
| 96 |
+
os.mkdir('./trained_models/')
|
| 97 |
+
print('Make directory and save model.')
|
| 98 |
+
torch.save(train_net.state_dict(), os.path.join('trained_models', data + "_" + model + "_epoch_" + str(epoch) + ".pt"))
|
| 99 |
+
scheduler.step()
|
| 100 |
+
|
| 101 |
+
def feed_dataset(data, data_dict):
|
| 102 |
+
if(data == 'CIFAR10'):
|
| 103 |
+
transform_train = transforms.Compose([
|
| 104 |
+
transforms.RandomCrop(32, padding=5),
|
| 105 |
+
transforms.RandomHorizontalFlip(),
|
| 106 |
+
transforms.ToTensor(),
|
| 107 |
+
#transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
| 108 |
+
])
|
| 109 |
+
|
| 110 |
+
transform_val = transforms.Compose([
|
| 111 |
+
transforms.ToTensor(),
|
| 112 |
+
#transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
| 113 |
+
])
|
| 114 |
+
|
| 115 |
+
train_loader = torch.utils.data.DataLoader(
|
| 116 |
+
datasets.CIFAR10(data_dict, train=True, download = True,
|
| 117 |
+
transform=transform_train),
|
| 118 |
+
batch_size= 128, shuffle=True) #, **kwargs)
|
| 119 |
+
|
| 120 |
+
test_loader = torch.utils.data.DataLoader(
|
| 121 |
+
datasets.CIFAR10(data_dict, train=False, download = True,
|
| 122 |
+
transform=transform_val),
|
| 123 |
+
batch_size= 1000, shuffle=True) #, **kwargs)
|
| 124 |
+
|
| 125 |
+
elif(data == 'MNIST'):
|
| 126 |
+
train_loader = torch.utils.data.DataLoader(
|
| 127 |
+
datasets.MNIST(data_dict, train=True, download = True,
|
| 128 |
+
transform=transforms.Compose([transforms.ToTensor(),
|
| 129 |
+
transforms.Normalize((0.1307,), (0.3081,))])),
|
| 130 |
+
batch_size=128,
|
| 131 |
+
shuffle=True)
|
| 132 |
+
|
| 133 |
+
test_loader = torch.utils.data.DataLoader(
|
| 134 |
+
datasets.MNIST(data_dict, train=False, download = True,
|
| 135 |
+
transform=transforms.Compose([transforms.ToTensor(),
|
| 136 |
+
transforms.Normalize((0.1307,), (0.3081,))])),
|
| 137 |
+
batch_size=1000,
|
| 138 |
+
shuffle=True)
|
| 139 |
+
|
| 140 |
+
elif(data == 'ImageNet'):
|
| 141 |
+
pass
|
| 142 |
+
|
| 143 |
+
return train_loader, test_loader
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
|
deeprobust/image/netmodels/train_resnet.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import print_function
|
| 2 |
+
import argparse
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F #233
|
| 6 |
+
import torch.optim as optim
|
| 7 |
+
from torchvision import datasets, transforms
|
| 8 |
+
import numpy as np
|
| 9 |
+
from PIL import Image
|
| 10 |
+
import CNNmodel
|
| 11 |
+
|
| 12 |
+
torch.manual_seed(100)
|
| 13 |
+
device = torch.device("cuda")
|
| 14 |
+
|
| 15 |
+
train_loader = torch.utils.data.DataLoader(
|
| 16 |
+
datasets.MNIST('../data', train=True, download=True,
|
| 17 |
+
transform=transforms.Compose([transforms.ToTensor(),
|
| 18 |
+
transforms.Normalize((0.1307,), (0.3081,))])),
|
| 19 |
+
batch_size=64,
|
| 20 |
+
shuffle=True)
|
| 21 |
+
|
| 22 |
+
test_loader = torch.utils.data.DataLoader(
|
| 23 |
+
datasets.MNIST('../data', train=False,
|
| 24 |
+
transform=transforms.Compose([transforms.ToTensor(),
|
| 25 |
+
transforms.Normalize((0.1307,), (0.3081,))])),
|
| 26 |
+
batch_size=1000,
|
| 27 |
+
shuffle=True)
|
| 28 |
+
|
| 29 |
+
model = CNNmodel.Net().to(device)
|
| 30 |
+
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
|
| 31 |
+
|
| 32 |
+
save_model = True
|
| 33 |
+
for epoch in range(1, 5 + 1): ## 5 batches
|
| 34 |
+
print(epoch)
|
| 35 |
+
CNNmodel.train(model, device, train_loader, optimizer, epoch)
|
| 36 |
+
CNNmodel.test(model, device, test_loader)
|
| 37 |
+
|
| 38 |
+
if (save_model):
|
| 39 |
+
torch.save(model.state_dict(), "mnist_cnn.pt")
|
deeprobust/image/netmodels/vgg.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This is an implementation of VGG net.
|
| 3 |
+
|
| 4 |
+
Reference
|
| 5 |
+
---------
|
| 6 |
+
..[1]Simonyan, Karen, and Andrew Zisserman. "Very deep convolutional networks for large-scale image recognition." arXiv preprint arXiv:1409.1556 (2014).
|
| 7 |
+
..[2]Original implementation: https://github.com/kuangliu/pytorch-cifar
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
cfg = {
|
| 15 |
+
'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
|
| 16 |
+
'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
|
| 17 |
+
'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
|
| 18 |
+
'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class VGG(nn.Module):
|
| 23 |
+
"""VGG.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, vgg_name):
|
| 27 |
+
super(VGG, self).__init__()
|
| 28 |
+
self.features = self._make_layers(cfg[vgg_name])
|
| 29 |
+
self.classifier = nn.Linear(512, 10)
|
| 30 |
+
|
| 31 |
+
def forward(self, x):
|
| 32 |
+
out = self.features(x)
|
| 33 |
+
out = out.view(out.size(0), -1)
|
| 34 |
+
out = self.classifier(out)
|
| 35 |
+
return out
|
| 36 |
+
|
| 37 |
+
def _make_layers(self, cfg):
|
| 38 |
+
layers = []
|
| 39 |
+
in_channels = 3
|
| 40 |
+
for x in cfg:
|
| 41 |
+
if x == 'M':
|
| 42 |
+
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
|
| 43 |
+
else:
|
| 44 |
+
layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
|
| 45 |
+
nn.BatchNorm2d(x),
|
| 46 |
+
nn.ReLU(inplace=True)]
|
| 47 |
+
in_channels = x
|
| 48 |
+
layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
|
| 49 |
+
return nn.Sequential(*layers)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def test(model, device, test_loader):
|
| 54 |
+
"""test.
|
| 55 |
+
|
| 56 |
+
Parameters
|
| 57 |
+
----------
|
| 58 |
+
model :
|
| 59 |
+
model
|
| 60 |
+
device :
|
| 61 |
+
device
|
| 62 |
+
test_loader :
|
| 63 |
+
test_loader
|
| 64 |
+
"""
|
| 65 |
+
model.eval()
|
| 66 |
+
|
| 67 |
+
test_loss = 0
|
| 68 |
+
correct = 0
|
| 69 |
+
with torch.no_grad():
|
| 70 |
+
for data, target in test_loader:
|
| 71 |
+
data, target = data.to(device), target.to(device)
|
| 72 |
+
output = model(data)
|
| 73 |
+
#test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
|
| 74 |
+
|
| 75 |
+
test_loss += F.cross_entropy(output, target)
|
| 76 |
+
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
|
| 77 |
+
correct += pred.eq(target.view_as(pred)).sum().item()
|
| 78 |
+
|
| 79 |
+
test_loss /= len(test_loader.dataset)
|
| 80 |
+
|
| 81 |
+
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
|
| 82 |
+
test_loss, correct, len(test_loader.dataset),
|
| 83 |
+
100. * correct / len(test_loader.dataset)))
|
| 84 |
+
|
| 85 |
+
def train(model, device, train_loader, optimizer, epoch):
|
| 86 |
+
"""train.
|
| 87 |
+
|
| 88 |
+
Parameters
|
| 89 |
+
----------
|
| 90 |
+
model :
|
| 91 |
+
model
|
| 92 |
+
device :
|
| 93 |
+
device
|
| 94 |
+
train_loader :
|
| 95 |
+
train_loader
|
| 96 |
+
optimizer :
|
| 97 |
+
optimizer
|
| 98 |
+
epoch :
|
| 99 |
+
epoch
|
| 100 |
+
"""
|
| 101 |
+
model.train()
|
| 102 |
+
|
| 103 |
+
# lr = util.adjust_learning_rate(optimizer, epoch, args) # don't need it if we use Adam
|
| 104 |
+
|
| 105 |
+
for batch_idx, (data, target) in enumerate(train_loader):
|
| 106 |
+
data, target = torch.tensor(data).to(device), torch.tensor(target).to(device)
|
| 107 |
+
optimizer.zero_grad()
|
| 108 |
+
output = model(data)
|
| 109 |
+
# loss = F.nll_loss(output, target)
|
| 110 |
+
loss = F.cross_entropy(output, target)
|
| 111 |
+
loss.backward()
|
| 112 |
+
optimizer.step()
|
| 113 |
+
if batch_idx % 10 == 0:
|
| 114 |
+
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
|
| 115 |
+
epoch, batch_idx * len(data), len(train_loader.dataset),
|
| 116 |
+
100. * batch_idx / len(train_loader), loss.item()/data.shape[0]))
|
deeprobust/image/synset_words.txt
ADDED
|
@@ -0,0 +1,1000 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
n01440764 tench, Tinca tinca
|
| 2 |
+
n01443537 goldfish, Carassius auratus
|
| 3 |
+
n01484850 great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias
|
| 4 |
+
n01491361 tiger shark, Galeocerdo cuvieri
|
| 5 |
+
n01494475 hammerhead, hammerhead shark
|
| 6 |
+
n01496331 electric ray, crampfish, numbfish, torpedo
|
| 7 |
+
n01498041 stingray
|
| 8 |
+
n01514668 cock
|
| 9 |
+
n01514859 hen
|
| 10 |
+
n01518878 ostrich, Struthio camelus
|
| 11 |
+
n01530575 brambling, Fringilla montifringilla
|
| 12 |
+
n01531178 goldfinch, Carduelis carduelis
|
| 13 |
+
n01532829 house finch, linnet, Carpodacus mexicanus
|
| 14 |
+
n01534433 junco, snowbird
|
| 15 |
+
n01537544 indigo bunting, indigo finch, indigo bird, Passerina cyanea
|
| 16 |
+
n01558993 robin, American robin, Turdus migratorius
|
| 17 |
+
n01560419 bulbul
|
| 18 |
+
n01580077 jay
|
| 19 |
+
n01582220 magpie
|
| 20 |
+
n01592084 chickadee
|
| 21 |
+
n01601694 water ouzel, dipper
|
| 22 |
+
n01608432 kite
|
| 23 |
+
n01614925 bald eagle, American eagle, Haliaeetus leucocephalus
|
| 24 |
+
n01616318 vulture
|
| 25 |
+
n01622779 great grey owl, great gray owl, Strix nebulosa
|
| 26 |
+
n01629819 European fire salamander, Salamandra salamandra
|
| 27 |
+
n01630670 common newt, Triturus vulgaris
|
| 28 |
+
n01631663 eft
|
| 29 |
+
n01632458 spotted salamander, Ambystoma maculatum
|
| 30 |
+
n01632777 axolotl, mud puppy, Ambystoma mexicanum
|
| 31 |
+
n01641577 bullfrog, Rana catesbeiana
|
| 32 |
+
n01644373 tree frog, tree-frog
|
| 33 |
+
n01644900 tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui
|
| 34 |
+
n01664065 loggerhead, loggerhead turtle, Caretta caretta
|
| 35 |
+
n01665541 leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea
|
| 36 |
+
n01667114 mud turtle
|
| 37 |
+
n01667778 terrapin
|
| 38 |
+
n01669191 box turtle, box tortoise
|
| 39 |
+
n01675722 banded gecko
|
| 40 |
+
n01677366 common iguana, iguana, Iguana iguana
|
| 41 |
+
n01682714 American chameleon, anole, Anolis carolinensis
|
| 42 |
+
n01685808 whiptail, whiptail lizard
|
| 43 |
+
n01687978 agama
|
| 44 |
+
n01688243 frilled lizard, Chlamydosaurus kingi
|
| 45 |
+
n01689811 alligator lizard
|
| 46 |
+
n01692333 Gila monster, Heloderma suspectum
|
| 47 |
+
n01693334 green lizard, Lacerta viridis
|
| 48 |
+
n01694178 African chameleon, Chamaeleo chamaeleon
|
| 49 |
+
n01695060 Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis
|
| 50 |
+
n01697457 African crocodile, Nile crocodile, Crocodylus niloticus
|
| 51 |
+
n01698640 American alligator, Alligator mississipiensis
|
| 52 |
+
n01704323 triceratops
|
| 53 |
+
n01728572 thunder snake, worm snake, Carphophis amoenus
|
| 54 |
+
n01728920 ringneck snake, ring-necked snake, ring snake
|
| 55 |
+
n01729322 hognose snake, puff adder, sand viper
|
| 56 |
+
n01729977 green snake, grass snake
|
| 57 |
+
n01734418 king snake, kingsnake
|
| 58 |
+
n01735189 garter snake, grass snake
|
| 59 |
+
n01737021 water snake
|
| 60 |
+
n01739381 vine snake
|
| 61 |
+
n01740131 night snake, Hypsiglena torquata
|
| 62 |
+
n01742172 boa constrictor, Constrictor constrictor
|
| 63 |
+
n01744401 rock python, rock snake, Python sebae
|
| 64 |
+
n01748264 Indian cobra, Naja naja
|
| 65 |
+
n01749939 green mamba
|
| 66 |
+
n01751748 sea snake
|
| 67 |
+
n01753488 horned viper, cerastes, sand viper, horned asp, Cerastes cornutus
|
| 68 |
+
n01755581 diamondback, diamondback rattlesnake, Crotalus adamanteus
|
| 69 |
+
n01756291 sidewinder, horned rattlesnake, Crotalus cerastes
|
| 70 |
+
n01768244 trilobite
|
| 71 |
+
n01770081 harvestman, daddy longlegs, Phalangium opilio
|
| 72 |
+
n01770393 scorpion
|
| 73 |
+
n01773157 black and gold garden spider, Argiope aurantia
|
| 74 |
+
n01773549 barn spider, Araneus cavaticus
|
| 75 |
+
n01773797 garden spider, Aranea diademata
|
| 76 |
+
n01774384 black widow, Latrodectus mactans
|
| 77 |
+
n01774750 tarantula
|
| 78 |
+
n01775062 wolf spider, hunting spider
|
| 79 |
+
n01776313 tick
|
| 80 |
+
n01784675 centipede
|
| 81 |
+
n01795545 black grouse
|
| 82 |
+
n01796340 ptarmigan
|
| 83 |
+
n01797886 ruffed grouse, partridge, Bonasa umbellus
|
| 84 |
+
n01798484 prairie chicken, prairie grouse, prairie fowl
|
| 85 |
+
n01806143 peacock
|
| 86 |
+
n01806567 quail
|
| 87 |
+
n01807496 partridge
|
| 88 |
+
n01817953 African grey, African gray, Psittacus erithacus
|
| 89 |
+
n01818515 macaw
|
| 90 |
+
n01819313 sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita
|
| 91 |
+
n01820546 lorikeet
|
| 92 |
+
n01824575 coucal
|
| 93 |
+
n01828970 bee eater
|
| 94 |
+
n01829413 hornbill
|
| 95 |
+
n01833805 hummingbird
|
| 96 |
+
n01843065 jacamar
|
| 97 |
+
n01843383 toucan
|
| 98 |
+
n01847000 drake
|
| 99 |
+
n01855032 red-breasted merganser, Mergus serrator
|
| 100 |
+
n01855672 goose
|
| 101 |
+
n01860187 black swan, Cygnus atratus
|
| 102 |
+
n01871265 tusker
|
| 103 |
+
n01872401 echidna, spiny anteater, anteater
|
| 104 |
+
n01873310 platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus
|
| 105 |
+
n01877812 wallaby, brush kangaroo
|
| 106 |
+
n01882714 koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus
|
| 107 |
+
n01883070 wombat
|
| 108 |
+
n01910747 jellyfish
|
| 109 |
+
n01914609 sea anemone, anemone
|
| 110 |
+
n01917289 brain coral
|
| 111 |
+
n01924916 flatworm, platyhelminth
|
| 112 |
+
n01930112 nematode, nematode worm, roundworm
|
| 113 |
+
n01943899 conch
|
| 114 |
+
n01944390 snail
|
| 115 |
+
n01945685 slug
|
| 116 |
+
n01950731 sea slug, nudibranch
|
| 117 |
+
n01955084 chiton, coat-of-mail shell, sea cradle, polyplacophore
|
| 118 |
+
n01968897 chambered nautilus, pearly nautilus, nautilus
|
| 119 |
+
n01978287 Dungeness crab, Cancer magister
|
| 120 |
+
n01978455 rock crab, Cancer irroratus
|
| 121 |
+
n01980166 fiddler crab
|
| 122 |
+
n01981276 king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica
|
| 123 |
+
n01983481 American lobster, Northern lobster, Maine lobster, Homarus americanus
|
| 124 |
+
n01984695 spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish
|
| 125 |
+
n01985128 crayfish, crawfish, crawdad, crawdaddy
|
| 126 |
+
n01986214 hermit crab
|
| 127 |
+
n01990800 isopod
|
| 128 |
+
n02002556 white stork, Ciconia ciconia
|
| 129 |
+
n02002724 black stork, Ciconia nigra
|
| 130 |
+
n02006656 spoonbill
|
| 131 |
+
n02007558 flamingo
|
| 132 |
+
n02009229 little blue heron, Egretta caerulea
|
| 133 |
+
n02009912 American egret, great white heron, Egretta albus
|
| 134 |
+
n02011460 bittern
|
| 135 |
+
n02012849 crane
|
| 136 |
+
n02013706 limpkin, Aramus pictus
|
| 137 |
+
n02017213 European gallinule, Porphyrio porphyrio
|
| 138 |
+
n02018207 American coot, marsh hen, mud hen, water hen, Fulica americana
|
| 139 |
+
n02018795 bustard
|
| 140 |
+
n02025239 ruddy turnstone, Arenaria interpres
|
| 141 |
+
n02027492 red-backed sandpiper, dunlin, Erolia alpina
|
| 142 |
+
n02028035 redshank, Tringa totanus
|
| 143 |
+
n02033041 dowitcher
|
| 144 |
+
n02037110 oystercatcher, oyster catcher
|
| 145 |
+
n02051845 pelican
|
| 146 |
+
n02056570 king penguin, Aptenodytes patagonica
|
| 147 |
+
n02058221 albatross, mollymawk
|
| 148 |
+
n02066245 grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus
|
| 149 |
+
n02071294 killer whale, killer, orca, grampus, sea wolf, Orcinus orca
|
| 150 |
+
n02074367 dugong, Dugong dugon
|
| 151 |
+
n02077923 sea lion
|
| 152 |
+
n02085620 Chihuahua
|
| 153 |
+
n02085782 Japanese spaniel
|
| 154 |
+
n02085936 Maltese dog, Maltese terrier, Maltese
|
| 155 |
+
n02086079 Pekinese, Pekingese, Peke
|
| 156 |
+
n02086240 Shih-Tzu
|
| 157 |
+
n02086646 Blenheim spaniel
|
| 158 |
+
n02086910 papillon
|
| 159 |
+
n02087046 toy terrier
|
| 160 |
+
n02087394 Rhodesian ridgeback
|
| 161 |
+
n02088094 Afghan hound, Afghan
|
| 162 |
+
n02088238 basset, basset hound
|
| 163 |
+
n02088364 beagle
|
| 164 |
+
n02088466 bloodhound, sleuthhound
|
| 165 |
+
n02088632 bluetick
|
| 166 |
+
n02089078 black-and-tan coonhound
|
| 167 |
+
n02089867 Walker hound, Walker foxhound
|
| 168 |
+
n02089973 English foxhound
|
| 169 |
+
n02090379 redbone
|
| 170 |
+
n02090622 borzoi, Russian wolfhound
|
| 171 |
+
n02090721 Irish wolfhound
|
| 172 |
+
n02091032 Italian greyhound
|
| 173 |
+
n02091134 whippet
|
| 174 |
+
n02091244 Ibizan hound, Ibizan Podenco
|
| 175 |
+
n02091467 Norwegian elkhound, elkhound
|
| 176 |
+
n02091635 otterhound, otter hound
|
| 177 |
+
n02091831 Saluki, gazelle hound
|
| 178 |
+
n02092002 Scottish deerhound, deerhound
|
| 179 |
+
n02092339 Weimaraner
|
| 180 |
+
n02093256 Staffordshire bullterrier, Staffordshire bull terrier
|
| 181 |
+
n02093428 American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier
|
| 182 |
+
n02093647 Bedlington terrier
|
| 183 |
+
n02093754 Border terrier
|
| 184 |
+
n02093859 Kerry blue terrier
|
| 185 |
+
n02093991 Irish terrier
|
| 186 |
+
n02094114 Norfolk terrier
|
| 187 |
+
n02094258 Norwich terrier
|
| 188 |
+
n02094433 Yorkshire terrier
|
| 189 |
+
n02095314 wire-haired fox terrier
|
| 190 |
+
n02095570 Lakeland terrier
|
| 191 |
+
n02095889 Sealyham terrier, Sealyham
|
| 192 |
+
n02096051 Airedale, Airedale terrier
|
| 193 |
+
n02096177 cairn, cairn terrier
|
| 194 |
+
n02096294 Australian terrier
|
| 195 |
+
n02096437 Dandie Dinmont, Dandie Dinmont terrier
|
| 196 |
+
n02096585 Boston bull, Boston terrier
|
| 197 |
+
n02097047 miniature schnauzer
|
| 198 |
+
n02097130 giant schnauzer
|
| 199 |
+
n02097209 standard schnauzer
|
| 200 |
+
n02097298 Scotch terrier, Scottish terrier, Scottie
|
| 201 |
+
n02097474 Tibetan terrier, chrysanthemum dog
|
| 202 |
+
n02097658 silky terrier, Sydney silky
|
| 203 |
+
n02098105 soft-coated wheaten terrier
|
| 204 |
+
n02098286 West Highland white terrier
|
| 205 |
+
n02098413 Lhasa, Lhasa apso
|
| 206 |
+
n02099267 flat-coated retriever
|
| 207 |
+
n02099429 curly-coated retriever
|
| 208 |
+
n02099601 golden retriever
|
| 209 |
+
n02099712 Labrador retriever
|
| 210 |
+
n02099849 Chesapeake Bay retriever
|
| 211 |
+
n02100236 German short-haired pointer
|
| 212 |
+
n02100583 vizsla, Hungarian pointer
|
| 213 |
+
n02100735 English setter
|
| 214 |
+
n02100877 Irish setter, red setter
|
| 215 |
+
n02101006 Gordon setter
|
| 216 |
+
n02101388 Brittany spaniel
|
| 217 |
+
n02101556 clumber, clumber spaniel
|
| 218 |
+
n02102040 English springer, English springer spaniel
|
| 219 |
+
n02102177 Welsh springer spaniel
|
| 220 |
+
n02102318 cocker spaniel, English cocker spaniel, cocker
|
| 221 |
+
n02102480 Sussex spaniel
|
| 222 |
+
n02102973 Irish water spaniel
|
| 223 |
+
n02104029 kuvasz
|
| 224 |
+
n02104365 schipperke
|
| 225 |
+
n02105056 groenendael
|
| 226 |
+
n02105162 malinois
|
| 227 |
+
n02105251 briard
|
| 228 |
+
n02105412 kelpie
|
| 229 |
+
n02105505 komondor
|
| 230 |
+
n02105641 Old English sheepdog, bobtail
|
| 231 |
+
n02105855 Shetland sheepdog, Shetland sheep dog, Shetland
|
| 232 |
+
n02106030 collie
|
| 233 |
+
n02106166 Border collie
|
| 234 |
+
n02106382 Bouvier des Flandres, Bouviers des Flandres
|
| 235 |
+
n02106550 Rottweiler
|
| 236 |
+
n02106662 German shepherd, German shepherd dog, German police dog, alsatian
|
| 237 |
+
n02107142 Doberman, Doberman pinscher
|
| 238 |
+
n02107312 miniature pinscher
|
| 239 |
+
n02107574 Greater Swiss Mountain dog
|
| 240 |
+
n02107683 Bernese mountain dog
|
| 241 |
+
n02107908 Appenzeller
|
| 242 |
+
n02108000 EntleBucher``^
|
| 243 |
+
n02108089 boxer`
|
| 244 |
+
n02108422 bull mastif
|
| 245 |
+
n02108551 Tibetan mastiff
|
| 246 |
+
n02108915 French bulldog
|
| 247 |
+
n02109047 Great Dane
|
| 248 |
+
n02109525 Saint Bernard, St Bernard
|
| 249 |
+
n02109961 Eskimo dog, husky
|
| 250 |
+
n02110063 malamute, malemute, Alaskan malamute
|
| 251 |
+
n02110185 Siberian husky
|
| 252 |
+
n02110341 dalmatian, coach dog, carriage dog
|
| 253 |
+
n02110627 affenpinscher, monkey pinscher, monkey dog
|
| 254 |
+
n02110806 basenji
|
| 255 |
+
n02110958 pug, pug-dog
|
| 256 |
+
n02111129 Leonberg
|
| 257 |
+
n02111277 Newfoundland, Newfoundland dog
|
| 258 |
+
n02111500 Great Pyrenees
|
| 259 |
+
n02111889 Samoyed, Samoyede
|
| 260 |
+
n02112018 Pomeranian
|
| 261 |
+
n02112137 chow, chow chow
|
| 262 |
+
n02112350 keeshond
|
| 263 |
+
n02112706 Brabancon griffon
|
| 264 |
+
n02113023 Pembroke, Pembroke Welsh corgi
|
| 265 |
+
n02113186 Cardigan, Cardigan Welsh corgi
|
| 266 |
+
n02113624 toy poodle
|
| 267 |
+
n02113712 miniature poodle
|
| 268 |
+
n02113799 standard poodle
|
| 269 |
+
n02113978 Mexican hairless
|
| 270 |
+
n02114367 timber wolf, grey wolf, gray wolf, Canis lupus
|
| 271 |
+
n02114548 white wolf, Arctic wolf, Canis lupus tundrarum
|
| 272 |
+
n02114712 red wolf, maned wolf, Canis rufus, Canis niger
|
| 273 |
+
n02114855 coyote, prairie wolf, brush wolf, Canis latrans
|
| 274 |
+
n02115641 dingo, warrigal, warragal, Canis dingo
|
| 275 |
+
n02115913 dhole, Cuon alpinus
|
| 276 |
+
n02116738 African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus
|
| 277 |
+
n02117135 hyena, hyaena
|
| 278 |
+
n02119022 red fox, Vulpes vulpes
|
| 279 |
+
n02119789 kit fox, Vulpes macrotis
|
| 280 |
+
n02120079 Arctic fox, white fox, Alopex lagopus
|
| 281 |
+
n02120505 grey fox, gray fox, Urocyon cinereoargenteus
|
| 282 |
+
n02123045 tabby, tabby cat
|
| 283 |
+
n02123159 tiger cat
|
| 284 |
+
n02123394 Persian cat
|
| 285 |
+
n02123597 Siamese cat, Siamese
|
| 286 |
+
n02124075 Egyptian cat
|
| 287 |
+
n02125311 cougar, puma, catamount, mountain lion, painter, panther, Felis concolor
|
| 288 |
+
n02127052 lynx, catamount
|
| 289 |
+
n02128385 leopard, Panthera pardus
|
| 290 |
+
n02128757 snow leopard, ounce, Panthera uncia
|
| 291 |
+
n02128925 jaguar, panther, Panthera onca, Felis onca
|
| 292 |
+
n02129165 lion, king of beasts, Panthera leo
|
| 293 |
+
n02129604 tiger, Panthera tigris
|
| 294 |
+
n02130308 cheetah, chetah, Acinonyx jubatus
|
| 295 |
+
n02132136 brown bear, bruin, Ursus arctos
|
| 296 |
+
n02133161 American black bear, black bear, Ursus americanus, Euarctos americanus
|
| 297 |
+
n02134084 ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus
|
| 298 |
+
n02134418 sloth bear, Melursus ursinus, Ursus ursinus
|
| 299 |
+
n02137549 mongoose
|
| 300 |
+
n02138441 meerkat, mierkat
|
| 301 |
+
n02165105 tiger beetle
|
| 302 |
+
n02165456 ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle
|
| 303 |
+
n02167151 ground beetle, carabid beetle
|
| 304 |
+
n02168699 long-horned beetle, longicorn, longicorn beetle
|
| 305 |
+
n02169497 leaf beetle, chrysomelid
|
| 306 |
+
n02172182 dung beetle
|
| 307 |
+
n02174001 rhinoceros beetle
|
| 308 |
+
n02177972 weevil
|
| 309 |
+
n02190166 fly
|
| 310 |
+
n02206856 bee
|
| 311 |
+
n02219486 ant, emmet, pismire
|
| 312 |
+
n02226429 grasshopper, hopper
|
| 313 |
+
n02229544 cricket
|
| 314 |
+
n02231487 walking stick, walkingstick, stick insect
|
| 315 |
+
n02233338 cockroach, roach
|
| 316 |
+
n02236044 mantis, mantid
|
| 317 |
+
n02256656 cicada, cicala
|
| 318 |
+
n02259212 leafhopper
|
| 319 |
+
n02264363 lacewing, lacewing fly
|
| 320 |
+
n02268443 dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk
|
| 321 |
+
n02268853 damselfly
|
| 322 |
+
n02276258 admiral
|
| 323 |
+
n02277742 ringlet, ringlet butterfly
|
| 324 |
+
n02279972 monarch, monarch butterfly, milkweed butterfly, Danaus plexippus
|
| 325 |
+
n02280649 cabbage butterfly
|
| 326 |
+
n02281406 sulphur butterfly, sulfur butterfly
|
| 327 |
+
n02281787 lycaenid, lycaenid butterfly
|
| 328 |
+
n02317335 starfish, sea star
|
| 329 |
+
n02319095 sea urchin
|
| 330 |
+
n02321529 sea cucumber, holothurian
|
| 331 |
+
n02325366 wood rabbit, cottontail, cottontail rabbit
|
| 332 |
+
n02326432 hare
|
| 333 |
+
n02328150 Angora, Angora rabbit
|
| 334 |
+
n02342885 hamster
|
| 335 |
+
n02346627 porcupine, hedgehog
|
| 336 |
+
n02356798 fox squirrel, eastern fox squirrel, Sciurus niger
|
| 337 |
+
n02361337 marmot
|
| 338 |
+
n02363005 beaver
|
| 339 |
+
n02364673 guinea pig, Cavia cobaya
|
| 340 |
+
n02389026 sorrel
|
| 341 |
+
n02391049 zebra
|
| 342 |
+
n02395406 hog, pig, grunter, squealer, Sus scrofa
|
| 343 |
+
n02396427 wild boar, boar, Sus scrofa
|
| 344 |
+
n02397096 warthog
|
| 345 |
+
n02398521 hippopotamus, hippo, river horse, Hippopotamus amphibius
|
| 346 |
+
n02403003 ox
|
| 347 |
+
n02408429 water buffalo, water ox, Asiatic buffalo, Bubalus bubalis
|
| 348 |
+
n02410509 bison
|
| 349 |
+
n02412080 ram, tup
|
| 350 |
+
n02415577 bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis
|
| 351 |
+
n02417914 ibex, Capra ibex
|
| 352 |
+
n02422106 hartebeest
|
| 353 |
+
n02422699 impala, Aepyceros melampus
|
| 354 |
+
n02423022 gazelle
|
| 355 |
+
n02437312 Arabian camel, dromedary, Camelus dromedarius
|
| 356 |
+
n02437616 llama
|
| 357 |
+
n02441942 weasel
|
| 358 |
+
n02442845 mink
|
| 359 |
+
n02443114 polecat, fitch, foulmart, foumart, Mustela putorius
|
| 360 |
+
n02443484 black-footed ferret, ferret, Mustela nigripes
|
| 361 |
+
n02444819 otter
|
| 362 |
+
n02445715 skunk, polecat, wood pussy
|
| 363 |
+
n02447366 badger
|
| 364 |
+
n02454379 armadillo
|
| 365 |
+
n02457408 three-toed sloth, ai, Bradypus tridactylus
|
| 366 |
+
n02480495 orangutan, orang, orangutang, Pongo pygmaeus
|
| 367 |
+
n02480855 gorilla, Gorilla gorilla
|
| 368 |
+
n02481823 chimpanzee, chimp, Pan troglodytes
|
| 369 |
+
n02483362 gibbon, Hylobates lar
|
| 370 |
+
n02483708 siamang, Hylobates syndactylus, Symphalangus syndactylus
|
| 371 |
+
n02484975 guenon, guenon monkey
|
| 372 |
+
n02486261 patas, hussar monkey, Erythrocebus patas
|
| 373 |
+
n02486410 baboon
|
| 374 |
+
n02487347 macaque
|
| 375 |
+
n02488291 langur
|
| 376 |
+
n02488702 colobus, colobus monkey
|
| 377 |
+
n02489166 proboscis monkey, Nasalis larvatus
|
| 378 |
+
n02490219 marmoset
|
| 379 |
+
n02492035 capuchin, ringtail, Cebus capucinus
|
| 380 |
+
n02492660 howler monkey, howler
|
| 381 |
+
n02493509 titi, titi monkey
|
| 382 |
+
n02493793 spider monkey, Ateles geoffroyi
|
| 383 |
+
n02494079 squirrel monkey, Saimiri sciureus
|
| 384 |
+
n02497673 Madagascar cat, ring-tailed lemur, Lemur catta
|
| 385 |
+
n02500267 indri, indris, Indri indri, Indri brevicaudatus
|
| 386 |
+
n02504013 Indian elephant, Elephas maximus
|
| 387 |
+
n02504458 African elephant, Loxodonta africana
|
| 388 |
+
n02509815 lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens
|
| 389 |
+
n02510455 giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca
|
| 390 |
+
n02514041 barracouta, snoek
|
| 391 |
+
n02526121 eel
|
| 392 |
+
n02536864 coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch
|
| 393 |
+
n02606052 rock beauty, Holocanthus tricolor
|
| 394 |
+
n02607072 anemone fish
|
| 395 |
+
n02640242 sturgeon
|
| 396 |
+
n02641379 gar, garfish, garpike, billfish, Lepisosteus osseus
|
| 397 |
+
n02643566 lionfish
|
| 398 |
+
n02655020 puffer, pufferfish, blowfish, globefish
|
| 399 |
+
n02666196 abacus
|
| 400 |
+
n02667093 abaya
|
| 401 |
+
n02669723 academic gown, academic robe, judge's robe
|
| 402 |
+
n02672831 accordion, piano accordion, squeeze box
|
| 403 |
+
n02676566 acoustic guitar
|
| 404 |
+
n02687172 aircraft carrier, carrier, flattop, attack aircraft carrier
|
| 405 |
+
n02690373 airliner
|
| 406 |
+
n02692877 airship, dirigible
|
| 407 |
+
n02699494 altar
|
| 408 |
+
n02701002 ambulance
|
| 409 |
+
n02704792 amphibian, amphibious vehicle
|
| 410 |
+
n02708093 analog clock
|
| 411 |
+
n02727426 apiary, bee house
|
| 412 |
+
n02730930 apron
|
| 413 |
+
n02747177 ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin
|
| 414 |
+
n02749479 assault rifle, assault gun
|
| 415 |
+
n02769748 backpack, back pack, knapsack, packsack, rucksack, haversack
|
| 416 |
+
n02776631 bakery, bakeshop, bakehouse
|
| 417 |
+
n02777292 balance beam, beam
|
| 418 |
+
n02782093 balloon
|
| 419 |
+
n02783161 ballpoint, ballpoint pen, ballpen, Biro
|
| 420 |
+
n02786058 Band Aid
|
| 421 |
+
n02787622 banjo
|
| 422 |
+
n02788148 bannister, banister, balustrade, balusters, handrail
|
| 423 |
+
n02790996 barbell
|
| 424 |
+
n02791124 barber chair
|
| 425 |
+
n02791270 barbershop
|
| 426 |
+
n02793495 barn
|
| 427 |
+
n02794156 barometer
|
| 428 |
+
n02795169 barrel, cask
|
| 429 |
+
n02797295 barrow, garden cart, lawn cart, wheelbarrow
|
| 430 |
+
n02799071 baseball
|
| 431 |
+
n02802426 basketball
|
| 432 |
+
n02804414 bassinet
|
| 433 |
+
n02804610 bassoon
|
| 434 |
+
n02807133 bathing cap, swimming cap
|
| 435 |
+
n02808304 bath towel
|
| 436 |
+
n02808440 bathtub, bathing tub, bath, tub
|
| 437 |
+
n02814533 beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon
|
| 438 |
+
n02814860 beacon, lighthouse, beacon light, pharos
|
| 439 |
+
n02815834 beaker
|
| 440 |
+
n02817516 bearskin, busby, shako
|
| 441 |
+
n02823428 beer bottle
|
| 442 |
+
n02823750 beer glass
|
| 443 |
+
n02825657 bell cote, bell cot
|
| 444 |
+
n02834397 bib
|
| 445 |
+
n02835271 bicycle-built-for-two, tandem bicycle, tandem
|
| 446 |
+
n02837789 bikini, two-piece
|
| 447 |
+
n02840245 binder, ring-binder
|
| 448 |
+
n02841315 binoculars, field glasses, opera glasses
|
| 449 |
+
n02843684 birdhouse
|
| 450 |
+
n02859443 boathouse
|
| 451 |
+
n02860847 bobsled, bobsleigh, bob
|
| 452 |
+
n02865351 bolo tie, bolo, bola tie, bola
|
| 453 |
+
n02869837 bonnet, poke bonnet
|
| 454 |
+
n02870880 bookcase
|
| 455 |
+
n02871525 bookshop, bookstore, bookstall
|
| 456 |
+
n02877765 bottlecap
|
| 457 |
+
n02879718 bow
|
| 458 |
+
n02883205 bow tie, bow-tie, bowtie
|
| 459 |
+
n02892201 brass, memorial tablet, plaque
|
| 460 |
+
n02892767 brassiere, bra, bandeau
|
| 461 |
+
n02894605 breakwater, groin, groyne, mole, bulwark, seawall, jetty
|
| 462 |
+
n02895154 breastplate, aegis, egis
|
| 463 |
+
n02906734 broom
|
| 464 |
+
n02909870 bucket, pail
|
| 465 |
+
n02910353 buckle
|
| 466 |
+
n02916936 bulletproof vest
|
| 467 |
+
n02917067 bullet train, bullet
|
| 468 |
+
n02927161 butcher shop, meat market
|
| 469 |
+
n02930766 cab, hack, taxi, taxicab
|
| 470 |
+
n02939185 caldron, cauldron
|
| 471 |
+
n02948072 candle, taper, wax light
|
| 472 |
+
n02950826 cannon
|
| 473 |
+
n02951358 canoe
|
| 474 |
+
n02951585 can opener, tin opener
|
| 475 |
+
n02963159 cardigan
|
| 476 |
+
n02965783 car mirror
|
| 477 |
+
n02966193 carousel, carrousel, merry-go-round, roundabout, whirligig
|
| 478 |
+
n02966687 carpenter's kit, tool kit
|
| 479 |
+
n02971356 carton
|
| 480 |
+
n02974003 car wheel
|
| 481 |
+
n02977058 cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM
|
| 482 |
+
n02978881 cassette
|
| 483 |
+
n02979186 cassette player
|
| 484 |
+
n02980441 castle
|
| 485 |
+
n02981792 catamaran
|
| 486 |
+
n02988304 CD player
|
| 487 |
+
n02992211 cello, violoncello
|
| 488 |
+
n02992529 cellular telephone, cellular phone, cellphone, cell, mobile phone
|
| 489 |
+
n02999410 chain
|
| 490 |
+
n03000134 chainlink fence
|
| 491 |
+
n03000247 chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour
|
| 492 |
+
n03000684 chain saw, chainsaw
|
| 493 |
+
n03014705 chest
|
| 494 |
+
n03016953 chiffonier, commode
|
| 495 |
+
n03017168 chime, bell, gong
|
| 496 |
+
n03018349 china cabinet, china closet
|
| 497 |
+
n03026506 Christmas stocking
|
| 498 |
+
n03028079 church, church building
|
| 499 |
+
n03032252 cinema, movie theater, movie theatre, movie house, picture palace
|
| 500 |
+
n03041632 cleaver, meat cleaver, chopper
|
| 501 |
+
n03042490 cliff dwelling
|
| 502 |
+
n03045698 cloak
|
| 503 |
+
n03047690 clog, geta, patten, sabot
|
| 504 |
+
n03062245 cocktail shaker
|
| 505 |
+
n03063599 coffee mug
|
| 506 |
+
n03063689 coffeepot
|
| 507 |
+
n03065424 coil, spiral, volute, whorl, helix
|
| 508 |
+
n03075370 combination lock
|
| 509 |
+
n03085013 computer keyboard, keypad
|
| 510 |
+
n03089624 confectionery, confectionary, candy store
|
| 511 |
+
n03095699 container ship, containership, container vessel
|
| 512 |
+
n03100240 convertible
|
| 513 |
+
n03109150 corkscrew, bottle screw
|
| 514 |
+
n03110669 cornet, horn, trumpet, trump
|
| 515 |
+
n03124043 cowboy boot
|
| 516 |
+
n03124170 cowboy hat, ten-gallon hat
|
| 517 |
+
n03125729 cradle
|
| 518 |
+
n03126707 crane
|
| 519 |
+
n03127747 crash helmet
|
| 520 |
+
n03127925 crate
|
| 521 |
+
n03131574 crib, cot
|
| 522 |
+
n03133878 Crock Pot
|
| 523 |
+
n03134739 croquet ball
|
| 524 |
+
n03141823 crutch
|
| 525 |
+
n03146219 cuirass
|
| 526 |
+
n03160309 dam, dike, dyke
|
| 527 |
+
n03179701 desk
|
| 528 |
+
n03180011 desktop computer
|
| 529 |
+
n03187595 dial telephone, dial phone
|
| 530 |
+
n03188531 diaper, nappy, napkin
|
| 531 |
+
n03196217 digital clock
|
| 532 |
+
n03197337 digital watch
|
| 533 |
+
n03201208 dining table, board
|
| 534 |
+
n03207743 dishrag, dishcloth
|
| 535 |
+
n03207941 dishwasher, dish washer, dishwashing machine
|
| 536 |
+
n03208938 disk brake, disc brake
|
| 537 |
+
n03216828 dock, dockage, docking facility
|
| 538 |
+
n03218198 dogsled, dog sled, dog sleigh
|
| 539 |
+
n03220513 dome
|
| 540 |
+
n03223299 doormat, welcome mat
|
| 541 |
+
n03240683 drilling platform, offshore rig
|
| 542 |
+
n03249569 drum, membranophone, tympan
|
| 543 |
+
n03250847 drumstick
|
| 544 |
+
n03255030 dumbbell
|
| 545 |
+
n03259280 Dutch oven
|
| 546 |
+
n03271574 electric fan, blower
|
| 547 |
+
n03272010 electric guitar
|
| 548 |
+
n03272562 electric locomotive
|
| 549 |
+
n03290653 entertainment center
|
| 550 |
+
n03291819 envelope
|
| 551 |
+
n03297495 espresso maker
|
| 552 |
+
n03314780 face powder
|
| 553 |
+
n03325584 feather boa, boa
|
| 554 |
+
n03337140 file, file cabinet, filing cabinet
|
| 555 |
+
n03344393 fireboat
|
| 556 |
+
n03345487 fire engine, fire truck
|
| 557 |
+
n03347037 fire screen, fireguard
|
| 558 |
+
n03355925 flagpole, flagstaff
|
| 559 |
+
n03372029 flute, transverse flute
|
| 560 |
+
n03376595 folding chair
|
| 561 |
+
n03379051 football helmet
|
| 562 |
+
n03384352 forklift
|
| 563 |
+
n03388043 fountain
|
| 564 |
+
n03388183 fountain pen
|
| 565 |
+
n03388549 four-poster
|
| 566 |
+
n03393912 freight car
|
| 567 |
+
n03394916 French horn, horn
|
| 568 |
+
n03400231 frying pan, frypan, skillet
|
| 569 |
+
n03404251 fur coat
|
| 570 |
+
n03417042 garbage truck, dustcart
|
| 571 |
+
n03424325 gasmask, respirator, gas helmet
|
| 572 |
+
n03425413 gas pump, gasoline pump, petrol pump, island dispenser
|
| 573 |
+
n03443371 goblet
|
| 574 |
+
n03444034 go-kart
|
| 575 |
+
n03445777 golf ball
|
| 576 |
+
n03445924 golfcart, golf cart
|
| 577 |
+
n03447447 gondola
|
| 578 |
+
n03447721 gong, tam-tam
|
| 579 |
+
n03450230 gown
|
| 580 |
+
n03452741 grand piano, grand
|
| 581 |
+
n03457902 greenhouse, nursery, glasshouse
|
| 582 |
+
n03459775 grille, radiator grille
|
| 583 |
+
n03461385 grocery store, grocery, food market, market
|
| 584 |
+
n03467068 guillotine
|
| 585 |
+
n03476684 hair slide
|
| 586 |
+
n03476991 hair spray
|
| 587 |
+
n03478589 half track
|
| 588 |
+
n03481172 hammer
|
| 589 |
+
n03482405 hamper
|
| 590 |
+
n03483316 hand blower, blow dryer, blow drier, hair dryer, hair drier
|
| 591 |
+
n03485407 hand-held computer, hand-held microcomputer
|
| 592 |
+
n03485794 handkerchief, hankie, hanky, hankey
|
| 593 |
+
n03492542 hard disc, hard disk, fixed disk
|
| 594 |
+
n03494278 harmonica, mouth organ, harp, mouth harp
|
| 595 |
+
n03495258 harp
|
| 596 |
+
n03496892 harvester, reaper
|
| 597 |
+
n03498962 hatchet
|
| 598 |
+
n03527444 holster
|
| 599 |
+
n03529860 home theater, home theatre
|
| 600 |
+
n03530642 honeycomb
|
| 601 |
+
n03532672 hook, claw
|
| 602 |
+
n03534580 hoopskirt, crinoline
|
| 603 |
+
n03535780 horizontal bar, high bar
|
| 604 |
+
n03538406 horse cart, horse-cart
|
| 605 |
+
n03544143 hourglass
|
| 606 |
+
n03584254 iPod
|
| 607 |
+
n03584829 iron, smoothing iron
|
| 608 |
+
n03590841 jack-o'-lantern
|
| 609 |
+
n03594734 jean, blue jean, denim
|
| 610 |
+
n03594945 jeep, landrover
|
| 611 |
+
n03595614 jersey, T-shirt, tee shirt
|
| 612 |
+
n03598930 jigsaw puzzle
|
| 613 |
+
n03599486 jinrikisha, ricksha, rickshaw
|
| 614 |
+
n03602883 joystick
|
| 615 |
+
n03617480 kimono
|
| 616 |
+
n03623198 knee pad
|
| 617 |
+
n03627232 knot
|
| 618 |
+
n03630383 lab coat, laboratory coat
|
| 619 |
+
n03633091 ladle
|
| 620 |
+
n03637318 lampshade, lamp shade
|
| 621 |
+
n03642806 laptop, laptop computer
|
| 622 |
+
n03649909 lawn mower, mower
|
| 623 |
+
n03657121 lens cap, lens cover
|
| 624 |
+
n03658185 letter opener, paper knife, paperknife
|
| 625 |
+
n03661043 library
|
| 626 |
+
n03662601 lifeboat
|
| 627 |
+
n03666591 lighter, light, igniter, ignitor
|
| 628 |
+
n03670208 limousine, limo
|
| 629 |
+
n03673027 liner, ocean liner
|
| 630 |
+
n03676483 lipstick, lip rouge
|
| 631 |
+
n03680355 Loafer
|
| 632 |
+
n03690938 lotion
|
| 633 |
+
n03691459 loudspeaker, speaker, speaker unit, loudspeaker system, speaker system
|
| 634 |
+
n03692522 loupe, jeweler's loupe
|
| 635 |
+
n03697007 lumbermill, sawmill
|
| 636 |
+
n03706229 magnetic compass
|
| 637 |
+
n03709823 mailbag, postbag
|
| 638 |
+
n03710193 mailbox, letter box
|
| 639 |
+
n03710637 maillot
|
| 640 |
+
n03710721 maillot, tank suit
|
| 641 |
+
n03717622 manhole cover
|
| 642 |
+
n03720891 maraca
|
| 643 |
+
n03721384 marimba, xylophone
|
| 644 |
+
n03724870 mask
|
| 645 |
+
n03729826 matchstick
|
| 646 |
+
n03733131 maypole
|
| 647 |
+
n03733281 maze, labyrinth
|
| 648 |
+
n03733805 measuring cup
|
| 649 |
+
n03742115 medicine chest, medicine cabinet
|
| 650 |
+
n03743016 megalith, megalithic structure
|
| 651 |
+
n03759954 microphone, mike
|
| 652 |
+
n03761084 microwave, microwave oven
|
| 653 |
+
n03763968 military uniform
|
| 654 |
+
n03764736 milk can
|
| 655 |
+
n03769881 minibus
|
| 656 |
+
n03770439 miniskirt, mini
|
| 657 |
+
n03770679 minivan
|
| 658 |
+
n03773504 missile
|
| 659 |
+
n03775071 mitten
|
| 660 |
+
n03775546 mixing bowl
|
| 661 |
+
n03776460 mobile home, manufactured home
|
| 662 |
+
n03777568 Model T
|
| 663 |
+
n03777754 modem
|
| 664 |
+
n03781244 monastery
|
| 665 |
+
n03782006 monitor
|
| 666 |
+
n03785016 moped
|
| 667 |
+
n03786901 mortar
|
| 668 |
+
n03787032 mortarboard
|
| 669 |
+
n03788195 mosque
|
| 670 |
+
n03788365 mosquito net
|
| 671 |
+
n03791053 motor scooter, scooter
|
| 672 |
+
n03792782 mountain bike, all-terrain bike, off-roader
|
| 673 |
+
n03792972 mountain tent
|
| 674 |
+
n03793489 mouse, computer mouse
|
| 675 |
+
n03794056 mousetrap
|
| 676 |
+
n03796401 moving van
|
| 677 |
+
n03803284 muzzle
|
| 678 |
+
n03804744 nail
|
| 679 |
+
n03814639 neck brace
|
| 680 |
+
n03814906 necklace
|
| 681 |
+
n03825788 nipple
|
| 682 |
+
n03832673 notebook, notebook computer
|
| 683 |
+
n03837869 obelisk
|
| 684 |
+
n03838899 oboe, hautboy, hautbois
|
| 685 |
+
n03840681 ocarina, sweet potato
|
| 686 |
+
n03841143 odometer, hodometer, mileometer, milometer
|
| 687 |
+
n03843555 oil filter
|
| 688 |
+
n03854065 organ, pipe organ
|
| 689 |
+
n03857828 oscilloscope, scope, cathode-ray oscilloscope, CRO
|
| 690 |
+
n03866082 overskirt
|
| 691 |
+
n03868242 oxcart
|
| 692 |
+
n03868863 oxygen mask
|
| 693 |
+
n03871628 packet
|
| 694 |
+
n03873416 paddle, boat paddle
|
| 695 |
+
n03874293 paddlewheel, paddle wheel
|
| 696 |
+
n03874599 padlock
|
| 697 |
+
n03876231 paintbrush
|
| 698 |
+
n03877472 pajama, pyjama, pj's, jammies
|
| 699 |
+
n03877845 palace
|
| 700 |
+
n03884397 panpipe, pandean pipe, syrinx
|
| 701 |
+
n03887697 paper towel
|
| 702 |
+
n03888257 parachute, chute
|
| 703 |
+
n03888605 parallel bars, bars
|
| 704 |
+
n03891251 park bench
|
| 705 |
+
n03891332 parking meter
|
| 706 |
+
n03895866 passenger car, coach, carriage
|
| 707 |
+
n03899768 patio, terrace
|
| 708 |
+
n03902125 pay-phone, pay-station
|
| 709 |
+
n03903868 pedestal, plinth, footstall
|
| 710 |
+
n03908618 pencil box, pencil case
|
| 711 |
+
n03908714 pencil sharpener
|
| 712 |
+
n03916031 perfume, essence
|
| 713 |
+
n03920288 Petri dish
|
| 714 |
+
n03924679 photocopier
|
| 715 |
+
n03929660 pick, plectrum, plectron
|
| 716 |
+
n03929855 pickelhaube
|
| 717 |
+
n03930313 picket fence, paling
|
| 718 |
+
n03930630 pickup, pickup truck
|
| 719 |
+
n03933933 pier
|
| 720 |
+
n03935335 piggy bank, penny bank
|
| 721 |
+
n03937543 pill bottle
|
| 722 |
+
n03938244 pillow
|
| 723 |
+
n03942813 ping-pong ball
|
| 724 |
+
n03944341 pinwheel
|
| 725 |
+
n03947888 pirate, pirate ship
|
| 726 |
+
n03950228 pitcher, ewer
|
| 727 |
+
n03954731 plane, carpenter's plane, woodworking plane
|
| 728 |
+
n03956157 planetarium
|
| 729 |
+
n03958227 plastic bag
|
| 730 |
+
n03961711 plate rack
|
| 731 |
+
n03967562 plow, plough
|
| 732 |
+
n03970156 plunger, plumber's helper
|
| 733 |
+
n03976467 Polaroid camera, Polaroid Land camera
|
| 734 |
+
n03976657 pole
|
| 735 |
+
n03977966 police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria
|
| 736 |
+
n03980874 poncho
|
| 737 |
+
n03982430 pool table, billiard table, snooker table
|
| 738 |
+
n03983396 pop bottle, soda bottle
|
| 739 |
+
n03991062 pot, flowerpot
|
| 740 |
+
n03992509 potter's wheel
|
| 741 |
+
n03995372 power drill
|
| 742 |
+
n03998194 prayer rug, prayer mat
|
| 743 |
+
n04004767 printer
|
| 744 |
+
n04005630 prison, prison house
|
| 745 |
+
n04008634 projectile, missile
|
| 746 |
+
n04009552 projector
|
| 747 |
+
n04019541 puck, hockey puck
|
| 748 |
+
n04023962 punching bag, punch bag, punching ball, punchball
|
| 749 |
+
n04026417 purse
|
| 750 |
+
n04033901 quill, quill pen
|
| 751 |
+
n04033995 quilt, comforter, comfort, puff
|
| 752 |
+
n04037443 racer, race car, racing car
|
| 753 |
+
n04039381 racket, racquet
|
| 754 |
+
n04040759 radiator
|
| 755 |
+
n04041544 radio, wireless
|
| 756 |
+
n04044716 radio telescope, radio reflector
|
| 757 |
+
n04049303 rain barrel
|
| 758 |
+
n04065272 recreational vehicle, RV, R.V.
|
| 759 |
+
n04067472 reel
|
| 760 |
+
n04069434 reflex camera
|
| 761 |
+
n04070727 refrigerator, icebox
|
| 762 |
+
n04074963 remote control, remote
|
| 763 |
+
n04081281 restaurant, eating house, eating place, eatery
|
| 764 |
+
n04086273 revolver, six-gun, six-shooter
|
| 765 |
+
n04090263 rifle
|
| 766 |
+
n04099969 rocking chair, rocker
|
| 767 |
+
n04111531 rotisserie
|
| 768 |
+
n04116512 rubber eraser, rubber, pencil eraser
|
| 769 |
+
n04118538 rugby ball
|
| 770 |
+
n04118776 rule, ruler
|
| 771 |
+
n04120489 running shoe
|
| 772 |
+
n04125021 safe
|
| 773 |
+
n04127249 safety pin
|
| 774 |
+
n04131690 saltshaker, salt shaker
|
| 775 |
+
n04133789 sandal
|
| 776 |
+
n04136333 sarong
|
| 777 |
+
n04141076 sax, saxophone
|
| 778 |
+
n04141327 scabbard
|
| 779 |
+
n04141975 scale, weighing machine
|
| 780 |
+
n04146614 school bus
|
| 781 |
+
n04147183 schooner
|
| 782 |
+
n04149813 scoreboard
|
| 783 |
+
n04152593 screen, CRT screen
|
| 784 |
+
n04153751 screw
|
| 785 |
+
n04154565 screwdriver
|
| 786 |
+
n04162706 seat belt, seatbelt
|
| 787 |
+
n04179913 sewing machine
|
| 788 |
+
n04192698 shield, buckler
|
| 789 |
+
n04200800 shoe shop, shoe-shop, shoe store
|
| 790 |
+
n04201297 shoji
|
| 791 |
+
n04204238 shopping basket
|
| 792 |
+
n04204347 shopping cart
|
| 793 |
+
n04208210 shovel
|
| 794 |
+
n04209133 shower cap
|
| 795 |
+
n04209239 shower curtain
|
| 796 |
+
n04228054 ski
|
| 797 |
+
n04229816 ski mask
|
| 798 |
+
n04235860 sleeping bag
|
| 799 |
+
n04238763 slide rule, slipstick
|
| 800 |
+
n04239074 sliding door
|
| 801 |
+
n04243546 slot, one-armed bandit
|
| 802 |
+
n04251144 snorkel
|
| 803 |
+
n04252077 snowmobile
|
| 804 |
+
n04252225 snowplow, snowplough
|
| 805 |
+
n04254120 soap dispenser
|
| 806 |
+
n04254680 soccer ball
|
| 807 |
+
n04254777 sock
|
| 808 |
+
n04258138 solar dish, solar collector, solar furnace
|
| 809 |
+
n04259630 sombrero
|
| 810 |
+
n04263257 soup bowl
|
| 811 |
+
n04264628 space bar
|
| 812 |
+
n04265275 space heater
|
| 813 |
+
n04266014 space shuttle
|
| 814 |
+
n04270147 spatula
|
| 815 |
+
n04273569 speedboat
|
| 816 |
+
n04275548 spider web, spider's web
|
| 817 |
+
n04277352 spindle
|
| 818 |
+
n04285008 sports car, sport car
|
| 819 |
+
n04286575 spotlight, spot
|
| 820 |
+
n04296562 stage
|
| 821 |
+
n04310018 steam locomotive
|
| 822 |
+
n04311004 steel arch bridge
|
| 823 |
+
n04311174 steel drum
|
| 824 |
+
n04317175 stethoscope
|
| 825 |
+
n04325704 stole
|
| 826 |
+
n04326547 stone wall
|
| 827 |
+
n04328186 stopwatch, stop watch
|
| 828 |
+
n04330267 stove
|
| 829 |
+
n04332243 strainer
|
| 830 |
+
n04335435 streetcar, tram, tramcar, trolley, trolley car
|
| 831 |
+
n04336792 stretcher
|
| 832 |
+
n04344873 studio couch, day bed
|
| 833 |
+
n04346328 stupa, tope
|
| 834 |
+
n04347754 submarine, pigboat, sub, U-boat
|
| 835 |
+
n04350905 suit, suit of clothes
|
| 836 |
+
n04355338 sundial
|
| 837 |
+
n04355933 sunglass
|
| 838 |
+
n04356056 sunglasses, dark glasses, shades
|
| 839 |
+
n04357314 sunscreen, sunblock, sun blocker
|
| 840 |
+
n04366367 suspension bridge
|
| 841 |
+
n04367480 swab, swob, mop
|
| 842 |
+
n04370456 sweatshirt
|
| 843 |
+
n04371430 swimming trunks, bathing trunks
|
| 844 |
+
n04371774 swing
|
| 845 |
+
n04372370 switch, electric switch, electrical switch
|
| 846 |
+
n04376876 syringe
|
| 847 |
+
n04380533 table lamp
|
| 848 |
+
n04389033 tank, army tank, armored combat vehicle, armoured combat vehicle
|
| 849 |
+
n04392985 tape player
|
| 850 |
+
n04398044 teapot
|
| 851 |
+
n04399382 teddy, teddy bear
|
| 852 |
+
n04404412 television, television system
|
| 853 |
+
n04409515 tennis ball
|
| 854 |
+
n04417672 thatch, thatched roof
|
| 855 |
+
n04418357 theater curtain, theatre curtain
|
| 856 |
+
n04423845 thimble
|
| 857 |
+
n04428191 thresher, thrasher, threshing machine
|
| 858 |
+
n04429376 throne
|
| 859 |
+
n04435653 tile roof
|
| 860 |
+
n04442312 toaster
|
| 861 |
+
n04443257 tobacco shop, tobacconist shop, tobacconist
|
| 862 |
+
n04447861 toilet seat
|
| 863 |
+
n04456115 torch
|
| 864 |
+
n04458633 totem pole
|
| 865 |
+
n04461696 tow truck, tow car, wrecker
|
| 866 |
+
n04462240 toyshop
|
| 867 |
+
n04465501 tractor
|
| 868 |
+
n04467665 trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi
|
| 869 |
+
n04476259 tray
|
| 870 |
+
n04479046 trench coat
|
| 871 |
+
n04482393 tricycle, trike, velocipede
|
| 872 |
+
n04483307 trimaran
|
| 873 |
+
n04485082 tripod
|
| 874 |
+
n04486054 triumphal arch
|
| 875 |
+
n04487081 trolleybus, trolley coach, trackless trolley
|
| 876 |
+
n04487394 trombone
|
| 877 |
+
n04493381 tub, vat
|
| 878 |
+
n04501370 turnstile
|
| 879 |
+
n04505470 typewriter keyboard
|
| 880 |
+
n04507155 umbrella
|
| 881 |
+
n04509417 unicycle, monocycle
|
| 882 |
+
n04515003 upright, upright piano
|
| 883 |
+
n04517823 vacuum, vacuum cleaner
|
| 884 |
+
n04522168 vase
|
| 885 |
+
n04523525 vault
|
| 886 |
+
n04525038 velvet
|
| 887 |
+
n04525305 vending machine
|
| 888 |
+
n04532106 vestment
|
| 889 |
+
n04532670 viaduct
|
| 890 |
+
n04536866 violin, fiddle
|
| 891 |
+
n04540053 volleyball
|
| 892 |
+
n04542943 waffle iron
|
| 893 |
+
n04548280 wall clock
|
| 894 |
+
n04548362 wallet, billfold, notecase, pocketbook
|
| 895 |
+
n04550184 wardrobe, closet, press
|
| 896 |
+
n04552348 warplane, military plane
|
| 897 |
+
n04553703 washbasin, handbasin, washbowl, lavabo, wash-hand basin
|
| 898 |
+
n04554684 washer, automatic washer, washing machine
|
| 899 |
+
n04557648 water bottle
|
| 900 |
+
n04560804 water jug
|
| 901 |
+
n04562935 water tower
|
| 902 |
+
n04579145 whiskey jug
|
| 903 |
+
n04579432 whistle
|
| 904 |
+
n04584207 wig
|
| 905 |
+
n04589890 window screen
|
| 906 |
+
n04590129 window shade
|
| 907 |
+
n04591157 Windsor tie
|
| 908 |
+
n04591713 wine bottle
|
| 909 |
+
n04592741 wing
|
| 910 |
+
n04596742 wok
|
| 911 |
+
n04597913 wooden spoon
|
| 912 |
+
n04599235 wool, woolen, woollen
|
| 913 |
+
n04604644 worm fence, snake fence, snake-rail fence, Virginia fence
|
| 914 |
+
n04606251 wreck
|
| 915 |
+
n04612504 yawl
|
| 916 |
+
n04613696 yurt
|
| 917 |
+
n06359193 web site, website, internet site, site
|
| 918 |
+
n06596364 comic book
|
| 919 |
+
n06785654 crossword puzzle, crossword
|
| 920 |
+
n06794110 street sign
|
| 921 |
+
n06874185 traffic light, traffic signal, stoplight
|
| 922 |
+
n07248320 book jacket, dust cover, dust jacket, dust wrapper
|
| 923 |
+
n07565083 menu
|
| 924 |
+
n07579787 plate
|
| 925 |
+
n07583066 guacamole
|
| 926 |
+
n07584110 consomme
|
| 927 |
+
n07590611 hot pot, hotpot
|
| 928 |
+
n07613480 trifle
|
| 929 |
+
n07614500 ice cream, icecream
|
| 930 |
+
n07615774 ice lolly, lolly, lollipop, popsicle
|
| 931 |
+
n07684084 French loaf
|
| 932 |
+
n07693725 bagel, beigel
|
| 933 |
+
n07695742 pretzel
|
| 934 |
+
n07697313 cheeseburger
|
| 935 |
+
n07697537 hotdog, hot dog, red hot
|
| 936 |
+
n07711569 mashed potato
|
| 937 |
+
n07714571 head cabbage
|
| 938 |
+
n07714990 broccoli
|
| 939 |
+
n07715103 cauliflower
|
| 940 |
+
n07716358 zucchini, courgette
|
| 941 |
+
n07716906 spaghetti squash
|
| 942 |
+
n07717410 acorn squash
|
| 943 |
+
n07717556 butternut squash
|
| 944 |
+
n07718472 cucumber, cuke
|
| 945 |
+
n07718747 artichoke, globe artichoke
|
| 946 |
+
n07720875 bell pepper
|
| 947 |
+
n07730033 cardoon
|
| 948 |
+
n07734744 mushroom
|
| 949 |
+
n07742313 Granny Smith
|
| 950 |
+
n07745940 strawberry
|
| 951 |
+
n07747607 orange
|
| 952 |
+
n07749582 lemon
|
| 953 |
+
n07753113 fig
|
| 954 |
+
n07753275 pineapple, ananas
|
| 955 |
+
n07753592 banana
|
| 956 |
+
n07754684 jackfruit, jak, jack
|
| 957 |
+
n07760859 custard apple
|
| 958 |
+
n07768694 pomegranate
|
| 959 |
+
n07802026 hay
|
| 960 |
+
n07831146 carbonara
|
| 961 |
+
n07836838 chocolate sauce, chocolate syrup
|
| 962 |
+
n07860988 dough
|
| 963 |
+
n07871810 meat loaf, meatloaf
|
| 964 |
+
n07873807 pizza, pizza pie
|
| 965 |
+
n07875152 potpie
|
| 966 |
+
n07880968 burrito
|
| 967 |
+
n07892512 red wine
|
| 968 |
+
n07920052 espresso
|
| 969 |
+
n07930864 cup
|
| 970 |
+
n07932039 eggnog
|
| 971 |
+
n09193705 alp
|
| 972 |
+
n09229709 bubble
|
| 973 |
+
n09246464 cliff, drop, drop-off
|
| 974 |
+
n09256479 coral reef
|
| 975 |
+
n09288635 geyser
|
| 976 |
+
n09332890 lakeside, lakeshore
|
| 977 |
+
n09399592 promontory, headland, head, foreland
|
| 978 |
+
n09421951 sandbar, sand bar
|
| 979 |
+
n09428293 seashore, coast, seacoast, sea-coast
|
| 980 |
+
n09468604 valley, vale
|
| 981 |
+
n09472597 volcano
|
| 982 |
+
n09835506 ballplayer, baseball player
|
| 983 |
+
n10148035 groom, bridegroom
|
| 984 |
+
n10565667 scuba diver
|
| 985 |
+
n11879895 rapeseed
|
| 986 |
+
n11939491 daisy
|
| 987 |
+
n12057211 yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum
|
| 988 |
+
n12144580 corn
|
| 989 |
+
n12267677 acorn
|
| 990 |
+
n12620546 hip, rose hip, rosehip
|
| 991 |
+
n12768682 buckeye, horse chestnut, conker
|
| 992 |
+
n12985857 coral fungus
|
| 993 |
+
n12998815 agaric
|
| 994 |
+
n13037406 gyromitra
|
| 995 |
+
n13040303 stinkhorn, carrion fungus
|
| 996 |
+
n13044778 earthstar
|
| 997 |
+
n13052670 hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa
|
| 998 |
+
n13054560 bolete
|
| 999 |
+
n13133613 ear, spike, capitulum
|
| 1000 |
+
n15075141 toilet tissue, toilet paper, bathroom tissue
|
docs/Makefile
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Minimal makefile for Sphinx documentation
|
| 2 |
+
#
|
| 3 |
+
|
| 4 |
+
# You can set these variables from the command line, and also
|
| 5 |
+
# from the environment for the first two.
|
| 6 |
+
SPHINXOPTS ?=
|
| 7 |
+
SPHINXBUILD ?= sphinx-build
|
| 8 |
+
SOURCEDIR = .
|
| 9 |
+
BUILDDIR = _build
|
| 10 |
+
|
| 11 |
+
# Put it first so that "make" without argument is like "make help".
|
| 12 |
+
help:
|
| 13 |
+
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
| 14 |
+
|
| 15 |
+
.PHONY: help Makefile
|
| 16 |
+
|
| 17 |
+
# Catch-all target: route all unknown targets to Sphinx using the new
|
| 18 |
+
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
| 19 |
+
%: Makefile
|
| 20 |
+
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
docs/conf.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Configuration file for the Sphinx documentation builder.
|
| 2 |
+
#
|
| 3 |
+
# This file only contains a selection of the most common options. For a full
|
| 4 |
+
# list see the documentation:
|
| 5 |
+
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
| 6 |
+
|
| 7 |
+
# -- Path setup --------------------------------------------------------------
|
| 8 |
+
|
| 9 |
+
# If extensions (or modules to document with autodoc) are in another directory,
|
| 10 |
+
# add these directories to sys.path here. If the directory is relative to the
|
| 11 |
+
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
| 12 |
+
#
|
| 13 |
+
import os
|
| 14 |
+
import sys
|
| 15 |
+
# sys.path.insert(0, os.path.abspath('.'))
|
| 16 |
+
sys.path.insert(0, os.path.abspath('../'))
|
| 17 |
+
# sys.path.append('/home/jinwei/Laboratory/api')
|
| 18 |
+
sys.path.append('../..')
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# -- Project information -----------------------------------------------------
|
| 22 |
+
|
| 23 |
+
project = 'DeepRobust'
|
| 24 |
+
copyright = ''
|
| 25 |
+
author = 'Yaxin Li, Wei Jin, Han Xu, Jiliang Tang'
|
| 26 |
+
|
| 27 |
+
# The full version, including alpha/beta/rc tags
|
| 28 |
+
release = '0.1.1'
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# -- General configuration ---------------------------------------------------
|
| 32 |
+
|
| 33 |
+
# Add any Sphinx extension module names here, as strings. They can be
|
| 34 |
+
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
| 35 |
+
# ones.
|
| 36 |
+
extensions = ['sphinx.ext.todo', 'sphinx.ext.viewcode', 'sphinx.ext.autodoc', 'sphinx.ext.napoleon',
|
| 37 |
+
'sphinx.ext.autosummary', 'sphinx.ext.mathjax', 'sphinx.ext.viewcode', 'sphinx.ext.githubpages' ]
|
| 38 |
+
|
| 39 |
+
# extensions = ['sphinx.ext.napoleon']
|
| 40 |
+
autodoc_mock_imports = ['torch', 'torchvision', 'texttable', 'tensorboardX',
|
| 41 |
+
'torch_geometric', 'gensim', 'node2vec']
|
| 42 |
+
|
| 43 |
+
# remove undoc members
|
| 44 |
+
#autodoc_default_flags = ['members']
|
| 45 |
+
|
| 46 |
+
# Add any paths that contain templates here, relative to this directory.
|
| 47 |
+
templates_path = ['_templates']
|
| 48 |
+
|
| 49 |
+
# List of patterns, relative to source directory, that match files and
|
| 50 |
+
# directories to ignore when looking for source files.
|
| 51 |
+
# This pattern also affects html_static_path and html_extra_path.
|
| 52 |
+
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# -- Options for HTML output -------------------------------------------------
|
| 56 |
+
|
| 57 |
+
# The theme to use for HTML and HTML Help pages. See the documentation for
|
| 58 |
+
# a list of builtin themes.
|
| 59 |
+
#
|
| 60 |
+
# html_theme = 'alabaster'
|
| 61 |
+
html_theme = 'sphinx_rtd_theme'
|
| 62 |
+
|
| 63 |
+
# Add any paths that contain custom static files (such as style sheets) here,
|
| 64 |
+
# relative to this directory. They are copied after the builtin static files,
|
| 65 |
+
# so a file named "default.css" will overwrite the builtin "default.css".
|
| 66 |
+
html_static_path = ['_static']
|
| 67 |
+
|
| 68 |
+
add_module_names = False
|
| 69 |
+
|
| 70 |
+
master_doc = 'index'
|
| 71 |
+
|
docs/index.rst
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.. Deep documentation master file, created by
|
| 2 |
+
sphinx-quickstart on Fri Jul 3 12:19:59 2020.
|
| 3 |
+
You can adapt this file completely to your liking, but it should at least
|
| 4 |
+
contain the root `toctree` directive.
|
| 5 |
+
|
| 6 |
+
Start building your robust models with DeepRobust!
|
| 7 |
+
================================
|
| 8 |
+
.. comments original size: 626*238
|
| 9 |
+
|
| 10 |
+
.. image:: ./DeepRobust.png
|
| 11 |
+
:width: 313px
|
| 12 |
+
:height: 119px
|
| 13 |
+
|
| 14 |
+
DeepRobust is a pytorch adversarial learning library, which contains most popular attack and defense algorithms in image domain and graph domain.
|
| 15 |
+
|
| 16 |
+
.. toctree::
|
| 17 |
+
:glob:
|
| 18 |
+
:maxdepth: 1
|
| 19 |
+
:caption: Installation
|
| 20 |
+
|
| 21 |
+
notes/installation
|
| 22 |
+
|
| 23 |
+
.. toctree::
|
| 24 |
+
:glob:
|
| 25 |
+
:maxdepth: 1
|
| 26 |
+
:caption: Graph Package
|
| 27 |
+
|
| 28 |
+
graph/data
|
| 29 |
+
graph/attack
|
| 30 |
+
graph/defense
|
| 31 |
+
graph/pyg
|
| 32 |
+
graph/node_embedding
|
| 33 |
+
|
| 34 |
+
.. toctree::
|
| 35 |
+
:glob:
|
| 36 |
+
:maxdepth: 1
|
| 37 |
+
:caption: Image Package
|
| 38 |
+
|
| 39 |
+
image/example
|
| 40 |
+
|
| 41 |
+
Package API
|
| 42 |
+
===========
|
| 43 |
+
.. toctree::
|
| 44 |
+
:maxdepth: 1
|
| 45 |
+
:caption: Image Package
|
| 46 |
+
|
| 47 |
+
source/deeprobust.image.attack
|
| 48 |
+
source/deeprobust.image.defense
|
| 49 |
+
source/deeprobust.image.netmodels
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
.. toctree::
|
| 53 |
+
:maxdepth: 1
|
| 54 |
+
:caption: Graph Package
|
| 55 |
+
|
| 56 |
+
source/deeprobust.graph.global_attack
|
| 57 |
+
source/deeprobust.graph.targeted_attack
|
| 58 |
+
source/deeprobust.graph.defense
|
| 59 |
+
source/deeprobust.graph.data
|
| 60 |
+
|
| 61 |
+
Indices and tables
|
| 62 |
+
==================
|
| 63 |
+
|
| 64 |
+
* :ref:`modindex`
|
| 65 |
+
* :ref:`search`
|
examples/graph/cgscore_datasets.py
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# compute cgscore for gcn
|
| 2 |
+
# author: Yaning
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch.nn.functional as Fd
|
| 6 |
+
from deeprobust.graph.defense import GCNJaccard, GCN
|
| 7 |
+
from deeprobust.graph.defense import GCNScore
|
| 8 |
+
from deeprobust.graph.utils import *
|
| 9 |
+
from deeprobust.graph.data import Dataset, PrePtbDataset
|
| 10 |
+
from scipy.sparse import csr_matrix
|
| 11 |
+
import argparse
|
| 12 |
+
import pickle
|
| 13 |
+
from deeprobust.graph import utils
|
| 14 |
+
from collections import defaultdict
|
| 15 |
+
|
| 16 |
+
parser = argparse.ArgumentParser()
|
| 17 |
+
parser.add_argument('--seed', type=int, default=15, help='Random seed.')
|
| 18 |
+
parser.add_argument('--dataset', type=str, default='polblogs', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'], help='dataset')
|
| 19 |
+
parser.add_argument('--ptb_rate', type=float, default=0.10, help='pertubation rate')
|
| 20 |
+
|
| 21 |
+
args = parser.parse_args()
|
| 22 |
+
args.cuda = torch.cuda.is_available()
|
| 23 |
+
print('cuda: %s' % args.cuda)
|
| 24 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 25 |
+
|
| 26 |
+
# make sure you use the same data splits as you generated attacks
|
| 27 |
+
np.random.seed(args.seed)
|
| 28 |
+
if args.cuda:
|
| 29 |
+
torch.cuda.manual_seed(args.seed)
|
| 30 |
+
|
| 31 |
+
# Here the random seed is to split the train/val/test data,
|
| 32 |
+
# we need to set the random seed to be the same as that when you generate the perturbed graph
|
| 33 |
+
# data = Dataset(root='/tmp/', name=args.dataset, setting='nettack', seed=15)
|
| 34 |
+
# Or we can just use setting='prognn' to get the splits
|
| 35 |
+
data = Dataset(root='/tmp/', name=args.dataset, setting='prognn')
|
| 36 |
+
adj, features, labels = data.adj, data.features, data.labels
|
| 37 |
+
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
perturbed_data = PrePtbDataset(root='/tmp/',
|
| 41 |
+
name=args.dataset,
|
| 42 |
+
attack_method='meta',
|
| 43 |
+
ptb_rate=args.ptb_rate)
|
| 44 |
+
|
| 45 |
+
perturbed_adj = perturbed_data.adj
|
| 46 |
+
# perturbed_adj = adj
|
| 47 |
+
|
| 48 |
+
def save_cg_scores(cg_scores, filename="cg_scores.npy"):
|
| 49 |
+
np.save(filename, cg_scores)
|
| 50 |
+
print(f"CG-scores saved to {filename}")
|
| 51 |
+
|
| 52 |
+
def load_cg_scores_numpy(filename="cg_scores.npy"):
|
| 53 |
+
cg_scores = np.load(filename, allow_pickle=True)
|
| 54 |
+
print(f"CG-scores loaded from {filename}")
|
| 55 |
+
return cg_scores
|
| 56 |
+
|
| 57 |
+
def calc_cg_score_gnn_with_sampling(
|
| 58 |
+
A, X, labels, device, rep_num=1, unbalance_ratio=1, sub_term=False
|
| 59 |
+
):
|
| 60 |
+
"""
|
| 61 |
+
Calculate CG-score for each edge in a graph with node labels and random sampling.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
A: torch.Tensor
|
| 65 |
+
Adjacency matrix of the graph (size: N x N).
|
| 66 |
+
X: torch.Tensor
|
| 67 |
+
Node features matrix (size: N x F).
|
| 68 |
+
labels: torch.Tensor
|
| 69 |
+
Node labels (size: N).
|
| 70 |
+
device: torch.device
|
| 71 |
+
Device to perform calculations.
|
| 72 |
+
rep_num: int
|
| 73 |
+
Number of repetitions for Monte Carlo sampling.
|
| 74 |
+
unbalance_ratio: float
|
| 75 |
+
Ratio of unbalanced data (1:unbalance_ratio).
|
| 76 |
+
sub_term: bool
|
| 77 |
+
If True, calculate and return sub-terms.
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
cg_scores: dict
|
| 81 |
+
Dictionary containing CG-scores for edges and optionally sub-terms.
|
| 82 |
+
"""
|
| 83 |
+
N = A.shape[0]
|
| 84 |
+
cg_scores = {
|
| 85 |
+
"vi": np.zeros((N, N)),
|
| 86 |
+
"ab": np.zeros((N, N)),
|
| 87 |
+
"a2": np.zeros((N, N)),
|
| 88 |
+
"b2": np.zeros((N, N)),
|
| 89 |
+
"times": np.zeros((N, N)),
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
with torch.no_grad():
|
| 93 |
+
for _ in range(rep_num):
|
| 94 |
+
# Compute AX (node representations)
|
| 95 |
+
AX = torch.matmul(A, X).to(device)
|
| 96 |
+
norm_AX = AX / torch.norm(AX, dim=1, keepdim=True)
|
| 97 |
+
|
| 98 |
+
# Group nodes by their labels
|
| 99 |
+
dataset = defaultdict(list)
|
| 100 |
+
data_idx = defaultdict(list)
|
| 101 |
+
for i, label in enumerate(labels):
|
| 102 |
+
dataset[label.item()].append(norm_AX[i].unsqueeze(0)) # Store normalized data
|
| 103 |
+
data_idx[label.item()].append(i) # Store indices
|
| 104 |
+
|
| 105 |
+
# Convert to tensors
|
| 106 |
+
for label, data_list in dataset.items():
|
| 107 |
+
dataset[label] = torch.cat(data_list, dim=0)
|
| 108 |
+
data_idx[label] = torch.tensor(data_idx[label], dtype=torch.long, device=device)
|
| 109 |
+
|
| 110 |
+
# Calculate CG-scores for each label group
|
| 111 |
+
for curr_label, curr_samples in dataset.items():
|
| 112 |
+
curr_indices = data_idx[curr_label]
|
| 113 |
+
curr_num = len(curr_samples)
|
| 114 |
+
|
| 115 |
+
# Randomly sample a subset of current label examples
|
| 116 |
+
chosen_curr_idx = np.random.choice(range(curr_num), curr_num, replace=False)
|
| 117 |
+
chosen_curr_samples = curr_samples[chosen_curr_idx]
|
| 118 |
+
chosen_curr_indices = curr_indices[chosen_curr_idx]
|
| 119 |
+
|
| 120 |
+
# Sample negative examples from other classes
|
| 121 |
+
neg_samples = torch.cat(
|
| 122 |
+
[dataset[l] for l in dataset if l != curr_label], dim=0
|
| 123 |
+
)
|
| 124 |
+
neg_indices = torch.cat(
|
| 125 |
+
[data_idx[l] for l in data_idx if l != curr_label], dim=0
|
| 126 |
+
)
|
| 127 |
+
neg_num = min(int(curr_num * unbalance_ratio), len(neg_samples))
|
| 128 |
+
chosen_neg_samples = neg_samples[
|
| 129 |
+
torch.randperm(len(neg_samples))[:neg_num]
|
| 130 |
+
]
|
| 131 |
+
|
| 132 |
+
# Combine positive and negative samples
|
| 133 |
+
combined_samples = torch.cat([chosen_curr_samples, chosen_neg_samples], dim=0)
|
| 134 |
+
y = torch.cat(
|
| 135 |
+
[torch.ones(len(chosen_curr_samples)), -torch.ones(neg_num)], dim=0
|
| 136 |
+
).to(device)
|
| 137 |
+
|
| 138 |
+
# Compute the Gram matrix H^\infty
|
| 139 |
+
H_inner = torch.matmul(combined_samples, combined_samples.T)
|
| 140 |
+
del combined_samples
|
| 141 |
+
###
|
| 142 |
+
H_inner = torch.clamp(H_inner, min=-1.0, max=1.0)
|
| 143 |
+
###
|
| 144 |
+
H = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi)
|
| 145 |
+
del H_inner
|
| 146 |
+
|
| 147 |
+
H.fill_diagonal_(0.5)
|
| 148 |
+
##
|
| 149 |
+
epsilon = 1e-6
|
| 150 |
+
H = H + epsilon * torch.eye(H.size(0), device=H.device)
|
| 151 |
+
##
|
| 152 |
+
invH = torch.inverse(H)
|
| 153 |
+
del H
|
| 154 |
+
original_error = y @ (invH @ y)
|
| 155 |
+
|
| 156 |
+
# Compute CG-scores for each edge
|
| 157 |
+
for i in chosen_curr_indices:
|
| 158 |
+
print("the node index:", i)
|
| 159 |
+
for j in range(i + 1, N): # Upper triangular traversal
|
| 160 |
+
# print(j)
|
| 161 |
+
if A[i, j] == 0: # Skip if no edge exists
|
| 162 |
+
continue
|
| 163 |
+
|
| 164 |
+
# Remove edge (i, j) to create A1
|
| 165 |
+
A1 = A.clone()
|
| 166 |
+
A1[i, j] = A1[j, i] = 0
|
| 167 |
+
|
| 168 |
+
# Recompute AX with A1
|
| 169 |
+
AX1 = torch.matmul(A1, X).to(device)
|
| 170 |
+
norm_AX1 = AX1 / torch.norm(AX1, dim=1, keepdim=True)
|
| 171 |
+
|
| 172 |
+
# Repeat error calculation with A1
|
| 173 |
+
curr_samples_A1 = norm_AX1[chosen_curr_indices]
|
| 174 |
+
neg_samples_A1 = norm_AX1[neg_indices]
|
| 175 |
+
chosen_neg_samples_A1 = neg_samples_A1[
|
| 176 |
+
torch.randperm(len(neg_samples_A1))[:neg_num]
|
| 177 |
+
]
|
| 178 |
+
combined_samples_A1 = torch.cat(
|
| 179 |
+
[curr_samples_A1, chosen_neg_samples_A1], dim=0
|
| 180 |
+
)
|
| 181 |
+
H_inner_A1 = torch.matmul(combined_samples_A1, combined_samples_A1.T)
|
| 182 |
+
|
| 183 |
+
del combined_samples_A1
|
| 184 |
+
|
| 185 |
+
### trick1
|
| 186 |
+
H_inner_A1 = torch.clamp(H_inner_A1, min=-1.0, max=1.0)
|
| 187 |
+
###
|
| 188 |
+
|
| 189 |
+
H_A1 = H_inner_A1 * (np.pi - torch.acos(H_inner_A1)) / (2 * np.pi)
|
| 190 |
+
del H_inner_A1
|
| 191 |
+
H_A1.fill_diagonal_(0.5)
|
| 192 |
+
|
| 193 |
+
### trick2
|
| 194 |
+
epsilon = 1e-6
|
| 195 |
+
H_A1= H_A1 + epsilon * torch.eye(H_A1.size(0), device=H_A1.device)
|
| 196 |
+
###
|
| 197 |
+
invH_A1 = torch.inverse(H_A1)
|
| 198 |
+
del H_A1
|
| 199 |
+
|
| 200 |
+
error_A1 = y @ (invH_A1 @ y)
|
| 201 |
+
|
| 202 |
+
print("i:", i)
|
| 203 |
+
print("j:", j)
|
| 204 |
+
print("current score:", (original_error - error_A1).item())
|
| 205 |
+
# Compute the difference in error (CG-score)
|
| 206 |
+
cg_scores["vi"][i, j] += (original_error - error_A1).item()
|
| 207 |
+
cg_scores["vi"][j, i] = cg_scores["vi"][i, j] # Symmetric
|
| 208 |
+
cg_scores["times"][i, j] += 1
|
| 209 |
+
cg_scores["times"][j, i] += 1
|
| 210 |
+
|
| 211 |
+
# Normalize CG-scores by repetition count
|
| 212 |
+
for key, values in cg_scores.items():
|
| 213 |
+
if key == "times":
|
| 214 |
+
continue
|
| 215 |
+
cg_scores[key] = values / np.where(cg_scores["times"] > 0, cg_scores["times"], 1)
|
| 216 |
+
|
| 217 |
+
return cg_scores if sub_term else cg_scores["vi"]
|
| 218 |
+
|
| 219 |
+
def is_symmetric_sparse(adj):
|
| 220 |
+
"""
|
| 221 |
+
Check if a sparse matrix is symmetric.
|
| 222 |
+
"""
|
| 223 |
+
# Check symmetry
|
| 224 |
+
return (adj != adj.transpose()).nnz == 0 # .nnz is the number of non-zero elements
|
| 225 |
+
|
| 226 |
+
def make_symmetric_sparse(adj):
|
| 227 |
+
"""
|
| 228 |
+
Ensure the sparse adjacency matrix is symmetrical.
|
| 229 |
+
"""
|
| 230 |
+
# Make the matrix symmetric
|
| 231 |
+
sym_adj = (adj + adj.transpose()) / 2
|
| 232 |
+
return sym_adj
|
| 233 |
+
|
| 234 |
+
perturbed_adj = make_symmetric_sparse(perturbed_adj)
|
| 235 |
+
|
| 236 |
+
if type(perturbed_adj) is not torch.Tensor:
|
| 237 |
+
features, perturbed_adj, labels = utils.to_tensor(features, perturbed_adj, labels)
|
| 238 |
+
else:
|
| 239 |
+
features = features.to(device)
|
| 240 |
+
perturbed_adj = perturbed_adj.to(device)
|
| 241 |
+
labels = labels.to(device)
|
| 242 |
+
|
| 243 |
+
if utils.is_sparse_tensor(perturbed_adj):
|
| 244 |
+
|
| 245 |
+
adj_norm = utils.normalize_adj_tensor(perturbed_adj, sparse=True)
|
| 246 |
+
else:
|
| 247 |
+
adj_norm = utils.normalize_adj_tensor(perturbed_adj)
|
| 248 |
+
|
| 249 |
+
features = features.to_dense()
|
| 250 |
+
perturbed_adj = adj_norm.to_dense()
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
calc_cg_score = calc_cg_score_gnn_with_sampling(perturbed_adj, features, labels, device, rep_num=1, unbalance_ratio=1, sub_term=False)
|
| 254 |
+
save_cg_scores(calc_cg_score, filename="cg_scores_polblogs_0.10.npy")
|
| 255 |
+
# print("completed")
|
examples/graph/cgscore_datasets_multigpus.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# compute cgscore for gcn
|
| 2 |
+
# author: Yaning
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch.nn.functional as Fd
|
| 6 |
+
from deeprobust.graph.defense import GCNJaccard, GCN
|
| 7 |
+
from deeprobust.graph.defense import GCNScore
|
| 8 |
+
from deeprobust.graph.utils import *
|
| 9 |
+
from deeprobust.graph.data import Dataset, PrePtbDataset
|
| 10 |
+
from scipy.sparse import csr_matrix
|
| 11 |
+
import argparse
|
| 12 |
+
import pickle
|
| 13 |
+
from deeprobust.graph import utils
|
| 14 |
+
import torch.multiprocessing as mp
|
| 15 |
+
from collections import defaultdict
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
|
| 18 |
+
parser = argparse.ArgumentParser()
|
| 19 |
+
parser.add_argument('--seed', type=int, default=15, help='Random seed.')
|
| 20 |
+
parser.add_argument('--dataset', type=str, default='pubmed', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'], help='dataset')
|
| 21 |
+
parser.add_argument('--ptb_rate', type=float, default=0.05, help='pertubation rate')
|
| 22 |
+
|
| 23 |
+
args = parser.parse_args()
|
| 24 |
+
args.cuda = torch.cuda.is_available()
|
| 25 |
+
print('cuda: %s' % args.cuda)
|
| 26 |
+
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
|
| 27 |
+
|
| 28 |
+
# make sure you use the same data splits as you generated attacks
|
| 29 |
+
np.random.seed(args.seed)
|
| 30 |
+
if args.cuda:
|
| 31 |
+
torch.cuda.manual_seed(args.seed)
|
| 32 |
+
|
| 33 |
+
# Here the random seed is to split the train/val/test data,
|
| 34 |
+
# we need to set the random seed to be the same as that when you generate the perturbed graph
|
| 35 |
+
# data = Dataset(root='/tmp/', name=args.dataset, setting='nettack', seed=15)
|
| 36 |
+
# Or we can just use setting='prognn' to get the splits
|
| 37 |
+
data = Dataset(root='/tmp/', name=args.dataset, setting='prognn')
|
| 38 |
+
adj, features, labels = data.adj, data.features, data.labels
|
| 39 |
+
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
perturbed_data = PrePtbDataset(root='/tmp/',
|
| 43 |
+
name=args.dataset,
|
| 44 |
+
attack_method='meta',
|
| 45 |
+
ptb_rate=args.ptb_rate)
|
| 46 |
+
|
| 47 |
+
perturbed_adj = perturbed_data.adj
|
| 48 |
+
# perturbed_adj = adj
|
| 49 |
+
|
| 50 |
+
def save_cg_scores(cg_scores, filename="cg_scores.npy"):
|
| 51 |
+
np.save(filename, cg_scores)
|
| 52 |
+
print(f"CG-scores saved to {filename}")
|
| 53 |
+
|
| 54 |
+
def load_cg_scores_numpy(filename="cg_scores.npy"):
|
| 55 |
+
cg_scores = np.load(filename, allow_pickle=True)
|
| 56 |
+
print(f"CG-scores loaded from {filename}")
|
| 57 |
+
return cg_scores
|
| 58 |
+
|
| 59 |
+
def calc_cg_score_gnn_with_sampling(
|
| 60 |
+
A, X, labels, device, rep_num=1, unbalance_ratio=1, sub_term=False, batch_size=64, label_filter=None
|
| 61 |
+
):
|
| 62 |
+
N = A.shape[0]
|
| 63 |
+
cg_scores = {
|
| 64 |
+
"vi": np.zeros((N, N)),
|
| 65 |
+
"ab": np.zeros((N, N)),
|
| 66 |
+
"a2": np.zeros((N, N)),
|
| 67 |
+
"b2": np.zeros((N, N)),
|
| 68 |
+
"times": np.zeros((N, N)),
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
A = A.to(device)
|
| 72 |
+
X = X.to(device)
|
| 73 |
+
labels = labels.to(device)
|
| 74 |
+
|
| 75 |
+
@torch.no_grad()
|
| 76 |
+
def normalize(tensor):
|
| 77 |
+
return tensor / (torch.norm(tensor, dim=1, keepdim=True) + 1e-8)
|
| 78 |
+
|
| 79 |
+
for _ in range(rep_num):
|
| 80 |
+
AX = torch.matmul(A, X)
|
| 81 |
+
norm_AX = normalize(AX)
|
| 82 |
+
|
| 83 |
+
unique_labels = torch.unique(labels)
|
| 84 |
+
if label_filter is not None:
|
| 85 |
+
unique_labels = [label for label in unique_labels if label.item() in label_filter]
|
| 86 |
+
|
| 87 |
+
label_to_indices = {
|
| 88 |
+
label.item(): (labels == label).nonzero(as_tuple=True)[0] for label in unique_labels
|
| 89 |
+
}
|
| 90 |
+
dataset = {label: norm_AX[indices] for label, indices in label_to_indices.items()}
|
| 91 |
+
|
| 92 |
+
neg_samples_dict = {}
|
| 93 |
+
neg_indices_dict = {}
|
| 94 |
+
for label in unique_labels:
|
| 95 |
+
print("label:", label)
|
| 96 |
+
label = label.item()
|
| 97 |
+
mask = labels != label
|
| 98 |
+
neg_samples = norm_AX[mask]
|
| 99 |
+
neg_indices = mask.nonzero(as_tuple=True)[0]
|
| 100 |
+
neg_samples_dict[label] = neg_samples
|
| 101 |
+
neg_indices_dict[label] = neg_indices
|
| 102 |
+
|
| 103 |
+
for curr_label in tqdm(unique_labels, desc="Label groups", position=device.index):
|
| 104 |
+
print("curr_label:", curr_label)
|
| 105 |
+
label_id = int(curr_label)
|
| 106 |
+
curr_samples = dataset[label_id]
|
| 107 |
+
curr_indices = label_to_indices[label_id]
|
| 108 |
+
curr_num = len(curr_samples)
|
| 109 |
+
|
| 110 |
+
chosen_curr_idx = torch.randperm(curr_num, device=device)
|
| 111 |
+
chosen_curr_samples = curr_samples[chosen_curr_idx]
|
| 112 |
+
chosen_curr_indices = curr_indices[chosen_curr_idx]
|
| 113 |
+
|
| 114 |
+
neg_samples = neg_samples_dict[label_id]
|
| 115 |
+
neg_indices = neg_indices_dict[label_id]
|
| 116 |
+
neg_num = min(int(curr_num * unbalance_ratio), len(neg_samples))
|
| 117 |
+
rand_idx = torch.randperm(len(neg_samples), device=device)[:neg_num]
|
| 118 |
+
chosen_neg_samples = neg_samples[rand_idx]
|
| 119 |
+
chosen_neg_indices = neg_indices[rand_idx]
|
| 120 |
+
|
| 121 |
+
combined_samples = torch.cat([chosen_curr_samples, chosen_neg_samples], dim=0)
|
| 122 |
+
y = torch.cat([torch.ones(len(chosen_curr_samples)), -torch.ones(neg_num)], dim=0).to(device)
|
| 123 |
+
|
| 124 |
+
H_inner = torch.matmul(combined_samples, combined_samples.T)
|
| 125 |
+
H_inner = torch.clamp(H_inner, min=-1.0, max=1.0)
|
| 126 |
+
H = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi)
|
| 127 |
+
H.fill_diagonal_(0.5)
|
| 128 |
+
H += 1e-6 * torch.eye(H.size(0), device=device)
|
| 129 |
+
invH = torch.inverse(H)
|
| 130 |
+
original_error = y @ (invH @ y)
|
| 131 |
+
|
| 132 |
+
edge_batch = []
|
| 133 |
+
for idx_i in chosen_curr_indices.tolist():
|
| 134 |
+
for j in range(idx_i + 1, N):
|
| 135 |
+
if A[idx_i, j] != 0:
|
| 136 |
+
edge_batch.append((idx_i, j))
|
| 137 |
+
|
| 138 |
+
for k in tqdm(range(0, len(edge_batch), batch_size), desc="Edge batches", leave=False, position=device.index):
|
| 139 |
+
batch = edge_batch[k: k + batch_size]
|
| 140 |
+
B = len(batch)
|
| 141 |
+
|
| 142 |
+
norm_AX1_batch = norm_AX.repeat(B, 1, 1).clone()
|
| 143 |
+
for b, (i, j) in enumerate(batch):
|
| 144 |
+
AX1_i = AX[i] - A[i, j] * X[j]
|
| 145 |
+
AX1_j = AX[j] - A[j, i] * X[i]
|
| 146 |
+
norm_AX1_batch[b, i] = AX1_i / (torch.norm(AX1_i) + 1e-8)
|
| 147 |
+
norm_AX1_batch[b, j] = AX1_j / (torch.norm(AX1_j) + 1e-8)
|
| 148 |
+
|
| 149 |
+
sample_idx = chosen_curr_indices.tolist() + chosen_neg_indices.tolist()
|
| 150 |
+
sample_batch = norm_AX1_batch[:, sample_idx, :]
|
| 151 |
+
|
| 152 |
+
H_inner = torch.matmul(sample_batch, sample_batch.transpose(1, 2))
|
| 153 |
+
H_inner = torch.clamp(H_inner, min=-1.0, max=1.0)
|
| 154 |
+
H = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi)
|
| 155 |
+
eye = torch.eye(H.size(-1), device=device).unsqueeze(0).expand_as(H)
|
| 156 |
+
H = H + 1e-6 * eye
|
| 157 |
+
H.diagonal(dim1=-2, dim2=-1).copy_(0.5)
|
| 158 |
+
|
| 159 |
+
invH = torch.inverse(H)
|
| 160 |
+
y_expanded = y.unsqueeze(0).expand(B, -1)
|
| 161 |
+
error_A1 = torch.einsum("bi,bij,bj->b", y_expanded, invH, y_expanded)
|
| 162 |
+
|
| 163 |
+
for b, (i, j) in enumerate(batch):
|
| 164 |
+
score = (original_error - error_A1[b]).item()
|
| 165 |
+
cg_scores["vi"][i, j] += score
|
| 166 |
+
cg_scores["vi"][j, i] = score
|
| 167 |
+
cg_scores["times"][i, j] += 1
|
| 168 |
+
cg_scores["times"][j, i] += 1
|
| 169 |
+
|
| 170 |
+
for key in cg_scores:
|
| 171 |
+
if key != "times":
|
| 172 |
+
cg_scores[key] = cg_scores[key] / np.where(cg_scores["times"] > 0, cg_scores["times"], 1)
|
| 173 |
+
|
| 174 |
+
# return cg_scores if sub_term else cg_scores["vi"]
|
| 175 |
+
return cg_scores
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def run_worker(gpu_id, world_size, A, X, labels, rep_num, unbalance_ratio, sub_term, batch_size, return_dict):
|
| 179 |
+
device = torch.device(f"cuda:{gpu_id}")
|
| 180 |
+
unique_labels = torch.unique(labels).tolist()
|
| 181 |
+
label_chunks = np.array_split(unique_labels, world_size)
|
| 182 |
+
rank = torch.cuda.current_device()
|
| 183 |
+
label_filter = [int(l) for l in label_chunks[gpu_id % world_size]]
|
| 184 |
+
|
| 185 |
+
result = calc_cg_score_gnn_with_sampling(
|
| 186 |
+
A, X, labels, device,
|
| 187 |
+
rep_num=rep_num,
|
| 188 |
+
unbalance_ratio=unbalance_ratio,
|
| 189 |
+
sub_term=sub_term,
|
| 190 |
+
batch_size=batch_size,
|
| 191 |
+
label_filter=label_filter
|
| 192 |
+
)
|
| 193 |
+
return_dict[gpu_id] = result
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def multi_gpu_wrapper(A, X, labels, rep_num=1, unbalance_ratio=1, sub_term=False, batch_size=64, gpu_ids=None):
|
| 198 |
+
if gpu_ids is None:
|
| 199 |
+
gpu_ids = list(range(torch.cuda.device_count()))
|
| 200 |
+
world_size = len(gpu_ids)
|
| 201 |
+
|
| 202 |
+
manager = mp.Manager()
|
| 203 |
+
return_dict = manager.dict()
|
| 204 |
+
processes = []
|
| 205 |
+
|
| 206 |
+
for local_rank, gpu_id in enumerate(gpu_ids):
|
| 207 |
+
p = mp.Process(
|
| 208 |
+
target=run_worker,
|
| 209 |
+
args=(gpu_id, world_size, A, X, labels, rep_num, unbalance_ratio, sub_term, batch_size, return_dict)
|
| 210 |
+
)
|
| 211 |
+
p.start()
|
| 212 |
+
processes.append(p)
|
| 213 |
+
|
| 214 |
+
for p in processes:
|
| 215 |
+
p.join()
|
| 216 |
+
|
| 217 |
+
# 初始化 final_score
|
| 218 |
+
final_score = None
|
| 219 |
+
|
| 220 |
+
for gpu_id, rank_result in return_dict.items():
|
| 221 |
+
if not isinstance(rank_result, dict):
|
| 222 |
+
print(f"[FATAL] GPU {gpu_id} result is not a dict: {type(rank_result)}")
|
| 223 |
+
continue
|
| 224 |
+
|
| 225 |
+
if final_score is None:
|
| 226 |
+
# 深拷贝防止指针复用
|
| 227 |
+
final_score = {k: np.copy(v) for k, v in rank_result.items()}
|
| 228 |
+
else:
|
| 229 |
+
for key in rank_result:
|
| 230 |
+
if key not in final_score:
|
| 231 |
+
print(f"[WARN] key '{key}' not in final_score. Skipping.")
|
| 232 |
+
continue
|
| 233 |
+
try:
|
| 234 |
+
if isinstance(final_score[key], np.ndarray) and isinstance(rank_result[key], np.ndarray):
|
| 235 |
+
final_score[key] += rank_result[key]
|
| 236 |
+
else:
|
| 237 |
+
print(f"[WARN] Skipped merging key '{key}' due to type mismatch.")
|
| 238 |
+
except Exception as e:
|
| 239 |
+
print(f"[ERROR] Failed merging key '{key}': {e}")
|
| 240 |
+
|
| 241 |
+
return final_score
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def is_symmetric_sparse(adj):
|
| 246 |
+
"""
|
| 247 |
+
Check if a sparse matrix is symmetric.
|
| 248 |
+
"""
|
| 249 |
+
# Check symmetry
|
| 250 |
+
return (adj != adj.transpose()).nnz == 0 # .nnz is the number of non-zero elements
|
| 251 |
+
|
| 252 |
+
def make_symmetric_sparse(adj):
|
| 253 |
+
"""
|
| 254 |
+
Ensure the sparse adjacency matrix is symmetrical.
|
| 255 |
+
"""
|
| 256 |
+
# Make the matrix symmetric
|
| 257 |
+
sym_adj = (adj + adj.transpose()) / 2
|
| 258 |
+
return sym_adj
|
| 259 |
+
|
| 260 |
+
if __name__ == "__main__":
|
| 261 |
+
mp.set_start_method("spawn", force=True)
|
| 262 |
+
|
| 263 |
+
print("cuda:", torch.cuda.is_available())
|
| 264 |
+
|
| 265 |
+
# 选择使用的 GPU(根据你的实际情况)
|
| 266 |
+
selected_gpus = [0, 1, 2, 3]
|
| 267 |
+
|
| 268 |
+
# 稀疏矩阵对称处理
|
| 269 |
+
perturbed_adj = make_symmetric_sparse(perturbed_adj)
|
| 270 |
+
|
| 271 |
+
# 转 tensor
|
| 272 |
+
if type(perturbed_adj) is not torch.Tensor:
|
| 273 |
+
features, perturbed_adj, labels = utils.to_tensor(features, perturbed_adj, labels)
|
| 274 |
+
else:
|
| 275 |
+
features = features.to(device)
|
| 276 |
+
perturbed_adj = perturbed_adj.to(device)
|
| 277 |
+
labels = labels.to(device)
|
| 278 |
+
|
| 279 |
+
# 标准化邻接
|
| 280 |
+
if utils.is_sparse_tensor(perturbed_adj):
|
| 281 |
+
adj_norm = utils.normalize_adj_tensor(perturbed_adj, sparse=True)
|
| 282 |
+
else:
|
| 283 |
+
adj_norm = utils.normalize_adj_tensor(perturbed_adj)
|
| 284 |
+
|
| 285 |
+
features = features.to_dense()
|
| 286 |
+
perturbed_adj = adj_norm.to_dense()
|
| 287 |
+
|
| 288 |
+
# 多GPU并行计算 CG-score
|
| 289 |
+
calc_cg_score = multi_gpu_wrapper(
|
| 290 |
+
perturbed_adj, features, labels,
|
| 291 |
+
rep_num=1,
|
| 292 |
+
unbalance_ratio=1,
|
| 293 |
+
sub_term=False,
|
| 294 |
+
batch_size=1024,
|
| 295 |
+
gpu_ids=selected_gpus
|
| 296 |
+
)
|
| 297 |
+
save_cg_scores(calc_cg_score["vi"], filename=f"{args.dataset}_{args.ptb_rate}.npy")
|
| 298 |
+
|
| 299 |
+
print(" CG-score computation completed.")
|
examples/graph/cgscore_datasets_multigpus2.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# compute cgscore for gcn - Final Optimized Complete Version
|
| 3 |
+
## 精度有损失,但不多
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch.multiprocessing as mp
|
| 7 |
+
import argparse
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from deeprobust.graph.utils import *
|
| 10 |
+
from deeprobust.graph.data import Dataset, PrePtbDataset
|
| 11 |
+
from torch.cuda.amp import autocast
|
| 12 |
+
from deeprobust.graph import utils
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
parser = argparse.ArgumentParser()
|
| 16 |
+
parser.add_argument('--seed', type=int, default=15)
|
| 17 |
+
parser.add_argument('--dataset', type=str, default='pubmed', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'])
|
| 18 |
+
parser.add_argument('--ptb_rate', type=float, default=0.05)
|
| 19 |
+
args = parser.parse_args()
|
| 20 |
+
|
| 21 |
+
args.cuda = torch.cuda.is_available()
|
| 22 |
+
print('cuda: %s' % args.cuda)
|
| 23 |
+
device = torch.device("cuda:0" if args.cuda else "cpu")
|
| 24 |
+
|
| 25 |
+
np.random.seed(args.seed)
|
| 26 |
+
if args.cuda:
|
| 27 |
+
torch.cuda.manual_seed(args.seed)
|
| 28 |
+
|
| 29 |
+
data = Dataset(root='/tmp/', name=args.dataset, setting='prognn')
|
| 30 |
+
adj, features, labels = data.adj, data.features, data.labels
|
| 31 |
+
|
| 32 |
+
perturbed_data = PrePtbDataset(root='/tmp/', name=args.dataset, attack_method='meta', ptb_rate=args.ptb_rate)
|
| 33 |
+
perturbed_adj = (perturbed_data.adj + perturbed_data.adj.T) / 2
|
| 34 |
+
|
| 35 |
+
def save_cg_scores(cg_scores, filename="cg_scores.npy"):
|
| 36 |
+
np.save(filename, cg_scores)
|
| 37 |
+
print(f"CG-scores saved to {filename}")
|
| 38 |
+
|
| 39 |
+
def calc_cg_score_gnn_with_sampling(A, X, labels, device, rep_num=1, unbalance_ratio=1, batch_size=1024, node_filter=None):
|
| 40 |
+
N = A.shape[0]
|
| 41 |
+
cg_scores = {"vi": np.zeros((N, N)), "times": np.zeros((N, N))}
|
| 42 |
+
A, X, labels = A.to(device), X.to(device), labels.to(device)
|
| 43 |
+
|
| 44 |
+
@torch.no_grad()
|
| 45 |
+
def normalize(tensor):
|
| 46 |
+
return tensor / (torch.norm(tensor, dim=1, keepdim=True) + 1e-8)
|
| 47 |
+
|
| 48 |
+
for _ in range(rep_num):
|
| 49 |
+
AX = torch.matmul(A, X)
|
| 50 |
+
norm_AX = normalize(AX)
|
| 51 |
+
|
| 52 |
+
unique_labels = torch.unique(labels)
|
| 53 |
+
label_to_indices = {label.item(): (labels == label).nonzero(as_tuple=True)[0] for label in unique_labels}
|
| 54 |
+
dataset = {label: norm_AX[idx] for label, idx in label_to_indices.items()}
|
| 55 |
+
|
| 56 |
+
neg_samples_dict = {}
|
| 57 |
+
neg_indices_dict = {}
|
| 58 |
+
for label in unique_labels:
|
| 59 |
+
label = label.item()
|
| 60 |
+
mask = labels != label
|
| 61 |
+
neg_samples_dict[label] = norm_AX[mask]
|
| 62 |
+
neg_indices_dict[label] = mask.nonzero(as_tuple=True)[0]
|
| 63 |
+
|
| 64 |
+
if node_filter is not None:
|
| 65 |
+
node_filter = set(node_filter.tolist())
|
| 66 |
+
else:
|
| 67 |
+
node_filter = set(range(labels.size(0)))
|
| 68 |
+
|
| 69 |
+
for curr_label in tqdm(unique_labels, desc="Label groups", position=device.index):
|
| 70 |
+
label_id = int(curr_label)
|
| 71 |
+
curr_samples = dataset[label_id]
|
| 72 |
+
curr_indices = label_to_indices[label_id]
|
| 73 |
+
curr_num = len(curr_samples)
|
| 74 |
+
|
| 75 |
+
chosen_curr_idx = torch.randperm(curr_num, device=device)
|
| 76 |
+
pos_samples = curr_samples[chosen_curr_idx]
|
| 77 |
+
pos_indices = curr_indices[chosen_curr_idx]
|
| 78 |
+
|
| 79 |
+
neg_samples = neg_samples_dict[label_id]
|
| 80 |
+
neg_indices = neg_indices_dict[label_id]
|
| 81 |
+
neg_num = min(int(curr_num * unbalance_ratio), len(neg_samples))
|
| 82 |
+
rand_idx = torch.randperm(len(neg_samples), device=device)[:neg_num]
|
| 83 |
+
neg_samples = neg_samples[rand_idx]
|
| 84 |
+
neg_indices = neg_indices[rand_idx]
|
| 85 |
+
|
| 86 |
+
sample_idx = pos_indices.tolist() + neg_indices.tolist()
|
| 87 |
+
sample_tensor = norm_AX[sample_idx] # [M, F]
|
| 88 |
+
y = torch.cat([
|
| 89 |
+
torch.ones(len(pos_samples)),
|
| 90 |
+
-torch.ones(len(neg_samples))
|
| 91 |
+
], dim=0).to(device)
|
| 92 |
+
|
| 93 |
+
with autocast():
|
| 94 |
+
H_inner = torch.matmul(sample_tensor, sample_tensor.T).clamp(-1.0, 1.0)
|
| 95 |
+
H_base = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi)
|
| 96 |
+
H_base.fill_diagonal_(0.5)
|
| 97 |
+
H_base += 1e-6 * torch.eye(H_base.size(0), device=device)
|
| 98 |
+
ref_error = torch.dot(y, torch.linalg.solve(H_base, y))
|
| 99 |
+
|
| 100 |
+
edge_batch = [(i.item(), j)
|
| 101 |
+
for i in pos_indices if i.item() in node_filter
|
| 102 |
+
for j in range(i.item() + 1, N) if A[i, j] != 0]
|
| 103 |
+
|
| 104 |
+
for k in tqdm(range(0, len(edge_batch), batch_size), desc="Edge batches", leave=False, position=device.index):
|
| 105 |
+
batch = edge_batch[k:k + batch_size]
|
| 106 |
+
if not batch: continue
|
| 107 |
+
i_idx, j_idx = zip(*batch)
|
| 108 |
+
i_idx = torch.tensor(i_idx, device=device)
|
| 109 |
+
j_idx = torch.tensor(j_idx, device=device)
|
| 110 |
+
|
| 111 |
+
AX1_i = AX[i_idx] - A[i_idx, j_idx].unsqueeze(1) * X[j_idx]
|
| 112 |
+
AX1_j = AX[j_idx] - A[j_idx, i_idx].unsqueeze(1) * X[i_idx]
|
| 113 |
+
norm_AX1_i = normalize(AX1_i)
|
| 114 |
+
norm_AX1_j = normalize(AX1_j)
|
| 115 |
+
|
| 116 |
+
for b, (i, j) in enumerate(batch):
|
| 117 |
+
i_int, j_int = i, j
|
| 118 |
+
sample_tensor_copy = sample_tensor.clone()
|
| 119 |
+
try:
|
| 120 |
+
i_pos = sample_idx.index(i_int)
|
| 121 |
+
sample_tensor_copy[i_pos] = norm_AX1_i[b]
|
| 122 |
+
except ValueError:
|
| 123 |
+
pass
|
| 124 |
+
try:
|
| 125 |
+
j_pos = sample_idx.index(j_int)
|
| 126 |
+
sample_tensor_copy[j_pos] = norm_AX1_j[b]
|
| 127 |
+
except ValueError:
|
| 128 |
+
pass
|
| 129 |
+
|
| 130 |
+
with autocast():
|
| 131 |
+
H_inner = torch.matmul(sample_tensor_copy, sample_tensor_copy.T).clamp(-1.0, 1.0)
|
| 132 |
+
H = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi)
|
| 133 |
+
H.fill_diagonal_(0.5)
|
| 134 |
+
H += 1e-6 * torch.eye(H.size(0), device=device)
|
| 135 |
+
sol = torch.linalg.solve(H, y)
|
| 136 |
+
err_new = torch.dot(y, sol)
|
| 137 |
+
|
| 138 |
+
score = (ref_error - err_new).item()
|
| 139 |
+
cg_scores["vi"][i, j] += score
|
| 140 |
+
cg_scores["vi"][j, i] = score
|
| 141 |
+
cg_scores["times"][i, j] += 1
|
| 142 |
+
cg_scores["times"][j, i] += 1
|
| 143 |
+
|
| 144 |
+
for key in ["vi"]:
|
| 145 |
+
cg_scores[key] = cg_scores[key] / np.where(cg_scores["times"] > 0, cg_scores["times"], 1)
|
| 146 |
+
|
| 147 |
+
return cg_scores
|
| 148 |
+
|
| 149 |
+
def run_worker(gpu_id, world_size, A, X, labels, rep_num, unbalance_ratio, batch_size, return_dict):
|
| 150 |
+
device = torch.device(f"cuda:{gpu_id}")
|
| 151 |
+
|
| 152 |
+
# 用 node ids 划分代替 label 分片
|
| 153 |
+
node_ids = torch.arange(labels.size(0))
|
| 154 |
+
node_chunks = np.array_split(node_ids.numpy(), world_size)
|
| 155 |
+
node_filter = torch.tensor(node_chunks[gpu_id], device=device)
|
| 156 |
+
|
| 157 |
+
result = calc_cg_score_gnn_with_sampling(
|
| 158 |
+
A, X, labels, device,
|
| 159 |
+
rep_num=rep_num,
|
| 160 |
+
unbalance_ratio=unbalance_ratio,
|
| 161 |
+
batch_size=batch_size,
|
| 162 |
+
node_filter=node_filter # 👈 改名字
|
| 163 |
+
)
|
| 164 |
+
return_dict[gpu_id] = result
|
| 165 |
+
|
| 166 |
+
def multi_gpu_wrapper(A, X, labels, rep_num=1, unbalance_ratio=1, batch_size=1024, gpu_ids=None):
|
| 167 |
+
if gpu_ids is None:
|
| 168 |
+
gpu_ids = list(range(torch.cuda.device_count()))
|
| 169 |
+
world_size = len(gpu_ids)
|
| 170 |
+
|
| 171 |
+
mp.set_start_method("spawn", force=True)
|
| 172 |
+
manager = mp.Manager()
|
| 173 |
+
return_dict = manager.dict()
|
| 174 |
+
processes = []
|
| 175 |
+
|
| 176 |
+
for i, gpu_id in enumerate(gpu_ids):
|
| 177 |
+
p = mp.Process(target=run_worker, args=(gpu_id, world_size, A, X, labels, rep_num, unbalance_ratio, batch_size, return_dict))
|
| 178 |
+
p.start()
|
| 179 |
+
processes.append(p)
|
| 180 |
+
|
| 181 |
+
for p in processes:
|
| 182 |
+
p.join()
|
| 183 |
+
|
| 184 |
+
final_score = None
|
| 185 |
+
for res in return_dict.values():
|
| 186 |
+
if final_score is None:
|
| 187 |
+
final_score = {k: np.copy(v) for k, v in res.items()}
|
| 188 |
+
else:
|
| 189 |
+
for k in res:
|
| 190 |
+
final_score[k] += res[k]
|
| 191 |
+
return final_score
|
| 192 |
+
|
| 193 |
+
if __name__ == "__main__":
|
| 194 |
+
features, perturbed_adj, labels = utils.to_tensor(features, perturbed_adj, labels)
|
| 195 |
+
features = features.to_dense()
|
| 196 |
+
if utils.is_sparse_tensor(perturbed_adj):
|
| 197 |
+
perturbed_adj = utils.normalize_adj_tensor(perturbed_adj, sparse=True)
|
| 198 |
+
perturbed_adj = perturbed_adj.to_dense()
|
| 199 |
+
|
| 200 |
+
selected_gpus = [0,1,2,3]
|
| 201 |
+
cg_scores = multi_gpu_wrapper(perturbed_adj, features, labels,
|
| 202 |
+
rep_num=1,
|
| 203 |
+
unbalance_ratio=3,
|
| 204 |
+
batch_size=40280,
|
| 205 |
+
gpu_ids=selected_gpus)
|
| 206 |
+
|
| 207 |
+
save_cg_scores(cg_scores["vi"], filename=f"{args.dataset}_{args.ptb_rate}.npy")
|
| 208 |
+
print("🎉 CG-score computation completed.")
|
examples/graph/cgscore_env.yaml
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: cgscore
|
| 2 |
+
channels:
|
| 3 |
+
- defaults
|
| 4 |
+
dependencies:
|
| 5 |
+
- _libgcc_mutex=0.1=main
|
| 6 |
+
- _openmp_mutex=5.1=1_gnu
|
| 7 |
+
- ca-certificates=2024.11.26=h06a4308_0
|
| 8 |
+
- ld_impl_linux-64=2.40=h12ee557_0
|
| 9 |
+
- libffi=3.4.4=h6a678d5_1
|
| 10 |
+
- libgcc-ng=11.2.0=h1234567_1
|
| 11 |
+
- libgomp=11.2.0=h1234567_1
|
| 12 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
| 13 |
+
- ncurses=6.4=h6a678d5_0
|
| 14 |
+
- openssl=3.0.15=h5eee18b_0
|
| 15 |
+
- python=3.9.21=he870216_1
|
| 16 |
+
- readline=8.2=h5eee18b_0
|
| 17 |
+
- sqlite=3.45.3=h5eee18b_0
|
| 18 |
+
- tk=8.6.14=h39e8969_0
|
| 19 |
+
- wheel=0.44.0=py39h06a4308_0
|
| 20 |
+
- xz=5.4.6=h5eee18b_1
|
| 21 |
+
- zlib=1.2.13=h5eee18b_1
|
| 22 |
+
- pip:
|
| 23 |
+
- aiohappyeyeballs==2.4.4
|
| 24 |
+
- aiohttp==3.11.11
|
| 25 |
+
- aiosignal==1.3.2
|
| 26 |
+
- alabaster==0.7.16
|
| 27 |
+
- alembic==1.15.1
|
| 28 |
+
- asttokens==3.0.0
|
| 29 |
+
- async-timeout==5.0.1
|
| 30 |
+
- attrs==24.3.0
|
| 31 |
+
- autopage==0.5.2
|
| 32 |
+
- babel==2.17.0
|
| 33 |
+
- certifi==2025.1.31
|
| 34 |
+
- cfgv==3.4.0
|
| 35 |
+
- charset-normalizer==3.4.1
|
| 36 |
+
- cliff==4.9.1
|
| 37 |
+
- cmaes==0.11.1
|
| 38 |
+
- cmake==3.31.2
|
| 39 |
+
- cmd2==2.5.11
|
| 40 |
+
- cogdl==0.6
|
| 41 |
+
- colorlog==6.9.0
|
| 42 |
+
- commonmark==0.9.1
|
| 43 |
+
- contourpy==1.3.0
|
| 44 |
+
- cycler==0.12.1
|
| 45 |
+
- decorator==5.1.1
|
| 46 |
+
- distlib==0.3.9
|
| 47 |
+
- docutils==0.21.2
|
| 48 |
+
- exceptiongroup==1.2.2
|
| 49 |
+
- executing==2.1.0
|
| 50 |
+
- filelock==3.18.0
|
| 51 |
+
- flake8==7.1.2
|
| 52 |
+
- fonttools==4.56.0
|
| 53 |
+
- frozenlist==1.5.0
|
| 54 |
+
- fsspec==2025.3.0
|
| 55 |
+
- fst-pso==1.8.1
|
| 56 |
+
- fuzzytm==2.0.9
|
| 57 |
+
- gensim==4.3.0
|
| 58 |
+
- grave==0.0.3
|
| 59 |
+
- grb==0.1.0
|
| 60 |
+
- greenlet==3.1.1
|
| 61 |
+
- huggingface-hub==0.29.3
|
| 62 |
+
- identify==2.6.9
|
| 63 |
+
- idna==3.10
|
| 64 |
+
- imageio==2.36.1
|
| 65 |
+
- imagesize==1.4.1
|
| 66 |
+
- importlib-metadata==4.13.0
|
| 67 |
+
- importlib-resources==6.5.2
|
| 68 |
+
- ipdb==0.13.13
|
| 69 |
+
- ipython==8.18.1
|
| 70 |
+
- jedi==0.19.2
|
| 71 |
+
- jinja2==3.1.6
|
| 72 |
+
- joblib==1.4.2
|
| 73 |
+
- kiwisolver==1.4.7
|
| 74 |
+
- lazy-loader==0.4
|
| 75 |
+
- lit==18.1.8
|
| 76 |
+
- littleutils==0.2.4
|
| 77 |
+
- llvmlite==0.40.1
|
| 78 |
+
- mako==1.3.9
|
| 79 |
+
- markupsafe==3.0.2
|
| 80 |
+
- matplotlib==3.6.0
|
| 81 |
+
- matplotlib-inline==0.1.7
|
| 82 |
+
- mccabe==0.7.0
|
| 83 |
+
- miniful==0.0.6
|
| 84 |
+
- mpmath==1.3.0
|
| 85 |
+
- multidict==6.1.0
|
| 86 |
+
- networkx==3.0
|
| 87 |
+
- ninja==1.11.1.4
|
| 88 |
+
- nodeenv==1.9.1
|
| 89 |
+
- numba==0.57.1
|
| 90 |
+
- numpy==1.23.5
|
| 91 |
+
- nvidia-cublas-cu11==11.10.3.66
|
| 92 |
+
- nvidia-cublas-cu12==12.1.3.1
|
| 93 |
+
- nvidia-cuda-cupti-cu11==11.7.101
|
| 94 |
+
- nvidia-cuda-cupti-cu12==12.1.105
|
| 95 |
+
- nvidia-cuda-nvrtc-cu11==11.7.99
|
| 96 |
+
- nvidia-cuda-nvrtc-cu12==12.1.105
|
| 97 |
+
- nvidia-cuda-runtime-cu11==11.7.99
|
| 98 |
+
- nvidia-cuda-runtime-cu12==12.1.105
|
| 99 |
+
- nvidia-cudnn-cu11==8.5.0.96
|
| 100 |
+
- nvidia-cudnn-cu12==8.9.2.26
|
| 101 |
+
- nvidia-cufft-cu11==10.9.0.58
|
| 102 |
+
- nvidia-cufft-cu12==11.0.2.54
|
| 103 |
+
- nvidia-curand-cu11==10.2.10.91
|
| 104 |
+
- nvidia-curand-cu12==10.3.2.106
|
| 105 |
+
- nvidia-cusolver-cu11==11.4.0.1
|
| 106 |
+
- nvidia-cusolver-cu12==11.4.5.107
|
| 107 |
+
- nvidia-cusparse-cu11==11.7.4.91
|
| 108 |
+
- nvidia-cusparse-cu12==12.1.0.106
|
| 109 |
+
- nvidia-cusparselt-cu12==0.6.2
|
| 110 |
+
- nvidia-nccl-cu11==2.14.3
|
| 111 |
+
- nvidia-nccl-cu12==2.20.5
|
| 112 |
+
- nvidia-nvjitlink-cu12==12.4.127
|
| 113 |
+
- nvidia-nvtx-cu11==11.7.91
|
| 114 |
+
- nvidia-nvtx-cu12==12.1.105
|
| 115 |
+
- ogb==1.3.6
|
| 116 |
+
- optuna==2.4.0
|
| 117 |
+
- outdated==0.2.2
|
| 118 |
+
- packaging==24.2
|
| 119 |
+
- pandas==2.2.3
|
| 120 |
+
- parso==0.8.4
|
| 121 |
+
- pbr==6.1.1
|
| 122 |
+
- pexpect==4.9.0
|
| 123 |
+
- pillow==9.4.0
|
| 124 |
+
- pip==25.0.1
|
| 125 |
+
- platformdirs==4.3.7
|
| 126 |
+
- pre-commit==4.2.0
|
| 127 |
+
- prettytable==3.16.0
|
| 128 |
+
- prompt-toolkit==3.0.48
|
| 129 |
+
- propcache==0.2.1
|
| 130 |
+
- protobuf==3.20.3
|
| 131 |
+
- psutil==6.1.1
|
| 132 |
+
- ptyprocess==0.7.0
|
| 133 |
+
- pure-eval==0.2.3
|
| 134 |
+
- pycodestyle==2.12.1
|
| 135 |
+
- pyflakes==3.2.0
|
| 136 |
+
- pyfume==0.3.1
|
| 137 |
+
- pygments==2.19.1
|
| 138 |
+
- pyparsing==3.2.3
|
| 139 |
+
- pyperclip==1.9.0
|
| 140 |
+
- python-dateutil==2.9.0.post0
|
| 141 |
+
- pytz==2025.2
|
| 142 |
+
- pyyaml==6.0.2
|
| 143 |
+
- recommonmark==0.7.1
|
| 144 |
+
- regex==2024.11.6
|
| 145 |
+
- requests==2.32.3
|
| 146 |
+
- safetensors==0.5.3
|
| 147 |
+
- scikit-image==0.24.0
|
| 148 |
+
- scikit-learn==1.2.0
|
| 149 |
+
- scipy==1.11.3
|
| 150 |
+
- sentencepiece==0.2.0
|
| 151 |
+
- setuptools==78.1.0
|
| 152 |
+
- simpful==2.12.0
|
| 153 |
+
- six==1.17.0
|
| 154 |
+
- smart-open==7.1.0
|
| 155 |
+
- snowballstemmer==2.2.0
|
| 156 |
+
- sphinx==7.3.7
|
| 157 |
+
- sphinxcontrib-applehelp==2.0.0
|
| 158 |
+
- sphinxcontrib-devhelp==2.0.0
|
| 159 |
+
- sphinxcontrib-htmlhelp==2.1.0
|
| 160 |
+
- sphinxcontrib-jsmath==1.0.1
|
| 161 |
+
- sphinxcontrib-qthelp==2.0.0
|
| 162 |
+
- sphinxcontrib-serializinghtml==2.0.0
|
| 163 |
+
- sqlalchemy==2.0.40
|
| 164 |
+
- stack-data==0.6.3
|
| 165 |
+
- stevedore==5.4.1
|
| 166 |
+
- sympy==1.13.1
|
| 167 |
+
- tabulate==0.9.0
|
| 168 |
+
- tensorboardx==2.6
|
| 169 |
+
- texttable==1.6.7
|
| 170 |
+
- threadpoolctl==3.6.0
|
| 171 |
+
- tifffile==2024.8.30
|
| 172 |
+
- tokenizers==0.21.1
|
| 173 |
+
- tomli==2.2.1
|
| 174 |
+
- torch==2.3.0+cu121
|
| 175 |
+
- torch-cluster==1.6.3+pt23cu121
|
| 176 |
+
- torch-geometric==2.6.1
|
| 177 |
+
- torch-scatter==2.1.2+pt23cu121
|
| 178 |
+
- torch-sparse==0.6.18+pt23cu121
|
| 179 |
+
- torch-spline-conv==1.2.2+pt23cu121
|
| 180 |
+
- torchvision==0.18.0+cu121
|
| 181 |
+
- tqdm==4.64.1
|
| 182 |
+
- traitlets==5.14.3
|
| 183 |
+
- transformers==4.50.2
|
| 184 |
+
- triton==2.3.0
|
| 185 |
+
- typing-extensions==4.13.0
|
| 186 |
+
- tzdata==2025.2
|
| 187 |
+
- urllib3==2.3.0
|
| 188 |
+
- virtualenv==20.29.3
|
| 189 |
+
- wcwidth==0.2.13
|
| 190 |
+
- wrapt==1.17.2
|
| 191 |
+
- yarl==1.18.3
|
| 192 |
+
- zipp==3.21.0
|
| 193 |
+
prefix: /home/yiren/new_ssd2/chunhui/miniconda/envs/cgscore
|
examples/graph/cgscore_experiments/attack_method/attack_minmax.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torch.optim as optim
|
| 5 |
+
from deeprobust.graph.defense import GCN
|
| 6 |
+
from deeprobust.graph.global_attack import MinMax
|
| 7 |
+
from deeprobust.graph.utils import *
|
| 8 |
+
from deeprobust.graph.data import Dataset
|
| 9 |
+
import argparse
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
parser = argparse.ArgumentParser()
|
| 13 |
+
parser.add_argument('--seed', type=int, default=15, help='Random seed.')
|
| 14 |
+
parser.add_argument('--epochs', type=int, default=200,
|
| 15 |
+
help='Number of epochs to train.')
|
| 16 |
+
parser.add_argument('--lr', type=float, default=0.01,
|
| 17 |
+
help='Initial learning rate.')
|
| 18 |
+
parser.add_argument('--weight_decay', type=float, default=5e-4,
|
| 19 |
+
help='Weight decay (L2 loss on parameters).')
|
| 20 |
+
parser.add_argument('--hidden', type=int, default=16,
|
| 21 |
+
help='Number of hidden units.')
|
| 22 |
+
parser.add_argument('--dropout', type=float, default=0.5,
|
| 23 |
+
help='Dropout rate (1 - keep probability).')
|
| 24 |
+
|
| 25 |
+
parser.add_argument('--dataset', type=str, default='cora', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed', 'Flickr'], help='dataset')
|
| 26 |
+
parser.add_argument('--ptb_rate', type=float, default=0.25, help='pertubation rate')
|
| 27 |
+
parser.add_argument('--ptb_type', type=str, default='minmax', choices=['clean', 'meta', 'dice', 'minmax', 'pgd', 'random'], help='attack type')
|
| 28 |
+
parser.add_argument('--model', type=str, default='min-max', choices=['PGD', 'min-max'], help='model variant')
|
| 29 |
+
|
| 30 |
+
args = parser.parse_args()
|
| 31 |
+
|
| 32 |
+
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
|
| 33 |
+
|
| 34 |
+
np.random.seed(args.seed)
|
| 35 |
+
torch.manual_seed(args.seed)
|
| 36 |
+
if device != 'cpu':
|
| 37 |
+
torch.cuda.manual_seed(args.seed)
|
| 38 |
+
|
| 39 |
+
# data = Dataset(root='/tmp/', name=args.dataset, setting='prognn')
|
| 40 |
+
data = Dataset(root='/tmp/', name=args.dataset)
|
| 41 |
+
adj, features, labels = data.adj, data.features, data.labels
|
| 42 |
+
# features = normalize_feature(features)
|
| 43 |
+
|
| 44 |
+
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
| 45 |
+
idx_unlabeled = np.union1d(idx_val, idx_test)
|
| 46 |
+
|
| 47 |
+
perturbations = int(args.ptb_rate * (adj.sum()//2))
|
| 48 |
+
|
| 49 |
+
adj, features, labels = preprocess(adj, features, labels, preprocess_adj=False)
|
| 50 |
+
|
| 51 |
+
# Setup Victim Model
|
| 52 |
+
victim_model = GCN(nfeat=features.shape[1], nclass=labels.max().item()+1, nhid=16,
|
| 53 |
+
dropout=0.5, weight_decay=5e-4, device=device)
|
| 54 |
+
|
| 55 |
+
victim_model = victim_model.to(device)
|
| 56 |
+
victim_model.fit(features, adj, labels, idx_train)
|
| 57 |
+
|
| 58 |
+
# Setup Attack Model
|
| 59 |
+
|
| 60 |
+
model = MinMax(model=victim_model, nnodes=adj.shape[0], loss_type='CE', device=device)
|
| 61 |
+
|
| 62 |
+
model = model.to(device)
|
| 63 |
+
|
| 64 |
+
def test(adj):
|
| 65 |
+
''' test on GCN '''
|
| 66 |
+
|
| 67 |
+
# adj = normalize_adj_tensor(adj)
|
| 68 |
+
gcn = GCN(nfeat=features.shape[1],
|
| 69 |
+
nhid=args.hidden,
|
| 70 |
+
nclass=labels.max().item() + 1,
|
| 71 |
+
dropout=args.dropout, device=device)
|
| 72 |
+
gcn = gcn.to(device)
|
| 73 |
+
gcn.fit(features, adj, labels, idx_train) # train without model picking
|
| 74 |
+
# gcn.fit(features, adj, labels, idx_train, idx_val) # train with validation model picking
|
| 75 |
+
output = gcn.output.cpu()
|
| 76 |
+
loss_test = F.nll_loss(output[idx_test], labels[idx_test])
|
| 77 |
+
acc_test = accuracy(output[idx_test], labels[idx_test])
|
| 78 |
+
print("Test set results:",
|
| 79 |
+
"loss= {:.4f}".format(loss_test.item()),
|
| 80 |
+
"accuracy= {:.4f}".format(acc_test.item()))
|
| 81 |
+
|
| 82 |
+
return acc_test.item()
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def main():
|
| 86 |
+
model.attack(features, adj, labels, idx_train, perturbations)
|
| 87 |
+
print('=== testing GCN on original(clean) graph ===')
|
| 88 |
+
test(adj)
|
| 89 |
+
modified_adj = model.modified_adj
|
| 90 |
+
# modified_features = model.modified_features
|
| 91 |
+
|
| 92 |
+
save_dir = f"../attacked_adj/{args.dataset}"
|
| 93 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 94 |
+
|
| 95 |
+
save_path = os.path.join(save_dir, f"{args.ptb_type}_{args.dataset}_{args.ptb_rate}.pt")
|
| 96 |
+
torch.save(modified_adj, save_path)
|
| 97 |
+
|
| 98 |
+
test(modified_adj)
|
| 99 |
+
|
| 100 |
+
# # if you want to save the modified adj/features, uncomment the code below
|
| 101 |
+
# model.save_adj(root='./', name=f'mod_adj')
|
| 102 |
+
# model.save_features(root='./', name='mod_features')
|
| 103 |
+
|
| 104 |
+
if __name__ == '__main__':
|
| 105 |
+
main()
|
| 106 |
+
|
examples/graph/cgscore_experiments/attack_method/attack_nettack.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torch.optim as optim
|
| 5 |
+
from deeprobust.graph.defense import GCN
|
| 6 |
+
from deeprobust.graph.targeted_attack import Nettack
|
| 7 |
+
from deeprobust.graph.utils import *
|
| 8 |
+
from deeprobust.graph.data import Dataset
|
| 9 |
+
import argparse
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
parser = argparse.ArgumentParser()
|
| 13 |
+
parser.add_argument('--seed', type=int, default=15, help='Random seed.')
|
| 14 |
+
parser.add_argument('--dataset', type=str, default='citeseer', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'], help='dataset')
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
args = parser.parse_args()
|
| 18 |
+
args.cuda = torch.cuda.is_available()
|
| 19 |
+
print('cuda: %s' % args.cuda)
|
| 20 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 21 |
+
|
| 22 |
+
np.random.seed(args.seed)
|
| 23 |
+
torch.manual_seed(args.seed)
|
| 24 |
+
if args.cuda:
|
| 25 |
+
torch.cuda.manual_seed(args.seed)
|
| 26 |
+
|
| 27 |
+
data = Dataset(root='/tmp/', name=args.dataset)
|
| 28 |
+
adj, features, labels = data.adj, data.features, data.labels
|
| 29 |
+
|
| 30 |
+
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
| 31 |
+
|
| 32 |
+
idx_unlabeled = np.union1d(idx_val, idx_test)
|
| 33 |
+
|
| 34 |
+
# Setup Surrogate model
|
| 35 |
+
surrogate = GCN(nfeat=features.shape[1], nclass=labels.max().item()+1,
|
| 36 |
+
nhid=16, dropout=0, with_relu=False, with_bias=False, device=device)
|
| 37 |
+
|
| 38 |
+
surrogate = surrogate.to(device)
|
| 39 |
+
surrogate.fit(features, adj, labels, idx_train, idx_val, patience=30)
|
| 40 |
+
|
| 41 |
+
# Setup Attack Model
|
| 42 |
+
target_node = 0
|
| 43 |
+
assert target_node in idx_unlabeled
|
| 44 |
+
|
| 45 |
+
model = Nettack(surrogate, nnodes=adj.shape[0], attack_structure=True, attack_features=True, device=device)
|
| 46 |
+
model = model.to(device)
|
| 47 |
+
|
| 48 |
+
def main():
|
| 49 |
+
degrees = adj.sum(0).A1
|
| 50 |
+
# How many perturbations to perform. Default: Degree of the node
|
| 51 |
+
n_perturbations = int(degrees[target_node])
|
| 52 |
+
|
| 53 |
+
# direct attack
|
| 54 |
+
model.attack(features, adj, labels, target_node, n_perturbations)
|
| 55 |
+
# # indirect attack/ influencer attack
|
| 56 |
+
# model.attack(features, adj, labels, target_node, n_perturbations, direct=False, n_influencers=5)
|
| 57 |
+
modified_adj = model.modified_adj
|
| 58 |
+
modified_features = model.modified_features
|
| 59 |
+
print(model.structure_perturbations)
|
| 60 |
+
print('=== testing GCN on original(clean) graph ===')
|
| 61 |
+
test(adj, features, target_node)
|
| 62 |
+
print('=== testing GCN on perturbed graph ===')
|
| 63 |
+
test(modified_adj, modified_features, target_node)
|
| 64 |
+
|
| 65 |
+
def test(adj, features, target_node):
|
| 66 |
+
''' test on GCN '''
|
| 67 |
+
gcn = GCN(nfeat=features.shape[1],
|
| 68 |
+
nhid=16,
|
| 69 |
+
nclass=labels.max().item() + 1,
|
| 70 |
+
dropout=0.5, device=device)
|
| 71 |
+
|
| 72 |
+
gcn = gcn.to(device)
|
| 73 |
+
|
| 74 |
+
gcn.fit(features, adj, labels, idx_train, idx_val, patience=30)
|
| 75 |
+
|
| 76 |
+
gcn.eval()
|
| 77 |
+
output = gcn.predict()
|
| 78 |
+
probs = torch.exp(output[[target_node]])[0]
|
| 79 |
+
print('Target node probs: {}'.format(probs.detach().cpu().numpy()))
|
| 80 |
+
acc_test = accuracy(output[idx_test], labels[idx_test])
|
| 81 |
+
|
| 82 |
+
print("Overall test set results:",
|
| 83 |
+
"accuracy= {:.4f}".format(acc_test.item()))
|
| 84 |
+
|
| 85 |
+
return acc_test.item()
|
| 86 |
+
|
| 87 |
+
def select_nodes(target_gcn=None):
|
| 88 |
+
'''
|
| 89 |
+
selecting nodes as reported in nettack paper:
|
| 90 |
+
(i) the 10 nodes with highest margin of classification, i.e. they are clearly correctly classified,
|
| 91 |
+
(ii) the 10 nodes with lowest margin (but still correctly classified) and
|
| 92 |
+
(iii) 20 more nodes randomly
|
| 93 |
+
'''
|
| 94 |
+
|
| 95 |
+
if target_gcn is None:
|
| 96 |
+
target_gcn = GCN(nfeat=features.shape[1],
|
| 97 |
+
nhid=16,
|
| 98 |
+
nclass=labels.max().item() + 1,
|
| 99 |
+
dropout=0.5, device=device)
|
| 100 |
+
target_gcn = target_gcn.to(device)
|
| 101 |
+
target_gcn.fit(features, adj, labels, idx_train, idx_val, patience=30)
|
| 102 |
+
target_gcn.eval()
|
| 103 |
+
output = target_gcn.predict()
|
| 104 |
+
|
| 105 |
+
margin_dict = {}
|
| 106 |
+
for idx in idx_test:
|
| 107 |
+
margin = classification_margin(output[idx], labels[idx])
|
| 108 |
+
if margin < 0: # only keep the nodes correctly classified
|
| 109 |
+
continue
|
| 110 |
+
margin_dict[idx] = margin
|
| 111 |
+
sorted_margins = sorted(margin_dict.items(), key=lambda x:x[1], reverse=True)
|
| 112 |
+
high = [x for x, y in sorted_margins[: 10]]
|
| 113 |
+
low = [x for x, y in sorted_margins[-10: ]]
|
| 114 |
+
other = [x for x, y in sorted_margins[10: -10]]
|
| 115 |
+
other = np.random.choice(other, 20, replace=False).tolist()
|
| 116 |
+
|
| 117 |
+
return high + low + other
|
| 118 |
+
|
| 119 |
+
def multi_test_poison_accumulative():
|
| 120 |
+
degrees = adj.sum(0).A1
|
| 121 |
+
node_list = select_nodes()
|
| 122 |
+
print("=== [Poisoning Accumulative] Attacking {} nodes sequentially ===".format(len(node_list)))
|
| 123 |
+
|
| 124 |
+
# 初始化结构
|
| 125 |
+
adj_attacked = adj.copy()
|
| 126 |
+
|
| 127 |
+
for target_node in tqdm(node_list):
|
| 128 |
+
n_perturbations = int(degrees[target_node])
|
| 129 |
+
model = Nettack(surrogate, nnodes=adj.shape[0],
|
| 130 |
+
attack_structure=True, attack_features=False, device=device)
|
| 131 |
+
model = model.to(device)
|
| 132 |
+
model.attack(features, adj_attacked, labels, target_node, n_perturbations, verbose=False)
|
| 133 |
+
adj_attacked = model.modified_adj
|
| 134 |
+
|
| 135 |
+
gcn = GCN(nfeat=features.shape[1],
|
| 136 |
+
nhid=16,
|
| 137 |
+
nclass=labels.max().item() + 1,
|
| 138 |
+
dropout=0.5, device=device)
|
| 139 |
+
gcn = gcn.to(device)
|
| 140 |
+
gcn.fit(features, adj_attacked, labels, idx_train, idx_val, patience=30)
|
| 141 |
+
gcn.eval()
|
| 142 |
+
output = gcn.predict()
|
| 143 |
+
|
| 144 |
+
# 在 node_list 上评估被误分类的节点数
|
| 145 |
+
preds = output.argmax(1)
|
| 146 |
+
node_preds = preds[node_list]
|
| 147 |
+
node_labels = labels[node_list]
|
| 148 |
+
|
| 149 |
+
# 转为 tensor,确保正确
|
| 150 |
+
correct_mask = torch.tensor(node_preds == node_labels, dtype=torch.bool)
|
| 151 |
+
correct = correct_mask.sum().item()
|
| 152 |
+
print("Accuracy on attacked nodes: {:.2f}%".format(100 * correct / len(node_list)))
|
| 153 |
+
|
| 154 |
+
def single_test(adj, features, target_node, gcn=None):
|
| 155 |
+
if gcn is None:
|
| 156 |
+
# test on GCN (poisoning attack)
|
| 157 |
+
gcn = GCN(nfeat=features.shape[1],
|
| 158 |
+
nhid=16,
|
| 159 |
+
nclass=labels.max().item() + 1,
|
| 160 |
+
dropout=0.5, device=device)
|
| 161 |
+
|
| 162 |
+
gcn = gcn.to(device)
|
| 163 |
+
|
| 164 |
+
gcn.fit(features, adj, labels, idx_train, idx_val, patience=30)
|
| 165 |
+
gcn.eval()
|
| 166 |
+
output = gcn.predict()
|
| 167 |
+
else:
|
| 168 |
+
# test on GCN (evasion attack)
|
| 169 |
+
output = gcn.predict(features, adj)
|
| 170 |
+
probs = torch.exp(output[[target_node]])
|
| 171 |
+
|
| 172 |
+
# acc_test = accuracy(output[[target_node]], labels[target_node])
|
| 173 |
+
acc_test = (output.argmax(1)[target_node] == labels[target_node])
|
| 174 |
+
return acc_test.item()
|
| 175 |
+
|
| 176 |
+
def multi_test_evasion():
|
| 177 |
+
# test on 40 nodes on evasion attack
|
| 178 |
+
target_gcn = GCN(nfeat=features.shape[1],
|
| 179 |
+
nhid=16,
|
| 180 |
+
nclass=labels.max().item() + 1,
|
| 181 |
+
dropout=0.5, device=device)
|
| 182 |
+
|
| 183 |
+
target_gcn = target_gcn.to(device)
|
| 184 |
+
|
| 185 |
+
target_gcn.fit(features, adj, labels, idx_train, idx_val, patience=30)
|
| 186 |
+
|
| 187 |
+
cnt = 0
|
| 188 |
+
degrees = adj.sum(0).A1
|
| 189 |
+
node_list = select_nodes(target_gcn)
|
| 190 |
+
num = len(node_list)
|
| 191 |
+
|
| 192 |
+
print('=== [Evasion] Attacking %s nodes respectively ===' % num)
|
| 193 |
+
for target_node in tqdm(node_list):
|
| 194 |
+
n_perturbations = int(degrees[target_node])
|
| 195 |
+
model = Nettack(surrogate, nnodes=adj.shape[0], attack_structure=True, attack_features=True, device=device)
|
| 196 |
+
model = model.to(device)
|
| 197 |
+
model.attack(features, adj, labels, target_node, n_perturbations, verbose=False)
|
| 198 |
+
modified_adj = model.modified_adj
|
| 199 |
+
modified_features = model.modified_features
|
| 200 |
+
|
| 201 |
+
acc = single_test(modified_adj, modified_features, target_node, gcn=target_gcn)
|
| 202 |
+
if acc == 0:
|
| 203 |
+
cnt += 1
|
| 204 |
+
print('misclassification rate : %s' % (cnt/num))
|
| 205 |
+
|
| 206 |
+
if __name__ == '__main__':
|
| 207 |
+
# main()
|
| 208 |
+
# multi_test_poison()
|
| 209 |
+
# multi_test_poison()
|
| 210 |
+
multi_test_poison_accumulative()
|
| 211 |
+
|
| 212 |
+
|
examples/graph/cgscore_experiments/defense_method/GAT.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import scipy.sparse as sp
|
| 5 |
+
from deeprobust.graph.defense import GCNJaccard, GCN, GAT
|
| 6 |
+
from deeprobust.graph.utils import *
|
| 7 |
+
from deeprobust.graph.data import Dataset, PrePtbDataset, Dpr2Pyg
|
| 8 |
+
import argparse
|
| 9 |
+
|
| 10 |
+
parser = argparse.ArgumentParser()
|
| 11 |
+
parser.add_argument('--seed', type=int, default=15, help='Random seed.')
|
| 12 |
+
parser.add_argument('--dataset', type=str, default='cora', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'], help='dataset')
|
| 13 |
+
parser.add_argument('--ptb_rate', type=float, default=0.0, help='pertubation rate')
|
| 14 |
+
parser.add_argument('--ptb_type', type=str, default='clean', choices=['clean', 'meta', 'dice', 'minmax', 'pgd', 'random'], help='attack type')
|
| 15 |
+
parser.add_argument('--hidden', type=int, default=16, help='Number of hidden units.')
|
| 16 |
+
parser.add_argument('--dropout', type=float, default=0.5, help='Dropout rate (1 - keep probability).')
|
| 17 |
+
|
| 18 |
+
args = parser.parse_args()
|
| 19 |
+
args.cuda = torch.cuda.is_available()
|
| 20 |
+
print('cuda: %s' % args.cuda)
|
| 21 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 22 |
+
|
| 23 |
+
# make sure you use the same data splits as you generated attacks
|
| 24 |
+
np.random.seed(args.seed)
|
| 25 |
+
if args.cuda:
|
| 26 |
+
torch.cuda.manual_seed(args.seed)
|
| 27 |
+
|
| 28 |
+
# Here the random seed is to split the train/val/test data,
|
| 29 |
+
# we need to set the random seed to be the same as that when you generate the perturbed graph
|
| 30 |
+
# data = Dataset(root='/tmp/', name=args.dataset, setting='nettack', seed=15)
|
| 31 |
+
# Or we can just use setting='prognn' to get the splits
|
| 32 |
+
|
| 33 |
+
# data = Dataset(root='/tmp/', name=args.dataset, setting='prognn')
|
| 34 |
+
data = Dataset(root='/tmp/', name=args.dataset)
|
| 35 |
+
adj, features, labels = data.adj, data.features, data.labels
|
| 36 |
+
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
| 37 |
+
idx_unlabeled = np.union1d(idx_val, idx_test)
|
| 38 |
+
# print(type((adj)))
|
| 39 |
+
# adj, features, labels = preprocess(adj, features, labels, preprocess_adj=False)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
## load from attacked_adj
|
| 43 |
+
ptb_path = f"../attacked_adj/{args.dataset}/{args.ptb_type}_{args.dataset}_{args.ptb_rate}.pt"
|
| 44 |
+
perturbed_adj = torch.load(ptb_path)
|
| 45 |
+
print("type(perturbed_adj)", type(perturbed_adj))
|
| 46 |
+
perturbed_adj = sp.csr_matrix(perturbed_adj.cpu().numpy())
|
| 47 |
+
|
| 48 |
+
gat = GAT(nfeat=features.shape[1],
|
| 49 |
+
nhid=8, heads=8,
|
| 50 |
+
nclass=labels.max().item() + 1,
|
| 51 |
+
dropout=0.5, device=device)
|
| 52 |
+
gat = gat.to(device)
|
| 53 |
+
|
| 54 |
+
# test on clean graph
|
| 55 |
+
print('==================')
|
| 56 |
+
print('=== train on clean graph ===')
|
| 57 |
+
|
| 58 |
+
pyg_data = Dpr2Pyg(data)
|
| 59 |
+
pyg_data.update_edge_index(perturbed_adj) # inplace operation
|
| 60 |
+
gat.fit(pyg_data, train_iters=200, verbose=True) # train with earlystopping
|
| 61 |
+
gat.test()
|
examples/graph/cgscore_experiments/defense_method/GCN.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from deeprobust.graph.defense import GCNJaccard, GCN
|
| 5 |
+
from deeprobust.graph.utils import *
|
| 6 |
+
from deeprobust.graph.data import Dataset, PrePtbDataset
|
| 7 |
+
import argparse
|
| 8 |
+
|
| 9 |
+
parser = argparse.ArgumentParser()
|
| 10 |
+
parser.add_argument('--seed', type=int, default=15, help='Random seed.')
|
| 11 |
+
parser.add_argument('--dataset', type=str, default='cora', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'], help='dataset')
|
| 12 |
+
parser.add_argument('--ptb_rate', type=float, default=0.0, help='pertubation rate')
|
| 13 |
+
parser.add_argument('--ptb_type', type=str, default='clean', choices=['clean', 'meta', 'dice', 'minmax', 'pgd', 'random'], help='attack type')
|
| 14 |
+
parser.add_argument('--hidden', type=int, default=16, help='Number of hidden units.')
|
| 15 |
+
parser.add_argument('--dropout', type=float, default=0.5, help='Dropout rate (1 - keep probability).')
|
| 16 |
+
|
| 17 |
+
args = parser.parse_args()
|
| 18 |
+
args.cuda = torch.cuda.is_available()
|
| 19 |
+
print('cuda: %s' % args.cuda)
|
| 20 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 21 |
+
|
| 22 |
+
# make sure you use the same data splits as you generated attacks
|
| 23 |
+
np.random.seed(args.seed)
|
| 24 |
+
if args.cuda:
|
| 25 |
+
torch.cuda.manual_seed(args.seed)
|
| 26 |
+
|
| 27 |
+
# Here the random seed is to split the train/val/test data,
|
| 28 |
+
# we need to set the random seed to be the same as that when you generate the perturbed graph
|
| 29 |
+
# data = Dataset(root='/tmp/', name=args.dataset, setting='nettack', seed=15)
|
| 30 |
+
# Or we can just use setting='prognn' to get the splits
|
| 31 |
+
|
| 32 |
+
# data = Dataset(root='/tmp/', name=args.dataset, setting='prognn')
|
| 33 |
+
data = Dataset(root='/tmp/', name=args.dataset)
|
| 34 |
+
adj, features, labels = data.adj, data.features, data.labels
|
| 35 |
+
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
| 36 |
+
idx_unlabeled = np.union1d(idx_val, idx_test)
|
| 37 |
+
adj, features, labels = preprocess(adj, features, labels, preprocess_adj=False)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
## load from attacked_adj
|
| 41 |
+
ptb_path = f"../attacked_adj/{args.dataset}/{args.ptb_type}_{args.dataset}_{args.ptb_rate}.pt"
|
| 42 |
+
perturbed_adj = torch.load(ptb_path)
|
| 43 |
+
|
| 44 |
+
def test(adj):
|
| 45 |
+
''' test on GCN '''
|
| 46 |
+
|
| 47 |
+
# adj = normalize_adj_tensor(adj)
|
| 48 |
+
gcn = GCN(nfeat=features.shape[1],
|
| 49 |
+
nhid=args.hidden,
|
| 50 |
+
nclass=labels.max().item() + 1,
|
| 51 |
+
dropout=args.dropout, device=device)
|
| 52 |
+
gcn = gcn.to(device)
|
| 53 |
+
gcn.fit(features, adj, labels, idx_train) # train without model picking
|
| 54 |
+
# gcn.fit(features, adj, labels, idx_train, idx_val) # train with validation model picking
|
| 55 |
+
output = gcn.output.cpu()
|
| 56 |
+
loss_test = F.nll_loss(output[idx_test], labels[idx_test])
|
| 57 |
+
acc_test = accuracy(output[idx_test], labels[idx_test])
|
| 58 |
+
print("Test set results:",
|
| 59 |
+
"loss= {:.4f}".format(loss_test.item()),
|
| 60 |
+
"accuracy= {:.4f}".format(acc_test.item()))
|
| 61 |
+
|
| 62 |
+
return acc_test.item()
|
| 63 |
+
|
| 64 |
+
def main():
|
| 65 |
+
|
| 66 |
+
test(perturbed_adj)
|
| 67 |
+
|
| 68 |
+
# # if you want to save the modified adj/features, uncomment the code below
|
| 69 |
+
# model.save_adj(root='./', name=f'mod_adj')
|
| 70 |
+
# model.save_features(root='./', name='mod_features')
|
| 71 |
+
|
| 72 |
+
if __name__ == '__main__':
|
| 73 |
+
main()
|
examples/graph/cgscore_experiments/defense_method/GCNJaccard.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from deeprobust.graph.defense import GCNJaccard, GCN
|
| 5 |
+
from deeprobust.graph.utils import *
|
| 6 |
+
from deeprobust.graph.data import Dataset, PrePtbDataset
|
| 7 |
+
import argparse
|
| 8 |
+
|
| 9 |
+
parser = argparse.ArgumentParser()
|
| 10 |
+
parser.add_argument('--seed', type=int, default=15, help='Random seed.')
|
| 11 |
+
parser.add_argument('--dataset', type=str, default='cora', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'], help='dataset')
|
| 12 |
+
parser.add_argument('--ptb_rate', type=float, default=0.0, help='pertubation rate')
|
| 13 |
+
parser.add_argument('--ptb_type', type=str, default='clean', choices=['clean', 'meta', 'dice', 'minmax', 'pgd', 'random'], help='attack type')
|
| 14 |
+
parser.add_argument('--hidden', type=int, default=16, help='Number of hidden units.')
|
| 15 |
+
parser.add_argument('--dropout', type=float, default=0.5, help='Dropout rate (1 - keep probability).')
|
| 16 |
+
parser.add_argument('--threshold', type=float, default=0.1, help='jaccard coeficient')
|
| 17 |
+
|
| 18 |
+
args = parser.parse_args()
|
| 19 |
+
args.cuda = torch.cuda.is_available()
|
| 20 |
+
print('cuda: %s' % args.cuda)
|
| 21 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 22 |
+
|
| 23 |
+
# make sure you use the same data splits as you generated attacks
|
| 24 |
+
np.random.seed(args.seed)
|
| 25 |
+
if args.cuda:
|
| 26 |
+
torch.cuda.manual_seed(args.seed)
|
| 27 |
+
|
| 28 |
+
# Here the random seed is to split the train/val/test data,
|
| 29 |
+
# we need to set the random seed to be the same as that when you generate the perturbed graph
|
| 30 |
+
# data = Dataset(root='/tmp/', name=args.dataset, setting='nettack', seed=15)
|
| 31 |
+
# Or we can just use setting='prognn' to get the splits
|
| 32 |
+
|
| 33 |
+
# data = Dataset(root='/tmp/', name=args.dataset, setting='prognn')
|
| 34 |
+
data = Dataset(root='/tmp/', name=args.dataset)
|
| 35 |
+
adj, features, labels = data.adj, data.features, data.labels
|
| 36 |
+
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
| 37 |
+
idx_unlabeled = np.union1d(idx_val, idx_test)
|
| 38 |
+
# adj, features, labels = preprocess(adj, features, labels, preprocess_adj=False)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
## load from attacked_adj
|
| 42 |
+
ptb_path = f"../attacked_adj/{args.dataset}/{args.ptb_type}_{args.dataset}_{args.ptb_rate}.pt"
|
| 43 |
+
perturbed_adj = torch.load(ptb_path)
|
| 44 |
+
perturbed_adj = perturbed_adj
|
| 45 |
+
|
| 46 |
+
def test_jaccard(adj):
|
| 47 |
+
''' test on GCN '''
|
| 48 |
+
|
| 49 |
+
# adj = normalize_adj_tensor(adj)
|
| 50 |
+
gcn = GCNJaccard(nfeat=features.shape[1],
|
| 51 |
+
nhid=args.hidden,
|
| 52 |
+
nclass=labels.max().item() + 1,
|
| 53 |
+
dropout=args.dropout, device=device)
|
| 54 |
+
gcn = gcn.to(device)
|
| 55 |
+
gcn.fit(features, adj, labels, idx_train, idx_val, threshold=args.threshold)
|
| 56 |
+
gcn.eval()
|
| 57 |
+
gcn.test(idx_test)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def main():
|
| 61 |
+
|
| 62 |
+
test_jaccard(perturbed_adj)
|
| 63 |
+
# # if you want to save the modified adj/features, uncomment the code below
|
| 64 |
+
# model.save_adj(root='./', name=f'mod_adj')
|
| 65 |
+
# model.save_features(root='./', name='mod_features')
|
| 66 |
+
|
| 67 |
+
if __name__ == '__main__':
|
| 68 |
+
main()
|
examples/graph/cgscore_experiments/defense_method/GCNSVD.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from deeprobust.graph.defense import GCNJaccard, GCN, GCNSVD
|
| 5 |
+
from deeprobust.graph.utils import *
|
| 6 |
+
from deeprobust.graph.data import Dataset, PrePtbDataset
|
| 7 |
+
import argparse
|
| 8 |
+
|
| 9 |
+
parser = argparse.ArgumentParser()
|
| 10 |
+
parser.add_argument('--seed', type=int, default=15, help='Random seed.')
|
| 11 |
+
parser.add_argument('--dataset', type=str, default='cora', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'], help='dataset')
|
| 12 |
+
parser.add_argument('--ptb_rate', type=float, default=0.0, help='pertubation rate')
|
| 13 |
+
parser.add_argument('--ptb_type', type=str, default='clean', choices=['clean', 'meta', 'dice', 'minmax', 'pgd', 'random'], help='attack type')
|
| 14 |
+
parser.add_argument('--hidden', type=int, default=16, help='Number of hidden units.')
|
| 15 |
+
parser.add_argument('--dropout', type=float, default=0.5, help='Dropout rate (1 - keep probability).')
|
| 16 |
+
parser.add_argument('--k', type=int, default=50, help='Truncated Components.')
|
| 17 |
+
|
| 18 |
+
args = parser.parse_args()
|
| 19 |
+
args.cuda = torch.cuda.is_available()
|
| 20 |
+
print('cuda: %s' % args.cuda)
|
| 21 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 22 |
+
|
| 23 |
+
# make sure you use the same data splits as you generated attacks
|
| 24 |
+
np.random.seed(args.seed)
|
| 25 |
+
if args.cuda:
|
| 26 |
+
torch.cuda.manual_seed(args.seed)
|
| 27 |
+
|
| 28 |
+
# Here the random seed is to split the train/val/test data,
|
| 29 |
+
# we need to set the random seed to be the same as that when you generate the perturbed graph
|
| 30 |
+
# data = Dataset(root='/tmp/', name=args.dataset, setting='nettack', seed=15)
|
| 31 |
+
# Or we can just use setting='prognn' to get the splits
|
| 32 |
+
|
| 33 |
+
# data = Dataset(root='/tmp/', name=args.dataset, setting='prognn')
|
| 34 |
+
data = Dataset(root='/tmp/', name=args.dataset)
|
| 35 |
+
adj, features, labels = data.adj, data.features, data.labels
|
| 36 |
+
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
| 37 |
+
# adj, features, labels = preprocess(adj, features, labels, preprocess_adj=False)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
## load from attacked_adj
|
| 41 |
+
ptb_path = f"../attacked_adj/{args.dataset}/{args.ptb_type}_{args.dataset}_{args.ptb_rate}.pt"
|
| 42 |
+
perturbed_adj = torch.load(ptb_path)
|
| 43 |
+
|
| 44 |
+
def test_svd(adj):
|
| 45 |
+
|
| 46 |
+
''' test on GCNSVD '''
|
| 47 |
+
|
| 48 |
+
gcn = GCNSVD(nfeat=features.shape[1], nclass=labels.max()+1,nhid=args.hidden, device=device)
|
| 49 |
+
gcn = gcn.to(device)
|
| 50 |
+
gcn.fit(features, perturbed_adj, labels, idx_train, idx_val, k=args.k, verbose=True)
|
| 51 |
+
gcn.eval()
|
| 52 |
+
gcn.test(idx_test)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def main():
|
| 56 |
+
|
| 57 |
+
test_svd(perturbed_adj)
|
| 58 |
+
# # if you want to save the modified adj/features, uncomment the code below
|
| 59 |
+
# model.save_adj(root='./', name=f'mod_adj')
|
| 60 |
+
# model.save_features(root='./', name='mod_features')
|
| 61 |
+
|
| 62 |
+
if __name__ == '__main__':
|
| 63 |
+
main()
|
examples/graph/cgscore_experiments/defense_method/GNNGuard.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from deeprobust.graph.defense.gcn_guard import GCNGuard
|
| 5 |
+
from deeprobust.graph.utils import *
|
| 6 |
+
from deeprobust.graph.data import Dataset
|
| 7 |
+
from deeprobust.graph.data import PtbDataset, PrePtbDataset
|
| 8 |
+
|
| 9 |
+
import argparse
|
| 10 |
+
from scipy import sparse
|
| 11 |
+
|
| 12 |
+
parser = argparse.ArgumentParser()
|
| 13 |
+
parser.add_argument('--seed', type=int, default=15, help='Random seed')
|
| 14 |
+
parser.add_argument('--GNNGuard', type=bool, default=True, choices=[True, False])
|
| 15 |
+
parser.add_argument('--dataset', type=str, default='cora', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed', 'flickr'], help='dataset')
|
| 16 |
+
parser.add_argument('--ptb_rate', type=float, default=0.0, help='pertubation rate')
|
| 17 |
+
parser.add_argument('--ptb_type', type=str, default='clean', choices=['clean', 'meta', 'dice', 'minmax', 'pgd', 'random'], help='attack type')
|
| 18 |
+
|
| 19 |
+
args = parser.parse_args()
|
| 20 |
+
args.cuda = torch.cuda.is_available()
|
| 21 |
+
print('cuda: %s' % args.cuda)
|
| 22 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 23 |
+
|
| 24 |
+
np.random.seed(args.seed)
|
| 25 |
+
torch.manual_seed(args.seed)
|
| 26 |
+
if args.cuda:
|
| 27 |
+
torch.cuda.manual_seed(args.seed)
|
| 28 |
+
|
| 29 |
+
# data = Dataset(root='/tmp/', name=args.dataset, setting='prognn')
|
| 30 |
+
data = Dataset(root='/tmp/', name=args.dataset)
|
| 31 |
+
adj, features, labels = data.adj, data.features, data.labels
|
| 32 |
+
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
| 33 |
+
|
| 34 |
+
ptb_path = f"../attacked_adj/{args.dataset}/{args.ptb_type}_{args.dataset}_{args.ptb_rate}.pt"
|
| 35 |
+
perturbed_adj = torch.load(ptb_path)
|
| 36 |
+
perturbed_adj = sp.csr_matrix(perturbed_adj.to('cpu').numpy())
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def test(adj):
|
| 40 |
+
# """defense models"""
|
| 41 |
+
''' testing model '''
|
| 42 |
+
gcn = GCNGuard(nfeat=features.shape[1], nclass=labels.max().item() + 1, nhid=16,
|
| 43 |
+
dropout=0.5, with_relu=False, with_bias=True, weight_decay=5e-4, device=device)
|
| 44 |
+
gcn = gcn.to(device)
|
| 45 |
+
|
| 46 |
+
gcn.fit(features, adj, labels, idx_train, train_iters=200, idx_val=idx_val, idx_test=idx_test, verbose=True, attention=args.GNNGuard)
|
| 47 |
+
gcn.eval()
|
| 48 |
+
|
| 49 |
+
# classifier.fit(features, adj, labels, idx_train, idx_val) # train with validation model picking
|
| 50 |
+
acc_test, _ = gcn.test(idx_test)
|
| 51 |
+
# acc_test = classifier.test(idx_test)
|
| 52 |
+
return acc_test
|
| 53 |
+
|
| 54 |
+
def main():
|
| 55 |
+
|
| 56 |
+
# print('=== testing GCN on original(clean) graph ===')
|
| 57 |
+
# test(adj)
|
| 58 |
+
#
|
| 59 |
+
print('=== testing GCN on Mettacked graph ===')
|
| 60 |
+
test(perturbed_adj)
|
| 61 |
+
|
| 62 |
+
if __name__ == '__main__':
|
| 63 |
+
main()
|
| 64 |
+
|
examples/graph/cgscore_experiments/defense_method/ProGNN.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from deeprobust.graph.defense import GCNJaccard, GCN, ProGNN
|
| 5 |
+
from deeprobust.graph.utils import *
|
| 6 |
+
from deeprobust.graph.data import Dataset, PrePtbDataset
|
| 7 |
+
import argparse
|
| 8 |
+
|
| 9 |
+
parser = argparse.ArgumentParser()
|
| 10 |
+
parser.add_argument('--seed', type=int, default=15, help='Random seed.')
|
| 11 |
+
parser.add_argument('--dataset', type=str, default='cora', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'], help='dataset')
|
| 12 |
+
parser.add_argument('--ptb_type', type=str, default='clean', choices=['clean', 'meta', 'dice', 'minmax', 'pgd', 'random'], help='attack type')
|
| 13 |
+
parser.add_argument('--ptb_rate', type=float, default=0.0, help='pertubation rate')
|
| 14 |
+
parser.add_argument('--epochs', type=int, default=200, help='Number of epochs to train.')
|
| 15 |
+
|
| 16 |
+
parser.add_argument('--hidden', type=int, default=16, help='Number of hidden units.')
|
| 17 |
+
parser.add_argument('--dropout', type=float, default=0.5, help='Dropout rate (1 - keep probability).')
|
| 18 |
+
parser.add_argument('--debug', action='store_true',default=False, help='debug mode')
|
| 19 |
+
parser.add_argument('--only_gcn', action='store_true',default=False, help='test the performance of gcn without other components')
|
| 20 |
+
# parser.add_argument('--no-cuda', action='store_true', default=False, help='Disables CUDA training.')
|
| 21 |
+
parser.add_argument('--lr', type=float, default=0.01, help='Initial learning rate.')
|
| 22 |
+
parser.add_argument('--weight_decay', type=float, default=5e-4, help='Weight decay (L2 loss on parameters).')
|
| 23 |
+
parser.add_argument('--alpha', type=float, default=5e-4, help='weight of l1 norm')
|
| 24 |
+
parser.add_argument('--beta', type=float, default=1.5, help='weight of nuclear norm')
|
| 25 |
+
parser.add_argument('--gamma', type=float, default=1, help='weight of l2 norm')
|
| 26 |
+
parser.add_argument('--lambda_', type=float, default=0, help='weight of feature smoothing')
|
| 27 |
+
parser.add_argument('--phi', type=float, default=0, help='weight of symmetric loss')
|
| 28 |
+
parser.add_argument('--inner_steps', type=int, default=2, help='steps for inner optimization')
|
| 29 |
+
parser.add_argument('--outer_steps', type=int, default=1, help='steps for outer optimization')
|
| 30 |
+
parser.add_argument('--lr_adj', type=float, default=0.01, help='lr for training adj')
|
| 31 |
+
parser.add_argument('--symmetric', action='store_true', default=False, help='whether use symmetric matrix')
|
| 32 |
+
|
| 33 |
+
args = parser.parse_args()
|
| 34 |
+
args.cuda = torch.cuda.is_available()
|
| 35 |
+
device = torch.device("cuda:5" if args.cuda else "cpu")
|
| 36 |
+
print('Using device:', device)
|
| 37 |
+
|
| 38 |
+
# make sure you use the same data splits as you generated attacks
|
| 39 |
+
np.random.seed(args.seed)
|
| 40 |
+
if args.cuda:
|
| 41 |
+
torch.cuda.manual_seed(args.seed)
|
| 42 |
+
|
| 43 |
+
# Here the random seed is to split the train/val/test data,
|
| 44 |
+
# we need to set the random seed to be the same as that when you generate the perturbed graph
|
| 45 |
+
# data = Dataset(root='/tmp/', name=args.dataset, setting='nettack', seed=15)
|
| 46 |
+
# Or we can just use setting='prognn' to get the splits
|
| 47 |
+
|
| 48 |
+
# data = Dataset(root='/tmp/', name=args.dataset, setting='prognn')
|
| 49 |
+
data = Dataset(root='/tmp/', name=args.dataset)
|
| 50 |
+
adj, features, labels = data.adj, data.features, data.labels
|
| 51 |
+
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
| 52 |
+
idx_unlabeled = np.union1d(idx_val, idx_test)
|
| 53 |
+
print(type(adj))
|
| 54 |
+
# adj, features, labels = preprocess(adj, features, labels, preprocess_adj=False)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
## load from attacked_adj
|
| 58 |
+
ptb_path = f"../attacked_adj/{args.dataset}/{args.ptb_type}_{args.dataset}_{args.ptb_rate}.pt"
|
| 59 |
+
perturbed_adj = torch.load(ptb_path)
|
| 60 |
+
perturbed_adj = sp.csr_matrix(perturbed_adj.to('cpu').numpy())
|
| 61 |
+
|
| 62 |
+
def test_prognn(features, adj, labels):
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
adj, features, labels = preprocess(adj, features, labels, preprocess_adj=False, device=device)
|
| 66 |
+
|
| 67 |
+
model = GCN(nfeat=features.shape[1],
|
| 68 |
+
nhid=args.hidden,
|
| 69 |
+
nclass=labels.max().item() + 1,
|
| 70 |
+
dropout=args.dropout, device=device).to(device)
|
| 71 |
+
prognn = ProGNN(model, args, device)
|
| 72 |
+
prognn.fit(features, adj, labels, idx_train, idx_val)
|
| 73 |
+
prognn.test(features, labels, idx_test)
|
| 74 |
+
|
| 75 |
+
def main():
|
| 76 |
+
|
| 77 |
+
test_prognn(features, perturbed_adj, labels)
|
| 78 |
+
|
| 79 |
+
if __name__ == '__main__':
|
| 80 |
+
main()
|
examples/graph/cgscore_experiments/defense_method/RGCN.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from deeprobust.graph.defense import RGCN
|
| 5 |
+
from deeprobust.graph.utils import *
|
| 6 |
+
from deeprobust.graph.data import Dataset, PrePtbDataset
|
| 7 |
+
import argparse
|
| 8 |
+
|
| 9 |
+
parser = argparse.ArgumentParser()
|
| 10 |
+
parser.add_argument('--seed', type=int, default=15, help='Random seed.')
|
| 11 |
+
parser.add_argument('--dataset', type=str, default='cora', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'], help='dataset')
|
| 12 |
+
parser.add_argument('--ptb_rate', type=float, default=0.0, help='pertubation rate')
|
| 13 |
+
parser.add_argument('--ptb_type', type=str, default='clean', choices=['clean', 'meta', 'dice', 'minmax', 'pgd', 'random'], help='attack type')
|
| 14 |
+
parser.add_argument('--hidden', type=int, default=16, help='Number of hidden units.')
|
| 15 |
+
parser.add_argument('--dropout', type=float, default=0.5, help='Dropout rate (1 - keep probability).')
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
args = parser.parse_args()
|
| 19 |
+
args.cuda = torch.cuda.is_available()
|
| 20 |
+
print('cuda: %s' % args.cuda)
|
| 21 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 22 |
+
|
| 23 |
+
# make sure you use the same data splits as you generated attacks
|
| 24 |
+
np.random.seed(args.seed)
|
| 25 |
+
if args.cuda:
|
| 26 |
+
torch.cuda.manual_seed(args.seed)
|
| 27 |
+
|
| 28 |
+
# Here the random seed is to split the train/val/test data,
|
| 29 |
+
# we need to set the random seed to be the same as that when you generate the perturbed graph
|
| 30 |
+
# data = Dataset(root='/tmp/', name=args.dataset, setting='nettack', seed=15)
|
| 31 |
+
# Or we can just use setting='prognn' to get the splits
|
| 32 |
+
|
| 33 |
+
# data = Dataset(root='/tmp/', name=args.dataset, setting='prognn')
|
| 34 |
+
data = Dataset(root='/tmp/', name=args.dataset)
|
| 35 |
+
adj, features, labels = data.adj, data.features, data.labels
|
| 36 |
+
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
| 37 |
+
|
| 38 |
+
# adj, features, labels = preprocess(adj, features, labels, preprocess_adj=False)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
## load from attacked_adj
|
| 42 |
+
ptb_path = f"../attacked_adj/{args.dataset}/{args.ptb_type}_{args.dataset}_{args.ptb_rate}.pt"
|
| 43 |
+
perturbed_adj = torch.load(ptb_path)
|
| 44 |
+
perturbed_adj = perturbed_adj
|
| 45 |
+
|
| 46 |
+
def test_rgcn(adj):
|
| 47 |
+
''' test on GCN '''
|
| 48 |
+
|
| 49 |
+
# adj = normalize_adj_tensor(adj)
|
| 50 |
+
gcn = RGCN(nnodes=adj.shape[0], nfeat=features.shape[1], nclass=labels.max()+1,
|
| 51 |
+
nhid=args.hidden, device=device)
|
| 52 |
+
gcn = gcn.to(device)
|
| 53 |
+
gcn.fit(features, adj, labels, idx_train, idx_val, train_iters=200, verbose=True)
|
| 54 |
+
gcn.eval()
|
| 55 |
+
gcn.test(idx_test)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def main():
|
| 59 |
+
|
| 60 |
+
test_rgcn(perturbed_adj)
|
| 61 |
+
# # if you want to save the modified adj/features, uncomment the code below
|
| 62 |
+
# model.save_adj(root='./', name=f'mod_adj')
|
| 63 |
+
# model.save_features(root='./', name='mod_features')
|
| 64 |
+
|
| 65 |
+
if __name__ == '__main__':
|
| 66 |
+
main()
|
examples/graph/cgscore_experiments/defense_method/cgscore.py
ADDED
|
File without changes
|
examples/graph/cgscore_experiments/grb/grb_data.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch # pytorch backend
|
| 2 |
+
from grb.dataset import Dataset
|
| 3 |
+
from grb.model.torch import GCN
|
| 4 |
+
from grb.utils.trainer import Trainer
|
| 5 |
+
|
| 6 |
+
# Load data
|
| 7 |
+
# name: ["grb-cora", "grb-citeseer", "grb-aminer", "grb-reddit", "grb-flickr"].
|
| 8 |
+
# mode: [["easy", "medium", "hard", "full"]
|
| 9 |
+
# mode:
|
| 10 |
+
|
| 11 |
+
dataset = Dataset(name='grb-citeseer', mode='easy',feat_norm='None')
|
| 12 |
+
|
| 13 |
+
features = dataset.features # 注意:不是 tensor,而是 scipy.sparse
|
| 14 |
+
adj = dataset.adj # 是 numpy.ndarray 或 scipy.sparse
|
| 15 |
+
labels = dataset.labels
|
| 16 |
+
index = dataset.idx_train
|
| 17 |
+
|
| 18 |
+
print(type(features))
|
| 19 |
+
|
| 20 |
+
print(type(adj))
|
| 21 |
+
print(type(labels))
|
| 22 |
+
print(type(index))
|
| 23 |
+
# model = GCN(in_features=dataset.num_features,
|
| 24 |
+
# out_features=dataset.num_classes,
|
| 25 |
+
# hidden_features=[64, 64])
|
| 26 |
+
# # Training
|
| 27 |
+
# adam = torch.optim.Adam(model.parameters(), lr=0.01)
|
| 28 |
+
# trainer = Trainer(dataset=dataset, optimizer=adam,
|
| 29 |
+
# loss=torch.nn.functional.nll_loss)
|
| 30 |
+
# trainer.train(model=model, n_epoch=200, dropout=0.1,
|
| 31 |
+
# train_mode='inductive')
|
| 32 |
+
|
examples/graph/cgscore_save.py
ADDED
|
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def calc_cg_score_gnn_with_sampling( # stable training and defense effect
|
| 2 |
+
A, X, labels, device, rep_num=1, unbalance_ratio=1, sub_term=False
|
| 3 |
+
):
|
| 4 |
+
"""
|
| 5 |
+
Optimized CG-score calculation with edge sampling.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
N = A.shape[0]
|
| 9 |
+
cg_scores = {
|
| 10 |
+
"vi": np.zeros((N, N)),
|
| 11 |
+
"ab": np.zeros((N, N)),
|
| 12 |
+
"a2": np.zeros((N, N)),
|
| 13 |
+
"b2": np.zeros((N, N)),
|
| 14 |
+
"times": np.zeros((N, N)),
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
A = A.to(device)
|
| 18 |
+
X = X.to(device)
|
| 19 |
+
labels = labels.to(device)
|
| 20 |
+
|
| 21 |
+
@torch.no_grad()
|
| 22 |
+
def normalize(tensor):
|
| 23 |
+
return tensor / (torch.norm(tensor, dim=1, keepdim=True) + 1e-8)
|
| 24 |
+
|
| 25 |
+
for _ in range(rep_num):
|
| 26 |
+
AX = torch.matmul(A, X)
|
| 27 |
+
norm_AX = normalize(AX)
|
| 28 |
+
|
| 29 |
+
# Organize data by labels
|
| 30 |
+
dataset = defaultdict(list)
|
| 31 |
+
data_idx = defaultdict(list)
|
| 32 |
+
for i, label in enumerate(labels):
|
| 33 |
+
dataset[label.item()].append(norm_AX[i].unsqueeze(0))
|
| 34 |
+
data_idx[label.item()].append(i)
|
| 35 |
+
|
| 36 |
+
for label in dataset:
|
| 37 |
+
dataset[label] = torch.cat(dataset[label], dim=0)
|
| 38 |
+
data_idx[label] = torch.tensor(data_idx[label], dtype=torch.long, device=device)
|
| 39 |
+
|
| 40 |
+
# Cache negative samples
|
| 41 |
+
neg_samples_dict = {}
|
| 42 |
+
neg_indices_dict = {}
|
| 43 |
+
for label in dataset:
|
| 44 |
+
neg_samples = torch.cat([dataset[l] for l in dataset if l != label])
|
| 45 |
+
neg_indices = torch.cat([data_idx[l] for l in data_idx if l != label])
|
| 46 |
+
neg_samples_dict[label] = neg_samples
|
| 47 |
+
neg_indices_dict[label] = neg_indices
|
| 48 |
+
|
| 49 |
+
# for curr_label, curr_samples in dataset.items():
|
| 50 |
+
for curr_label, curr_samples in tqdm(dataset.items(), desc="Label groups"):
|
| 51 |
+
curr_indices = data_idx[curr_label]
|
| 52 |
+
curr_num = len(curr_samples)
|
| 53 |
+
|
| 54 |
+
chosen_curr_idx = np.random.choice(range(curr_num), curr_num, replace=False)
|
| 55 |
+
chosen_curr_samples = curr_samples[chosen_curr_idx]
|
| 56 |
+
chosen_curr_indices = curr_indices[chosen_curr_idx]
|
| 57 |
+
|
| 58 |
+
# Get negative samples
|
| 59 |
+
neg_samples = neg_samples_dict[curr_label]
|
| 60 |
+
neg_indices = neg_indices_dict[curr_label]
|
| 61 |
+
neg_num = min(int(curr_num * unbalance_ratio), len(neg_samples))
|
| 62 |
+
rand_idx = torch.randperm(len(neg_samples))[:neg_num]
|
| 63 |
+
chosen_neg_samples = neg_samples[rand_idx]
|
| 64 |
+
chosen_neg_indices = neg_indices[rand_idx]
|
| 65 |
+
|
| 66 |
+
combined_samples = torch.cat([chosen_curr_samples, chosen_neg_samples], dim=0)
|
| 67 |
+
y = torch.cat([torch.ones(len(chosen_curr_samples)), -torch.ones(neg_num)], dim=0).to(device)
|
| 68 |
+
|
| 69 |
+
# Gram matrix H
|
| 70 |
+
H_inner = torch.matmul(combined_samples, combined_samples.T)
|
| 71 |
+
H_inner = torch.clamp(H_inner, min=-1.0, max=1.0)
|
| 72 |
+
H = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi)
|
| 73 |
+
H.fill_diagonal_(0.5)
|
| 74 |
+
H += 1e-6 * torch.eye(H.size(0), device=device)
|
| 75 |
+
invH = torch.inverse(H)
|
| 76 |
+
original_error = y @ (invH @ y)
|
| 77 |
+
|
| 78 |
+
# for idx_i in chosen_curr_indices:
|
| 79 |
+
for idx_i in tqdm(chosen_curr_indices.tolist(), desc=f"Nodes in label {curr_label}"):
|
| 80 |
+
for j in range(idx_i + 1, N):
|
| 81 |
+
if A[idx_i, j] == 0:
|
| 82 |
+
continue
|
| 83 |
+
|
| 84 |
+
# Sparse AX1 update
|
| 85 |
+
AX1_i = AX[idx_i] - A[idx_i, j] * X[j]
|
| 86 |
+
AX1_j = AX[j] - A[j, idx_i] * X[idx_i]
|
| 87 |
+
|
| 88 |
+
norm_AX1 = norm_AX.clone()
|
| 89 |
+
norm_AX1[idx_i] = AX1_i / (torch.norm(AX1_i) + 1e-8)
|
| 90 |
+
norm_AX1[j] = AX1_j / (torch.norm(AX1_j) + 1e-8)
|
| 91 |
+
|
| 92 |
+
# Updated samples
|
| 93 |
+
curr_samples_A1 = norm_AX1[chosen_curr_indices]
|
| 94 |
+
neg_samples_A1 = norm_AX1[chosen_neg_indices]
|
| 95 |
+
combined_samples_A1 = torch.cat([curr_samples_A1, neg_samples_A1], dim=0)
|
| 96 |
+
|
| 97 |
+
# Recompute H_A1
|
| 98 |
+
H_inner_A1 = torch.matmul(combined_samples_A1, combined_samples_A1.T)
|
| 99 |
+
H_inner_A1 = torch.clamp(H_inner_A1, min=-1.0, max=1.0)
|
| 100 |
+
H_A1 = H_inner_A1 * (np.pi - torch.acos(H_inner_A1)) / (2 * np.pi)
|
| 101 |
+
H_A1.fill_diagonal_(0.5)
|
| 102 |
+
H_A1 += 1e-6 * torch.eye(H_A1.size(0), device=device)
|
| 103 |
+
invH_A1 = torch.inverse(H_A1)
|
| 104 |
+
error_A1 = y @ (invH_A1 @ y)
|
| 105 |
+
|
| 106 |
+
score = (original_error - error_A1).item()
|
| 107 |
+
cg_scores["vi"][idx_i, j] += score
|
| 108 |
+
cg_scores["vi"][j, idx_i] = cg_scores["vi"][idx_i, j]
|
| 109 |
+
cg_scores["times"][idx_i, j] += 1
|
| 110 |
+
cg_scores["times"][j, idx_i] += 1
|
| 111 |
+
|
| 112 |
+
# Normalize
|
| 113 |
+
for key in cg_scores:
|
| 114 |
+
if key != "times":
|
| 115 |
+
cg_scores[key] = cg_scores[key] / np.where(cg_scores["times"] > 0, cg_scores["times"], 1)
|
| 116 |
+
|
| 117 |
+
return cg_scores if sub_term else cg_scores["vi"]
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def calc_cg_score_gnn_with_sampling(
|
| 121 |
+
A, X, labels, device, rep_num=1, unbalance_ratio=1, sub_term=False, batch_size=64
|
| 122 |
+
):
|
| 123 |
+
"""
|
| 124 |
+
Optimized CG-score calculation with edge batching and GPU acceleration.
|
| 125 |
+
"""
|
| 126 |
+
# if hasattr(torch, "compile"):
|
| 127 |
+
# calc_cg_score_gnn_with_sampling = torch.compile(calc_cg_score_gnn_with_sampling)
|
| 128 |
+
|
| 129 |
+
N = A.shape[0]
|
| 130 |
+
cg_scores = {
|
| 131 |
+
"vi": np.zeros((N, N)),
|
| 132 |
+
"ab": np.zeros((N, N)),
|
| 133 |
+
"a2": np.zeros((N, N)),
|
| 134 |
+
"b2": np.zeros((N, N)),
|
| 135 |
+
"times": np.zeros((N, N)),
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
A = A.to(device)
|
| 139 |
+
X = X.to(device)
|
| 140 |
+
labels = labels.to(device)
|
| 141 |
+
|
| 142 |
+
@torch.no_grad()
|
| 143 |
+
def normalize(tensor):
|
| 144 |
+
return tensor / (torch.norm(tensor, dim=1, keepdim=True) + 1e-8)
|
| 145 |
+
|
| 146 |
+
for _ in range(rep_num):
|
| 147 |
+
AX = torch.matmul(A, X)
|
| 148 |
+
norm_AX = normalize(AX)
|
| 149 |
+
|
| 150 |
+
# Group nodes by label
|
| 151 |
+
dataset = defaultdict(list)
|
| 152 |
+
data_idx = defaultdict(list)
|
| 153 |
+
for i, label in enumerate(labels):
|
| 154 |
+
dataset[label.item()].append(norm_AX[i].unsqueeze(0))
|
| 155 |
+
data_idx[label.item()].append(i)
|
| 156 |
+
|
| 157 |
+
for label in dataset:
|
| 158 |
+
dataset[label] = torch.cat(dataset[label], dim=0)
|
| 159 |
+
data_idx[label] = torch.tensor(data_idx[label], dtype=torch.long, device=device)
|
| 160 |
+
|
| 161 |
+
# Prepare negative samples
|
| 162 |
+
neg_samples_dict = {}
|
| 163 |
+
neg_indices_dict = {}
|
| 164 |
+
for label in dataset:
|
| 165 |
+
neg_samples = torch.cat([dataset[l] for l in dataset if l != label])
|
| 166 |
+
neg_indices = torch.cat([data_idx[l] for l in data_idx if l != label])
|
| 167 |
+
neg_samples_dict[label] = neg_samples
|
| 168 |
+
neg_indices_dict[label] = neg_indices
|
| 169 |
+
|
| 170 |
+
for curr_label, curr_samples in tqdm(dataset.items(), desc="Label groups"):
|
| 171 |
+
curr_indices = data_idx[curr_label]
|
| 172 |
+
curr_num = len(curr_samples)
|
| 173 |
+
|
| 174 |
+
chosen_curr_idx = np.random.choice(range(curr_num), curr_num, replace=False)
|
| 175 |
+
chosen_curr_samples = curr_samples[chosen_curr_idx]
|
| 176 |
+
chosen_curr_indices = curr_indices[chosen_curr_idx]
|
| 177 |
+
|
| 178 |
+
neg_samples = neg_samples_dict[curr_label]
|
| 179 |
+
neg_indices = neg_indices_dict[curr_label]
|
| 180 |
+
neg_num = min(int(curr_num * unbalance_ratio), len(neg_samples))
|
| 181 |
+
rand_idx = torch.randperm(len(neg_samples))[:neg_num]
|
| 182 |
+
chosen_neg_samples = neg_samples[rand_idx]
|
| 183 |
+
chosen_neg_indices = neg_indices[rand_idx]
|
| 184 |
+
|
| 185 |
+
combined_samples = torch.cat([chosen_curr_samples, chosen_neg_samples], dim=0)
|
| 186 |
+
y = torch.cat([torch.ones(len(chosen_curr_samples)), -torch.ones(neg_num)], dim=0).to(device)
|
| 187 |
+
|
| 188 |
+
# Compute reference error
|
| 189 |
+
H_inner = torch.matmul(combined_samples, combined_samples.T)
|
| 190 |
+
H_inner = torch.clamp(H_inner, min=-1.0, max=1.0)
|
| 191 |
+
H = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi)
|
| 192 |
+
H.fill_diagonal_(0.5)
|
| 193 |
+
H += 1e-6 * torch.eye(H.size(0), device=device)
|
| 194 |
+
invH = torch.inverse(H)
|
| 195 |
+
original_error = y @ (invH @ y)
|
| 196 |
+
|
| 197 |
+
# Gather candidate edges
|
| 198 |
+
edge_batch = []
|
| 199 |
+
for idx_i in chosen_curr_indices.tolist():
|
| 200 |
+
for j in range(idx_i + 1, N):
|
| 201 |
+
if A[idx_i, j] != 0:
|
| 202 |
+
edge_batch.append((idx_i, j))
|
| 203 |
+
|
| 204 |
+
# Process in batches
|
| 205 |
+
for k in tqdm(range(0, len(edge_batch), batch_size), desc="Edge batches", leave=False):
|
| 206 |
+
batch = edge_batch[k : k + batch_size]
|
| 207 |
+
B = len(batch)
|
| 208 |
+
|
| 209 |
+
norm_AX1_batch = norm_AX.repeat(B, 1, 1)
|
| 210 |
+
updates = []
|
| 211 |
+
for b, (i, j) in enumerate(batch):
|
| 212 |
+
AX1_i = AX[i] - A[i, j] * X[j]
|
| 213 |
+
AX1_j = AX[j] - A[j, i] * X[i]
|
| 214 |
+
norm_AX1_batch[b, i] = AX1_i / (torch.norm(AX1_i) + 1e-8)
|
| 215 |
+
norm_AX1_batch[b, j] = AX1_j / (torch.norm(AX1_j) + 1e-8)
|
| 216 |
+
|
| 217 |
+
sample_idx = chosen_curr_indices.tolist() + chosen_neg_indices.tolist()
|
| 218 |
+
sample_batch = norm_AX1_batch[:, sample_idx, :] # [B, M, D]
|
| 219 |
+
|
| 220 |
+
H_inner = torch.matmul(sample_batch, sample_batch.transpose(1, 2))
|
| 221 |
+
H_inner = torch.clamp(H_inner, min=-1.0, max=1.0)
|
| 222 |
+
H = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi)
|
| 223 |
+
eye = torch.eye(H.size(-1), device=device).unsqueeze(0).expand_as(H)
|
| 224 |
+
H = H + 1e-6 * eye
|
| 225 |
+
H.diagonal(dim1=-2, dim2=-1).copy_(0.5)
|
| 226 |
+
|
| 227 |
+
invH = torch.inverse(H)
|
| 228 |
+
y_expanded = y.unsqueeze(0).expand(B, -1)
|
| 229 |
+
error_A1 = torch.einsum('bi,bij,bj->b', y_expanded, invH, y_expanded)
|
| 230 |
+
|
| 231 |
+
for b, (i, j) in enumerate(batch):
|
| 232 |
+
score = (original_error - error_A1[b]).item()
|
| 233 |
+
cg_scores["vi"][i, j] += score
|
| 234 |
+
cg_scores["vi"][j, i] = cg_scores["vi"][i, j]
|
| 235 |
+
cg_scores["times"][i, j] += 1
|
| 236 |
+
cg_scores["times"][j, i] += 1
|
| 237 |
+
|
| 238 |
+
for key in cg_scores:
|
| 239 |
+
if key != "times":
|
| 240 |
+
cg_scores[key] = cg_scores[key] / np.where(cg_scores["times"] > 0, cg_scores["times"], 1)
|
| 241 |
+
|
| 242 |
+
return cg_scores if sub_term else cg_scores["vi"]
|
| 243 |
+
|
| 244 |
+
def calc_cg_score_gnn_with_sampling( # based on the front code, remove more data to GPU, effect is approxiamate
|
| 245 |
+
A, X, labels, device, rep_num=1, unbalance_ratio=1, sub_term=False
|
| 246 |
+
):
|
| 247 |
+
"""
|
| 248 |
+
Calculate CG-score for each edge in a graph with node labels and random sampling.
|
| 249 |
+
|
| 250 |
+
Args:
|
| 251 |
+
A: torch.Tensor
|
| 252 |
+
Adjacency matrix of the graph (size: N x N).
|
| 253 |
+
X: torch.Tensor
|
| 254 |
+
Node features matrix (size: N x F).
|
| 255 |
+
labels: torch.Tensor
|
| 256 |
+
Node labels (size: N).
|
| 257 |
+
device: torch.device
|
| 258 |
+
Device to perform calculations.
|
| 259 |
+
rep_num: int
|
| 260 |
+
Number of repetitions for Monte Carlo sampling.
|
| 261 |
+
unbalance_ratio: float
|
| 262 |
+
Ratio of unbalanced data (1:unbalance_ratio).
|
| 263 |
+
sub_term: bool
|
| 264 |
+
If True, calculate and return sub-terms.
|
| 265 |
+
|
| 266 |
+
Returns:
|
| 267 |
+
cg_scores: dict
|
| 268 |
+
Dictionary containing CG-scores for edges and optionally sub-terms.
|
| 269 |
+
"""
|
| 270 |
+
N = A.shape[0]
|
| 271 |
+
cg_scores = {
|
| 272 |
+
"vi": np.zeros((N, N)),
|
| 273 |
+
"ab": np.zeros((N, N)),
|
| 274 |
+
"a2": np.zeros((N, N)),
|
| 275 |
+
"b2": np.zeros((N, N)),
|
| 276 |
+
"times": np.zeros((N, N)),
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
with torch.no_grad():
|
| 280 |
+
for _ in range(rep_num):
|
| 281 |
+
# Compute AX (node representations)
|
| 282 |
+
AX = torch.matmul(A, X).to(device)
|
| 283 |
+
norm_AX = AX / torch.norm(AX, dim=1, keepdim=True)
|
| 284 |
+
|
| 285 |
+
# Group nodes by their labels
|
| 286 |
+
dataset = defaultdict(list)
|
| 287 |
+
data_idx = defaultdict(list)
|
| 288 |
+
for i, label in enumerate(labels):
|
| 289 |
+
dataset[label.item()].append(norm_AX[i].unsqueeze(0)) # Store normalized data
|
| 290 |
+
data_idx[label.item()].append(i) # Store indices
|
| 291 |
+
|
| 292 |
+
# Convert to tensors
|
| 293 |
+
for label, data_list in dataset.items():
|
| 294 |
+
dataset[label] = torch.cat(data_list, dim=0)
|
| 295 |
+
data_idx[label] = torch.tensor(data_idx[label], dtype=torch.long, device=device)
|
| 296 |
+
|
| 297 |
+
# Calculate CG-scores for each label group
|
| 298 |
+
for curr_label, curr_samples in dataset.items():
|
| 299 |
+
curr_indices = data_idx[curr_label]
|
| 300 |
+
curr_num = len(curr_samples)
|
| 301 |
+
|
| 302 |
+
# Randomly sample a subset of current label examples
|
| 303 |
+
chosen_curr_idx = np.random.choice(range(curr_num), curr_num, replace=False)
|
| 304 |
+
chosen_curr_samples = curr_samples[chosen_curr_idx]
|
| 305 |
+
chosen_curr_indices = curr_indices[chosen_curr_idx]
|
| 306 |
+
|
| 307 |
+
# Sample negative examples from other classes
|
| 308 |
+
neg_samples = torch.cat(
|
| 309 |
+
[dataset[l] for l in dataset if l != curr_label], dim=0
|
| 310 |
+
)
|
| 311 |
+
neg_indices = torch.cat(
|
| 312 |
+
[data_idx[l] for l in data_idx if l != curr_label], dim=0
|
| 313 |
+
)
|
| 314 |
+
neg_num = min(int(curr_num * unbalance_ratio), len(neg_samples))
|
| 315 |
+
chosen_neg_samples = neg_samples[
|
| 316 |
+
torch.randperm(len(neg_samples))[:neg_num]
|
| 317 |
+
]
|
| 318 |
+
|
| 319 |
+
# Combine positive and negative samples
|
| 320 |
+
combined_samples = torch.cat([chosen_curr_samples, chosen_neg_samples], dim=0)
|
| 321 |
+
y = torch.cat(
|
| 322 |
+
[torch.ones(len(chosen_curr_samples)), -torch.ones(neg_num)], dim=0
|
| 323 |
+
).to(device)
|
| 324 |
+
|
| 325 |
+
# Compute the Gram matrix H^\infty
|
| 326 |
+
H_inner = torch.matmul(combined_samples, combined_samples.T)
|
| 327 |
+
del combined_samples
|
| 328 |
+
###
|
| 329 |
+
H_inner = torch.clamp(H_inner, min=-1.0, max=1.0)
|
| 330 |
+
###
|
| 331 |
+
H = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi)
|
| 332 |
+
del H_inner
|
| 333 |
+
|
| 334 |
+
H.fill_diagonal_(0.5)
|
| 335 |
+
##
|
| 336 |
+
epsilon = 1e-6
|
| 337 |
+
H = H + epsilon * torch.eye(H.size(0), device=H.device)
|
| 338 |
+
##
|
| 339 |
+
invH = torch.inverse(H)
|
| 340 |
+
del H
|
| 341 |
+
original_error = y @ (invH @ y)
|
| 342 |
+
|
| 343 |
+
# Compute CG-scores for each edge
|
| 344 |
+
for i in chosen_curr_indices:
|
| 345 |
+
print("the node index:", i)
|
| 346 |
+
for j in range(i + 1, N): # Upper triangular traversal
|
| 347 |
+
# print(j)
|
| 348 |
+
if A[i, j] == 0: # Skip if no edge exists
|
| 349 |
+
continue
|
| 350 |
+
|
| 351 |
+
# Remove edge (i, j) to create A1
|
| 352 |
+
A1 = A.clone()
|
| 353 |
+
A1[i, j] = A1[j, i] = 0
|
| 354 |
+
|
| 355 |
+
# Recompute AX with A1
|
| 356 |
+
AX1 = torch.matmul(A1, X).to(device)
|
| 357 |
+
norm_AX1 = AX1 / torch.norm(AX1, dim=1, keepdim=True)
|
| 358 |
+
|
| 359 |
+
# Repeat error calculation with A1
|
| 360 |
+
curr_samples_A1 = norm_AX1[chosen_curr_indices]
|
| 361 |
+
neg_samples_A1 = norm_AX1[neg_indices]
|
| 362 |
+
chosen_neg_samples_A1 = neg_samples_A1[
|
| 363 |
+
torch.randperm(len(neg_samples_A1))[:neg_num]
|
| 364 |
+
]
|
| 365 |
+
combined_samples_A1 = torch.cat(
|
| 366 |
+
[curr_samples_A1, chosen_neg_samples_A1], dim=0
|
| 367 |
+
)
|
| 368 |
+
H_inner_A1 = torch.matmul(combined_samples_A1, combined_samples_A1.T)
|
| 369 |
+
|
| 370 |
+
del combined_samples_A1
|
| 371 |
+
|
| 372 |
+
### trick1
|
| 373 |
+
H_inner_A1 = torch.clamp(H_inner_A1, min=-1.0, max=1.0)
|
| 374 |
+
###
|
| 375 |
+
|
| 376 |
+
H_A1 = H_inner_A1 * (np.pi - torch.acos(H_inner_A1)) / (2 * np.pi)
|
| 377 |
+
del H_inner_A1
|
| 378 |
+
H_A1.fill_diagonal_(0.5)
|
| 379 |
+
|
| 380 |
+
### trick2
|
| 381 |
+
epsilon = 1e-6
|
| 382 |
+
H_A1= H_A1 + epsilon * torch.eye(H_A1.size(0), device=H_A1.device)
|
| 383 |
+
###
|
| 384 |
+
invH_A1 = torch.inverse(H_A1)
|
| 385 |
+
del H_A1
|
| 386 |
+
|
| 387 |
+
error_A1 = y @ (invH_A1 @ y)
|
| 388 |
+
|
| 389 |
+
print("i:", i)
|
| 390 |
+
print("j:", j)
|
| 391 |
+
print("current score:", (original_error - error_A1).item())
|
| 392 |
+
# Compute the difference in error (CG-score)
|
| 393 |
+
cg_scores["vi"][i, j] += (original_error - error_A1).item()
|
| 394 |
+
cg_scores["vi"][j, i] = cg_scores["vi"][i, j] # Symmetric
|
| 395 |
+
cg_scores["times"][i, j] += 1
|
| 396 |
+
cg_scores["times"][j, i] += 1
|
| 397 |
+
|
| 398 |
+
# Normalize CG-scores by repetition count
|
| 399 |
+
for key, values in cg_scores.items():
|
| 400 |
+
if key == "times":
|
| 401 |
+
continue
|
| 402 |
+
cg_scores[key] = values / np.where(cg_scores["times"] > 0, cg_scores["times"], 1)
|
examples/graph/test_adv_train_evasion.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torch.optim as optim
|
| 5 |
+
from deeprobust.graph.defense import GCN
|
| 6 |
+
from deeprobust.graph.global_attack import Random
|
| 7 |
+
from deeprobust.graph.targeted_attack import Nettack
|
| 8 |
+
from deeprobust.graph.utils import *
|
| 9 |
+
from deeprobust.graph.data import Dataset
|
| 10 |
+
from deeprobust.graph.data import PtbDataset
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
import argparse
|
| 13 |
+
|
| 14 |
+
parser = argparse.ArgumentParser()
|
| 15 |
+
parser.add_argument('--seed', type=int, default=15, help='Random seed.')
|
| 16 |
+
parser.add_argument('--dataset', type=str, default='citeseer', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'], help='dataset')
|
| 17 |
+
parser.add_argument('--ptb_rate', type=float, default=0.05, help='pertubation rate')
|
| 18 |
+
|
| 19 |
+
args = parser.parse_args()
|
| 20 |
+
args.cuda = torch.cuda.is_available()
|
| 21 |
+
print('cuda: %s' % args.cuda)
|
| 22 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 23 |
+
|
| 24 |
+
# make sure you use the same data splits as you generated attacks
|
| 25 |
+
np.random.seed(args.seed)
|
| 26 |
+
if args.cuda:
|
| 27 |
+
torch.cuda.manual_seed(args.seed)
|
| 28 |
+
|
| 29 |
+
# load original dataset (to get clean features and labels)
|
| 30 |
+
data = Dataset(root='/tmp/', name=args.dataset)
|
| 31 |
+
adj, features, labels = data.adj, data.features, data.labels
|
| 32 |
+
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
| 33 |
+
|
| 34 |
+
# Setup Target Model
|
| 35 |
+
model = GCN(nfeat=features.shape[1], nclass=labels.max()+1,
|
| 36 |
+
nhid=16, dropout=0, with_relu=False, with_bias=True, device=device)
|
| 37 |
+
|
| 38 |
+
model = model.to(device)
|
| 39 |
+
|
| 40 |
+
# test on original adj
|
| 41 |
+
print('=== test on original adj ===')
|
| 42 |
+
model.fit(features, adj, labels, idx_train)
|
| 43 |
+
output = model.output
|
| 44 |
+
acc_test = accuracy(output[idx_test], labels[idx_test])
|
| 45 |
+
print("Test set results:",
|
| 46 |
+
"accuracy= {:.4f}".format(acc_test.item()))
|
| 47 |
+
|
| 48 |
+
print('=== Adversarial Training for Evasion Attack===')
|
| 49 |
+
adversary = Random()
|
| 50 |
+
adv_train_model = GCN(nfeat=features.shape[1], nclass=labels.max()+1,
|
| 51 |
+
nhid=16, dropout=0, with_relu=False, with_bias=True, device=device)
|
| 52 |
+
|
| 53 |
+
adv_train_model = adv_train_model.to(device)
|
| 54 |
+
|
| 55 |
+
adv_train_model.initialize()
|
| 56 |
+
n_perturbations = int(0.01 * (adj.sum()//2))
|
| 57 |
+
for i in tqdm(range(100)):
|
| 58 |
+
# modified_adj = adversary.attack(features, adj)
|
| 59 |
+
adversary.attack(adj, n_perturbations=n_perturbations, type='add')
|
| 60 |
+
modified_adj = adversary.modified_adj
|
| 61 |
+
adv_train_model.fit(features, modified_adj, labels, idx_train, train_iters=50, initialize=False)
|
| 62 |
+
|
| 63 |
+
adv_train_model.eval()
|
| 64 |
+
# test directly or fine tune
|
| 65 |
+
print('=== test on perturbed adj ===')
|
| 66 |
+
output = adv_train_model.predict()
|
| 67 |
+
acc_test = accuracy(output[idx_test], labels[idx_test])
|
| 68 |
+
print("Test set results:",
|
| 69 |
+
"accuracy= {:.4f}".format(acc_test.item()))
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# set up Surrogate & Nettack to attack the graph
|
| 73 |
+
import random
|
| 74 |
+
target_nodes = random.sample(idx_test.tolist(), 20)
|
| 75 |
+
# Setup Surrogate model
|
| 76 |
+
surrogate = GCN(nfeat=features.shape[1], nclass=labels.max().item()+1,
|
| 77 |
+
nhid=16, dropout=0, with_relu=False, with_bias=False, device=device)
|
| 78 |
+
surrogate = surrogate.to(device)
|
| 79 |
+
surrogate.fit(features, adj, labels, idx_train)
|
| 80 |
+
|
| 81 |
+
all_margins = []
|
| 82 |
+
all_adv_margins = []
|
| 83 |
+
|
| 84 |
+
for target_node in target_nodes:
|
| 85 |
+
# set up Nettack
|
| 86 |
+
adversary = Nettack(surrogate, nnodes=adj.shape[0], attack_structure=True, attack_features=True, device=device)
|
| 87 |
+
adversary = adversary.to(device)
|
| 88 |
+
degrees = adj.sum(0).A1
|
| 89 |
+
n_perturbations = int(degrees[target_node]) + 2
|
| 90 |
+
adversary.attack(features, adj, labels, target_node, n_perturbations)
|
| 91 |
+
perturbed_adj = adversary.modified_adj
|
| 92 |
+
|
| 93 |
+
model = GCN(nfeat=features.shape[1], nclass=labels.max()+1,
|
| 94 |
+
nhid=16, dropout=0, with_relu=False, with_bias=True, device=device)
|
| 95 |
+
model = model.to(device)
|
| 96 |
+
|
| 97 |
+
print('=== testing GCN on perturbed graph ===')
|
| 98 |
+
model.fit(features, perturbed_adj, labels, idx_train)
|
| 99 |
+
output = model.output
|
| 100 |
+
margin = classification_margin(output[target_node], labels[target_node])
|
| 101 |
+
all_margins.append(margin)
|
| 102 |
+
|
| 103 |
+
print('=== testing adv-GCN on perturbed graph ===')
|
| 104 |
+
output = adv_train_model.predict(features, perturbed_adj)
|
| 105 |
+
adv_margin = classification_margin(output[target_node], labels[target_node])
|
| 106 |
+
all_adv_margins.append(adv_margin)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
print("No adversarial training: classfication margin for {0} nodes: {1}".format(len(target_nodes), np.mean(all_margins)))
|
| 110 |
+
|
| 111 |
+
print("Adversarial training: classfication margin for {0} nodes: {1}".format(len(target_nodes), np.mean(all_adv_margins)))
|
| 112 |
+
|
examples/graph/test_adv_train_poisoning.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torch.optim as optim
|
| 5 |
+
from deeprobust.graph.defense import GCN
|
| 6 |
+
from deeprobust.graph.global_attack import Random
|
| 7 |
+
from deeprobust.graph.utils import *
|
| 8 |
+
from deeprobust.graph.data import Dataset
|
| 9 |
+
from deeprobust.graph.data import PtbDataset
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
parser = argparse.ArgumentParser()
|
| 13 |
+
parser.add_argument('--seed', type=int, default=15, help='Random seed.')
|
| 14 |
+
parser.add_argument('--dataset', type=str, default='citeseer', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'], help='dataset')
|
| 15 |
+
parser.add_argument('--ptb_rate', type=float, default=0.05, help='pertubation rate')
|
| 16 |
+
|
| 17 |
+
args = parser.parse_args()
|
| 18 |
+
args.cuda = torch.cuda.is_available()
|
| 19 |
+
print('cuda: %s' % args.cuda)
|
| 20 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 21 |
+
|
| 22 |
+
# make sure you use the same data splits as you generated attacks
|
| 23 |
+
np.random.seed(args.seed)
|
| 24 |
+
if args.cuda:
|
| 25 |
+
torch.cuda.manual_seed(args.seed)
|
| 26 |
+
|
| 27 |
+
# load original dataset (to get clean features and labels)
|
| 28 |
+
data = Dataset(root='/tmp/', name=args.dataset)
|
| 29 |
+
adj, features, labels = data.adj, data.features, data.labels
|
| 30 |
+
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
| 31 |
+
|
| 32 |
+
# load pre-attacked graph
|
| 33 |
+
perturbed_data = PtbDataset(root='/tmp/', name=args.dataset)
|
| 34 |
+
perturbed_adj = perturbed_data.adj
|
| 35 |
+
|
| 36 |
+
# Setup Target Model
|
| 37 |
+
model = GCN(nfeat=features.shape[1], nclass=labels.max()+1,
|
| 38 |
+
nhid=16, dropout=0, with_relu=False, with_bias=True, device=device)
|
| 39 |
+
|
| 40 |
+
model = model.to(device)
|
| 41 |
+
|
| 42 |
+
adversary = Random()
|
| 43 |
+
# test on original adj
|
| 44 |
+
print('=== test on original adj ===')
|
| 45 |
+
model.fit(features, adj, labels, idx_train)
|
| 46 |
+
output = model.output
|
| 47 |
+
acc_test = accuracy(output[idx_test], labels[idx_test])
|
| 48 |
+
print("Test set results:",
|
| 49 |
+
"accuracy= {:.4f}".format(acc_test.item()))
|
| 50 |
+
|
| 51 |
+
print('=== testing GCN on perturbed graph ===')
|
| 52 |
+
model.fit(features, perturbed_adj, labels, idx_train)
|
| 53 |
+
output = model.output
|
| 54 |
+
acc_test = accuracy(output[idx_test], labels[idx_test])
|
| 55 |
+
print("Test set results:",
|
| 56 |
+
"accuracy= {:.4f}".format(acc_test.item()))
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# For poisoning attack, the adjacency matrix you have
|
| 60 |
+
# is alreay perturbed
|
| 61 |
+
print('=== Adversarial Training for Poisoning Attack===')
|
| 62 |
+
model.initialize()
|
| 63 |
+
n_perturbations = int(0.01 * (adj.sum()//2))
|
| 64 |
+
for i in range(100):
|
| 65 |
+
# modified_adj = adversary.attack(features, adj)
|
| 66 |
+
adversary.attack(perturbed_adj, n_perturbations=n_perturbations, type='remove')
|
| 67 |
+
modified_adj = adversary.modified_adj
|
| 68 |
+
model.fit(features, modified_adj, labels, idx_train, train_iters=50, initialize=False)
|
| 69 |
+
|
| 70 |
+
model.eval()
|
| 71 |
+
|
| 72 |
+
# test directly or fine tune
|
| 73 |
+
print('=== test on perturbed adj ===')
|
| 74 |
+
output = model.predict(features, perturbed_adj)
|
| 75 |
+
acc_test = accuracy(output[idx_test], labels[idx_test])
|
| 76 |
+
print("Test set results:",
|
| 77 |
+
"accuracy= {:.4f}".format(acc_test.item()))
|
| 78 |
+
|
examples/graph/test_all.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import os.path as osp
|
| 3 |
+
|
| 4 |
+
for file in os.listdir('./'):
|
| 5 |
+
if "py" not in file:
|
| 6 |
+
continue
|
| 7 |
+
if 'rl' in file or 'nipa' in file or 'meta' in file or 'all' in file:
|
| 8 |
+
continue
|
| 9 |
+
if osp.isfile(file):
|
| 10 |
+
print(file)
|
| 11 |
+
os.system('CUDA_VISIBLE_DEVICES=0 python %s' % file)
|
| 12 |
+
|
| 13 |
+
|
examples/graph/test_chebnet.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import argparse
|
| 3 |
+
from deeprobust.graph.data import Dataset, Dpr2Pyg
|
| 4 |
+
from deeprobust.graph.defense import ChebNet
|
| 5 |
+
from deeprobust.graph.data import Dataset
|
| 6 |
+
from deeprobust.graph.data import PrePtbDataset
|
| 7 |
+
|
| 8 |
+
parser = argparse.ArgumentParser()
|
| 9 |
+
parser.add_argument('--seed', type=int, default=15, help='Random seed.')
|
| 10 |
+
parser.add_argument('--dataset', type=str, default='cora', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'], help='dataset')
|
| 11 |
+
parser.add_argument('--ptb_rate', type=float, default=0.05, help='perturbation rate')
|
| 12 |
+
|
| 13 |
+
args = parser.parse_args()
|
| 14 |
+
args.cuda = torch.cuda.is_available()
|
| 15 |
+
print('cuda: %s' % args.cuda)
|
| 16 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 17 |
+
|
| 18 |
+
# use data splist provided by prognn
|
| 19 |
+
data = Dataset(root='/tmp/', name=args.dataset, setting='prognn')
|
| 20 |
+
adj, features, labels = data.adj, data.features, data.labels
|
| 21 |
+
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
| 22 |
+
|
| 23 |
+
cheby = ChebNet(nfeat=features.shape[1],
|
| 24 |
+
nhid=16, num_hops=3,
|
| 25 |
+
nclass=labels.max().item() + 1,
|
| 26 |
+
dropout=0.5, device=device)
|
| 27 |
+
cheby = cheby.to(device)
|
| 28 |
+
|
| 29 |
+
# test on clean graph
|
| 30 |
+
print('==================')
|
| 31 |
+
print('=== train on clean graph ===')
|
| 32 |
+
|
| 33 |
+
pyg_data = Dpr2Pyg(data)
|
| 34 |
+
cheby.fit(pyg_data, verbose=True) # train with earlystopping
|
| 35 |
+
cheby.test()
|
| 36 |
+
|
| 37 |
+
# load pre-attacked graph by Zugner: https://github.com/danielzuegner/gnn-meta-attack
|
| 38 |
+
print('==================')
|
| 39 |
+
print('=== load graph perturbed by Zugner metattack (under prognn splits) ===')
|
| 40 |
+
perturbed_data = PrePtbDataset(root='/tmp/',
|
| 41 |
+
name=args.dataset,
|
| 42 |
+
attack_method='meta',
|
| 43 |
+
ptb_rate=args.ptb_rate)
|
| 44 |
+
perturbed_adj = perturbed_data.adj
|
| 45 |
+
pyg_data.update_edge_index(perturbed_adj) # inplace operation
|
| 46 |
+
cheby.fit(pyg_data, verbose=True) # train with earlystopping
|
| 47 |
+
cheby.test()
|
| 48 |
+
|
examples/graph/test_deepwalk.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from deeprobust.graph.data import Dataset
|
| 2 |
+
from deeprobust.graph.defense import DeepWalk, Node2Vec
|
| 3 |
+
from deeprobust.graph.global_attack import NodeEmbeddingAttack
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
dataset_str = 'cora_ml'
|
| 7 |
+
data = Dataset(root='/tmp/', name=dataset_str, seed=15)
|
| 8 |
+
adj, features, labels = data.adj, data.features, data.labels
|
| 9 |
+
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
| 10 |
+
|
| 11 |
+
attacker = NodeEmbeddingAttack()
|
| 12 |
+
attacker.attack(adj, attack_type="remove", n_perturbations=1000)
|
| 13 |
+
modified_adj = attacker.modified_adj
|
| 14 |
+
|
| 15 |
+
# train defense model
|
| 16 |
+
print("Test DeepWalk on clean graph")
|
| 17 |
+
model = DeepWalk()
|
| 18 |
+
model.fit(adj)
|
| 19 |
+
model.evaluate_node_classification(labels, idx_train, idx_test)
|
| 20 |
+
# model.evaluate_node_classification(labels, idx_train, idx_test, lr_params={"max_iter": 1000})
|
| 21 |
+
|
| 22 |
+
print("Test DeepWalk on attacked graph")
|
| 23 |
+
model.fit(modified_adj)
|
| 24 |
+
model.evaluate_node_classification(labels, idx_train, idx_test)
|
| 25 |
+
|
| 26 |
+
print("Test DeepWalk on link prediciton...")
|
| 27 |
+
model.evaluate_link_prediction(modified_adj, np.array(adj.nonzero()).T)
|
| 28 |
+
|
| 29 |
+
print("Test DeepWalk SVD on attacked graph")
|
| 30 |
+
model = DeepWalk(type="svd")
|
| 31 |
+
model.fit(modified_adj)
|
| 32 |
+
model.evaluate_node_classification(labels, idx_train, idx_test)
|
| 33 |
+
|
| 34 |
+
print("Test Node2vec on attacked graph")
|
| 35 |
+
model = Node2Vec()
|
| 36 |
+
model.fit(modified_adj)
|
| 37 |
+
model.evaluate_node_classification(labels, idx_train, idx_test)
|
| 38 |
+
|
| 39 |
+
|
examples/graph/test_gat.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import argparse
|
| 3 |
+
from deeprobust.graph.data import Dataset, Dpr2Pyg
|
| 4 |
+
from deeprobust.graph.defense import GAT
|
| 5 |
+
from deeprobust.graph.data import Dataset
|
| 6 |
+
from deeprobust.graph.data import PrePtbDataset
|
| 7 |
+
import scipy.sparse as sp
|
| 8 |
+
|
| 9 |
+
parser = argparse.ArgumentParser()
|
| 10 |
+
parser.add_argument('--seed', type=int, default=15, help='Random seed.')
|
| 11 |
+
parser.add_argument('--dataset', type=str, default='cora', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'], help='dataset')
|
| 12 |
+
parser.add_argument('--ptb_rate', type=float, default=0.10, help='perturbation rate')
|
| 13 |
+
|
| 14 |
+
args = parser.parse_args()
|
| 15 |
+
args.cuda = torch.cuda.is_available()
|
| 16 |
+
print('cuda: %s' % args.cuda)
|
| 17 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 18 |
+
|
| 19 |
+
# use data splist provided by prognn
|
| 20 |
+
data = Dataset(root='/tmp/', name=args.dataset, setting='prognn')
|
| 21 |
+
adj, features, labels = data.adj, data.features, data.labels
|
| 22 |
+
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
| 23 |
+
|
| 24 |
+
gat = GAT(nfeat=features.shape[1],
|
| 25 |
+
nhid=8, heads=8,
|
| 26 |
+
nclass=labels.max().item() + 1,
|
| 27 |
+
dropout=0.5, device=device)
|
| 28 |
+
gat = gat.to(device)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# test on clean graph
|
| 32 |
+
print('==================')
|
| 33 |
+
print('=== train on clean graph ===')
|
| 34 |
+
|
| 35 |
+
print(type(features))
|
| 36 |
+
print(type(adj))
|
| 37 |
+
pyg_data = Dpr2Pyg(data)
|
| 38 |
+
gat.fit(pyg_data, verbose=True) # train with earlystopping
|
| 39 |
+
gat.test()
|
| 40 |
+
|
| 41 |
+
# load pre-attacked graph by Zugner: https://github.com/danielzuegner/gnn-meta-attack
|
| 42 |
+
print('==================')
|
| 43 |
+
print('=== load graph perturbed by Zugner metattack (under prognn splits) ===')
|
| 44 |
+
perturbed_data = PrePtbDataset(root='/tmp/',
|
| 45 |
+
name=args.dataset,
|
| 46 |
+
attack_method='meta',
|
| 47 |
+
ptb_rate=args.ptb_rate)
|
| 48 |
+
perturbed_adj = perturbed_data.adj
|
| 49 |
+
pyg_data.update_edge_index(perturbed_adj) # inplace operation
|
| 50 |
+
gat.fit(pyg_data, verbose=True) # train with earlystopping
|
| 51 |
+
gat.test()
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
|
examples/graph/test_gcn.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from deeprobust.graph.defense import GCN
|
| 5 |
+
from deeprobust.graph.utils import *
|
| 6 |
+
from deeprobust.graph.data import Dataset
|
| 7 |
+
from deeprobust.graph.data import PtbDataset, PrePtbDataset
|
| 8 |
+
import argparse
|
| 9 |
+
|
| 10 |
+
parser = argparse.ArgumentParser()
|
| 11 |
+
parser.add_argument('--seed', type=int, default=15, help='Random seed.')
|
| 12 |
+
parser.add_argument('--dataset', type=str, default='citeseer', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'], help='dataset')
|
| 13 |
+
parser.add_argument('--ptb_rate', type=float, default=0.05, help='pertubation rate')
|
| 14 |
+
|
| 15 |
+
args = parser.parse_args()
|
| 16 |
+
args.cuda = torch.cuda.is_available()
|
| 17 |
+
print('cuda: %s' % args.cuda)
|
| 18 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 19 |
+
|
| 20 |
+
# Here the random seed is to split the train/val/test data,
|
| 21 |
+
# we need to set the random seed to be the same as that when you generate the perturbed graph
|
| 22 |
+
# data = Dataset(root='/tmp/', name=args.dataset, setting='nettack', seed=15)
|
| 23 |
+
# Or we can just use setting='prognn' to get the splits
|
| 24 |
+
data = Dataset(root='/tmp/', name=args.dataset, setting='prognn')
|
| 25 |
+
adj, features, labels = data.adj, data.features, data.labels
|
| 26 |
+
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# load pre-attacked graph by Zugner: https://github.com/danielzuegner/gnn-meta-attack
|
| 30 |
+
print('==================')
|
| 31 |
+
print('=== load graph perturbed by Zugner metattack (under prognn splits) ===')
|
| 32 |
+
perturbed_data = PrePtbDataset(root='/tmp/',
|
| 33 |
+
name=args.dataset,
|
| 34 |
+
attack_method='meta',
|
| 35 |
+
ptb_rate=args.ptb_rate)
|
| 36 |
+
# perturbed_adj = perturbed_data.adj
|
| 37 |
+
perturbed_adj = adj
|
| 38 |
+
|
| 39 |
+
np.random.seed(args.seed)
|
| 40 |
+
torch.manual_seed(args.seed)
|
| 41 |
+
if args.cuda:
|
| 42 |
+
torch.cuda.manual_seed(args.seed)
|
| 43 |
+
|
| 44 |
+
# Setup GCN Model
|
| 45 |
+
model = GCN(nfeat=features.shape[1], nhid=16, nclass=labels.max()+1, device=device)
|
| 46 |
+
model = model.to(device)
|
| 47 |
+
|
| 48 |
+
# model.fit(features, perturbed_adj, labels, idx_train, train_iters=200, verbose=True)
|
| 49 |
+
# # using validation to pick model
|
| 50 |
+
model.fit(features, perturbed_adj, labels, idx_train, idx_val, train_iters=200, verbose=True)
|
| 51 |
+
model.eval()
|
| 52 |
+
# You can use the inner function of model to test
|
| 53 |
+
model.test(idx_test)
|
| 54 |
+
|
| 55 |
+
# print('==================')
|
| 56 |
+
# print('=== load graph perturbed by DeepRobust 5% metattack (under prognn splits) ===')
|
| 57 |
+
# perturbed_data = PtbDataset(root='/tmp/',
|
| 58 |
+
# name=args.dataset,
|
| 59 |
+
# attack_method='meta')
|
| 60 |
+
# perturbed_adj = perturbed_data.adj
|
| 61 |
+
|
| 62 |
+
# print("dataset:", args.dataset)
|
| 63 |
+
# # model.fit(features, perturbed_adj, labels, idx_train, train_iters=200, verbose=True)
|
| 64 |
+
# # # using validation to pick model
|
| 65 |
+
# model.fit(features, perturbed_adj, labels, idx_train, idx_val, train_iters=200, verbose=True)
|
| 66 |
+
# model.eval()
|
| 67 |
+
# # You can use the inner function of model to test
|
| 68 |
+
# model.test(idx_test)
|
| 69 |
+
|