File size: 26,178 Bytes
16a0f31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 |
from typing import Union, List, Optional
import numpy as np
import torch
from pkg_resources import packaging
from torch import nn
from torch.nn import functional as F
from .clip_model import CLIP
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
from sklearn.cluster import KMeans
class ProjectLayer(nn.Module):
def __init__(self, input_dim, output_dim, num_replicas, stack=False, is_array=True):
super(ProjectLayer, self).__init__()
self.head = nn.ModuleList([nn.Linear(input_dim, output_dim) for _ in range(num_replicas)])
self.num_replicas = num_replicas
self.stack = stack
self.is_array = is_array
def forward(self, tokens):
out_tokens = []
for i in range(self.num_replicas):
if self.is_array:
temp = self.head[i](tokens[i][:, 1:, :]) # for ViT, we exclude the class token and only extract patch tokens here.
else:
temp = self.head[i](tokens)
out_tokens.append(temp)
if self.stack:
out_tokens = torch.stack(out_tokens, dim=1)
return out_tokens
class PromptLayer(nn.Module):
def __init__(self, channel, length, depth, is_text, prompting_type, enabled=True):
super(PromptLayer, self).__init__()
self.channel = channel
self.length = length
self.depth = depth
self.is_text = is_text
self.enabled = enabled
self.prompting_type = prompting_type
if self.enabled: # only when enabled, the parameters should be constructed
if 'S' in prompting_type: # static prompts
# learnable
self.static_prompts = nn.ParameterList(
[nn.Parameter(torch.empty(self.length, self.channel))
for _ in range(self.depth)])
for single_para in self.static_prompts:
nn.init.normal_(single_para, std=0.02)
if 'D' in prompting_type: # dynamic prompts
self.dynamic_prompts = [0.] # place holder
def set_dynamic_prompts(self, dynamic_prompts):
self.dynamic_prompts = dynamic_prompts
def forward_text(self, resblock, indx, x, k_x=None, v_x=None, attn_mask: Optional[torch.Tensor] = None):
if self.enabled:
length = self.length
# only prompt the first J layers
if indx < self.depth:
if 'S' in self.prompting_type and 'D' in self.prompting_type: # both
static_prompts = self.static_prompts[indx].unsqueeze(0).expand(x.shape[1], -1, -1)
textual_context = self.dynamic_prompts + static_prompts
elif 'S' in self.prompting_type: # static
static_prompts = self.static_prompts[indx].unsqueeze(0).expand(x.shape[1], -1, -1)
textual_context = static_prompts
elif 'D' in self.prompting_type: # dynamic
textual_context = self.dynamic_prompts
else:
print('You should at least choose one type of prompts when the prompting branches are not none.')
raise NotImplementedError
if indx == 0: # for the first layer
x = x
else:
if indx < self.depth: # replace with learnalbe tokens
prefix = x[:1, :, :]
suffix = x[1 + length:, :, :]
textual_context = textual_context.permute(1, 0, 2).half()
x = torch.cat([prefix, textual_context, suffix], dim=0)
else: # keep the same
x = x
else:
x = x
x, attn_tmp = resblock(q_x=x, k_x=k_x, v_x= v_x, attn_mask=attn_mask)
return x, attn_tmp
def forward_visual(self, resblock, indx, x, k_x=None, v_x=None, attn_mask: Optional[torch.Tensor] = None):
if self.enabled:
length = self.length
# only prompt the first J layers
if indx < self.depth:
if 'S' in self.prompting_type and 'D' in self.prompting_type: # both
static_prompts = self.static_prompts[indx].unsqueeze(0).expand(x.shape[1], -1, -1)
visual_context = self.dynamic_prompts + static_prompts
elif 'S' in self.prompting_type: # static
static_prompts = self.static_prompts[indx].unsqueeze(0).expand(x.shape[1], -1, -1)
visual_context = static_prompts
elif 'D' in self.prompting_type: # dynamic
visual_context = self.dynamic_prompts
else:
print('You should at least choose one type of prompts when the prompting branches are not none.')
raise NotImplementedError
if indx == 0: # for the first layer
visual_context = visual_context.permute(1, 0, 2).half()
x = torch.cat([x, visual_context], dim=0)
else:
if indx < self.depth: # replace with learnalbe tokens
prefix = x[0:x.shape[0] - length, :, :]
visual_context = visual_context.permute(1, 0, 2).half()
x = torch.cat([prefix, visual_context], dim=0)
else: # keep the same
x = x
else:
x = x
x, attn_tmp = resblock(q_x=x, k_x=k_x, v_x= v_x, attn_mask=attn_mask)
if self.enabled:
tokens = x[:x.shape[0] - length, :, :]
else:
tokens = x
return x, tokens, attn_tmp
def forward(self, resblock, indx, x, k_x=None, v_x=None, attn_mask: Optional[torch.Tensor] = None):
if self.is_text:
return self.forward_text(resblock, indx, x, k_x, v_x, attn_mask)
else:
return self.forward_visual(resblock, indx, x, k_x, v_x, attn_mask)
class TextEmbebddingLayer(nn.Module):
def __init__(self, fixed):
super(TextEmbebddingLayer, self).__init__()
self.tokenizer = _Tokenizer()
self.ensemble_text_features = {}
self.prompt_normal = ['{}', 'flawless {}', 'perfect {}', 'unblemished {}', '{} without flaw',
'{} without defect',
'{} without damage']
self.prompt_abnormal = ['damaged {}', 'broken {}', '{} with flaw', '{} with defect', '{} with damage']
self.prompt_state = [self.prompt_normal, self.prompt_abnormal]
self.prompt_templates = ['a bad photo of a {}.',
'a low resolution photo of the {}.',
'a bad photo of the {}.',
'a cropped photo of the {}.',
]
self.fixed = fixed
def tokenize(self, texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[
torch.IntTensor, torch.LongTensor]:
if isinstance(texts, str):
texts = [texts]
sot_token = self.tokenizer.encoder["<|startoftext|>"]
eot_token = self.tokenizer.encoder["<|endoftext|>"]
all_tokens = [[sot_token] + self.tokenizer.encode(text) + [eot_token] for text in texts]
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
else:
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
if truncate:
tokens = tokens[:context_length]
tokens[-1] = eot_token
else:
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
result[i, :len(tokens)] = torch.tensor(tokens)
return result
## TODO: text layeer is not compitable with multiple batches...
def forward(self, model, texts, device):
text_feature_list = []
for indx, text in enumerate(texts):
if self.fixed:
if self.ensemble_text_features.get(text) is None:
text_features = self.encode_text(model, text, device)
self.ensemble_text_features[text] = text_features
else:
text_features = self.ensemble_text_features[text]
else:
text_features = self.encode_text(model, text, device)
self.ensemble_text_features[text] = text_features
text_feature_list.append(text_features)
text_features = torch.stack(text_feature_list, dim=0)
text_features = F.normalize(text_features, dim=1)
return text_features
def encode_text(self, model, text, device):
text_features = []
for i in range(len(self.prompt_state)):
text = text.replace('-', ' ')
prompted_state = [state.format(text) for state in self.prompt_state[i]]
prompted_sentence = []
for s in prompted_state:
for template in self.prompt_templates:
prompted_sentence.append(template.format(s))
prompted_sentence = self.tokenize(prompted_sentence, context_length=77).to(device)
class_embeddings = model.encode_text(prompted_sentence)
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
class_embedding = class_embeddings.mean(dim=0)
class_embedding /= class_embedding.norm()
text_features.append(class_embedding)
text_features = torch.stack(text_features, dim=1)
return text_features
# Note: the implementation of HSF is slightly different to the reported one, since we found that the upgraded one is more stable.
class HybridSemanticFusion(nn.Module):
def __init__(self, k_clusters):
super(HybridSemanticFusion, self).__init__()
self.k_clusters = k_clusters
self.n_aggregate_patch_tokens = k_clusters * 5
self.cluster_performer = KMeans(n_clusters=self.k_clusters, n_init="auto")
# @torch.no_grad()
def forward(self, patch_tokens: list, anomaly_maps: list):
anomaly_map = torch.mean(torch.stack(anomaly_maps, dim=1), dim=1)
anomaly_map = torch.softmax(anomaly_map, dim=2)[:, :, 1] # B, L
# extract most abnormal feats
selected_abnormal_tokens = []
k = min(anomaly_map.shape[1], self.n_aggregate_patch_tokens)
top_k_indices = torch.topk(anomaly_map, k=k, dim=1).indices
for layer in range(len(patch_tokens)):
selected_tokens = patch_tokens[layer]. \
gather(dim=1, index=top_k_indices.unsqueeze(-1).
expand(-1, -1, patch_tokens[layer].shape[-1]))
selected_abnormal_tokens.append(selected_tokens)
# use kmeans to extract these centriods
# Stack the data_preprocess
stacked_data = torch.cat(selected_abnormal_tokens, dim=2)
batch_cluster_centers = []
# Perform K-Means clustering
for b in range(stacked_data.shape[0]):
cluster_labels = self.cluster_performer.fit_predict(stacked_data[b, :, :].detach().cpu().numpy())
# Initialize a list to store the cluster centers
cluster_centers = []
# Extract cluster centers for each cluster
for cluster_id in range(self.k_clusters):
collected_cluster_data = []
for abnormal_tokens in selected_abnormal_tokens:
cluster_data = abnormal_tokens[b, :, :][cluster_labels == cluster_id]
collected_cluster_data.append(cluster_data)
collected_cluster_data = torch.cat(collected_cluster_data, dim=0)
cluster_center = torch.mean(collected_cluster_data, dim=0, keepdim=True)
cluster_centers.append(cluster_center)
# Normalize the cluster centers
cluster_centers = torch.cat(cluster_centers, dim=0)
cluster_centers = torch.mean(cluster_centers, dim=0)
batch_cluster_centers.append(cluster_centers)
batch_cluster_centers = torch.stack(batch_cluster_centers, dim=0)
batch_cluster_centers = F.normalize(batch_cluster_centers, dim=1)
return batch_cluster_centers
# # preprocess
# # compute the anomaly map
# anomaly_map = torch.mean(torch.stack(anomaly_maps, dim=1), dim=1)
# anomaly_map = torch.softmax(anomaly_map, dim=2)[:, :, 1] # B, L
#
# # compute the average multi-hierarchy patch embeddings
# avg_patch_tokens = torch.mean(torch.stack(patch_tokens, dim=0), dim=0) # B, L, C
#
# # Initialize a list to store the centroids of clusters with the largest anomaly scores
# cluster_centroids = []
#
# # loop across the batch size
# for b in range(avg_patch_tokens.shape[0]):
# # step1: group features into clusters
# cluster_labels = self.cluster_performer.fit_predict(avg_patch_tokens[b, :, :].detach().cpu().numpy())
#
# # step2: compute the anomaly scores for individual clusters via the anomaly map
# # Convert cluster labels back to tensor
# cluster_labels = torch.tensor(cluster_labels).to(avg_patch_tokens.device)
# cluster_anomaly_scores = {}
# for label in torch.unique(cluster_labels):
# cluster_indices = torch.where(cluster_labels == label)[0]
# cluster_anomaly_scores[label.item()] = anomaly_map[b, cluster_indices].mean().item()
#
# # step3: select the cluster with the largest anomaly score and then compute its centroid by averaging the
# # corresponding avg_patch_tokens
# # Find the cluster with the largest anomaly score
# largest_anomaly_cluster = max(cluster_anomaly_scores, key=cluster_anomaly_scores.get)
#
# # Get the indices of the tokens belonging to the largest anomaly cluster
# largest_anomaly_cluster_indices = torch.where(cluster_labels == largest_anomaly_cluster)[0]
#
# # Compute the centroid of the largest anomaly cluster by averaging the corresponding avg_patch_tokens
# centroid = avg_patch_tokens[b, largest_anomaly_cluster_indices, :].mean(dim=0)
#
# # Append the centroid to the list of cluster centroids
# cluster_centroids.append(centroid)
#
# # Convert the list of centroids to a tensor
# cluster_centroids = torch.stack(cluster_centroids, dim=0)
# cluster_centroids = F.normalize(cluster_centroids, dim=1)
# return cluster_centroids
class AdaCLIP(nn.Module):
def __init__(self, freeze_clip: CLIP, text_channel: int, visual_channel: int,
prompting_length: int, prompting_depth: int, prompting_branch: str, prompting_type: str,
use_hsf: bool, k_clusters: int,
output_layers: list, device: str, image_size: int):
super(AdaCLIP, self).__init__()
self.freeze_clip = freeze_clip
self.visual = self.freeze_clip.visual
self.transformer = self.freeze_clip.transformer
self.token_embedding = self.freeze_clip.token_embedding
self.positional_embedding = self.freeze_clip.positional_embedding
self.ln_final = self.freeze_clip.ln_final
self.text_projection = self.freeze_clip.text_projection
self.attn_mask = self.freeze_clip.attn_mask
self.output_layers = output_layers
self.prompting_branch = prompting_branch
self.prompting_type = prompting_type
self.prompting_depth = prompting_depth
self.prompting_length = prompting_length
self.use_hsf = use_hsf
self.k_clusters = k_clusters
if 'L' in self.prompting_branch:
self.enable_text_prompt = True
else:
self.enable_text_prompt = False
if 'V' in self.prompting_branch:
self.enable_visual_prompt = True
else:
self.enable_visual_prompt = False
self.text_embedding_layer = TextEmbebddingLayer(fixed=(not self.enable_text_prompt))
self.text_prompter = PromptLayer(text_channel, prompting_length, prompting_depth, is_text=True,
prompting_type=prompting_type,
enabled=self.enable_text_prompt)
self.visual_prompter = PromptLayer(visual_channel, prompting_length, prompting_depth, is_text=False,
prompting_type=prompting_type,
enabled=self.enable_visual_prompt)
self.patch_token_layer = ProjectLayer(
visual_channel,
text_channel,
len(output_layers), stack=False, is_array=True
)
self.cls_token_layer = ProjectLayer(
text_channel,
text_channel,
1, stack=False, is_array=False
)
if 'D' in self.prompting_type: # dynamic prompts
self.dynamic_visual_prompt_generator = ProjectLayer(text_channel,
visual_channel,
prompting_length,
stack=True,
is_array=False)
self.dynamic_text_prompt_generator = ProjectLayer(text_channel,
text_channel,
prompting_length,
stack=True,
is_array=False)
if self.use_hsf:
self.HSF = HybridSemanticFusion(k_clusters)
self.image_size = image_size
self.device = device
def generate_and_set_dynamic_promtps(self, image):
with torch.no_grad():
# extract image features
image_features, _ = self.visual.forward(image, self.output_layers)
dynamic_visual_prompts = self.dynamic_visual_prompt_generator(image_features)
dynamic_text_prompts = self.dynamic_text_prompt_generator(image_features)
self.visual_prompter.set_dynamic_prompts(dynamic_visual_prompts)
self.text_prompter.set_dynamic_prompts(dynamic_text_prompts)
def encode_image(self, image):
x = image
# to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
if self.visual.input_patchnorm:
# einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')
x = x.reshape(x.shape[0], x.shape[1],
self.visual.grid_size[0],
self.visual.patch_size[0],
self.visual.grid_size[1],
self.visual.patch_size[1])
x = x.permute(0, 2, 4, 1, 3, 5)
x = x.reshape(x.shape[0], self.visual.grid_size[0] * self.visual.grid_size[1], -1)
x = self.visual.patchnorm_pre_ln(x)
x = self.visual.conv1(x)
else:
x = self.visual.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
# class embeddings and positional embeddings
x = torch.cat(
[self.visual.class_embedding.to(x.dtype) +
torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
x], dim=1) # shape = [*, grid ** 2 + 1, width]
x = x + self.visual.positional_embedding.to(x.dtype)
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
x = self.visual.patch_dropout(x)
x = self.visual.ln_pre(x)
patch_embedding = x
x = x.permute(1, 0, 2) # NLD -> LND
patch_tokens = []
for indx, r in enumerate(self.visual.transformer.resblocks):
x, tokens, attn_tmp = self.visual_prompter(r, indx, x, k_x=None, v_x=None, attn_mask=None)
if (indx + 1) in self.output_layers:
patch_tokens.append(tokens)
x = x.permute(1, 0, 2) # LND -> NLD
patch_tokens = [patch_tokens[t].permute(1, 0, 2) for t in range(len(patch_tokens))] # LND -> NLD
if self.visual.attn_pool is not None:
x = self.visual.attn_pool(x)
x = self.visual.ln_post(x)
pooled, tokens = self.visual._global_pool(x)
else:
pooled, tokens = self.visual._global_pool(x)
pooled = self.visual.ln_post(pooled)
if self.visual.proj is not None:
pooled = pooled @ self.visual.proj
return pooled, patch_tokens, patch_embedding
def proj_visual_tokens(self, image_features, patch_tokens):
# for patch tokens
proj_patch_tokens = self.patch_token_layer(patch_tokens)
for layer in range(len(proj_patch_tokens)):
proj_patch_tokens[layer] /= proj_patch_tokens[layer].norm(dim=-1, keepdim=True)
# for cls tokens
proj_cls_tokens = self.cls_token_layer(image_features)[0]
proj_cls_tokens /= proj_cls_tokens.norm(dim=-1, keepdim=True)
return proj_cls_tokens, proj_patch_tokens
def encode_text(self, text):
cast_dtype = self.transformer.get_cast_dtype()
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.to(cast_dtype)
x = x.permute(1, 0, 2) # NLD -> LND
for indx, r in enumerate(self.transformer.resblocks):
# add prompt here
x, attn_tmp = self.text_prompter(r, indx, x, k_x=None, v_x=None, attn_mask=self.attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return x
def visual_text_similarity(self, image_feature, patch_token, text_feature, aggregation):
anomaly_maps = []
for layer in range(len(patch_token)):
anomaly_map = (100.0 * patch_token[layer] @ text_feature)
anomaly_maps.append(anomaly_map)
if self.use_hsf:
alpha = 0.2
clustered_feature = self.HSF.forward(patch_token, anomaly_maps)
# aggregate the class token and the clustered features for more comprehensive information
cur_image_feature = alpha * clustered_feature + (1 - alpha) * image_feature
cur_image_feature = F.normalize(cur_image_feature, dim=1)
else:
cur_image_feature = image_feature
anomaly_score = (100.0 * cur_image_feature.unsqueeze(1) @ text_feature)
anomaly_score = anomaly_score.squeeze(1)
anomaly_score = torch.softmax(anomaly_score, dim=1)
# NOTE: this bilinear interpolation is not unreproducible and may occasionally lead to unstable ZSAD performance.
for i in range(len(anomaly_maps)):
B, L, C = anomaly_maps[i].shape
H = int(np.sqrt(L))
anomaly_maps[i] = anomaly_maps[i].permute(0, 2, 1).view(B, 2, H, H)
anomaly_maps[i] = F.interpolate(anomaly_maps[i], size=self.image_size, mode='bilinear', align_corners=True)
if aggregation: # in the test stage, we firstly aggregate logits from all hierarchies and then do the softmax normalization
anomaly_map = torch.mean(torch.stack(anomaly_maps, dim=1), dim=1)
anomaly_map = torch.softmax(anomaly_map, dim=1)
anomaly_map = (anomaly_map[:, 1:, :, :] + 1 - anomaly_map[:, 0:1, :, :]) / 2.0
anomaly_score = anomaly_score[:, 1]
return anomaly_map, anomaly_score
else: # otherwise, we do the softmax normalization for individual hierarchies
for i in range(len(anomaly_maps)):
anomaly_maps[i] = torch.softmax(anomaly_maps[i], dim=1)
return anomaly_maps, anomaly_score
def extract_feat(self, image, cls_name):
if 'D' in self.prompting_type:
self.generate_and_set_dynamic_promtps(image) # generate and set dynamic prompts for corresponding prompters
if self.enable_visual_prompt:
image_features, patch_tokens, _ = self.encode_image(image)
else:
with torch.no_grad():
image_features, patch_tokens, _ = self.encode_image(image)
if self.enable_text_prompt:
text_features = self.text_embedding_layer(self, cls_name, self.device)
else:
with torch.no_grad():
text_features = self.text_embedding_layer(self, cls_name, self.device)
proj_cls_tokens, proj_patch_tokens = self.proj_visual_tokens(image_features, patch_tokens)
return proj_cls_tokens, proj_patch_tokens, text_features
@torch.cuda.amp.autocast()
def forward(self, image, cls_name, aggregation=True):
# extract features for images and texts
image_features, patch_tokens, text_features = self.extract_feat(image, cls_name)
anomaly_map, anomaly_score = self.visual_text_similarity(image_features, patch_tokens, text_features, aggregation)
if aggregation:
anomaly_map = anomaly_map # tensor
anomaly_score = anomaly_score
anomaly_map = anomaly_map.squeeze(1)
return anomaly_map, anomaly_score
else:
anomaly_maps = anomaly_map # list
anomaly_score = anomaly_score
return anomaly_maps, anomaly_score
|