|
import os |
|
import numpy as np |
|
from sklearn.model_selection import train_test_split |
|
import cv2 |
|
import argparse |
|
from config import DATA_ROOT |
|
|
|
dataset_root = os.path.join(DATA_ROOT, 'DAGM2007') |
|
|
|
class_names = os.listdir(dataset_root) |
|
|
|
|
|
for class_name in class_names: |
|
states = os.listdir(os.path.join(dataset_root, class_name)) |
|
for state in states: |
|
images = list() |
|
mask = list() |
|
files = os.listdir(os.path.join(dataset_root, class_name,state)) |
|
for f in files: |
|
if 'PNG' in f[-3:]: |
|
images.append(f) |
|
files = os.listdir(os.path.join(dataset_root, class_name, state,'Label')) |
|
for f in files: |
|
if 'PNG' in f[-3:]: |
|
mask.append(f) |
|
normal_image_path_train = list() |
|
normal_image_path_test = list() |
|
normal_image_path = list() |
|
abnormal_image_path = list() |
|
abnormal_image_label = list() |
|
for f in images: |
|
id = f[-8:-4] |
|
flag = 0 |
|
for y in mask: |
|
if id in y: |
|
abnormal_image_path.append(f) |
|
abnormal_image_label.append(y) |
|
flag = 1 |
|
break |
|
if flag == 0: |
|
normal_image_path.append(f) |
|
|
|
if len(abnormal_image_path) != len(abnormal_image_label): |
|
raise ValueError |
|
length = len(abnormal_image_path) |
|
|
|
normal_image_path_test = normal_image_path[:length] |
|
normal_image_path_train = normal_image_path[length:] |
|
|
|
target_root = '../datasets/DAGM_anomaly_detection' |
|
|
|
train_root = os.path.join(target_root, class_name, 'train','good') |
|
if not os.path.exists(train_root): |
|
os.makedirs(train_root) |
|
for f in normal_image_path_train: |
|
image_data = cv2.imread(os.path.join(dataset_root, class_name, state,f)) |
|
cv2.imwrite(os.path.join(train_root,f), image_data) |
|
|
|
test_root = os.path.join(target_root, class_name, 'test','good') |
|
if not os.path.exists(test_root): |
|
os.makedirs(test_root) |
|
for f in normal_image_path_test: |
|
image_data = cv2.imread(os.path.join(dataset_root, class_name, state,f)) |
|
cv2.imwrite(os.path.join(test_root,f), image_data) |
|
|
|
test_root = os.path.join(target_root, class_name, 'test','defect') |
|
if not os.path.exists(test_root): |
|
os.makedirs(test_root) |
|
for f in abnormal_image_path: |
|
image_data = cv2.imread(os.path.join(dataset_root, class_name, state,f)) |
|
cv2.imwrite(os.path.join(test_root,f), image_data) |
|
|
|
test_root = os.path.join(target_root, class_name, 'ground_truth','defect') |
|
if not os.path.exists(test_root): |
|
os.makedirs(test_root) |
|
for f in mask: |
|
image_data = cv2.imread(os.path.join(dataset_root, class_name, state,'Label',f)) |
|
cv2.imwrite(os.path.join(test_root,f), image_data) |
|
|
|
|
|
|
|
print("Done") |