SpiderClassification / gendata.py
Spidartist's picture
Upload 13 files
961d213
import tensorflow as tf
from keras.layers import Conv2D, Dense, MaxPool2D, Input, Flatten, BatchNormalization, Dropout, Layer
from keras.models import Model, Sequential
from keras.utils import plot_model
from keras.optimizers import SGD, Adam
from keras.callbacks import EarlyStopping, ModelCheckpoint
train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
rescale=1. / 255,
rotation_range=30,
horizontal_flip=True)
valid_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
rescale=1. / 255)
train_dir = "./dataset/train"
val_dir = "./dataset/valid"
batch_size = 32
train_data = train_datagen.flow_from_directory(
train_dir,
target_size=(224, 224),
color_mode="rgb",
batch_size=batch_size,
class_mode="categorical",
shuffle=True,
seed=42
)
val_data = valid_datagen.flow_from_directory(
val_dir,
target_size=(224, 224),
color_mode="rgb",
batch_size=batch_size,
class_mode="categorical",
shuffle=True,
seed=42
)
print(len(train_data))
lr = 0.001
epochs = 100
cnn_model = Sequential()
class Normalization(Layer):
def __init__(self):
super(Normalization, self).__init__()
self.mean = [0.485, 0.456, 0.406]
self.std = [0.229, 0.224, 0.225]
def call(self, inputs):
return (inputs - self.mean) / self.std
pretrained_model = tf.keras.applications.ResNet101(include_top=False,
input_shape=(224, 224, 3),
pooling='max', classes=15,
weights='imagenet')
for layer in pretrained_model.layers:
layer.trainable = False
cnn_model.add(Input((224, 224, 3)))
cnn_model.add(Normalization())
cnn_model.add(pretrained_model)
cnn_model.add(Dense(15))
plot_model(cnn_model, "model.png", show_shapes=True)
cnn_model.summary()
metrics = ["acc"]
cnn_model.compile(loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True), optimizer=SGD(learning_rate=lr, momentum=0.9),
metrics=metrics)
callbacks = [
ModelCheckpoint("files/model_new.h5"),
EarlyStopping(monitor="val_loss", patience=10, restore_best_weights=False),
]
cnn_model.fit(
train_data,
validation_data=val_data,
epochs=epochs,
steps_per_epoch=len(train_data),
validation_steps=len(val_data),
callbacks=callbacks,
)