[NVIDIA][test] Tests for flashinfer TRTLLM BF16 MoE (#33715)
Signed-off-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com> Co-authored-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
@@ -318,3 +318,44 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
|
||||
torch.testing.assert_close(
|
||||
output, flashinfer_cutlass_output, atol=5.5e-2, rtol=1e-2
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_experts,intermediate,hidden",
|
||||
[
|
||||
(8, 2048, 1536),
|
||||
(64, 4096, 4096),
|
||||
],
|
||||
)
|
||||
def test_convert_moe_weights_to_flashinfer_trtllm_block_layout(
|
||||
num_experts, intermediate, hidden
|
||||
):
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
convert_moe_weights_to_flashinfer_trtllm_block_layout,
|
||||
)
|
||||
|
||||
w13 = torch.randn(
|
||||
(num_experts, 2 * intermediate, hidden), dtype=torch.bfloat16, device="cuda"
|
||||
)
|
||||
w2 = torch.randn(
|
||||
(num_experts, hidden, intermediate), dtype=torch.bfloat16, device="cuda"
|
||||
)
|
||||
|
||||
cache: dict[torch.Size, torch.Tensor] = {}
|
||||
w13_converted, w2_converted = convert_moe_weights_to_flashinfer_trtllm_block_layout(
|
||||
cache, w13, w2
|
||||
)
|
||||
|
||||
assert w13_converted.ndim == 4, (
|
||||
f"Expected 4D tensor, got shape {w13_converted.shape}"
|
||||
)
|
||||
assert w2_converted.ndim == 4, f"Expected 4D tensor, got shape {w2_converted.shape}"
|
||||
|
||||
assert w13_converted.numel() == w13.numel(), "W13 element count should be preserved"
|
||||
assert w2_converted.numel() == w2.numel(), "W2 element count should be preserved"
|
||||
|
||||
assert w13_converted.dtype == torch.bfloat16
|
||||
assert w2_converted.dtype == torch.bfloat16
|
||||
|
||||
assert w13_converted.shape[0] == num_experts
|
||||
assert w2_converted.shape[0] == num_experts
|
||||
|
||||
@@ -1558,3 +1558,103 @@ def test_batched_fused_marlin_moe(
|
||||
marlin_output = br.run(a, kwargs)
|
||||
|
||||
torch.testing.assert_close(marlin_output, ref_marlin_output, atol=1e-3, rtol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", [(32, 1024, 1024)])
|
||||
@pytest.mark.parametrize("e,topk", [(8, 2)])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_device_capability_family(100),
|
||||
reason="TRTLLM backend test only runs on Blackwell GPUs (SM10x).",
|
||||
)
|
||||
def test_unquantized_bf16_flashinfer_trtllm_backend(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
):
|
||||
"""
|
||||
Test BF16 unquantized MoE with FlashInfer TRTLLM backend.
|
||||
"""
|
||||
set_random_seed(7)
|
||||
|
||||
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP16", "1")
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.oracle.unquantized import (
|
||||
UnquantizedMoeBackend,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
|
||||
UnquantizedFusedMoEMethod,
|
||||
)
|
||||
|
||||
# Setup test data
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
router_logits = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
|
||||
moe_config = FusedMoEConfig(
|
||||
num_experts=e,
|
||||
experts_per_token=topk,
|
||||
hidden_dim=k,
|
||||
intermediate_size_per_partition=n,
|
||||
num_local_experts=e,
|
||||
activation="silu",
|
||||
device="cuda",
|
||||
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
|
||||
in_dtype=dtype,
|
||||
is_act_and_mul=True,
|
||||
routing_method=RoutingMethodType.Renormalize,
|
||||
max_num_tokens=m,
|
||||
)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
quant_method = UnquantizedFusedMoEMethod(moe_config)
|
||||
|
||||
# Verify TRTLLM backend was selected
|
||||
assert (
|
||||
quant_method.unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM
|
||||
), f"Expected FLASHINFER_TRTLLM backend, got {quant_method.unquantized_backend}"
|
||||
|
||||
# Verify it's using monolithic path
|
||||
assert quant_method.is_monolithic, (
|
||||
"FLASHINFER_TRTLLM backend should use monolithic forward"
|
||||
)
|
||||
layer = torch.nn.Module()
|
||||
layer.w13_weight = Parameter(w1.clone(), requires_grad=False)
|
||||
layer.w2_weight = Parameter(w2.clone(), requires_grad=False)
|
||||
layer.global_num_experts = e
|
||||
layer.local_num_experts = e
|
||||
layer.top_k = topk
|
||||
layer.num_expert_group = 1
|
||||
layer.topk_group = 1
|
||||
layer.intermediate_size_per_partition = n
|
||||
layer.ep_rank = 0
|
||||
layer.activation = "silu"
|
||||
layer.e_score_correction_bias = None
|
||||
layer.routing_method_type = RoutingMethodType.Renormalize
|
||||
|
||||
quant_method.process_weights_after_loading(layer)
|
||||
|
||||
trtllm_output = quant_method.forward_monolithic_cuda(
|
||||
layer=layer,
|
||||
x=a,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
# Compute torch baseline
|
||||
w1_original = w1.clone()
|
||||
w2_original = w2.clone()
|
||||
baseline_output = torch_moe(a, w1_original, w2_original, router_logits, topk)
|
||||
|
||||
close = torch.isclose(trtllm_output, baseline_output, atol=1e-1, rtol=0.85)
|
||||
assert close.float().mean() > 0.925
|
||||
|
||||
132
tests/kernels/moe/test_unquantized_backend_selection.py
Normal file
132
tests/kernels/moe/test_unquantized_backend_selection.py
Normal file
@@ -0,0 +1,132 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.kernels.moe.utils import make_dummy_moe_config
|
||||
from vllm.model_executor.layers.fused_moe.oracle.unquantized import (
|
||||
UnquantizedMoeBackend,
|
||||
select_unquantized_moe_backend,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"platform_method,expected_backend",
|
||||
[
|
||||
("is_cuda", UnquantizedMoeBackend.TRITON), # Default CUDA without FlashInfer
|
||||
("is_rocm", UnquantizedMoeBackend.TRITON),
|
||||
("is_cpu", UnquantizedMoeBackend.CPU),
|
||||
("is_xpu", UnquantizedMoeBackend.XPU),
|
||||
("is_tpu", UnquantizedMoeBackend.TPU),
|
||||
("is_out_of_tree", UnquantizedMoeBackend.OOT),
|
||||
],
|
||||
)
|
||||
@patch(
|
||||
"vllm.model_executor.layers.fused_moe.oracle.unquantized.has_flashinfer",
|
||||
return_value=False,
|
||||
)
|
||||
def test_select_default_backend_by_platform(
|
||||
mock_has_flashinfer,
|
||||
monkeypatch,
|
||||
platform_method,
|
||||
expected_backend,
|
||||
):
|
||||
"""Test backend selection for different platforms."""
|
||||
with patch(
|
||||
"vllm.model_executor.layers.fused_moe.oracle.unquantized.current_platform"
|
||||
) as mock_platform:
|
||||
# Set all platform checks to False
|
||||
mock_platform.is_cuda.return_value = False
|
||||
mock_platform.is_rocm.return_value = False
|
||||
mock_platform.is_cpu.return_value = False
|
||||
mock_platform.is_xpu.return_value = False
|
||||
mock_platform.is_tpu.return_value = False
|
||||
mock_platform.is_out_of_tree.return_value = False
|
||||
|
||||
# Set only the specified platform to True
|
||||
getattr(mock_platform, platform_method).return_value = True
|
||||
|
||||
moe_config = make_dummy_moe_config()
|
||||
selected_backend = select_unquantized_moe_backend(
|
||||
moe_config=moe_config,
|
||||
use_ep=False,
|
||||
use_dp=False,
|
||||
)
|
||||
|
||||
assert selected_backend == expected_backend
|
||||
|
||||
|
||||
@patch(
|
||||
"vllm.model_executor.layers.fused_moe.oracle.unquantized.has_flashinfer",
|
||||
return_value=True,
|
||||
)
|
||||
@patch(
|
||||
"vllm.model_executor.layers.fused_moe.oracle.unquantized.is_supported_config_trtllm_bf16",
|
||||
return_value=(True, None),
|
||||
)
|
||||
def test_select_cuda_flashinfer_trtllm_backend(
|
||||
mock_has_flashinfer, mock_is_supported_trtllm, monkeypatch
|
||||
):
|
||||
"""Test CUDA backend selection when FlashInfer TRTLLM is available and enabled."""
|
||||
with patch(
|
||||
"vllm.model_executor.layers.fused_moe.oracle.unquantized.current_platform"
|
||||
) as mock_platform:
|
||||
# Set as CUDA platform
|
||||
mock_platform.is_cuda.return_value = True
|
||||
mock_platform.is_rocm.return_value = False
|
||||
mock_platform.is_cpu.return_value = False
|
||||
mock_platform.is_xpu.return_value = False
|
||||
mock_platform.is_tpu.return_value = False
|
||||
mock_platform.is_out_of_tree.return_value = False
|
||||
|
||||
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP16", "1")
|
||||
|
||||
moe_config = make_dummy_moe_config()
|
||||
|
||||
selected_backend = select_unquantized_moe_backend(
|
||||
moe_config=moe_config,
|
||||
use_ep=True,
|
||||
use_dp=False,
|
||||
)
|
||||
|
||||
assert selected_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM
|
||||
|
||||
|
||||
@patch(
|
||||
"vllm.model_executor.layers.fused_moe.oracle.unquantized.has_flashinfer",
|
||||
return_value=True,
|
||||
)
|
||||
@patch(
|
||||
"vllm.model_executor.layers.fused_moe.oracle.unquantized.is_supported_config_trtllm_bf16",
|
||||
return_value=(False, None),
|
||||
)
|
||||
def test_select_cuda_flashinfer_cutlass_backend(
|
||||
mock_has_flashinfer, mock_is_supported_trtllm, monkeypatch
|
||||
):
|
||||
"""Test CUDA backend selection when FlashInfer TRTLLM is not available
|
||||
and FlashInfer CUTLASS is available."""
|
||||
with patch(
|
||||
"vllm.model_executor.layers.fused_moe.oracle.unquantized.current_platform"
|
||||
) as mock_platform:
|
||||
# Set as CUDA platform with Hopper capability
|
||||
mock_platform.is_cuda.return_value = True
|
||||
mock_platform.is_rocm.return_value = False
|
||||
mock_platform.is_cpu.return_value = False
|
||||
mock_platform.is_xpu.return_value = False
|
||||
mock_platform.is_tpu.return_value = False
|
||||
mock_platform.is_out_of_tree.return_value = False
|
||||
mock_platform.has_device_capability.return_value = True # SM90+
|
||||
|
||||
# Enable FlashInfer via env var
|
||||
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP16", "1")
|
||||
|
||||
moe_config = make_dummy_moe_config()
|
||||
|
||||
selected_backend = select_unquantized_moe_backend(
|
||||
moe_config=moe_config,
|
||||
use_ep=True, # CUTLASS requires EP
|
||||
use_dp=False, # CUTLASS doesn't support DP
|
||||
)
|
||||
|
||||
assert selected_backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS
|
||||
Reference in New Issue
Block a user