KyanChen's picture
Upload 1861 files
3b96cb1
raw
history blame
1.62 kB
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.dataset import DefaultSampler
from mmpretrain.datasets import (CUB, CenterCrop, LoadImageFromFile,
PackInputs, RandomCrop, RandomFlip, Resize)
from mmpretrain.evaluation import Accuracy
# dataset settings
dataset_type = CUB
data_preprocessor = dict(
num_classes=200,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True,
)
train_pipeline = [
dict(type=LoadImageFromFile),
dict(type=Resize, scale=510),
dict(type=RandomCrop, crop_size=384),
dict(type=RandomFlip, prob=0.5, direction='horizontal'),
dict(type=PackInputs),
]
test_pipeline = [
dict(type=LoadImageFromFile),
dict(type=Resize, scale=510),
dict(type=CenterCrop, crop_size=384),
dict(type=PackInputs),
]
train_dataloader = dict(
batch_size=8,
num_workers=2,
dataset=dict(
type=dataset_type,
data_root='data/CUB_200_2011',
split='train',
pipeline=train_pipeline),
sampler=dict(type=DefaultSampler, shuffle=True),
)
val_dataloader = dict(
batch_size=8,
num_workers=2,
dataset=dict(
type=dataset_type,
data_root='data/CUB_200_2011',
split='test',
pipeline=test_pipeline),
sampler=dict(type=DefaultSampler, shuffle=False),
)
val_evaluator = dict(type=Accuracy, topk=(1, ))
test_dataloader = val_dataloader
test_evaluator = val_evaluator