[CI Perf] Prune tests in tests/kernels/attention/ (#22936)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -12,14 +12,16 @@ from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
|
||||
flash_attn_with_kvcache,
|
||||
is_fa_version_supported)
|
||||
|
||||
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
|
||||
NUM_HEADS = [(4, 4), (8, 2)]
|
||||
HEAD_SIZES = [128, 256]
|
||||
BLOCK_SIZES = [16, 32]
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
BLOCK_SIZES = [16]
|
||||
DTYPES = [torch.bfloat16]
|
||||
QDTYPES = [None, torch.float8_e4m3fn]
|
||||
# one value large enough to test overflow in index calculation.
|
||||
# one value small enough to test the schema op check
|
||||
NUM_BLOCKS = [32768, 2048]
|
||||
SOFT_CAPS = [None, 50.0]
|
||||
SLIDING_WINDOWS = [None, 256]
|
||||
|
||||
|
||||
def ref_paged_attn(
|
||||
@@ -83,9 +85,9 @@ def ref_paged_attn(
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
|
||||
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("sliding_window", [None, 256])
|
||||
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
|
||||
@pytest.mark.parametrize("fa_version", [2, 3])
|
||||
@pytest.mark.parametrize("q_dtype", QDTYPES)
|
||||
@torch.inference_mode()
|
||||
@@ -198,9 +200,9 @@ def test_flash_attn_with_paged_kv(
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("sliding_window", [None, 256])
|
||||
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
|
||||
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("fa_version", [2, 3])
|
||||
@pytest.mark.parametrize("q_dtype", QDTYPES)
|
||||
|
||||
Reference in New Issue
Block a user