|
import tensorflow as tf |
|
from tensorflow.keras.applications import ResNet50 |
|
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D |
|
from tensorflow.keras.models import Model |
|
from tensorflow.keras.optimizers import Adam |
|
from tensorflow.keras.preprocessing.image import ImageDataGenerator |
|
|
|
|
|
train_data_dir = 'data/train' |
|
validation_data_dir = 'data/validation' |
|
|
|
|
|
|
|
num_classes = 50 |
|
batch_size = 32 |
|
|
|
|
|
train_datagen = ImageDataGenerator( |
|
rescale=1.0 / 255, |
|
shear_range=0.2, |
|
zoom_range=0.2, |
|
horizontal_flip=True |
|
) |
|
|
|
validation_datagen = ImageDataGenerator(rescale=1.0 / 255) |
|
|
|
|
|
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3)) |
|
|
|
|
|
x = base_model.output |
|
x = GlobalAveragePooling2D()(x) |
|
|
|
|
|
x = Dense(1024, activation='relu')(x) |
|
|
|
|
|
predictions = Dense(num_classes, activation='sigmoid')(x) |
|
|
|
|
|
model = Model(inputs=base_model.input, outputs=predictions) |
|
|
|
|
|
for layer in base_model.layers: |
|
layer.trainable = False |
|
|
|
|
|
model.compile(optimizer=Adam(lr=0.001), loss='binary_crossentropy', metrics=['accuracy']) |
|
|
|
|
|
train_generator = train_datagen.flow_from_directory( |
|
train_data_dir, |
|
target_size=(224, 224), |
|
batch_size=batch_size, |
|
class_mode='categorical' |
|
) |
|
|
|
validation_generator = validation_datagen.flow_from_directory( |
|
validation_data_dir, |
|
target_size=(224, 224), |
|
batch_size=batch_size, |
|
class_mode='categorical' |
|
) |
|
|
|
|
|
model.fit( |
|
train_generator, |
|
steps_per_epoch=train_generator.samples, |
|
epochs=10, |
|
validation_data=validation_generator, |
|
validation_steps=validation_generator.samples |
|
) |
|
|
|
|
|
model.save('deepfashion_attribute_model.h5') |
|
|