|
import os |
|
import json |
|
import pandas as pd |
|
import random |
|
from dataset import VISA_ROOT |
|
|
|
class VisASolver(object): |
|
CLSNAMES = [ |
|
'candle', 'capsules', 'cashew', 'chewinggum', 'fryum', |
|
'macaroni1', 'macaroni2', 'pcb1', 'pcb2', 'pcb3', |
|
'pcb4', 'pipe_fryum', |
|
] |
|
|
|
def __init__(self, root=VISA_ROOT, train_ratio=0.5): |
|
self.root = root |
|
self.meta_path = f'{root}/meta.json' |
|
self.phases = ['train', 'test'] |
|
self.csv_data = pd.read_csv(f'{root}/split_csv/1cls.csv', header=0) |
|
self.train_ratio = train_ratio |
|
|
|
def run(self): |
|
self.generate_meta_info() |
|
|
|
def generate_meta_info(self): |
|
columns = self.csv_data.columns |
|
info = {phase: {} for phase in self.phases} |
|
for cls_name in self.CLSNAMES: |
|
cls_data = self.csv_data[self.csv_data[columns[0]] == cls_name] |
|
for phase in self.phases: |
|
cls_info = [] |
|
cls_data_phase = cls_data[cls_data[columns[1]] == phase] |
|
cls_data_phase.index = list(range(len(cls_data_phase))) |
|
for idx in range(cls_data_phase.shape[0]): |
|
data = cls_data_phase.loc[idx] |
|
is_abnormal = True if data[2] == 'anomaly' else False |
|
info_img = dict( |
|
img_path=data[3], |
|
mask_path=data[4] if is_abnormal else '', |
|
cls_name=cls_name, |
|
specie_name='', |
|
anomaly=1 if is_abnormal else 0, |
|
) |
|
cls_info.append(info_img) |
|
info[phase][cls_name] = cls_info |
|
with open(self.meta_path, 'w') as f: |
|
f.write(json.dumps(info, indent=4) + "\n") |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
runner = VisASolver(root=VISA_ROOT) |
|
runner.run() |
|
|