File size: 4,171 Bytes
a08e47f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import gradio as gr
import zipfile, os, uuid, shutil, subprocess, sys, time
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np
# Directory constants
UPLOAD_DIR = "uploads"
MODEL_DIR = "models"
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)
def train_and_export(dataset_zip, model_name, num_classes, epochs, batch_size, image_size):
# Save upload
uid = str(uuid.uuid4())
zip_path = os.path.join(UPLOAD_DIR, f"{uid}.zip")
with open(zip_path, "wb") as f:
f.write(dataset_zip.read())
extract_path = os.path.join(UPLOAD_DIR, uid)
with zipfile.ZipFile(zip_path, "r") as z:
z.extractall(extract_path)
train_dir = os.path.join(extract_path, "train")
val_dir = os.path.join(extract_path, "validation")
# Data gens
train_gen = ImageDataGenerator(rescale=1/255,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
horizontal_flip=True,
zoom_range=0.2) \
.flow_from_directory(train_dir, target_size=(image_size, image_size),
batch_size=batch_size, class_mode="categorical")
val_gen = ImageDataGenerator(rescale=1/255) \
.flow_from_directory(val_dir, target_size=(image_size, image_size),
batch_size=batch_size, class_mode="categorical")
# Build model (e.g., simplified ResNet-inspired CNN)
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32,3,activation="relu",
input_shape=(image_size,image_size,3)),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(64,3,activation="relu"),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128,activation="relu"),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(num_classes,activation="softmax")
])
model.compile(optimizer="adam", loss="categorical_crossentropy",
metrics=["accuracy"])
# Train
start = time.time()
history = model.fit(train_gen,
validation_data=val_gen,
epochs=epochs,
verbose=0)
elapsed = time.time() - start
# Save outputs
space_model_dir = os.path.join(MODEL_DIR, uid)
os.makedirs(space_model_dir, exist_ok=True)
h5_path = os.path.join(space_model_dir, f"{model_name}.h5")
model.save(h5_path)
savedmodel_path = os.path.join(space_model_dir, "savedmodel")
model.save(savedmodel_path)
tfjs_dir = os.path.join(space_model_dir, "tfjs")
subprocess.run([
sys.executable, "-m", "tensorflowjs_converter",
"--input_format=tf_saved_model",
savedmodel_path,
tfjs_dir
], check=True)
size_mb = sum(os.path.getsize(os.path.join(dp,fn))
for dp,_,fns in os.walk(space_model_dir)
for fn in fns) / (1024*1024)
return (f"✅ Trained in {elapsed:.1f}s\n"
f"Final val acc: {max(history.history['val_accuracy']):.4f}\n"
f"Model size: {size_mb:.2f} MB\n"
f"Download H5: {h5_path}\n"
f"Download SavedModel: {savedmodel_path}\n"
f"Download TF.js: {tfjs_dir}")
# Gradio interface
demo = gr.Interface(
fn=train_and_export,
inputs=[
gr.File(label="Dataset ZIP"),
gr.Textbox(label="Model Name", value="my_model"),
gr.Slider(2, 100, value=5, label="Number of Classes"),
gr.Slider(1, 200, value=50, step=1, label="Epochs"),
gr.Radio([16,32,64,128], value=32, label="Batch Size"),
gr.Radio([128,224,256], value=224, label="Image Size")
],
outputs="text",
title="AI Image Classifier Trainer",
description="Upload train/validation dataset, train a CNN, and download H5, SavedModel, or TF.js."
)
if __name__ == "__main__":
demo.launch()
|