training_bench / app /services /resource_manager.py
rider-provider-777's picture
Upload 4 files
62aa251 verified
import os, shutil, time
import tempfile
import torch
import threading
from contextlib import contextmanager
from app.services.logger import get_logger
log = get_logger(__name__)
# Concurrency
_MAX_CONCURRENT = int(os.getenv("MAX_CONCURRENT_JOBS", "1"))
_sema = threading.Semaphore(_MAX_CONCURRENT)
# Retention policy (seconds)
_RESULTS_RETENTION = int(os.getenv("RESULTS_RETENTION_SEC", str(60*60*24*7))) # 7 days
@contextmanager
def job_slot(timeout_sec: int = 60*45):
acquired = _sema.acquire(timeout=5)
if not acquired:
raise RuntimeError("All runners busy; try later")
try:
yield
finally:
_sema.release()
@contextmanager
def temp_workdir(prefix: str = "job_"):
d = tempfile.mkdtemp(prefix=prefix)
try:
yield d
finally:
try:
shutil.rmtree(d, ignore_errors=True)
except Exception:
log.warning("Failed to rmtree %s", d)
def check_gpu(mem_required_gb: float = 4.0) -> None:
if torch.cuda.is_available():
try:
props = torch.cuda.get_device_properties(0)
total_gb = props.total_memory / (1024**3)
log.info("Detected GPU with %.1f GB", total_gb)
if total_gb < mem_required_gb:
raise RuntimeError(f"GPU has {total_gb:.1f}GB < required {mem_required_gb}GB")
except Exception as e:
log.warning("GPU check failed: %s", e)
else:
log.info("CUDA not available; running on CPU")
def prune_old_results(local_results_dir: str = "local_results"):
now = time.time()
if not os.path.isdir(local_results_dir):
return
for fname in os.listdir(local_results_dir):
path = os.path.join(local_results_dir, fname)
try:
if now - os.path.getmtime(path) > _RESULTS_RETENTION:
os.remove(path)
log.info("Pruned old result %s", path)
except Exception as e:
log.warning("Prune failed for %s: %s", path, e)
# Heuristic OOM estimate (very rough)
def suggest_batch_limit(model_size_mb: float = 400):
try:
if torch.cuda.is_available():
props = torch.cuda.get_device_properties(0)
free_gb = props.total_memory / (1024**3)
suggested = max(1, int((free_gb * 1024) / (model_size_mb)))
return suggested
except Exception:
pass
return 1