ZennyKenny commited on
Commit
6f7ec0d
·
verified ·
1 Parent(s): d71f074

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -10
app.py CHANGED
@@ -42,6 +42,15 @@ class SyntheticDataGenerator:
42
  max_training_time: int = 60,
43
  batch_size: int = 32,
44
  value_protection: bool = True,
 
 
 
 
 
 
 
 
 
45
  ) -> Tuple[bool, str]:
46
  if not self.mostly:
47
  return False, "Mostly AI SDK not initialized. Please initialize the SDK first."
@@ -57,6 +66,15 @@ class SyntheticDataGenerator:
57
  "max_training_time": max_training_time,
58
  "value_protection": value_protection,
59
  "batch_size": batch_size,
 
 
 
 
 
 
 
 
 
60
  },
61
  }
62
  ]
@@ -110,11 +128,34 @@ def train_model(
110
  max_training_time: int,
111
  batch_size: int,
112
  value_protection: bool,
 
 
 
 
 
 
 
 
 
113
  ) -> str:
114
  if data is None or data.empty:
115
  return "Error: No data provided. Please upload or create sample data first."
116
  ok, msg = generator.train_generator(
117
- data, model_name, epochs, max_training_time, batch_size, value_protection
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  )
119
  return ("Success: " if ok else "Error: ") + msg
120
 
@@ -209,23 +250,81 @@ def create_interface():
209
  memory_info = gr.Markdown(label="Memory Usage Info", visible=False)
210
 
211
  with gr.Row():
212
- with gr.Column():
213
  model_name = gr.Textbox(
214
- value="My Synthetic Model", label="Model Name", placeholder="Enter a name for your model"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  )
216
- epochs = gr.Slider(1, 200, value=100, step=1, label="Training Epochs")
217
- max_training_time = gr.Slider(1, 1000, value=60, step=1, label="Maximum Training Time")
218
- batch_size = gr.Slider(8, 1024, value=32, step=8, label="Training Batch Size")
219
- value_protection = gr.Checkbox(label="Value Protection", info="Enable Value Protection")
220
  train_btn = gr.Button("Train Model", variant="primary")
221
- with gr.Column():
222
  train_status = gr.Textbox(label="Training Status", interactive=False)
223
 
224
  with gr.Tab("Generate Data"):
225
  gr.Markdown("### Generate synthetic data from your trained model")
226
  with gr.Row():
227
  with gr.Column():
228
- gen_size = gr.Slider(10, 1000, value=100, step=10, label="Number of Records to Generate")
 
229
  generate_btn = gr.Button("Generate Synthetic Data", variant="primary")
230
  with gr.Column():
231
  gen_status = gr.Textbox(label="Generation Status", interactive=False)
@@ -243,7 +342,13 @@ def create_interface():
243
 
244
  train_btn.click(
245
  train_model,
246
- inputs=[uploaded_data, model_name, epochs, max_training_time, batch_size, value_protection],
 
 
 
 
 
 
247
  outputs=[train_status],
248
  )
249
 
 
42
  max_training_time: int = 60,
43
  batch_size: int = 32,
44
  value_protection: bool = True,
45
+ rare_category_protection: bool = True,
46
+ flexible_generation: bool = True,
47
+ model_size: str = "MEDIUM",
48
+ target_accuracy: float = 0.95,
49
+ validation_split: float = 0.2,
50
+ learning_rate: float = 0.001,
51
+ early_stopping_patience: int = 10,
52
+ dropout_rate: float = 0.1,
53
+ weight_decay: float = 0.0001,
54
  ) -> Tuple[bool, str]:
55
  if not self.mostly:
56
  return False, "Mostly AI SDK not initialized. Please initialize the SDK first."
 
66
  "max_training_time": max_training_time,
67
  "value_protection": value_protection,
68
  "batch_size": batch_size,
69
+ "rare_category_protection": rare_category_protection,
70
+ "flexible_generation": flexible_generation,
71
+ "model_size": model_size, # "SMALL" | "MEDIUM" | "LARGE"
72
+ "target_accuracy": target_accuracy, # early stop once target met
73
+ "validation_split": validation_split,
74
+ "learning_rate": learning_rate,
75
+ "early_stopping_patience": early_stopping_patience,
76
+ "dropout_rate": dropout_rate,
77
+ "weight_decay": weight_decay,
78
  },
79
  }
80
  ]
 
128
  max_training_time: int,
129
  batch_size: int,
130
  value_protection: bool,
131
+ rare_category_protection: bool,
132
+ flexible_generation: bool,
133
+ model_size: str,
134
+ target_accuracy: float,
135
+ validation_split: float,
136
+ learning_rate: float,
137
+ early_stopping_patience: int,
138
+ dropout_rate: float,
139
+ weight_decay: float,
140
  ) -> str:
141
  if data is None or data.empty:
142
  return "Error: No data provided. Please upload or create sample data first."
143
  ok, msg = generator.train_generator(
144
+ data=data,
145
+ name=model_name,
146
+ epochs=epochs,
147
+ max_training_time=max_training_time,
148
+ batch_size=batch_size,
149
+ value_protection=value_protection,
150
+ rare_category_protection=rare_category_protection,
151
+ flexible_generation=flexible_generation,
152
+ model_size=model_size,
153
+ target_accuracy=target_accuracy,
154
+ validation_split=validation_split,
155
+ learning_rate=learning_rate,
156
+ early_stopping_patience=early_stopping_patience,
157
+ dropout_rate=dropout_rate,
158
+ weight_decay=weight_decay,
159
  )
160
  return ("Success: " if ok else "Error: ") + msg
161
 
 
250
  memory_info = gr.Markdown(label="Memory Usage Info", visible=False)
251
 
252
  with gr.Row():
253
+ with gr.Column(scale=1):
254
  model_name = gr.Textbox(
255
+ value="My Synthetic Model",
256
+ label="Model Name",
257
+ placeholder="Enter a name for your model",
258
+ info="Appears in training runs and saved generators."
259
+ )
260
+ epochs = gr.Slider(
261
+ 1, 200, value=100, step=1, label="Training Epochs",
262
+ info="Maximum number of passes over the training data."
263
+ )
264
+ max_training_time = gr.Slider(
265
+ 1, 1000, value=60, step=1, label="Maximum Training Time (minutes)",
266
+ info="Upper bound in minutes; training stops if exceeded."
267
+ )
268
+ batch_size = gr.Slider(
269
+ 8, 1024, value=32, step=8, label="Batch Size",
270
+ info="Number of rows per optimization step. Larger can speed up but needs more memory."
271
+ )
272
+ value_protection = gr.Checkbox(
273
+ label="Value Protection",
274
+ info="Adds protections to reduce memorization of unique or sensitive values.",
275
+ value=True
276
+ )
277
+ rare_category_protection = gr.Checkbox(
278
+ label="Rare Category Protection",
279
+ info="Prevents overfitting to infrequent categories to improve privacy and robustness.",
280
+ value=True
281
+ )
282
+ with gr.Column(scale=1):
283
+ flexible_generation = gr.Checkbox(
284
+ label="Flexible Generation",
285
+ info="Allows generation when inputs slightly differ from training schema.",
286
+ value=True
287
+ )
288
+ model_size = gr.Dropdown(
289
+ choices=["SMALL", "MEDIUM", "LARGE"],
290
+ value="MEDIUM",
291
+ label="Model Size",
292
+ info="Sets model capacity. Larger can improve fidelity but uses more compute."
293
+ )
294
+ target_accuracy = gr.Slider(
295
+ 0.50, 0.999, value=0.95, step=0.001, label="Target Accuracy",
296
+ info="Stop early when validation accuracy reaches this threshold."
297
+ )
298
+ validation_split = gr.Slider(
299
+ 0.05, 0.5, value=0.2, step=0.01, label="Validation Split",
300
+ info="Fraction of the dataset held out for validation during training."
301
+ )
302
+ early_stopping_patience = gr.Slider(
303
+ 0, 50, value=10, step=1, label="Early Stopping Patience (epochs)",
304
+ info="Stop if no validation improvement after this many epochs."
305
+ )
306
+ with gr.Column(scale=1):
307
+ learning_rate = gr.Number(
308
+ value=0.001, precision=6, label="Learning Rate",
309
+ info="Step size for the optimizer. Typical range: 1e-4 to 1e-2."
310
+ )
311
+ dropout_rate = gr.Slider(
312
+ 0.0, 0.6, value=0.1, step=0.01, label="Dropout Rate",
313
+ info="Regularization to reduce overfitting by randomly dropping units."
314
+ )
315
+ weight_decay = gr.Number(
316
+ value=0.0001, precision=6, label="Weight Decay",
317
+ info="L2 regularization strength applied to model weights."
318
  )
 
 
 
 
319
  train_btn = gr.Button("Train Model", variant="primary")
 
320
  train_status = gr.Textbox(label="Training Status", interactive=False)
321
 
322
  with gr.Tab("Generate Data"):
323
  gr.Markdown("### Generate synthetic data from your trained model")
324
  with gr.Row():
325
  with gr.Column():
326
+ gen_size = gr.Slider(10, 1000, value=100, step=10, label="Number of Records to Generate",
327
+ info="How many synthetic rows to create in the table.")
328
  generate_btn = gr.Button("Generate Synthetic Data", variant="primary")
329
  with gr.Column():
330
  gen_status = gr.Textbox(label="Generation Status", interactive=False)
 
342
 
343
  train_btn.click(
344
  train_model,
345
+ inputs=[
346
+ uploaded_data, model_name,
347
+ epochs, max_training_time, batch_size,
348
+ value_protection, rare_category_protection, flexible_generation,
349
+ model_size, target_accuracy, validation_split,
350
+ learning_rate, early_stopping_patience, dropout_rate, weight_decay
351
+ ],
352
  outputs=[train_status],
353
  )
354