Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import zipfile, os, uuid, shutil, subprocess, sys, time
|
3 |
+
import tensorflow as tf
|
4 |
+
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
# Directory constants
|
8 |
+
UPLOAD_DIR = "uploads"
|
9 |
+
MODEL_DIR = "models"
|
10 |
+
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
11 |
+
os.makedirs(MODEL_DIR, exist_ok=True)
|
12 |
+
|
13 |
+
def train_and_export(dataset_zip, model_name, num_classes, epochs, batch_size, image_size):
|
14 |
+
# Save upload
|
15 |
+
uid = str(uuid.uuid4())
|
16 |
+
zip_path = os.path.join(UPLOAD_DIR, f"{uid}.zip")
|
17 |
+
with open(zip_path, "wb") as f:
|
18 |
+
f.write(dataset_zip.read())
|
19 |
+
extract_path = os.path.join(UPLOAD_DIR, uid)
|
20 |
+
with zipfile.ZipFile(zip_path, "r") as z:
|
21 |
+
z.extractall(extract_path)
|
22 |
+
train_dir = os.path.join(extract_path, "train")
|
23 |
+
val_dir = os.path.join(extract_path, "validation")
|
24 |
+
# Data gens
|
25 |
+
train_gen = ImageDataGenerator(rescale=1/255,
|
26 |
+
rotation_range=20,
|
27 |
+
width_shift_range=0.2,
|
28 |
+
height_shift_range=0.2,
|
29 |
+
horizontal_flip=True,
|
30 |
+
zoom_range=0.2) \
|
31 |
+
.flow_from_directory(train_dir, target_size=(image_size, image_size),
|
32 |
+
batch_size=batch_size, class_mode="categorical")
|
33 |
+
val_gen = ImageDataGenerator(rescale=1/255) \
|
34 |
+
.flow_from_directory(val_dir, target_size=(image_size, image_size),
|
35 |
+
batch_size=batch_size, class_mode="categorical")
|
36 |
+
# Build model (e.g., simplified ResNet-inspired CNN)
|
37 |
+
model = tf.keras.Sequential([
|
38 |
+
tf.keras.layers.Conv2D(32,3,activation="relu",
|
39 |
+
input_shape=(image_size,image_size,3)),
|
40 |
+
tf.keras.layers.BatchNormalization(),
|
41 |
+
tf.keras.layers.MaxPooling2D(),
|
42 |
+
tf.keras.layers.Conv2D(64,3,activation="relu"),
|
43 |
+
tf.keras.layers.BatchNormalization(),
|
44 |
+
tf.keras.layers.MaxPooling2D(),
|
45 |
+
tf.keras.layers.Flatten(),
|
46 |
+
tf.keras.layers.Dense(128,activation="relu"),
|
47 |
+
tf.keras.layers.Dropout(0.5),
|
48 |
+
tf.keras.layers.Dense(num_classes,activation="softmax")
|
49 |
+
])
|
50 |
+
model.compile(optimizer="adam", loss="categorical_crossentropy",
|
51 |
+
metrics=["accuracy"])
|
52 |
+
# Train
|
53 |
+
start = time.time()
|
54 |
+
history = model.fit(train_gen,
|
55 |
+
validation_data=val_gen,
|
56 |
+
epochs=epochs,
|
57 |
+
verbose=0)
|
58 |
+
elapsed = time.time() - start
|
59 |
+
# Save outputs
|
60 |
+
space_model_dir = os.path.join(MODEL_DIR, uid)
|
61 |
+
os.makedirs(space_model_dir, exist_ok=True)
|
62 |
+
h5_path = os.path.join(space_model_dir, f"{model_name}.h5")
|
63 |
+
model.save(h5_path)
|
64 |
+
savedmodel_path = os.path.join(space_model_dir, "savedmodel")
|
65 |
+
model.save(savedmodel_path)
|
66 |
+
tfjs_dir = os.path.join(space_model_dir, "tfjs")
|
67 |
+
subprocess.run([
|
68 |
+
sys.executable, "-m", "tensorflowjs_converter",
|
69 |
+
"--input_format=tf_saved_model",
|
70 |
+
savedmodel_path,
|
71 |
+
tfjs_dir
|
72 |
+
], check=True)
|
73 |
+
size_mb = sum(os.path.getsize(os.path.join(dp,fn))
|
74 |
+
for dp,_,fns in os.walk(space_model_dir)
|
75 |
+
for fn in fns) / (1024*1024)
|
76 |
+
return (f"✅ Trained in {elapsed:.1f}s\n"
|
77 |
+
f"Final val acc: {max(history.history['val_accuracy']):.4f}\n"
|
78 |
+
f"Model size: {size_mb:.2f} MB\n"
|
79 |
+
f"Download H5: {h5_path}\n"
|
80 |
+
f"Download SavedModel: {savedmodel_path}\n"
|
81 |
+
f"Download TF.js: {tfjs_dir}")
|
82 |
+
|
83 |
+
# Gradio interface
|
84 |
+
demo = gr.Interface(
|
85 |
+
fn=train_and_export,
|
86 |
+
inputs=[
|
87 |
+
gr.File(label="Dataset ZIP"),
|
88 |
+
gr.Textbox(label="Model Name", value="my_model"),
|
89 |
+
gr.Slider(2, 100, value=5, label="Number of Classes"),
|
90 |
+
gr.Slider(1, 200, value=50, step=1, label="Epochs"),
|
91 |
+
gr.Radio([16,32,64,128], value=32, label="Batch Size"),
|
92 |
+
gr.Radio([128,224,256], value=224, label="Image Size")
|
93 |
+
],
|
94 |
+
outputs="text",
|
95 |
+
title="AI Image Classifier Trainer",
|
96 |
+
description="Upload train/validation dataset, train a CNN, and download H5, SavedModel, or TF.js."
|
97 |
+
)
|
98 |
+
|
99 |
+
if __name__ == "__main__":
|
100 |
+
demo.launch()
|