[Neuron][Kernel] Vectorize KV cache load in FlashPagedAttention to maximize DMA bandwidth (#13245)
Signed-off-by: Lingfan Yu <lingfany@amazon.com>
This commit is contained in:
@@ -107,7 +107,7 @@ def ref_masked_attention(
|
||||
masked_score, dim=-1, return_max_reduce=True)
|
||||
else:
|
||||
norm_score = ref_softmax(masked_score, dim=-1)
|
||||
out = torch.einsum("hqk,khd->qhd", norm_score, value)
|
||||
out = torch.einsum("hqk,khd->qhd", norm_score.to(value.dtype), value)
|
||||
if return_max_reduce:
|
||||
return (
|
||||
out,
|
||||
@@ -118,7 +118,7 @@ def ref_masked_attention(
|
||||
scaled_qk,
|
||||
)
|
||||
else:
|
||||
return out
|
||||
return (out, )
|
||||
|
||||
|
||||
def ref_context_attention(
|
||||
@@ -128,8 +128,6 @@ def ref_context_attention(
|
||||
query_lens,
|
||||
seq_lens,
|
||||
head_size,
|
||||
num_kv_heads,
|
||||
num_heads,
|
||||
num_queries_per_kv,
|
||||
return_max_reduce=False,
|
||||
):
|
||||
@@ -146,18 +144,19 @@ def ref_context_attention(
|
||||
attn_mask = torch.logical_not(attn_mask)
|
||||
attn_mask = attn_mask.float() * -30000
|
||||
|
||||
output, cached_max, cached_sum_reciprocal, lse, masked_score, scaled_qk = (
|
||||
ref_masked_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
scale,
|
||||
attn_mask,
|
||||
return_max_reduce=return_max_reduce,
|
||||
))
|
||||
output, *debug_tensors = ref_masked_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
scale,
|
||||
attn_mask,
|
||||
return_max_reduce=return_max_reduce,
|
||||
)
|
||||
|
||||
output = output.unsqueeze(1)
|
||||
if return_max_reduce:
|
||||
cached_max, cached_sum_reciprocal, lse, masked_score, scaled_qk = (
|
||||
debug_tensors)
|
||||
return (
|
||||
output,
|
||||
cached_max,
|
||||
@@ -170,65 +169,22 @@ def ref_context_attention(
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"block_size, large_tile_size",
|
||||
[
|
||||
(32, 2048), # 64 blocks
|
||||
(32, 4096), # 128 blocks
|
||||
(32, 8192), # 256 blocks
|
||||
(64, 8192), # 128 blocks
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"num_heads,num_queries_per_kv,head_size,mixed_precision",
|
||||
[
|
||||
(4, 2, 8, False),
|
||||
(4, 2, 8, True),
|
||||
(32, 8, 64, True),
|
||||
(16, 2, 128, True),
|
||||
],
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_contexted_kv_attention(
|
||||
num_heads: int,
|
||||
num_queries_per_kv: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
large_tile_size,
|
||||
mixed_precision: bool,
|
||||
) -> None:
|
||||
import os
|
||||
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
from vllm.attention.ops.nki_flash_attn import flash_attn_varlen_nkifunc
|
||||
|
||||
assert large_tile_size % block_size == 0
|
||||
|
||||
device = xm.xla_device()
|
||||
|
||||
compiler_flags = [
|
||||
"--model-type=transformer -O1",
|
||||
"--internal-hlo2tensorizer-options='--verify-hlo'",
|
||||
"--retry_failed_compilation",
|
||||
]
|
||||
compiler_flags_str = " ".join(compiler_flags)
|
||||
os.environ["NEURON_CC_FLAGS"] = compiler_flags_str
|
||||
|
||||
torch.manual_seed(0)
|
||||
torch.set_printoptions(sci_mode=False)
|
||||
|
||||
min_ctx_len = 32
|
||||
max_ctx_len = 1024
|
||||
min_query_len = 16
|
||||
max_query_len = 512
|
||||
prefill_batch_size = 4
|
||||
decode_batch_size = 12
|
||||
def sample_inputs(
|
||||
prefill_batch_size,
|
||||
decode_batch_size,
|
||||
min_query_len,
|
||||
max_query_len,
|
||||
min_ctx_len,
|
||||
max_ctx_len,
|
||||
block_size,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype,
|
||||
):
|
||||
batch_size = prefill_batch_size + decode_batch_size
|
||||
max_model_len = (max_query_len + max_ctx_len) * 4
|
||||
|
||||
max_block_per_request = max_model_len // block_size
|
||||
dtype = torch.float32
|
||||
cache_size = (batch_size * max_block_per_request) + 2
|
||||
prefill_ctx_lens = torch.randint(min_ctx_len,
|
||||
max_ctx_len + 1, (prefill_batch_size, ),
|
||||
@@ -244,7 +200,6 @@ def test_contexted_kv_attention(
|
||||
dtype=torch.long,
|
||||
).tolist() + [1 for _ in range(decode_batch_size)]
|
||||
seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
|
||||
num_kv_heads = num_heads // num_queries_per_kv
|
||||
|
||||
num_tokens = sum(query_lens)
|
||||
query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
|
||||
@@ -304,47 +259,139 @@ def test_contexted_kv_attention(
|
||||
cur_ctx += block_size
|
||||
block_id += 1
|
||||
|
||||
return (
|
||||
query,
|
||||
k,
|
||||
v,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_table,
|
||||
key,
|
||||
value,
|
||||
query_lens,
|
||||
seq_lens,
|
||||
)
|
||||
|
||||
|
||||
def get_active_block_tables(block_tables, query_lens, seq_lens, block_size,
|
||||
num_blocks):
|
||||
context_lens = seq_lens - query_lens
|
||||
blocks_per_seq = (context_lens + block_size - 1) // block_size
|
||||
num_seqs = len(seq_lens)
|
||||
active_blocks: list[int] = []
|
||||
for seq_id in range(num_seqs):
|
||||
active_blocks = (
|
||||
active_blocks +
|
||||
block_tables[seq_id, :blocks_per_seq[seq_id]].tolist())
|
||||
return F.pad(
|
||||
torch.tensor(active_blocks, dtype=torch.int32),
|
||||
(0, num_blocks - len(active_blocks)),
|
||||
"constant",
|
||||
0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"prefill_batch_size,decode_batch_size,block_size,large_tile_size",
|
||||
[
|
||||
(1, 199, 1, 512), # 512 blocks
|
||||
(4, 12, 256, 2048), # 128 blocks
|
||||
(4, 12, 16, 2048), # 128 blocks
|
||||
(4, 12, 4, 1024), # 256 blocks
|
||||
(4, 12, 32, 2048), # 64 blocks
|
||||
(4, 12, 32, 4096), # 128 blocks
|
||||
(4, 12, 32, 8192), # 256 blocks
|
||||
(4, 12, 64, 8192), # 128 blocks
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"num_heads,num_queries_per_kv,head_size",
|
||||
[
|
||||
(4, 2, 8),
|
||||
(32, 8, 64),
|
||||
(4, 4, 128),
|
||||
(8, 1, 32),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("mixed_precision", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_contexted_kv_attention(
|
||||
prefill_batch_size: int,
|
||||
decode_batch_size: int,
|
||||
num_heads: int,
|
||||
num_queries_per_kv: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
large_tile_size,
|
||||
mixed_precision: bool,
|
||||
) -> None:
|
||||
import os
|
||||
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
from vllm.attention.ops.nki_flash_attn import (flash_attn_varlen_nkifunc,
|
||||
reorder_context_mask)
|
||||
|
||||
assert large_tile_size % block_size == 0
|
||||
|
||||
device = xm.xla_device()
|
||||
|
||||
compiler_flags = [
|
||||
"-O1",
|
||||
"--retry_failed_compilation",
|
||||
]
|
||||
compiler_flags_str = " ".join(compiler_flags)
|
||||
os.environ["NEURON_CC_FLAGS"] = compiler_flags_str
|
||||
|
||||
torch.manual_seed(0)
|
||||
torch.set_printoptions(sci_mode=False)
|
||||
dtype = torch.float32
|
||||
|
||||
min_ctx_len = 32
|
||||
max_ctx_len = 1024
|
||||
min_query_len = 16
|
||||
max_query_len = 512
|
||||
num_kv_heads = num_heads // num_queries_per_kv
|
||||
(
|
||||
output_ref,
|
||||
cached_max,
|
||||
cached_sum_reciprocal,
|
||||
lse,
|
||||
masked_score,
|
||||
scaled_qk,
|
||||
) = ref_context_attention(
|
||||
query,
|
||||
k_active,
|
||||
v_active,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_table,
|
||||
key,
|
||||
value,
|
||||
query_lens,
|
||||
seq_lens,
|
||||
) = sample_inputs(
|
||||
prefill_batch_size=prefill_batch_size,
|
||||
decode_batch_size=decode_batch_size,
|
||||
min_query_len=min_query_len,
|
||||
max_query_len=max_query_len,
|
||||
min_ctx_len=min_ctx_len,
|
||||
max_ctx_len=max_ctx_len,
|
||||
block_size=block_size,
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
output_ref = ref_context_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
query_lens,
|
||||
seq_lens,
|
||||
head_size,
|
||||
num_kv_heads,
|
||||
num_heads,
|
||||
num_queries_per_kv,
|
||||
return_max_reduce=True,
|
||||
return_max_reduce=False,
|
||||
)
|
||||
|
||||
# build neuron program
|
||||
return_debug_tensors = False
|
||||
B_P_SIZE = 128
|
||||
LARGE_TILE_SZ = large_tile_size
|
||||
|
||||
def get_active_block_tables(block_tables, query_lens, seq_lens, block_size,
|
||||
num_blocks):
|
||||
context_lens = seq_lens - query_lens
|
||||
blocks_per_seq = (context_lens + block_size - 1) // block_size
|
||||
num_seqs = len(seq_lens)
|
||||
active_blocks: list[int] = []
|
||||
for seq_id in range(num_seqs):
|
||||
active_blocks = (
|
||||
active_blocks +
|
||||
block_tables[seq_id, :blocks_per_seq[seq_id]].tolist())
|
||||
return F.pad(
|
||||
torch.tensor(active_blocks),
|
||||
(0, num_blocks - len(active_blocks)),
|
||||
"constant",
|
||||
0,
|
||||
)
|
||||
assert (large_tile_size >= B_P_SIZE
|
||||
), f"Expect {large_tile_size=} to be larger than {B_P_SIZE=}"
|
||||
|
||||
def ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
@@ -357,32 +404,27 @@ def test_contexted_kv_attention(
|
||||
return 2**int(a - 1).bit_length()
|
||||
|
||||
# calculate input shapes
|
||||
max_num_queries = pad_to_multiple(sum(query_lens), block_size)
|
||||
max_num_queries = pad_to_next_power_of_2(max_num_queries)
|
||||
head_size_padded = B_P_SIZE
|
||||
assert head_size_padded >= head_size
|
||||
max_num_queries = pad_to_next_power_of_2(sum(query_lens))
|
||||
context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens)
|
||||
num_active_blocks = ceil_div(context_lens, block_size).sum().item()
|
||||
num_active_blocks = pad_to_multiple(num_active_blocks,
|
||||
LARGE_TILE_SZ // block_size)
|
||||
large_tile_size // block_size)
|
||||
context_kv_len = num_active_blocks * block_size
|
||||
assert (context_kv_len %
|
||||
LARGE_TILE_SZ == 0), f"invalid context_kv_len={context_kv_len}"
|
||||
large_tile_size == 0), f"invalid context_kv_len={context_kv_len}"
|
||||
|
||||
# pad QKV tensors
|
||||
pad_dims = (
|
||||
0,
|
||||
head_size_padded - query.shape[2],
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
max_num_queries - query.shape[0],
|
||||
)
|
||||
query = F.pad(query, pad_dims, "constant", 0)
|
||||
k = F.pad(k, pad_dims, "constant", 0)
|
||||
v = F.pad(v, pad_dims, "constant", 0)
|
||||
k_cache = F.pad(k_cache, (0, head_size_padded - head_size), "constant", 0)
|
||||
v_cache = F.pad(v_cache, (0, head_size_padded - head_size), "constant", 0)
|
||||
k = F.pad(k_active, pad_dims, "constant", 0)
|
||||
v = F.pad(v_active, pad_dims, "constant", 0)
|
||||
|
||||
# permute QKV tensors
|
||||
# query: (1, n_heads, d, seq_q)
|
||||
@@ -391,6 +433,8 @@ def test_contexted_kv_attention(
|
||||
query = query.unsqueeze(0).permute(0, 2, 3, 1).contiguous()
|
||||
k = k.unsqueeze(0).permute(0, 2, 3, 1).contiguous()
|
||||
v = v.unsqueeze(0).permute(0, 2, 1, 3).contiguous()
|
||||
k_cache = k_cache.permute(0, 2, 1, 3).contiguous()
|
||||
v_cache = v_cache.permute(0, 2, 1, 3).contiguous()
|
||||
|
||||
# transform block table
|
||||
active_block_table = get_active_block_tables(
|
||||
@@ -405,33 +449,31 @@ def test_contexted_kv_attention(
|
||||
prior_mask, active_mask = (
|
||||
BlockDiagonalCausalFromBottomRightMask.from_seqlens(
|
||||
query_lens, seq_lens, block_size=block_size))
|
||||
attn_mask = torch.concat(
|
||||
[
|
||||
F.pad(
|
||||
prior_mask,
|
||||
(
|
||||
0,
|
||||
context_kv_len - prior_mask.shape[1],
|
||||
0,
|
||||
max_num_queries - prior_mask.shape[0],
|
||||
),
|
||||
"constant",
|
||||
0,
|
||||
).bool(),
|
||||
F.pad(
|
||||
active_mask,
|
||||
(
|
||||
0,
|
||||
max_num_queries - active_mask.shape[1],
|
||||
0,
|
||||
max_num_queries - active_mask.shape[0],
|
||||
),
|
||||
"constant",
|
||||
0,
|
||||
).bool(),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
prior_mask_padded = F.pad(
|
||||
prior_mask,
|
||||
(
|
||||
0,
|
||||
context_kv_len - prior_mask.shape[1],
|
||||
0,
|
||||
max_num_queries - prior_mask.shape[0],
|
||||
),
|
||||
"constant",
|
||||
0,
|
||||
).bool()
|
||||
active_mask_padded = F.pad(
|
||||
active_mask,
|
||||
(
|
||||
0,
|
||||
max_num_queries - active_mask.shape[1],
|
||||
0,
|
||||
max_num_queries - active_mask.shape[0],
|
||||
),
|
||||
"constant",
|
||||
0,
|
||||
).bool()
|
||||
attn_mask = torch.concat([prior_mask_padded, active_mask_padded], dim=1)
|
||||
|
||||
attn_mask = reorder_context_mask(attn_mask, large_tile_size, block_size)
|
||||
|
||||
input_args = (
|
||||
query.to(device=device),
|
||||
@@ -439,29 +481,21 @@ def test_contexted_kv_attention(
|
||||
v.to(device=device),
|
||||
k_cache.to(device=device),
|
||||
v_cache.to(device=device),
|
||||
active_block_table.to(torch.int32).to(device=device),
|
||||
active_block_table.to(device=device),
|
||||
attn_mask.to(device=device),
|
||||
)
|
||||
input_kwargs = dict(
|
||||
n_kv_head=num_kv_heads,
|
||||
head_size=head_size,
|
||||
mixed_precision=mixed_precision,
|
||||
LARGE_TILE_SZ=LARGE_TILE_SZ,
|
||||
return_debug_tensors=return_debug_tensors,
|
||||
LARGE_TILE_SZ=large_tile_size,
|
||||
)
|
||||
|
||||
if return_debug_tensors:
|
||||
output_nki, *debug_tensors = flash_attn_varlen_nkifunc(
|
||||
*input_args, **input_kwargs)
|
||||
else:
|
||||
output_nki = flash_attn_varlen_nkifunc(*input_args, **input_kwargs)
|
||||
debug_tensors = []
|
||||
|
||||
debug_tensors = [torch.tensor(dt).cpu() for dt in debug_tensors]
|
||||
output_nki = flash_attn_varlen_nkifunc(*input_args, **input_kwargs)
|
||||
|
||||
num_actual_tokens = sum(query_lens)
|
||||
# - o: shape (bs, n_heads, seq_q, d) -> (bs, seq_q, n_heads, d)
|
||||
output_nki = output_nki.cpu().permute(0, 2, 1, 3)[:, :, :, :head_size]
|
||||
output_nki = output_nki.cpu().permute(0, 2, 1, 3)
|
||||
output_nki = output_nki[0, :num_actual_tokens, :, :]
|
||||
output_ref_padded = F.pad(
|
||||
output_ref,
|
||||
|
||||
Reference in New Issue
Block a user