Spaces:
Paused
Paused
import torch | |
from torch import Tensor | |
from typing import Optional | |
from optimum.quanto import QBytesTensor | |
def compute_scale_for_dtype(tensor, dtype): | |
""" | |
Compute appropriate scale for the given tensor and target dtype. | |
Args: | |
tensor: Input tensor to be quantized | |
dtype: Target dtype for quantization | |
Returns: | |
Appropriate scale factor for the quantization | |
""" | |
if dtype == torch.int8: | |
abs_max = torch.max(torch.abs(tensor)) | |
return abs_max / 127.0 if abs_max > 0 else 1.0 | |
elif dtype == torch.uint8: | |
max_val = torch.max(tensor) | |
min_val = torch.min(tensor) | |
range_val = max_val - min_val | |
return range_val / 255.0 if range_val > 0 else 1.0 | |
elif dtype in (torch.float8_e4m3fn, torch.float8_e5m2): | |
# For float8, we typically want to preserve the magnitude of the values | |
# while fitting within the representable range of the format | |
abs_max = torch.max(torch.abs(tensor)) | |
if dtype == torch.float8_e4m3fn: | |
# e4m3fn has range [-448, 448] with no infinities | |
max_representable = 448.0 | |
else: # torch.float8_e5m2 | |
# e5m2 has range [-57344, 57344] with infinities | |
max_representable = 57344.0 | |
return abs_max / max_representable if abs_max > 0 else 1.0 | |
else: | |
raise ValueError(f"Unsupported dtype for quantization: {dtype}") | |
def quantize_tensor(tensor, dtype): | |
""" | |
Quantize a floating-point tensor to the target dtype with appropriate scaling. | |
Args: | |
tensor: Input tensor (float) | |
dtype: Target dtype for quantization | |
Returns: | |
quantized_data: Quantized tensor | |
scale: Scale factor used | |
""" | |
scale = compute_scale_for_dtype(tensor, dtype) | |
if dtype == torch.int8: | |
quantized_data = torch.clamp(torch.round(tensor / scale), -128, 127).to(dtype) | |
elif dtype == torch.uint8: | |
quantized_data = torch.clamp(torch.round(tensor / scale), 0, 255).to(dtype) | |
elif dtype in (torch.float8_e4m3fn, torch.float8_e5m2): | |
# For float8, we scale and then cast directly to the target type | |
# The casting operation will handle the appropriate rounding | |
scaled_tensor = tensor / scale | |
quantized_data = scaled_tensor.to(dtype) | |
else: | |
raise ValueError(f"Unsupported dtype for quantization: {dtype}") | |
return quantized_data, scale | |
def update_parameter(target, result_float): | |
""" | |
Updates a parameter tensor, handling both regular torch.Tensor and QBytesTensor cases | |
with proper rescaling for quantized tensors. | |
Args: | |
target: The parameter to update (either torch.Tensor or QBytesTensor) | |
result_float: The new values to assign (torch.Tensor) | |
""" | |
if isinstance(target, QBytesTensor): | |
# Get the target dtype from the existing quantized tensor | |
target_dtype = target._data.dtype | |
# Handle device placement | |
device = target._data.device | |
result_float = result_float.to(device) | |
# Compute new quantized values and scale | |
quantized_data, new_scale = quantize_tensor(result_float, target_dtype) | |
# Update the internal tensors with newly computed values | |
target._data.copy_(quantized_data) | |
target._scale.copy_(new_scale) | |
else: | |
# Regular tensor update | |
target.copy_(result_float) | |
def get_format_params(dtype: torch.dtype) -> tuple[int, int]: | |
""" | |
Returns (mantissa_bits, total_bits) for each format. | |
mantissa_bits excludes the implicit leading 1. | |
""" | |
if dtype == torch.float32: | |
return 23, 32 | |
elif dtype == torch.bfloat16: | |
return 7, 16 | |
elif dtype == torch.float16: | |
return 10, 16 | |
elif dtype == torch.float8_e4m3fn: | |
return 3, 8 | |
elif dtype == torch.float8_e5m2: | |
return 2, 8 | |
elif dtype == torch.int8: | |
return 0, 8 # Int8 doesn't have mantissa bits | |
else: | |
raise ValueError(f"Unsupported dtype: {dtype}") | |
def copy_stochastic_bf16(target: torch.Tensor, source: torch.Tensor): | |
# adapted from https://github.com/Nerogar/OneTrainer/blob/411532e85f3cf2b52baa37597f9c145073d54511/modules/util/bf16_stochastic_rounding.py#L5 | |
# create a random 16 bit integer | |
result = torch.randint_like( | |
source, | |
dtype=torch.int32, | |
low=0, | |
high=(1 << 16), | |
) | |
# add the random number to the lower 16 bit of the mantissa | |
result.add_(source.view(dtype=torch.int32)) | |
# mask off the lower 16 bit of the mantissa | |
result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32 | |
# copy the higher 16 bit into the target tensor | |
target.copy_(result.view(dtype=torch.float32)) | |
del result | |
def copy_stochastic(target: torch.Tensor, source: torch.Tensor, eps: Optional[float] = None) -> None: | |
with torch.no_grad(): | |
# assert if target is on cpu, throw error | |
assert target.device.type != 'cpu', "Target is on cpu!" | |
assert source.device.type != 'cpu', "Source is on cpu!" | |
if target.dtype == torch.float32: | |
target.copy_(source) | |
return | |
if target.dtype == torch.bfloat16: | |
copy_stochastic_bf16(target, source) | |
return | |
mantissa_bits, _ = get_format_params(target.dtype) | |
round_factor = 2 ** (23 - mantissa_bits) | |
# Add uniform noise for stochastic rounding | |
noise = torch.rand_like(source, device=source.device) - 0.5 | |
rounded = torch.round(source * round_factor + noise) | |
result_float = rounded / round_factor | |
# Clamp for float8 | |
if target.dtype == torch.float8_e4m3fn: | |
result_float.clamp_(-448.0, 448.0) | |
elif target.dtype == torch.float8_e5m2: | |
result_float.clamp_(-57344.0, 57344.0) | |
update_parameter(target, result_float) | |
class Auto8bitTensor: | |
def __init__(self, data: Tensor, *args, **kwargs): | |
if isinstance(data, dict): # Add constructor from state dict | |
self._load_from_state_dict(data) | |
else: | |
abs_max = data.abs().max().item() | |
scale = abs_max / 127.0 if abs_max > 0 else 1.0 | |
self.quantized = (data / scale).round().clamp(-127, 127).to(torch.int8) | |
self.scale = scale | |
self.orig_dtype = data.dtype | |
def dequantize(self) -> Tensor: | |
return self.quantized.to(dtype=torch.float32) * self.scale | |
def to(self, *args, **kwargs): | |
# Handle the dtype argument whether it's positional or keyword | |
dtype = None | |
if args and isinstance(args[0], torch.dtype): | |
dtype = args[0] | |
args = args[1:] | |
elif 'dtype' in kwargs: | |
dtype = kwargs['dtype'] | |
del kwargs['dtype'] | |
if dtype is not None: | |
# First dequantize then convert to requested dtype | |
return self.dequantize().to(dtype=dtype, *args, **kwargs) | |
# If no dtype specified, just pass through to parent | |
return self.dequantize().to(*args, **kwargs) | |
def state_dict(self): | |
"""Returns a dictionary containing the current state of the tensor.""" | |
return { | |
'quantized': self.quantized, | |
'scale': self.scale, | |
'orig_dtype': self.orig_dtype | |
} | |
def _load_from_state_dict(self, state_dict): | |
"""Loads the tensor state from a state dictionary.""" | |
self.quantized = state_dict['quantized'] | |
self.scale = state_dict['scale'] | |
self.orig_dtype = state_dict['orig_dtype'] | |
def __str__(self): | |
return f"Auto8bitTensor({self.dequantize()})" | |
def stochastic_grad_accummulation(param): | |
if hasattr(param, "_accum_grad"): | |
grad_fp32 = param._accum_grad.clone().to(torch.float32) | |
grad_fp32.add_(param.grad.to(torch.float32)) | |
copy_stochastic(param._accum_grad, grad_fp32) | |
del grad_fp32 | |
del param.grad | |
else: | |
param._accum_grad = param.grad.clone() | |
del param.grad | |