|
|
import functools |
|
|
|
|
|
import torch |
|
|
from accelerate.logging import get_logger |
|
|
from peft.tuners.tuners_utils import BaseTunerLayer |
|
|
|
|
|
from .constants import FINETRAINERS_LOG_LEVEL |
|
|
|
|
|
|
|
|
logger = get_logger("finetrainers") |
|
|
logger.setLevel(FINETRAINERS_LOG_LEVEL) |
|
|
|
|
|
|
|
|
def perform_peft_patches() -> None: |
|
|
_perform_patch_move_adapter_to_device_of_base_layer() |
|
|
|
|
|
|
|
|
def _perform_patch_move_adapter_to_device_of_base_layer() -> None: |
|
|
|
|
|
|
|
|
|
|
|
BaseTunerLayer._move_adapter_to_device_of_base_layer = _patched_move_adapter_to_device_of_base_layer( |
|
|
BaseTunerLayer._move_adapter_to_device_of_base_layer |
|
|
) |
|
|
|
|
|
|
|
|
def _patched_move_adapter_to_device_of_base_layer(func) -> None: |
|
|
@functools.wraps(func) |
|
|
def wrapper(self, *args, **kwargs): |
|
|
with DisableTensorToDtype(): |
|
|
return func(self, *args, **kwargs) |
|
|
|
|
|
return wrapper |
|
|
|
|
|
|
|
|
class DisableTensorToDtype: |
|
|
def __enter__(self): |
|
|
self.original_to = torch.Tensor.to |
|
|
|
|
|
def modified_to(tensor, *args, **kwargs): |
|
|
|
|
|
args = [arg if not isinstance(arg, torch.dtype) else None for arg in args] |
|
|
if "dtype" in kwargs: |
|
|
kwargs.pop("dtype") |
|
|
return self.original_to(tensor, *args, **kwargs) |
|
|
|
|
|
torch.Tensor.to = modified_to |
|
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
|
torch.Tensor.to = self.original_to |
|
|
|