|
import gradio as gr |
|
import zipfile |
|
import os |
|
import uuid |
|
import shutil |
|
import subprocess |
|
import sys |
|
import time |
|
from PIL import Image |
|
import tensorflow as tf |
|
from tensorflow.keras.preprocessing.image import ImageDataGenerator |
|
import numpy as np |
|
|
|
|
|
UPLOAD_DIR = "uploads" |
|
MODEL_DIR = "models" |
|
os.makedirs(UPLOAD_DIR, exist_ok=True) |
|
os.makedirs(MODEL_DIR, exist_ok=True) |
|
|
|
def train_and_export(dataset_file, model_name, num_classes, epochs, batch_size, image_size): |
|
try: |
|
uid = str(uuid.uuid4()) |
|
zip_path = os.path.join(UPLOAD_DIR, f"{uid}.zip") |
|
shutil.copyfile(dataset_file.name, zip_path) |
|
|
|
extract_path = os.path.join(UPLOAD_DIR, uid) |
|
os.makedirs(extract_path, exist_ok=True) |
|
with zipfile.ZipFile(zip_path, 'r') as zip_ref: |
|
zip_ref.extractall(extract_path) |
|
|
|
train_dir = os.path.join(extract_path, "train") |
|
val_dir = os.path.join(extract_path, "validation") |
|
|
|
|
|
if not os.path.exists(train_dir) or not os.path.exists(val_dir): |
|
os.makedirs(train_dir, exist_ok=True) |
|
os.makedirs(val_dir, exist_ok=True) |
|
|
|
for split_dir in [train_dir, val_dir]: |
|
for class_name in ["class_a", "class_b"]: |
|
class_path = os.path.join(split_dir, class_name) |
|
os.makedirs(class_path, exist_ok=True) |
|
|
|
|
|
for i in range(2): |
|
img = Image.new('RGB', (image_size, image_size), color=(i * 50, 100, 150)) |
|
img.save(os.path.join(class_path, f"sample_{i}.jpg")) |
|
|
|
|
|
train_datagen = ImageDataGenerator( |
|
rescale=1./255, |
|
rotation_range=20, |
|
width_shift_range=0.2, |
|
height_shift_range=0.2, |
|
horizontal_flip=True, |
|
zoom_range=0.2 |
|
) |
|
val_datagen = ImageDataGenerator(rescale=1./255) |
|
|
|
train_generator = train_datagen.flow_from_directory( |
|
train_dir, |
|
target_size=(image_size, image_size), |
|
batch_size=batch_size, |
|
class_mode='categorical' |
|
) |
|
|
|
val_generator = val_datagen.flow_from_directory( |
|
val_dir, |
|
target_size=(image_size, image_size), |
|
batch_size=batch_size, |
|
class_mode='categorical' |
|
) |
|
|
|
actual_classes = train_generator.num_classes |
|
if actual_classes != num_classes: |
|
num_classes = actual_classes |
|
|
|
model = tf.keras.Sequential([ |
|
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(image_size, image_size, 3)), |
|
tf.keras.layers.BatchNormalization(), |
|
tf.keras.layers.MaxPooling2D(), |
|
tf.keras.layers.Dropout(0.25), |
|
tf.keras.layers.Conv2D(64, 3, activation='relu'), |
|
tf.keras.layers.BatchNormalization(), |
|
tf.keras.layers.MaxPooling2D(), |
|
tf.keras.layers.Dropout(0.25), |
|
tf.keras.layers.Conv2D(128, 3, activation='relu'), |
|
tf.keras.layers.BatchNormalization(), |
|
tf.keras.layers.MaxPooling2D(), |
|
tf.keras.layers.Dropout(0.25), |
|
tf.keras.layers.Flatten(), |
|
tf.keras.layers.Dense(256, activation='relu'), |
|
tf.keras.layers.BatchNormalization(), |
|
tf.keras.layers.Dropout(0.5), |
|
tf.keras.layers.Dense(num_classes, activation='softmax') |
|
]) |
|
|
|
model.compile( |
|
optimizer='adam', |
|
loss='categorical_crossentropy', |
|
metrics=['accuracy'] |
|
) |
|
|
|
start_time = time.time() |
|
history = model.fit( |
|
train_generator, |
|
steps_per_epoch=train_generator.samples // train_generator.batch_size, |
|
epochs=epochs, |
|
validation_data=val_generator, |
|
validation_steps=val_generator.samples // val_generator.batch_size, |
|
verbose=0 |
|
) |
|
training_time = time.time() - start_time |
|
|
|
model_dir = os.path.join(MODEL_DIR, uid) |
|
os.makedirs(model_dir, exist_ok=True) |
|
|
|
h5_path = os.path.join(model_dir, f"{model_name}.h5") |
|
model.save(h5_path) |
|
|
|
savedmodel_path = os.path.join(model_dir, "savedmodel") |
|
model.save(savedmodel_path) |
|
|
|
tfjs_path = os.path.join(model_dir, "tfjs") |
|
try: |
|
subprocess.run([ |
|
"tensorflowjs_converter", |
|
"--input_format=tf_saved_model", |
|
savedmodel_path, |
|
tfjs_path |
|
], check=True) |
|
except Exception: |
|
subprocess.run([sys.executable, "-m", "pip", "install", "tensorflowjs"], check=True) |
|
subprocess.run([ |
|
"tensorflowjs_converter", |
|
"--input_format=tf_saved_model", |
|
savedmodel_path, |
|
tfjs_path |
|
], check=True) |
|
|
|
model_size = 0 |
|
for dirpath, _, filenames in os.walk(model_dir): |
|
for f in filenames: |
|
model_size += os.path.getsize(os.path.join(dirpath, f)) |
|
model_size_mb = model_size / (1024 * 1024) |
|
|
|
result_text = f""" |
|
β
Training completed successfully! |
|
β±οΈ Training time: {training_time:.2f} seconds |
|
π Best validation accuracy: {max(history.history['val_accuracy']):.4f} |
|
π¦ Model size: {model_size_mb:.2f} MB |
|
ποΈ Number of classes: {num_classes} |
|
""" |
|
|
|
return result_text, h5_path, savedmodel_path, tfjs_path |
|
|
|
except Exception as e: |
|
return f"β Training failed: {str(e)}", None, None, None |
|
|
|
|
|
with gr.Blocks(title="AI Image Classifier Trainer") as demo: |
|
gr.Markdown("# πΌοΈ AI Image Classifier Trainer") |
|
gr.Markdown("Upload a ZIP of `train/` and `validation/`, or leave it empty to auto-generate dummy data.") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
dataset = gr.File(label="Dataset ZIP File", file_types=[".zip"]) |
|
model_name = gr.Textbox(label="Model Name", value="my_classifier") |
|
num_classes = gr.Slider(2, 100, value=5, step=1, label="Number of Classes") |
|
epochs = gr.Slider(5, 200, value=30, step=1, label="Training Epochs") |
|
batch_size = gr.Radio([16, 32, 64], value=32, label="Batch Size") |
|
image_size = gr.Radio([128, 224, 256], value=224, label="Image Size (px)") |
|
train_btn = gr.Button("π Train Model", variant="primary") |
|
|
|
with gr.Column(): |
|
output = gr.Textbox(label="Training Results", interactive=False) |
|
with gr.Column(visible=False) as download_col: |
|
h5_download = gr.File(label="H5 Model Download") |
|
savedmodel_download = gr.File(label="SavedModel Download") |
|
tfjs_download = gr.File(label="TensorFlow.js Download") |
|
|
|
def toggle_downloads(result, h5_path, saved_path, tfjs_path): |
|
if h5_path: |
|
return ( |
|
gr.Column(visible=True), |
|
gr.File(value=h5_path), |
|
gr.File(value=saved_path), |
|
gr.File(value=tfjs_path) |
|
) |
|
return ( |
|
gr.Column(visible=False), |
|
gr.File(value=None), |
|
gr.File(value=None), |
|
gr.File(value=None) |
|
) |
|
|
|
train_btn.click( |
|
fn=train_and_export, |
|
inputs=[dataset, model_name, num_classes, epochs, batch_size, image_size], |
|
outputs=[output, h5_download, savedmodel_download, tfjs_download] |
|
).then( |
|
fn=toggle_downloads, |
|
inputs=[output, h5_download, savedmodel_download, tfjs_download], |
|
outputs=[download_col, h5_download, savedmodel_download, tfjs_download] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, max_file_size="100mb") |
|
|