|
from typing import Union, List, Optional |
|
import numpy as np |
|
import torch |
|
from pkg_resources import packaging |
|
from torch import nn |
|
from torch.nn import functional as F |
|
from .clip_model import CLIP |
|
from .simple_tokenizer import SimpleTokenizer as _Tokenizer |
|
from sklearn.cluster import KMeans |
|
|
|
class ProjectLayer(nn.Module): |
|
def __init__(self, input_dim, output_dim, num_replicas, stack=False, is_array=True): |
|
super(ProjectLayer, self).__init__() |
|
|
|
self.head = nn.ModuleList([nn.Linear(input_dim, output_dim) for _ in range(num_replicas)]) |
|
self.num_replicas = num_replicas |
|
self.stack = stack |
|
self.is_array = is_array |
|
|
|
def forward(self, tokens): |
|
out_tokens = [] |
|
for i in range(self.num_replicas): |
|
if self.is_array: |
|
temp = self.head[i](tokens[i][:, 1:, :]) |
|
else: |
|
temp = self.head[i](tokens) |
|
|
|
out_tokens.append(temp) |
|
|
|
if self.stack: |
|
out_tokens = torch.stack(out_tokens, dim=1) |
|
|
|
return out_tokens |
|
|
|
class PromptLayer(nn.Module): |
|
def __init__(self, channel, length, depth, is_text, prompting_type, enabled=True): |
|
super(PromptLayer, self).__init__() |
|
|
|
self.channel = channel |
|
self.length = length |
|
self.depth = depth |
|
self.is_text = is_text |
|
self.enabled = enabled |
|
|
|
self.prompting_type = prompting_type |
|
|
|
if self.enabled: |
|
if 'S' in prompting_type: |
|
|
|
self.static_prompts = nn.ParameterList( |
|
[nn.Parameter(torch.empty(self.length, self.channel)) |
|
for _ in range(self.depth)]) |
|
|
|
for single_para in self.static_prompts: |
|
nn.init.normal_(single_para, std=0.02) |
|
|
|
if 'D' in prompting_type: |
|
self.dynamic_prompts = [0.] |
|
|
|
def set_dynamic_prompts(self, dynamic_prompts): |
|
self.dynamic_prompts = dynamic_prompts |
|
|
|
def forward_text(self, resblock, indx, x, k_x=None, v_x=None, attn_mask: Optional[torch.Tensor] = None): |
|
if self.enabled: |
|
length = self.length |
|
|
|
|
|
if indx < self.depth: |
|
if 'S' in self.prompting_type and 'D' in self.prompting_type: |
|
static_prompts = self.static_prompts[indx].unsqueeze(0).expand(x.shape[1], -1, -1) |
|
textual_context = self.dynamic_prompts + static_prompts |
|
elif 'S' in self.prompting_type: |
|
static_prompts = self.static_prompts[indx].unsqueeze(0).expand(x.shape[1], -1, -1) |
|
textual_context = static_prompts |
|
elif 'D' in self.prompting_type: |
|
textual_context = self.dynamic_prompts |
|
else: |
|
print('You should at least choose one type of prompts when the prompting branches are not none.') |
|
raise NotImplementedError |
|
|
|
if indx == 0: |
|
x = x |
|
else: |
|
if indx < self.depth: |
|
prefix = x[:1, :, :] |
|
suffix = x[1 + length:, :, :] |
|
textual_context = textual_context.permute(1, 0, 2).half() |
|
x = torch.cat([prefix, textual_context, suffix], dim=0) |
|
else: |
|
x = x |
|
else: |
|
x = x |
|
|
|
x, attn_tmp = resblock(q_x=x, k_x=k_x, v_x= v_x, attn_mask=attn_mask) |
|
|
|
return x, attn_tmp |
|
|
|
def forward_visual(self, resblock, indx, x, k_x=None, v_x=None, attn_mask: Optional[torch.Tensor] = None): |
|
if self.enabled: |
|
length = self.length |
|
|
|
|
|
if indx < self.depth: |
|
if 'S' in self.prompting_type and 'D' in self.prompting_type: |
|
static_prompts = self.static_prompts[indx].unsqueeze(0).expand(x.shape[1], -1, -1) |
|
visual_context = self.dynamic_prompts + static_prompts |
|
elif 'S' in self.prompting_type: |
|
static_prompts = self.static_prompts[indx].unsqueeze(0).expand(x.shape[1], -1, -1) |
|
visual_context = static_prompts |
|
elif 'D' in self.prompting_type: |
|
visual_context = self.dynamic_prompts |
|
else: |
|
print('You should at least choose one type of prompts when the prompting branches are not none.') |
|
raise NotImplementedError |
|
|
|
|
|
if indx == 0: |
|
visual_context = visual_context.permute(1, 0, 2).half() |
|
x = torch.cat([x, visual_context], dim=0) |
|
else: |
|
if indx < self.depth: |
|
prefix = x[0:x.shape[0] - length, :, :] |
|
visual_context = visual_context.permute(1, 0, 2).half() |
|
x = torch.cat([prefix, visual_context], dim=0) |
|
else: |
|
x = x |
|
else: |
|
x = x |
|
|
|
x, attn_tmp = resblock(q_x=x, k_x=k_x, v_x= v_x, attn_mask=attn_mask) |
|
|
|
if self.enabled: |
|
tokens = x[:x.shape[0] - length, :, :] |
|
else: |
|
tokens = x |
|
|
|
return x, tokens, attn_tmp |
|
|
|
def forward(self, resblock, indx, x, k_x=None, v_x=None, attn_mask: Optional[torch.Tensor] = None): |
|
if self.is_text: |
|
return self.forward_text(resblock, indx, x, k_x, v_x, attn_mask) |
|
else: |
|
return self.forward_visual(resblock, indx, x, k_x, v_x, attn_mask) |
|
|
|
|
|
class TextEmbebddingLayer(nn.Module): |
|
def __init__(self, fixed): |
|
super(TextEmbebddingLayer, self).__init__() |
|
self.tokenizer = _Tokenizer() |
|
self.ensemble_text_features = {} |
|
self.prompt_normal = ['{}', 'flawless {}', 'perfect {}', 'unblemished {}', '{} without flaw', |
|
'{} without defect', |
|
'{} without damage'] |
|
self.prompt_abnormal = ['damaged {}', 'broken {}', '{} with flaw', '{} with defect', '{} with damage'] |
|
self.prompt_state = [self.prompt_normal, self.prompt_abnormal] |
|
self.prompt_templates = ['a bad photo of a {}.', |
|
'a low resolution photo of the {}.', |
|
'a bad photo of the {}.', |
|
'a cropped photo of the {}.', |
|
] |
|
self.fixed = fixed |
|
|
|
def tokenize(self, texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[ |
|
torch.IntTensor, torch.LongTensor]: |
|
if isinstance(texts, str): |
|
texts = [texts] |
|
|
|
sot_token = self.tokenizer.encoder["<|startoftext|>"] |
|
eot_token = self.tokenizer.encoder["<|endoftext|>"] |
|
all_tokens = [[sot_token] + self.tokenizer.encode(text) + [eot_token] for text in texts] |
|
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): |
|
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) |
|
else: |
|
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) |
|
|
|
for i, tokens in enumerate(all_tokens): |
|
if len(tokens) > context_length: |
|
if truncate: |
|
tokens = tokens[:context_length] |
|
tokens[-1] = eot_token |
|
else: |
|
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") |
|
result[i, :len(tokens)] = torch.tensor(tokens) |
|
|
|
return result |
|
|
|
|
|
def forward(self, model, texts, device): |
|
text_feature_list = [] |
|
|
|
for indx, text in enumerate(texts): |
|
|
|
if self.fixed: |
|
if self.ensemble_text_features.get(text) is None: |
|
text_features = self.encode_text(model, text, device) |
|
self.ensemble_text_features[text] = text_features |
|
else: |
|
text_features = self.ensemble_text_features[text] |
|
else: |
|
text_features = self.encode_text(model, text, device) |
|
self.ensemble_text_features[text] = text_features |
|
|
|
text_feature_list.append(text_features) |
|
|
|
text_features = torch.stack(text_feature_list, dim=0) |
|
text_features = F.normalize(text_features, dim=1) |
|
|
|
return text_features |
|
|
|
def encode_text(self, model, text, device): |
|
text_features = [] |
|
for i in range(len(self.prompt_state)): |
|
text = text.replace('-', ' ') |
|
prompted_state = [state.format(text) for state in self.prompt_state[i]] |
|
prompted_sentence = [] |
|
for s in prompted_state: |
|
for template in self.prompt_templates: |
|
prompted_sentence.append(template.format(s)) |
|
prompted_sentence = self.tokenize(prompted_sentence, context_length=77).to(device) |
|
|
|
class_embeddings = model.encode_text(prompted_sentence) |
|
|
|
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) |
|
class_embedding = class_embeddings.mean(dim=0) |
|
class_embedding /= class_embedding.norm() |
|
text_features.append(class_embedding) |
|
|
|
text_features = torch.stack(text_features, dim=1) |
|
|
|
return text_features |
|
|
|
|
|
|
|
class HybridSemanticFusion(nn.Module): |
|
def __init__(self, k_clusters): |
|
super(HybridSemanticFusion, self).__init__() |
|
self.k_clusters = k_clusters |
|
self.n_aggregate_patch_tokens = k_clusters * 5 |
|
self.cluster_performer = KMeans(n_clusters=self.k_clusters, n_init="auto") |
|
|
|
|
|
def forward(self, patch_tokens: list, anomaly_maps: list): |
|
anomaly_map = torch.mean(torch.stack(anomaly_maps, dim=1), dim=1) |
|
anomaly_map = torch.softmax(anomaly_map, dim=2)[:, :, 1] |
|
|
|
|
|
selected_abnormal_tokens = [] |
|
k = min(anomaly_map.shape[1], self.n_aggregate_patch_tokens) |
|
top_k_indices = torch.topk(anomaly_map, k=k, dim=1).indices |
|
for layer in range(len(patch_tokens)): |
|
selected_tokens = patch_tokens[layer]. \ |
|
gather(dim=1, index=top_k_indices.unsqueeze(-1). |
|
expand(-1, -1, patch_tokens[layer].shape[-1])) |
|
selected_abnormal_tokens.append(selected_tokens) |
|
|
|
|
|
|
|
stacked_data = torch.cat(selected_abnormal_tokens, dim=2) |
|
|
|
batch_cluster_centers = [] |
|
|
|
for b in range(stacked_data.shape[0]): |
|
cluster_labels = self.cluster_performer.fit_predict(stacked_data[b, :, :].detach().cpu().numpy()) |
|
|
|
|
|
cluster_centers = [] |
|
|
|
|
|
for cluster_id in range(self.k_clusters): |
|
collected_cluster_data = [] |
|
for abnormal_tokens in selected_abnormal_tokens: |
|
cluster_data = abnormal_tokens[b, :, :][cluster_labels == cluster_id] |
|
collected_cluster_data.append(cluster_data) |
|
collected_cluster_data = torch.cat(collected_cluster_data, dim=0) |
|
cluster_center = torch.mean(collected_cluster_data, dim=0, keepdim=True) |
|
cluster_centers.append(cluster_center) |
|
|
|
|
|
cluster_centers = torch.cat(cluster_centers, dim=0) |
|
cluster_centers = torch.mean(cluster_centers, dim=0) |
|
batch_cluster_centers.append(cluster_centers) |
|
|
|
batch_cluster_centers = torch.stack(batch_cluster_centers, dim=0) |
|
batch_cluster_centers = F.normalize(batch_cluster_centers, dim=1) |
|
|
|
return batch_cluster_centers |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AdaCLIP(nn.Module): |
|
def __init__(self, freeze_clip: CLIP, text_channel: int, visual_channel: int, |
|
prompting_length: int, prompting_depth: int, prompting_branch: str, prompting_type: str, |
|
use_hsf: bool, k_clusters: int, |
|
output_layers: list, device: str, image_size: int): |
|
super(AdaCLIP, self).__init__() |
|
self.freeze_clip = freeze_clip |
|
|
|
self.visual = self.freeze_clip.visual |
|
self.transformer = self.freeze_clip.transformer |
|
self.token_embedding = self.freeze_clip.token_embedding |
|
self.positional_embedding = self.freeze_clip.positional_embedding |
|
self.ln_final = self.freeze_clip.ln_final |
|
self.text_projection = self.freeze_clip.text_projection |
|
self.attn_mask = self.freeze_clip.attn_mask |
|
|
|
self.output_layers = output_layers |
|
|
|
self.prompting_branch = prompting_branch |
|
self.prompting_type = prompting_type |
|
self.prompting_depth = prompting_depth |
|
self.prompting_length = prompting_length |
|
self.use_hsf = use_hsf |
|
self.k_clusters = k_clusters |
|
|
|
if 'L' in self.prompting_branch: |
|
self.enable_text_prompt = True |
|
else: |
|
self.enable_text_prompt = False |
|
|
|
if 'V' in self.prompting_branch: |
|
self.enable_visual_prompt = True |
|
else: |
|
self.enable_visual_prompt = False |
|
|
|
self.text_embedding_layer = TextEmbebddingLayer(fixed=(not self.enable_text_prompt)) |
|
self.text_prompter = PromptLayer(text_channel, prompting_length, prompting_depth, is_text=True, |
|
prompting_type=prompting_type, |
|
enabled=self.enable_text_prompt) |
|
self.visual_prompter = PromptLayer(visual_channel, prompting_length, prompting_depth, is_text=False, |
|
prompting_type=prompting_type, |
|
enabled=self.enable_visual_prompt) |
|
|
|
self.patch_token_layer = ProjectLayer( |
|
visual_channel, |
|
text_channel, |
|
len(output_layers), stack=False, is_array=True |
|
) |
|
|
|
self.cls_token_layer = ProjectLayer( |
|
text_channel, |
|
text_channel, |
|
1, stack=False, is_array=False |
|
) |
|
|
|
if 'D' in self.prompting_type: |
|
self.dynamic_visual_prompt_generator = ProjectLayer(text_channel, |
|
visual_channel, |
|
prompting_length, |
|
stack=True, |
|
is_array=False) |
|
self.dynamic_text_prompt_generator = ProjectLayer(text_channel, |
|
text_channel, |
|
prompting_length, |
|
stack=True, |
|
is_array=False) |
|
|
|
if self.use_hsf: |
|
self.HSF = HybridSemanticFusion(k_clusters) |
|
|
|
self.image_size = image_size |
|
self.device = device |
|
|
|
def generate_and_set_dynamic_promtps(self, image): |
|
with torch.no_grad(): |
|
|
|
image_features, _ = self.visual.forward(image, self.output_layers) |
|
|
|
dynamic_visual_prompts = self.dynamic_visual_prompt_generator(image_features) |
|
dynamic_text_prompts = self.dynamic_text_prompt_generator(image_features) |
|
|
|
self.visual_prompter.set_dynamic_prompts(dynamic_visual_prompts) |
|
self.text_prompter.set_dynamic_prompts(dynamic_text_prompts) |
|
|
|
|
|
def encode_image(self, image): |
|
|
|
x = image |
|
|
|
if self.visual.input_patchnorm: |
|
|
|
x = x.reshape(x.shape[0], x.shape[1], |
|
self.visual.grid_size[0], |
|
self.visual.patch_size[0], |
|
self.visual.grid_size[1], |
|
self.visual.patch_size[1]) |
|
x = x.permute(0, 2, 4, 1, 3, 5) |
|
x = x.reshape(x.shape[0], self.visual.grid_size[0] * self.visual.grid_size[1], -1) |
|
x = self.visual.patchnorm_pre_ln(x) |
|
x = self.visual.conv1(x) |
|
else: |
|
x = self.visual.conv1(x) |
|
x = x.reshape(x.shape[0], x.shape[1], -1) |
|
x = x.permute(0, 2, 1) |
|
|
|
|
|
x = torch.cat( |
|
[self.visual.class_embedding.to(x.dtype) + |
|
torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), |
|
x], dim=1) |
|
|
|
x = x + self.visual.positional_embedding.to(x.dtype) |
|
|
|
|
|
x = self.visual.patch_dropout(x) |
|
x = self.visual.ln_pre(x) |
|
|
|
patch_embedding = x |
|
|
|
x = x.permute(1, 0, 2) |
|
|
|
patch_tokens = [] |
|
|
|
for indx, r in enumerate(self.visual.transformer.resblocks): |
|
x, tokens, attn_tmp = self.visual_prompter(r, indx, x, k_x=None, v_x=None, attn_mask=None) |
|
|
|
if (indx + 1) in self.output_layers: |
|
patch_tokens.append(tokens) |
|
|
|
x = x.permute(1, 0, 2) |
|
patch_tokens = [patch_tokens[t].permute(1, 0, 2) for t in range(len(patch_tokens))] |
|
|
|
if self.visual.attn_pool is not None: |
|
x = self.visual.attn_pool(x) |
|
x = self.visual.ln_post(x) |
|
pooled, tokens = self.visual._global_pool(x) |
|
else: |
|
pooled, tokens = self.visual._global_pool(x) |
|
pooled = self.visual.ln_post(pooled) |
|
|
|
if self.visual.proj is not None: |
|
pooled = pooled @ self.visual.proj |
|
|
|
return pooled, patch_tokens, patch_embedding |
|
|
|
def proj_visual_tokens(self, image_features, patch_tokens): |
|
|
|
|
|
proj_patch_tokens = self.patch_token_layer(patch_tokens) |
|
for layer in range(len(proj_patch_tokens)): |
|
proj_patch_tokens[layer] /= proj_patch_tokens[layer].norm(dim=-1, keepdim=True) |
|
|
|
|
|
proj_cls_tokens = self.cls_token_layer(image_features)[0] |
|
proj_cls_tokens /= proj_cls_tokens.norm(dim=-1, keepdim=True) |
|
|
|
return proj_cls_tokens, proj_patch_tokens |
|
|
|
def encode_text(self, text): |
|
cast_dtype = self.transformer.get_cast_dtype() |
|
|
|
x = self.token_embedding(text).to(cast_dtype) |
|
|
|
x = x + self.positional_embedding.to(cast_dtype) |
|
x = x.permute(1, 0, 2) |
|
|
|
for indx, r in enumerate(self.transformer.resblocks): |
|
|
|
x, attn_tmp = self.text_prompter(r, indx, x, k_x=None, v_x=None, attn_mask=self.attn_mask) |
|
|
|
x = x.permute(1, 0, 2) |
|
x = self.ln_final(x) |
|
|
|
|
|
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection |
|
return x |
|
|
|
def visual_text_similarity(self, image_feature, patch_token, text_feature, aggregation): |
|
anomaly_maps = [] |
|
|
|
for layer in range(len(patch_token)): |
|
anomaly_map = (100.0 * patch_token[layer] @ text_feature) |
|
anomaly_maps.append(anomaly_map) |
|
|
|
if self.use_hsf: |
|
alpha = 0.2 |
|
clustered_feature = self.HSF.forward(patch_token, anomaly_maps) |
|
|
|
cur_image_feature = alpha * clustered_feature + (1 - alpha) * image_feature |
|
cur_image_feature = F.normalize(cur_image_feature, dim=1) |
|
else: |
|
cur_image_feature = image_feature |
|
|
|
anomaly_score = (100.0 * cur_image_feature.unsqueeze(1) @ text_feature) |
|
anomaly_score = anomaly_score.squeeze(1) |
|
anomaly_score = torch.softmax(anomaly_score, dim=1) |
|
|
|
|
|
for i in range(len(anomaly_maps)): |
|
B, L, C = anomaly_maps[i].shape |
|
H = int(np.sqrt(L)) |
|
anomaly_maps[i] = anomaly_maps[i].permute(0, 2, 1).view(B, 2, H, H) |
|
anomaly_maps[i] = F.interpolate(anomaly_maps[i], size=self.image_size, mode='bilinear', align_corners=True) |
|
|
|
if aggregation: |
|
anomaly_map = torch.mean(torch.stack(anomaly_maps, dim=1), dim=1) |
|
anomaly_map = torch.softmax(anomaly_map, dim=1) |
|
anomaly_map = (anomaly_map[:, 1:, :, :] + 1 - anomaly_map[:, 0:1, :, :]) / 2.0 |
|
anomaly_score = anomaly_score[:, 1] |
|
return anomaly_map, anomaly_score |
|
else: |
|
for i in range(len(anomaly_maps)): |
|
anomaly_maps[i] = torch.softmax(anomaly_maps[i], dim=1) |
|
return anomaly_maps, anomaly_score |
|
|
|
def extract_feat(self, image, cls_name): |
|
if 'D' in self.prompting_type: |
|
self.generate_and_set_dynamic_promtps(image) |
|
|
|
if self.enable_visual_prompt: |
|
image_features, patch_tokens, _ = self.encode_image(image) |
|
else: |
|
with torch.no_grad(): |
|
image_features, patch_tokens, _ = self.encode_image(image) |
|
|
|
if self.enable_text_prompt: |
|
text_features = self.text_embedding_layer(self, cls_name, self.device) |
|
else: |
|
with torch.no_grad(): |
|
text_features = self.text_embedding_layer(self, cls_name, self.device) |
|
|
|
proj_cls_tokens, proj_patch_tokens = self.proj_visual_tokens(image_features, patch_tokens) |
|
|
|
return proj_cls_tokens, proj_patch_tokens, text_features |
|
|
|
@torch.cuda.amp.autocast() |
|
def forward(self, image, cls_name, aggregation=True): |
|
|
|
image_features, patch_tokens, text_features = self.extract_feat(image, cls_name) |
|
anomaly_map, anomaly_score = self.visual_text_similarity(image_features, patch_tokens, text_features, aggregation) |
|
|
|
if aggregation: |
|
anomaly_map = anomaly_map |
|
anomaly_score = anomaly_score |
|
anomaly_map = anomaly_map.squeeze(1) |
|
|
|
return anomaly_map, anomaly_score |
|
else: |
|
anomaly_maps = anomaly_map |
|
anomaly_score = anomaly_score |
|
|
|
return anomaly_maps, anomaly_score |
|
|
|
|