Spaces:
Sleeping
Sleeping
import tensorflow as tf | |
from keras.layers import Conv2D, Dense, MaxPool2D, Input, Flatten, BatchNormalization, Dropout, Layer | |
from keras.models import Model, Sequential | |
from keras.utils import plot_model | |
from keras.optimizers import SGD, Adam | |
from keras.callbacks import EarlyStopping, ModelCheckpoint | |
train_datagen = tf.keras.preprocessing.image.ImageDataGenerator( | |
rescale=1. / 255, | |
rotation_range=30, | |
horizontal_flip=True) | |
valid_datagen = tf.keras.preprocessing.image.ImageDataGenerator( | |
rescale=1. / 255) | |
train_dir = "./dataset/train" | |
val_dir = "./dataset/valid" | |
batch_size = 32 | |
train_data = train_datagen.flow_from_directory( | |
train_dir, | |
target_size=(224, 224), | |
color_mode="rgb", | |
batch_size=batch_size, | |
class_mode="categorical", | |
shuffle=True, | |
seed=42 | |
) | |
val_data = valid_datagen.flow_from_directory( | |
val_dir, | |
target_size=(224, 224), | |
color_mode="rgb", | |
batch_size=batch_size, | |
class_mode="categorical", | |
shuffle=True, | |
seed=42 | |
) | |
print(len(train_data)) | |
lr = 0.001 | |
epochs = 100 | |
cnn_model = Sequential() | |
class Normalization(Layer): | |
def __init__(self): | |
super(Normalization, self).__init__() | |
self.mean = [0.485, 0.456, 0.406] | |
self.std = [0.229, 0.224, 0.225] | |
def call(self, inputs): | |
return (inputs - self.mean) / self.std | |
pretrained_model = tf.keras.applications.ResNet101(include_top=False, | |
input_shape=(224, 224, 3), | |
pooling='max', classes=15, | |
weights='imagenet') | |
for layer in pretrained_model.layers: | |
layer.trainable = False | |
cnn_model.add(Input((224, 224, 3))) | |
cnn_model.add(Normalization()) | |
cnn_model.add(pretrained_model) | |
cnn_model.add(Dense(15)) | |
plot_model(cnn_model, "model.png", show_shapes=True) | |
cnn_model.summary() | |
metrics = ["acc"] | |
cnn_model.compile(loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True), optimizer=SGD(learning_rate=lr, momentum=0.9), | |
metrics=metrics) | |
callbacks = [ | |
ModelCheckpoint("files/model_new.h5"), | |
EarlyStopping(monitor="val_loss", patience=10, restore_best_weights=False), | |
] | |
cnn_model.fit( | |
train_data, | |
validation_data=val_data, | |
epochs=epochs, | |
steps_per_epoch=len(train_data), | |
validation_steps=len(val_data), | |
callbacks=callbacks, | |
) | |