|
|
|
|
|
import logging
|
|
from contextlib import contextmanager
|
|
from functools import wraps
|
|
import torch
|
|
|
|
__all__ = ["retry_if_cuda_oom"]
|
|
|
|
|
|
@contextmanager
|
|
def _ignore_torch_cuda_oom():
|
|
"""
|
|
A context which ignores CUDA OOM exception from pytorch.
|
|
"""
|
|
try:
|
|
yield
|
|
except RuntimeError as e:
|
|
|
|
if "CUDA out of memory. " in str(e):
|
|
pass
|
|
else:
|
|
raise
|
|
|
|
|
|
def retry_if_cuda_oom(func):
|
|
"""
|
|
Makes a function retry itself after encountering
|
|
pytorch's CUDA OOM error.
|
|
It will first retry after calling `torch.cuda.empty_cache()`.
|
|
|
|
If that still fails, it will then retry by trying to convert inputs to CPUs.
|
|
In this case, it expects the function to dispatch to CPU implementation.
|
|
The return values may become CPU tensors as well and it's user's
|
|
responsibility to convert it back to CUDA tensor if needed.
|
|
|
|
Args:
|
|
func: a stateless callable that takes tensor-like objects as arguments
|
|
|
|
Returns:
|
|
a callable which retries `func` if OOM is encountered.
|
|
|
|
Examples:
|
|
::
|
|
output = retry_if_cuda_oom(some_torch_function)(input1, input2)
|
|
# output may be on CPU even if inputs are on GPU
|
|
|
|
Note:
|
|
1. When converting inputs to CPU, it will only look at each argument and check
|
|
if it has `.device` and `.to` for conversion. Nested structures of tensors
|
|
are not supported.
|
|
|
|
2. Since the function might be called more than once, it has to be
|
|
stateless.
|
|
"""
|
|
|
|
def maybe_to_cpu(x):
|
|
try:
|
|
like_gpu_tensor = x.device.type == "cuda" and hasattr(x, "to")
|
|
except AttributeError:
|
|
like_gpu_tensor = False
|
|
if like_gpu_tensor:
|
|
return x.to(device="cpu")
|
|
else:
|
|
return x
|
|
|
|
@wraps(func)
|
|
def wrapped(*args, **kwargs):
|
|
with _ignore_torch_cuda_oom():
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
torch.cuda.empty_cache()
|
|
with _ignore_torch_cuda_oom():
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger.info("Attempting to copy inputs of {} to CPU due to CUDA OOM".format(str(func)))
|
|
new_args = (maybe_to_cpu(x) for x in args)
|
|
new_kwargs = {k: maybe_to_cpu(v) for k, v in kwargs.items()}
|
|
return func(*new_args, **new_kwargs)
|
|
|
|
return wrapped
|
|
|