KyanChen's picture
Upload 1861 files
3b96cb1
# Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from mmengine.model import BaseModel, BaseModule
from torch import nn
from mmpretrain.datasets.categories import CIFAR100_CATEGORIES_CN
from mmpretrain.registry import MODELS, TOKENIZER
from mmpretrain.structures import DataSample
from mmpretrain.utils import track_on_main_process
from .utils import OPENAI_PROMPT
PROTOTYPE_MAP = {'cifar100': CIFAR100_CATEGORIES_CN}
PROMPT_MAP = {'openai': OPENAI_PROMPT}
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = None
self.stride = stride
if stride > 1 or inplanes != planes * Bottleneck.expansion:
self.downsample = nn.Sequential(
OrderedDict([('-1', nn.AvgPool2d(stride)),
('0',
nn.Conv2d(
inplanes,
planes * self.expansion,
1,
stride=1,
bias=False)),
('1', nn.BatchNorm2d(planes * self.expansion))]))
def forward(self, x: torch.Tensor):
identity = x
out = self.relu(self.bn1(self.conv1(x)))
out = self.relu(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class AttentionPool2d(nn.Module):
def __init__(self,
spacial_dim: int,
embed_dim: int,
num_heads: int,
output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(
torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.reshape(x.shape[0], x.shape[1],
x.shape[2] * x.shape[3]).permute(2, 0,
1) # NCHW -> (HW)NC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
x, _ = F.multi_head_attention_forward(
query=x,
key=x,
value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat(
[self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False)
return x[0]
@MODELS.register_module()
class ModifiedResNet(BaseModule):
"""A modified ResNet contains the following changes:
- Apply deep stem with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is
prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
""" # noqa
arch_settings = {
50: (Bottleneck, (3, 4, 6, 3)),
101: (Bottleneck, (3, 4, 23, 3)),
152: (Bottleneck, (3, 8, 36, 3))
}
def __init__(self,
depth: int = 50,
base_channels: int = 64,
input_size: int = 224,
num_attn_heads: int = 32,
output_dim: int = 1024,
init_cfg: Optional[dict] = None):
super().__init__(init_cfg=init_cfg)
self.input_size = input_size
self.block, stage_blocks = self.arch_settings[depth]
# the 3-layer stem
self.conv1 = nn.Conv2d(
3,
base_channels // 2,
kernel_size=3,
stride=2,
padding=1,
bias=False)
self.bn1 = nn.BatchNorm2d(base_channels // 2)
self.conv2 = nn.Conv2d(
base_channels // 2,
base_channels // 2,
kernel_size=3,
padding=1,
bias=False)
self.bn2 = nn.BatchNorm2d(base_channels // 2)
self.conv3 = nn.Conv2d(
base_channels // 2,
base_channels,
kernel_size=3,
padding=1,
bias=False)
self.bn3 = nn.BatchNorm2d(base_channels)
self.avgpool = nn.AvgPool2d(2)
self.relu = nn.ReLU(inplace=True)
# residual layers
# this is a *mutable* variable used during construction
self._inplanes = base_channels
self.layer1 = self._make_layer(base_channels, stage_blocks[0])
self.layer2 = self._make_layer(
base_channels * 2, stage_blocks[1], stride=2)
self.layer3 = self._make_layer(
base_channels * 4, stage_blocks[2], stride=2)
self.layer4 = self._make_layer(
base_channels * 8, stage_blocks[3], stride=2)
embed_dim = base_channels * 32
self.attnpool = AttentionPool2d(input_size // 32, embed_dim,
num_attn_heads, output_dim)
def _make_layer(self, planes, blocks, stride=1):
layers = [Bottleneck(self._inplanes, planes, stride)]
self._inplanes = planes * Bottleneck.expansion
for _ in range(1, blocks):
layers.append(Bottleneck(self._inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
def stem(x):
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2),
(self.conv3, self.bn3)]:
x = self.relu(bn(conv(x)))
x = self.avgpool(x)
return x
x = x.type(self.conv1.weight.dtype)
x = stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.attnpool(x)
return x
@MODELS.register_module()
class ChineseCLIP(BaseModel):
"""The implementation of `ChineseCLIP <https://arxiv.org/abs/2211.01335>`_.
Args:
vision_backbone (dict): Config dict for vision backbone.
text_backbone (dict): Config dict for text backbone.
tokenizer (dict): Config dict for text tokenizer.
proj_dim (int): Projection dimension for similarity computation.
text_prototype (str): Text prototype, which can be a key in
`PROTOTYPE_MAP` or list of text.
text_prompt (str): The prompt for text prototype. Defaults to 'openai'.
context_length (int): The context length to use. Defaults to 52.
data_preprocessor (Union[dict, nn.Module], optional): The config for
preprocessing input data. If None or no specified type, it will use
"MultiModalDataPreprocessor" as type.
See :class:`MultiModalDataPreprocessor` for more details.
Defaults to None.
init_cfg (dict, optional): The config to control the initialization.
Defaults to None.
"""
def __init__(self,
vision_backbone: dict,
text_backbone: dict,
tokenizer: dict,
proj_dim: int,
text_prototype: Union[str, List[str]],
text_prompt: str = 'openai',
context_length: int = 52,
data_preprocessor: Optional[dict] = None,
init_cfg: Optional[dict] = None):
if data_preprocessor is None:
data_preprocessor = {}
data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor')
data_preprocessor = MODELS.build(data_preprocessor)
super().__init__(
data_preprocessor=data_preprocessor, init_cfg=init_cfg)
self.vision_backbone = MODELS.build(vision_backbone)
self.text_backbone = MODELS.build(text_backbone)
if not isinstance(self.vision_backbone, ModifiedResNet):
self.vision_projection = nn.Parameter(
torch.empty(self.vision_backbone.embed_dims, proj_dim))
text_hidden_size = text_backbone['config']['hidden_size']
self.text_projection = nn.Parameter(
torch.empty(text_hidden_size, proj_dim))
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.tokenizer = TOKENIZER.build(tokenizer)
self.context_length = context_length
# for zero-shot classification
if isinstance(text_prototype,
str) and text_prototype in PROTOTYPE_MAP.keys():
self.prototype = PROTOTYPE_MAP[text_prototype]
else:
self.prototype = text_prototype
self.text_prototype_embeds = None
self.prompt = PROMPT_MAP[text_prompt]
def forward(
self,
images: torch.Tensor,
data_samples: Optional[list] = None,
mode: str = 'predict',
**kwargs,
):
"""The unified entry for a forward process in both training and test.
The method accepts the following modes:
- "predict": Forward and return a list of data samples contain the
predict results.
Args:
images (torch.Tensor): the preprocessed image tensor of shape
``(N, C, H, W)``.
data_samples (List[DataSample], optional): The annotation data
of every samples. Defaults to None.
mode (str): Return what kind of value. Defaults to 'predict'.
"""
if mode == 'predict':
return self.predict(images, data_samples, **kwargs)
else:
raise RuntimeError(f'Invalid mode "{mode}".')
def extract_image_feat(self, images: torch.Tensor) -> torch.Tensor:
"""The function to extract image latent features."""
if isinstance(self.vision_backbone, ModifiedResNet):
return self.vision_backbone(images)
return self.vision_backbone(images)[-1] @ self.vision_projection
def extract_text_feat(self, texts: torch.Tensor) -> torch.Tensor:
"""The function to extract text latent features."""
pad_index = self.tokenizer.vocab['[PAD]']
attn_mask = texts.ne(pad_index)
# [batch_size, seq_length, hidden_size]
x = self.text_backbone(texts, attention_mask=attn_mask)[0]
return x[:, 0, :] @ self.text_projection
def extract_feat(
self, images: torch.Tensor,
texts: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
"""The function to extract image and text latent features, the input
image or text can not both be None."""
assert images is not None or texts is not None, \
'text and image cannot both be None!'
if images is None:
return self.extract_text_feat(texts)
elif texts is None:
return self.extract_image_feat(images)
image_features = self.extract_image_feat(images)
text_features = self.extract_text_feat(texts)
image_features = image_features / image_features.norm(
dim=-1, keepdim=True)
text_features = text_features / text_features.norm(
dim=-1, keepdim=True)
return image_features, text_features
def compute_similarity(self, images, texts):
"""Extract images and texts features and compute cosine similarity."""
image_features, text_features = self.extract_feat(
images=images, texts=texts)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
# shape (N, N)
return logits_per_image, logits_per_text
def predict(self,
images: torch.Tensor,
data_samples: DataSample = None) -> DataSample:
"""Predict the classes of the input images.
The prediction is for zero-shot classification and the text prototypes
will be prepared in thisfunction.
Args:
images (torch.Tensor): The input images.
data_samples (DataSample): The data samples with information from
dataset.
Returns:
DataSample: The results of prediction.
"""
if self.text_prototype_embeds is None:
self.prepare_text_prototype(device=images.device)
image_features = self.extract_image_feat(images=images)
image_features /= image_features.norm(dim=-1, keepdim=True)
# cosine similarity as logits
logits_per_image = image_features @ self.text_prototype_embeds.to(
image_features.device) * self.logit_scale.exp()
pred_scores = F.softmax(logits_per_image, dim=1)
pred_labels = pred_scores.argmax(dim=1, keepdim=True).detach()
out_data_samples = []
if data_samples is None:
data_samples = [None for _ in range(pred_scores.size(0))]
for data_sample, score, label in zip(data_samples, pred_scores,
pred_labels):
if data_sample is None:
data_sample = DataSample()
data_sample.set_pred_score(score).set_pred_label(label)
out_data_samples.append(data_sample)
return out_data_samples
def prepare_text_prototype(self, device) -> None:
"""The function to prepare text prototypes with prompt."""
class_embeddings = []
for classname in track_on_main_process(self.prototype,
'Prepare text prototype...'):
# format with class
texts = [prompt(classname) for prompt in self.prompt]
tokenized_texts = self.tokenize(texts)
class_features = self.extract_text_feat(tokenized_texts.to(device))
class_features /= class_features.norm(dim=-1, keepdim=True)
class_feature = class_features.mean(dim=0)
class_feature /= class_feature.norm()
class_embeddings.append(class_feature)
self.text_prototype_embeds = torch.stack(
class_embeddings, dim=1).to(device)
def tokenize(self, texts: Union[str, List[str]]) -> torch.LongTensor:
"""Returns the tokenized representation of given input string(s)
Args:
texts (Union[str, List[str]]): An input string or a list of input
strings to tokenize
context_length (int): The context length to use. Defaults to 52.
Returns:
torch.Tensor: Resulting tokens.
"""
if isinstance(texts, str):
texts = [texts]
all_tokens = []
for text in texts:
# adapt the text to Chinese BERT vocab
text = text.lower().replace('β€œ', "\"").replace('”', "\"")
# add special tokens
all_tokens.append(
[self.tokenizer.vocab['[CLS]']] +
self.tokenizer.convert_tokens_to_ids(
self.tokenizer.tokenize(text))[:self.context_length - 2] +
[self.tokenizer.vocab['[SEP]']])
result = torch.zeros(
len(all_tokens), self.context_length, dtype=torch.long)
for i, tokens in enumerate(all_tokens):
assert len(tokens) <= self.context_length
result[i, :len(tokens)] = torch.tensor(tokens)
return result