Spaces:
Sleeping
Sleeping
File size: 5,148 Bytes
4c78db1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import json, os, subprocess, sys, signal, time
from datetime import datetime
from app.services.logger import get_logger
from app.services.github_service import GitHubClient
from app.services.hf_service import dataset_exists
from app.services.resource_manager import job_slot, temp_workdir, check_gpu, prune_old_results
from app.utils.validation import safe_tag, safe_dataset_id, safe_int, ValidationError
log = get_logger(__name__)
DEFAULT_TIMEOUT = int(os.getenv("DEFAULT_JOB_TIMEOUT", str(60*60*3))) # 3 hours
# streaming generator compatible with Gradio
def launch_training(config: dict, timeout_sec: int = DEFAULT_TIMEOUT):
try:
# Basic validations
config['exp_tag'] = safe_tag(config.get('exp_tag', 'exp'))
safe_dataset_id(config['dataset_name'])
config['batch_size'] = safe_int(config.get('batch_size', 8), 1, 4096, 'batch_size')
p = config.get('params', {})
p['epochs'] = safe_int(p.get('epochs', 3), 1, 999, 'epochs')
config['params'] = p
except ValidationError as e:
yield f"❌ ValidationError: {e}"
return
# Prune old results opportunistically
prune_old_results()
# Dataset validation
try:
dataset_exists(config['dataset_name'])
except Exception as e:
yield f"❌ Dataset validation failed: {e}"
return
# Resource check (best-effort)
try:
check_gpu(mem_required_gb=float(os.getenv('MIN_GPU_GB', '4')))
except Exception as e:
yield f"⚠️ GPU check warning: {e}"
gh = GitHubClient()
with job_slot():
with temp_workdir(prefix='run_') as work:
cfg_path = os.path.join(work, 'config.json')
results_file = config.get('results_file') or os.path.join('local_results', f"{config['exp_tag']}_{int(time.time())}.json")
config['results_file'] = results_file
config['experiment_id'] = config.get('experiment_id') or f"{config['exp_tag']}_{config.get('algorithm','algo')}_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}_{os.urandom(3).hex()}"
with open(cfg_path, 'w') as f:
json.dump(config, f, indent=2)
cmd = [sys.executable, 'train.py', '--config', cfg_path]
log.info('Launching: %s', ' '.join(cmd))
try:
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1, preexec_fn=os.setsid)
except Exception as e:
yield f"❌ Failed to start training: {e}"
return
start = time.time()
try:
# Read lines and yield; enforce timeout
while True:
line = proc.stdout.readline()
if line:
yield line
elif proc.poll() is not None:
break
# timeout check
if time.time() - start > timeout_sec:
yield f"⏱️ Timeout reached ({timeout_sec}s). Terminating job..."
try:
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
except Exception:
proc.terminate()
break
time.sleep(0.1)
except GeneratorExit:
# Gradio or caller closed generator, terminate gracefully
try:
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
except Exception:
proc.terminate()
yield "⚠️ Training was cancelled by the user."
return
except Exception as e:
yield f"❌ Runtime error during training: {e}"
finally:
# Ensure process ended
if proc.poll() is None:
try:
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
except Exception:
proc.terminate()
proc.wait(timeout=30)
# Try to read results
try:
if os.path.exists(results_file):
with open(results_file, 'r') as f:
results = json.load(f)
# push to GitHub if available
try:
if gh.push_json(config['experiment_id'], results):
yield "✅ Training complete. Results pushed to GitHub."
else:
yield "⚠️ Training complete. GitHub push skipped (not configured)."
except Exception as e:
yield f"⚠️ Training complete, but push failed: {e}"
else:
yield "⚠️ Training finished but no results file found."
except Exception as e:
yield f"❌ Failed to read/push results: {e}"
yield "🔚 Orchestrator finished."
|