Spaces:
Runtime error
Runtime error
File size: 1,454 Bytes
3b96cb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
# Copyright (c) Open-CD. All rights reserved.
import torch
import torch.nn as nn
from mmseg.models.utils import nlc_to_nchw
from mmseg.models.backbones import MixVisionTransformer
from opencd.registry import MODELS
@MODELS.register_module()
class IA_MixVisionTransformer(MixVisionTransformer):
def __init__(self,
interaction_cfg=(None, None, None, None),
**kwargs):
super().__init__(**kwargs)
assert self.num_stages == len(interaction_cfg), \
'The length of the `interaction_cfg` should be same as the `num_stages`.'
# cross-correlation
self.ccs = []
for ia_cfg in interaction_cfg:
if ia_cfg is None:
ia_cfg = dict(type='TwoIdentity')
self.ccs.append(MODELS.build(ia_cfg))
self.ccs = nn.ModuleList(self.ccs)
def forward(self, x1, x2):
outs = []
for i, layer in enumerate(self.layers):
x1, hw_shape = layer[0](x1)
x2, hw_shape = layer[0](x2)
for block in layer[1]:
x1 = block(x1, hw_shape)
x2 = block(x2, hw_shape)
x1 = layer[2](x1)
x2 = layer[2](x2)
x1 = nlc_to_nchw(x1, hw_shape)
x2 = nlc_to_nchw(x2, hw_shape)
x1, x2 = self.ccs[i](x1, x2)
if i in self.out_indices:
outs.append(torch.cat([x1, x2], dim=1))
return outs |