File size: 2,461 Bytes
62aa251
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os, shutil, time
import tempfile
import torch
import threading
from contextlib import contextmanager
from app.services.logger import get_logger

log = get_logger(__name__)

# Concurrency
_MAX_CONCURRENT = int(os.getenv("MAX_CONCURRENT_JOBS", "1"))
_sema = threading.Semaphore(_MAX_CONCURRENT)

# Retention policy (seconds)
_RESULTS_RETENTION = int(os.getenv("RESULTS_RETENTION_SEC", str(60*60*24*7)))  # 7 days

@contextmanager
def job_slot(timeout_sec: int = 60*45):
    acquired = _sema.acquire(timeout=5)
    if not acquired:
        raise RuntimeError("All runners busy; try later")
    try:
        yield
    finally:
        _sema.release()

@contextmanager
def temp_workdir(prefix: str = "job_"):
    d = tempfile.mkdtemp(prefix=prefix)
    try:
        yield d
    finally:
        try:
            shutil.rmtree(d, ignore_errors=True)
        except Exception:
            log.warning("Failed to rmtree %s", d)

def check_gpu(mem_required_gb: float = 4.0) -> None:
    if torch.cuda.is_available():
        try:
            props = torch.cuda.get_device_properties(0)
            total_gb = props.total_memory / (1024**3)
            log.info("Detected GPU with %.1f GB", total_gb)
            if total_gb < mem_required_gb:
                raise RuntimeError(f"GPU has {total_gb:.1f}GB < required {mem_required_gb}GB")
        except Exception as e:
            log.warning("GPU check failed: %s", e)
    else:
        log.info("CUDA not available; running on CPU")

def prune_old_results(local_results_dir: str = "local_results"):
    now = time.time()
    if not os.path.isdir(local_results_dir):
        return
    for fname in os.listdir(local_results_dir):
        path = os.path.join(local_results_dir, fname)
        try:
            if now - os.path.getmtime(path) > _RESULTS_RETENTION:
                os.remove(path)
                log.info("Pruned old result %s", path)
        except Exception as e:
            log.warning("Prune failed for %s: %s", path, e)

# Heuristic OOM estimate (very rough)
def suggest_batch_limit(model_size_mb: float = 400):
    try:
        if torch.cuda.is_available():
            props = torch.cuda.get_device_properties(0)
            free_gb = props.total_memory / (1024**3)
            suggested = max(1, int((free_gb * 1024) / (model_size_mb)))
            return suggested
    except Exception:
        pass
    return 1