Tonic commited on
Commit
97ce9ed
Β·
verified Β·
1 Parent(s): 102da75

Upload dataset_utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. dataset_utils.py +470 -0
dataset_utils.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Dataset utilities for Trackio experiment data management
4
+ Provides functions for safe dataset operations with data preservation
5
+ """
6
+
7
+ import json
8
+ import logging
9
+ from datetime import datetime
10
+ from typing import Dict, Any, List, Optional, Union
11
+ from datasets import Dataset, load_dataset
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ class TrackioDatasetManager:
16
+ """
17
+ Manager class for Trackio experiment datasets with data preservation.
18
+
19
+ This class ensures that existing experiment data is always preserved
20
+ when adding new experiments or updating existing ones.
21
+ """
22
+
23
+ def __init__(self, dataset_repo: str, hf_token: str):
24
+ """
25
+ Initialize the dataset manager.
26
+
27
+ Args:
28
+ dataset_repo (str): HF dataset repository ID (e.g., "username/dataset-name")
29
+ hf_token (str): Hugging Face token for authentication
30
+ """
31
+ self.dataset_repo = dataset_repo
32
+ self.hf_token = hf_token
33
+ self._validate_repo_format()
34
+
35
+ def _validate_repo_format(self):
36
+ """Validate dataset repository format"""
37
+ if not self.dataset_repo or '/' not in self.dataset_repo:
38
+ raise ValueError(f"Invalid dataset repository format: {self.dataset_repo}")
39
+
40
+ def check_dataset_exists(self) -> bool:
41
+ """
42
+ Check if the dataset repository exists and is accessible.
43
+
44
+ Returns:
45
+ bool: True if dataset exists and is accessible, False otherwise
46
+ """
47
+ try:
48
+ # Try standard load first
49
+ load_dataset(self.dataset_repo, token=self.hf_token)
50
+ logger.info(f"βœ… Dataset {self.dataset_repo} exists and is accessible")
51
+ return True
52
+ except Exception as e:
53
+ # Retry with relaxed verification to handle split metadata mismatches
54
+ try:
55
+ logger.info(f"πŸ“Š Standard load failed: {e}. Retrying with relaxed verification...")
56
+ load_dataset(
57
+ self.dataset_repo,
58
+ token=self.hf_token,
59
+ verification_mode="no_checks" # type: ignore[arg-type]
60
+ )
61
+ logger.info(f"βœ… Dataset {self.dataset_repo} accessible with relaxed verification")
62
+ return True
63
+ except Exception as e2:
64
+ logger.info(f"πŸ“Š Dataset {self.dataset_repo} doesn't exist or isn't accessible: {e2}")
65
+ return False
66
+
67
+ def load_existing_experiments(self) -> List[Dict[str, Any]]:
68
+ """
69
+ Load all existing experiments from the dataset.
70
+
71
+ Returns:
72
+ List[Dict[str, Any]]: List of existing experiment dictionaries
73
+ """
74
+ try:
75
+ if not self.check_dataset_exists():
76
+ logger.info("πŸ“Š No existing dataset found, returning empty list")
77
+ return []
78
+
79
+ # Load with relaxed verification to avoid split-metadata mismatches blocking reads
80
+ try:
81
+ dataset = load_dataset(self.dataset_repo, token=self.hf_token)
82
+ except Exception:
83
+ dataset = load_dataset(self.dataset_repo, token=self.hf_token, verification_mode="no_checks") # type: ignore[arg-type]
84
+
85
+ if 'train' not in dataset:
86
+ logger.info("πŸ“Š No 'train' split found in dataset")
87
+ return []
88
+
89
+ experiments = list(dataset['train'])
90
+ logger.info(f"πŸ“Š Loaded {len(experiments)} existing experiments")
91
+
92
+ # Validate experiment structure
93
+ valid_experiments = []
94
+ for exp in experiments:
95
+ if self._validate_experiment_structure(exp):
96
+ valid_experiments.append(exp)
97
+ else:
98
+ logger.warning(f"⚠️ Skipping invalid experiment: {exp.get('experiment_id', 'unknown')}")
99
+
100
+ logger.info(f"πŸ“Š {len(valid_experiments)} valid experiments loaded")
101
+ return valid_experiments
102
+
103
+ except Exception as e:
104
+ logger.error(f"❌ Failed to load existing experiments: {e}")
105
+ return []
106
+
107
+ def _validate_experiment_structure(self, experiment: Dict[str, Any]) -> bool:
108
+ """
109
+ Validate and SANITIZE an experiment structure to prevent destructive failures.
110
+
111
+ - Requires 'experiment_id'; otherwise skip the row.
112
+ - Fills defaults for missing non-JSON fields.
113
+ - Normalizes JSON fields to valid JSON strings.
114
+ """
115
+ if not experiment.get('experiment_id'):
116
+ logger.warning("⚠️ Missing required field 'experiment_id' in experiment; skipping row")
117
+ return False
118
+
119
+ defaults = {
120
+ 'name': '',
121
+ 'description': '',
122
+ 'created_at': datetime.now().isoformat(),
123
+ 'status': 'running',
124
+ }
125
+ for key, default_value in defaults.items():
126
+ if experiment.get(key) in (None, ''):
127
+ experiment[key] = default_value
128
+
129
+ def _ensure_json_string(field_name: str, default_value: Any):
130
+ raw_value = experiment.get(field_name)
131
+ try:
132
+ if isinstance(raw_value, str):
133
+ if raw_value.strip() == '':
134
+ experiment[field_name] = json.dumps(default_value, default=str)
135
+ else:
136
+ json.loads(raw_value)
137
+ else:
138
+ experiment[field_name] = json.dumps(
139
+ raw_value if raw_value is not None else default_value,
140
+ default=str
141
+ )
142
+ except Exception:
143
+ experiment[field_name] = json.dumps(default_value, default=str)
144
+
145
+ for json_field, default in (('metrics', []), ('parameters', {}), ('artifacts', []), ('logs', [])):
146
+ _ensure_json_string(json_field, default)
147
+
148
+ return True
149
+
150
+ def save_experiments(self, experiments: List[Dict[str, Any]], commit_message: Optional[str] = None) -> bool:
151
+ """
152
+ Save a list of experiments to the dataset using a non-destructive union merge.
153
+
154
+ - Loads existing experiments (if any) and builds a union by `experiment_id`.
155
+ - For overlapping IDs, merges JSON fields:
156
+ - metrics: concatenates lists and de-duplicates by (step, timestamp) for nested entries
157
+ - parameters: dict-update (new values override)
158
+ - artifacts: union with de-dup
159
+ - logs: concatenation with de-dup
160
+ - Non-JSON scalar fields from incoming experiments take precedence.
161
+
162
+ Args:
163
+ experiments (List[Dict[str, Any]]): List of experiment dictionaries
164
+ commit_message (Optional[str]): Custom commit message
165
+
166
+ Returns:
167
+ bool: True if save was successful, False otherwise
168
+ """
169
+ try:
170
+ if not experiments:
171
+ logger.warning("⚠️ No experiments to save")
172
+ return False
173
+
174
+ # Helpers
175
+ def _parse_json_field(value, default):
176
+ try:
177
+ if value is None:
178
+ return default
179
+ if isinstance(value, str):
180
+ return json.loads(value) if value else default
181
+ return value
182
+ except Exception:
183
+ return default
184
+
185
+ def _metrics_key(entry: Dict[str, Any]):
186
+ if isinstance(entry, dict):
187
+ return (entry.get('step'), entry.get('timestamp'))
188
+ return (None, json.dumps(entry, sort_keys=True))
189
+
190
+ # Load existing experiments for union merge
191
+ existing = {}
192
+ dataset_exists = self.check_dataset_exists()
193
+ try:
194
+ existing_list = self.load_existing_experiments()
195
+ for row in existing_list:
196
+ exp_id = row.get('experiment_id')
197
+ if exp_id:
198
+ existing[exp_id] = row
199
+ except Exception:
200
+ existing = {}
201
+
202
+ # Safety guard: avoid destructive overwrite if dataset exists but
203
+ # we failed to read any existing records (e.g., transient HF issue)
204
+ if dataset_exists and len(existing) == 0 and len(experiments) <= 3:
205
+ logger.error(
206
+ "❌ Refusing to overwrite dataset: existing records could not be loaded "
207
+ "but repository exists. Skipping save to prevent data loss."
208
+ )
209
+ return False
210
+
211
+ # Validate and merge
212
+ merged_map: Dict[str, Dict[str, Any]] = {}
213
+ # Seed with existing
214
+ for exp_id, row in existing.items():
215
+ merged_map[exp_id] = row
216
+
217
+ # Apply incoming
218
+ for exp in experiments:
219
+ if not self._validate_experiment_structure(exp):
220
+ logger.error(f"❌ Invalid experiment structure: {exp.get('experiment_id', 'unknown')}")
221
+ return False
222
+ exp_id = exp['experiment_id']
223
+ incoming = exp
224
+ if exp_id not in merged_map:
225
+ incoming['last_updated'] = incoming.get('last_updated') or datetime.now().isoformat()
226
+ merged_map[exp_id] = incoming
227
+ continue
228
+ # Merge with existing
229
+ base = merged_map[exp_id]
230
+ # Parse JSON fields
231
+ base_metrics = _parse_json_field(base.get('metrics'), [])
232
+ base_params = _parse_json_field(base.get('parameters'), {})
233
+ base_artifacts = _parse_json_field(base.get('artifacts'), [])
234
+ base_logs = _parse_json_field(base.get('logs'), [])
235
+ inc_metrics = _parse_json_field(incoming.get('metrics'), [])
236
+ inc_params = _parse_json_field(incoming.get('parameters'), {})
237
+ inc_artifacts = _parse_json_field(incoming.get('artifacts'), [])
238
+ inc_logs = _parse_json_field(incoming.get('logs'), [])
239
+ # Merge metrics with de-dup
240
+ merged_metrics = []
241
+ seen = set()
242
+ for entry in base_metrics + inc_metrics:
243
+ try:
244
+ # Use the original entry so _metrics_key can properly
245
+ # distinguish dict vs non-dict entries
246
+ key = _metrics_key(entry)
247
+ except Exception:
248
+ key = (None, None)
249
+ if key not in seen:
250
+ seen.add(key)
251
+ merged_metrics.append(entry)
252
+ # Merge params
253
+ merged_params = {}
254
+ if isinstance(base_params, dict):
255
+ merged_params.update(base_params)
256
+ if isinstance(inc_params, dict):
257
+ merged_params.update(inc_params)
258
+ # Merge artifacts and logs with de-dup
259
+ def _dedup_list(lst):
260
+ out = []
261
+ seen_local = set()
262
+ for item in lst:
263
+ key = json.dumps(item, sort_keys=True, default=str) if not isinstance(item, str) else item
264
+ if key not in seen_local:
265
+ seen_local.add(key)
266
+ out.append(item)
267
+ return out
268
+ merged_artifacts = _dedup_list(list(base_artifacts) + list(inc_artifacts))
269
+ merged_logs = _dedup_list(list(base_logs) + list(inc_logs))
270
+ # Rebuild merged record preferring incoming scalars
271
+ merged_rec = dict(base)
272
+ merged_rec.update({k: v for k, v in incoming.items() if k not in ('metrics', 'parameters', 'artifacts', 'logs')})
273
+ merged_rec['metrics'] = json.dumps(merged_metrics, default=str)
274
+ merged_rec['parameters'] = json.dumps(merged_params, default=str)
275
+ merged_rec['artifacts'] = json.dumps(merged_artifacts, default=str)
276
+ merged_rec['logs'] = json.dumps(merged_logs, default=str)
277
+ merged_rec['last_updated'] = datetime.now().isoformat()
278
+ merged_map[exp_id] = merged_rec
279
+
280
+ # Prepare final list
281
+ valid_experiments = list(merged_map.values())
282
+ # Ensure all have mandatory fields encoded
283
+ normalized = []
284
+ for rec in valid_experiments:
285
+ # Normalize json fields to strings
286
+ for f, default in (('metrics', []), ('parameters', {}), ('artifacts', []), ('logs', [])):
287
+ val = rec.get(f)
288
+ if not isinstance(val, str):
289
+ rec[f] = json.dumps(val if val is not None else default, default=str)
290
+ if 'last_updated' not in rec:
291
+ rec['last_updated'] = datetime.now().isoformat()
292
+ normalized.append(rec)
293
+
294
+ dataset = Dataset.from_list(normalized)
295
+
296
+ # Generate commit message if not provided
297
+ if not commit_message:
298
+ commit_message = f"Union-merge update with {len(normalized)} experiments ({datetime.now().isoformat()})"
299
+
300
+ # Push to hub
301
+ dataset.push_to_hub(
302
+ self.dataset_repo,
303
+ token=self.hf_token,
304
+ private=True,
305
+ commit_message=commit_message
306
+ )
307
+
308
+ logger.info(f"βœ… Successfully saved {len(normalized)} experiments (union-merged) to {self.dataset_repo}")
309
+ return True
310
+
311
+ except Exception as e:
312
+ logger.error(f"❌ Failed to save experiments to dataset: {e}")
313
+ return False
314
+
315
+ def upsert_experiment(self, experiment: Dict[str, Any]) -> bool:
316
+ """
317
+ Insert a new experiment or update an existing one, preserving all other data.
318
+
319
+ Args:
320
+ experiment (Dict[str, Any]): Experiment dictionary to upsert
321
+
322
+ Returns:
323
+ bool: True if operation was successful, False otherwise
324
+ """
325
+ try:
326
+ # Validate the experiment structure
327
+ if not self._validate_experiment_structure(experiment):
328
+ logger.error(f"❌ Invalid experiment structure for {experiment.get('experiment_id', 'unknown')}")
329
+ return False
330
+
331
+ # Load existing experiments
332
+ existing_experiments = self.load_existing_experiments()
333
+
334
+ # Find if experiment already exists
335
+ experiment_id = experiment['experiment_id']
336
+ experiment_found = False
337
+ updated_experiments = []
338
+
339
+ for existing_exp in existing_experiments:
340
+ if existing_exp.get('experiment_id') == experiment_id:
341
+ # Update existing experiment
342
+ logger.info(f"πŸ”„ Updating existing experiment: {experiment_id}")
343
+ experiment['last_updated'] = datetime.now().isoformat()
344
+ updated_experiments.append(experiment)
345
+ experiment_found = True
346
+ else:
347
+ # Preserve existing experiment
348
+ updated_experiments.append(existing_exp)
349
+
350
+ # If experiment doesn't exist, add it
351
+ if not experiment_found:
352
+ logger.info(f"βž• Adding new experiment: {experiment_id}")
353
+ experiment['last_updated'] = datetime.now().isoformat()
354
+ updated_experiments.append(experiment)
355
+
356
+ # Save all experiments
357
+ commit_message = f"{'Update' if experiment_found else 'Add'} experiment {experiment_id} (preserving {len(existing_experiments)} existing experiments)"
358
+
359
+ return self.save_experiments(updated_experiments, commit_message)
360
+
361
+ except Exception as e:
362
+ logger.error(f"❌ Failed to upsert experiment: {e}")
363
+ return False
364
+
365
+ def get_experiment_by_id(self, experiment_id: str) -> Optional[Dict[str, Any]]:
366
+ """
367
+ Retrieve a specific experiment by its ID.
368
+
369
+ Args:
370
+ experiment_id (str): The experiment ID to search for
371
+
372
+ Returns:
373
+ Optional[Dict[str, Any]]: The experiment dictionary if found, None otherwise
374
+ """
375
+ try:
376
+ experiments = self.load_existing_experiments()
377
+
378
+ for exp in experiments:
379
+ if exp.get('experiment_id') == experiment_id:
380
+ logger.info(f"βœ… Found experiment: {experiment_id}")
381
+ return exp
382
+
383
+ logger.info(f"πŸ“Š Experiment not found: {experiment_id}")
384
+ return None
385
+
386
+ except Exception as e:
387
+ logger.error(f"❌ Failed to get experiment {experiment_id}: {e}")
388
+ return None
389
+
390
+ def list_experiments(self, status_filter: Optional[str] = None) -> List[Dict[str, Any]]:
391
+ """
392
+ List all experiments, optionally filtered by status.
393
+
394
+ Args:
395
+ status_filter (Optional[str]): Filter by experiment status (running, completed, failed, paused)
396
+
397
+ Returns:
398
+ List[Dict[str, Any]]: List of experiments matching the filter
399
+ """
400
+ try:
401
+ experiments = self.load_existing_experiments()
402
+
403
+ if status_filter:
404
+ filtered_experiments = [exp for exp in experiments if exp.get('status') == status_filter]
405
+ logger.info(f"πŸ“Š Found {len(filtered_experiments)} experiments with status '{status_filter}'")
406
+ return filtered_experiments
407
+
408
+ logger.info(f"πŸ“Š Found {len(experiments)} total experiments")
409
+ return experiments
410
+
411
+ except Exception as e:
412
+ logger.error(f"❌ Failed to list experiments: {e}")
413
+ return []
414
+
415
+ def backup_dataset(self, backup_suffix: Optional[str] = None) -> str:
416
+ """
417
+ Create a backup of the current dataset.
418
+
419
+ Args:
420
+ backup_suffix (Optional[str]): Optional suffix for backup repo name
421
+
422
+ Returns:
423
+ str: Backup repository name if successful, empty string otherwise
424
+ """
425
+ try:
426
+ if not backup_suffix:
427
+ backup_suffix = datetime.now().strftime('%Y%m%d_%H%M%S')
428
+
429
+ backup_repo = f"{self.dataset_repo}-backup-{backup_suffix}"
430
+
431
+ # Load current experiments
432
+ experiments = self.load_existing_experiments()
433
+
434
+ if not experiments:
435
+ logger.warning("⚠️ No experiments to backup")
436
+ return ""
437
+
438
+ # Create backup dataset manager
439
+ backup_manager = TrackioDatasetManager(backup_repo, self.hf_token)
440
+
441
+ # Save to backup
442
+ success = backup_manager.save_experiments(
443
+ experiments,
444
+ f"Backup of {self.dataset_repo} created on {datetime.now().isoformat()}"
445
+ )
446
+
447
+ if success:
448
+ logger.info(f"βœ… Backup created: {backup_repo}")
449
+ return backup_repo
450
+ else:
451
+ logger.error("❌ Failed to create backup")
452
+ return ""
453
+
454
+ except Exception as e:
455
+ logger.error(f"❌ Failed to create backup: {e}")
456
+ return ""
457
+
458
+
459
+ def create_dataset_manager(dataset_repo: str, hf_token: str) -> TrackioDatasetManager:
460
+ """
461
+ Factory function to create a TrackioDatasetManager instance.
462
+
463
+ Args:
464
+ dataset_repo (str): HF dataset repository ID
465
+ hf_token (str): Hugging Face token
466
+
467
+ Returns:
468
+ TrackioDatasetManager: Configured dataset manager instance
469
+ """
470
+ return TrackioDatasetManager(dataset_repo, hf_token)