[Neuron][Kernel] Vectorize KV cache load in FlashPagedAttention to maximize DMA bandwidth (#13245)

Signed-off-by: Lingfan Yu <lingfany@amazon.com>
This commit is contained in:
Lingfan Yu
2025-02-20 17:45:45 -08:00
committed by GitHub
parent 71face8540
commit 33170081f1
3 changed files with 769 additions and 353 deletions

View File

@@ -107,7 +107,7 @@ def ref_masked_attention(
masked_score, dim=-1, return_max_reduce=True)
else:
norm_score = ref_softmax(masked_score, dim=-1)
out = torch.einsum("hqk,khd->qhd", norm_score, value)
out = torch.einsum("hqk,khd->qhd", norm_score.to(value.dtype), value)
if return_max_reduce:
return (
out,
@@ -118,7 +118,7 @@ def ref_masked_attention(
scaled_qk,
)
else:
return out
return (out, )
def ref_context_attention(
@@ -128,8 +128,6 @@ def ref_context_attention(
query_lens,
seq_lens,
head_size,
num_kv_heads,
num_heads,
num_queries_per_kv,
return_max_reduce=False,
):
@@ -146,18 +144,19 @@ def ref_context_attention(
attn_mask = torch.logical_not(attn_mask)
attn_mask = attn_mask.float() * -30000
output, cached_max, cached_sum_reciprocal, lse, masked_score, scaled_qk = (
ref_masked_attention(
query,
key,
value,
scale,
attn_mask,
return_max_reduce=return_max_reduce,
))
output, *debug_tensors = ref_masked_attention(
query,
key,
value,
scale,
attn_mask,
return_max_reduce=return_max_reduce,
)
output = output.unsqueeze(1)
if return_max_reduce:
cached_max, cached_sum_reciprocal, lse, masked_score, scaled_qk = (
debug_tensors)
return (
output,
cached_max,
@@ -170,65 +169,22 @@ def ref_context_attention(
return output
@pytest.mark.parametrize(
"block_size, large_tile_size",
[
(32, 2048), # 64 blocks
(32, 4096), # 128 blocks
(32, 8192), # 256 blocks
(64, 8192), # 128 blocks
],
)
@pytest.mark.parametrize(
"num_heads,num_queries_per_kv,head_size,mixed_precision",
[
(4, 2, 8, False),
(4, 2, 8, True),
(32, 8, 64, True),
(16, 2, 128, True),
],
)
@torch.inference_mode()
def test_contexted_kv_attention(
num_heads: int,
num_queries_per_kv: int,
head_size: int,
block_size: int,
large_tile_size,
mixed_precision: bool,
) -> None:
import os
import torch_xla.core.xla_model as xm
from vllm.attention.ops.nki_flash_attn import flash_attn_varlen_nkifunc
assert large_tile_size % block_size == 0
device = xm.xla_device()
compiler_flags = [
"--model-type=transformer -O1",
"--internal-hlo2tensorizer-options='--verify-hlo'",
"--retry_failed_compilation",
]
compiler_flags_str = " ".join(compiler_flags)
os.environ["NEURON_CC_FLAGS"] = compiler_flags_str
torch.manual_seed(0)
torch.set_printoptions(sci_mode=False)
min_ctx_len = 32
max_ctx_len = 1024
min_query_len = 16
max_query_len = 512
prefill_batch_size = 4
decode_batch_size = 12
def sample_inputs(
prefill_batch_size,
decode_batch_size,
min_query_len,
max_query_len,
min_ctx_len,
max_ctx_len,
block_size,
num_heads,
num_kv_heads,
head_size,
dtype,
):
batch_size = prefill_batch_size + decode_batch_size
max_model_len = (max_query_len + max_ctx_len) * 4
max_block_per_request = max_model_len // block_size
dtype = torch.float32
cache_size = (batch_size * max_block_per_request) + 2
prefill_ctx_lens = torch.randint(min_ctx_len,
max_ctx_len + 1, (prefill_batch_size, ),
@@ -244,7 +200,6 @@ def test_contexted_kv_attention(
dtype=torch.long,
).tolist() + [1 for _ in range(decode_batch_size)]
seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
num_kv_heads = num_heads // num_queries_per_kv
num_tokens = sum(query_lens)
query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
@@ -304,47 +259,139 @@ def test_contexted_kv_attention(
cur_ctx += block_size
block_id += 1
return (
query,
k,
v,
k_cache,
v_cache,
block_table,
key,
value,
query_lens,
seq_lens,
)
def get_active_block_tables(block_tables, query_lens, seq_lens, block_size,
num_blocks):
context_lens = seq_lens - query_lens
blocks_per_seq = (context_lens + block_size - 1) // block_size
num_seqs = len(seq_lens)
active_blocks: list[int] = []
for seq_id in range(num_seqs):
active_blocks = (
active_blocks +
block_tables[seq_id, :blocks_per_seq[seq_id]].tolist())
return F.pad(
torch.tensor(active_blocks, dtype=torch.int32),
(0, num_blocks - len(active_blocks)),
"constant",
0,
)
@pytest.mark.parametrize(
"prefill_batch_size,decode_batch_size,block_size,large_tile_size",
[
(1, 199, 1, 512), # 512 blocks
(4, 12, 256, 2048), # 128 blocks
(4, 12, 16, 2048), # 128 blocks
(4, 12, 4, 1024), # 256 blocks
(4, 12, 32, 2048), # 64 blocks
(4, 12, 32, 4096), # 128 blocks
(4, 12, 32, 8192), # 256 blocks
(4, 12, 64, 8192), # 128 blocks
],
)
@pytest.mark.parametrize(
"num_heads,num_queries_per_kv,head_size",
[
(4, 2, 8),
(32, 8, 64),
(4, 4, 128),
(8, 1, 32),
],
)
@pytest.mark.parametrize("mixed_precision", [True, False])
@torch.inference_mode()
def test_contexted_kv_attention(
prefill_batch_size: int,
decode_batch_size: int,
num_heads: int,
num_queries_per_kv: int,
head_size: int,
block_size: int,
large_tile_size,
mixed_precision: bool,
) -> None:
import os
import torch_xla.core.xla_model as xm
from vllm.attention.ops.nki_flash_attn import (flash_attn_varlen_nkifunc,
reorder_context_mask)
assert large_tile_size % block_size == 0
device = xm.xla_device()
compiler_flags = [
"-O1",
"--retry_failed_compilation",
]
compiler_flags_str = " ".join(compiler_flags)
os.environ["NEURON_CC_FLAGS"] = compiler_flags_str
torch.manual_seed(0)
torch.set_printoptions(sci_mode=False)
dtype = torch.float32
min_ctx_len = 32
max_ctx_len = 1024
min_query_len = 16
max_query_len = 512
num_kv_heads = num_heads // num_queries_per_kv
(
output_ref,
cached_max,
cached_sum_reciprocal,
lse,
masked_score,
scaled_qk,
) = ref_context_attention(
query,
k_active,
v_active,
k_cache,
v_cache,
block_table,
key,
value,
query_lens,
seq_lens,
) = sample_inputs(
prefill_batch_size=prefill_batch_size,
decode_batch_size=decode_batch_size,
min_query_len=min_query_len,
max_query_len=max_query_len,
min_ctx_len=min_ctx_len,
max_ctx_len=max_ctx_len,
block_size=block_size,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_size=head_size,
dtype=dtype,
)
output_ref = ref_context_attention(
query,
key,
value,
query_lens,
seq_lens,
head_size,
num_kv_heads,
num_heads,
num_queries_per_kv,
return_max_reduce=True,
return_max_reduce=False,
)
# build neuron program
return_debug_tensors = False
B_P_SIZE = 128
LARGE_TILE_SZ = large_tile_size
def get_active_block_tables(block_tables, query_lens, seq_lens, block_size,
num_blocks):
context_lens = seq_lens - query_lens
blocks_per_seq = (context_lens + block_size - 1) // block_size
num_seqs = len(seq_lens)
active_blocks: list[int] = []
for seq_id in range(num_seqs):
active_blocks = (
active_blocks +
block_tables[seq_id, :blocks_per_seq[seq_id]].tolist())
return F.pad(
torch.tensor(active_blocks),
(0, num_blocks - len(active_blocks)),
"constant",
0,
)
assert (large_tile_size >= B_P_SIZE
), f"Expect {large_tile_size=} to be larger than {B_P_SIZE=}"
def ceil_div(a, b):
return (a + b - 1) // b
@@ -357,32 +404,27 @@ def test_contexted_kv_attention(
return 2**int(a - 1).bit_length()
# calculate input shapes
max_num_queries = pad_to_multiple(sum(query_lens), block_size)
max_num_queries = pad_to_next_power_of_2(max_num_queries)
head_size_padded = B_P_SIZE
assert head_size_padded >= head_size
max_num_queries = pad_to_next_power_of_2(sum(query_lens))
context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens)
num_active_blocks = ceil_div(context_lens, block_size).sum().item()
num_active_blocks = pad_to_multiple(num_active_blocks,
LARGE_TILE_SZ // block_size)
large_tile_size // block_size)
context_kv_len = num_active_blocks * block_size
assert (context_kv_len %
LARGE_TILE_SZ == 0), f"invalid context_kv_len={context_kv_len}"
large_tile_size == 0), f"invalid context_kv_len={context_kv_len}"
# pad QKV tensors
pad_dims = (
0,
head_size_padded - query.shape[2],
0,
0,
0,
0,
max_num_queries - query.shape[0],
)
query = F.pad(query, pad_dims, "constant", 0)
k = F.pad(k, pad_dims, "constant", 0)
v = F.pad(v, pad_dims, "constant", 0)
k_cache = F.pad(k_cache, (0, head_size_padded - head_size), "constant", 0)
v_cache = F.pad(v_cache, (0, head_size_padded - head_size), "constant", 0)
k = F.pad(k_active, pad_dims, "constant", 0)
v = F.pad(v_active, pad_dims, "constant", 0)
# permute QKV tensors
# query: (1, n_heads, d, seq_q)
@@ -391,6 +433,8 @@ def test_contexted_kv_attention(
query = query.unsqueeze(0).permute(0, 2, 3, 1).contiguous()
k = k.unsqueeze(0).permute(0, 2, 3, 1).contiguous()
v = v.unsqueeze(0).permute(0, 2, 1, 3).contiguous()
k_cache = k_cache.permute(0, 2, 1, 3).contiguous()
v_cache = v_cache.permute(0, 2, 1, 3).contiguous()
# transform block table
active_block_table = get_active_block_tables(
@@ -405,33 +449,31 @@ def test_contexted_kv_attention(
prior_mask, active_mask = (
BlockDiagonalCausalFromBottomRightMask.from_seqlens(
query_lens, seq_lens, block_size=block_size))
attn_mask = torch.concat(
[
F.pad(
prior_mask,
(
0,
context_kv_len - prior_mask.shape[1],
0,
max_num_queries - prior_mask.shape[0],
),
"constant",
0,
).bool(),
F.pad(
active_mask,
(
0,
max_num_queries - active_mask.shape[1],
0,
max_num_queries - active_mask.shape[0],
),
"constant",
0,
).bool(),
],
dim=1,
)
prior_mask_padded = F.pad(
prior_mask,
(
0,
context_kv_len - prior_mask.shape[1],
0,
max_num_queries - prior_mask.shape[0],
),
"constant",
0,
).bool()
active_mask_padded = F.pad(
active_mask,
(
0,
max_num_queries - active_mask.shape[1],
0,
max_num_queries - active_mask.shape[0],
),
"constant",
0,
).bool()
attn_mask = torch.concat([prior_mask_padded, active_mask_padded], dim=1)
attn_mask = reorder_context_mask(attn_mask, large_tile_size, block_size)
input_args = (
query.to(device=device),
@@ -439,29 +481,21 @@ def test_contexted_kv_attention(
v.to(device=device),
k_cache.to(device=device),
v_cache.to(device=device),
active_block_table.to(torch.int32).to(device=device),
active_block_table.to(device=device),
attn_mask.to(device=device),
)
input_kwargs = dict(
n_kv_head=num_kv_heads,
head_size=head_size,
mixed_precision=mixed_precision,
LARGE_TILE_SZ=LARGE_TILE_SZ,
return_debug_tensors=return_debug_tensors,
LARGE_TILE_SZ=large_tile_size,
)
if return_debug_tensors:
output_nki, *debug_tensors = flash_attn_varlen_nkifunc(
*input_args, **input_kwargs)
else:
output_nki = flash_attn_varlen_nkifunc(*input_args, **input_kwargs)
debug_tensors = []
debug_tensors = [torch.tensor(dt).cpu() for dt in debug_tensors]
output_nki = flash_attn_varlen_nkifunc(*input_args, **input_kwargs)
num_actual_tokens = sum(query_lens)
# - o: shape (bs, n_heads, seq_q, d) -> (bs, seq_q, n_heads, d)
output_nki = output_nki.cpu().permute(0, 2, 1, 3)[:, :, :, :head_size]
output_nki = output_nki.cpu().permute(0, 2, 1, 3)
output_nki = output_nki[0, :num_actual_tokens, :, :]
output_ref_padded = F.pad(
output_ref,