[CI Perf] Prune tests in tests/kernels/quantization/ (#22942)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -60,10 +60,18 @@ def test_rocm_compressed_tensors_w8a8(vllm_runner, example_prompts, model_path,
|
||||
num_logprobs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M", [1, 33, 64, 512])
|
||||
@pytest.mark.parametrize("N", [256, 971, 20486])
|
||||
@pytest.mark.parametrize("K", [128, 496, 1024])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16])
|
||||
MNK_FACTORS = [
|
||||
(1, 256, 128),
|
||||
(33, 256, 496),
|
||||
(64, 971, 1024),
|
||||
(64, 20486, 128),
|
||||
(512, 256, 496),
|
||||
(512, 20486, 1024),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M,N,K", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("in_dtype", get_8bit_types())
|
||||
@pytest.mark.parametrize("use_scalar_scale_a", [True, False])
|
||||
@pytest.mark.parametrize("use_scalar_scale_b", [True, False])
|
||||
|
||||
Reference in New Issue
Block a user