ilhamap commited on
Commit
cee5099
·
verified ·
1 Parent(s): 18c4e17

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +361 -0
  2. module.py +275 -0
  3. requirements.txt +9 -0
  4. sstvit.py +94 -0
app.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import io
3
+ import collections
4
+ from scipy.io import loadmat
5
+ import matplotlib.pyplot as plt
6
+ from PIL import Image
7
+ import numpy as np
8
+ import torch
9
+ import argparse
10
+ import torch.nn as nn
11
+ import torch.utils.data as Data
12
+ import torch.backends.cudnn as cudnn
13
+ from scipy.io import loadmat
14
+ from scipy.io import savemat
15
+ from torch import optim
16
+ from torch.autograd import Variable
17
+ from sstvit import SSTViT
18
+ from sklearn.metrics import confusion_matrix
19
+
20
+ import matplotlib.pyplot as plt
21
+ from matplotlib import colors
22
+ import numpy as np
23
+ from patchify import patchify, unpatchify
24
+ import time
25
+ from matplotlib import colors as mcolors
26
+ import base64
27
+ import pandas as pd
28
+ import st_aggrid
29
+ import os
30
+ import json
31
+ import plotly.express as px
32
+
33
+
34
+ css='''
35
+ <style>
36
+ section.main > div {max-width:60rem}
37
+ </style>
38
+ '''
39
+ st.markdown(css, unsafe_allow_html=True)
40
+
41
+ class Args(dict):
42
+ __setattr__ = dict.__setitem__
43
+ __getattr__ = dict.__getitem__
44
+
45
+ args = {
46
+ 'dataset' : 'mg',
47
+ 'flag_test' : 'train',
48
+ 'gpu_id' : 0,
49
+ 'seed' : int(0),
50
+ 'batch_size' : int(64),
51
+ 'test_freq' : int(10),
52
+ 'patches' : int(5),
53
+ 'band_patches' : int(1),
54
+ 'epoches' : int(2000),
55
+ 'learning_rate' : float(5e-4),
56
+ 'gamma' : float(0.9),
57
+ 'weight_decay' : float(0),
58
+ 'train_number' : int(500)
59
+ }
60
+ args = Args(args) # dict2object
61
+ obj = args.copy() # object2dict
62
+
63
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
64
+
65
+ def test_epoch(model, test_loader):
66
+
67
+ pre = np.array([])
68
+ for batch_idx, (batch_data_t1, batch_data_t2) in enumerate(test_loader):
69
+ batch_data_t1 = batch_data_t1
70
+ batch_data_t2 = batch_data_t2
71
+
72
+ batch_pred = model(batch_data_t1,batch_data_t2)
73
+
74
+ _, pred = batch_pred.topk(1, 1, True, True)
75
+ pp = pred.squeeze()
76
+ pre = np.append(pre, pp.data.cpu().numpy())
77
+ return pre
78
+ mdic = ['Before','After','Before','After']
79
+ colors = ['#3b68f8', '#ff0201', '#23fe01'] #-1,0,1,2,3
80
+ cmap = mcolors.ListedColormap(colors)
81
+ # Parameter Setting
82
+ np.random.seed(args.seed)
83
+ torch.manual_seed(args.seed)
84
+ torch.cuda.manual_seed(args.seed)
85
+ cudnn.deterministic = True
86
+ cudnn.benchmark = False
87
+
88
+ def encode_masks_to_rgb(masks):
89
+ colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0)]
90
+ # Create an empty RGB image
91
+ height, width = masks.shape
92
+ rgb_image = np.zeros((height, width, 3), dtype=np.uint8)
93
+
94
+ # Assign colors based on the mask values
95
+ for i in range(len(colors)):
96
+ mask_indices = masks == i
97
+ rgb_image[mask_indices] = colors[i]
98
+
99
+ return rgb_image
100
+ def count_pixel(pred):
101
+ image = Image.fromarray(pred)
102
+
103
+ # Define the colors you want to count in RGB format
104
+ color2label = {
105
+ (0, 0, 255): "Non Mangrove",
106
+ (255, 0, 0): "Mangrove Loss",
107
+ (0, 255, 0): "Mangrove Before",
108
+ }
109
+
110
+ # Create a flattened list of pixel values
111
+ pixels = list(image.getdata())
112
+ # Count the number of pixels for each color
113
+ color_counts = collections.Counter(pixels)
114
+ # Calculate the total number of pixels in the image
115
+ total_pixels = len(pixels)
116
+
117
+ # Initialize a dictionary to store the average number of pixels for each class
118
+ average_counts = {color2label[label]: (count / total_pixels)*100 for label, count in color_counts.items()}
119
+
120
+ class_counts = {color2label[label]: count for label, count in color_counts.items()}
121
+
122
+ pix_avg = {}
123
+ pix_count = {}
124
+ for _, i in color2label.items():
125
+ try:
126
+ pix_avg[i] = average_counts[i]
127
+ pix_count[i] = class_counts[i]
128
+ except:
129
+ pix_avg[i] = 0
130
+ pix_count[i] = 0
131
+
132
+
133
+ x = {
134
+ "class": list(pix_avg.keys()),
135
+ "percentage": list(pix_avg.values()),
136
+ "pixel_count": list(pix_count.values())
137
+ }
138
+ # print(x)
139
+
140
+ return pd.DataFrame(x)
141
+ def count_pixel1(pred):
142
+ image = Image.fromarray(pred)
143
+
144
+ # Define the colors you want to count in RGB format
145
+ color2label = {
146
+ (0, 0, 255): "Non Mangrove",
147
+ (255, 0, 0): "Mangrove Loss",
148
+ (0, 255, 0): "Mangrove After",
149
+ }
150
+
151
+ # Create a flattened list of pixel values
152
+ pixels = list(image.getdata())
153
+ # Count the number of pixels for each color
154
+ color_counts = collections.Counter(pixels)
155
+ # Calculate the total number of pixels in the image
156
+ total_pixels = len(pixels)
157
+
158
+ # Initialize a dictionary to store the average number of pixels for each class
159
+ average_counts = {color2label[label]: (count / total_pixels)*100 for label, count in color_counts.items()}
160
+
161
+ class_counts = {color2label[label]: count for label, count in color_counts.items()}
162
+
163
+ pix_avg = {}
164
+ pix_count = {}
165
+ for _, i in color2label.items():
166
+ try:
167
+ pix_avg[i] = average_counts[i]
168
+ pix_count[i] = class_counts[i]
169
+ except:
170
+ pix_avg[i] = 0
171
+ pix_count[i] = 0
172
+
173
+
174
+ x = {
175
+ "class": list(pix_avg.keys()),
176
+ "percentage": list(pix_avg.values()),
177
+ "pixel_count": list(pix_count.values())
178
+ }
179
+ # print(x)
180
+
181
+ return pd.DataFrame(x)
182
+
183
+ file = st.file_uploader("Upload file", type=['mat'])
184
+
185
+ if file:
186
+
187
+ data_img2 = loadmat(file)['data_img2']
188
+ data_img1 = loadmat(file)['data_img1']
189
+ st.subheader("Preview Dataset")
190
+ col1, col2 = st.columns(2)
191
+ with col1:
192
+ fig = plt.figure(figsize=(5, 5))
193
+ plt.subplot(121)
194
+ plt.imshow(data_img1)
195
+ plt.title('Before', fontweight='bold')
196
+ plt.xticks([])
197
+ plt.yticks([])
198
+ plt.subplot(122)
199
+ plt.imshow(data_img2)
200
+ plt.title('After', fontweight='bold')
201
+ plt.xticks([])
202
+ plt.yticks([])
203
+ plt.show()
204
+ st.pyplot(fig)
205
+ holder = st.empty()
206
+ if holder.button("Start Prediction"):
207
+ start = time.time()
208
+ holder.empty()
209
+ with st.spinner("Processing, please wait around 7-15 minute"):
210
+ data_t1 = loadmat(file)['data_t1']
211
+ data_t2 = loadmat(file)['data_t2']
212
+ L_post = loadmat(file)['L_post']
213
+ L_pre = loadmat(file)['L_pre']
214
+ data_img1 = loadmat(file)['data_img1']
215
+ data_img2 = loadmat(file)['data_img2']
216
+
217
+ L_post = np.double(L_post)
218
+ L_post[L_post==0]=-0.8
219
+ L_post[L_post==1]=0
220
+ L_post[L_post==0]=-0.2
221
+
222
+ L_pre = np.double(L_pre)
223
+ L_pre[L_pre==0]=-0.8
224
+ L_pre[L_pre==1]=0
225
+ L_pre[L_pre==0]=-0.2
226
+
227
+
228
+ data_t1 = data_t1[:L_post.shape[0],:L_post.shape[1],:]
229
+ data_t2 = data_t2[:L_post.shape[0],:L_post.shape[1],:]
230
+ data_cb1 = np.zeros(shape=(L_post.shape[0],L_post.shape[1],11),dtype=np.float32)
231
+ data_cb2 = np.zeros(shape=(L_post.shape[0],L_post.shape[1],11),dtype=np.float32)
232
+ data_cb1[:,:,:10]=data_t1
233
+ data_cb1[:,:,10]=L_pre
234
+ data_cb2[:,:,:10]=data_t2
235
+ data_cb2[:,:,10]=L_post
236
+ height, width, band = data_cb1.shape
237
+ height=height-4
238
+ width = width-4
239
+ x1 = patchify(data_cb1, (5, 5, 11), step=1).reshape(-1,5*5, 11)
240
+ x2 = patchify(data_cb2, (5, 5, 11), step=1).reshape(-1,5*5, 11)
241
+
242
+ # create model
243
+ model = SSTViT(
244
+ image_size = 5,
245
+ near_band = args.band_patches,
246
+ num_patches = 11,
247
+ num_classes = 3,
248
+ dim = 32,
249
+ depth = 2,
250
+ heads = 4,
251
+ dim_head=16,
252
+ mlp_dim = 8,
253
+ b_dim = 512,
254
+ b_depth = 3,
255
+ b_heads = 8,
256
+ b_dim_head= 32,
257
+ b_mlp_head = 8,
258
+ dropout = 0.2,
259
+ emb_dropout = 0.1,
260
+ )
261
+ model.load_state_dict(torch.load("model/lsstformer.pth",map_location=torch.device("cpu")))
262
+
263
+ x1_true_band=torch.from_numpy(x1.transpose(0,2,1)).type(torch.FloatTensor)
264
+ x2_true_band=torch.from_numpy(x1.transpose(0,2,1)).type(torch.FloatTensor)
265
+ Label_true=Data.TensorDataset(x1_true_band,x2_true_band)
266
+ label_true_loader=Data.DataLoader(Label_true,batch_size=100,shuffle=False)
267
+ model.eval()
268
+ # output classification maps
269
+ pre_u = test_epoch(model, label_true_loader)
270
+ prediction_matrix = pre_u.reshape(height,width)
271
+
272
+ x1_true_band=torch.from_numpy(x1.transpose(0,2,1)).type(torch.FloatTensor)
273
+ x2_true_band=torch.from_numpy(x2.transpose(0,2,1)).type(torch.FloatTensor)
274
+ Label_true=Data.TensorDataset(x1_true_band,x2_true_band)
275
+ label_true_loader=Data.DataLoader(Label_true,batch_size=100,shuffle=False)
276
+ model.eval()
277
+ # output classification maps
278
+ pre_u = test_epoch(model, label_true_loader)
279
+ prediction_matrix2 = pre_u.reshape(height,width)
280
+ A = prediction_matrix.reshape(-1)
281
+ B = prediction_matrix2.reshape(-1)
282
+ mg = np.array(np.where(A==2))
283
+ mg1 = np.array(np.where(B==2))
284
+ mgls = np.array(np.where(B==1))
285
+ class_counts = count_pixel(encode_masks_to_rgb(prediction_matrix))
286
+ class_counts1 = count_pixel1(encode_masks_to_rgb(prediction_matrix2))
287
+
288
+ with st.container():
289
+ st.subheader("Prediction Result")
290
+ col1, col2 = st.columns(2)
291
+ with col1:
292
+ with st.container():
293
+ fig = plt.figure(figsize=(10, 10))
294
+ plt.subplot(121)
295
+ plt.imshow(prediction_matrix, cmap=cmap)
296
+ plt.title('Before',fontsize=25, fontweight='bold')
297
+ plt.xticks([])
298
+ plt.yticks([])
299
+ plt.subplot(122)
300
+ plt.imshow(prediction_matrix2, cmap=cmap)
301
+ plt.title('After',fontsize=25, fontweight='bold')
302
+ plt.xticks([])
303
+ plt.yticks([])
304
+ plt.show()
305
+ st.pyplot(fig)
306
+ buf = io.BytesIO()
307
+ fig.savefig(buf, format="png")
308
+ with col2:
309
+ with st.container():
310
+ table_data = {
311
+ "Total mangrove before":f"{mg.shape[1]*100} m\u00B2",
312
+ "Total mangrove after":f"{mg1.shape[1]*100} m\u00B2",
313
+ "Total mangrove loss":f"{mgls.shape[1]*100} m\u00B2",
314
+ }
315
+ df = pd.DataFrame(list(table_data.items()), columns=['Key', 'Value'])
316
+
317
+ MIN_HEIGHT = 100
318
+ MAX_HEIGHT = 180
319
+ ROW_HEIGHT = 50
320
+
321
+ # st.dataframe(df, hide_index=True, use_container_width=True)
322
+ st_aggrid.AgGrid(df,fit_columns_on_grid_load=True, height=min(MIN_HEIGHT + len(df) * ROW_HEIGHT, MAX_HEIGHT))
323
+ with st.container():
324
+ st.subheader("Pixel Distribution")
325
+
326
+
327
+ df = class_counts
328
+ df = df.drop(0)
329
+ df1 = df.drop(1)
330
+
331
+
332
+
333
+ df2 = class_counts1
334
+ df3 = df2.drop(0)
335
+ vertical_concat = pd.concat([df1, df3], axis=0)
336
+ MIN_HEIGHT = 100
337
+ MAX_HEIGHT = 180
338
+ ROW_HEIGHT = 50
339
+ vertical_concat = vertical_concat.iloc[[0,2,1],:]
340
+
341
+
342
+ st_aggrid.AgGrid(vertical_concat,fit_columns_on_grid_load=True, height=min(MIN_HEIGHT + len(vertical_concat) * ROW_HEIGHT, MAX_HEIGHT))
343
+ fig = px.bar(vertical_concat, x='percentage', y='class', color='class', orientation='h',
344
+ color_discrete_sequence=["green","green", "red", "blue"],
345
+ category_orders={"class": ["Mangrove Before","Mangrove After", "Mangrove Loss", "Non Mangrove",]}
346
+ )
347
+
348
+ st.plotly_chart(fig,use_container_width=False)
349
+ end = time.time()
350
+ process = end-start
351
+ st.write('process',process)
352
+
353
+
354
+ show_file = st.empty()
355
+
356
+ if not file:
357
+ url = "https://drive.usercontent.google.com/download?id=1u48pMzRWQ2Etfjaq5A0CUjRtGKZaJoJy&export=download&authuser=2&confirm=t&uuid=52b0e01e-377f-42cb-8412-c84aa38a1740&at=APZUnTXslmuCCV1drJ2WWtkZr9BR%3A1710357675310"
358
+ show_file.info("""
359
+ The model was trained using Sentinel-2 imagery, users can upload MAT files to perform LSST-Former for mangrove loss detection models that have been trained in this research. Tool for generate from Sentinel-2 to MAT file i will create later, please download demo dataset bellow. for better in mobile phone, se desktop mode.
360
+ """)
361
+ st.write("download demo datasets this [link](%s)" % url)
module.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from einops import rearrange, repeat
5
+
6
+ class Residual(nn.Module):
7
+ def __init__(self, fn):
8
+ super().__init__()
9
+ self.fn = fn
10
+ def forward(self, x, **kwargs):
11
+ return self.fn(x, **kwargs) + x
12
+
13
+ class PreNorm(nn.Module):
14
+ def __init__(self, dim, fn):
15
+ super().__init__()
16
+ self.norm = nn.LayerNorm(dim)
17
+ self.fn = fn
18
+ def forward(self, x, **kwargs):
19
+ return self.fn(self.norm(x), **kwargs)
20
+
21
+ class FeedForward(nn.Module):
22
+ def __init__(self, dim, hidden_dim, dropout = 0.):
23
+ super().__init__()
24
+ self.net = nn.Sequential(
25
+ nn.Linear(dim, hidden_dim),
26
+ nn.GELU(),
27
+ nn.Dropout(dropout),
28
+ nn.Linear(hidden_dim, dim),
29
+ nn.Dropout(dropout)
30
+ )
31
+ def forward(self, x):
32
+ return self.net(x)
33
+
34
+ class Attention(nn.Module):
35
+ def __init__(self, dim, heads, dim_head, dropout):
36
+ super().__init__()
37
+ inner_dim = dim_head * heads
38
+ self.heads = heads
39
+ self.scale = dim_head ** -0.5
40
+
41
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
42
+ self.to_out = nn.Sequential(
43
+ nn.Linear(inner_dim, dim),
44
+ nn.Dropout(dropout)
45
+ )
46
+ def forward(self, x, mask = None):
47
+ # x:[b,n,dim]
48
+ b, n, _, h = *x.shape, self.heads
49
+
50
+ # get qkv tuple:([b,n,head_num*head_dim],[...],[...])
51
+ qkv = self.to_qkv(x).chunk(3, dim = -1)
52
+ # split q,k,v from [b,n,head_num*head_dim] -> [b,head_num,n,head_dim]
53
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
54
+ # transpose(k) * q / sqrt(head_dim) -> [b,head_num,n,n]
55
+ dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
56
+ mask_value = -torch.finfo(dots.dtype).max
57
+
58
+ # mask value: -inf
59
+ if mask is not None:
60
+ mask = F.pad(mask.flatten(1), (1, 0), value = True)
61
+ assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
62
+ mask = mask[:, None, :] * mask[:, :, None]
63
+ dots.masked_fill_(~mask, mask_value)
64
+ del mask
65
+
66
+ # softmax normalization -> attention matrix
67
+ attn = dots.softmax(dim=-1)
68
+ # value * attention matrix -> output
69
+ out = torch.einsum('bhij,bhjd->bhid', attn, v)
70
+ # cat all output -> [b, n, head_num*head_dim]
71
+ out = rearrange(out, 'b h n d -> b n (h d)')
72
+ out = self.to_out(out)
73
+ return out
74
+
75
+ class CrossAttention(nn.Module):
76
+ def __init__(self, dim, heads, dim_head, dropout):
77
+ super().__init__()
78
+ inner_dim = dim_head * heads
79
+ project_out = not (heads == 1 and dim_head == dim)
80
+
81
+ self.heads = heads
82
+ self.scale = dim_head ** -0.5
83
+
84
+ self.to_k = nn.Linear(dim, inner_dim , bias=False)
85
+ self.to_v = nn.Linear(dim, inner_dim , bias = False)
86
+ self.to_q = nn.Linear(dim, inner_dim, bias = False)
87
+
88
+ self.to_out = nn.Sequential(
89
+ nn.Linear(inner_dim, dim),
90
+ nn.Dropout(dropout)
91
+ ) if project_out else nn.Identity()
92
+
93
+ def forward(self, x_qkv):
94
+ b, n, _, h = *x_qkv.shape, self.heads
95
+
96
+ k = self.to_k(x_qkv)
97
+ k = rearrange(k, 'b n (h d) -> b h n d', h = h)
98
+
99
+ v = self.to_v(x_qkv)
100
+ v = rearrange(v, 'b n (h d) -> b h n d', h = h)
101
+
102
+ q = self.to_q(x_qkv[:, 0].unsqueeze(1))
103
+ q = rearrange(q, 'b n (h d) -> b h n d', h = h)
104
+
105
+ dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
106
+
107
+ attn = dots.softmax(dim=-1)
108
+
109
+ out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
110
+ out = rearrange(out, 'b h n d -> b n (h d)')
111
+ out = self.to_out(out)
112
+ return out
113
+
114
+ class Transformer(nn.Module):
115
+ def __init__(self, dim, depth, heads, dim_head, mlp_head, dropout, num_channel):
116
+ super().__init__()
117
+
118
+ self.layers = nn.ModuleList([])
119
+ for _ in range(depth):
120
+ self.layers.append(nn.ModuleList([
121
+ Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
122
+ Residual(PreNorm(dim, FeedForward(dim, mlp_head, dropout = dropout)))
123
+ ]))
124
+
125
+ self.skipcat = nn.ModuleList([])
126
+ for _ in range(depth-2):
127
+ self.skipcat.append(nn.Conv2d(num_channel+1, num_channel+1, [1, 2], 1, 0))
128
+
129
+ def forward(self, x, mask = None):
130
+ for attn, ff in self.layers:
131
+ x = attn(x, mask = mask)
132
+ x = ff(x)
133
+ return x
134
+
135
+ class SSTransformer(nn.Module):
136
+ def __init__(self, dim, depth, heads, dim_head, mlp_head, b_dim, b_depth, b_heads, b_dim_head, b_mlp_head, num_patches, dropout):
137
+ super().__init__()
138
+
139
+ self.layers = nn.ModuleList([])
140
+ self.k_layers = nn.ModuleList([])
141
+ self.channels_to_embedding = nn.Linear(num_patches, b_dim)
142
+ self.cls_token = nn.Parameter(torch.randn(1, 1, b_dim))
143
+ for _ in range(depth):
144
+ self.layers.append(nn.ModuleList([
145
+ Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
146
+ Residual(PreNorm(dim, FeedForward(dim, mlp_head, dropout = dropout)))
147
+ ]))
148
+ for _ in range(b_depth):
149
+ self.k_layers.append(nn.ModuleList([
150
+ Residual(PreNorm(b_dim, Attention(dim=b_dim, heads=b_heads, dim_head=b_dim_head, dropout = dropout))),
151
+ Residual(PreNorm(b_dim, FeedForward(b_dim, b_mlp_head, dropout = dropout)))
152
+ ]))
153
+
154
+ def forward(self, x, mask = None):
155
+ for attn, ff in self.layers:
156
+ x = attn(x, mask = mask)
157
+ x = ff(x)
158
+ x = rearrange(x, 'b n d -> b d n')
159
+ x = self.channels_to_embedding(x)
160
+ b, d, n = x.shape
161
+ cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
162
+ x = torch.cat((cls_tokens, x), dim = 1)
163
+ for attn, ff in self.k_layers:
164
+ x = attn(x, mask = mask)
165
+ x = ff(x)
166
+ return x
167
+
168
+ class SSTransformer_pyramid(nn.Module):
169
+ def __init__(self, dim, depth, heads, dim_head, mlp_head, b_dim, b_depth, b_heads, b_dim_head, b_mlp_head, num_patches, dropout):
170
+ super().__init__()
171
+
172
+ self.layers = nn.ModuleList([])
173
+ self.k_layers = nn.ModuleList([])
174
+ self.channels_to_embedding = nn.Linear(num_patches, b_dim)
175
+ self.cls_token = nn.Parameter(torch.randn(1, 1, b_dim))
176
+ for _ in range(depth):
177
+ self.layers.append(nn.ModuleList([
178
+ Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
179
+ Residual(PreNorm(dim, FeedForward(dim, mlp_head, dropout = dropout)))
180
+ ]))
181
+ for _ in range(b_depth):
182
+ self.k_layers.append(nn.ModuleList([
183
+ Residual(PreNorm(b_dim, Attention(dim=b_dim, heads=b_heads, dim_head=b_dim_head, dropout = dropout))),
184
+ Residual(PreNorm(b_dim, FeedForward(b_dim, b_mlp_head, dropout = dropout)))
185
+ ]))
186
+
187
+ def forward(self, x, mask = None):
188
+ for attn, ff in self.layers:
189
+ x = attn(x, mask = mask)
190
+ x = ff(x)
191
+ out_feature = x
192
+ x = rearrange(x, 'b n d -> b d n')
193
+ x = self.channels_to_embedding(x)
194
+ b, d, n = x.shape
195
+ cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
196
+ x = torch.cat((cls_tokens, x), dim = 1)
197
+ for attn, ff in self.k_layers:
198
+ x = attn(x, mask = mask)
199
+ x = ff(x)
200
+ return x, out_feature
201
+
202
+ class ViT(nn.Module):
203
+ def __init__(self, image_size, near_band, num_patches, num_classes, dim, depth, heads, mlp_dim, pool='cls', channel_dim=1, dim_head = 16, dropout=0., emb_dropout=0., mode='ViT'):
204
+ super().__init__()
205
+
206
+ patch_dim = image_size ** 2 * near_band
207
+
208
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
209
+ self.patch_to_embedding = nn.Linear(channel_dim, dim)
210
+ self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
211
+
212
+ self.dropout = nn.Dropout(emb_dropout)
213
+ self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, num_patches, mode)
214
+
215
+ self.pool = pool
216
+ self.to_latent = nn.Identity()
217
+
218
+ self.mlp_head = nn.Sequential(
219
+ nn.LayerNorm(dim),
220
+ nn.Linear(dim, num_classes)
221
+ )
222
+ def forward(self, x, mask = None):
223
+ # patchs[batch, patch_num, patch_size*patch_size*c] [batch,200,145*145]
224
+ # x = rearrange(x, 'b c h w -> b c (h w)')
225
+ ## embedding every patch vector to embedding size: [batch, patch_num, embedding_size]
226
+
227
+ x = self.patch_to_embedding(x) #[b,n,dim]
228
+ b, n, _ = x.shape
229
+
230
+ # add position embedding
231
+ cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) #[b,1,dim]
232
+ x = torch.cat((cls_tokens, x), dim = 1) #[b,n+1,dim]
233
+ x += self.pos_embedding[:, :(n + 1)]
234
+ x = self.dropout(x)
235
+ # transformer: x[b,n + 1,dim] -> x[b,n + 1,dim]
236
+ x = self.transformer(x, mask)
237
+ # classification: using cls_token output
238
+ x = self.to_latent(x[:,0])
239
+
240
+ # MLP classification layer
241
+ return self.mlp_head(x)
242
+
243
+ class SSFormer_v4(nn.Module):
244
+ def __init__(self, dim, depth, heads, dim_head, mlp_head, b_dim, b_depth, b_heads, b_dim_head, b_mlp_head, num_patches, dropout, mode):
245
+ super().__init__()
246
+
247
+ self.layers = nn.ModuleList([])
248
+ self.k_layers = nn.ModuleList([])
249
+ self.channels_to_embedding = nn.Linear(num_patches, b_dim)
250
+ self.cls_token = nn.Parameter(torch.randn(1, 1, b_dim))
251
+ for _ in range(depth):
252
+ self.layers.append(nn.ModuleList([
253
+ Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
254
+ Residual(PreNorm(dim, FeedForward(dim, mlp_head, dropout = dropout)))
255
+ ]))
256
+ for _ in range(b_depth):
257
+ self.k_layers.append(nn.ModuleList([
258
+ Residual(PreNorm(b_dim, Attention(dim=b_dim, heads=b_heads, dim_head=b_dim_head, dropout = dropout))),
259
+ Residual(PreNorm(b_dim, FeedForward(b_dim, b_mlp_head, dropout = dropout)))
260
+ ]))
261
+ self.mode = mode
262
+
263
+ def forward(self, x, c, mask = None):
264
+ for attn, ff in self.layers:
265
+ x = attn(x, mask = mask)
266
+ x = ff(x)
267
+ x = rearrange(x, 'b n d -> b d n')
268
+ x = self.channels_to_embedding(x)
269
+ b, d, n = x.shape
270
+ cls_tokens = repeat(c, '() n d -> b n d', b = b)
271
+ x = torch.cat((cls_tokens, x), dim = 1)
272
+ for attn, ff in self.k_layers:
273
+ x = attn(x, mask = mask)
274
+ x = ff(x)
275
+ return x
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ einops
2
+ patchify
3
+ argparse
4
+ scipy
5
+ scikit-learn
6
+ torch
7
+ streamlit-aggrid
8
+ plotly
9
+ collection
sstvit.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, einsum
3
+ import torch.nn.functional as F
4
+ from einops import rearrange, repeat
5
+ from einops.layers.torch import Rearrange
6
+ from module import Attention, PreNorm, FeedForward, CrossAttention, SSTransformer
7
+ import numpy as np
8
+
9
+ class SSTTransformerEncoder(nn.Module):
10
+
11
+ def __init__(self, dim, depth, heads, dim_head, mlp_dim, b_dim, b_depth, b_heads, b_dim_head, b_mlp_head, num_patches, cross_attn_depth=3, cross_attn_heads=8, dropout = 0):
12
+ super().__init__()
13
+
14
+ self.transformer = SSTransformer(dim, depth, heads, dim_head, mlp_dim, b_dim, b_depth, b_heads, b_dim_head, b_mlp_head, num_patches, dropout)
15
+
16
+ self.cross_attn_layers = nn.ModuleList([])
17
+ for _ in range(cross_attn_depth):
18
+ self.cross_attn_layers.append(PreNorm(b_dim, CrossAttention(b_dim, heads = cross_attn_heads, dim_head=dim_head, dropout=0)))
19
+
20
+ def forward(self, x1, x2):
21
+ x1 = self.transformer(x1)
22
+ x2 = self.transformer(x2)
23
+
24
+ for cross_attn in self.cross_attn_layers:
25
+ x1_class = x1[:, 0]
26
+ x1 = x1[:, 1:]
27
+ x2_class = x2[:, 0]
28
+ x2 = x2[:, 1:]
29
+
30
+ # Cross Attn
31
+ cat1_q = x1_class.unsqueeze(1)
32
+ cat1_qkv = torch.cat((cat1_q, x2), dim=1)
33
+ cat1_out = cat1_q+cross_attn(cat1_qkv)
34
+ x1 = torch.cat((cat1_out, x1), dim=1)
35
+ cat2_q = x2_class.unsqueeze(1)
36
+ cat2_qkv = torch.cat((cat2_q, x1), dim=1)
37
+ cat2_out = cat2_q+cross_attn(cat2_qkv)
38
+ x2 = torch.cat((cat2_out, x2), dim=1)
39
+
40
+ return cat1_out, cat2_out
41
+
42
+ class SSTViT(nn.Module):
43
+ def __init__(self, image_size, near_band, num_patches, num_classes, dim, depth, heads, mlp_dim, b_dim, b_depth, b_heads, b_dim_head, b_mlp_head, pool='cls', channels=1, dim_head = 16, dropout=0., emb_dropout=0., multi_scale_enc_depth=1):
44
+ super().__init__()
45
+
46
+ patch_dim = image_size ** 2 * near_band
47
+ self.num_patches = num_patches+1
48
+ self.pos_embedding = nn.Parameter(torch.randn(1, self.num_patches, dim))
49
+ self.patch_to_embedding = nn.Linear(patch_dim, dim)
50
+ self.cls_token_t1 = nn.Parameter(torch.randn(1, 1, dim))
51
+ self.cls_token_t2 = nn.Parameter(torch.randn(1, 1, dim))
52
+
53
+ self.dropout = nn.Dropout(emb_dropout)
54
+
55
+ self.multi_scale_transformers = nn.ModuleList([])
56
+ for _ in range(multi_scale_enc_depth):
57
+ self.multi_scale_transformers.append(SSTTransformerEncoder(dim, depth, heads, dim_head, mlp_dim,b_dim, b_depth, b_heads, b_dim_head, b_mlp_head, self.num_patches,
58
+ dropout = 0.))
59
+
60
+ self.pool = pool
61
+ self.to_latent = nn.Identity()
62
+
63
+ self.mlp_head = nn.Sequential(
64
+ nn.LayerNorm(b_dim),
65
+ nn.Linear(b_dim, num_classes)
66
+ )
67
+ def forward(self, x1, x2):
68
+ # patchs[batch, patch_num, patch_size*patch_size*c] [batch,200,145*145]
69
+ # x = rearrange(x, 'b c h w -> b c (h w)')
70
+ ## embedding every patch vector to embedding size: [batch, patch_num, embedding_size]
71
+ x1 = self.patch_to_embedding(x1) #[b,n,dim]
72
+ x2 = self.patch_to_embedding(x2)
73
+ b, n, _ = x1.shape
74
+ # add position embedding
75
+ cls_tokens_t1 = repeat(self.cls_token_t1, '() n d -> b n d', b = b) #[b,1,dim]
76
+ cls_tokens_t2 = repeat(self.cls_token_t2, '() n d -> b n d', b = b)
77
+
78
+ x1 = torch.cat((cls_tokens_t1, x1), dim = 1) #[b,n+1,dim]
79
+ x1 += self.pos_embedding[:, :(n + 1)]
80
+ x1 = self.dropout(x1)
81
+ x2 = torch.cat((cls_tokens_t2, x2), dim = 1) #[b,n+1,dim]
82
+ x2 += self.pos_embedding[:, :(n + 1)]
83
+ x2 = self.dropout(x2)
84
+ # transformer: x[b,n + 1,dim] -> x[b,n + 1,dim]
85
+ for multi_scale_transformer in self.multi_scale_transformers:
86
+ out1, out2 = multi_scale_transformer(x1, x2)
87
+ # classification: using cls_token output
88
+ out1 = self.to_latent(out1[:,0])
89
+ out2 = self.to_latent(out2[:,0])
90
+ out = out1+out2
91
+ # MLP classification layer
92
+ return self.mlp_head(out)
93
+
94
+