Spaces:
Runtime error
Runtime error
| # Copyright (c) Open-CD. All rights reserved. | |
| import torch | |
| import torch.nn as nn | |
| from mmseg.models.utils import nlc_to_nchw | |
| from mmseg.models.backbones import MixVisionTransformer | |
| from opencd.registry import MODELS | |
| class IA_MixVisionTransformer(MixVisionTransformer): | |
| def __init__(self, | |
| interaction_cfg=(None, None, None, None), | |
| **kwargs): | |
| super().__init__(**kwargs) | |
| assert self.num_stages == len(interaction_cfg), \ | |
| 'The length of the `interaction_cfg` should be same as the `num_stages`.' | |
| # cross-correlation | |
| self.ccs = [] | |
| for ia_cfg in interaction_cfg: | |
| if ia_cfg is None: | |
| ia_cfg = dict(type='TwoIdentity') | |
| self.ccs.append(MODELS.build(ia_cfg)) | |
| self.ccs = nn.ModuleList(self.ccs) | |
| def forward(self, x1, x2): | |
| outs = [] | |
| for i, layer in enumerate(self.layers): | |
| x1, hw_shape = layer[0](x1) | |
| x2, hw_shape = layer[0](x2) | |
| for block in layer[1]: | |
| x1 = block(x1, hw_shape) | |
| x2 = block(x2, hw_shape) | |
| x1 = layer[2](x1) | |
| x2 = layer[2](x2) | |
| x1 = nlc_to_nchw(x1, hw_shape) | |
| x2 = nlc_to_nchw(x2, hw_shape) | |
| x1, x2 = self.ccs[i](x1, x2) | |
| if i in self.out_indices: | |
| outs.append(torch.cat([x1, x2], dim=1)) | |
| return outs |