Spaces:
Runtime error
Runtime error
| # Copyright 2023-present the HuggingFace Inc. team. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import importlib | |
| import warnings | |
| from typing import Any, Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.init as init | |
| from peft.tuners.tuners_utils import BaseTunerLayer | |
| from .layer import LoraLayer | |
| class LoraParallelLinear(nn.Module, LoraLayer): | |
| """ | |
| When the target layer parallel_linear is RowParallelLinear, in order to keep the input and output shapes | |
| consistent, we need to split the lora matrix A into rows, and the lora_B at this time should be a complete linear | |
| layer; In the same way, when the target layer is ColumnParallelLinear, we perform column segmentation on lora_B, | |
| while lora_A is still a complete linear layer. | |
| """ | |
| def __init__( | |
| self, | |
| base_layer, | |
| adapter_name: str, | |
| backend, | |
| r: int = 0, | |
| lora_alpha: int = 1, | |
| lora_dropout: float = 0.0, | |
| fan_in_fan_out: bool = False, | |
| init_lora_weights: bool = True, | |
| use_rslora: bool = False, | |
| use_dora: bool = False, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| LoraLayer.__init__(self, base_layer=base_layer) | |
| if use_dora: | |
| raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False") | |
| self.backend = backend | |
| self.is_parallel_a = isinstance(base_layer, backend.RowParallelLinear) | |
| self.fan_in_fan_out = fan_in_fan_out | |
| self._active_adapter = adapter_name | |
| megatron_config = kwargs["megatron_config"] | |
| parallel_linear_kwargs = {"megatron_config": megatron_config} | |
| init_method = init.xavier_normal_ | |
| if hasattr(megatron_config, "init_method"): | |
| init_method = megatron_config.init_method | |
| input_is_parallel = True | |
| gather_output = False | |
| if isinstance(base_layer, self.backend.RowParallelLinear): | |
| input_is_parallel = base_layer.input_is_parallel | |
| else: | |
| gather_output = base_layer.gather_output | |
| self.update_layer( | |
| adapter_name, | |
| r, | |
| lora_alpha=lora_alpha, | |
| lora_dropout=lora_dropout, | |
| init_lora_weights=init_lora_weights, | |
| use_rslora=use_rslora, | |
| use_dora=use_dora, | |
| init_method=init_method, | |
| input_is_parallel=input_is_parallel, | |
| gather_output=gather_output, | |
| **parallel_linear_kwargs, | |
| ) | |
| self.is_target_conv_1d_layer = False | |
| def update_layer( | |
| self, | |
| adapter_name, | |
| r, | |
| lora_alpha, | |
| lora_dropout, | |
| init_lora_weights, | |
| use_rslora, | |
| use_dora=False, | |
| init_method=init.xavier_normal_, | |
| input_is_parallel=True, | |
| gather_output=False, | |
| **parallel_linear_kwargs, | |
| ): | |
| if r <= 0: | |
| raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") | |
| self.r[adapter_name] = r | |
| self.lora_alpha[adapter_name] = lora_alpha | |
| if lora_dropout > 0.0: | |
| lora_dropout_layer = nn.Dropout(p=lora_dropout) | |
| else: | |
| lora_dropout_layer = nn.Identity() | |
| self.lora_dropout[adapter_name] = lora_dropout_layer | |
| megatron_config = parallel_linear_kwargs["megatron_config"] | |
| # lora needs to be forced to upgrade to 32-bit precision, otherwise it will overflow | |
| megatron_config.params_dtype = torch.float32 | |
| if self.is_parallel_a: | |
| lora_a = self.backend.RowParallelLinear( | |
| input_size=self.in_features, | |
| output_size=r, | |
| bias=False, | |
| input_is_parallel=input_is_parallel, | |
| skip_bias_add=True, | |
| init_method=init_method, | |
| config=megatron_config, | |
| ) | |
| lora_b = nn.Linear(in_features=r, out_features=self.out_features, bias=False, dtype=torch.float32) | |
| else: | |
| lora_a = nn.Linear(in_features=self.in_features, out_features=r, bias=False, dtype=torch.float32) | |
| lora_b = self.backend.ColumnParallelLinear( | |
| input_size=r, | |
| output_size=self.out_features, | |
| bias=False, | |
| gather_output=gather_output, | |
| init_method=init_method, | |
| config=megatron_config, | |
| ) | |
| self.lora_A[adapter_name] = lora_a | |
| self.lora_B[adapter_name] = lora_b | |
| if use_rslora: | |
| self.scaling[adapter_name] = lora_alpha / (r**0.5) | |
| else: | |
| self.scaling[adapter_name] = lora_alpha / r | |
| if init_lora_weights: | |
| self.reset_lora_parameters(adapter_name, init_lora_weights) | |
| weight = getattr(self.get_base_layer(), "weight", None) | |
| if weight is not None: | |
| # the layer is already completely initialized, this is an update | |
| if weight.dtype.is_floating_point or weight.dtype.is_complex: | |
| self.to(weight.device, dtype=weight.dtype) | |
| else: | |
| self.to(weight.device) | |
| self.set_adapter(self.active_adapters) | |
| def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any): | |
| previous_dtype = x.dtype | |
| # If weight is used for matrix multiplication here, the final aggregation operation of the original | |
| # parallel_linear layer will be missing, so we need to directly call its forward function to obtain the | |
| # output of the original parallel_linear layer. | |
| if self.disable_adapters: | |
| if self.merged: | |
| self.unmerge() | |
| result, bias = self.base_layer(x, *args, **kwargs) | |
| elif self.merged: | |
| result, bias = self.base_layer(x, *args, **kwargs) | |
| else: | |
| result, bias = self.base_layer(x, *args, **kwargs) | |
| for active_adapter in self.active_adapters: | |
| if active_adapter not in self.lora_A.keys(): | |
| continue | |
| lora_A = self.lora_A[active_adapter] | |
| lora_B = self.lora_B[active_adapter] | |
| dropout = self.lora_dropout[active_adapter] | |
| scaling = self.scaling[active_adapter] | |
| x = x.to(lora_A.weight.dtype) | |
| lora_result = lora_A(dropout(x)) | |
| if isinstance(lora_result, tuple): | |
| lora_result = lora_result[0] | |
| lora_result = lora_B(lora_result) | |
| if isinstance(lora_result, tuple): | |
| lora_result = lora_result[0] | |
| lora_result = lora_result * scaling | |
| result = result + lora_result | |
| result = result.to(previous_dtype) | |
| return result, bias | |
| def dispatch_megatron( | |
| target: torch.nn.Module, | |
| adapter_name: str, | |
| lora_config, | |
| **kwargs: Any, | |
| ) -> Optional[torch.nn.Module]: | |
| new_module = None | |
| if isinstance(target, BaseTunerLayer): | |
| target_base_layer = target.get_base_layer() | |
| else: | |
| target_base_layer = target | |
| if lora_config.megatron_config: | |
| megatron_core = importlib.import_module(lora_config.megatron_core) | |
| else: | |
| megatron_core = None | |
| if megatron_core and isinstance( | |
| target_base_layer, | |
| (megatron_core.tensor_parallel.ColumnParallelLinear, megatron_core.tensor_parallel.RowParallelLinear), | |
| ): | |
| megatron_kwargs = kwargs.copy() | |
| megatron_config = lora_config.megatron_config | |
| if isinstance(megatron_config, dict): | |
| transformer_config_class = megatron_core.transformer.transformer_config.TransformerConfig | |
| megatron_config = transformer_config_class(**lora_config.megatron_config) | |
| megatron_kwargs["megatron_config"] = megatron_config | |
| if megatron_kwargs["fan_in_fan_out"]: | |
| warnings.warn( | |
| "fan_in_fan_out is set to True but the target module is `ColumnParallelLinear` " | |
| "or `RowParallelLinear`. " | |
| "Setting fan_in_fan_out to False." | |
| ) | |
| megatron_kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False | |
| new_module = LoraParallelLinear( | |
| base_layer=target, adapter_name=adapter_name, backend=megatron_core.tensor_parallel, **megatron_kwargs | |
| ) | |
| return new_module | |