[Neuron][Kernel] Support Longer Sequences in NKI-based Flash PagedAttention and Improve Efficiency (#12921)

Signed-off-by: Lingfan Yu <lingfany@amazon.com>
This commit is contained in:
Lingfan Yu
2025-02-11 21:12:37 -08:00
committed by GitHub
parent 842b0fd402
commit e92694b6fe
2 changed files with 152 additions and 178 deletions

View File

@@ -28,7 +28,6 @@ class FlashConfig:
def transpose_p_local(p_local_transposed,
p_local,
LARGE_TILE_SZ,
forward_mask,
B_F_SIZE=512):
for i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE):
if nisa.get_nc_version() == nisa.nc_version.gen3:
@@ -46,13 +45,13 @@ def transpose_p_local(p_local_transposed,
if nisa.get_nc_version() == nisa.nc_version.gen3:
p_local_t_tmp[:, j_128_slice] = nisa.dma_transpose(
p_local[:, i_j_128_slice], mask=forward_mask)
p_local[:, i_j_128_slice])
else:
p_local_t_tmp[:, j_128_slice] = nisa.nc_transpose(
p_local[:, i_j_128_slice], mask=forward_mask)
p_local[:, i_j_128_slice])
p_local_transposed[:, nl.ds(i * B_F_SIZE, B_F_SIZE)] = nl.copy(
p_local_t_tmp, dtype=p_local_transposed.dtype, mask=forward_mask)
p_local_t_tmp, dtype=p_local_transposed.dtype)
@nki.jit
@@ -60,36 +59,25 @@ def _flash_attention_core(
q_local_tile,
k,
v,
q_h_per_k_h,
seqlen_q,
nheads,
o_buffer,
l_buffer,
m_buffer,
batch_id,
head_id,
gqa_head_idx,
q_tile_idx,
local_k_large_tile_idx,
kernel_dtype,
acc_type,
flash_config: FlashConfig,
use_causal_mask=False,
continuous_batching_mask=None,
use_causal_mask,
tile_mask,
initialize=False,
B_P_SIZE=128,
B_F_SIZE=512,
B_D_SIZE=128,
dropout_p=0.0,
dropout_p_tensor=None,
seed_tensor=None,
logit_bias_tile=None,
qk_res_buffer=None,
):
"""
The flash attention core function to calculate self attention between a tile
of q and a block of K and V.
The q_local_tile has (B_P_SIZE, B_F_SIZE), which is loaded into the SBUF
The q_local_tile has (B_P_SIZE, B_F_SIZE), which is loaded into the SBUF
already. The block size of K and V
is defined in the seq_tile_size of the flash_config. The results are stored
in the following three buffers
@@ -99,24 +87,9 @@ def _flash_attention_core(
"""
LARGE_TILE_SZ = flash_config.seq_tile_size
num_k_tile_per_large_tile = LARGE_TILE_SZ // B_F_SIZE
seqlen_k = k.shape[-1]
seqlen_q // B_P_SIZE
seqlen_k // B_F_SIZE
# TODO : support logit_bias with continuous_batching_mask
assert not use_causal_mask, "causal mask is not supported."
assert (continuous_batching_mask
is not None), "continuous_batching_mask input is required."
if continuous_batching_mask is not None:
assert (
logit_bias_tile
is None), "continuous_batching_mask does not support logit_bias!"
# mask are used to only apply computation to the lower half of the matrix,
# which reduce the arithmetic intensity by half
forward_mask = (q_tile_idx * B_P_SIZE >= local_k_large_tile_idx *
LARGE_TILE_SZ if use_causal_mask else None)
qk_res_buf = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
buffer=nl.sbuf,
dtype=acc_type)
@@ -125,20 +98,27 @@ def _flash_attention_core(
for k_i in nl.affine_range(num_k_tile_per_large_tile):
k_i_b_f_slice = nl.ds(k_i * B_F_SIZE, B_F_SIZE)
qk_psum = nl.zeros((par_dim(B_P_SIZE), B_F_SIZE),
dtype=np.float32,
buffer=nl.psum) # (128, 512)
qk_psum[:, :] = nl.matmul(q_local_tile,
k[:, k_i_b_f_slice],
transpose_x=True,
mask=None) # (p(128), 512)
if use_causal_mask:
multiplication_required_selection = (q_tile_idx * B_P_SIZE
>= k_i * B_F_SIZE)
else:
multiplication_required_selection = True
qk_res_buf[:, k_i_b_f_slice] = nl.where(
continuous_batching_mask[:, k_i_b_f_slice],
qk_psum[:, nl.ds(0, B_F_SIZE)],
-9984.0,
dtype=acc_type,
)
if multiplication_required_selection:
qk_psum = nl.ndarray((par_dim(B_P_SIZE), B_F_SIZE),
dtype=np.float32,
buffer=nl.psum) # (128, 512)
qk_psum[:, :] = nl.matmul(q_local_tile,
k[:, k_i_b_f_slice],
transpose_x=True) # (p(128), 512)
qk_res_buf[:, k_i_b_f_slice] = nl.where(
tile_mask[:, k_i_b_f_slice],
qk_psum[:, nl.ds(0, B_F_SIZE)],
-9984.0,
dtype=acc_type,
)
else:
qk_res_buf[:, k_i_b_f_slice] = -9984.0
# Calculate max of the current tile
max_local[:, k_i] = nisa.tensor_reduce(
@@ -147,7 +127,6 @@ def _flash_attention_core(
axis=(1, ),
dtype=acc_type,
negate=False,
mask=forward_mask,
)
if qk_res_buffer is not None:
@@ -159,7 +138,6 @@ def _flash_attention_core(
axis=(1, ),
dtype=acc_type,
negate=False,
mask=forward_mask,
)
o_previous_scaled = nl.ndarray((par_dim(B_P_SIZE), B_D_SIZE),
@@ -170,8 +148,7 @@ def _flash_attention_core(
m_current = max_
else:
m_previous = nl.copy(m_buffer[:, 0])
m_buffer[:, 0] = nl.maximum(m_previous, max_,
mask=forward_mask) # (128,1)
m_buffer[:, 0] = nl.maximum(m_previous, max_) # (128,1)
m_current = m_buffer[:, 0]
# Compute scaling factor
@@ -180,11 +157,8 @@ def _flash_attention_core(
m_previous,
bias=-1 * m_current,
scale=1.0,
mask=forward_mask,
)
o_previous_scaled[...] = nl.multiply(o_buffer[:, :],
alpha,
mask=forward_mask)
o_previous_scaled[...] = nl.multiply(o_buffer[:, :], alpha)
p_local = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
dtype=kernel_dtype)
@@ -207,10 +181,9 @@ def _flash_attention_core(
reduce_op=nl.add,
reduce_res=p_partial_sum[:, k_r_i],
dtype=kernel_dtype,
mask=forward_mask,
)
ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type, mask=forward_mask)
ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type)
p_local_transposed = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
dtype=kernel_dtype)
@@ -218,7 +191,6 @@ def _flash_attention_core(
p_local_transposed=p_local_transposed,
p_local=p_local,
LARGE_TILE_SZ=LARGE_TILE_SZ,
forward_mask=forward_mask,
B_F_SIZE=B_F_SIZE,
)
@@ -230,27 +202,20 @@ def _flash_attention_core(
p_local_transposed[:, nl.ds(k_i * B_P_SIZE, B_P_SIZE)],
v[k_i, :, :],
transpose_x=True,
mask=forward_mask,
) # (128, 128) (p(Br), d)
if initialize:
o_buffer[:, :] = nl.copy(pv_psum[:, :])
l_buffer[:, 0] = nl.add(nl.log(ps), max_)
else:
o_buffer[:, :] = nl.add(o_previous_scaled, pv_psum, mask=forward_mask)
o_buffer[:, :] = nl.add(o_previous_scaled, pv_psum)
l_prev = l_buffer[:, 0]
l_exp = nl.add(
nl.exp(
nl.subtract(l_prev, m_current, mask=forward_mask),
mask=forward_mask,
),
nl.exp(nl.subtract(l_prev, m_current)),
ps,
mask=forward_mask,
)
l_buffer[:, 0] = nl.add(m_current,
nl.log(l_exp, mask=forward_mask),
mask=forward_mask)
l_buffer[:, 0] = nl.add(m_current, nl.log(l_exp))
@nki.jit
@@ -279,6 +244,21 @@ def load_v_tile(v_hbm_tile, cur_v_tile, j, v_i, config):
)
@nki.jit
def load_block_tables(block_tables_hbm, num_tiles):
(num_blocks, ) = block_tables_hbm.shape
assert num_blocks % num_tiles == 0
num_blocks_per_tile = num_blocks // num_tiles
block_tables_hbm = block_tables_hbm.reshape(
(num_tiles, num_blocks_per_tile))
block_tables_buffer = nl.load(block_tables_hbm, dtype=nl.int32)
return block_tables_buffer
def is_power_of_2(x):
return x > 0 and (x & (x - 1)) == 0
@nki.jit
def flash_paged_attention(
query,
@@ -316,24 +296,24 @@ def flash_paged_attention(
- We use paged cache blocks (key_cache, value_cache) to store KV cache.
IO tensor dtypes:
- This kernel assumes all IO tensors have the same dtype except for
- This kernel assumes all IO tensors have the same dtype except for
block_tables (int32) and mask (int32)
- If mixed_percision is True, then all Tensor Engine operation will be
performed in bfloat16 and accumulation will be performed in float32.
- If mixed_percision is True, then all Tensor Engine operation will be
performed in bfloat16 and accumulation will be performed in float32.
Otherwise the intermediates will be in the same type as the inputs.
Compile-time Constants:
- softmax_scale: scaling for softmax, is None, default is `1.0/(d**0.5)`
- mixed_precision: flag to set non-matmul ops in fp32 precision, default
is set to `true`, if false, we use same precision as input types
is set to `true`, if false, we use same precision as input types
- config: Instance of dataclass :class:`nki.kernels.attention.FlashConfig`
with Performance config parameters for flash attention with default
values
seq_tile_size: `default=2048`, size of the kv tile size for attention
seq_tile_size: `default=2048`, size of the kv tile size for attention
computation reduction
GQA support Notes:
the spmd kernel for launching kernel should be on kv_heads instead of
the spmd kernel for launching kernel should be on kv_heads instead of
nheads
Example usage:
@@ -415,18 +395,13 @@ def flash_paged_attention(
), f"Need B_P_SIZE ({B_P_SIZE}) to be divisible by {block_size=}"
num_large_k_tile = context_kv_len // LARGE_TILE_SZ
num_blocks_per_large_tile = LARGE_TILE_SZ // block_size
assert (num_blocks_per_large_tile <= B_P_SIZE
), f"The number of blocks in each large tile " \
f"({num_blocks_per_large_tile}) shouldn't exceed partition size {B_P_SIZE}"
assert block_size % 32 == 0, "block_size is expected to be a multiple of 32"
assert is_power_of_2(
num_blocks_per_large_tile
), "The number of blocks in each large tile is expected of be power of 2"
assert is_power_of_2(seqlen_q), "seqlen_q is expected to be power of 2"
block_tables_sbuf = nl.full((par_dim(B_P_SIZE), num_large_k_tile),
0,
dtype=np.int32,
buffer=nl.sbuf)
for j in nl.affine_range(num_large_k_tile):
i_p = nl.arange(num_blocks_per_large_tile)[:, None]
block_tables_sbuf[i_p, j] = nl.load(
block_tables[j * num_blocks_per_large_tile + i_p], dtype=np.int32)
block_tables_sbuf = load_block_tables(block_tables, num_large_k_tile)
# Global Flash Attention accumulators
o_buffer = nl.zeros(
@@ -457,7 +432,7 @@ def flash_paged_attention(
)
for k_i in nl.affine_range(num_blocks_per_large_tile):
loaded = nl.load(key_cache[block_tables_sbuf[k_i, j], :,
loaded = nl.load(key_cache[block_tables_sbuf[j, k_i], :,
head_id, :])
cur_k_tile[:, nl.ds(k_i *
block_size, block_size)] = nl.transpose(loaded)
@@ -469,7 +444,7 @@ def flash_paged_attention(
num_blocks_per_partition):
v_i = (partition_idx * num_blocks_per_partition +
block_in_partition)
loaded_v = nl.load(value_cache[block_tables_sbuf[v_i, j], :,
loaded_v = nl.load(value_cache[block_tables_sbuf[j, v_i], :,
head_id, :])
cur_v_tile[
partition_idx,
@@ -477,14 +452,15 @@ def flash_paged_attention(
:,
] = loaded_v
cur_mask = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
dtype=mask.dtype)
for m_i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE):
cur_mask[:, nl.ds(m_i * B_F_SIZE, B_F_SIZE)] = nl.load(
mask[:, nl.ds(j * LARGE_TILE_SZ + m_i * B_F_SIZE, B_F_SIZE)])
for i_q_h in nl.affine_range(q_h_per_k_h):
for i in nl.affine_range(n_tile_q):
for i in nl.affine_range(n_tile_q):
cur_mask = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
dtype=mask.dtype)
for m_i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE):
cur_mask[:, nl.ds(m_i * B_F_SIZE, B_F_SIZE)] = nl.load(mask[
nl.ds(i * B_P_SIZE, B_P_SIZE),
nl.ds(j * LARGE_TILE_SZ + m_i * B_F_SIZE, B_F_SIZE),
])
for i_q_h in nl.affine_range(q_h_per_k_h):
q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype)
q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h]
q_sbuf_tile = nl.load(
@@ -497,35 +473,24 @@ def flash_paged_attention(
q_local_tile=q_tile,
k=cur_k_tile,
v=cur_v_tile,
q_h_per_k_h=q_h_per_k_h,
seqlen_q=seqlen_q,
nheads=h,
o_buffer=o_buffer[i, i_q_h],
l_buffer=l_buffer[:, i, i_q_h],
m_buffer=m_buffer[i, i_q_h],
batch_id=batch_id,
head_id=head_id,
gqa_head_idx=i_q_h,
q_tile_idx=i,
local_k_large_tile_idx=j,
kernel_dtype=kernel_dtype,
acc_type=acc_type,
flash_config=config,
use_causal_mask=False,
continuous_batching_mask=cur_mask,
tile_mask=cur_mask,
initialize=j == 0,
B_P_SIZE=B_P_SIZE,
B_F_SIZE=B_F_SIZE,
B_D_SIZE=B_D_SIZE,
dropout_p=0.0,
dropout_p_tensor=None,
seed_tensor=None,
logit_bias_tile=None,
)
# compute attention between input query, key and value
if key is not None and value is not None:
B_F_SIZE = seqlen_q
B_F_SIZE = min(seqlen_q, B_F_SIZE)
LARGE_TILE_SZ = seqlen_q
active_config = FlashConfig(
seq_tile_size=LARGE_TILE_SZ,
@@ -552,11 +517,16 @@ def flash_paged_attention(
config=active_config,
)
cur_mask = nl.ndarray((par_dim(B_P_SIZE), B_F_SIZE), dtype=mask.dtype)
cur_mask[:, :] = nl.load(mask[:, nl.ds(context_kv_len, B_F_SIZE)])
for i in nl.affine_range(n_tile_q):
cur_mask = nl.load(
mask[
nl.ds(i * B_P_SIZE, B_P_SIZE),
nl.ds(context_kv_len, LARGE_TILE_SZ),
],
dtype=mask.dtype,
)
for i_q_h in nl.affine_range(q_h_per_k_h):
for i_q_h in nl.affine_range(q_h_per_k_h):
for i in nl.affine_range(n_tile_q):
q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype)
q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h]
q_sbuf_tile = nl.load(
@@ -568,32 +538,21 @@ def flash_paged_attention(
q_local_tile=q_tile,
k=cur_k_tile,
v=cur_v_tile,
q_h_per_k_h=q_h_per_k_h,
seqlen_q=seqlen_q,
nheads=h,
o_buffer=o_buffer[i, i_q_h],
l_buffer=l_buffer[:, i, i_q_h],
m_buffer=m_buffer[i, i_q_h],
batch_id=batch_id,
head_id=head_id,
gqa_head_idx=i_q_h,
q_tile_idx=i,
local_k_large_tile_idx=0,
kernel_dtype=kernel_dtype,
acc_type=acc_type,
flash_config=active_config,
use_causal_mask=False,
continuous_batching_mask=cur_mask,
use_causal_mask=True,
tile_mask=cur_mask,
initialize=False,
B_P_SIZE=B_P_SIZE,
B_F_SIZE=B_F_SIZE,
B_D_SIZE=B_D_SIZE,
dropout_p=0.0,
dropout_p_tensor=None,
seed_tensor=None,
logit_bias_tile=None,
qk_res_buffer=qk_res_buffer[i, i_q_h]
if qk_res_buffer is not None else None,
qk_res_buffer=(qk_res_buffer[i, i_q_h]
if qk_res_buffer is not None else None),
)
# -- -- -- -- write output to buffer on HBM -- -- -- -- -- -- #
@@ -652,7 +611,6 @@ def flash_attn_varlen_nkifunc(
attn_mask,
n_kv_head=None,
head_size=None,
B_P_SIZE=128,
LARGE_TILE_SZ=2048,
return_debug_tensors=False,
mixed_precision=True,