[CI Perf]Prune Tests in kernel/mamba (#26538)

Signed-off-by: Fardin Hoque <kfhfar@amazon.com>
Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
Fardin Hoque
2025-10-13 15:22:31 -07:00
committed by GitHub
parent 314285d4f2
commit 577c72a227
4 changed files with 21 additions and 36 deletions

View File

@@ -188,9 +188,9 @@ def generate_continuous_batched_examples(
)
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("n_heads", [3, 4, 11, 16, 32])
@pytest.mark.parametrize("d_head", [5, 8, 19, 32, 128])
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("n_heads", [4, 16, 32])
@pytest.mark.parametrize("d_head", [5, 8, 32, 128])
@pytest.mark.parametrize("seq_len_chunk_size", [(112, 16), (128, 32)])
def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, itype):
# this tests the kernels on a single example (bs=1)
@@ -254,15 +254,14 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, it
)
@pytest.mark.parametrize("itype", [torch.float32, torch.float16])
@pytest.mark.parametrize("n_heads", [4, 8, 13])
@pytest.mark.parametrize("d_head", [5, 16, 21, 32])
@pytest.mark.parametrize("itype", [torch.float32])
@pytest.mark.parametrize("n_heads", [4, 8])
@pytest.mark.parametrize("d_head", [5, 16, 32])
@pytest.mark.parametrize(
"seq_len_chunk_size_cases",
[
# small-ish chunk_size (8)
(64, 8, 2, [(64, 32), (64, 32)]),
(64, 8, 2, [(32, 32), (32, 32), (32, 32)]),
(64, 8, 2, [(8, 8), (8, 8), (8, 8)]), # chunk size boundary
(
64,
@@ -270,16 +269,7 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, it
2,
[(4, 4), (4, 4), (4, 4), (4, 4)],
), # chunk_size larger than cont batches
(
64,
8,
5,
[
(64, 32, 16, 8, 8),
(8, 16, 32, 16, 8),
(8, 8, 16, 32, 16),
],
), # mode examples with varied lengths
(64, 8, 5, [(64, 32, 16, 8, 8)]),
# large-ish chunk_size (256)
(64, 256, 1, [(5,), (1,), (1,), (1,)]), # irregular sizes with small sequences
(
@@ -359,11 +349,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
@pytest.mark.parametrize("chunk_size", [8, 256])
@pytest.mark.parametrize(
"seqlens",
[
(16, 2, 8, 13),
(270, 88, 212, 203),
(16, 20),
],
[(16, 20), (270, 88, 212, 203)],
)
def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
# This test verifies the correctness of the chunked prefill implementation