[CI] Actually run tests/kernels/quantization/test_block_fp8.py in CI (#34274)

This commit is contained in:
Michael Goin
2026-02-26 19:58:03 -05:00
committed by GitHub
parent 38c498b8e3
commit 4fec53cfcb
3 changed files with 7 additions and 9 deletions

View File

@@ -70,7 +70,7 @@ steps:
- tests/kernels/moe/test_batched_deepgemm.py
- tests/kernels/attention/test_deepgemm_attention.py
commands:
- pytest -v -s kernels/quantization/test_block_fp8.py -k deep_gemm
- pytest -v -s kernels/quantization/test_block_fp8.py
- pytest -v -s kernels/moe/test_deepgemm.py
- pytest -v -s kernels/moe/test_batched_deepgemm.py
- pytest -v -s kernels/attention/test_deepgemm_attention.py

View File

@@ -37,13 +37,15 @@ vllm_config = VllmConfig()
# Test configurations
DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
# Quantization test configs
NUM_TOKENS = [7, 2050]
D = [512, 4096, 5120, 13824]
GROUP_SIZE = [64, 128, 512]
COLUMN_MAJOR_SCALES = [True, False]
TMA_ALIGNED_SCALES = [True, False]
M = [1, 7, 8, 83, 84, 4096]
N = [128, 512, 7168, 7748, 13824]
# Matmul test configs
M = [1, 7, 8, 83, 4096]
N = [128, 512, 576, 7168, 13824]
K = [256, 3884, 4096, 13824, 16384]
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
# and its hidden size is 7168.
@@ -162,8 +164,6 @@ def test_w8a8_block_fp8_cutlass_matmul():
k_tiles = (K + block_k - 1) // block_k
Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale
# Hopper requires row-major format for scales
Bs_cutlass = Bs.T.contiguous() if current_platform.is_device_capability(90) else Bs
A_fp8, As = per_token_group_quant_fp8(
A_fp32, block_size[1], column_major_scales=False
@@ -174,9 +174,7 @@ def test_w8a8_block_fp8_cutlass_matmul():
)
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
out = cutlass_scaled_mm(
A_fp8_cutlass, B_fp8, As_cutlass, Bs_cutlass, block_size, out_dtype
)
out = cutlass_scaled_mm(A_fp8_cutlass, B_fp8, As_cutlass, Bs, block_size, out_dtype)
rel_diff = torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))

View File

@@ -734,7 +734,7 @@ def should_use_flashinfer_for_blockscale_fp8_gemm(
# Verify DeepGEMM N/K dims requirements
# NOTE: Also synchronized with test_w8a8_block_fp8_deep_gemm_matmul
# test inside kernels/quatization/test_block_fp8.py
# test inside kernels/quantization/test_block_fp8.py
N_MULTIPLE = 64
K_MULTIPLE = 128