doevent commited on
Commit
87abef8
·
1 Parent(s): 24685d9

Upload models/clusterkit.py

Browse files
Files changed (1) hide show
  1. models/clusterkit.py +291 -0
models/clusterkit.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from functools import partial
6
+ import numpy as np
7
+ import torch
8
+ from tqdm import tqdm
9
+ import math, random
10
+ #from sklearn.cluster import KMeans, kmeans_plusplus, MeanShift, estimate_bandwidth
11
+
12
+
13
+ def tensor_kmeans_sklearn(data_vecs, n_clusters=7, metric='euclidean', need_layer_masks=False, max_iters=20):
14
+ N,C,H,W = data_vecs.shape
15
+ assert N == 1, 'only support singe image tensor'
16
+ ## (1,C,H,W) -> (HW,C)
17
+ data_vecs = data_vecs.permute(0,2,3,1).view(-1,C)
18
+ ## convert tensor to array
19
+ data_vecs_np = data_vecs.squeeze().detach().to("cpu").numpy()
20
+ km = KMeans(n_clusters=n_clusters, init='k-means++', n_init=10, max_iter=300)
21
+ pred = km.fit_predict(data_vecs_np)
22
+ cluster_ids_x = torch.from_numpy(km.labels_).to(data_vecs.device)
23
+ id_maps = cluster_ids_x.reshape(1,1,H,W).long()
24
+ if need_layer_masks:
25
+ one_hot_labels = F.one_hot(id_maps.squeeze(1), num_classes=n_clusters).float()
26
+ cluster_mask = one_hot_labels.permute(0,3,1,2)
27
+ return cluster_mask
28
+ return id_maps
29
+
30
+
31
+ def tensor_kmeans_pytorch(data_vecs, n_clusters=7, metric='euclidean', need_layer_masks=False, max_iters=20):
32
+ N,C,H,W = data_vecs.shape
33
+ assert N == 1, 'only support singe image tensor'
34
+
35
+ ## (1,C,H,W) -> (HW,C)
36
+ data_vecs = data_vecs.permute(0,2,3,1).view(-1,C)
37
+ ## cosine | euclidean
38
+ #cluster_ids_x, cluster_centers = kmeans(X=data_vecs, num_clusters=n_clusters, distance=metric, device=data_vecs.device)
39
+ cluster_ids_x, cluster_centers = kmeans(X=data_vecs, num_clusters=n_clusters, distance=metric,\
40
+ tqdm_flag=False, iter_limit=max_iters, device=data_vecs.device)
41
+ id_maps = cluster_ids_x.reshape(1,1,H,W)
42
+ if need_layer_masks:
43
+ one_hot_labels = F.one_hot(id_maps.squeeze(1), num_classes=n_clusters).float()
44
+ cluster_mask = one_hot_labels.permute(0,3,1,2)
45
+ return cluster_mask
46
+ return id_maps
47
+
48
+
49
+ def batch_kmeans_pytorch(data_vecs, n_clusters=7, metric='euclidean', use_sklearn_kmeans=False):
50
+ N,C,H,W = data_vecs.shape
51
+ sample_list = []
52
+ for idx in range(N):
53
+ if use_sklearn_kmeans:
54
+ cluster_mask = tensor_kmeans_sklearn(data_vecs[idx:idx+1,:,:,:], n_clusters, metric, True)
55
+ else:
56
+ cluster_mask = tensor_kmeans_pytorch(data_vecs[idx:idx+1,:,:,:], n_clusters, metric, True)
57
+ sample_list.append(cluster_mask)
58
+ return torch.cat(sample_list, dim=0)
59
+
60
+
61
+ def get_centroid_candidates(data_vecs, n_clusters=7, metric='euclidean', max_iters=20):
62
+ N,C,H,W = data_vecs.shape
63
+ data_vecs = data_vecs.permute(0,2,3,1).view(-1,C)
64
+ cluster_ids_x, cluster_centers = kmeans(X=data_vecs, num_clusters=n_clusters, distance=metric,\
65
+ tqdm_flag=False, iter_limit=max_iters, device=data_vecs.device)
66
+ return cluster_centers
67
+
68
+
69
+ def find_distinctive_elements(data_tensor, n_clusters=7, topk=3, metric='euclidean'):
70
+ N,C,H,W = data_tensor.shape
71
+ centroid_list = []
72
+ for idx in range(N):
73
+ cluster_centers = get_centroid_candidates(data_tensor[idx:idx+1,:,:,:], n_clusters, metric)
74
+ centroid_list.append(cluster_centers)
75
+
76
+ batch_centroids = torch.stack(centroid_list, dim=0)
77
+ data_vecs = data_tensor.flatten(2)
78
+ ## distance matrix: (N,K,HW) = (N,K,C) x (N,C,HW)
79
+ AtB = torch.matmul(batch_centroids, data_vecs)
80
+ AtA = torch.matmul(batch_centroids, batch_centroids.permute(0,2,1))
81
+ BtB = torch.matmul(data_vecs.permute(0,2,1), data_vecs)
82
+ diag_A = torch.diagonal(AtA, dim1=-2, dim2=-1)
83
+ diag_B = torch.diagonal(BtB, dim1=-2, dim2=-1)
84
+ A2 = diag_A.unsqueeze(2).repeat(1,1,H*W)
85
+ B2 = diag_B.unsqueeze(1).repeat(1,n_clusters,1)
86
+ distance_map = A2 - 2*AtB + B2
87
+ values, indices = distance_map.topk(topk, dim=2, largest=False, sorted=True)
88
+ cluster_mask = torch.where(distance_map <= values[:,:,topk-1:], torch.ones_like(distance_map), torch.zeros_like(distance_map))
89
+ cluster_mask = cluster_mask.view(N,n_clusters,H,W)
90
+ return cluster_mask
91
+
92
+
93
+ ##---------------------------------------------------------------------------------
94
+ '''
95
+ resource from github: https://github.com/subhadarship/kmeans_pytorch
96
+ '''
97
+ ##---------------------------------------------------------------------------------
98
+
99
+ def initialize(X, num_clusters):
100
+ """
101
+ initialize cluster centers
102
+ :param X: (torch.tensor) matrix
103
+ :param num_clusters: (int) number of clusters
104
+ :return: (np.array) initial state
105
+ """
106
+ np.random.seed(1)
107
+ num_samples = len(X)
108
+ indices = np.random.choice(num_samples, num_clusters, replace=False)
109
+ initial_state = X[indices]
110
+ return initial_state
111
+
112
+
113
+ def kmeans(
114
+ X,
115
+ num_clusters,
116
+ distance='euclidean',
117
+ cluster_centers=[],
118
+ tol=1e-4,
119
+ tqdm_flag=True,
120
+ iter_limit=0,
121
+ device=torch.device('cpu'),
122
+ gamma_for_soft_dtw=0.001
123
+ ):
124
+ """
125
+ perform kmeans
126
+ :param X: (torch.tensor) matrix
127
+ :param num_clusters: (int) number of clusters
128
+ :param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean']
129
+ :param tol: (float) threshold [default: 0.0001]
130
+ :param device: (torch.device) device [default: cpu]
131
+ :param tqdm_flag: Allows to turn logs on and off
132
+ :param iter_limit: hard limit for max number of iterations
133
+ :param gamma_for_soft_dtw: approaches to (hard) DTW as gamma -> 0
134
+ :return: (torch.tensor, torch.tensor) cluster ids, cluster centers
135
+ """
136
+ if tqdm_flag:
137
+ print(f'running k-means on {device}..')
138
+
139
+ if distance == 'euclidean':
140
+ pairwise_distance_function = partial(pairwise_distance, device=device, tqdm_flag=tqdm_flag)
141
+ elif distance == 'cosine':
142
+ pairwise_distance_function = partial(pairwise_cosine, device=device)
143
+ else:
144
+ raise NotImplementedError
145
+
146
+ # convert to float
147
+ X = X.float()
148
+
149
+ # transfer to device
150
+ X = X.to(device)
151
+
152
+ # initialize
153
+ if type(cluster_centers) == list: # ToDo: make this less annoyingly weird
154
+ initial_state = initialize(X, num_clusters)
155
+ else:
156
+ if tqdm_flag:
157
+ print('resuming')
158
+ # find data point closest to the initial cluster center
159
+ initial_state = cluster_centers
160
+ dis = pairwise_distance_function(X, initial_state)
161
+ choice_points = torch.argmin(dis, dim=0)
162
+ initial_state = X[choice_points]
163
+ initial_state = initial_state.to(device)
164
+
165
+ iteration = 0
166
+ if tqdm_flag:
167
+ tqdm_meter = tqdm(desc='[running kmeans]')
168
+ while True:
169
+
170
+ dis = pairwise_distance_function(X, initial_state)
171
+
172
+ choice_cluster = torch.argmin(dis, dim=1)
173
+
174
+ initial_state_pre = initial_state.clone()
175
+
176
+ for index in range(num_clusters):
177
+ selected = torch.nonzero(choice_cluster == index).squeeze().to(device)
178
+
179
+ selected = torch.index_select(X, 0, selected)
180
+
181
+ # https://github.com/subhadarship/kmeans_pytorch/issues/16
182
+ if selected.shape[0] == 0:
183
+ selected = X[torch.randint(len(X), (1,))]
184
+
185
+ initial_state[index] = selected.mean(dim=0)
186
+
187
+ center_shift = torch.sum(
188
+ torch.sqrt(
189
+ torch.sum((initial_state - initial_state_pre) ** 2, dim=1)
190
+ ))
191
+
192
+ # increment iteration
193
+ iteration = iteration + 1
194
+
195
+ # update tqdm meter
196
+ if tqdm_flag:
197
+ tqdm_meter.set_postfix(
198
+ iteration=f'{iteration}',
199
+ center_shift=f'{center_shift ** 2:0.6f}',
200
+ tol=f'{tol:0.6f}'
201
+ )
202
+ tqdm_meter.update()
203
+ if center_shift ** 2 < tol:
204
+ break
205
+ if iter_limit != 0 and iteration >= iter_limit:
206
+ #print('hello, there!')
207
+ break
208
+
209
+ return choice_cluster.to(device), initial_state.to(device)
210
+
211
+
212
+ def kmeans_predict(
213
+ X,
214
+ cluster_centers,
215
+ distance='euclidean',
216
+ device=torch.device('cpu'),
217
+ gamma_for_soft_dtw=0.001,
218
+ tqdm_flag=True
219
+ ):
220
+ """
221
+ predict using cluster centers
222
+ :param X: (torch.tensor) matrix
223
+ :param cluster_centers: (torch.tensor) cluster centers
224
+ :param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean']
225
+ :param device: (torch.device) device [default: 'cpu']
226
+ :param gamma_for_soft_dtw: approaches to (hard) DTW as gamma -> 0
227
+ :return: (torch.tensor) cluster ids
228
+ """
229
+ if tqdm_flag:
230
+ print(f'predicting on {device}..')
231
+
232
+ if distance == 'euclidean':
233
+ pairwise_distance_function = partial(pairwise_distance, device=device, tqdm_flag=tqdm_flag)
234
+ elif distance == 'cosine':
235
+ pairwise_distance_function = partial(pairwise_cosine, device=device)
236
+ elif distance == 'soft_dtw':
237
+ sdtw = SoftDTW(use_cuda=device.type == 'cuda', gamma=gamma_for_soft_dtw)
238
+ pairwise_distance_function = partial(pairwise_soft_dtw, sdtw=sdtw, device=device)
239
+ else:
240
+ raise NotImplementedError
241
+
242
+ # convert to float
243
+ X = X.float()
244
+
245
+ # transfer to device
246
+ X = X.to(device)
247
+
248
+ dis = pairwise_distance_function(X, cluster_centers)
249
+ choice_cluster = torch.argmin(dis, dim=1)
250
+
251
+ return choice_cluster.cpu()
252
+
253
+
254
+ def pairwise_distance(data1, data2, device=torch.device('cpu'), tqdm_flag=True):
255
+ if tqdm_flag:
256
+ print(f'device is :{device}')
257
+
258
+ # transfer to device
259
+ data1, data2 = data1.to(device), data2.to(device)
260
+
261
+ # N*1*M
262
+ A = data1.unsqueeze(dim=1)
263
+
264
+ # 1*N*M
265
+ B = data2.unsqueeze(dim=0)
266
+
267
+ dis = (A - B) ** 2.0
268
+ # return N*N matrix for pairwise distance
269
+ dis = dis.sum(dim=-1).squeeze()
270
+ return dis
271
+
272
+
273
+ def pairwise_cosine(data1, data2, device=torch.device('cpu')):
274
+ # transfer to device
275
+ data1, data2 = data1.to(device), data2.to(device)
276
+
277
+ # N*1*M
278
+ A = data1.unsqueeze(dim=1)
279
+
280
+ # 1*N*M
281
+ B = data2.unsqueeze(dim=0)
282
+
283
+ # normalize the points | [0.3, 0.4] -> [0.3/sqrt(0.09 + 0.16), 0.4/sqrt(0.09 + 0.16)] = [0.3/0.5, 0.4/0.5]
284
+ A_normalized = A / A.norm(dim=-1, keepdim=True)
285
+ B_normalized = B / B.norm(dim=-1, keepdim=True)
286
+
287
+ cosine = A_normalized * B_normalized
288
+
289
+ # return N*N matrix for pairwise distance
290
+ cosine_dis = 1 - cosine.sum(dim=-1).squeeze()
291
+ return cosine_dis