[CI Perf]Prune Tests in kernel/mamba (#26538)
Signed-off-by: Fardin Hoque <kfhfar@amazon.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
@@ -229,8 +229,8 @@ def selective_scan_opcheck_fn(
|
||||
|
||||
|
||||
@pytest.mark.parametrize("wtype", [torch.float32])
|
||||
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("seqlen", [128, 256, 512, 1024, 2048, 4096])
|
||||
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
|
||||
@pytest.mark.parametrize("seqlen", [128, 1024, 4096])
|
||||
@pytest.mark.parametrize("has_delta_bias", [True])
|
||||
@pytest.mark.parametrize("delta_softplus", [True])
|
||||
@pytest.mark.parametrize("has_z", [True])
|
||||
@@ -238,7 +238,7 @@ def selective_scan_opcheck_fn(
|
||||
@pytest.mark.parametrize("varBC_groups", [1, 2])
|
||||
@pytest.mark.parametrize("is_variable_C", [True])
|
||||
@pytest.mark.parametrize("is_variable_B", [True])
|
||||
@pytest.mark.parametrize("scan_chunks", [1, 2, 3])
|
||||
@pytest.mark.parametrize("scan_chunks", [1, 3])
|
||||
def test_selective_scan(
|
||||
is_variable_B,
|
||||
is_variable_C,
|
||||
@@ -375,9 +375,9 @@ def test_selective_scan(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
|
||||
@pytest.mark.parametrize("has_z", [False, True])
|
||||
@pytest.mark.parametrize("dstate", [16, 32, 64])
|
||||
@pytest.mark.parametrize("dstate", [16, 64])
|
||||
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
||||
def test_selective_state_update(dim, dstate, has_z, itype):
|
||||
device = "cuda"
|
||||
@@ -413,7 +413,7 @@ def test_selective_state_update(dim, dstate, has_z, itype):
|
||||
|
||||
@pytest.mark.parametrize("wtype", [torch.float32])
|
||||
@pytest.mark.parametrize("itype", [torch.float32])
|
||||
@pytest.mark.parametrize("seqlen", [1, 128, 129, 256, 512, 1024, 2048, 4096])
|
||||
@pytest.mark.parametrize("seqlen", [1, 256, 1024, 4096])
|
||||
@pytest.mark.parametrize("return_last_state", [True])
|
||||
@pytest.mark.parametrize("has_delta_bias", [True])
|
||||
@pytest.mark.parametrize("delta_softplus", [True])
|
||||
@@ -589,9 +589,9 @@ def test_selective_scan_varlen(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
|
||||
@pytest.mark.parametrize("has_z", [True])
|
||||
@pytest.mark.parametrize("dstate", [16, 32, 64])
|
||||
@pytest.mark.parametrize("dstate", [16, 64])
|
||||
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
||||
# tests correctness in case subset of the sequences are padded
|
||||
@pytest.mark.parametrize("with_padding", [True, False])
|
||||
@@ -679,11 +679,11 @@ def test_selective_state_update_with_batch_indices(
|
||||
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
|
||||
@pytest.mark.parametrize("has_z", [False, True])
|
||||
@pytest.mark.parametrize("tie_hdim", [False, True])
|
||||
@pytest.mark.parametrize("ngroups", [1, 2, 4])
|
||||
@pytest.mark.parametrize("dstate", [16, 32, 64])
|
||||
@pytest.mark.parametrize("ngroups", [1, 4])
|
||||
@pytest.mark.parametrize("dstate", [16, 64])
|
||||
@pytest.mark.parametrize("dim", [2048, 4096])
|
||||
def test_selective_state_update_with_heads_with_batch_indices(
|
||||
dim, dstate, ngroups, has_z, tie_hdim, itype
|
||||
|
||||
Reference in New Issue
Block a user