|
import warnings |
|
|
|
warnings.filterwarnings("ignore", category=RuntimeWarning) |
|
import os |
|
|
|
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" |
|
|
|
|
|
import argparse |
|
import json |
|
import os |
|
import torch |
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
from tools import write2csv, setup_seed, Logger |
|
|
|
|
|
from method import AdaCLIP_Trainer |
|
from PIL import Image |
|
import numpy as np |
|
from datasets.rayan_dataset import RayanDataset |
|
from utils.dump_scores import DumpScores |
|
from torchvision import transforms |
|
|
|
setup_seed(111) |
|
|
|
IMAGENET_MEAN = [0.485, 0.456, 0.406] |
|
IMAGENET_STD = [0.229, 0.224, 0.225] |
|
|
|
|
|
def get_available_class_names(data_path): |
|
all_items = os.listdir(data_path) |
|
folder_names = [ |
|
item for item in all_items if os.path.isdir(os.path.join(data_path, item)) |
|
] |
|
|
|
return folder_names |
|
|
|
|
|
def train(args): |
|
assert os.path.isfile( |
|
args.ckt_path |
|
), f"Please check the path of pre-trained model, {args.ckt_path} is not valid." |
|
|
|
data_path = "./data/" |
|
class_names = get_available_class_names(data_path) |
|
|
|
|
|
batch_size = args.batch_size |
|
image_size = args.image_size |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
save_fig = args.save_fig |
|
|
|
|
|
logger = Logger("log.txt") |
|
|
|
|
|
for key, value in sorted(vars(args).items()): |
|
logger.info(f"{key} = {value}") |
|
|
|
config_path = os.path.join("./model_configs", f"{args.model}.json") |
|
|
|
|
|
with open(config_path, "r") as f: |
|
model_configs = json.load(f) |
|
|
|
|
|
n_layers = model_configs["vision_cfg"]["layers"] |
|
substage = n_layers // 4 |
|
features_list = [substage, substage * 2, substage * 3, substage * 4] |
|
|
|
model = AdaCLIP_Trainer( |
|
backbone=args.model, |
|
feat_list=features_list, |
|
input_dim=model_configs["vision_cfg"]["width"], |
|
output_dim=model_configs["embed_dim"], |
|
learning_rate=0.0, |
|
device=device, |
|
image_size=image_size, |
|
prompting_depth=args.prompting_depth, |
|
prompting_length=args.prompting_length, |
|
prompting_branch=args.prompting_branch, |
|
prompting_type=args.prompting_type, |
|
use_hsf=args.use_hsf, |
|
k_clusters=args.k_clusters, |
|
).to(device) |
|
|
|
model.load(args.ckt_path) |
|
|
|
if args.testing_model == "dataset": |
|
|
|
|
|
|
|
save_root = args.save_path |
|
csv_root = os.path.join(save_root, "csvs") |
|
image_root = os.path.join(save_root, "images") |
|
csv_path = os.path.join(csv_root, f"{args.testing_data}.csv") |
|
image_dir = os.path.join(image_root, f"{args.testing_data}") |
|
os.makedirs(image_dir, exist_ok=True) |
|
|
|
dumper = DumpScores() |
|
|
|
for classname in class_names: |
|
test_data = RayanDataset( |
|
source=data_path, |
|
classname=classname, |
|
external_transform=transforms.Compose( |
|
[ |
|
transforms.Resize((224, 224)), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), |
|
] |
|
), |
|
) |
|
|
|
test_dataloader = torch.utils.data.DataLoader( |
|
test_data, batch_size=1, shuffle=False |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_data_cls_names = [classname] |
|
results = model.evaluation( |
|
test_dataloader, |
|
test_data_cls_names, |
|
False, |
|
image_dir, |
|
) |
|
results["anomaly_maps"] = np.concatenate(results["anomaly_maps"], axis=0) |
|
results["anomaly_maps"] = results["anomaly_maps"][:, np.newaxis, :, :] |
|
|
|
dumper.save_scores( |
|
results["img_path"], |
|
results["anomaly_scores"], |
|
results["anomaly_maps"], |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def str2bool(v): |
|
return v.lower() in ("yes", "true", "t", "1") |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser("AdaCLIP", add_help=True) |
|
|
|
|
|
parser.add_argument( |
|
"--ckt_path", |
|
type=str, |
|
default="weights/pretrained_mvtec_colondb.pth", |
|
help="Path to the pre-trained model (default: weights/pretrained_mvtec_colondb.pth)", |
|
) |
|
|
|
parser.add_argument( |
|
"--testing_model", |
|
type=str, |
|
default="dataset", |
|
choices=["dataset", "image"], |
|
help="Model for testing (default: 'dataset')", |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--testing_data", |
|
type=str, |
|
default="visa", |
|
help="Dataset for testing (default: 'visa')", |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--image_path", |
|
type=str, |
|
default="asset/img.png", |
|
help="Model for testing (default: 'asset/img.png')", |
|
) |
|
parser.add_argument( |
|
"--class_name", |
|
type=str, |
|
default="candle", |
|
help="The class name of the testing image (default: 'candle')", |
|
) |
|
parser.add_argument( |
|
"--save_name", |
|
type=str, |
|
default="test.png", |
|
help="Model for testing (default: 'dataset')", |
|
) |
|
|
|
parser.add_argument( |
|
"--save_path", |
|
type=str, |
|
default="./workspaces", |
|
help="Directory to save results (default: './workspaces')", |
|
) |
|
|
|
parser.add_argument( |
|
"--model", |
|
type=str, |
|
default="ViT-L-14-336", |
|
choices=["ViT-B-16", "ViT-B-32", "ViT-L-14", "ViT-L-14-336"], |
|
help="The CLIP model to be used (default: 'ViT-L-14-336')", |
|
) |
|
|
|
parser.add_argument( |
|
"--save_fig", |
|
type=str2bool, |
|
default=False, |
|
help="Save figures for visualizations (default: False)", |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--batch_size", type=int, default=1, help="Batch size (default: 1)" |
|
) |
|
parser.add_argument( |
|
"--image_size", |
|
type=int, |
|
default=224, |
|
help="Size of the input images (default: 518)", |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--prompting_depth", type=int, default=4, help="Depth of prompting (default: 4)" |
|
) |
|
parser.add_argument( |
|
"--prompting_length", |
|
type=int, |
|
default=5, |
|
help="Length of prompting (default: 5)", |
|
) |
|
parser.add_argument( |
|
"--prompting_type", |
|
type=str, |
|
default="SD", |
|
choices=["", "S", "D", "SD"], |
|
help="Type of prompting. 'S' for Static, 'D' for Dynamic, 'SD' for both (default: 'SD')", |
|
) |
|
parser.add_argument( |
|
"--prompting_branch", |
|
type=str, |
|
default="VL", |
|
choices=["", "V", "L", "VL"], |
|
help="Branch of prompting. 'V' for Visual, 'L' for Language, 'VL' for both (default: 'VL')", |
|
) |
|
|
|
parser.add_argument( |
|
"--use_hsf", |
|
type=str2bool, |
|
default=True, |
|
help="Use HSF for aggregation. If False, original class embedding is used (default: True)", |
|
) |
|
parser.add_argument( |
|
"--k_clusters", type=int, default=20, help="Number of clusters (default: 20)" |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
if args.batch_size != 1: |
|
raise NotImplementedError( |
|
"Currently, only batch size of 1 is supported due to unresolved bugs. Please set --batch_size to 1." |
|
) |
|
|
|
train(args) |
|
|