File size: 5,148 Bytes
4c78db1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import json, os, subprocess, sys, signal, time
from datetime import datetime
from app.services.logger import get_logger
from app.services.github_service import GitHubClient
from app.services.hf_service import dataset_exists
from app.services.resource_manager import job_slot, temp_workdir, check_gpu, prune_old_results
from app.utils.validation import safe_tag, safe_dataset_id, safe_int, ValidationError

log = get_logger(__name__)

DEFAULT_TIMEOUT = int(os.getenv("DEFAULT_JOB_TIMEOUT", str(60*60*3)))  # 3 hours

# streaming generator compatible with Gradio
def launch_training(config: dict, timeout_sec: int = DEFAULT_TIMEOUT):
    try:
        # Basic validations
        config['exp_tag'] = safe_tag(config.get('exp_tag', 'exp'))
        safe_dataset_id(config['dataset_name'])
        config['batch_size'] = safe_int(config.get('batch_size', 8), 1, 4096, 'batch_size')
        p = config.get('params', {})
        p['epochs'] = safe_int(p.get('epochs', 3), 1, 999, 'epochs')
        config['params'] = p
    except ValidationError as e:
        yield f"❌ ValidationError: {e}"
        return

    # Prune old results opportunistically
    prune_old_results()

    # Dataset validation
    try:
        dataset_exists(config['dataset_name'])
    except Exception as e:
        yield f"❌ Dataset validation failed: {e}"
        return

    # Resource check (best-effort)
    try:
        check_gpu(mem_required_gb=float(os.getenv('MIN_GPU_GB', '4')))
    except Exception as e:
        yield f"⚠️ GPU check warning: {e}"

    gh = GitHubClient()

    with job_slot():
        with temp_workdir(prefix='run_') as work:
            cfg_path = os.path.join(work, 'config.json')
            results_file = config.get('results_file') or os.path.join('local_results', f"{config['exp_tag']}_{int(time.time())}.json")
            config['results_file'] = results_file
            config['experiment_id'] = config.get('experiment_id') or f"{config['exp_tag']}_{config.get('algorithm','algo')}_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}_{os.urandom(3).hex()}"

            with open(cfg_path, 'w') as f:
                json.dump(config, f, indent=2)

            cmd = [sys.executable, 'train.py', '--config', cfg_path]
            log.info('Launching: %s', ' '.join(cmd))

            try:
                proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1, preexec_fn=os.setsid)
            except Exception as e:
                yield f"❌ Failed to start training: {e}"
                return

            start = time.time()
            try:
                # Read lines and yield; enforce timeout
                while True:
                    line = proc.stdout.readline()
                    if line:
                        yield line
                    elif proc.poll() is not None:
                        break
                    # timeout check
                    if time.time() - start > timeout_sec:
                        yield f"⏱️ Timeout reached ({timeout_sec}s). Terminating job..."
                        try:
                            os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
                        except Exception:
                            proc.terminate()
                        break
                    time.sleep(0.1)
            except GeneratorExit:
                # Gradio or caller closed generator, terminate gracefully
                try:
                    os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
                except Exception:
                    proc.terminate()
                yield "⚠️ Training was cancelled by the user."
                return
            except Exception as e:
                yield f"❌ Runtime error during training: {e}"
            finally:
                # Ensure process ended
                if proc.poll() is None:
                    try:
                        os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
                    except Exception:
                        proc.terminate()
                proc.wait(timeout=30)

            # Try to read results
            try:
                if os.path.exists(results_file):
                    with open(results_file, 'r') as f:
                        results = json.load(f)
                    # push to GitHub if available
                    try:
                        if gh.push_json(config['experiment_id'], results):
                            yield "✅ Training complete. Results pushed to GitHub."
                        else:
                            yield "⚠️ Training complete. GitHub push skipped (not configured)."
                    except Exception as e:
                        yield f"⚠️ Training complete, but push failed: {e}"
                else:
                    yield "⚠️ Training finished but no results file found."
            except Exception as e:
                yield f"❌ Failed to read/push results: {e}"

    yield "🔚 Orchestrator finished."