[Kernel] Flash Attention 3 Support (#12093)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
@@ -78,6 +78,7 @@ CASES = [
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("soft_cap", [None, 50])
|
||||
@pytest.mark.parametrize("num_blocks", [2048])
|
||||
@pytest.mark.parametrize("fa_version", [2, 3])
|
||||
@torch.inference_mode()
|
||||
def test_cascade(
|
||||
seq_lens_and_common_prefix: Tuple[List[Tuple[int, int]], int],
|
||||
@@ -87,8 +88,14 @@ def test_cascade(
|
||||
block_size: int,
|
||||
soft_cap: Optional[float],
|
||||
num_blocks: int,
|
||||
fa_version: int,
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6)
|
||||
or torch.cuda.get_device_capability() == (8, 9)):
|
||||
pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to "
|
||||
"insufficient shared memory for some shapes")
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
window_size = (-1, -1)
|
||||
@@ -118,9 +125,7 @@ def test_cascade(
|
||||
cu_query_lens = torch.tensor([0] + query_lens,
|
||||
dtype=torch.int32).cumsum(dim=0,
|
||||
dtype=torch.int32)
|
||||
cu_kv_lens = torch.tensor([0] + kv_lens,
|
||||
dtype=torch.int32).cumsum(dim=0,
|
||||
dtype=torch.int32)
|
||||
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
|
||||
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(0,
|
||||
num_blocks,
|
||||
@@ -140,7 +145,7 @@ def test_cascade(
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
cu_seqlens_k=cu_kv_lens,
|
||||
seqused_k=kv_lens_tensor,
|
||||
max_seqlen_q=max_query_len,
|
||||
max_seqlen_k=max_kv_len,
|
||||
softmax_scale=scale,
|
||||
@@ -154,10 +159,8 @@ def test_cascade(
|
||||
assert all(common_prefix_len < kv_len for kv_len in kv_lens)
|
||||
cu_prefix_query_lens = torch.tensor([0, total_num_query_tokens],
|
||||
dtype=torch.int32)
|
||||
cu_prefix_kv_lens = torch.tensor([0, common_prefix_len], dtype=torch.int32)
|
||||
cu_suffix_kv_lens = (
|
||||
cu_kv_lens -
|
||||
torch.arange(num_seqs + 1, dtype=torch.int32) * common_prefix_len)
|
||||
prefix_kv_lens = torch.tensor([common_prefix_len], dtype=torch.int32)
|
||||
suffix_kv_lens = kv_lens_tensor - common_prefix_len
|
||||
output = torch.empty_like(query)
|
||||
cascade_attention(
|
||||
output=output,
|
||||
@@ -167,8 +170,8 @@ def test_cascade(
|
||||
cu_query_lens=cu_query_lens,
|
||||
max_query_len=max_query_len,
|
||||
cu_prefix_query_lens=cu_prefix_query_lens,
|
||||
cu_prefix_kv_lens=cu_prefix_kv_lens,
|
||||
cu_suffix_kv_lens=cu_suffix_kv_lens,
|
||||
prefix_kv_lens=prefix_kv_lens,
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
max_kv_len=max_kv_len,
|
||||
softmax_scale=scale,
|
||||
alibi_slopes=None,
|
||||
@@ -176,6 +179,7 @@ def test_cascade(
|
||||
logits_soft_cap=soft_cap if soft_cap is not None else 0,
|
||||
block_table=block_tables,
|
||||
common_prefix_len=common_prefix_len,
|
||||
fa_version=fa_version,
|
||||
)
|
||||
|
||||
# Compare the results.
|
||||
|
||||
Reference in New Issue
Block a user