Spaces:
Runtime error
Runtime error
File size: 5,954 Bytes
3b96cb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
"""
Codegoni A, Lombardi G, Ferrari A.
TINYCD: A (Not So) Deep Learning Model For Change Detection[J].
arXiv preprint arXiv:2207.13159, 2022.
The code in this file is borrowed from:
https://github.com/AndreaCodegoni/Tiny_model_4_CD
"""
from typing import List, Optional
import torchvision
from torch import Tensor, reshape, stack
from torch.nn import (Conv2d, InstanceNorm2d, Module, ModuleList, PReLU,
Sequential, Upsample)
from opencd.registry import MODELS
class PixelwiseLinear(Module):
def __init__(
self,
fin: List[int],
fout: List[int],
last_activation: Module = None,
) -> None:
assert len(fout) == len(fin)
super().__init__()
n = len(fin)
self._linears = Sequential(
*[
Sequential(
Conv2d(fin[i], fout[i], kernel_size=1, bias=True),
PReLU()
if i < n - 1 or last_activation is None
else last_activation,
)
for i in range(n)
]
)
def forward(self, x: Tensor) -> Tensor:
# Processing the tensor:
return self._linears(x)
class MixingBlock(Module):
def __init__(
self,
ch_in: int,
ch_out: int,
):
super().__init__()
self._convmix = Sequential(
Conv2d(ch_in, ch_out, 3, groups=ch_out, padding=1),
PReLU(),
InstanceNorm2d(ch_out),
)
def forward(self, x: Tensor, y: Tensor) -> Tensor:
# Packing the tensors and interleaving the channels:
mixed = stack((x, y), dim=2)
mixed = reshape(mixed, (x.shape[0], -1, x.shape[2], x.shape[3]))
# Mixing:
return self._convmix(mixed)
class MixingMaskAttentionBlock(Module):
"""use the grouped convolution to make a sort of attention"""
def __init__(
self,
ch_in: int,
ch_out: int,
fin: List[int],
fout: List[int],
generate_masked: bool = False,
):
super().__init__()
self._mixing = MixingBlock(ch_in, ch_out)
self._linear = PixelwiseLinear(fin, fout)
self._final_normalization = InstanceNorm2d(ch_out) if generate_masked else None
self._mixing_out = MixingBlock(ch_in, ch_out) if generate_masked else None
def forward(self, x: Tensor, y: Tensor) -> Tensor:
z_mix = self._mixing(x, y)
z = self._linear(z_mix)
z_mix_out = 0 if self._mixing_out is None else self._mixing_out(x, y)
return (
z
if self._final_normalization is None
else self._final_normalization(z_mix_out * z)
)
class UpMask(Module):
def __init__(
self,
scale_factor: float,
nin: int,
nout: int,
):
super().__init__()
self._upsample = Upsample(
scale_factor=scale_factor, mode="bilinear", align_corners=True
)
self._convolution = Sequential(
Conv2d(nin, nin, 3, 1, groups=nin, padding=1),
PReLU(),
InstanceNorm2d(nin),
Conv2d(nin, nout, kernel_size=1, stride=1),
PReLU(),
InstanceNorm2d(nout),
)
def forward(self, x: Tensor, y: Optional[Tensor] = None) -> Tensor:
x = self._upsample(x)
if y is not None:
x = x * y
return self._convolution(x)
def _get_backbone(
bkbn_name, pretrained, output_layer_bkbn, freeze_backbone
) -> ModuleList:
# The whole model:
entire_model = getattr(torchvision.models, bkbn_name)(
pretrained=pretrained
).features
# Slicing it:
derived_model = ModuleList([])
for name, layer in entire_model.named_children():
derived_model.append(layer)
if name == output_layer_bkbn:
break
# Freezing the backbone weights:
if freeze_backbone:
for param in derived_model.parameters():
param.requires_grad = False
return derived_model
@MODELS.register_module()
class TinyCD(Module):
def __init__(
self,
in_channels,
bkbn_name="efficientnet_b4",
pretrained=True,
output_layer_bkbn="3",
freeze_backbone=False,
):
super().__init__()
# Load the pretrained backbone according to parameters:
self._backbone = _get_backbone(
bkbn_name, pretrained, output_layer_bkbn, freeze_backbone
)
# Initialize mixing blocks:
self._first_mix = MixingMaskAttentionBlock(6, 3, [3, 10, 5], [10, 5, 1])
self._mixing_mask = ModuleList(
[
MixingMaskAttentionBlock(48, 24, [24, 12, 6], [12, 6, 1]),
MixingMaskAttentionBlock(64, 32, [32, 16, 8], [16, 8, 1]),
MixingBlock(112, 56),
]
)
# Initialize Upsampling blocks:
self._up = ModuleList(
[
UpMask(2, 56, 64),
UpMask(2, 64, 64),
UpMask(2, 64, 32),
]
)
# Final classification layer:
self._classify = PixelwiseLinear([32, 16], [16, 1], None) # out_channels = 8
def forward(self, x1: Tensor, x2: Tensor) -> Tensor:
features = self._encode(x1, x2)
latents = self._decode(features)
out = self._classify(latents)
return (out,)
def _encode(self, ref, test) -> List[Tensor]:
features = [self._first_mix(ref, test)]
for num, layer in enumerate(self._backbone):
ref, test = layer(ref), layer(test)
if num != 0:
features.append(self._mixing_mask[num - 1](ref, test))
return features
def _decode(self, features) -> Tensor:
upping = features[-1]
for i, j in enumerate(range(-2, -5, -1)):
upping = self._up[i](upping, features[j])
return upping
|