File size: 1,454 Bytes
3b96cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# Copyright (c) Open-CD. All rights reserved.
import torch
import torch.nn as nn

from mmseg.models.utils import nlc_to_nchw
from mmseg.models.backbones import MixVisionTransformer

from opencd.registry import MODELS


@MODELS.register_module()
class IA_MixVisionTransformer(MixVisionTransformer):
    def __init__(self, 
                 interaction_cfg=(None, None, None, None), 
                 **kwargs):
        super().__init__(**kwargs)
        assert self.num_stages == len(interaction_cfg), \
            'The length of the `interaction_cfg` should be same as the `num_stages`.'
        # cross-correlation
        self.ccs = []
        for ia_cfg in interaction_cfg:
            if ia_cfg is None:
                ia_cfg = dict(type='TwoIdentity')
            self.ccs.append(MODELS.build(ia_cfg))
        self.ccs = nn.ModuleList(self.ccs)
    
    def forward(self, x1, x2):
        outs = []

        for i, layer in enumerate(self.layers):
            x1, hw_shape = layer[0](x1)
            x2, hw_shape = layer[0](x2)
            for block in layer[1]:
                x1 = block(x1, hw_shape)
                x2 = block(x2, hw_shape)
            x1 = layer[2](x1)
            x2 = layer[2](x2)

            x1 = nlc_to_nchw(x1, hw_shape)
            x2 = nlc_to_nchw(x2, hw_shape)

            x1, x2 = self.ccs[i](x1, x2)
            if i in self.out_indices:
                outs.append(torch.cat([x1, x2], dim=1))
        return outs