|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""JAX Ops for symmetric matrices used by the Shampoo optimizer.""" |
|
|
|
import functools |
|
from typing import Any, List, Sequence, Union |
|
|
|
import jax |
|
import jax.numpy as jnp |
|
import numpy as np |
|
from flax import struct |
|
from jax import lax |
|
|
|
|
|
@struct.dataclass |
|
class SlicedSymmetricMatrix: |
|
"""A symmetric matrix represented by lower-triangular block row slices. |
|
|
|
For example, the symmetric matrix M = [[a, b^T], [b, c]] would be represented |
|
by the block rows a and [b, c]. |
|
|
|
The matrix may be batched, in which case each entry of block_rows may have |
|
dimension greater than 2. The last two dimensions represent the rows and cols. |
|
""" |
|
|
|
block_rows: List[jnp.ndarray] |
|
|
|
|
|
def product_with_transpose( |
|
mat1, |
|
mat2, |
|
axes, |
|
precision=lax.Precision.DEFAULT, |
|
): |
|
"""Returns mat1 * mat2^T for two matrices (possibly batched). |
|
|
|
The rows and columns are the last two dimensions for each matrix. |
|
|
|
Args: |
|
mat1: First matrix. |
|
mat2: Second matrix. |
|
axes: The axes over which to apply the product. |
|
precision: JAX precision to use for the multiplication. |
|
""" |
|
return jnp.tensordot(a=mat1, b=mat2, axes=axes, precision=precision) |
|
|
|
|
|
@functools.partial(jax.jit, static_argnames=("block_size", "axes", "precision")) |
|
def sliced_transposed_product( |
|
mat, |
|
block_size, |
|
axes=(-1,), |
|
precision=lax.Precision.DEFAULT, |
|
): |
|
"""Returns the blocked slices representing a symmetric contraction. |
|
|
|
Specifically, the output is a contraction of the input mat with itself, in the |
|
specified axes. |
|
|
|
Args: |
|
mat: The matrix for which we will compute a contraction with itself. |
|
block_size: The size of row blocks to compute. |
|
axes: Axes to use for the contraction. |
|
precision: The precision to use in each computation. |
|
|
|
Raises: |
|
ValueError: Raised when the specified block size does not evenly divide |
|
the number of rows of the input mat. |
|
""" |
|
rank = len(mat.shape) |
|
|
|
def _make_axis_positive(ax): |
|
assert -rank <= ax < rank |
|
return ax + rank if ax < 0 else ax |
|
|
|
positive_axes = [_make_axis_positive(ax) for ax in axes] |
|
assert len(positive_axes) == len(axes) |
|
remaining_axes = set(range(rank)) - set(positive_axes) |
|
assert len(remaining_axes) == 1 |
|
remaining_ax = remaining_axes.pop() |
|
|
|
num_rows = mat.shape[remaining_ax] |
|
if num_rows % block_size != 0: |
|
raise ValueError( |
|
"The row dimension must be divisible by block_size. " |
|
f"Instead got row dimension={num_rows} and block_size={block_size}." |
|
) |
|
|
|
block_rows = [] |
|
for i in range(num_rows // block_size): |
|
start_indices = [0] * rank |
|
start_indices[remaining_ax] = i * block_size |
|
|
|
slice_sizes = list(mat.shape) |
|
slice_sizes[remaining_ax] = block_size |
|
|
|
slice_sizes_full = list(mat.shape) |
|
slice_sizes_full[remaining_ax] = (i + 1) * block_size |
|
|
|
block_rows.append( |
|
product_with_transpose( |
|
lax.dynamic_slice( |
|
mat, start_indices=start_indices, slice_sizes=slice_sizes |
|
), |
|
lax.dynamic_slice( |
|
mat, start_indices=[0] * rank, slice_sizes=slice_sizes_full |
|
), |
|
axes=(axes, axes), |
|
precision=precision, |
|
) |
|
) |
|
|
|
return SlicedSymmetricMatrix(block_rows=block_rows) |
|
|
|
|
|
@functools.partial(jax.jit, static_argnames=("block_size", "axes", "precision")) |
|
def sliced_transposed_product_concat( |
|
mat, |
|
block_size, |
|
axes=(-1,), |
|
precision=lax.Precision.DEFAULT, |
|
): |
|
"""Returns the concatenated slices representing mat*mat^T. |
|
|
|
Args: |
|
mat: The matrix for which we will compute mat*mat^T. It does not need to be |
|
square, and may be batched. |
|
block_size: The size of row blocks to compute. |
|
axes: Axes to use for the contraction. |
|
precision: The precision to use in each computation. |
|
|
|
Raises: |
|
ValueError: Raised when the specified block size does not evenly divide |
|
the number of rows of the input mat. |
|
""" |
|
sliced_symmetric_matrix = sliced_transposed_product( |
|
mat=mat, block_size=block_size, axes=axes, precision=precision |
|
) |
|
return jnp.concatenate(sliced_symmetric_matrix.block_rows, axis=-1) |
|
|
|
|
|
@jax.jit |
|
def materialize_matrix(symmetric_matrix): |
|
"""Returns a materialized symmetric matrix. |
|
|
|
Args: |
|
symmetric_matrix: the matrix represented by lower-triangular block slices. |
|
""" |
|
block_rows = symmetric_matrix.block_rows |
|
block_size = block_rows[0].shape[-2] |
|
num_blocks = len(block_rows) |
|
|
|
|
|
blocks = [ |
|
[ |
|
block_row[Ellipsis, i * block_size : (i + 1) * block_size] |
|
for i in range(k + 1) |
|
] |
|
for k, block_row in enumerate(block_rows) |
|
] |
|
|
|
|
|
off_diags = [[] for _ in range(num_blocks - 1)] |
|
for k, block_row in enumerate(block_rows[1:]): |
|
for i in range(k + 1): |
|
off_diags[i].append( |
|
jnp.swapaxes( |
|
a=block_row[Ellipsis, i * block_size : (i + 1) * block_size], |
|
axis1=-1, |
|
axis2=-2, |
|
) |
|
) |
|
|
|
return jnp.block( |
|
[row + row_t for row, row_t in zip(blocks[:-1], off_diags)] + [blocks[-1]] |
|
) |
|
|
|
|
|
@functools.partial(jax.jit, static_argnames=("num_blocks")) |
|
def materialize_matrix_from_concat( |
|
block_rows_concat, |
|
num_blocks, |
|
): |
|
"""Returns a materialized symmetric matrix from concatenated slices. |
|
|
|
Args: |
|
block_rows_concat: The matrix represented as the concatenated |
|
lower-triangular blocks. |
|
num_blocks: The number of block-rows used to represent the symmetric matrix. |
|
""" |
|
block_size = block_rows_concat.shape[-2] |
|
|
|
block_rows = [ |
|
block_rows_concat[ |
|
Ellipsis, |
|
(k * (k + 1)) |
|
// 2 |
|
* block_size : (((k + 1) * (k + 2)) // 2 + 1) |
|
* block_size, |
|
] |
|
for k in range(num_blocks) |
|
] |
|
|
|
return materialize_matrix(SlicedSymmetricMatrix(block_rows=block_rows)) |
|
|
|
|
|
@functools.partial(jax.jit, static_argnames=("alpha", "beta", "axes")) |
|
def update_sliced_rows( |
|
symmetric_matrix, |
|
mat, |
|
alpha, |
|
beta, |
|
axes=(-1,), |
|
): |
|
"""Implements the blocked equivalent of SYRK. |
|
|
|
Specifically, the symmetric matrix (represented using lower-triangular block |
|
rows) is updated using the sliced product of mat. |
|
|
|
Args: |
|
symmetric_matrix: The symmetric matrix to update. |
|
mat: The matrix to use for the update = mat * mat^T. The number of rows |
|
should match that of symmetric_matrix. |
|
alpha: The weight for the update. |
|
beta: The weight for the original symmetric matrix. |
|
axes: Axes to use for the contraction of the update. |
|
|
|
Returns: |
|
The updated rows of alpha * mat * mat^T + beta * symmetric_matrix. |
|
""" |
|
block_size = symmetric_matrix.block_rows[0].shape[-2] |
|
sym_prod = sliced_transposed_product(mat=mat, block_size=block_size, axes=axes) |
|
return SlicedSymmetricMatrix( |
|
block_rows=[ |
|
update * alpha + row * beta |
|
for update, row in zip(sym_prod.block_rows, symmetric_matrix.block_rows) |
|
] |
|
) |
|
|
|
|
|
def find_num_blocks(block_rows_concat): |
|
"""Returns the number of (row) blocks representing the concatenated matrix. |
|
|
|
For example, an input with dimensions [256, 2560] represents 10 square blocks, |
|
which matches 4 lower-triangular block rows (1+2+3+4). So this function will |
|
return 4. |
|
|
|
Use ordinary numpy functions here so that the returned value is static. |
|
|
|
Args: |
|
block_rows_concat: The concatenated block array. |
|
|
|
Raises: |
|
ValueError: When the dimensions of the matrix do not correspond to a lower |
|
triangular block representation. |
|
""" |
|
|
|
total_blocks = block_rows_concat.shape[-1] / block_rows_concat.shape[-2] |
|
|
|
num_blocks = np.round((np.sqrt(8 * total_blocks + 1) - 1) / 2).astype(np.int32) |
|
if num_blocks * (num_blocks + 1) / 2 != total_blocks: |
|
raise ValueError( |
|
"Could not determine an appropriate number of blocks for " |
|
"the concatenated matrix." |
|
) |
|
else: |
|
return num_blocks |
|
|