[Neuron][Kernel] Vectorize KV cache load in FlashPagedAttention to maximize DMA bandwidth (#13245)
Signed-off-by: Lingfan Yu <lingfany@amazon.com>
This commit is contained in:
153
tests/neuron/test_block_table.py
Normal file
153
tests/neuron/test_block_table.py
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import os
|
||||||
|
|
||||||
|
import neuronxcc.nki.language as nl
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from neuronxcc import nki
|
||||||
|
|
||||||
|
from vllm.attention.ops.nki_flash_attn import (
|
||||||
|
load_block_tables, transform_block_tables_for_indirect_load)
|
||||||
|
|
||||||
|
|
||||||
|
def is_power_of_2(n):
|
||||||
|
return n > 0 and (n & (n - 1) == 0)
|
||||||
|
|
||||||
|
|
||||||
|
def nki_load_and_transform_block_tables(
|
||||||
|
block_tables,
|
||||||
|
num_tiles,
|
||||||
|
num_blocks_per_tile,
|
||||||
|
num_head,
|
||||||
|
head_id,
|
||||||
|
block_size_tiling_factor,
|
||||||
|
):
|
||||||
|
assert is_power_of_2(
|
||||||
|
num_blocks_per_tile), f"{num_blocks_per_tile=} must be power of 2"
|
||||||
|
block_tables_sbuf = load_block_tables(block_tables, num_tiles,
|
||||||
|
num_blocks_per_tile)
|
||||||
|
|
||||||
|
# we need to pass an Index as head_id
|
||||||
|
head_id = nl.arange(1)[None, :] + head_id
|
||||||
|
|
||||||
|
block_tables_transposed = transform_block_tables_for_indirect_load(
|
||||||
|
block_tables_sbuf, block_size_tiling_factor, num_head, head_id)
|
||||||
|
B_P_SIZE = 128
|
||||||
|
assert block_tables_transposed.shape[1] == B_P_SIZE
|
||||||
|
|
||||||
|
out = nl.ndarray(
|
||||||
|
block_tables_transposed.shape,
|
||||||
|
dtype=nl.int32,
|
||||||
|
buffer=nl.shared_hbm,
|
||||||
|
)
|
||||||
|
for i in nl.affine_range(block_tables_transposed.shape[0]):
|
||||||
|
nl.store(dst=out[i], value=block_tables_transposed[i])
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def ref_block_tables_transform(
|
||||||
|
block_tables,
|
||||||
|
num_tiles,
|
||||||
|
num_blocks_per_tile,
|
||||||
|
num_head,
|
||||||
|
head_id,
|
||||||
|
block_size_tiling_factor,
|
||||||
|
):
|
||||||
|
assert block_tables.numel() == num_tiles * num_blocks_per_tile
|
||||||
|
block_tables = block_tables.view(num_tiles, num_blocks_per_tile)
|
||||||
|
B_F_SIZE = 128
|
||||||
|
num_tiles_padded = (num_tiles + B_F_SIZE - 1) // B_F_SIZE * B_F_SIZE
|
||||||
|
block_tables = F.pad(
|
||||||
|
block_tables,
|
||||||
|
(0, 0, 0, num_tiles_padded - num_tiles),
|
||||||
|
"constant",
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
|
||||||
|
block_tables = block_tables * num_head + head_id
|
||||||
|
block_tables = block_tables.view(num_tiles_padded, num_blocks_per_tile, 1)
|
||||||
|
offset = torch.arange(0, block_size_tiling_factor).view(1, 1, -1)
|
||||||
|
block_tables = block_tables * block_size_tiling_factor + offset
|
||||||
|
block_tables_transposed = block_tables.view(num_tiles_padded, -1).t()
|
||||||
|
|
||||||
|
num_blocks_per_tile = block_tables_transposed.shape[0]
|
||||||
|
assert num_blocks_per_tile % B_F_SIZE == 0
|
||||||
|
return block_tables_transposed.view(num_blocks_per_tile // B_F_SIZE,
|
||||||
|
B_F_SIZE, num_tiles_padded)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"q_head_per_kv_head,head_id",
|
||||||
|
[
|
||||||
|
(1, 0),
|
||||||
|
(3, 1),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"num_tiles,num_blocks_per_tile",
|
||||||
|
[
|
||||||
|
(1, 1),
|
||||||
|
(13, 16),
|
||||||
|
(17, 128),
|
||||||
|
(35, 512),
|
||||||
|
(128, 128),
|
||||||
|
(130, 64),
|
||||||
|
(280, 256),
|
||||||
|
(315, 1),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_load_and_transform_block_tables(
|
||||||
|
num_tiles,
|
||||||
|
num_blocks_per_tile,
|
||||||
|
q_head_per_kv_head,
|
||||||
|
head_id,
|
||||||
|
) -> None:
|
||||||
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
|
device = xm.xla_device()
|
||||||
|
|
||||||
|
compiler_flags = [
|
||||||
|
"-O1",
|
||||||
|
"--retry_failed_compilation",
|
||||||
|
]
|
||||||
|
compiler_flags_str = " ".join(compiler_flags)
|
||||||
|
os.environ["NEURON_CC_FLAGS"] = compiler_flags_str
|
||||||
|
|
||||||
|
torch.manual_seed(10000)
|
||||||
|
torch.set_printoptions(sci_mode=False)
|
||||||
|
|
||||||
|
# On Neuron, we need B_P_SIZE = 128 blocks to make DMA efficient
|
||||||
|
B_P_SIZE = 128
|
||||||
|
if num_blocks_per_tile < B_P_SIZE:
|
||||||
|
assert B_P_SIZE % num_blocks_per_tile == 0
|
||||||
|
block_size_tiling_factor = B_P_SIZE // num_blocks_per_tile
|
||||||
|
else:
|
||||||
|
block_size_tiling_factor = 1
|
||||||
|
max_num_blocks = 100000
|
||||||
|
block_tables = torch.randint(
|
||||||
|
0,
|
||||||
|
max_num_blocks,
|
||||||
|
(num_tiles * num_blocks_per_tile, ),
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
nki_out = nki.jit(nki_load_and_transform_block_tables)[1, 1](
|
||||||
|
block_tables.to(device=device),
|
||||||
|
num_tiles,
|
||||||
|
num_blocks_per_tile,
|
||||||
|
q_head_per_kv_head,
|
||||||
|
head_id,
|
||||||
|
block_size_tiling_factor,
|
||||||
|
).cpu()
|
||||||
|
ref_out = ref_block_tables_transform(
|
||||||
|
block_tables,
|
||||||
|
num_tiles,
|
||||||
|
num_blocks_per_tile,
|
||||||
|
q_head_per_kv_head,
|
||||||
|
head_id,
|
||||||
|
block_size_tiling_factor,
|
||||||
|
)
|
||||||
|
assert (nki_out.shape == ref_out.shape
|
||||||
|
), f"{nki_out.shape=} != {ref_out.shape=}"
|
||||||
|
assert torch.all(nki_out == ref_out)
|
||||||
@@ -107,7 +107,7 @@ def ref_masked_attention(
|
|||||||
masked_score, dim=-1, return_max_reduce=True)
|
masked_score, dim=-1, return_max_reduce=True)
|
||||||
else:
|
else:
|
||||||
norm_score = ref_softmax(masked_score, dim=-1)
|
norm_score = ref_softmax(masked_score, dim=-1)
|
||||||
out = torch.einsum("hqk,khd->qhd", norm_score, value)
|
out = torch.einsum("hqk,khd->qhd", norm_score.to(value.dtype), value)
|
||||||
if return_max_reduce:
|
if return_max_reduce:
|
||||||
return (
|
return (
|
||||||
out,
|
out,
|
||||||
@@ -118,7 +118,7 @@ def ref_masked_attention(
|
|||||||
scaled_qk,
|
scaled_qk,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return out
|
return (out, )
|
||||||
|
|
||||||
|
|
||||||
def ref_context_attention(
|
def ref_context_attention(
|
||||||
@@ -128,8 +128,6 @@ def ref_context_attention(
|
|||||||
query_lens,
|
query_lens,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
head_size,
|
head_size,
|
||||||
num_kv_heads,
|
|
||||||
num_heads,
|
|
||||||
num_queries_per_kv,
|
num_queries_per_kv,
|
||||||
return_max_reduce=False,
|
return_max_reduce=False,
|
||||||
):
|
):
|
||||||
@@ -146,18 +144,19 @@ def ref_context_attention(
|
|||||||
attn_mask = torch.logical_not(attn_mask)
|
attn_mask = torch.logical_not(attn_mask)
|
||||||
attn_mask = attn_mask.float() * -30000
|
attn_mask = attn_mask.float() * -30000
|
||||||
|
|
||||||
output, cached_max, cached_sum_reciprocal, lse, masked_score, scaled_qk = (
|
output, *debug_tensors = ref_masked_attention(
|
||||||
ref_masked_attention(
|
query,
|
||||||
query,
|
key,
|
||||||
key,
|
value,
|
||||||
value,
|
scale,
|
||||||
scale,
|
attn_mask,
|
||||||
attn_mask,
|
return_max_reduce=return_max_reduce,
|
||||||
return_max_reduce=return_max_reduce,
|
)
|
||||||
))
|
|
||||||
|
|
||||||
output = output.unsqueeze(1)
|
output = output.unsqueeze(1)
|
||||||
if return_max_reduce:
|
if return_max_reduce:
|
||||||
|
cached_max, cached_sum_reciprocal, lse, masked_score, scaled_qk = (
|
||||||
|
debug_tensors)
|
||||||
return (
|
return (
|
||||||
output,
|
output,
|
||||||
cached_max,
|
cached_max,
|
||||||
@@ -170,65 +169,22 @@ def ref_context_attention(
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
def sample_inputs(
|
||||||
"block_size, large_tile_size",
|
prefill_batch_size,
|
||||||
[
|
decode_batch_size,
|
||||||
(32, 2048), # 64 blocks
|
min_query_len,
|
||||||
(32, 4096), # 128 blocks
|
max_query_len,
|
||||||
(32, 8192), # 256 blocks
|
min_ctx_len,
|
||||||
(64, 8192), # 128 blocks
|
max_ctx_len,
|
||||||
],
|
block_size,
|
||||||
)
|
num_heads,
|
||||||
@pytest.mark.parametrize(
|
num_kv_heads,
|
||||||
"num_heads,num_queries_per_kv,head_size,mixed_precision",
|
head_size,
|
||||||
[
|
dtype,
|
||||||
(4, 2, 8, False),
|
):
|
||||||
(4, 2, 8, True),
|
|
||||||
(32, 8, 64, True),
|
|
||||||
(16, 2, 128, True),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
@torch.inference_mode()
|
|
||||||
def test_contexted_kv_attention(
|
|
||||||
num_heads: int,
|
|
||||||
num_queries_per_kv: int,
|
|
||||||
head_size: int,
|
|
||||||
block_size: int,
|
|
||||||
large_tile_size,
|
|
||||||
mixed_precision: bool,
|
|
||||||
) -> None:
|
|
||||||
import os
|
|
||||||
|
|
||||||
import torch_xla.core.xla_model as xm
|
|
||||||
|
|
||||||
from vllm.attention.ops.nki_flash_attn import flash_attn_varlen_nkifunc
|
|
||||||
|
|
||||||
assert large_tile_size % block_size == 0
|
|
||||||
|
|
||||||
device = xm.xla_device()
|
|
||||||
|
|
||||||
compiler_flags = [
|
|
||||||
"--model-type=transformer -O1",
|
|
||||||
"--internal-hlo2tensorizer-options='--verify-hlo'",
|
|
||||||
"--retry_failed_compilation",
|
|
||||||
]
|
|
||||||
compiler_flags_str = " ".join(compiler_flags)
|
|
||||||
os.environ["NEURON_CC_FLAGS"] = compiler_flags_str
|
|
||||||
|
|
||||||
torch.manual_seed(0)
|
|
||||||
torch.set_printoptions(sci_mode=False)
|
|
||||||
|
|
||||||
min_ctx_len = 32
|
|
||||||
max_ctx_len = 1024
|
|
||||||
min_query_len = 16
|
|
||||||
max_query_len = 512
|
|
||||||
prefill_batch_size = 4
|
|
||||||
decode_batch_size = 12
|
|
||||||
batch_size = prefill_batch_size + decode_batch_size
|
batch_size = prefill_batch_size + decode_batch_size
|
||||||
max_model_len = (max_query_len + max_ctx_len) * 4
|
max_model_len = (max_query_len + max_ctx_len) * 4
|
||||||
|
|
||||||
max_block_per_request = max_model_len // block_size
|
max_block_per_request = max_model_len // block_size
|
||||||
dtype = torch.float32
|
|
||||||
cache_size = (batch_size * max_block_per_request) + 2
|
cache_size = (batch_size * max_block_per_request) + 2
|
||||||
prefill_ctx_lens = torch.randint(min_ctx_len,
|
prefill_ctx_lens = torch.randint(min_ctx_len,
|
||||||
max_ctx_len + 1, (prefill_batch_size, ),
|
max_ctx_len + 1, (prefill_batch_size, ),
|
||||||
@@ -244,7 +200,6 @@ def test_contexted_kv_attention(
|
|||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
).tolist() + [1 for _ in range(decode_batch_size)]
|
).tolist() + [1 for _ in range(decode_batch_size)]
|
||||||
seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
|
seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
|
||||||
num_kv_heads = num_heads // num_queries_per_kv
|
|
||||||
|
|
||||||
num_tokens = sum(query_lens)
|
num_tokens = sum(query_lens)
|
||||||
query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
|
query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
|
||||||
@@ -304,47 +259,139 @@ def test_contexted_kv_attention(
|
|||||||
cur_ctx += block_size
|
cur_ctx += block_size
|
||||||
block_id += 1
|
block_id += 1
|
||||||
|
|
||||||
|
return (
|
||||||
|
query,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
k_cache,
|
||||||
|
v_cache,
|
||||||
|
block_table,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
query_lens,
|
||||||
|
seq_lens,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_active_block_tables(block_tables, query_lens, seq_lens, block_size,
|
||||||
|
num_blocks):
|
||||||
|
context_lens = seq_lens - query_lens
|
||||||
|
blocks_per_seq = (context_lens + block_size - 1) // block_size
|
||||||
|
num_seqs = len(seq_lens)
|
||||||
|
active_blocks: list[int] = []
|
||||||
|
for seq_id in range(num_seqs):
|
||||||
|
active_blocks = (
|
||||||
|
active_blocks +
|
||||||
|
block_tables[seq_id, :blocks_per_seq[seq_id]].tolist())
|
||||||
|
return F.pad(
|
||||||
|
torch.tensor(active_blocks, dtype=torch.int32),
|
||||||
|
(0, num_blocks - len(active_blocks)),
|
||||||
|
"constant",
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"prefill_batch_size,decode_batch_size,block_size,large_tile_size",
|
||||||
|
[
|
||||||
|
(1, 199, 1, 512), # 512 blocks
|
||||||
|
(4, 12, 256, 2048), # 128 blocks
|
||||||
|
(4, 12, 16, 2048), # 128 blocks
|
||||||
|
(4, 12, 4, 1024), # 256 blocks
|
||||||
|
(4, 12, 32, 2048), # 64 blocks
|
||||||
|
(4, 12, 32, 4096), # 128 blocks
|
||||||
|
(4, 12, 32, 8192), # 256 blocks
|
||||||
|
(4, 12, 64, 8192), # 128 blocks
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"num_heads,num_queries_per_kv,head_size",
|
||||||
|
[
|
||||||
|
(4, 2, 8),
|
||||||
|
(32, 8, 64),
|
||||||
|
(4, 4, 128),
|
||||||
|
(8, 1, 32),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("mixed_precision", [True, False])
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_contexted_kv_attention(
|
||||||
|
prefill_batch_size: int,
|
||||||
|
decode_batch_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
num_queries_per_kv: int,
|
||||||
|
head_size: int,
|
||||||
|
block_size: int,
|
||||||
|
large_tile_size,
|
||||||
|
mixed_precision: bool,
|
||||||
|
) -> None:
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
|
from vllm.attention.ops.nki_flash_attn import (flash_attn_varlen_nkifunc,
|
||||||
|
reorder_context_mask)
|
||||||
|
|
||||||
|
assert large_tile_size % block_size == 0
|
||||||
|
|
||||||
|
device = xm.xla_device()
|
||||||
|
|
||||||
|
compiler_flags = [
|
||||||
|
"-O1",
|
||||||
|
"--retry_failed_compilation",
|
||||||
|
]
|
||||||
|
compiler_flags_str = " ".join(compiler_flags)
|
||||||
|
os.environ["NEURON_CC_FLAGS"] = compiler_flags_str
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
torch.set_printoptions(sci_mode=False)
|
||||||
|
dtype = torch.float32
|
||||||
|
|
||||||
|
min_ctx_len = 32
|
||||||
|
max_ctx_len = 1024
|
||||||
|
min_query_len = 16
|
||||||
|
max_query_len = 512
|
||||||
|
num_kv_heads = num_heads // num_queries_per_kv
|
||||||
(
|
(
|
||||||
output_ref,
|
query,
|
||||||
cached_max,
|
k_active,
|
||||||
cached_sum_reciprocal,
|
v_active,
|
||||||
lse,
|
k_cache,
|
||||||
masked_score,
|
v_cache,
|
||||||
scaled_qk,
|
block_table,
|
||||||
) = ref_context_attention(
|
key,
|
||||||
|
value,
|
||||||
|
query_lens,
|
||||||
|
seq_lens,
|
||||||
|
) = sample_inputs(
|
||||||
|
prefill_batch_size=prefill_batch_size,
|
||||||
|
decode_batch_size=decode_batch_size,
|
||||||
|
min_query_len=min_query_len,
|
||||||
|
max_query_len=max_query_len,
|
||||||
|
min_ctx_len=min_ctx_len,
|
||||||
|
max_ctx_len=max_ctx_len,
|
||||||
|
block_size=block_size,
|
||||||
|
num_heads=num_heads,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
head_size=head_size,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
output_ref = ref_context_attention(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
query_lens,
|
query_lens,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
head_size,
|
head_size,
|
||||||
num_kv_heads,
|
|
||||||
num_heads,
|
|
||||||
num_queries_per_kv,
|
num_queries_per_kv,
|
||||||
return_max_reduce=True,
|
return_max_reduce=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# build neuron program
|
# build neuron program
|
||||||
return_debug_tensors = False
|
|
||||||
B_P_SIZE = 128
|
B_P_SIZE = 128
|
||||||
LARGE_TILE_SZ = large_tile_size
|
assert (large_tile_size >= B_P_SIZE
|
||||||
|
), f"Expect {large_tile_size=} to be larger than {B_P_SIZE=}"
|
||||||
def get_active_block_tables(block_tables, query_lens, seq_lens, block_size,
|
|
||||||
num_blocks):
|
|
||||||
context_lens = seq_lens - query_lens
|
|
||||||
blocks_per_seq = (context_lens + block_size - 1) // block_size
|
|
||||||
num_seqs = len(seq_lens)
|
|
||||||
active_blocks: list[int] = []
|
|
||||||
for seq_id in range(num_seqs):
|
|
||||||
active_blocks = (
|
|
||||||
active_blocks +
|
|
||||||
block_tables[seq_id, :blocks_per_seq[seq_id]].tolist())
|
|
||||||
return F.pad(
|
|
||||||
torch.tensor(active_blocks),
|
|
||||||
(0, num_blocks - len(active_blocks)),
|
|
||||||
"constant",
|
|
||||||
0,
|
|
||||||
)
|
|
||||||
|
|
||||||
def ceil_div(a, b):
|
def ceil_div(a, b):
|
||||||
return (a + b - 1) // b
|
return (a + b - 1) // b
|
||||||
@@ -357,32 +404,27 @@ def test_contexted_kv_attention(
|
|||||||
return 2**int(a - 1).bit_length()
|
return 2**int(a - 1).bit_length()
|
||||||
|
|
||||||
# calculate input shapes
|
# calculate input shapes
|
||||||
max_num_queries = pad_to_multiple(sum(query_lens), block_size)
|
max_num_queries = pad_to_next_power_of_2(sum(query_lens))
|
||||||
max_num_queries = pad_to_next_power_of_2(max_num_queries)
|
|
||||||
head_size_padded = B_P_SIZE
|
|
||||||
assert head_size_padded >= head_size
|
|
||||||
context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens)
|
context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens)
|
||||||
num_active_blocks = ceil_div(context_lens, block_size).sum().item()
|
num_active_blocks = ceil_div(context_lens, block_size).sum().item()
|
||||||
num_active_blocks = pad_to_multiple(num_active_blocks,
|
num_active_blocks = pad_to_multiple(num_active_blocks,
|
||||||
LARGE_TILE_SZ // block_size)
|
large_tile_size // block_size)
|
||||||
context_kv_len = num_active_blocks * block_size
|
context_kv_len = num_active_blocks * block_size
|
||||||
assert (context_kv_len %
|
assert (context_kv_len %
|
||||||
LARGE_TILE_SZ == 0), f"invalid context_kv_len={context_kv_len}"
|
large_tile_size == 0), f"invalid context_kv_len={context_kv_len}"
|
||||||
|
|
||||||
# pad QKV tensors
|
# pad QKV tensors
|
||||||
pad_dims = (
|
pad_dims = (
|
||||||
0,
|
0,
|
||||||
head_size_padded - query.shape[2],
|
0,
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
max_num_queries - query.shape[0],
|
max_num_queries - query.shape[0],
|
||||||
)
|
)
|
||||||
query = F.pad(query, pad_dims, "constant", 0)
|
query = F.pad(query, pad_dims, "constant", 0)
|
||||||
k = F.pad(k, pad_dims, "constant", 0)
|
k = F.pad(k_active, pad_dims, "constant", 0)
|
||||||
v = F.pad(v, pad_dims, "constant", 0)
|
v = F.pad(v_active, pad_dims, "constant", 0)
|
||||||
k_cache = F.pad(k_cache, (0, head_size_padded - head_size), "constant", 0)
|
|
||||||
v_cache = F.pad(v_cache, (0, head_size_padded - head_size), "constant", 0)
|
|
||||||
|
|
||||||
# permute QKV tensors
|
# permute QKV tensors
|
||||||
# query: (1, n_heads, d, seq_q)
|
# query: (1, n_heads, d, seq_q)
|
||||||
@@ -391,6 +433,8 @@ def test_contexted_kv_attention(
|
|||||||
query = query.unsqueeze(0).permute(0, 2, 3, 1).contiguous()
|
query = query.unsqueeze(0).permute(0, 2, 3, 1).contiguous()
|
||||||
k = k.unsqueeze(0).permute(0, 2, 3, 1).contiguous()
|
k = k.unsqueeze(0).permute(0, 2, 3, 1).contiguous()
|
||||||
v = v.unsqueeze(0).permute(0, 2, 1, 3).contiguous()
|
v = v.unsqueeze(0).permute(0, 2, 1, 3).contiguous()
|
||||||
|
k_cache = k_cache.permute(0, 2, 1, 3).contiguous()
|
||||||
|
v_cache = v_cache.permute(0, 2, 1, 3).contiguous()
|
||||||
|
|
||||||
# transform block table
|
# transform block table
|
||||||
active_block_table = get_active_block_tables(
|
active_block_table = get_active_block_tables(
|
||||||
@@ -405,33 +449,31 @@ def test_contexted_kv_attention(
|
|||||||
prior_mask, active_mask = (
|
prior_mask, active_mask = (
|
||||||
BlockDiagonalCausalFromBottomRightMask.from_seqlens(
|
BlockDiagonalCausalFromBottomRightMask.from_seqlens(
|
||||||
query_lens, seq_lens, block_size=block_size))
|
query_lens, seq_lens, block_size=block_size))
|
||||||
attn_mask = torch.concat(
|
prior_mask_padded = F.pad(
|
||||||
[
|
prior_mask,
|
||||||
F.pad(
|
(
|
||||||
prior_mask,
|
0,
|
||||||
(
|
context_kv_len - prior_mask.shape[1],
|
||||||
0,
|
0,
|
||||||
context_kv_len - prior_mask.shape[1],
|
max_num_queries - prior_mask.shape[0],
|
||||||
0,
|
),
|
||||||
max_num_queries - prior_mask.shape[0],
|
"constant",
|
||||||
),
|
0,
|
||||||
"constant",
|
).bool()
|
||||||
0,
|
active_mask_padded = F.pad(
|
||||||
).bool(),
|
active_mask,
|
||||||
F.pad(
|
(
|
||||||
active_mask,
|
0,
|
||||||
(
|
max_num_queries - active_mask.shape[1],
|
||||||
0,
|
0,
|
||||||
max_num_queries - active_mask.shape[1],
|
max_num_queries - active_mask.shape[0],
|
||||||
0,
|
),
|
||||||
max_num_queries - active_mask.shape[0],
|
"constant",
|
||||||
),
|
0,
|
||||||
"constant",
|
).bool()
|
||||||
0,
|
attn_mask = torch.concat([prior_mask_padded, active_mask_padded], dim=1)
|
||||||
).bool(),
|
|
||||||
],
|
attn_mask = reorder_context_mask(attn_mask, large_tile_size, block_size)
|
||||||
dim=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
input_args = (
|
input_args = (
|
||||||
query.to(device=device),
|
query.to(device=device),
|
||||||
@@ -439,29 +481,21 @@ def test_contexted_kv_attention(
|
|||||||
v.to(device=device),
|
v.to(device=device),
|
||||||
k_cache.to(device=device),
|
k_cache.to(device=device),
|
||||||
v_cache.to(device=device),
|
v_cache.to(device=device),
|
||||||
active_block_table.to(torch.int32).to(device=device),
|
active_block_table.to(device=device),
|
||||||
attn_mask.to(device=device),
|
attn_mask.to(device=device),
|
||||||
)
|
)
|
||||||
input_kwargs = dict(
|
input_kwargs = dict(
|
||||||
n_kv_head=num_kv_heads,
|
n_kv_head=num_kv_heads,
|
||||||
head_size=head_size,
|
head_size=head_size,
|
||||||
mixed_precision=mixed_precision,
|
mixed_precision=mixed_precision,
|
||||||
LARGE_TILE_SZ=LARGE_TILE_SZ,
|
LARGE_TILE_SZ=large_tile_size,
|
||||||
return_debug_tensors=return_debug_tensors,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if return_debug_tensors:
|
output_nki = flash_attn_varlen_nkifunc(*input_args, **input_kwargs)
|
||||||
output_nki, *debug_tensors = flash_attn_varlen_nkifunc(
|
|
||||||
*input_args, **input_kwargs)
|
|
||||||
else:
|
|
||||||
output_nki = flash_attn_varlen_nkifunc(*input_args, **input_kwargs)
|
|
||||||
debug_tensors = []
|
|
||||||
|
|
||||||
debug_tensors = [torch.tensor(dt).cpu() for dt in debug_tensors]
|
|
||||||
|
|
||||||
num_actual_tokens = sum(query_lens)
|
num_actual_tokens = sum(query_lens)
|
||||||
# - o: shape (bs, n_heads, seq_q, d) -> (bs, seq_q, n_heads, d)
|
# - o: shape (bs, n_heads, seq_q, d) -> (bs, seq_q, n_heads, d)
|
||||||
output_nki = output_nki.cpu().permute(0, 2, 1, 3)[:, :, :, :head_size]
|
output_nki = output_nki.cpu().permute(0, 2, 1, 3)
|
||||||
output_nki = output_nki[0, :num_actual_tokens, :, :]
|
output_nki = output_nki[0, :num_actual_tokens, :, :]
|
||||||
output_ref_padded = F.pad(
|
output_ref_padded = F.pad(
|
||||||
output_ref,
|
output_ref,
|
||||||
|
|||||||
@@ -1,27 +1,203 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
import neuronxcc.nki.isa as nisa
|
import neuronxcc.nki.isa as nisa
|
||||||
import neuronxcc.nki.language as nl
|
import neuronxcc.nki.language as nl
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
from neuronxcc import nki
|
from neuronxcc import nki
|
||||||
from neuronxcc.nki.language import par_dim
|
from neuronxcc.nki.language import par_dim
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
def ceil_div(a, b):
|
||||||
class FlashConfig:
|
return (a + b - 1) // b
|
||||||
"""
|
|
||||||
Config class for flash attention with default values
|
|
||||||
"""
|
|
||||||
|
|
||||||
seq_tile_size: int = 2048
|
|
||||||
should_transpose_v: bool = False
|
|
||||||
|
|
||||||
__annotations__ = {
|
def is_power_of_2(x):
|
||||||
"seq_tile_size": int,
|
return x > 0 and (x & (x - 1)) == 0
|
||||||
"should_transpose_v": bool,
|
|
||||||
}
|
|
||||||
|
@nki.jit
|
||||||
|
def load_block_tables(block_tables_hbm, num_tiles, num_blocks_per_tile):
|
||||||
|
"""
|
||||||
|
Load block tables from HBM into SRAM
|
||||||
|
|
||||||
|
`block_tables_hbm` has shape `(num_tiles * num_blocks_per_tile, )`.
|
||||||
|
In case `num_tiles > B_P_SIZE`, we need further tile `num_tile` dimension.
|
||||||
|
"""
|
||||||
|
B_P_SIZE = 128
|
||||||
|
|
||||||
|
# reshape as `(num_tiles, num_blocks_per_tile)`
|
||||||
|
assert len(block_tables_hbm.shape) == 1
|
||||||
|
(num_total_blocks, ) = block_tables_hbm.shape
|
||||||
|
assert num_blocks_per_tile * num_tiles == num_total_blocks
|
||||||
|
block_tables_hbm = block_tables_hbm.reshape(
|
||||||
|
(num_tiles, num_blocks_per_tile))
|
||||||
|
|
||||||
|
block_tables_sbuf = nl.zeros(
|
||||||
|
(ceil_div(num_tiles,
|
||||||
|
B_P_SIZE), par_dim(B_P_SIZE), num_blocks_per_tile),
|
||||||
|
dtype=nl.int32,
|
||||||
|
)
|
||||||
|
for i in nl.affine_range(ceil_div(num_tiles, B_P_SIZE)):
|
||||||
|
i_p = nl.arange(B_P_SIZE)[:, None]
|
||||||
|
i_f = nl.arange(num_blocks_per_tile)[None, :]
|
||||||
|
block_tables_sbuf[i, i_p, i_f] = nl.load(
|
||||||
|
block_tables_hbm[i_p + i * B_P_SIZE, i_f],
|
||||||
|
dtype=nl.int32,
|
||||||
|
mask=(i_p + i * B_P_SIZE < num_tiles),
|
||||||
|
)
|
||||||
|
return block_tables_sbuf
|
||||||
|
|
||||||
|
|
||||||
|
@nki.jit
|
||||||
|
def transform_block_tables_for_indirect_load(
|
||||||
|
block_tables,
|
||||||
|
block_size_tiling_factor,
|
||||||
|
num_head,
|
||||||
|
head_id,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
This function does two things:
|
||||||
|
1. calculate new `block_tables` for a `head_id` after flattening
|
||||||
|
`num_block`, `num_head`, and `block_size_tiling_factor` dimensions
|
||||||
|
2. transpose the result so that `block_table` for each tile is mapped to
|
||||||
|
SBUF Partition dimension for vectorized DMA
|
||||||
|
|
||||||
|
Tiling trick to further improve DMA performance:
|
||||||
|
Given KV cache shape `(num_block, num_head, block_size, D)`, when loading M
|
||||||
|
blocks of a given `head_id` from HBM, the load `cache[block_tables,
|
||||||
|
head_id]` has shape `(M, block_size, D)`. If M < B_P_SIZE = 128, DMA may not
|
||||||
|
fully utilize hardware parallelization. The solution is to tile `block_size`
|
||||||
|
into `(block_size_tiling_factor, tiled_block_size)` s.t. `M *
|
||||||
|
block_size_tiling_factor = B_P_SIZE`. After tiling, KV cache has shape
|
||||||
|
`(num_block, num_head, block_size_tiling_factor, tiled_block_size, D)`.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
We don't further tile D dimension as small DMA size also hurts performance.
|
||||||
|
"""
|
||||||
|
B_P_SIZE = 128
|
||||||
|
num_partitions, num_tiles_per_partition, num_blocks_per_tile = (
|
||||||
|
block_tables.shape)
|
||||||
|
assert num_tiles_per_partition == B_P_SIZE
|
||||||
|
assert is_power_of_2(
|
||||||
|
num_blocks_per_tile), f"{num_blocks_per_tile=} is not power of 2"
|
||||||
|
|
||||||
|
num_loads = ceil_div(num_blocks_per_tile, B_P_SIZE)
|
||||||
|
block_tables_transposed = nl.ndarray(
|
||||||
|
(
|
||||||
|
num_loads,
|
||||||
|
par_dim(B_P_SIZE),
|
||||||
|
num_partitions * num_tiles_per_partition,
|
||||||
|
),
|
||||||
|
dtype=nl.int32,
|
||||||
|
)
|
||||||
|
|
||||||
|
# prepare iota ahead of time to avoid repeatedly using Gpsimd
|
||||||
|
if num_head > 1:
|
||||||
|
head_id = nisa.iota(head_id, dtype=nl.int32).reshape((1, 1))
|
||||||
|
head_id = nl.transpose(
|
||||||
|
head_id.broadcast_to((1, num_tiles_per_partition)))
|
||||||
|
if num_blocks_per_tile > 1:
|
||||||
|
head_id = head_id.broadcast_to(
|
||||||
|
(num_tiles_per_partition, num_blocks_per_tile))
|
||||||
|
|
||||||
|
if block_size_tiling_factor > 1:
|
||||||
|
broadcast_shape = (
|
||||||
|
num_tiles_per_partition,
|
||||||
|
num_blocks_per_tile,
|
||||||
|
block_size_tiling_factor,
|
||||||
|
)
|
||||||
|
offset = nisa.iota(nl.arange(block_size_tiling_factor)[None, None, :],
|
||||||
|
dtype=nl.int32).broadcast_to(broadcast_shape)
|
||||||
|
|
||||||
|
for partition_id in nl.affine_range(num_partitions):
|
||||||
|
block_tables_partition = block_tables[partition_id]
|
||||||
|
if num_head > 1:
|
||||||
|
# fuse num_block and num_head dimension
|
||||||
|
block_tables_partition = block_tables_partition * num_head + head_id
|
||||||
|
|
||||||
|
if block_size_tiling_factor > 1:
|
||||||
|
# need to apply block size tiling trick
|
||||||
|
assert num_blocks_per_tile * block_size_tiling_factor == B_P_SIZE
|
||||||
|
block_tables_partition = ((block_tables_partition *
|
||||||
|
block_size_tiling_factor).reshape(
|
||||||
|
(num_tiles_per_partition,
|
||||||
|
num_blocks_per_tile,
|
||||||
|
1)).broadcast_to(broadcast_shape))
|
||||||
|
new_block_tables = block_tables_partition + offset
|
||||||
|
new_block_tables = new_block_tables.reshape(
|
||||||
|
(num_tiles_per_partition, B_P_SIZE))
|
||||||
|
else:
|
||||||
|
new_block_tables = block_tables_partition
|
||||||
|
|
||||||
|
# transpose the block table so that it can be used by vector DGE
|
||||||
|
for i in nl.affine_range(num_loads):
|
||||||
|
i_p = nl.arange(B_P_SIZE)[:, None]
|
||||||
|
i_f = (partition_id * num_tiles_per_partition +
|
||||||
|
nl.arange(num_tiles_per_partition)[None, :])
|
||||||
|
block_tables_transposed[i, i_p, i_f] = nl.transpose(
|
||||||
|
new_block_tables[:, nl.ds(i * B_P_SIZE, B_P_SIZE)])
|
||||||
|
return block_tables_transposed
|
||||||
|
|
||||||
|
|
||||||
|
@nki.jit
|
||||||
|
def load_kv_tile_from_cache(
|
||||||
|
cur_k_tile,
|
||||||
|
cur_v_tile,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
block_tables,
|
||||||
|
large_k_tile_idx,
|
||||||
|
num_blocks_per_large_tile,
|
||||||
|
tiled_block_size,
|
||||||
|
B_P_SIZE,
|
||||||
|
B_D_SIZE,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Load KV cache and transform Key and Value into layout required by Matmul
|
||||||
|
|
||||||
|
Vectorized DMA Load layout:
|
||||||
|
Key and Value: (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE)
|
||||||
|
|
||||||
|
Layout used by attention matmuls:
|
||||||
|
Key: (par_dim(B_D_SIZE), seqlen_kv)
|
||||||
|
Value: (seqlen_kv // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE)
|
||||||
|
equivalent to (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE)
|
||||||
|
"""
|
||||||
|
# load key cache
|
||||||
|
num_loads = ceil_div(num_blocks_per_large_tile, B_P_SIZE)
|
||||||
|
for load_idx in nl.affine_range(num_loads):
|
||||||
|
i_p = nl.arange(B_P_SIZE)[:, None]
|
||||||
|
i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :]
|
||||||
|
loaded = nl.load(key_cache[block_tables[load_idx, i_p,
|
||||||
|
large_k_tile_idx], i_f])
|
||||||
|
if cur_k_tile.dtype != loaded.dtype:
|
||||||
|
loaded = nl.copy(loaded, dtype=cur_k_tile.dtype)
|
||||||
|
# Transpose SBUF tensor using PE
|
||||||
|
for tb_i in nl.affine_range(tiled_block_size):
|
||||||
|
cur_k_tile[
|
||||||
|
:,
|
||||||
|
nl.ds(
|
||||||
|
load_idx * B_P_SIZE * tiled_block_size + tb_i * B_P_SIZE,
|
||||||
|
B_P_SIZE,
|
||||||
|
),
|
||||||
|
] = nl.transpose(loaded[:, nl.ds(tb_i * B_D_SIZE, B_D_SIZE)])
|
||||||
|
|
||||||
|
# load value cache
|
||||||
|
for load_idx in nl.affine_range(num_loads):
|
||||||
|
loaded = nl.load(value_cache[block_tables[load_idx, i_p,
|
||||||
|
large_k_tile_idx], i_f])
|
||||||
|
if cur_v_tile.dtype != loaded.dtype:
|
||||||
|
loaded = nl.copy(loaded, dtype=cur_v_tile.dtype)
|
||||||
|
i_p = nl.arange(B_P_SIZE)[:, None]
|
||||||
|
i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :]
|
||||||
|
cur_v_tile[
|
||||||
|
:,
|
||||||
|
nl.ds(
|
||||||
|
load_idx * tiled_block_size * B_D_SIZE,
|
||||||
|
tiled_block_size * B_D_SIZE,
|
||||||
|
),
|
||||||
|
] = loaded
|
||||||
|
|
||||||
|
|
||||||
@nki.jit
|
@nki.jit
|
||||||
@@ -62,13 +238,13 @@ def _flash_attention_core(
|
|||||||
o_buffer,
|
o_buffer,
|
||||||
l_buffer,
|
l_buffer,
|
||||||
m_buffer,
|
m_buffer,
|
||||||
q_tile_idx,
|
|
||||||
kernel_dtype,
|
kernel_dtype,
|
||||||
acc_type,
|
acc_type,
|
||||||
flash_config: FlashConfig,
|
|
||||||
use_causal_mask,
|
|
||||||
tile_mask,
|
tile_mask,
|
||||||
|
use_causal_mask,
|
||||||
|
q_tile_idx=None,
|
||||||
initialize=False,
|
initialize=False,
|
||||||
|
LARGE_TILE_SZ=2048,
|
||||||
B_P_SIZE=128,
|
B_P_SIZE=128,
|
||||||
B_F_SIZE=512,
|
B_F_SIZE=512,
|
||||||
B_D_SIZE=128,
|
B_D_SIZE=128,
|
||||||
@@ -77,19 +253,19 @@ def _flash_attention_core(
|
|||||||
"""
|
"""
|
||||||
The flash attention core function to calculate self attention between a tile
|
The flash attention core function to calculate self attention between a tile
|
||||||
of q and a block of K and V.
|
of q and a block of K and V.
|
||||||
The q_local_tile has (B_P_SIZE, B_F_SIZE), which is loaded into the SBUF
|
The q_local_tile has (B_P_SIZE, B_D_SIZE)
|
||||||
already. The block size of K and V
|
The K and V have shape (B_D_SIZE, LARGE_TILE_SZ), whose free dimension will
|
||||||
is defined in the seq_tile_size of the flash_config. The results are stored
|
be split into size B_F_SIZE tiles
|
||||||
in the following three buffers
|
|
||||||
|
The results are stored in the following three buffers
|
||||||
o_buffer: (B_P_SIZE, d)
|
o_buffer: (B_P_SIZE, d)
|
||||||
l_buffer: (B_P_SIZE, 1)
|
l_buffer: (B_P_SIZE, 1)
|
||||||
m_buffer: (B_P_SIZE, 1)
|
m_buffer: (B_P_SIZE, 1)
|
||||||
|
|
||||||
|
All IO buffers are in SBUF.
|
||||||
"""
|
"""
|
||||||
LARGE_TILE_SZ = flash_config.seq_tile_size
|
|
||||||
num_k_tile_per_large_tile = LARGE_TILE_SZ // B_F_SIZE
|
num_k_tile_per_large_tile = LARGE_TILE_SZ // B_F_SIZE
|
||||||
|
|
||||||
# mask are used to only apply computation to the lower half of the matrix,
|
|
||||||
# which reduce the arithmetic intensity by half
|
|
||||||
qk_res_buf = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
|
qk_res_buf = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
|
||||||
buffer=nl.sbuf,
|
buffer=nl.sbuf,
|
||||||
dtype=acc_type)
|
dtype=acc_type)
|
||||||
@@ -99,6 +275,8 @@ def _flash_attention_core(
|
|||||||
k_i_b_f_slice = nl.ds(k_i * B_F_SIZE, B_F_SIZE)
|
k_i_b_f_slice = nl.ds(k_i * B_F_SIZE, B_F_SIZE)
|
||||||
|
|
||||||
if use_causal_mask:
|
if use_causal_mask:
|
||||||
|
# mask are used to only apply computation to the lower half of the
|
||||||
|
# matrix, which reduce the arithmetic intensity by up to 50%
|
||||||
multiplication_required_selection = (q_tile_idx * B_P_SIZE
|
multiplication_required_selection = (q_tile_idx * B_P_SIZE
|
||||||
>= k_i * B_F_SIZE)
|
>= k_i * B_F_SIZE)
|
||||||
else:
|
else:
|
||||||
@@ -165,7 +343,9 @@ def _flash_attention_core(
|
|||||||
REDUCTION_TILE = min(2048, LARGE_TILE_SZ // 2)
|
REDUCTION_TILE = min(2048, LARGE_TILE_SZ // 2)
|
||||||
|
|
||||||
p_partial_sum = nl.ndarray(
|
p_partial_sum = nl.ndarray(
|
||||||
(par_dim(B_P_SIZE), LARGE_TILE_SZ // REDUCTION_TILE), dtype=acc_type)
|
(par_dim(B_P_SIZE), LARGE_TILE_SZ // REDUCTION_TILE),
|
||||||
|
dtype=acc_type,
|
||||||
|
)
|
||||||
|
|
||||||
for k_r_i in nl.affine_range(LARGE_TILE_SZ // REDUCTION_TILE):
|
for k_r_i in nl.affine_range(LARGE_TILE_SZ // REDUCTION_TILE):
|
||||||
k_r_i_reduce_slice = nl.ds(k_r_i * REDUCTION_TILE, REDUCTION_TILE)
|
k_r_i_reduce_slice = nl.ds(k_r_i * REDUCTION_TILE, REDUCTION_TILE)
|
||||||
@@ -194,13 +374,15 @@ def _flash_attention_core(
|
|||||||
B_F_SIZE=B_F_SIZE,
|
B_F_SIZE=B_F_SIZE,
|
||||||
)
|
)
|
||||||
|
|
||||||
pv_psum = nl.zeros((par_dim(B_P_SIZE), B_D_SIZE),
|
pv_psum = nl.zeros(
|
||||||
dtype=np.float32,
|
(par_dim(B_P_SIZE), B_D_SIZE),
|
||||||
buffer=nl.psum)
|
dtype=np.float32,
|
||||||
|
buffer=nl.psum,
|
||||||
|
)
|
||||||
for k_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE):
|
for k_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE):
|
||||||
pv_psum[:, :] += nl.matmul(
|
pv_psum[:, :] += nl.matmul(
|
||||||
p_local_transposed[:, nl.ds(k_i * B_P_SIZE, B_P_SIZE)],
|
p_local_transposed[:, nl.ds(k_i * B_P_SIZE, B_P_SIZE)],
|
||||||
v[k_i, :, :],
|
v[:, nl.ds(k_i * B_D_SIZE, B_D_SIZE)],
|
||||||
transpose_x=True,
|
transpose_x=True,
|
||||||
) # (128, 128) (p(Br), d)
|
) # (128, 128) (p(Br), d)
|
||||||
|
|
||||||
@@ -219,44 +401,16 @@ def _flash_attention_core(
|
|||||||
|
|
||||||
|
|
||||||
@nki.jit
|
@nki.jit
|
||||||
def load_v_tile(v_hbm_tile, cur_v_tile, j, v_i, config):
|
def load_v_tile(v_hbm_tile, cur_v_tile, large_tile_idx, v_i, LARGE_TILE_SZ):
|
||||||
LARGE_TILE_SZ = config.seq_tile_size
|
|
||||||
B_P_SIZE = 128
|
B_P_SIZE = 128
|
||||||
|
B_D_SIZE = v_hbm_tile.shape[-1]
|
||||||
if not config.should_transpose_v:
|
loaded = nl.load(v_hbm_tile[
|
||||||
cur_v_tile[v_i, :, :] = nl.load(
|
nl.ds(large_tile_idx * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE),
|
||||||
v_hbm_tile[nl.ds(j * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE), :],
|
:,
|
||||||
dtype=cur_v_tile.dtype,
|
])
|
||||||
)
|
if cur_v_tile.dtype != loaded.dtype:
|
||||||
return
|
loaded = nl.copy(loaded, dtype=cur_v_tile.dtype)
|
||||||
|
cur_v_tile[:, nl.ds(v_i * B_D_SIZE, B_D_SIZE)] = loaded
|
||||||
if nisa.get_nc_version() == nisa.nc_version.gen3:
|
|
||||||
cur_v_tile_transposed = nisa.dma_transpose(
|
|
||||||
v_hbm_tile[:,
|
|
||||||
nl.ds(j * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE)])
|
|
||||||
cur_v_tile[v_i, :, :] = nisa.tensor_copy(cur_v_tile_transposed,
|
|
||||||
dtype=cur_v_tile.dtype)
|
|
||||||
return
|
|
||||||
|
|
||||||
cur_v_tile[v_i, :, :] = nl.load_transpose2d(
|
|
||||||
v_hbm_tile[:, nl.ds(j * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE)],
|
|
||||||
dtype=cur_v_tile.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@nki.jit
|
|
||||||
def load_block_tables(block_tables_hbm, num_tiles):
|
|
||||||
(num_blocks, ) = block_tables_hbm.shape
|
|
||||||
assert num_blocks % num_tiles == 0
|
|
||||||
num_blocks_per_tile = num_blocks // num_tiles
|
|
||||||
block_tables_hbm = block_tables_hbm.reshape(
|
|
||||||
(num_tiles, num_blocks_per_tile))
|
|
||||||
block_tables_buffer = nl.load(block_tables_hbm, dtype=nl.int32)
|
|
||||||
return block_tables_buffer
|
|
||||||
|
|
||||||
|
|
||||||
def is_power_of_2(x):
|
|
||||||
return x > 0 and (x & (x - 1)) == 0
|
|
||||||
|
|
||||||
|
|
||||||
@nki.jit
|
@nki.jit
|
||||||
@@ -270,24 +424,21 @@ def flash_paged_attention(
|
|||||||
mask,
|
mask,
|
||||||
softmax_scale=None,
|
softmax_scale=None,
|
||||||
mixed_precision=True,
|
mixed_precision=True,
|
||||||
config=None,
|
LARGE_TILE_SZ=2048,
|
||||||
return_debug_tensors=False,
|
return_debug_tensors=False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Flash PagedAttention Forward Kernel.
|
Flash PagedAttention Forward Kernel.
|
||||||
- PagedAttention Paper: https://arxiv.org/abs/2309.06180
|
|
||||||
- Chunked Prefill Paper: https://arxiv.org/abs/2403.02310
|
|
||||||
|
|
||||||
IO tensor layouts:
|
IO tensor layouts:
|
||||||
- query: shape (1, n_heads, d, seq_q)
|
- query: shape (1, n_heads, d, seq_q)
|
||||||
- key: shape (1, n_kv_heads, d, seq_k)
|
- key: shape (1, n_kv_heads, d, seq_k)
|
||||||
- value: shape (1, n_kv_heads, seq_v, d)
|
- value: shape (1, n_kv_heads, seq_v, d)
|
||||||
- key_cache: (num_blocks, block_size, n_kv_heads, d)
|
- key_cache: (num_blocks, n_kv_heads, block_size, d)
|
||||||
- value_cache: (num_blocks, block_size, n_kv_heads, d)
|
- value_cache: (num_blocks, n_kv_heads, block_size, d)
|
||||||
- block_tables: (num_active_blocks, )
|
- block_tables: (num_active_blocks, )
|
||||||
- mask: (seq_q, num_active_blocks * block_size)
|
- mask: (seq_q, num_active_blocks * block_size + seq_q)
|
||||||
- o: shape (1, n_heads, seq_q, d)
|
- o: shape (1, n_heads, seq_q, d)
|
||||||
- l_m: shape (1, n_heads, seq_q, 2)
|
|
||||||
|
|
||||||
- This kernel requires seq_k == seq_v
|
- This kernel requires seq_k == seq_v
|
||||||
- We use continuous batching by default, so the batch dimension is
|
- We use continuous batching by default, so the batch dimension is
|
||||||
@@ -306,11 +457,8 @@ def flash_paged_attention(
|
|||||||
- softmax_scale: scaling for softmax, is None, default is `1.0/(d**0.5)`
|
- softmax_scale: scaling for softmax, is None, default is `1.0/(d**0.5)`
|
||||||
- mixed_precision: flag to set non-matmul ops in fp32 precision, default
|
- mixed_precision: flag to set non-matmul ops in fp32 precision, default
|
||||||
is set to `true`, if false, we use same precision as input types
|
is set to `true`, if false, we use same precision as input types
|
||||||
- config: Instance of dataclass :class:`nki.kernels.attention.FlashConfig`
|
- LARGE_TILE_SZ: `default=2048`, size of the kv tile size for attention
|
||||||
with Performance config parameters for flash attention with default
|
computation reduction
|
||||||
values
|
|
||||||
seq_tile_size: `default=2048`, size of the kv tile size for attention
|
|
||||||
computation reduction
|
|
||||||
|
|
||||||
GQA support Notes:
|
GQA support Notes:
|
||||||
the spmd kernel for launching kernel should be on kv_heads instead of
|
the spmd kernel for launching kernel should be on kv_heads instead of
|
||||||
@@ -322,31 +470,65 @@ def flash_paged_attention(
|
|||||||
GQA: q: [b, h, d, s], k: [b, kv_h, d, s], v: [b, kv_h, s, d]
|
GQA: q: [b, h, d, s], k: [b, kv_h, d, s], v: [b, kv_h, s, d]
|
||||||
usage: `flash_fwd[b, kv_h](q, k, v, ...)`
|
usage: `flash_fwd[b, kv_h](q, k, v, ...)`
|
||||||
"""
|
"""
|
||||||
config = config or FlashConfig()
|
|
||||||
B_F_SIZE = 512
|
B_F_SIZE = 512
|
||||||
B_P_SIZE = 128
|
B_P_SIZE = 128
|
||||||
b, h, d, seqlen_q = query.shape
|
b, h, d, seqlen_q = query.shape
|
||||||
B_D_SIZE = d
|
B_D_SIZE = d
|
||||||
LARGE_TILE_SZ = config.seq_tile_size
|
|
||||||
n_tile_q = seqlen_q // B_P_SIZE # since q will be loaded on tensor engine
|
n_tile_q = seqlen_q // B_P_SIZE # since q will be loaded on tensor engine
|
||||||
num_blocks, block_size, k_h, _ = key_cache.shape
|
num_blocks, k_h, block_size, _ = key_cache.shape
|
||||||
q_h_per_k_h = h // k_h
|
q_h_per_k_h = h // k_h
|
||||||
assert tuple(key_cache.shape) == (
|
|
||||||
num_blocks,
|
|
||||||
block_size,
|
|
||||||
k_h,
|
|
||||||
d,
|
|
||||||
), "Input shape mismatch!"
|
|
||||||
assert tuple(value_cache.shape) == (
|
|
||||||
num_blocks,
|
|
||||||
block_size,
|
|
||||||
k_h,
|
|
||||||
d,
|
|
||||||
), "Input shape mismatch!"
|
|
||||||
assert b == 1, f"invalid batch size {b=}"
|
assert b == 1, f"invalid batch size {b=}"
|
||||||
assert d <= 128, f" we do not support head_dim > 128, got head dim {d}"
|
assert d <= 128, f" we do not support head_dim > 128, got head dim {d=}"
|
||||||
|
cache_shape = (num_blocks, k_h, block_size, d)
|
||||||
|
assert (tuple(key_cache.shape) == cache_shape
|
||||||
|
), f"{key_cache.shape=} mismatch, expect {cache_shape}"
|
||||||
|
assert (tuple(value_cache.shape) == cache_shape
|
||||||
|
), f"{value_cache.shape=} mismatch, expect {cache_shape}"
|
||||||
|
assert key is None or tuple(key.shape) == (
|
||||||
|
1,
|
||||||
|
k_h,
|
||||||
|
d,
|
||||||
|
seqlen_q,
|
||||||
|
), f"key shape {key.shape} mismatch!"
|
||||||
|
assert value is None or tuple(value.shape) == (
|
||||||
|
1,
|
||||||
|
k_h,
|
||||||
|
seqlen_q,
|
||||||
|
d,
|
||||||
|
), f"value shape {value.shape} mismatch!"
|
||||||
|
|
||||||
|
assert (
|
||||||
|
nl.program_ndim() == 2
|
||||||
|
), f"Expect spmd grid with 2 dimensions, got {nl.program_ndim()} instead!"
|
||||||
|
batch_id = nl.program_id(axis=0)
|
||||||
|
head_id = nl.program_id(axis=1)
|
||||||
|
|
||||||
|
(num_active_blocks, ) = block_tables.shape
|
||||||
|
context_kv_len = num_active_blocks * block_size
|
||||||
|
assert (
|
||||||
|
LARGE_TILE_SZ % B_F_SIZE == 0
|
||||||
|
), f"Need {LARGE_TILE_SZ=} to be divisible by {B_F_SIZE=} in transpose_p"
|
||||||
|
assert (context_kv_len % LARGE_TILE_SZ == 0
|
||||||
|
), f"Need {context_kv_len=} to be divisible by {LARGE_TILE_SZ=}"
|
||||||
|
|
||||||
|
num_blocks_per_large_tile = LARGE_TILE_SZ // block_size
|
||||||
|
assert is_power_of_2(
|
||||||
|
num_blocks_per_large_tile
|
||||||
|
), f"{num_blocks_per_large_tile=} is expected of be power of 2"
|
||||||
|
if seqlen_q > B_F_SIZE:
|
||||||
|
MAX_REDUCTION_TILE = 2048
|
||||||
|
if seqlen_q // 2 > MAX_REDUCTION_TILE:
|
||||||
|
assert (
|
||||||
|
seqlen_q % MAX_REDUCTION_TILE == 0
|
||||||
|
), f"{seqlen_q=} should be divisible by {MAX_REDUCTION_TILE=}"
|
||||||
|
else:
|
||||||
|
assert (seqlen_q % B_F_SIZE == 0
|
||||||
|
), f"{seqlen_q=} should be divisible by {B_F_SIZE=})"
|
||||||
|
|
||||||
kernel_dtype = nl.bfloat16 if mixed_precision else query.dtype
|
kernel_dtype = nl.bfloat16 if mixed_precision else query.dtype
|
||||||
acc_type = np.dtype(np.float32) if mixed_precision else kernel_dtype
|
acc_type = np.dtype(np.float32) if mixed_precision else kernel_dtype
|
||||||
|
softmax_scale = softmax_scale or (1.0 / (d**0.5))
|
||||||
|
num_large_k_tile = context_kv_len // LARGE_TILE_SZ
|
||||||
|
|
||||||
o = nl.ndarray((b, h, seqlen_q, d),
|
o = nl.ndarray((b, h, seqlen_q, d),
|
||||||
dtype=query.dtype,
|
dtype=query.dtype,
|
||||||
@@ -373,35 +555,38 @@ def flash_paged_attention(
|
|||||||
buffer=nl.sbuf,
|
buffer=nl.sbuf,
|
||||||
lazy_initialization=True,
|
lazy_initialization=True,
|
||||||
)
|
)
|
||||||
|
block_tables_sbuf = load_block_tables(
|
||||||
|
block_tables_hbm=block_tables,
|
||||||
|
num_tiles=num_large_k_tile,
|
||||||
|
num_blocks_per_tile=num_blocks_per_large_tile,
|
||||||
|
)
|
||||||
|
|
||||||
assert (
|
# On Neuron, we need B_P_SIZE = 128 blocks to make DMA efficient
|
||||||
nl.program_ndim() == 2
|
if num_blocks_per_large_tile < B_P_SIZE:
|
||||||
), f"Expect spmd grid with 2 dimensions, got {nl.program_ndim()} instead!"
|
# we checked num_blocks_per_tile is a power of 2
|
||||||
batch_id = nl.program_id(axis=0)
|
assert B_P_SIZE % num_blocks_per_large_tile == 0
|
||||||
head_id = nl.program_id(axis=1)
|
block_size_tiling_factor = B_P_SIZE // num_blocks_per_large_tile
|
||||||
|
# We assume block_size >= block_size_tiling_factor
|
||||||
|
assert block_size % block_size_tiling_factor == 0
|
||||||
|
else:
|
||||||
|
block_size_tiling_factor = 1
|
||||||
|
tiled_block_size = block_size // block_size_tiling_factor
|
||||||
|
|
||||||
softmax_scale = softmax_scale or (1.0 / (d**0.5))
|
# Indirect DMA load must be placed along Partition Dimension
|
||||||
|
block_tables_sbuf = transform_block_tables_for_indirect_load(
|
||||||
|
block_tables_sbuf,
|
||||||
|
block_size_tiling_factor=block_size_tiling_factor,
|
||||||
|
num_head=k_h,
|
||||||
|
head_id=head_id,
|
||||||
|
)
|
||||||
|
|
||||||
(num_active_blocks, ) = block_tables.shape
|
# Flatten KV cache to be 2D for loading into SBUF
|
||||||
context_kv_len = num_active_blocks * block_size
|
new_cache_shape = (
|
||||||
assert (config.seq_tile_size >= 512
|
num_blocks * k_h * block_size_tiling_factor,
|
||||||
), f" seq tile_size {config.seq_tile_size} cannot be less than 512"
|
tiled_block_size * d,
|
||||||
assert (context_kv_len % LARGE_TILE_SZ == 0
|
)
|
||||||
), f"Need {context_kv_len=} to be divisible by {LARGE_TILE_SZ=}"
|
key_cache = key_cache.reshape(new_cache_shape)
|
||||||
assert (
|
value_cache = value_cache.reshape(new_cache_shape)
|
||||||
LARGE_TILE_SZ % B_P_SIZE == 0
|
|
||||||
), f"Need LARGE_TILE_SZ ({LARGE_TILE_SZ}) to be divisible by {B_P_SIZE=}"
|
|
||||||
assert (B_P_SIZE % block_size == 0
|
|
||||||
), f"Need B_P_SIZE ({B_P_SIZE}) to be divisible by {block_size=}"
|
|
||||||
num_large_k_tile = context_kv_len // LARGE_TILE_SZ
|
|
||||||
num_blocks_per_large_tile = LARGE_TILE_SZ // block_size
|
|
||||||
assert block_size % 32 == 0, "block_size is expected to be a multiple of 32"
|
|
||||||
assert is_power_of_2(
|
|
||||||
num_blocks_per_large_tile
|
|
||||||
), "The number of blocks in each large tile is expected of be power of 2"
|
|
||||||
assert is_power_of_2(seqlen_q), "seqlen_q is expected to be power of 2"
|
|
||||||
|
|
||||||
block_tables_sbuf = load_block_tables(block_tables, num_large_k_tile)
|
|
||||||
|
|
||||||
# Global Flash Attention accumulators
|
# Global Flash Attention accumulators
|
||||||
o_buffer = nl.zeros(
|
o_buffer = nl.zeros(
|
||||||
@@ -411,7 +596,7 @@ def flash_paged_attention(
|
|||||||
lazy_initialization=True,
|
lazy_initialization=True,
|
||||||
)
|
)
|
||||||
l_buffer = nl.zeros(
|
l_buffer = nl.zeros(
|
||||||
(par_dim(B_P_SIZE), n_tile_q, q_h_per_k_h),
|
(n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), 1),
|
||||||
dtype=acc_type,
|
dtype=acc_type,
|
||||||
buffer=nl.sbuf,
|
buffer=nl.sbuf,
|
||||||
lazy_initialization=True,
|
lazy_initialization=True,
|
||||||
@@ -423,50 +608,42 @@ def flash_paged_attention(
|
|||||||
lazy_initialization=True,
|
lazy_initialization=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
for j in nl.sequential_range(0, num_large_k_tile):
|
for large_k_tile_idx in nl.sequential_range(0, num_large_k_tile):
|
||||||
cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ),
|
num_loads = ceil_div(num_blocks_per_large_tile, B_P_SIZE)
|
||||||
dtype=kernel_dtype)
|
cur_k_tile = nl.ndarray(
|
||||||
cur_v_tile = nl.ndarray(
|
(par_dim(B_D_SIZE), LARGE_TILE_SZ),
|
||||||
(LARGE_TILE_SZ // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE),
|
|
||||||
dtype=kernel_dtype,
|
dtype=kernel_dtype,
|
||||||
)
|
)
|
||||||
|
cur_v_tile = nl.ndarray(
|
||||||
for k_i in nl.affine_range(num_blocks_per_large_tile):
|
(par_dim(B_P_SIZE), num_loads * tiled_block_size * B_D_SIZE),
|
||||||
loaded = nl.load(key_cache[block_tables_sbuf[j, k_i], :,
|
dtype=kernel_dtype,
|
||||||
head_id, :])
|
)
|
||||||
cur_k_tile[:, nl.ds(k_i *
|
load_kv_tile_from_cache(
|
||||||
block_size, block_size)] = nl.transpose(loaded)
|
cur_k_tile=cur_k_tile,
|
||||||
|
cur_v_tile=cur_v_tile,
|
||||||
load_tile_size = B_P_SIZE
|
key_cache=key_cache,
|
||||||
num_blocks_per_partition = load_tile_size // block_size
|
value_cache=value_cache,
|
||||||
for partition_idx in nl.affine_range(LARGE_TILE_SZ // load_tile_size):
|
block_tables=block_tables_sbuf,
|
||||||
for block_in_partition in nl.affine_range(
|
large_k_tile_idx=large_k_tile_idx,
|
||||||
num_blocks_per_partition):
|
num_blocks_per_large_tile=num_blocks_per_large_tile,
|
||||||
v_i = (partition_idx * num_blocks_per_partition +
|
tiled_block_size=tiled_block_size,
|
||||||
block_in_partition)
|
B_P_SIZE=B_P_SIZE,
|
||||||
loaded_v = nl.load(value_cache[block_tables_sbuf[j, v_i], :,
|
B_D_SIZE=B_D_SIZE,
|
||||||
head_id, :])
|
)
|
||||||
cur_v_tile[
|
|
||||||
partition_idx,
|
|
||||||
nl.ds(block_in_partition * block_size, block_size),
|
|
||||||
:,
|
|
||||||
] = loaded_v
|
|
||||||
|
|
||||||
for i in nl.affine_range(n_tile_q):
|
for i in nl.affine_range(n_tile_q):
|
||||||
cur_mask = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
|
cur_mask = nl.load(mask[
|
||||||
dtype=mask.dtype)
|
nl.ds(i * B_P_SIZE, B_P_SIZE),
|
||||||
for m_i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE):
|
nl.ds(large_k_tile_idx * LARGE_TILE_SZ, LARGE_TILE_SZ),
|
||||||
cur_mask[:, nl.ds(m_i * B_F_SIZE, B_F_SIZE)] = nl.load(mask[
|
])
|
||||||
nl.ds(i * B_P_SIZE, B_P_SIZE),
|
|
||||||
nl.ds(j * LARGE_TILE_SZ + m_i * B_F_SIZE, B_F_SIZE),
|
|
||||||
])
|
|
||||||
for i_q_h in nl.affine_range(q_h_per_k_h):
|
for i_q_h in nl.affine_range(q_h_per_k_h):
|
||||||
q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype)
|
q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype)
|
||||||
q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h]
|
q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h]
|
||||||
q_sbuf_tile = nl.load(
|
q_sbuf_tile = nl.load(q_hbm_tile[:,
|
||||||
q_hbm_tile[:, nl.ds(i * B_P_SIZE, B_P_SIZE)],
|
nl.ds(i *
|
||||||
dtype=kernel_dtype,
|
B_P_SIZE, B_P_SIZE)])
|
||||||
) # load (d, 128) tile in SBUF
|
if q_sbuf_tile.dtype != kernel_dtype:
|
||||||
|
q_sbuf_tile = nl.copy(q_sbuf_tile, dtype=kernel_dtype)
|
||||||
q_tile[:, :] = q_sbuf_tile * softmax_scale
|
q_tile[:, :] = q_sbuf_tile * softmax_scale
|
||||||
|
|
||||||
_flash_attention_core(
|
_flash_attention_core(
|
||||||
@@ -474,15 +651,15 @@ def flash_paged_attention(
|
|||||||
k=cur_k_tile,
|
k=cur_k_tile,
|
||||||
v=cur_v_tile,
|
v=cur_v_tile,
|
||||||
o_buffer=o_buffer[i, i_q_h],
|
o_buffer=o_buffer[i, i_q_h],
|
||||||
l_buffer=l_buffer[:, i, i_q_h],
|
l_buffer=l_buffer[i, i_q_h],
|
||||||
m_buffer=m_buffer[i, i_q_h],
|
m_buffer=m_buffer[i, i_q_h],
|
||||||
q_tile_idx=i,
|
|
||||||
kernel_dtype=kernel_dtype,
|
kernel_dtype=kernel_dtype,
|
||||||
acc_type=acc_type,
|
acc_type=acc_type,
|
||||||
flash_config=config,
|
|
||||||
use_causal_mask=False,
|
|
||||||
tile_mask=cur_mask,
|
tile_mask=cur_mask,
|
||||||
initialize=j == 0,
|
use_causal_mask=False,
|
||||||
|
q_tile_idx=i,
|
||||||
|
initialize=large_k_tile_idx == 0,
|
||||||
|
LARGE_TILE_SZ=LARGE_TILE_SZ,
|
||||||
B_P_SIZE=B_P_SIZE,
|
B_P_SIZE=B_P_SIZE,
|
||||||
B_F_SIZE=B_F_SIZE,
|
B_F_SIZE=B_F_SIZE,
|
||||||
B_D_SIZE=B_D_SIZE,
|
B_D_SIZE=B_D_SIZE,
|
||||||
@@ -492,62 +669,58 @@ def flash_paged_attention(
|
|||||||
if key is not None and value is not None:
|
if key is not None and value is not None:
|
||||||
B_F_SIZE = min(seqlen_q, B_F_SIZE)
|
B_F_SIZE = min(seqlen_q, B_F_SIZE)
|
||||||
LARGE_TILE_SZ = seqlen_q
|
LARGE_TILE_SZ = seqlen_q
|
||||||
active_config = FlashConfig(
|
|
||||||
seq_tile_size=LARGE_TILE_SZ,
|
|
||||||
should_transpose_v=config.should_transpose_v,
|
|
||||||
)
|
|
||||||
|
|
||||||
cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ),
|
cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ),
|
||||||
dtype=kernel_dtype)
|
dtype=kernel_dtype)
|
||||||
cur_v_tile = nl.ndarray(
|
cur_v_tile = nl.ndarray(
|
||||||
(LARGE_TILE_SZ // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE),
|
(par_dim(B_P_SIZE), LARGE_TILE_SZ // B_P_SIZE * B_D_SIZE),
|
||||||
dtype=kernel_dtype,
|
dtype=kernel_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
cur_k_tile[:, :] = nl.load(key[batch_id, head_id, :, :])
|
loaded = nl.load(key[batch_id, head_id, :, :])
|
||||||
|
if loaded.dtype != kernel_dtype:
|
||||||
|
loaded = nl.copy(loaded, dtype=kernel_dtype)
|
||||||
|
cur_k_tile[:, :] = loaded
|
||||||
|
|
||||||
load_tile_size = B_P_SIZE
|
|
||||||
v_hbm_tile = value[batch_id, head_id]
|
v_hbm_tile = value[batch_id, head_id]
|
||||||
for v_i in nl.affine_range(LARGE_TILE_SZ // load_tile_size):
|
for v_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE):
|
||||||
load_v_tile(
|
load_v_tile(
|
||||||
v_hbm_tile=v_hbm_tile,
|
v_hbm_tile=v_hbm_tile,
|
||||||
cur_v_tile=cur_v_tile,
|
cur_v_tile=cur_v_tile,
|
||||||
j=0,
|
large_tile_idx=0,
|
||||||
v_i=v_i,
|
v_i=v_i,
|
||||||
config=active_config,
|
LARGE_TILE_SZ=LARGE_TILE_SZ,
|
||||||
)
|
)
|
||||||
|
|
||||||
for i in nl.affine_range(n_tile_q):
|
for i in nl.affine_range(n_tile_q):
|
||||||
cur_mask = nl.load(
|
cur_mask = nl.load(mask[
|
||||||
mask[
|
nl.ds(i * B_P_SIZE, B_P_SIZE),
|
||||||
nl.ds(i * B_P_SIZE, B_P_SIZE),
|
nl.ds(context_kv_len, LARGE_TILE_SZ),
|
||||||
nl.ds(context_kv_len, LARGE_TILE_SZ),
|
])
|
||||||
],
|
|
||||||
dtype=mask.dtype,
|
|
||||||
)
|
|
||||||
for i_q_h in nl.affine_range(q_h_per_k_h):
|
for i_q_h in nl.affine_range(q_h_per_k_h):
|
||||||
|
|
||||||
q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype)
|
q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype)
|
||||||
q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h]
|
q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h]
|
||||||
q_sbuf_tile = nl.load(
|
q_sbuf_tile = nl.load(q_hbm_tile[:,
|
||||||
q_hbm_tile[:, nl.ds(i * B_P_SIZE, B_P_SIZE)],
|
nl.ds(i *
|
||||||
dtype=kernel_dtype,
|
B_P_SIZE, B_P_SIZE)])
|
||||||
) # load (d, 128) tile in SBUF
|
if q_sbuf_tile.dtype != kernel_dtype:
|
||||||
|
q_sbuf_tile = nl.copy(q_sbuf_tile, dtype=kernel_dtype)
|
||||||
q_tile[:, :] = q_sbuf_tile * softmax_scale
|
q_tile[:, :] = q_sbuf_tile * softmax_scale
|
||||||
_flash_attention_core(
|
_flash_attention_core(
|
||||||
q_local_tile=q_tile,
|
q_local_tile=q_tile,
|
||||||
k=cur_k_tile,
|
k=cur_k_tile,
|
||||||
v=cur_v_tile,
|
v=cur_v_tile,
|
||||||
o_buffer=o_buffer[i, i_q_h],
|
o_buffer=o_buffer[i, i_q_h],
|
||||||
l_buffer=l_buffer[:, i, i_q_h],
|
l_buffer=l_buffer[i, i_q_h],
|
||||||
m_buffer=m_buffer[i, i_q_h],
|
m_buffer=m_buffer[i, i_q_h],
|
||||||
q_tile_idx=i,
|
|
||||||
kernel_dtype=kernel_dtype,
|
kernel_dtype=kernel_dtype,
|
||||||
acc_type=acc_type,
|
acc_type=acc_type,
|
||||||
flash_config=active_config,
|
|
||||||
use_causal_mask=True,
|
|
||||||
tile_mask=cur_mask,
|
tile_mask=cur_mask,
|
||||||
|
use_causal_mask=True,
|
||||||
|
q_tile_idx=i,
|
||||||
initialize=False,
|
initialize=False,
|
||||||
|
LARGE_TILE_SZ=LARGE_TILE_SZ,
|
||||||
B_P_SIZE=B_P_SIZE,
|
B_P_SIZE=B_P_SIZE,
|
||||||
B_F_SIZE=B_F_SIZE,
|
B_F_SIZE=B_F_SIZE,
|
||||||
B_D_SIZE=B_D_SIZE,
|
B_D_SIZE=B_D_SIZE,
|
||||||
@@ -559,8 +732,8 @@ def flash_paged_attention(
|
|||||||
for i_q_h in nl.affine_range(q_h_per_k_h):
|
for i_q_h in nl.affine_range(q_h_per_k_h):
|
||||||
for i in nl.affine_range(n_tile_q):
|
for i in nl.affine_range(n_tile_q):
|
||||||
out = nl.multiply(
|
out = nl.multiply(
|
||||||
o_buffer[i, i_q_h, :, :],
|
o_buffer[i, i_q_h],
|
||||||
nl.exp(m_buffer[i, i_q_h, :, :] - l_buffer[:, i, i_q_h]),
|
nl.exp(m_buffer[i, i_q_h] - l_buffer[i, i_q_h]),
|
||||||
dtype=kernel_dtype,
|
dtype=kernel_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -589,7 +762,7 @@ def flash_paged_attention(
|
|||||||
head_id * q_h_per_k_h + i_q_h,
|
head_id * q_h_per_k_h + i_q_h,
|
||||||
nl.ds(i * B_P_SIZE, B_P_SIZE),
|
nl.ds(i * B_P_SIZE, B_P_SIZE),
|
||||||
],
|
],
|
||||||
l_buffer[:, i, i_q_h],
|
l_buffer[i, i_q_h],
|
||||||
)
|
)
|
||||||
nl.store(
|
nl.store(
|
||||||
hbm_qk_res[batch_id, head_id * q_h_per_k_h + i_q_h, :, :],
|
hbm_qk_res[batch_id, head_id * q_h_per_k_h + i_q_h, :, :],
|
||||||
@@ -601,6 +774,49 @@ def flash_paged_attention(
|
|||||||
return o
|
return o
|
||||||
|
|
||||||
|
|
||||||
|
def reorder_context_mask(mask, LARGE_TILE_SZ, block_size):
|
||||||
|
"""
|
||||||
|
Reorder the mask to make it compatible with the flash attention kernel.
|
||||||
|
|
||||||
|
We vectorize KV cache read to improve DMA utilization. However, the layout
|
||||||
|
that maximizes DMA bandwidth changes the order tokens are consumed.
|
||||||
|
|
||||||
|
The token layout (inner 2 dimensions) after vectorized load is (B_P_SIZE,
|
||||||
|
tiled_block_size) in a tile of `B_P_SIZE * tiled_block_size` tokens. And
|
||||||
|
each step the engine consumes a column (rather than a row) of B_P_SIZE
|
||||||
|
tokens. Therefore, the tokens are visited in a strided way.
|
||||||
|
|
||||||
|
To make sure mask matches the order tokens are consumed, we need to properly
|
||||||
|
transpose mask.
|
||||||
|
"""
|
||||||
|
total_query_len, total_seq_len = mask.shape
|
||||||
|
context_kv_len = total_seq_len - total_query_len
|
||||||
|
|
||||||
|
B_P_SIZE = 128
|
||||||
|
assert (LARGE_TILE_SZ
|
||||||
|
>= B_P_SIZE), f"{LARGE_TILE_SZ=} must be larger than {B_P_SIZE=}"
|
||||||
|
num_tiled_blocks = max(B_P_SIZE, LARGE_TILE_SZ // block_size)
|
||||||
|
tiled_block_size = LARGE_TILE_SZ // num_tiled_blocks
|
||||||
|
if tiled_block_size > 1:
|
||||||
|
# Mask reordering is needed when tiled_block_size > 1
|
||||||
|
device = mask.device
|
||||||
|
mask = mask.cpu()
|
||||||
|
context_mask = mask[:, :context_kv_len]
|
||||||
|
context_mask = context_mask.view(
|
||||||
|
total_query_len,
|
||||||
|
context_kv_len // LARGE_TILE_SZ,
|
||||||
|
num_tiled_blocks // B_P_SIZE,
|
||||||
|
B_P_SIZE,
|
||||||
|
tiled_block_size,
|
||||||
|
)
|
||||||
|
context_mask = context_mask.transpose(3, 4).reshape(
|
||||||
|
total_query_len, context_kv_len)
|
||||||
|
new_mask = mask[:, context_kv_len:]
|
||||||
|
return torch.concat([context_mask, new_mask], dim=1).to(device)
|
||||||
|
else:
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
def flash_attn_varlen_nkifunc(
|
def flash_attn_varlen_nkifunc(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
@@ -612,13 +828,32 @@ def flash_attn_varlen_nkifunc(
|
|||||||
n_kv_head=None,
|
n_kv_head=None,
|
||||||
head_size=None,
|
head_size=None,
|
||||||
LARGE_TILE_SZ=2048,
|
LARGE_TILE_SZ=2048,
|
||||||
return_debug_tensors=False,
|
|
||||||
mixed_precision=True,
|
mixed_precision=True,
|
||||||
):
|
):
|
||||||
config = FlashConfig(
|
"""
|
||||||
seq_tile_size=LARGE_TILE_SZ,
|
Compute flash paged attention for variable length sequences.
|
||||||
should_transpose_v=False,
|
|
||||||
)
|
This function is a wrapper around the flash attention NKI kernel. It takes
|
||||||
|
in the following arguments:
|
||||||
|
- query: (1, n_heads, d, seq_q)
|
||||||
|
- key: (1, n_kv_heads, d, seq_k)
|
||||||
|
- value: (1, n_kv_heads, seq_v, d)
|
||||||
|
- key_cache: (n_blocks, n_kv_heads, block_size, d)
|
||||||
|
- value_cache: (n_blocks, n_kv_heads, block_size, d)
|
||||||
|
- block_tables: (n_active_blocks, )
|
||||||
|
- attn_mask: (seq_q, n_active_blocks * block_size + seq_q)
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- attn_mask must be reordered outside using `reorder_context_mask`
|
||||||
|
- Key/value cache layout must be (n_blocks, n_kv_heads, block_size, d)
|
||||||
|
for better DMA throughput
|
||||||
|
"""
|
||||||
|
if n_kv_head is None:
|
||||||
|
n_kv_head = key_cache.shape[1]
|
||||||
|
assert key_cache.shape[1] == n_kv_head
|
||||||
|
if head_size is None:
|
||||||
|
head_size = key_cache.shape[-1]
|
||||||
|
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
query=query,
|
query=query,
|
||||||
key=key,
|
key=key,
|
||||||
@@ -628,15 +863,9 @@ def flash_attn_varlen_nkifunc(
|
|||||||
block_tables=block_table,
|
block_tables=block_table,
|
||||||
mask=attn_mask,
|
mask=attn_mask,
|
||||||
softmax_scale=1.0 / (head_size**0.5),
|
softmax_scale=1.0 / (head_size**0.5),
|
||||||
config=config,
|
|
||||||
mixed_precision=mixed_precision,
|
mixed_precision=mixed_precision,
|
||||||
return_debug_tensors=return_debug_tensors,
|
LARGE_TILE_SZ=LARGE_TILE_SZ,
|
||||||
)
|
)
|
||||||
_, n_kv_head, _, _ = key.shape
|
|
||||||
|
|
||||||
if return_debug_tensors:
|
o = flash_paged_attention[1, n_kv_head](**kwargs)
|
||||||
o, *debug_tensors = flash_paged_attention[1, n_kv_head](**kwargs)
|
return o
|
||||||
return o, *debug_tensors
|
|
||||||
else:
|
|
||||||
o = flash_paged_attention[1, n_kv_head](**kwargs)
|
|
||||||
return o
|
|
||||||
|
|||||||
Reference in New Issue
Block a user