|
import gradio as gr |
|
import zipfile, os, uuid, shutil, subprocess, sys, time |
|
import tensorflow as tf |
|
from tensorflow.keras.preprocessing.image import ImageDataGenerator |
|
import numpy as np |
|
|
|
|
|
UPLOAD_DIR = "uploads" |
|
MODEL_DIR = "models" |
|
os.makedirs(UPLOAD_DIR, exist_ok=True) |
|
os.makedirs(MODEL_DIR, exist_ok=True) |
|
|
|
def train_and_export(dataset_zip, model_name, num_classes, epochs, batch_size, image_size): |
|
|
|
uid = str(uuid.uuid4()) |
|
zip_path = os.path.join(UPLOAD_DIR, f"{uid}.zip") |
|
with open(zip_path, "wb") as f: |
|
f.write(dataset_zip.read()) |
|
extract_path = os.path.join(UPLOAD_DIR, uid) |
|
with zipfile.ZipFile(zip_path, "r") as z: |
|
z.extractall(extract_path) |
|
train_dir = os.path.join(extract_path, "train") |
|
val_dir = os.path.join(extract_path, "validation") |
|
|
|
train_gen = ImageDataGenerator(rescale=1/255, |
|
rotation_range=20, |
|
width_shift_range=0.2, |
|
height_shift_range=0.2, |
|
horizontal_flip=True, |
|
zoom_range=0.2) \ |
|
.flow_from_directory(train_dir, target_size=(image_size, image_size), |
|
batch_size=batch_size, class_mode="categorical") |
|
val_gen = ImageDataGenerator(rescale=1/255) \ |
|
.flow_from_directory(val_dir, target_size=(image_size, image_size), |
|
batch_size=batch_size, class_mode="categorical") |
|
|
|
model = tf.keras.Sequential([ |
|
tf.keras.layers.Conv2D(32,3,activation="relu", |
|
input_shape=(image_size,image_size,3)), |
|
tf.keras.layers.BatchNormalization(), |
|
tf.keras.layers.MaxPooling2D(), |
|
tf.keras.layers.Conv2D(64,3,activation="relu"), |
|
tf.keras.layers.BatchNormalization(), |
|
tf.keras.layers.MaxPooling2D(), |
|
tf.keras.layers.Flatten(), |
|
tf.keras.layers.Dense(128,activation="relu"), |
|
tf.keras.layers.Dropout(0.5), |
|
tf.keras.layers.Dense(num_classes,activation="softmax") |
|
]) |
|
model.compile(optimizer="adam", loss="categorical_crossentropy", |
|
metrics=["accuracy"]) |
|
|
|
start = time.time() |
|
history = model.fit(train_gen, |
|
validation_data=val_gen, |
|
epochs=epochs, |
|
verbose=0) |
|
elapsed = time.time() - start |
|
|
|
space_model_dir = os.path.join(MODEL_DIR, uid) |
|
os.makedirs(space_model_dir, exist_ok=True) |
|
h5_path = os.path.join(space_model_dir, f"{model_name}.h5") |
|
model.save(h5_path) |
|
savedmodel_path = os.path.join(space_model_dir, "savedmodel") |
|
model.save(savedmodel_path) |
|
tfjs_dir = os.path.join(space_model_dir, "tfjs") |
|
subprocess.run([ |
|
sys.executable, "-m", "tensorflowjs_converter", |
|
"--input_format=tf_saved_model", |
|
savedmodel_path, |
|
tfjs_dir |
|
], check=True) |
|
size_mb = sum(os.path.getsize(os.path.join(dp,fn)) |
|
for dp,_,fns in os.walk(space_model_dir) |
|
for fn in fns) / (1024*1024) |
|
return (f"✅ Trained in {elapsed:.1f}s\n" |
|
f"Final val acc: {max(history.history['val_accuracy']):.4f}\n" |
|
f"Model size: {size_mb:.2f} MB\n" |
|
f"Download H5: {h5_path}\n" |
|
f"Download SavedModel: {savedmodel_path}\n" |
|
f"Download TF.js: {tfjs_dir}") |
|
|
|
|
|
demo = gr.Interface( |
|
fn=train_and_export, |
|
inputs=[ |
|
gr.File(label="Dataset ZIP"), |
|
gr.Textbox(label="Model Name", value="my_model"), |
|
gr.Slider(2, 100, value=5, label="Number of Classes"), |
|
gr.Slider(1, 200, value=50, step=1, label="Epochs"), |
|
gr.Radio([16,32,64,128], value=32, label="Batch Size"), |
|
gr.Radio([128,224,256], value=224, label="Image Size") |
|
], |
|
outputs="text", |
|
title="AI Image Classifier Trainer", |
|
description="Upload train/validation dataset, train a CNN, and download H5, SavedModel, or TF.js." |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|