# Copyright (c) OpenMMLab. All rights reserved. from abc import abstractmethod from typing import List, Optional, Tuple, Union import numpy as np import torch import torch.nn.functional as F from mmengine.model import BaseModel from torch import nn from mmpretrain.datasets.categories import (CIFAR100_CATEGORIES, IMAGENET_SIMPLE_CATEGORIES) from mmpretrain.registry import MODELS, TOKENIZER from mmpretrain.structures import DataSample from mmpretrain.utils import track_on_main_process from .utils import (OPENAI_CIFAR100_PROMPT, OPENAI_IMAGENET_PROMPT, OPENAI_IMAGENET_PROMPT_SUB) CIFAR100_CATEGORIES = [' '.join(c.split('_')) for c in CIFAR100_CATEGORIES] PROTOTYPE_MAP = { 'imagenet': IMAGENET_SIMPLE_CATEGORIES, 'cifar100': CIFAR100_CATEGORIES, } PROMPT_MAP = { 'openai_imagenet': OPENAI_IMAGENET_PROMPT, 'openai_cifar100': OPENAI_CIFAR100_PROMPT, 'vanilla': [lambda c: f'a photo of a {c}'], 'openai_imagenet_sub': OPENAI_IMAGENET_PROMPT_SUB } class LayerNorm(nn.LayerNorm): """Subclass torch's LayerNorm to handle fp16.""" def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward function.""" orig_type = x.dtype ret = super().forward(x.type(torch.float32)) return ret.type(orig_type) class CLIP(BaseModel): """The implementation of `CLIP `_. Args: vision_backbone (dict): Config dict for vision backbone. text_backbone (dict): Config dict for text backbone. tokenizer (dict): Config dict for text tokenizer. proj_dim (int): Projection dimension for similarity computation. text_prototype (str): Text prototype, which can be a key in `PROTOTYPE_MAP` or list of text. text_prompt (str): The prompt for text prototype. Defaults to 'vanilla',which refers to "a photo of {cls}". context_length (int): The context length to use. Defaults to 77. data_preprocessor (Union[dict, nn.Module], optional): The config for preprocessing input data. If None or no specified type, it will use "MultiModalDataPreprocessor" as type. See :class:`MultiModalDataPreprocessor` for more details. Defaults to None. init_cfg (dict, optional): The config to control the initialization. Defaults to None. """ def __init__(self, vision_backbone: dict, projection: dict, text_backbone: dict, tokenizer: dict, vocab_size: int, transformer_width: int, proj_dim: int, context_length: int = 77, data_preprocessor: Optional[dict] = None, init_cfg: Optional[dict] = None): if data_preprocessor is None: data_preprocessor = {} data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') data_preprocessor = MODELS.build(data_preprocessor) super().__init__( data_preprocessor=data_preprocessor, init_cfg=init_cfg) self.context_length = context_length # build the vision transformer self.visual = MODELS.build(vision_backbone) # build the visual projection self.visual_proj = MODELS.build(projection) # build attn_mask for casual-attn text_backbone['attn_mask'] = self.build_attention_mask() # build the text transformer self.transformer = MODELS.build(text_backbone) self.vocab_size = vocab_size self.token_embedding = nn.Embedding(vocab_size, transformer_width) self.positional_embedding = nn.Parameter( torch.empty(self.context_length, transformer_width)) self.ln_final = LayerNorm(transformer_width) self.text_projection = nn.Parameter( torch.empty(transformer_width, proj_dim)) self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.initialize_parameters() self.tokenizer = TOKENIZER.build(tokenizer) self.tokenizer.vocab = self.tokenizer.get_vocab( ) # CLIPTokenizer has no attribute named 'vocab', so manually def initialize_parameters(self) -> None: """Initialize the parameters. The pretrained weight will override the initialized parameters by this function. """ nn.init.normal_(self.token_embedding.weight, std=0.02) nn.init.normal_(self.positional_embedding, std=0.01) proj_std = (self.transformer.width**-0.5) * ( (2 * self.transformer.layers)**-0.5) attn_std = self.transformer.width**-0.5 fc_std = (2 * self.transformer.width)**-0.5 for block in self.transformer.resblocks: nn.init.normal_(block.attn.in_proj_weight, std=attn_std) nn.init.normal_(block.attn.out_proj.weight, std=proj_std) nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) if self.text_projection is not None: nn.init.normal_( self.text_projection, std=self.transformer.width**-0.5) def build_attention_mask(self): # lazily create causal attention mask, # with full attention between the vision tokens # pytorch uses additive attention mask; fill with -inf mask = torch.empty(self.context_length, self.context_length) mask.fill_(float('-inf')) mask.triu_(1) # zero out the lower diagonal return mask def forward( self, images: torch.Tensor, data_samples: Optional[list] = None, mode: str = 'predict', **kwargs, ): """The unified entry for a forward process in both training and test. The method accepts the following modes: - "predict": Forward and return a list of data samples contain the predict results. Args: images (torch.Tensor): the preprocessed image tensor of shape ``(N, C, H, W)``. data_samples (List[DataSample], optional): The annotation data of every samples. Defaults to None. mode (str): Return what kind of value. Defaults to 'predict'. """ if mode == 'predict': return self.predict(images, data_samples, **kwargs) else: raise RuntimeError(f'Invalid mode "{mode}".') def extract_image_feat(self, images: torch.Tensor) -> torch.Tensor: """The function to extract image latent features.""" return self.visual_proj(self.visual(images))[0] def extract_text_feat(self, texts: torch.Tensor) -> torch.Tensor: """The function to extract text latent features.""" x = self.token_embedding(texts) # [batch_size, n_ctx, d_model] x = x + self.positional_embedding x = x.permute(1, 0, 2) # NLD -> LND x = self.transformer(x)[0] x = x.permute(1, 0, 2) # LND -> NLD x = self.ln_final(x) # x.shape = [batch_size, n_ctx, transformer.width] # take features from the eot embedding # (eot_token is the highest number in each sequence) x = x[torch.arange(x.shape[0]), texts.argmax(dim=-1)] @ self.text_projection return x def extract_feat( self, images: torch.Tensor, texts: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor]]: """The function to extract image and text latent features, the input image or text can not both be None.""" assert images is not None or texts is not None, \ 'text and image cannot both be None!' if images is None: return self.extract_text_feat(texts) elif texts is None: return self.extract_image_feat(images) image_features = self.extract_image_feat(images) text_features = self.extract_text_feat(texts) image_features = image_features / image_features.norm( dim=-1, keepdim=True) text_features = text_features / text_features.norm( dim=-1, keepdim=True) return image_features, text_features def compute_similarity(self, images, texts): """Extract images and texts features and compute cosine similarity.""" image_features, text_features = self.extract_feat( images=images, texts=texts) # cosine similarity as logits logit_scale = self.logit_scale.exp() logits_per_image = logit_scale * image_features @ text_features.t() logits_per_text = logits_per_image.t() # shape (N, N) return logits_per_image, logits_per_text @abstractmethod def predict(self, images: torch.Tensor, data_samples: DataSample = None) -> DataSample: raise NotImplementedError def tokenize(self, texts: Union[str, List[str]]) -> torch.LongTensor: """Returns the tokenized representation of given input string(s) Args: texts (Union[str, List[str]]): An input string or a list of input strings to tokenize context_length (int): The context length to use. Defaults to 52. Returns: torch.Tensor: Resulting tokens. """ if isinstance(texts, str): texts = [texts] all_tokens = [] for text in texts: # adapt the text to Chinese BERT vocab # text = text.lower().replace('“', "\"").replace('”', "\"") # add special tokens all_tokens.append( [self.tokenizer.vocab['<|startoftext|>'] ] + # <|startoftext|>代表[CLS] token self.tokenizer.convert_tokens_to_ids( self.tokenizer.tokenize(text))[:self.context_length - 2] + [self.tokenizer.vocab['<|endoftext|>']]) result = torch.zeros( len(all_tokens), self.context_length, dtype=torch.long) for i, tokens in enumerate(all_tokens): assert len(tokens) <= self.context_length result[i, :len(tokens)] = torch.tensor(tokens) return result @MODELS.register_module() class CLIPZeroShot(CLIP): def __init__( self, vision_backbone: dict, projection: dict, text_backbone: dict, tokenizer: dict, vocab_size: int, transformer_width: int, proj_dim: int, context_length: int = 77, data_preprocessor: Optional[dict] = None, init_cfg: Optional[dict] = None, text_prototype: Union[str, List[str]] = 'imagenet', text_prompt: str = 'vanilla', ): super(CLIPZeroShot, self).__init__(vision_backbone, projection, text_backbone, tokenizer, vocab_size, transformer_width, proj_dim, context_length, data_preprocessor, init_cfg) # for zero-shot classification if isinstance(text_prototype, str) and text_prototype in PROTOTYPE_MAP.keys(): self.prototype = PROTOTYPE_MAP[text_prototype] else: self.prototype = text_prototype self.text_prototype_embeds = None self.prompt = PROMPT_MAP[text_prompt] def predict(self, images: torch.Tensor, data_samples: DataSample = None) -> DataSample: """Predict the classes of the input images. The prediction is for zero-shot classification and the text prototypes will be prepared in thisfunction. Args: images (torch.Tensor): The input images. data_samples (DataSample): The data samples with information from dataset. Returns: DataSample: The results of prediction. """ if self.text_prototype_embeds is None: self.prepare_text_prototype(device=images.device) image_features = self.extract_image_feat(images=images) image_features /= image_features.norm(dim=-1, keepdim=True) # cosine similarity as logits logits_per_image = image_features @ self.text_prototype_embeds.to( image_features.device) * self.logit_scale.exp() pred_scores = F.softmax(logits_per_image, dim=1) pred_labels = pred_scores.argmax(dim=1, keepdim=True).detach() out_data_samples = [] if data_samples is None: data_samples = [None for _ in range(pred_scores.size(0))] for data_sample, score, label in zip(data_samples, pred_scores, pred_labels): if data_sample is None: data_sample = DataSample() data_sample.set_pred_score(score).set_pred_label(label) out_data_samples.append(data_sample) return out_data_samples def prepare_text_prototype(self, device) -> None: """The function to prepare text prototypes with prompt.""" class_embeddings = [] for classname in track_on_main_process(self.prototype, 'Prepare text prototype...'): # format with class texts = [prompt(classname) for prompt in self.prompt] tokenized_texts = self.tokenize(texts) class_features = self.extract_text_feat(tokenized_texts.to(device)) class_features /= class_features.norm(dim=-1, keepdim=True) class_feature = class_features.mean(dim=0) class_feature /= class_feature.norm() class_embeddings.append(class_feature) self.text_prototype_embeds = torch.stack( class_embeddings, dim=1).to(device)