rider-provider-777 commited on
Commit
4c78db1
·
verified ·
1 Parent(s): f3d54fd

Upload orchestrator.py

Browse files
Files changed (1) hide show
  1. app/core/orchestrator.py +119 -0
app/core/orchestrator.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json, os, subprocess, sys, signal, time
2
+ from datetime import datetime
3
+ from app.services.logger import get_logger
4
+ from app.services.github_service import GitHubClient
5
+ from app.services.hf_service import dataset_exists
6
+ from app.services.resource_manager import job_slot, temp_workdir, check_gpu, prune_old_results
7
+ from app.utils.validation import safe_tag, safe_dataset_id, safe_int, ValidationError
8
+
9
+ log = get_logger(__name__)
10
+
11
+ DEFAULT_TIMEOUT = int(os.getenv("DEFAULT_JOB_TIMEOUT", str(60*60*3))) # 3 hours
12
+
13
+ # streaming generator compatible with Gradio
14
+ def launch_training(config: dict, timeout_sec: int = DEFAULT_TIMEOUT):
15
+ try:
16
+ # Basic validations
17
+ config['exp_tag'] = safe_tag(config.get('exp_tag', 'exp'))
18
+ safe_dataset_id(config['dataset_name'])
19
+ config['batch_size'] = safe_int(config.get('batch_size', 8), 1, 4096, 'batch_size')
20
+ p = config.get('params', {})
21
+ p['epochs'] = safe_int(p.get('epochs', 3), 1, 999, 'epochs')
22
+ config['params'] = p
23
+ except ValidationError as e:
24
+ yield f"❌ ValidationError: {e}"
25
+ return
26
+
27
+ # Prune old results opportunistically
28
+ prune_old_results()
29
+
30
+ # Dataset validation
31
+ try:
32
+ dataset_exists(config['dataset_name'])
33
+ except Exception as e:
34
+ yield f"❌ Dataset validation failed: {e}"
35
+ return
36
+
37
+ # Resource check (best-effort)
38
+ try:
39
+ check_gpu(mem_required_gb=float(os.getenv('MIN_GPU_GB', '4')))
40
+ except Exception as e:
41
+ yield f"⚠️ GPU check warning: {e}"
42
+
43
+ gh = GitHubClient()
44
+
45
+ with job_slot():
46
+ with temp_workdir(prefix='run_') as work:
47
+ cfg_path = os.path.join(work, 'config.json')
48
+ results_file = config.get('results_file') or os.path.join('local_results', f"{config['exp_tag']}_{int(time.time())}.json")
49
+ config['results_file'] = results_file
50
+ config['experiment_id'] = config.get('experiment_id') or f"{config['exp_tag']}_{config.get('algorithm','algo')}_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}_{os.urandom(3).hex()}"
51
+
52
+ with open(cfg_path, 'w') as f:
53
+ json.dump(config, f, indent=2)
54
+
55
+ cmd = [sys.executable, 'train.py', '--config', cfg_path]
56
+ log.info('Launching: %s', ' '.join(cmd))
57
+
58
+ try:
59
+ proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1, preexec_fn=os.setsid)
60
+ except Exception as e:
61
+ yield f"❌ Failed to start training: {e}"
62
+ return
63
+
64
+ start = time.time()
65
+ try:
66
+ # Read lines and yield; enforce timeout
67
+ while True:
68
+ line = proc.stdout.readline()
69
+ if line:
70
+ yield line
71
+ elif proc.poll() is not None:
72
+ break
73
+ # timeout check
74
+ if time.time() - start > timeout_sec:
75
+ yield f"⏱️ Timeout reached ({timeout_sec}s). Terminating job..."
76
+ try:
77
+ os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
78
+ except Exception:
79
+ proc.terminate()
80
+ break
81
+ time.sleep(0.1)
82
+ except GeneratorExit:
83
+ # Gradio or caller closed generator, terminate gracefully
84
+ try:
85
+ os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
86
+ except Exception:
87
+ proc.terminate()
88
+ yield "⚠️ Training was cancelled by the user."
89
+ return
90
+ except Exception as e:
91
+ yield f"❌ Runtime error during training: {e}"
92
+ finally:
93
+ # Ensure process ended
94
+ if proc.poll() is None:
95
+ try:
96
+ os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
97
+ except Exception:
98
+ proc.terminate()
99
+ proc.wait(timeout=30)
100
+
101
+ # Try to read results
102
+ try:
103
+ if os.path.exists(results_file):
104
+ with open(results_file, 'r') as f:
105
+ results = json.load(f)
106
+ # push to GitHub if available
107
+ try:
108
+ if gh.push_json(config['experiment_id'], results):
109
+ yield "✅ Training complete. Results pushed to GitHub."
110
+ else:
111
+ yield "⚠️ Training complete. GitHub push skipped (not configured)."
112
+ except Exception as e:
113
+ yield f"⚠️ Training complete, but push failed: {e}"
114
+ else:
115
+ yield "⚠️ Training finished but no results file found."
116
+ except Exception as e:
117
+ yield f"❌ Failed to read/push results: {e}"
118
+
119
+ yield "🔚 Orchestrator finished."