import sys # 加入父文件夹路径到sys.path中 sys.path.append(sys.path[0].replace('models', '')) import re import logging import math import json import pathlib import numpy as np from copy import deepcopy from pathlib import Path from einops import rearrange from collections import OrderedDict from dataclasses import dataclass from typing import Tuple, Union, Callable, Optional import torch from torch import nn import torch.nn.functional as F import torchvision.models as models from torch.utils.checkpoint import checkpoint from transformers import AutoModel,BertConfig,AutoTokenizer # from pytorch_pretrained_vit import ViT # from visualizer import get_local from models.transformer_decoder import * # from io import BytesIO # from petrel_client.client import Client # conf_path = '~/petreloss.conf' # client = Client(conf_path) from torch.autograd import Function import timm class ReverseLayerF(Function): @staticmethod def forward(ctx, x, alpha): ctx.alpha = alpha return x.view_as(x) @staticmethod def backward(ctx, grad_output): return grad_output.neg() * ctx.alpha, None class DomainClassifier(nn.Module): '''一个单层分类器 带梯度反转层''' def __init__(self, domain_nums=4, feature_dims=768): super().__init__() self.domain_nums = domain_nums self.feature_dims = feature_dims self.fc = nn.Linear(feature_dims, domain_nums) def forward(self, x): reverse_x = ReverseLayerF.apply(x, 1.0) return self.fc(reverse_x) class CLP_clinical(nn.Module): def __init__(self, bert_model_name: str, embed_dim: int = 768, freeze_layers:Union[Tuple[int, int], int] = None): super().__init__() self.bert_model = self._get_bert_basemodel(bert_model_name=bert_model_name, freeze_layers=freeze_layers) self.mlp_embed = nn.Sequential( nn.Linear(embed_dim, embed_dim), nn.GELU(), nn.Linear(embed_dim, embed_dim) ) self.embed_dim = embed_dim self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.init_parameters() def init_parameters(self): nn.init.constant_(self.logit_scale, np.log(1 / 0.07)) for m in self.mlp_embed: if isinstance(m, nn.Linear): nn.init.normal_(m.weight, std=self.embed_dim ** -0.5) def _get_bert_basemodel(self, bert_model_name, freeze_layers=None):#12 try: print(bert_model_name) config = BertConfig.from_pretrained(bert_model_name, output_hidden_states=True)#bert-base-uncased model = AutoModel.from_pretrained(bert_model_name, config=config)#, return_dict=True) print("Text feature extractor:", bert_model_name) print("bert encoder layers:",len(model.encoder.layer)) except: raise ("Invalid model name. Check the config file and pass a BERT model from transformers lybrary") if freeze_layers is not None: for layer_idx in freeze_layers: for param in list(model.encoder.layer[layer_idx].parameters()): param.requires_grad = False return model def encode_text(self, text): #input batch_size,token, return batch_size,dim output = self.bert_model(input_ids = text['input_ids'],attention_mask = text['attention_mask'] ) last_hidden_state, pooler_output, hidden_states = output[0],output[1],output[2] encode_out = self.mlp_embed(pooler_output) # encode_out = pooler_output return encode_out def forward(self, text): #input batch_size,token, return batch_size,dim output = self.bert_model(input_ids = text['input_ids'],attention_mask = text['attention_mask'] ) last_hidden_state, pooler_output, hidden_states = output[0],output[1],output[2] encode_out = self.mlp_embed(pooler_output) # encode_out = pooler_output return encode_out # def forward(self,text1,text2): # text1_features = self.encode_text(text1) # text2_features = self.encode_text(text2) # text1_features = F.normalize(text1_features, dim=-1) # text2_features = F.normalize(text2_features, dim=-1) # return text1_features, text2_features, self.logit_scale.exp() class ModelRes(nn.Module): def __init__(self, res_base_model): super(ModelRes, self).__init__() self.resnet_dict = { "resnet50": models.resnet50(pretrained=True), "resnet101": models.resnet101(pretrained=True), "resnet152": models.resnet152(pretrained=True), "resnet50_openai": None, 'resnet101_openai': None, 'resnet50x4_openai': None, } # "resnet50": models.resnet50(pretrained=True)} self.resnet = self._get_res_basemodel(res_base_model) # num_ftrs = int(self.resnet.fc.in_features/2) # self.res_features = nn.Sequential(*list(self.resnet.children())[:-3]) 224 if 'openai' in res_base_model: # 重新定义res_features num_ftrs = int(self.resnet.attnpool.v_proj.in_features) self.res_features = nn.Sequential(*list(self.resnet.children())[:-1]) else: num_ftrs = int(self.resnet.fc.in_features) self.res_features = nn.Sequential(*list(self.resnet.children())[:-2]) # here num_ftrs = 2048 self.res_l1 = nn.Linear(num_ftrs, num_ftrs) self.res_l2 = nn.Linear(num_ftrs, 768) def _get_res_basemodel(self, res_model_name): try: res_model = self.resnet_dict[res_model_name] print("Image feature extractor:", res_model_name) return res_model except: raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50") def forward(self, img): #return (batchsize, patch_num, dim) batch_size = img.shape[0] res_fea = self.res_features(img) # return res_fea # res_fea = F.adaptive_avg_pool2d(res_fea, (1, 1)) res_fea = rearrange(res_fea,'b d n1 n2 -> b (n1 n2) d') h = rearrange(res_fea,'b n d -> (b n) d') x = self.res_l1(h) x = F.relu(x) x = self.res_l2(x) out_emb = rearrange(x,'(b n) d -> b n d',b=batch_size) out_pool = torch.mean(out_emb,dim=1) return out_emb,out_pool class ModelConvNeXt(nn.Module): def __init__(self, convnext_base_model): super(ModelConvNeXt, self).__init__() self.convnext_dict = {"convnext-tiny": timm.create_model('convnextv2_tiny.fcmae_ft_in22k_in1k_384', pretrained=True, num_classes=1000), "convnext-base": timm.create_model('convnextv2_base.fcmae_ft_in22k_in1k_384', pretrained=True, num_classes=1000), } convnext = self._get_convnext_basemodel(convnext_base_model) num_ftrs = int(convnext.head.in_features) self.conv_features = nn.Sequential(*list(convnext.children())[:-2]) self.conv_l1 = nn.Linear(num_ftrs, num_ftrs) self.conv_l2 = nn.Linear(num_ftrs, 768) def _get_convnext_basemodel(self, convnext_model_name): try: convnext_model = self.convnext_dict[convnext_model_name] print("Image feature extractor:", convnext_model_name) return convnext_model except: raise ("Invalid model name. Check the config file and pass one of: convnext-tiny, convnext-small or convnext-base") def forward(self, img): #return (batchsize, patch_num, dim) batch_size = img.shape[0] conv_fea = self.conv_features(img) conv_fea = F.adaptive_avg_pool2d(conv_fea, (1, 1)) conv_fea = rearrange(conv_fea,'b d n1 n2 -> b (n1 n2) d') h = rearrange(conv_fea,'b n d -> (b n) d') x = self.conv_l1(h) x = F.relu(x) x = self.conv_l2(x) out_emb = rearrange(x,'(b n) d -> b n d',b=batch_size) out_pool = torch.mean(out_emb,dim=1) return out_emb,out_pool # class ModelConvNeXt(nn.Module): # def __init__(self, convnext_base_model): # super(ModelConvNeXt, self).__init__() # self.convnext_dict = {"convnext-tiny": models.convnext_tiny(weights='ConvNeXt_Tiny_Weights.DEFAULT'), # "convnext-small": models.convnext_small(weights='ConvNeXt_Small_Weights.DEFAULT'), # "convnext-base": models.convnext_base(weights='ConvNeXt_Base_Weights.DEFAULT'), # } # convnext = self._get_convnext_basemodel(convnext_base_model) # num_ftrs = int(convnext.classifier[-1].in_features) # self.conv_features = nn.Sequential(*list(convnext.children())[:-2]) # self.conv_l1 = nn.Linear(num_ftrs, num_ftrs) # self.conv_l2 = nn.Linear(num_ftrs, 768) # def _get_convnext_basemodel(self, convnext_model_name): # try: # convnext_model = self.convnext_dict[convnext_model_name] # print("Image feature extractor:", convnext_model_name) # return convnext_model # except: # raise ("Invalid model name. Check the config file and pass one of: convnext-tiny, convnext-small or convnext-base") # def forward(self, img): # #return (batchsize, patch_num, dim) # batch_size = img.shape[0] # conv_fea = self.conv_features(img) # conv_fea = F.adaptive_avg_pool2d(conv_fea, (1, 1)) # conv_fea = rearrange(conv_fea,'b d n1 n2 -> b (n1 n2) d') # h = rearrange(conv_fea,'b n d -> (b n) d') # x = self.conv_l1(h) # x = F.relu(x) # x = self.conv_l2(x) # out_emb = rearrange(x,'(b n) d -> b n d',b=batch_size) # out_pool = torch.mean(out_emb,dim=1) # return out_emb,out_pool # import open_clip # class ModelCLIP(nn.Module): # def __init__(self, clip_base_model): # super(ModelCLIP, self).__init__() # # 根据clip_base_model加载不同的模型 # if clip_base_model == 'openai_EVA02-B-16': # model, _, preprocess = open_clip.create_model_and_transforms('EVA02-B-16', pretrained='merged2b_s8b_b131k') # elif clip_base_model == 'openai_convnext_base_w': # model, _, preprocess = open_clip.create_model_and_transforms('convnext_base_w', pretrained='laion2b_s13b_b82k_augreg') # else: # raise ("Invalid model name. Check the config file and pass one of: EVA02-B-16 or convnext_base_w") class ModelEfficientV2(nn.Module): def __init__(self, efficientv2_base_model): super(ModelEfficientV2, self).__init__() self.efficientv2_dict = {"efficientnet_v2_s": models.efficientnet_v2_s(weights='EfficientNet_V2_S_Weights.IMAGENET1K_V1'),} self.efficientv2_model = self._get_efficientv2_basemodel(efficientv2_base_model) num_ftrs = int(self.efficientv2_model.classifier[-1].in_features) self.efficientv2_features = nn.Sequential(*list(self.efficientv2_model.children())[:-2]) self.efficientv2_l1 = nn.Linear(num_ftrs, num_ftrs) self.efficientv2_l2 = nn.Linear(num_ftrs, 768) def _get_efficientv2_basemodel(self, efficientv2_model_name): try: efficientv2_model = self.efficientv2_dict[efficientv2_model_name] print("Image feature extractor:", efficientv2_model_name) return efficientv2_model except: raise ("Invalid model name. Check the config file and pass one of: efficientnetv2_rw_s") def forward(self, img): batch_size = img.shape[0] efficientv2_fea = self.efficientv2_features(img) # efficientv2_fea = F.adaptive_avg_pool2d(efficientv2_fea, (1, 1)) # print(efficientv2_fea.shape) efficientv2_fea = rearrange(efficientv2_fea,'b d n1 n2 -> b (n1 n2) d') # print(efficientv2_fea.shape) h = rearrange(efficientv2_fea,'b n d -> (b n) d') # print(h.shape) x = self.efficientv2_l1(h) x = F.relu(x) x = self.efficientv2_l2(x) # print(x.shape) out_emb = rearrange(x,'(b n) d -> b n d',b=batch_size) out_pool = torch.mean(out_emb,dim=1) return out_emb,out_pool class ModelDense(nn.Module): def __init__(self, dense_base_model): super(ModelDense, self).__init__() self.densenet_dict = {"densenet121": models.densenet121(weights='DenseNet121_Weights.IMAGENET1K_V1'), "densenet161": models.densenet161(weights='DenseNet161_Weights.IMAGENET1K_V1'), "densenet201": models.densenet201(weights='DenseNet201_Weights.IMAGENET1K_V1'),} self.densenet = self._get_dense_basemodel(dense_base_model) num_ftrs = int(self.densenet.classifier.in_features) self.dense_features = self.densenet.features self.dense_l1 = nn.Linear(num_ftrs, num_ftrs) self.dense_l2 = nn.Linear(num_ftrs, 768) def _get_dense_basemodel(self, dense_base_model): try: dense_model = self.densenet_dict[dense_base_model] print("Image feature extractor:", dense_base_model) return dense_model except: raise ("Invalid model name. Check the config file and pass one of: densenet121 or densenet161") def forward(self, img): batch_size = img.shape[0] dense_fea = self.dense_features(img)#N, 1024, 7,7 dense_fea = rearrange(dense_fea,'b d n1 n2 -> b (n1 n2) d') h = rearrange(dense_fea,'b n d -> (b n) d') x = self.dense_l1(h) x = F.relu(x) x = self.dense_l2(x) out_emb = rearrange(x,'(b n) d -> b n d',b=batch_size) out_pool = torch.mean(out_emb,dim=1) return out_emb,out_pool class TQN_Model(nn.Module): def __init__(self, embed_dim: int = 768, class_num: int = 1, lam: list = [1, 0] ): super().__init__() self.d_model = embed_dim self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) # decoder_layer = TransformerDecoderLayer(self.d_model, 4, 1024, # 0.1, 'relu',normalize_before=True) decoder_layerV1 = TransformerDecoderLayerV1(self.d_model, 4, 1024, 0.1, 'relu', True, lam) self.decoder_norm = nn.LayerNorm(self.d_model) # self.decoder = TransformerDecoder(decoder_layer, 4, self.decoder_norm, # return_intermediate=False) self.decoderV1 = TransformerDecoderV1(decoder_layerV1, 4, self.decoder_norm, return_intermediate=False) self.dropout_feas = nn.Dropout(0.1) # class_num = 2 self.mlp_head = nn.Sequential( # nn.LayerNorm(768), nn.Linear(embed_dim, class_num) ) self.apply(self._init_weights) @staticmethod def _init_weights(module): if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.MultiheadAttention): module.in_proj_weight.data.normal_(mean=0.0, std=0.02) module.out_proj.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=0.02) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() # def forward(self, image_features, text_features): # #image_features (batch_size,patch_num,dim) # #text_features (query_num,dim) # batch_size = image_features.shape[0] # image_features = image_features.transpose(0,1) # text_features = text_features.unsqueeze(1).repeat(1, batch_size, 1) # image_features = self.decoder_norm(image_features) # text_features = self.decoder_norm(text_features) # # features = self.decoder(text_features, image_features, # # memory_key_padding_mask=None, pos=None, query_pos=None) # image_features_pool = torch.mean(image_features,dim=0).unsqueeze(0) # features = self.decoderV1(text_features, image_features, image_features_pool, # memory_key_padding_mask=None, pos=None, query_pos=None) # features = self.dropout_feas(features).transpose(0,1) #b,embed_dim # out = self.mlp_head(features) #(batch_size, query_num) # # out = out.squeeze(-1) # return out def forward(self, image_features, text_features, return_atten = False): #image_features (batch_size,patch_num,dim) #text_features (query_num,dim) batch_size = image_features.shape[0] image_features = image_features.transpose(0,1) text_features = text_features.unsqueeze(1).repeat(1, batch_size, 1) image_features = self.decoder_norm(image_features) text_features = self.decoder_norm(text_features) image_features_pool = torch.mean(image_features,dim=0).unsqueeze(0) features,atten_map = self.decoderV1(text_features, image_features, image_features_pool, memory_key_padding_mask=None, pos=None, query_pos=None) features = self.dropout_feas(features).transpose(0,1) #b,embed_dim out = self.mlp_head(features) #(batch_size, query_num) if return_atten: return out, atten_map else: return out class TQN_Model_Ensemble(nn.Module): def __init__(self, embed_dim: int = 768, class_num: int = 1, lam: list = [1, 0] ): super().__init__() self.d_model = embed_dim self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) decoder_layerV1 = TransformerDecoderLayerV1(self.d_model, 4, 1024, 0.1, 'relu', True, lam) self.decoder_norm = nn.LayerNorm(self.d_model) self.decoder_norm_1 = nn.LayerNorm(self.d_model) self.decoder_norm_2 = nn.LayerNorm(self.d_model) self.decoderV1 = TransformerDecoderV1(decoder_layerV1, 4, self.decoder_norm, return_intermediate=False) self.decoderV1_1 = TransformerDecoderV1(decoder_layerV1, 4, self.decoder_norm_1, return_intermediate=False) self.decoderV1_2 = TransformerDecoderV1(decoder_layerV1, 4, self.decoder_norm_2, return_intermediate=False) self.dropout_feas = nn.Dropout(0.1) # class_num = 2 self.mlp_head = nn.Sequential(nn.Linear(embed_dim, class_num)) self.mlp_head_1 = nn.Sequential(nn.Linear(embed_dim, class_num)) self.mlp_head_2 = nn.Sequential(nn.Linear(embed_dim, class_num)) self.apply(self._init_weights) @staticmethod def _init_weights(module): if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.MultiheadAttention): module.in_proj_weight.data.normal_(mean=0.0, std=0.02) module.out_proj.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=0.02) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() def forward(self, image_features, text_features, return_atten = False): batch_size = image_features.shape[0] image_features = image_features.transpose(0,1) text_features = text_features.unsqueeze(1).repeat(1, batch_size, 1) image_features = self.decoder_norm(image_features) image_features_1 = self.decoder_norm_1(image_features) image_features_2 = self.decoder_norm_2(image_features) text_features = self.decoder_norm(text_features) text_features_1 = self.decoder_norm_1(text_features) text_features_2 = self.decoder_norm_2(text_features) image_features_pool = torch.mean(image_features,dim=0).unsqueeze(0) image_features_pool_1 = torch.mean(image_features_1,dim=0).unsqueeze(0) image_features_pool_2 = torch.mean(image_features_2,dim=0).unsqueeze(0) features,atten_map = self.decoderV1(text_features, image_features, image_features_pool, memory_key_padding_mask=None, pos=None, query_pos=None) features = self.dropout_feas(features).transpose(0,1) out = self.mlp_head(features) features_1,atten_map_1 = self.decoderV1_1(text_features_1, image_features_1, image_features_pool_1, memory_key_padding_mask=None, pos=None, query_pos=None) features_1 = self.dropout_feas(features_1).transpose(0,1) out_1 = self.mlp_head_1(features_1) features_2,atten_map_2 = self.decoderV1_2(text_features_2, image_features_2, image_features_pool_2, memory_key_padding_mask=None, pos=None, query_pos=None) features_2 = self.dropout_feas(features_2).transpose(0,1) out_2 = self.mlp_head_2(features_2) out_stack = torch.stack([out, out_1, out_2]) out = torch.mean(out_stack, dim=0) if return_atten: return out, atten_map else: return out # MIMIC时,batch_size=32, query_num=41, patch_num=256, dim=768 # img 256, 32, 768 # txt 1, 32, 768 # query41, 32, 768 # fts 41, 32, 768 # out 41, 32, 1 # 未经过sigmoid!计算loss时sigmoid! if __name__ == "__main__": #torch 1.10.2 to torch 1.12.1 #torchvision-0.11.3 to torchvision-0.13.1 # image = torch.randn(1, 3, 224, 224) # image_encoder = ModelRes(res_base_model = 'resnet50') # # image_encoder = ModelDense(dense_base_model = 'densenet121') # # image_encoder = ModelViT(vit_base_model = 'vit_b_32') # image_encoder(image) # image = torch.randn(256, 1, 768) # query = torch.randn(41, 768) # model = TQN_Model() # out = model(image, query) # img = torch.randn(1,3,512,512) img = torch.randn(2,3,224,224) # model = ModelConvNeXt(convnext_base_model = 'convnext-base') # model = ModelEfficientV2(efficientv2_base_model = 'efficientnet_v2_s') model = ModelRes(res_base_model = 'resnet50_openai') out_emb, out_pool = model(img) print(out_emb.size(), out_pool.size())