ZennyKenny commited on
Commit
6ce5adf
Β·
verified Β·
1 Parent(s): 01b4e73

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -65
app.py CHANGED
@@ -26,7 +26,6 @@ class SyntheticDataGenerator:
26
  self.original_data = None
27
 
28
  def initialize_mostly_ai(self) -> Tuple[bool, str]:
29
- """Initialize Mostly AI SDK"""
30
  if not MOSTLY_AI_AVAILABLE:
31
  return False, "Mostly AI SDK not installed. Please install with: pip install mostlyai[local]"
32
  try:
@@ -44,7 +43,6 @@ class SyntheticDataGenerator:
44
  batch_size: int = 32,
45
  value_protection: bool = True,
46
  ) -> Tuple[bool, str]:
47
- """Train the synthetic data generator"""
48
  if not self.mostly:
49
  return False, "Mostly AI SDK not initialized. Please initialize the SDK first."
50
  try:
@@ -63,14 +61,12 @@ class SyntheticDataGenerator:
63
  }
64
  ]
65
  }
66
-
67
  self.generator = self.mostly.train(config=train_config)
68
  return True, f"Training completed successfully. Model name: {name}"
69
  except Exception as e:
70
  return False, f"Training failed with error: {str(e)}"
71
 
72
  def generate_synthetic_data(self, size: int) -> Tuple[Optional[pd.DataFrame], str]:
73
- """Generate synthetic data"""
74
  if not self.generator:
75
  return None, "No trained generator available. Please train a model first."
76
  try:
@@ -82,27 +78,28 @@ class SyntheticDataGenerator:
82
 
83
  def get_quality_report_file(self) -> Optional[str]:
84
  """
85
- Generate/export the quality report and return a file path for download.
86
- Tries to find an existing ZIP; otherwise saves a TXT fallback.
87
  """
88
  if not self.generator:
89
  return None
90
  try:
91
  rep = self.generator.reports(display=False)
92
 
93
- # 1) If a string path to a .zip is returned
94
  if isinstance(rep, str) and rep.endswith(".zip") and os.path.exists(rep):
95
  return rep
96
 
97
- # 2) If the object exposes a path-like attribute
98
  for attr in ("archive_path", "zip_path", "path", "file_path"):
99
  if hasattr(rep, attr):
100
  p = getattr(rep, attr)
101
  if isinstance(p, str) and os.path.exists(p):
102
  return p
103
 
104
- # 3) If the object can save/export itself
105
- target_zip = "/mnt/data/quality_report.zip"
 
106
  if hasattr(rep, "save"):
107
  try:
108
  rep.save(target_zip)
@@ -118,8 +115,8 @@ class SyntheticDataGenerator:
118
  except Exception:
119
  pass
120
 
121
- # 4) Fallback: write string representation
122
- target_txt = "/mnt/data/quality_report.txt"
123
  with open(target_txt, "w", encoding="utf-8") as f:
124
  f.write(str(rep))
125
  return target_txt
@@ -128,21 +125,12 @@ class SyntheticDataGenerator:
128
  return None
129
 
130
  def estimate_memory_usage(self, df: pd.DataFrame) -> str:
131
- """Estimate memory usage for the dataset"""
132
  if df is None or df.empty:
133
  return "No data available to analyze."
134
-
135
  memory_mb = df.memory_usage(deep=True).sum() / (1024 * 1024)
136
  rows, cols = len(df), len(df.columns)
137
  estimated_training_mb = memory_mb * 4
138
-
139
- if memory_mb < 100:
140
- status = "Good"
141
- elif memory_mb < 500:
142
- status = "Large"
143
- else:
144
- status = "Very Large"
145
-
146
  return f"""
147
  Memory Usage Estimate:
148
  - Data size: {memory_mb:.1f} MB
@@ -152,10 +140,12 @@ Memory Usage Estimate:
152
  """.strip()
153
 
154
 
155
- # Initialize the generator
156
  generator = SyntheticDataGenerator()
 
 
157
 
158
- # ---- Wrapper functions for Gradio ----
159
  def initialize_sdk() -> str:
160
  ok, msg = generator.initialize_mostly_ai()
161
  return ("Success: " if ok else "Error: ") + msg
@@ -178,62 +168,53 @@ def train_model(
178
 
179
 
180
  def generate_data(size: int) -> Tuple[Optional[pd.DataFrame], str]:
181
- synthetic_df, message = generator.generate_synthetic_data(size)
182
- status = "Success" if synthetic_df is not None else "Error"
183
- return synthetic_df, f"{status}: {message}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
 
186
- def create_comparison_plot(original_df: pd.DataFrame, synthetic_df: pd.DataFrame) -> Optional[go.Figure]:
187
  if original_df is None or synthetic_df is None:
188
  return None
189
-
190
  numeric_cols = original_df.select_dtypes(include=[np.number]).columns.tolist()
191
  if not numeric_cols:
192
  return None
193
-
194
  n_cols = min(3, len(numeric_cols))
195
  n_rows = (len(numeric_cols) + n_cols - 1) // n_cols
196
-
197
  fig = make_subplots(rows=n_rows, cols=n_cols, subplot_titles=numeric_cols[: n_rows * n_cols])
198
-
199
  for i, col in enumerate(numeric_cols[: n_rows * n_cols]):
200
  row = i // n_cols + 1
201
  col_idx = i % n_cols + 1
202
-
203
- fig.add_trace(
204
- go.Histogram(x=original_df[col], name=f"Original {col}", opacity=0.7, nbinsx=20),
205
- row=row,
206
- col=col_idx,
207
- )
208
- fig.add_trace(
209
- go.Histogram(x=synthetic_df[col], name=f"Synthetic {col}", opacity=0.7, nbinsx=20),
210
- row=row,
211
- col=col_idx,
212
- )
213
-
214
  fig.update_layout(title="Original vs Synthetic Data Comparison", height=300 * n_rows, showlegend=True)
215
  return fig
216
 
217
 
218
- def download_csv(df: pd.DataFrame) -> Optional[str]:
219
- if df is None or df.empty:
220
- return None
221
- path = "/mnt/data/synthetic_data.csv"
222
- df.to_csv(path, index=False)
223
- return path
224
-
225
-
226
  # ---- UI ----
227
  def create_interface():
228
  with gr.Blocks(title="MOSTLY AI Synthetic Data Generator", theme=gr.themes.Soft()) as demo:
229
- # Header image
230
  gr.Image(
231
  value="https://img.mailinblue.com/8225865/images/content_library/original/6880d164e4e4ea1a183ad4c0.png",
232
  show_label=False,
233
  elem_id="header-image",
234
  )
235
 
236
- # README
237
  gr.Markdown(
238
  """
239
  # Synthetic Data SDK by MOSTLY AI Demo Space
@@ -289,6 +270,7 @@ def create_interface():
289
  train_status = gr.Textbox(label="Training Status", interactive=False)
290
 
291
  with gr.Row():
 
292
  get_report_btn = gr.DownloadButton("Get Quality Report", variant="secondary")
293
 
294
  with gr.Tab("Generate Data"):
@@ -302,10 +284,11 @@ def create_interface():
302
 
303
  synthetic_data = gr.Dataframe(label="Synthetic Data", interactive=False)
304
  with gr.Row():
 
305
  download_btn = gr.DownloadButton("Download CSV", variant="secondary")
306
  comparison_plot = gr.Plot(label="Data Comparison")
307
 
308
- # ---- Event handlers ----
309
  init_btn.click(initialize_sdk, outputs=[init_status])
310
 
311
  train_btn.click(
@@ -314,21 +297,18 @@ def create_interface():
314
  outputs=[train_status],
315
  )
316
 
317
- # Direct download of quality report
318
- get_report_btn.click(generator.get_quality_report_file, outputs=[get_report_btn])
319
 
320
- # Generate data
321
  generate_btn.click(generate_data, inputs=[gen_size], outputs=[synthetic_data, gen_status])
322
 
323
- # Update CSV DownloadButton whenever synthetic data changes
324
- synthetic_data.change(download_csv, inputs=[synthetic_data], outputs=[download_btn])
325
-
326
  # Build comparison plot when both datasets are available
327
- synthetic_data.change(
328
- create_comparison_plot, inputs=[uploaded_data, synthetic_data], outputs=[comparison_plot]
329
- )
 
330
 
331
- # Handle file upload with size and column limits
332
  def process_uploaded_file(file):
333
  if file is None:
334
  return None, "No file uploaded.", gr.update(visible=False)
 
26
  self.original_data = None
27
 
28
  def initialize_mostly_ai(self) -> Tuple[bool, str]:
 
29
  if not MOSTLY_AI_AVAILABLE:
30
  return False, "Mostly AI SDK not installed. Please install with: pip install mostlyai[local]"
31
  try:
 
43
  batch_size: int = 32,
44
  value_protection: bool = True,
45
  ) -> Tuple[bool, str]:
 
46
  if not self.mostly:
47
  return False, "Mostly AI SDK not initialized. Please initialize the SDK first."
48
  try:
 
61
  }
62
  ]
63
  }
 
64
  self.generator = self.mostly.train(config=train_config)
65
  return True, f"Training completed successfully. Model name: {name}"
66
  except Exception as e:
67
  return False, f"Training failed with error: {str(e)}"
68
 
69
  def generate_synthetic_data(self, size: int) -> Tuple[Optional[pd.DataFrame], str]:
 
70
  if not self.generator:
71
  return None, "No trained generator available. Please train a model first."
72
  try:
 
78
 
79
  def get_quality_report_file(self) -> Optional[str]:
80
  """
81
+ Build/export the quality report and return a file path for immediate download.
82
+ Uses /tmp for Spaces; tries ZIP, falls back to TXT.
83
  """
84
  if not self.generator:
85
  return None
86
  try:
87
  rep = self.generator.reports(display=False)
88
 
89
+ # If a string path to a .zip is returned
90
  if isinstance(rep, str) and rep.endswith(".zip") and os.path.exists(rep):
91
  return rep
92
 
93
+ # If object exposes a path-like attribute
94
  for attr in ("archive_path", "zip_path", "path", "file_path"):
95
  if hasattr(rep, attr):
96
  p = getattr(rep, attr)
97
  if isinstance(p, str) and os.path.exists(p):
98
  return p
99
 
100
+ # Try saving/exporting
101
+ os.makedirs("/tmp", exist_ok=True)
102
+ target_zip = "/tmp/quality_report.zip"
103
  if hasattr(rep, "save"):
104
  try:
105
  rep.save(target_zip)
 
115
  except Exception:
116
  pass
117
 
118
+ # Fallback: stringify into TXT
119
+ target_txt = "/tmp/quality_report.txt"
120
  with open(target_txt, "w", encoding="utf-8") as f:
121
  f.write(str(rep))
122
  return target_txt
 
125
  return None
126
 
127
  def estimate_memory_usage(self, df: pd.DataFrame) -> str:
 
128
  if df is None or df.empty:
129
  return "No data available to analyze."
 
130
  memory_mb = df.memory_usage(deep=True).sum() / (1024 * 1024)
131
  rows, cols = len(df), len(df.columns)
132
  estimated_training_mb = memory_mb * 4
133
+ status = "Good" if memory_mb < 100 else ("Large" if memory_mb < 500 else "Very Large")
 
 
 
 
 
 
 
134
  return f"""
135
  Memory Usage Estimate:
136
  - Data size: {memory_mb:.1f} MB
 
140
  """.strip()
141
 
142
 
143
+ # App state
144
  generator = SyntheticDataGenerator()
145
+ _last_synth_df: Optional[pd.DataFrame] = None # store latest synthetic DF for download
146
+
147
 
148
+ # ---- Gradio wrappers ----
149
  def initialize_sdk() -> str:
150
  ok, msg = generator.initialize_mostly_ai()
151
  return ("Success: " if ok else "Error: ") + msg
 
168
 
169
 
170
  def generate_data(size: int) -> Tuple[Optional[pd.DataFrame], str]:
171
+ global _last_synth_df
172
+ synth_df, message = generator.generate_synthetic_data(size)
173
+ if synth_df is not None:
174
+ _last_synth_df = synth_df.copy()
175
+ return synth_df, f"Success: {message}"
176
+ else:
177
+ return None, f"Error: {message}"
178
+
179
+
180
+ def download_csv_now() -> Optional[str]:
181
+ """Write the most recent synthetic DF to /tmp and return the path for direct download."""
182
+ global _last_synth_df
183
+ if _last_synth_df is None or _last_synth_df.empty:
184
+ return None
185
+ os.makedirs("/tmp", exist_ok=True)
186
+ path = "/tmp/synthetic_data.csv"
187
+ _last_synth_df.to_csv(path, index=False)
188
+ return path
189
 
190
 
191
+ def create_comparison_plot(original_df: pd.DataFrame, synthetic_df: pd.DataFrame):
192
  if original_df is None or synthetic_df is None:
193
  return None
 
194
  numeric_cols = original_df.select_dtypes(include=[np.number]).columns.tolist()
195
  if not numeric_cols:
196
  return None
 
197
  n_cols = min(3, len(numeric_cols))
198
  n_rows = (len(numeric_cols) + n_cols - 1) // n_cols
 
199
  fig = make_subplots(rows=n_rows, cols=n_cols, subplot_titles=numeric_cols[: n_rows * n_cols])
 
200
  for i, col in enumerate(numeric_cols[: n_rows * n_cols]):
201
  row = i // n_cols + 1
202
  col_idx = i % n_cols + 1
203
+ fig.add_trace(go.Histogram(x=original_df[col], name=f"Original {col}", opacity=0.7, nbinsx=20), row=row, col=col_idx)
204
+ fig.add_trace(go.Histogram(x=synthetic_df[col], name=f"Synthetic {col}", opacity=0.7, nbinsx=20), row=row, col=col_idx)
 
 
 
 
 
 
 
 
 
 
205
  fig.update_layout(title="Original vs Synthetic Data Comparison", height=300 * n_rows, showlegend=True)
206
  return fig
207
 
208
 
 
 
 
 
 
 
 
 
209
  # ---- UI ----
210
  def create_interface():
211
  with gr.Blocks(title="MOSTLY AI Synthetic Data Generator", theme=gr.themes.Soft()) as demo:
 
212
  gr.Image(
213
  value="https://img.mailinblue.com/8225865/images/content_library/original/6880d164e4e4ea1a183ad4c0.png",
214
  show_label=False,
215
  elem_id="header-image",
216
  )
217
 
 
218
  gr.Markdown(
219
  """
220
  # Synthetic Data SDK by MOSTLY AI Demo Space
 
270
  train_status = gr.Textbox(label="Training Status", interactive=False)
271
 
272
  with gr.Row():
273
+ # This download button calls a function that returns a file path β†’ download starts immediately
274
  get_report_btn = gr.DownloadButton("Get Quality Report", variant="secondary")
275
 
276
  with gr.Tab("Generate Data"):
 
284
 
285
  synthetic_data = gr.Dataframe(label="Synthetic Data", interactive=False)
286
  with gr.Row():
287
+ # Same pattern: click β†’ function returns the CSV path β†’ immediate download
288
  download_btn = gr.DownloadButton("Download CSV", variant="secondary")
289
  comparison_plot = gr.Plot(label="Data Comparison")
290
 
291
+ # ---- Events ----
292
  init_btn.click(initialize_sdk, outputs=[init_status])
293
 
294
  train_btn.click(
 
297
  outputs=[train_status],
298
  )
299
 
300
+ # IMPORTANT: For DownloadButton, do NOT specify outputs β€” the returned path is auto-downloaded.
301
+ get_report_btn.click(generator.get_quality_report_file, inputs=None, outputs=None)
302
 
 
303
  generate_btn.click(generate_data, inputs=[gen_size], outputs=[synthetic_data, gen_status])
304
 
 
 
 
305
  # Build comparison plot when both datasets are available
306
+ synthetic_data.change(create_comparison_plot, inputs=[uploaded_data, synthetic_data], outputs=[comparison_plot])
307
+
308
+ # CSV download: return a path from the click handler (no outputs)
309
+ download_btn.click(download_csv_now, inputs=None, outputs=None)
310
 
311
+ # File upload handler
312
  def process_uploaded_file(file):
313
  if file is None:
314
  return None, "No file uploaded.", gr.update(visible=False)