import torch import torch.nn as nn from monai.networks.nets import ViT, UNETR import os class ViTUNETRSegmentationModel(nn.Module): def __init__(self, simclr_ckpt_path: str, img_size=(96, 96, 96), in_channels=1, out_channels=1): super().__init__() # Load ViT backbone self.vit = ViT( in_channels=in_channels, img_size=img_size, patch_size=(16, 16, 16), hidden_size=768, mlp_dim=3072, num_layers=12, num_heads=12, save_attn=False, ) # Load SimCLR weights if provided if False:#simclr_ckpt_path and os.path.exists(simclr_ckpt_path): ckpt = torch.load(simclr_ckpt_path, map_location='cpu', weights_only=False) state_dict = ckpt.get('state_dict', ckpt) backbone_state_dict = {k[9:]: v for k, v in state_dict.items() if k.startswith('backbone.')} missing, unexpected = self.vit.load_state_dict(backbone_state_dict, strict=False) print(f"Loaded SimCLR backbone weights. Missing: {len(missing)}, Unexpected: {len(unexpected)}") else: print("Warning: SimCLR checkpoint not found or not provided. Using randomly initialized backbone.") # UNETR decoder self.unetr = UNETR( in_channels=in_channels, out_channels=out_channels, img_size=img_size, feature_size=16, hidden_size=768, mlp_dim=3072, num_heads=12, norm_name='instance', res_block=True, dropout_rate=0.0 ) # Transfer ViT weights to UNETR encoder self.unetr.vit.load_state_dict(self.vit.state_dict(), strict=True) print("="*10) print("ViT loaded for segmentation") print("="*10) def forward(self, x): return self.unetr(x)