File size: 4,171 Bytes
a08e47f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import gradio as gr
import zipfile, os, uuid, shutil, subprocess, sys, time
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np

# Directory constants
UPLOAD_DIR = "uploads"
MODEL_DIR = "models"
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)

def train_and_export(dataset_zip, model_name, num_classes, epochs, batch_size, image_size):
    # Save upload
    uid = str(uuid.uuid4())
    zip_path = os.path.join(UPLOAD_DIR, f"{uid}.zip")
    with open(zip_path, "wb") as f:
        f.write(dataset_zip.read())
    extract_path = os.path.join(UPLOAD_DIR, uid)
    with zipfile.ZipFile(zip_path, "r") as z:
        z.extractall(extract_path)
    train_dir = os.path.join(extract_path, "train")
    val_dir   = os.path.join(extract_path, "validation")
    # Data gens
    train_gen = ImageDataGenerator(rescale=1/255,
                                   rotation_range=20,
                                   width_shift_range=0.2,
                                   height_shift_range=0.2,
                                   horizontal_flip=True,
                                   zoom_range=0.2) \
        .flow_from_directory(train_dir, target_size=(image_size, image_size),
                             batch_size=batch_size, class_mode="categorical")
    val_gen = ImageDataGenerator(rescale=1/255) \
        .flow_from_directory(val_dir, target_size=(image_size, image_size),
                             batch_size=batch_size, class_mode="categorical")
    # Build model (e.g., simplified ResNet-inspired CNN)
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32,3,activation="relu",
                               input_shape=(image_size,image_size,3)),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Conv2D(64,3,activation="relu"),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(128,activation="relu"),
        tf.keras.layers.Dropout(0.5),
        tf.keras.layers.Dense(num_classes,activation="softmax")
    ])
    model.compile(optimizer="adam", loss="categorical_crossentropy",
                  metrics=["accuracy"])
    # Train
    start = time.time()
    history = model.fit(train_gen,
                        validation_data=val_gen,
                        epochs=epochs,
                        verbose=0)
    elapsed = time.time() - start
    # Save outputs
    space_model_dir = os.path.join(MODEL_DIR, uid)
    os.makedirs(space_model_dir, exist_ok=True)
    h5_path = os.path.join(space_model_dir, f"{model_name}.h5")
    model.save(h5_path)
    savedmodel_path = os.path.join(space_model_dir, "savedmodel")
    model.save(savedmodel_path)
    tfjs_dir = os.path.join(space_model_dir, "tfjs")
    subprocess.run([
        sys.executable, "-m", "tensorflowjs_converter",
        "--input_format=tf_saved_model",
        savedmodel_path,
        tfjs_dir
    ], check=True)
    size_mb = sum(os.path.getsize(os.path.join(dp,fn))
                  for dp,_,fns in os.walk(space_model_dir)
                  for fn in fns) / (1024*1024)
    return (f"✅ Trained in {elapsed:.1f}s\n"
            f"Final val acc: {max(history.history['val_accuracy']):.4f}\n"
            f"Model size: {size_mb:.2f} MB\n"
            f"Download H5: {h5_path}\n"
            f"Download SavedModel: {savedmodel_path}\n"
            f"Download TF.js: {tfjs_dir}")

# Gradio interface
demo = gr.Interface(
    fn=train_and_export,
    inputs=[
        gr.File(label="Dataset ZIP"),
        gr.Textbox(label="Model Name", value="my_model"),
        gr.Slider(2, 100, value=5, label="Number of Classes"),
        gr.Slider(1, 200, value=50, step=1, label="Epochs"),
        gr.Radio([16,32,64,128], value=32, label="Batch Size"),
        gr.Radio([128,224,256], value=224, label="Image Size")
    ],
    outputs="text",
    title="AI Image Classifier Trainer",
    description="Upload train/validation dataset, train a CNN, and download H5, SavedModel, or TF.js."
)

if __name__ == "__main__":
    demo.launch()