swin2mose (#1)
Browse files- swin2mose: runnable version (22a2f9f37b55bf583c5a2e28910e41d8124be4c8)
Co-authored-by: Leonardo Rossi <[email protected]>
- .gitignore +1 -0
- swin2_mose/libs.py +56 -0
- swin2_mose/model.py +9 -12
- swin2_mose/moe.py +3 -2
- swin2_mose/run.py +36 -20
- swin2_mose/utils.py +77 -56
- swin2_mose/weights/config-70.yml +46 -0
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
*.pyc
|
swin2_mose/libs.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def window_reverse(windows, window_size, H, W):
|
| 5 |
+
"""
|
| 6 |
+
Args:
|
| 7 |
+
windows: (num_windows*B, window_size, window_size, C)
|
| 8 |
+
window_size (int): Window size
|
| 9 |
+
H (int): Height of image
|
| 10 |
+
W (int): Width of image
|
| 11 |
+
|
| 12 |
+
Returns:
|
| 13 |
+
x: (B, H, W, C)
|
| 14 |
+
"""
|
| 15 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
| 16 |
+
x = windows.view(B, H // window_size, W // window_size, window_size,
|
| 17 |
+
window_size, -1)
|
| 18 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
| 19 |
+
return x
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class Mlp(nn.Module):
|
| 23 |
+
def __init__(self, in_features, hidden_features=None, out_features=None,
|
| 24 |
+
act_layer=nn.GELU, drop=0.):
|
| 25 |
+
super().__init__()
|
| 26 |
+
out_features = out_features or in_features
|
| 27 |
+
hidden_features = hidden_features or in_features
|
| 28 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 29 |
+
self.act = act_layer()
|
| 30 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 31 |
+
self.drop = nn.Dropout(drop)
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
x = self.fc1(x)
|
| 35 |
+
x = self.act(x)
|
| 36 |
+
x = self.drop(x)
|
| 37 |
+
x = self.fc2(x)
|
| 38 |
+
x = self.drop(x)
|
| 39 |
+
return x
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def window_partition(x, window_size):
|
| 43 |
+
"""
|
| 44 |
+
Args:
|
| 45 |
+
x: (B, H, W, C)
|
| 46 |
+
window_size (int): window size
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
windows: (num_windows*B, window_size, window_size, C)
|
| 50 |
+
"""
|
| 51 |
+
B, H, W, C = x.shape
|
| 52 |
+
x = x.view(B, H // window_size, window_size,
|
| 53 |
+
W // window_size, window_size, C)
|
| 54 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(
|
| 55 |
+
-1, window_size, window_size, C)
|
| 56 |
+
return windows
|
swin2_mose/model.py
CHANGED
|
@@ -1,10 +1,9 @@
|
|
| 1 |
#
|
| 2 |
-
# Source code: https://github.com/
|
| 3 |
#
|
| 4 |
-
#
|
| 5 |
-
#
|
| 6 |
-
#
|
| 7 |
-
# -----------------------------------------------------------------------------------
|
| 8 |
|
| 9 |
import math
|
| 10 |
import numpy as np
|
|
@@ -14,7 +13,7 @@ import torch.nn.functional as F
|
|
| 14 |
import torch.utils.checkpoint as checkpoint
|
| 15 |
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
| 16 |
|
| 17 |
-
from
|
| 18 |
from moe import MoE
|
| 19 |
|
| 20 |
|
|
@@ -746,9 +745,8 @@ class UpsampleOneStep(nn.Sequential):
|
|
| 746 |
|
| 747 |
|
| 748 |
|
| 749 |
-
class
|
| 750 |
-
r"""
|
| 751 |
-
A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`.
|
| 752 |
|
| 753 |
Args:
|
| 754 |
img_size (int | tuple(int)): Input image size. Default 64
|
|
@@ -784,8 +782,7 @@ class Swin2SR(nn.Module):
|
|
| 784 |
MoE_config=None,
|
| 785 |
use_rpe_bias=False,
|
| 786 |
**kwargs):
|
| 787 |
-
super(
|
| 788 |
-
print('==== SWIN 2SR')
|
| 789 |
num_in_ch = in_chans
|
| 790 |
num_out_ch = in_chans
|
| 791 |
num_feat = 64
|
|
@@ -1154,4 +1151,4 @@ class Swin2SR(nn.Module):
|
|
| 1154 |
flops += layer.flops()
|
| 1155 |
flops += H * W * 3 * self.embed_dim * self.embed_dim
|
| 1156 |
flops += self.upsample.flops()
|
| 1157 |
-
return flops
|
|
|
|
| 1 |
#
|
| 2 |
+
# Source code: https://github.com/IMPLabUniPr/swin2-mose
|
| 3 |
#
|
| 4 |
+
# ----------------------------------------------------------------------------
|
| 5 |
+
# https://arxiv.org/abs/2404.18924
|
| 6 |
+
# ----------------------------------------------------------------------------
|
|
|
|
| 7 |
|
| 8 |
import math
|
| 9 |
import numpy as np
|
|
|
|
| 13 |
import torch.utils.checkpoint as checkpoint
|
| 14 |
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
| 15 |
|
| 16 |
+
from libs import window_reverse, Mlp, window_partition
|
| 17 |
from moe import MoE
|
| 18 |
|
| 19 |
|
|
|
|
| 745 |
|
| 746 |
|
| 747 |
|
| 748 |
+
class Swin2MoSE(nn.Module):
|
| 749 |
+
r""" Swin2-MoSE
|
|
|
|
| 750 |
|
| 751 |
Args:
|
| 752 |
img_size (int | tuple(int)): Input image size. Default 64
|
|
|
|
| 782 |
MoE_config=None,
|
| 783 |
use_rpe_bias=False,
|
| 784 |
**kwargs):
|
| 785 |
+
super(Swin2MoSE, self).__init__()
|
|
|
|
| 786 |
num_in_ch = in_chans
|
| 787 |
num_out_ch = in_chans
|
| 788 |
num_feat = 64
|
|
|
|
| 1151 |
flops += layer.flops()
|
| 1152 |
flops += H * W * 3 * self.embed_dim * self.embed_dim
|
| 1153 |
flops += self.upsample.flops()
|
| 1154 |
+
return flops
|
swin2_mose/moe.py
CHANGED
|
@@ -18,7 +18,8 @@ from torch.distributions.normal import Normal
|
|
| 18 |
from copy import deepcopy
|
| 19 |
import numpy as np
|
| 20 |
|
| 21 |
-
from
|
|
|
|
| 22 |
|
| 23 |
class SparseDispatcher(object):
|
| 24 |
"""Helper for implementing a mixture of experts.
|
|
@@ -320,4 +321,4 @@ class MoE(nn.Module):
|
|
| 320 |
expert_outputs = [self.experts[i](expert_inputs[i])
|
| 321 |
for i in range(self.num_experts)]
|
| 322 |
y = dispatcher.combine(expert_outputs, cnn_combine=self.cnn_combine)
|
| 323 |
-
return y, loss
|
|
|
|
| 18 |
from copy import deepcopy
|
| 19 |
import numpy as np
|
| 20 |
|
| 21 |
+
from libs import Mlp as MLP
|
| 22 |
+
|
| 23 |
|
| 24 |
class SparseDispatcher(object):
|
| 25 |
"""Helper for implementing a mixture of experts.
|
|
|
|
| 321 |
expert_outputs = [self.experts[i](expert_inputs[i])
|
| 322 |
for i in range(self.num_experts)]
|
| 323 |
y = dispatcher.combine(expert_outputs, cnn_combine=self.cnn_combine)
|
| 324 |
+
return y, loss
|
swin2_mose/run.py
CHANGED
|
@@ -1,20 +1,36 @@
|
|
| 1 |
-
import
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import benchmark
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
import opensr_test
|
| 4 |
+
|
| 5 |
+
from utils import load_swin2_mose, load_config, run_swin2_mose
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
path = 'swin2_mose/weights/config-70.yml'
|
| 9 |
+
model_weights = "swin2_mose/weights/model-70.pt"
|
| 10 |
+
index = 2
|
| 11 |
+
|
| 12 |
+
# load config
|
| 13 |
+
cfg = load_config(path)
|
| 14 |
+
# load model
|
| 15 |
+
model = load_swin2_mose(model_weights, cfg)
|
| 16 |
+
|
| 17 |
+
# load the dataset
|
| 18 |
+
dataset = opensr_test.load("venus")
|
| 19 |
+
lr_dataset, hr_dataset = dataset["L2A"], dataset["HRharm"]
|
| 20 |
+
|
| 21 |
+
results = run_swin2_mose(model, lr_dataset[index], hr_dataset[index])
|
| 22 |
+
|
| 23 |
+
# Display the results
|
| 24 |
+
fig, ax = plt.subplots(1, 3, figsize=(10, 5))
|
| 25 |
+
ax[0].imshow(results['lr'].numpy().transpose(1, 2, 0)/3000)
|
| 26 |
+
ax[0].set_title("LR")
|
| 27 |
+
ax[0].axis("off")
|
| 28 |
+
ax[1].imshow(results["sr"].detach().numpy().transpose(1, 2, 0)/3000)
|
| 29 |
+
ax[1].set_title("SR")
|
| 30 |
+
ax[1].axis("off")
|
| 31 |
+
ax[2].imshow(results['hr'].numpy().transpose(1, 2, 0) / 3000)
|
| 32 |
+
ax[2].set_title("HR")
|
| 33 |
+
# plt.show()
|
| 34 |
+
|
| 35 |
+
# Run the experiment
|
| 36 |
+
benchmark.create_geotiff(model, run_swin2_mose, "all", "swin2mose/")
|
swin2_mose/utils.py
CHANGED
|
@@ -1,56 +1,77 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import yaml
|
| 3 |
+
|
| 4 |
+
from model import Swin2MoSE
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def to_shape(t1, t2):
|
| 8 |
+
t1 = t1[None].repeat(t2.shape[0], 1)
|
| 9 |
+
t1 = t1.view((t2.shape[:2] + (1, 1)))
|
| 10 |
+
return t1
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def norm(tensor, mean, std):
|
| 14 |
+
# get stats
|
| 15 |
+
mean = torch.tensor(mean).to(tensor.device)
|
| 16 |
+
std = torch.tensor(std).to(tensor.device)
|
| 17 |
+
# denorm
|
| 18 |
+
return (tensor - to_shape(mean, tensor)) / to_shape(std, tensor)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def denorm(tensor, mean, std):
|
| 22 |
+
# get stats
|
| 23 |
+
mean = torch.tensor(mean).to(tensor.device)
|
| 24 |
+
std = torch.tensor(std).to(tensor.device)
|
| 25 |
+
# denorm
|
| 26 |
+
return (tensor * to_shape(std, tensor)) + to_shape(mean, tensor)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def load_config(path):
|
| 30 |
+
# load config
|
| 31 |
+
with open(path, 'r') as f:
|
| 32 |
+
cfg = yaml.safe_load(f)
|
| 33 |
+
return cfg
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def load_swin2_mose(model_weights, cfg):
|
| 37 |
+
# load checkpoint
|
| 38 |
+
checkpoint = torch.load(model_weights)
|
| 39 |
+
|
| 40 |
+
# build model
|
| 41 |
+
sr_model = Swin2MoSE(**cfg['super_res']['model'])
|
| 42 |
+
sr_model.load_state_dict(
|
| 43 |
+
checkpoint['model_state_dict'])
|
| 44 |
+
|
| 45 |
+
sr_model.cfg = cfg
|
| 46 |
+
|
| 47 |
+
return sr_model
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def run_swin2_mose(model, lr, hr):
|
| 51 |
+
cfg = model.cfg
|
| 52 |
+
|
| 53 |
+
# norm fun
|
| 54 |
+
hr_stats = cfg['dataset']['stats']['tensor_05m_b2b3b4b8']
|
| 55 |
+
lr_stats = cfg['dataset']['stats']['tensor_10m_b2b3b4b8']
|
| 56 |
+
|
| 57 |
+
# select 10m lr bands: B02, B03, B04, B08 and hr bands
|
| 58 |
+
lr_orig = torch.tensor(lr)[None].float()[:, [3, 2, 1, 7]]
|
| 59 |
+
hr_orig = torch.tensor(hr)[None].float()
|
| 60 |
+
|
| 61 |
+
# normalize data
|
| 62 |
+
lr = norm(lr_orig, mean=lr_stats['mean'], std=lr_stats['std'])
|
| 63 |
+
hr = norm(hr_orig, mean=hr_stats['mean'], std=hr_stats['std'])
|
| 64 |
+
|
| 65 |
+
# predict a image
|
| 66 |
+
sr = model(lr)
|
| 67 |
+
if not torch.is_tensor(sr):
|
| 68 |
+
sr, _ = sr
|
| 69 |
+
|
| 70 |
+
# denorm sr
|
| 71 |
+
sr = denorm(sr, mean=hr_stats['mean'], std=hr_stats['std'])
|
| 72 |
+
|
| 73 |
+
return {
|
| 74 |
+
"lr": lr_orig[0],
|
| 75 |
+
"sr": sr[0],
|
| 76 |
+
"hr": hr_orig[0],
|
| 77 |
+
}
|
swin2_mose/weights/config-70.yml
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset:
|
| 2 |
+
root_path: data/sen2venus
|
| 3 |
+
stats:
|
| 4 |
+
use_minmax: true
|
| 5 |
+
tensor_05m_b2b3b4b8: {
|
| 6 |
+
mean: [444.21923828125, 715.9031372070312, 813.4345703125, 2604.867919921875],
|
| 7 |
+
std: [279.85552978515625, 385.3569641113281, 648.458984375, 796.9918212890625],
|
| 8 |
+
min: [-1025.0, -3112.0, -5122.0, -3851.0],
|
| 9 |
+
max: [14748.0, 14960.0, 16472.0, 16109.0]
|
| 10 |
+
}
|
| 11 |
+
tensor_10m_b2b3b4b8: {
|
| 12 |
+
mean: [443.78643798828125, 715.4202270507812, 813.0512084960938, 2602.813232421875],
|
| 13 |
+
std: [283.89276123046875, 389.26361083984375, 651.094970703125, 811.5682373046875],
|
| 14 |
+
min: [-848.0, -902.0, -946.0, -323.0],
|
| 15 |
+
max: [19684.0, 17982.0, 17064.0, 15958.0]
|
| 16 |
+
}
|
| 17 |
+
hr_name: tensor_05m_b2b3b4b8
|
| 18 |
+
lr_name: tensor_10m_b2b3b4b8
|
| 19 |
+
collate_fn: mods.v3.collate_fn
|
| 20 |
+
denorm: mods.v3.uncollate_fn
|
| 21 |
+
printable: mods.v3.printable
|
| 22 |
+
super_res: {
|
| 23 |
+
version: 'v2',
|
| 24 |
+
model: {
|
| 25 |
+
upscale: 2,
|
| 26 |
+
use_lepe: true,
|
| 27 |
+
use_cpb_bias: false,
|
| 28 |
+
use_rpe_bias: true,
|
| 29 |
+
mlp_ratio: 1,
|
| 30 |
+
MoE_config: {
|
| 31 |
+
k: 2,
|
| 32 |
+
num_experts: 8,
|
| 33 |
+
with_noise: false,
|
| 34 |
+
with_smart_merger: v1,
|
| 35 |
+
},
|
| 36 |
+
depths: [6, 6, 6, 6],
|
| 37 |
+
embed_dim: 90,
|
| 38 |
+
img_range: 1.,
|
| 39 |
+
img_size: 64,
|
| 40 |
+
in_chans: 4,
|
| 41 |
+
num_heads: [6, 6, 6, 6],
|
| 42 |
+
resi_connection: 1conv,
|
| 43 |
+
upsampler: pixelshuffledirect,
|
| 44 |
+
window_size: 16,
|
| 45 |
+
}
|
| 46 |
+
}
|