Spaces:
Sleeping
Sleeping
import json, os, subprocess, sys, signal, time | |
from datetime import datetime | |
from app.services.logger import get_logger | |
from app.services.github_service import GitHubClient | |
from app.services.hf_service import dataset_exists | |
from app.services.resource_manager import job_slot, temp_workdir, check_gpu, prune_old_results | |
from app.utils.validation import safe_tag, safe_dataset_id, safe_int, ValidationError | |
log = get_logger(__name__) | |
DEFAULT_TIMEOUT = int(os.getenv("DEFAULT_JOB_TIMEOUT", str(60*60*3))) # 3 hours | |
# streaming generator compatible with Gradio | |
def launch_training(config: dict, timeout_sec: int = DEFAULT_TIMEOUT): | |
try: | |
# Basic validations | |
config['exp_tag'] = safe_tag(config.get('exp_tag', 'exp')) | |
safe_dataset_id(config['dataset_name']) | |
config['batch_size'] = safe_int(config.get('batch_size', 8), 1, 4096, 'batch_size') | |
p = config.get('params', {}) | |
p['epochs'] = safe_int(p.get('epochs', 3), 1, 999, 'epochs') | |
config['params'] = p | |
except ValidationError as e: | |
yield f"❌ ValidationError: {e}" | |
return | |
# Prune old results opportunistically | |
prune_old_results() | |
# Dataset validation | |
try: | |
dataset_exists(config['dataset_name']) | |
except Exception as e: | |
yield f"❌ Dataset validation failed: {e}" | |
return | |
# Resource check (best-effort) | |
try: | |
check_gpu(mem_required_gb=float(os.getenv('MIN_GPU_GB', '4'))) | |
except Exception as e: | |
yield f"⚠️ GPU check warning: {e}" | |
gh = GitHubClient() | |
with job_slot(): | |
with temp_workdir(prefix='run_') as work: | |
cfg_path = os.path.join(work, 'config.json') | |
results_file = config.get('results_file') or os.path.join('local_results', f"{config['exp_tag']}_{int(time.time())}.json") | |
config['results_file'] = results_file | |
config['experiment_id'] = config.get('experiment_id') or f"{config['exp_tag']}_{config.get('algorithm','algo')}_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}_{os.urandom(3).hex()}" | |
with open(cfg_path, 'w') as f: | |
json.dump(config, f, indent=2) | |
cmd = [sys.executable, 'train.py', '--config', cfg_path] | |
log.info('Launching: %s', ' '.join(cmd)) | |
try: | |
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1, preexec_fn=os.setsid) | |
except Exception as e: | |
yield f"❌ Failed to start training: {e}" | |
return | |
start = time.time() | |
try: | |
# Read lines and yield; enforce timeout | |
while True: | |
line = proc.stdout.readline() | |
if line: | |
yield line | |
elif proc.poll() is not None: | |
break | |
# timeout check | |
if time.time() - start > timeout_sec: | |
yield f"⏱️ Timeout reached ({timeout_sec}s). Terminating job..." | |
try: | |
os.killpg(os.getpgid(proc.pid), signal.SIGTERM) | |
except Exception: | |
proc.terminate() | |
break | |
time.sleep(0.1) | |
except GeneratorExit: | |
# Gradio or caller closed generator, terminate gracefully | |
try: | |
os.killpg(os.getpgid(proc.pid), signal.SIGTERM) | |
except Exception: | |
proc.terminate() | |
yield "⚠️ Training was cancelled by the user." | |
return | |
except Exception as e: | |
yield f"❌ Runtime error during training: {e}" | |
finally: | |
# Ensure process ended | |
if proc.poll() is None: | |
try: | |
os.killpg(os.getpgid(proc.pid), signal.SIGTERM) | |
except Exception: | |
proc.terminate() | |
proc.wait(timeout=30) | |
# Try to read results | |
try: | |
if os.path.exists(results_file): | |
with open(results_file, 'r') as f: | |
results = json.load(f) | |
# push to GitHub if available | |
try: | |
if gh.push_json(config['experiment_id'], results): | |
yield "✅ Training complete. Results pushed to GitHub." | |
else: | |
yield "⚠️ Training complete. GitHub push skipped (not configured)." | |
except Exception as e: | |
yield f"⚠️ Training complete, but push failed: {e}" | |
else: | |
yield "⚠️ Training finished but no results file found." | |
except Exception as e: | |
yield f"❌ Failed to read/push results: {e}" | |
yield "🔚 Orchestrator finished." | |