ameenmarashi's picture
Create app.py
a08e47f verified
raw
history blame
4.17 kB
import gradio as gr
import zipfile, os, uuid, shutil, subprocess, sys, time
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np
# Directory constants
UPLOAD_DIR = "uploads"
MODEL_DIR = "models"
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)
def train_and_export(dataset_zip, model_name, num_classes, epochs, batch_size, image_size):
# Save upload
uid = str(uuid.uuid4())
zip_path = os.path.join(UPLOAD_DIR, f"{uid}.zip")
with open(zip_path, "wb") as f:
f.write(dataset_zip.read())
extract_path = os.path.join(UPLOAD_DIR, uid)
with zipfile.ZipFile(zip_path, "r") as z:
z.extractall(extract_path)
train_dir = os.path.join(extract_path, "train")
val_dir = os.path.join(extract_path, "validation")
# Data gens
train_gen = ImageDataGenerator(rescale=1/255,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
horizontal_flip=True,
zoom_range=0.2) \
.flow_from_directory(train_dir, target_size=(image_size, image_size),
batch_size=batch_size, class_mode="categorical")
val_gen = ImageDataGenerator(rescale=1/255) \
.flow_from_directory(val_dir, target_size=(image_size, image_size),
batch_size=batch_size, class_mode="categorical")
# Build model (e.g., simplified ResNet-inspired CNN)
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32,3,activation="relu",
input_shape=(image_size,image_size,3)),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(64,3,activation="relu"),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128,activation="relu"),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(num_classes,activation="softmax")
])
model.compile(optimizer="adam", loss="categorical_crossentropy",
metrics=["accuracy"])
# Train
start = time.time()
history = model.fit(train_gen,
validation_data=val_gen,
epochs=epochs,
verbose=0)
elapsed = time.time() - start
# Save outputs
space_model_dir = os.path.join(MODEL_DIR, uid)
os.makedirs(space_model_dir, exist_ok=True)
h5_path = os.path.join(space_model_dir, f"{model_name}.h5")
model.save(h5_path)
savedmodel_path = os.path.join(space_model_dir, "savedmodel")
model.save(savedmodel_path)
tfjs_dir = os.path.join(space_model_dir, "tfjs")
subprocess.run([
sys.executable, "-m", "tensorflowjs_converter",
"--input_format=tf_saved_model",
savedmodel_path,
tfjs_dir
], check=True)
size_mb = sum(os.path.getsize(os.path.join(dp,fn))
for dp,_,fns in os.walk(space_model_dir)
for fn in fns) / (1024*1024)
return (f"✅ Trained in {elapsed:.1f}s\n"
f"Final val acc: {max(history.history['val_accuracy']):.4f}\n"
f"Model size: {size_mb:.2f} MB\n"
f"Download H5: {h5_path}\n"
f"Download SavedModel: {savedmodel_path}\n"
f"Download TF.js: {tfjs_dir}")
# Gradio interface
demo = gr.Interface(
fn=train_and_export,
inputs=[
gr.File(label="Dataset ZIP"),
gr.Textbox(label="Model Name", value="my_model"),
gr.Slider(2, 100, value=5, label="Number of Classes"),
gr.Slider(1, 200, value=50, step=1, label="Epochs"),
gr.Radio([16,32,64,128], value=32, label="Batch Size"),
gr.Radio([128,224,256], value=224, label="Image Size")
],
outputs="text",
title="AI Image Classifier Trainer",
description="Upload train/validation dataset, train a CNN, and download H5, SavedModel, or TF.js."
)
if __name__ == "__main__":
demo.launch()