Spaces:
Runtime error
Runtime error
| import copy | |
| from typing import List, Tuple, Optional | |
| import torch.nn.functional as F | |
| import einops | |
| import torch | |
| from mmcv.cnn import ConvModule, build_norm_layer | |
| from mmcv.cnn.bricks.transformer import PatchEmbed, FFN, build_transformer_layer | |
| from mmengine.dist import is_main_process | |
| from mmengine.model import BaseModule | |
| from peft import get_peft_config, get_peft_model | |
| from torch import Tensor, nn | |
| # from mmdet.utils import OptConfigType, MultiConfig | |
| from mmpretrain.models import resize_pos_embed | |
| from mmpretrain.models.backbones.vit_sam import Attention, window_partition, window_unpartition | |
| from mmseg.models import BaseSegmentor, EncoderDecoder | |
| from mmseg.models.decode_heads.decode_head import BaseDecodeHead | |
| from mmseg.models.utils import resize | |
| from mmseg.utils import OptConfigType, MultiConfig | |
| from opencd.registry import MODELS | |
| from mmpretrain.models import build_norm_layer as build_norm_layer_mmpretrain | |
| class MMPretrainSamVisionEncoder(BaseModule): | |
| def __init__( | |
| self, | |
| encoder_cfg, | |
| peft_cfg=None, | |
| init_cfg=None, | |
| ): | |
| super().__init__(init_cfg=init_cfg) | |
| vision_encoder = MODELS.build(encoder_cfg) | |
| vision_encoder.init_weights() | |
| if peft_cfg is not None and isinstance(peft_cfg, dict): | |
| config = { | |
| "peft_type": "LORA", | |
| "r": 16, | |
| 'target_modules': ["qkv"], | |
| "lora_alpha": 32, | |
| "lora_dropout": 0.05, | |
| "bias": "none", | |
| "inference_mode": False, | |
| } | |
| config.update(peft_cfg) | |
| peft_config = get_peft_config(config) | |
| self.vision_encoder = get_peft_model(vision_encoder, peft_config) | |
| if is_main_process(): | |
| self.vision_encoder.print_trainable_parameters() | |
| else: | |
| self.vision_encoder = vision_encoder | |
| # freeze the vision encoder | |
| for param in self.vision_encoder.parameters(): | |
| param.requires_grad = False | |
| for name, param in self.vision_encoder.named_parameters(): | |
| if 'down_channel' in name: | |
| param.requires_grad = True | |
| if 'soft_ffn' in name: | |
| param.requires_grad = True | |
| if is_main_process() and peft_cfg is not None: | |
| self.vision_encoder.print_trainable_parameters() | |
| def forward(self, x): | |
| return self.vision_encoder(x) | |
| class MLPSegHead(BaseDecodeHead): | |
| def __init__( | |
| self, | |
| out_size, | |
| interpolate_mode='bilinear', | |
| **kwargs | |
| ): | |
| super().__init__(input_transform='multiple_select', **kwargs) | |
| self.interpolate_mode = interpolate_mode | |
| num_inputs = len(self.in_channels) | |
| assert num_inputs == len(self.in_index) | |
| self.out_size = out_size | |
| self.convs = nn.ModuleList() | |
| for i in range(num_inputs): | |
| self.convs.append( | |
| ConvModule( | |
| in_channels=self.in_channels[i], | |
| out_channels=self.channels, | |
| kernel_size=1, | |
| stride=1, | |
| norm_cfg=self.norm_cfg, | |
| act_cfg=self.act_cfg)) | |
| self.fusion_conv = ConvModule( | |
| in_channels=self.channels * num_inputs, | |
| out_channels=self.channels, | |
| kernel_size=1, | |
| norm_cfg=self.norm_cfg) | |
| def forward(self, inputs): | |
| inputs = self._transform_inputs(inputs) | |
| outs = [] | |
| for idx in range(len(inputs)): | |
| x = inputs[idx] | |
| conv = self.convs[idx] | |
| outs.append( | |
| resize( | |
| input=conv(x), | |
| size=self.out_size, | |
| mode=self.interpolate_mode, | |
| align_corners=self.align_corners)) | |
| out = self.fusion_conv(torch.cat(outs, dim=1)) | |
| out = self.cls_seg(out) | |
| return out | |
| class LN2d(nn.Module): | |
| """A LayerNorm variant, popularized by Transformers, that performs | |
| pointwise mean and variance normalization over the channel dimension for | |
| inputs that have shape (batch_size, channels, height, width).""" | |
| def __init__(self, normalized_shape, eps=1e-6): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(normalized_shape)) | |
| self.bias = nn.Parameter(torch.zeros(normalized_shape)) | |
| self.eps = eps | |
| self.normalized_shape = (normalized_shape, ) | |
| def forward(self, x): | |
| u = x.mean(1, keepdim=True) | |
| s = (x - u).pow(2).mean(1, keepdim=True) | |
| x = (x - u) / torch.sqrt(s + self.eps) | |
| x = self.weight[:, None, None] * x + self.bias[:, None, None] | |
| return x | |
| class SequentialNeck(BaseModule): | |
| def __init__(self, necks): | |
| super().__init__() | |
| self.necks = nn.ModuleList() | |
| for neck in necks: | |
| self.necks.append(MODELS.build(neck)) | |
| def forward(self, *args, **kwargs): | |
| for neck in self.necks: | |
| args = neck(*args, **kwargs) | |
| return args | |
| class SimpleFPN(BaseModule): | |
| def __init__(self, | |
| backbone_channel: int, | |
| in_channels: List[int], | |
| out_channels: int, | |
| num_outs: int, | |
| conv_cfg: OptConfigType = None, | |
| norm_cfg: OptConfigType = None, | |
| act_cfg: OptConfigType = None, | |
| init_cfg: MultiConfig = None) -> None: | |
| super().__init__(init_cfg=init_cfg) | |
| assert isinstance(in_channels, list) | |
| self.backbone_channel = backbone_channel | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.num_ins = len(in_channels) | |
| self.num_outs = num_outs | |
| self.fpn1 = nn.Sequential( | |
| nn.ConvTranspose2d(self.backbone_channel, | |
| self.backbone_channel // 2, 2, 2), | |
| build_norm_layer(norm_cfg, self.backbone_channel // 2)[1], | |
| nn.GELU(), | |
| nn.ConvTranspose2d(self.backbone_channel // 2, | |
| self.backbone_channel // 4, 2, 2)) | |
| self.fpn2 = nn.Sequential( | |
| nn.ConvTranspose2d(self.backbone_channel, | |
| self.backbone_channel // 2, 2, 2)) | |
| self.fpn3 = nn.Sequential(nn.Identity()) | |
| self.fpn4 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2)) | |
| self.lateral_convs = nn.ModuleList() | |
| self.fpn_convs = nn.ModuleList() | |
| for i in range(self.num_ins): | |
| l_conv = ConvModule( | |
| in_channels[i], | |
| out_channels, | |
| 1, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg, | |
| inplace=False) | |
| fpn_conv = ConvModule( | |
| out_channels, | |
| out_channels, | |
| 3, | |
| padding=1, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg, | |
| inplace=False) | |
| self.lateral_convs.append(l_conv) | |
| self.fpn_convs.append(fpn_conv) | |
| def forward(self, input: Tensor) -> tuple: | |
| # build FPN | |
| inputs = [] | |
| inputs.append(self.fpn1(input)) | |
| inputs.append(self.fpn2(input)) | |
| inputs.append(self.fpn3(input)) | |
| inputs.append(self.fpn4(input)) | |
| # build laterals | |
| laterals = [ | |
| lateral_conv(inputs[i]) | |
| for i, lateral_conv in enumerate(self.lateral_convs) | |
| ] | |
| # build outputs | |
| # part 1: from original levels | |
| outs = [self.fpn_convs[i](laterals[i]) for i in range(self.num_ins)] | |
| # part 2: add extra levels | |
| if self.num_outs > len(outs): | |
| for i in range(self.num_outs - self.num_ins): | |
| outs.append(F.max_pool2d(outs[-1], 1, stride=2)) | |
| return tuple(outs) | |
| class TimeFusionTransformerEncoderLayer(BaseModule): | |
| def __init__(self, | |
| embed_dims: int, | |
| num_heads: int, | |
| feedforward_channels: int, | |
| drop_rate: float = 0., | |
| drop_path_rate: float = 0., | |
| num_fcs: int = 2, | |
| qkv_bias: bool = True, | |
| act_cfg: dict = dict(type='GELU'), | |
| norm_cfg: dict = dict(type='LN'), | |
| use_rel_pos: bool = False, | |
| window_size: int = 0, | |
| input_size: Optional[Tuple[int, int]] = None, | |
| init_cfg=None): | |
| super().__init__(init_cfg=init_cfg) | |
| self.embed_dims = embed_dims | |
| self.window_size = window_size | |
| self.ln1 = build_norm_layer_mmpretrain(norm_cfg, self.embed_dims) | |
| self.attn = Attention( | |
| embed_dims=embed_dims, | |
| num_heads=num_heads, | |
| qkv_bias=qkv_bias, | |
| use_rel_pos=use_rel_pos, | |
| input_size=input_size if window_size == 0 else | |
| (window_size, window_size), | |
| ) | |
| self.ln2 = build_norm_layer_mmpretrain(norm_cfg, self.embed_dims) | |
| self.ffn = FFN( | |
| embed_dims=embed_dims, | |
| feedforward_channels=feedforward_channels, | |
| num_fcs=num_fcs, | |
| ffn_drop=drop_rate, | |
| dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), | |
| act_cfg=act_cfg) | |
| if self.window_size == 0: | |
| in_channels = embed_dims * 2 | |
| self.down_channel = nn.Conv2d(in_channels, 1, kernel_size=1, stride=1, bias=False) | |
| self.down_channel.weight.data.fill_(1.0/in_channels) | |
| self.soft_ffn = nn.Sequential( | |
| nn.Conv2d(embed_dims, embed_dims, kernel_size=1, stride=1), | |
| nn.GELU(), | |
| nn.Conv2d(embed_dims, embed_dims, kernel_size=1, stride=1), | |
| ) | |
| def norm1(self): | |
| return self.ln1 | |
| def norm2(self): | |
| return self.ln2 | |
| def forward(self, x): | |
| shortcut = x | |
| x = self.ln1(x) | |
| # Window partition | |
| if self.window_size > 0: | |
| H, W = x.shape[1], x.shape[2] | |
| x, pad_hw = window_partition(x, self.window_size) | |
| x = self.attn(x) | |
| # Reverse window partition | |
| if self.window_size > 0: | |
| x = window_unpartition(x, self.window_size, pad_hw, (H, W)) | |
| x = shortcut + x | |
| x = self.ffn(self.ln2(x), identity=x) | |
| # # time phase fusion | |
| if self.window_size == 0: | |
| x = einops.rearrange(x, 'b h w d -> b d h w') # 2B, C, H, W | |
| x0 = x[:x.size(0)//2] | |
| x1 = x[x.size(0)//2:] # B, C, H, W | |
| x0_1 = torch.cat([x0, x1], dim=1) | |
| activate_map = self.down_channel(x0_1) | |
| activate_map = torch.sigmoid(activate_map) | |
| x0 = x0 + self.soft_ffn(x1 * activate_map) | |
| x1 = x1 + self.soft_ffn(x0 * activate_map) | |
| x = torch.cat([x0, x1], dim=0) | |
| x = einops.rearrange(x, 'b d h w -> b h w d') | |
| return x |