[Kernel] Support sliding window in flash attention backend (#9403)

This commit is contained in:
Chen Zhang
2024-10-20 10:57:52 -07:00
committed by GitHub
parent 962d2c6349
commit 4fa3e33349
13 changed files with 41 additions and 61 deletions

View File

@@ -78,6 +78,7 @@ def ref_paged_attn(
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("sliding_window", [None, 256])
@torch.inference_mode()
def test_flash_attn_with_paged_kv(
kv_lens: List[int],
@@ -87,6 +88,7 @@ def test_flash_attn_with_paged_kv(
block_size: int,
soft_cap: Optional[float],
num_blocks: int,
sliding_window: Optional[int],
) -> None:
torch.set_default_device("cuda")
seed_everything(0)
@@ -96,6 +98,8 @@ def test_flash_attn_with_paged_kv(
assert num_query_heads % num_kv_heads == 0
max_kv_len = max(kv_lens)
scale = head_size**-0.5
window_size = ((sliding_window - 1, 0) if sliding_window is not None else
(-1, -1))
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
key_cache = torch.randn(num_blocks,
@@ -121,18 +125,18 @@ def test_flash_attn_with_paged_kv(
block_table=block_tables,
cache_seqlens=kv_lens_tensor,
softcap=soft_cap if soft_cap is not None else 0,
window_size=window_size,
).squeeze(1)
ref_output = ref_paged_attn(
query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=[1] * num_seqs,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap,
)
ref_output = ref_paged_attn(query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=[1] * num_seqs,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap,
sliding_window=sliding_window)
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"
@@ -141,7 +145,7 @@ def test_flash_attn_with_paged_kv(
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("sliding_window", [None])
@pytest.mark.parametrize("sliding_window", [None, 256])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@@ -166,8 +170,7 @@ def test_varlen_with_paged_kv(
assert num_query_heads % num_kv_heads == 0
max_query_len = max(query_lens)
max_kv_len = max(kv_lens)
window_size = ((sliding_window,
sliding_window) if sliding_window is not None else
window_size = ((sliding_window - 1, 0) if sliding_window is not None else
(-1, -1))
scale = head_size**-0.5