[Kernel] Flash Attention 3 Support (#12093)

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
Lucas Wilkinson
2025-01-23 09:45:48 -05:00
committed by GitHub
parent c5b4b11d7f
commit 978b45f399
8 changed files with 151 additions and 83 deletions

View File

@@ -78,6 +78,7 @@ CASES = [
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("soft_cap", [None, 50])
@pytest.mark.parametrize("num_blocks", [2048])
@pytest.mark.parametrize("fa_version", [2, 3])
@torch.inference_mode()
def test_cascade(
seq_lens_and_common_prefix: Tuple[List[Tuple[int, int]], int],
@@ -87,8 +88,14 @@ def test_cascade(
block_size: int,
soft_cap: Optional[float],
num_blocks: int,
fa_version: int,
) -> None:
torch.set_default_device("cuda")
if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6)
or torch.cuda.get_device_capability() == (8, 9)):
pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to "
"insufficient shared memory for some shapes")
current_platform.seed_everything(0)
window_size = (-1, -1)
@@ -118,9 +125,7 @@ def test_cascade(
cu_query_lens = torch.tensor([0] + query_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
cu_kv_lens = torch.tensor([0] + kv_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
num_blocks,
@@ -140,7 +145,7 @@ def test_cascade(
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_query_lens,
cu_seqlens_k=cu_kv_lens,
seqused_k=kv_lens_tensor,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len,
softmax_scale=scale,
@@ -154,10 +159,8 @@ def test_cascade(
assert all(common_prefix_len < kv_len for kv_len in kv_lens)
cu_prefix_query_lens = torch.tensor([0, total_num_query_tokens],
dtype=torch.int32)
cu_prefix_kv_lens = torch.tensor([0, common_prefix_len], dtype=torch.int32)
cu_suffix_kv_lens = (
cu_kv_lens -
torch.arange(num_seqs + 1, dtype=torch.int32) * common_prefix_len)
prefix_kv_lens = torch.tensor([common_prefix_len], dtype=torch.int32)
suffix_kv_lens = kv_lens_tensor - common_prefix_len
output = torch.empty_like(query)
cascade_attention(
output=output,
@@ -167,8 +170,8 @@ def test_cascade(
cu_query_lens=cu_query_lens,
max_query_len=max_query_len,
cu_prefix_query_lens=cu_prefix_query_lens,
cu_prefix_kv_lens=cu_prefix_kv_lens,
cu_suffix_kv_lens=cu_suffix_kv_lens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
max_kv_len=max_kv_len,
softmax_scale=scale,
alibi_slopes=None,
@@ -176,6 +179,7 @@ def test_cascade(
logits_soft_cap=soft_cap if soft_cap is not None else 0,
block_table=block_tables,
common_prefix_len=common_prefix_len,
fa_version=fa_version,
)
# Compare the results.