Spaces:
Paused
Paused
import streamlit as st | |
import io | |
import collections | |
from scipy.io import loadmat | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
import numpy as np | |
import torch | |
import argparse | |
import torch.nn as nn | |
import torch.utils.data as Data | |
import torch.backends.cudnn as cudnn | |
from scipy.io import loadmat | |
from scipy.io import savemat | |
from torch import optim | |
from torch.autograd import Variable | |
from sstvit import SSTViT | |
from sklearn.metrics import confusion_matrix | |
import matplotlib.pyplot as plt | |
from matplotlib import colors | |
import numpy as np | |
from patchify import patchify, unpatchify | |
import time | |
from matplotlib import colors as mcolors | |
import base64 | |
import pandas as pd | |
import st_aggrid | |
import os | |
import json | |
import plotly.express as px | |
css=''' | |
<style> | |
section.main > div {max-width:60rem} | |
</style> | |
''' | |
st.markdown(css, unsafe_allow_html=True) | |
class Args(dict): | |
__setattr__ = dict.__setitem__ | |
__getattr__ = dict.__getitem__ | |
args = { | |
'dataset' : 'mg', | |
'flag_test' : 'train', | |
'gpu_id' : 0, | |
'seed' : int(0), | |
'batch_size' : int(64), | |
'test_freq' : int(10), | |
'patches' : int(5), | |
'band_patches' : int(1), | |
'epoches' : int(2000), | |
'learning_rate' : float(5e-4), | |
'gamma' : float(0.9), | |
'weight_decay' : float(0), | |
'train_number' : int(500) | |
} | |
args = Args(args) # dict2object | |
obj = args.copy() # object2dict | |
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id) | |
def test_epoch(model, test_loader): | |
pre = np.array([]) | |
for batch_idx, (batch_data_t1, batch_data_t2) in enumerate(test_loader): | |
batch_data_t1 = batch_data_t1 | |
batch_data_t2 = batch_data_t2 | |
batch_pred = model(batch_data_t1,batch_data_t2) | |
_, pred = batch_pred.topk(1, 1, True, True) | |
pp = pred.squeeze() | |
pre = np.append(pre, pp.data.cpu().numpy()) | |
return pre | |
mdic = ['Before','After','Before','After'] | |
colors = ['#3b68f8', '#ff0201', '#23fe01'] #-1,0,1,2,3 | |
cmap = mcolors.ListedColormap(colors) | |
# Parameter Setting | |
np.random.seed(args.seed) | |
torch.manual_seed(args.seed) | |
torch.cuda.manual_seed(args.seed) | |
cudnn.deterministic = True | |
cudnn.benchmark = False | |
def encode_masks_to_rgb(masks): | |
colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0)] | |
# Create an empty RGB image | |
height, width = masks.shape | |
rgb_image = np.zeros((height, width, 3), dtype=np.uint8) | |
# Assign colors based on the mask values | |
for i in range(len(colors)): | |
mask_indices = masks == i | |
rgb_image[mask_indices] = colors[i] | |
return rgb_image | |
def count_pixel(pred): | |
image = Image.fromarray(pred) | |
# Define the colors you want to count in RGB format | |
color2label = { | |
(0, 0, 255): "Non Mangrove", | |
(255, 0, 0): "Mangrove Loss", | |
(0, 255, 0): "Mangrove Before", | |
} | |
# Create a flattened list of pixel values | |
pixels = list(image.getdata()) | |
# Count the number of pixels for each color | |
color_counts = collections.Counter(pixels) | |
# Calculate the total number of pixels in the image | |
total_pixels = len(pixels) | |
# Initialize a dictionary to store the average number of pixels for each class | |
average_counts = {color2label[label]: (count / total_pixels)*100 for label, count in color_counts.items()} | |
class_counts = {color2label[label]: count for label, count in color_counts.items()} | |
pix_avg = {} | |
pix_count = {} | |
for _, i in color2label.items(): | |
try: | |
pix_avg[i] = average_counts[i] | |
pix_count[i] = class_counts[i] | |
except: | |
pix_avg[i] = 0 | |
pix_count[i] = 0 | |
x = { | |
"class": list(pix_avg.keys()), | |
"percentage": list(pix_avg.values()), | |
"pixel_count": list(pix_count.values()) | |
} | |
# print(x) | |
return pd.DataFrame(x) | |
def count_pixel1(pred): | |
image = Image.fromarray(pred) | |
# Define the colors you want to count in RGB format | |
color2label = { | |
(0, 0, 255): "Non Mangrove", | |
(255, 0, 0): "Mangrove Loss", | |
(0, 255, 0): "Mangrove After", | |
} | |
# Create a flattened list of pixel values | |
pixels = list(image.getdata()) | |
# Count the number of pixels for each color | |
color_counts = collections.Counter(pixels) | |
# Calculate the total number of pixels in the image | |
total_pixels = len(pixels) | |
# Initialize a dictionary to store the average number of pixels for each class | |
average_counts = {color2label[label]: (count / total_pixels)*100 for label, count in color_counts.items()} | |
class_counts = {color2label[label]: count for label, count in color_counts.items()} | |
pix_avg = {} | |
pix_count = {} | |
for _, i in color2label.items(): | |
try: | |
pix_avg[i] = average_counts[i] | |
pix_count[i] = class_counts[i] | |
except: | |
pix_avg[i] = 0 | |
pix_count[i] = 0 | |
x = { | |
"class": list(pix_avg.keys()), | |
"percentage": list(pix_avg.values()), | |
"pixel_count": list(pix_count.values()) | |
} | |
# print(x) | |
return pd.DataFrame(x) | |
file = st.file_uploader("Upload file", type=['mat']) | |
if file: | |
data_img2 = loadmat(file)['data_img2'] | |
data_img1 = loadmat(file)['data_img1'] | |
st.subheader("Preview Dataset") | |
col1, col2 = st.columns(2) | |
with col1: | |
fig = plt.figure(figsize=(5, 5)) | |
plt.subplot(121) | |
plt.imshow(data_img1) | |
plt.title('Before', fontweight='bold') | |
plt.xticks([]) | |
plt.yticks([]) | |
plt.subplot(122) | |
plt.imshow(data_img2) | |
plt.title('After', fontweight='bold') | |
plt.xticks([]) | |
plt.yticks([]) | |
plt.show() | |
st.pyplot(fig) | |
holder = st.empty() | |
if holder.button("Start Prediction"): | |
start = time.time() | |
holder.empty() | |
with st.spinner("Processing, please wait around 7-15 minute"): | |
data_t1 = loadmat(file)['data_t1'] | |
data_t2 = loadmat(file)['data_t2'] | |
L_post = loadmat(file)['L_post'] | |
L_pre = loadmat(file)['L_pre'] | |
data_img1 = loadmat(file)['data_img1'] | |
data_img2 = loadmat(file)['data_img2'] | |
L_post = np.double(L_post) | |
L_post[L_post==0]=-0.8 | |
L_post[L_post==1]=0 | |
L_post[L_post==0]=-0.2 | |
L_pre = np.double(L_pre) | |
L_pre[L_pre==0]=-0.8 | |
L_pre[L_pre==1]=0 | |
L_pre[L_pre==0]=-0.2 | |
data_t1 = data_t1[:L_post.shape[0],:L_post.shape[1],:] | |
data_t2 = data_t2[:L_post.shape[0],:L_post.shape[1],:] | |
data_cb1 = np.zeros(shape=(L_post.shape[0],L_post.shape[1],11),dtype=np.float32) | |
data_cb2 = np.zeros(shape=(L_post.shape[0],L_post.shape[1],11),dtype=np.float32) | |
data_cb1[:,:,:10]=data_t1 | |
data_cb1[:,:,10]=L_pre | |
data_cb2[:,:,:10]=data_t2 | |
data_cb2[:,:,10]=L_post | |
height, width, band = data_cb1.shape | |
height=height-4 | |
width = width-4 | |
x1 = patchify(data_cb1, (5, 5, 11), step=1).reshape(-1,5*5, 11) | |
x2 = patchify(data_cb2, (5, 5, 11), step=1).reshape(-1,5*5, 11) | |
# create model | |
model = SSTViT( | |
image_size = 5, | |
near_band = args.band_patches, | |
num_patches = 11, | |
num_classes = 3, | |
dim = 32, | |
depth = 2, | |
heads = 4, | |
dim_head=16, | |
mlp_dim = 8, | |
b_dim = 512, | |
b_depth = 3, | |
b_heads = 8, | |
b_dim_head= 32, | |
b_mlp_head = 8, | |
dropout = 0.2, | |
emb_dropout = 0.1, | |
) | |
model.load_state_dict(torch.load("model/lsstformer.pth",map_location=torch.device("cpu"))) | |
x1_true_band=torch.from_numpy(x1.transpose(0,2,1)).type(torch.FloatTensor) | |
x2_true_band=torch.from_numpy(x1.transpose(0,2,1)).type(torch.FloatTensor) | |
Label_true=Data.TensorDataset(x1_true_band,x2_true_band) | |
label_true_loader=Data.DataLoader(Label_true,batch_size=100,shuffle=False) | |
model.eval() | |
# output classification maps | |
pre_u = test_epoch(model, label_true_loader) | |
prediction_matrix = pre_u.reshape(height,width) | |
x1_true_band=torch.from_numpy(x1.transpose(0,2,1)).type(torch.FloatTensor) | |
x2_true_band=torch.from_numpy(x2.transpose(0,2,1)).type(torch.FloatTensor) | |
Label_true=Data.TensorDataset(x1_true_band,x2_true_band) | |
label_true_loader=Data.DataLoader(Label_true,batch_size=100,shuffle=False) | |
model.eval() | |
# output classification maps | |
pre_u = test_epoch(model, label_true_loader) | |
prediction_matrix2 = pre_u.reshape(height,width) | |
A = prediction_matrix.reshape(-1) | |
B = prediction_matrix2.reshape(-1) | |
mg = np.array(np.where(A==2)) | |
mg1 = np.array(np.where(B==2)) | |
mgls = np.array(np.where(B==1)) | |
class_counts = count_pixel(encode_masks_to_rgb(prediction_matrix)) | |
class_counts1 = count_pixel1(encode_masks_to_rgb(prediction_matrix2)) | |
with st.container(): | |
st.subheader("Prediction Result") | |
col1, col2 = st.columns(2) | |
with col1: | |
with st.container(): | |
fig = plt.figure(figsize=(10, 10)) | |
plt.subplot(121) | |
plt.imshow(prediction_matrix, cmap=cmap) | |
plt.title('Before',fontsize=25, fontweight='bold') | |
plt.xticks([]) | |
plt.yticks([]) | |
plt.subplot(122) | |
plt.imshow(prediction_matrix2, cmap=cmap) | |
plt.title('After',fontsize=25, fontweight='bold') | |
plt.xticks([]) | |
plt.yticks([]) | |
plt.show() | |
st.pyplot(fig) | |
buf = io.BytesIO() | |
fig.savefig(buf, format="png") | |
with col2: | |
with st.container(): | |
table_data = { | |
"Total mangrove before":f"{mg.shape[1]*100} m\u00B2", | |
"Total mangrove after":f"{mg1.shape[1]*100} m\u00B2", | |
"Total mangrove loss":f"{mgls.shape[1]*100} m\u00B2", | |
} | |
df = pd.DataFrame(list(table_data.items()), columns=['Key', 'Value']) | |
MIN_HEIGHT = 100 | |
MAX_HEIGHT = 180 | |
ROW_HEIGHT = 50 | |
# st.dataframe(df, hide_index=True, use_container_width=True) | |
st_aggrid.AgGrid(df,fit_columns_on_grid_load=True, height=min(MIN_HEIGHT + len(df) * ROW_HEIGHT, MAX_HEIGHT)) | |
with st.container(): | |
st.subheader("Pixel Distribution") | |
df = class_counts | |
df = df.drop(0) | |
df1 = df.drop(1) | |
df2 = class_counts1 | |
df3 = df2.drop(0) | |
vertical_concat = pd.concat([df1, df3], axis=0) | |
MIN_HEIGHT = 100 | |
MAX_HEIGHT = 180 | |
ROW_HEIGHT = 50 | |
vertical_concat = vertical_concat.iloc[[0,2,1],:] | |
st_aggrid.AgGrid(vertical_concat,fit_columns_on_grid_load=True, height=min(MIN_HEIGHT + len(vertical_concat) * ROW_HEIGHT, MAX_HEIGHT)) | |
fig = px.bar(vertical_concat, x='percentage', y='class', color='class', orientation='h', | |
color_discrete_sequence=["green","green", "red", "blue"], | |
category_orders={"class": ["Mangrove Before","Mangrove After", "Mangrove Loss", "Non Mangrove",]} | |
) | |
st.plotly_chart(fig,use_container_width=False) | |
end = time.time() | |
process = end-start | |
st.write('process',process) | |
show_file = st.empty() | |
if not file: | |
url = "https://drive.usercontent.google.com/download?id=1u48pMzRWQ2Etfjaq5A0CUjRtGKZaJoJy&export=download&authuser=2&confirm=t&uuid=52b0e01e-377f-42cb-8412-c84aa38a1740&at=APZUnTXslmuCCV1drJ2WWtkZr9BR%3A1710357675310" | |
show_file.info(""" | |
The model was trained using Sentinel-2 imagery, users can upload MAT files to perform LSST-Former for mangrove loss detection models that have been trained in this research. Tool for generate from Sentinel-2 to MAT file i will create later, please download demo dataset bellow. for better in mobile phone, use desktop mode. | |
""") | |
st.write("download demo datasets this [link](%s)" % url) |