|
"""Unsloth checkpointing""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
class Unsloth_Offloaded_Gradient_Checkpointer( |
|
torch.autograd.Function |
|
): |
|
""" |
|
Saves VRAM by smartly offloading to RAM. |
|
Tiny hit to performance, since we mask the movement via non blocking calls. |
|
""" |
|
|
|
@staticmethod |
|
@torch.cuda.amp.custom_fwd |
|
def forward(ctx, forward_function, hidden_states, *args): |
|
saved_hidden_states = hidden_states.to("cpu", non_blocking=True) |
|
with torch.no_grad(): |
|
output = forward_function(hidden_states, *args) |
|
ctx.save_for_backward(saved_hidden_states) |
|
ctx.forward_function = forward_function |
|
ctx.args = args |
|
return output |
|
|
|
@staticmethod |
|
@torch.cuda.amp.custom_bwd |
|
def backward(ctx, dY): |
|
(hidden_states,) = ctx.saved_tensors |
|
hidden_states = hidden_states.to("cuda", non_blocking=True).detach() |
|
hidden_states.requires_grad = True |
|
with torch.enable_grad(): |
|
(output,) = ctx.forward_function(hidden_states, *ctx.args) |
|
torch.autograd.backward(output, dY) |
|
return ( |
|
None, |
|
hidden_states.grad, |
|
) + ( |
|
None, |
|
) * len(ctx.args) |
|
|