|
import os |
|
import numpy as np |
|
import tensorflow as tf |
|
import matplotlib.pyplot as plt |
|
from tensorflow.keras.preprocessing.image import ImageDataGenerator |
|
from model import create_model |
|
|
|
base_dir = 'data/chest_xray' |
|
train_dir = os.path.join(base_dir, 'train') |
|
val_dir = os.path.join(base_dir, 'val') |
|
|
|
train_datagen = ImageDataGenerator( |
|
rescale=1./255, |
|
rotation_range=20, |
|
width_shift_range=0.2, |
|
height_shift_range=0.2, |
|
shear_range=0.2, |
|
zoom_range=0.2, |
|
horizontal_flip=True, |
|
fill_mode='nearest' |
|
) |
|
val_datagen = ImageDataGenerator(rescale=1./255) |
|
|
|
train_generator = train_datagen.flow_from_directory( |
|
train_dir, |
|
target_size=(150, 150), |
|
batch_size=32, |
|
class_mode='binary' |
|
) |
|
|
|
val_generator = val_datagen.flow_from_directory( |
|
val_dir, |
|
target_size=(150, 150), |
|
batch_size=32, |
|
class_mode='binary' |
|
) |
|
|
|
sample_images, _ = next(train_generator) |
|
for i in range(5): |
|
plt.subplot(1, 5, i+1) |
|
plt.imshow(sample_images[i]) |
|
plt.axis('off') |
|
plt.show() |
|
|
|
model = create_model() |
|
|
|
history = model.fit( |
|
train_generator, |
|
steps_per_epoch=243, |
|
epochs=10, |
|
validation_data=val_generator, |
|
validation_steps=280, |
|
callbacks=[tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)] |
|
) |
|
|
|
model.save('xray_image_classifier_model.keras') |
|
|