|
from typing import Tuple, Union, Literal |
|
|
|
from einops import repeat |
|
import torch |
|
import numpy as np |
|
|
|
|
|
def get_diags_indices( |
|
shape: Union[int, Tuple[int, int]], k_min: int = 0, k_max: int = 0 |
|
): |
|
if isinstance(shape, int): |
|
shape = (shape, shape) |
|
rows, cols = np.indices(shape) |
|
diag = cols - rows |
|
return np.where((diag >= k_min) & (diag <= k_max)) |
|
|
|
|
|
def generate_mask_from_indices( |
|
shape: Tuple[int, int], |
|
indices: Tuple[np.ndarray, np.ndarray], |
|
big_value: float = 0, |
|
small_value: float = -1e9, |
|
): |
|
matrix = np.ones(shape) * small_value |
|
matrix[indices] = big_value |
|
return matrix |
|
|
|
|
|
def generate_sparse_causcal_attn_mask( |
|
batch_size: int, |
|
n: int, |
|
n_near: int = 1, |
|
big_value: float = 0, |
|
small_value: float = -1e9, |
|
out_type: Literal["torch", "numpy"] = "numpy", |
|
expand: int = 1, |
|
) -> np.ndarray: |
|
"""generate b (n expand) (n expand) mask, |
|
where value of diag (0<=<=n_near) and first column of shape mat (n n) is set as big_value, others as small value |
|
expand的概念: |
|
attn 是 b n d 时,mask 是 b n n, 当 attn 是 b (expand n) d 时, mask 是 b (n expand) (n expand) |
|
Args: |
|
batch_size (int): _description_ |
|
n (int): _description_ |
|
n_near (int, optional): _description_. Defaults to 1. |
|
big_value (float, optional): _description_. Defaults to 0. |
|
small_value (float, optional): _description_. Defaults to -1e9. |
|
out_type (Literal["torch", "numpy"], optional): _description_. Defaults to "numpy". |
|
expand (int, optional): _description_. Defaults to 1. |
|
|
|
Returns: |
|
np.ndarray: _description_ |
|
""" |
|
shape = (n, n) |
|
diag_indices = get_diags_indices(n, k_min=-n_near, k_max=0) |
|
first_column = (np.arange(n), np.zeros(n).astype(np.int)) |
|
indices = ( |
|
np.concatenate([diag_indices[0], first_column[0]]), |
|
np.concatenate([diag_indices[1], first_column[1]]), |
|
) |
|
mask = generate_mask_from_indices( |
|
shape=shape, indices=indices, big_value=big_value, small_value=small_value |
|
) |
|
mask = repeat(mask, "m n-> b m n", b=batch_size) |
|
if expand > 1: |
|
mask = repeat( |
|
mask, |
|
"b m n -> b (m d1) (n d2)", |
|
d1=expand, |
|
d2=expand, |
|
) |
|
if out_type == "torch": |
|
mask = torch.from_numpy(mask) |
|
return mask |
|
|