From 63c0889416f0d4c3979c4046c6bf41c43143080c Mon Sep 17 00:00:00 2001 From: Roy Wang Date: Sun, 1 Feb 2026 05:10:24 +0800 Subject: [PATCH] [Misc] Fix flashinfer related tests (#33462) Signed-off-by: esmeetu --- tests/kernels/moe/test_moe.py | 2 +- .../quantization/test_flashinfer_nvfp4_scaled_mm.py | 2 +- tests/kernels/quantization/test_fp8_quant.py | 4 ++-- .../layers/quantization/utils/nvfp4_utils.py | 7 ++++--- vllm/utils/flashinfer.py | 2 +- 5 files changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index aaa048ab8..a304e70fc 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -412,7 +412,7 @@ def test_naive_block_assignment_moe( monkeypatch, workspace_init, ): - current_platform.seed_everything(7) + set_random_seed(7) monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size)) diff --git a/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py b/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py index 04e28dd20..e414ba7d2 100644 --- a/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py +++ b/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py @@ -74,7 +74,7 @@ def get_ref_results( @pytest.mark.parametrize("shape", SHAPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.parametrize("backend", ["cutlass", "trtllm"]) +@pytest.mark.parametrize("backend", ["cutlass", "cudnn", "trtllm"]) @pytest.mark.parametrize("autotune", [False, True]) @torch.inference_mode() def test_flashinfer_nvfp4_gemm( diff --git a/tests/kernels/quantization/test_fp8_quant.py b/tests/kernels/quantization/test_fp8_quant.py index 325665c48..ce94d3397 100644 --- a/tests/kernels/quantization/test_fp8_quant.py +++ b/tests/kernels/quantization/test_fp8_quant.py @@ -174,7 +174,7 @@ def test_static_fp8_quant_group_2d( f"group_shape ({group_shape[0]}, {group_shape[1]})" ) - current_platform.seed_everything(seed) + set_random_seed(seed) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") ref_out, scale = scaled_quantize( @@ -202,7 +202,7 @@ def test_static_fp8_quant_1d_scale( group_shape: tuple[int, int], ) -> None: """Test static FP8 quantization with 1D scale (per-token or per-channel).""" - current_platform.seed_everything(seed) + set_random_seed(seed) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") ref_out, scale_2d = scaled_quantize( diff --git a/vllm/model_executor/layers/quantization/utils/nvfp4_utils.py b/vllm/model_executor/layers/quantization/utils/nvfp4_utils.py index 8b2549be0..7e1d9991c 100644 --- a/vllm/model_executor/layers/quantization/utils/nvfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/nvfp4_utils.py @@ -154,9 +154,10 @@ def convert_to_nvfp4_linear_kernel_format( ) layer.weight = torch.nn.Parameter(weight, requires_grad=False) layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) - elif ( - backend == NvFp4LinearBackend.VLLM_CUTLASS - or backend == NvFp4LinearBackend.FLASHINFER_CUTLASS + elif backend in ( + NvFp4LinearBackend.VLLM_CUTLASS, + NvFp4LinearBackend.FLASHINFER_CUTLASS, + NvFp4LinearBackend.FLASHINFER_CUDNN, ): weight, weight_scale, weights_padding_cols = prepare_weights_for_nvfp4_cutlass( layer.weight.data, layer.weight_scale.data diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index cf5089247..f8cb1e14e 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -521,7 +521,7 @@ def flashinfer_scaled_fp4_mm( assert a.stride(-1) == 1 and b.stride(-1) == 1 assert a.shape[1] == b.shape[1] - if backend == "cutlass": + if backend in ("cutlass", "cudnn"): block_scale_a = block_scale_a.view(torch.uint8) block_scale_b = block_scale_b.view(torch.uint8)