[NVIDIA][test] Tests for flashinfer TRTLLM BF16 MoE (#33715)

Signed-off-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com>
Co-authored-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
Linda
2026-02-11 13:38:11 +01:00
committed by GitHub
parent 0f5e55e7a8
commit 275e0d2a99
7 changed files with 296 additions and 1 deletions

View File

@@ -1558,3 +1558,103 @@ def test_batched_fused_marlin_moe(
marlin_output = br.run(a, kwargs)
torch.testing.assert_close(marlin_output, ref_marlin_output, atol=1e-3, rtol=0)
@pytest.mark.parametrize("m,n,k", [(32, 1024, 1024)])
@pytest.mark.parametrize("e,topk", [(8, 2)])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.skipif(
not current_platform.is_device_capability_family(100),
reason="TRTLLM backend test only runs on Blackwell GPUs (SM10x).",
)
def test_unquantized_bf16_flashinfer_trtllm_backend(
m: int,
n: int,
k: int,
e: int,
topk: int,
dtype: torch.dtype,
monkeypatch,
workspace_init,
):
"""
Test BF16 unquantized MoE with FlashInfer TRTLLM backend.
"""
set_random_seed(7)
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP16", "1")
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.oracle.unquantized import (
UnquantizedMoeBackend,
)
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
UnquantizedFusedMoEMethod,
)
# Setup test data
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
router_logits = torch.randn((m, e), device="cuda", dtype=dtype)
moe_config = FusedMoEConfig(
num_experts=e,
experts_per_token=topk,
hidden_dim=k,
intermediate_size_per_partition=n,
num_local_experts=e,
activation="silu",
device="cuda",
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
in_dtype=dtype,
is_act_and_mul=True,
routing_method=RoutingMethodType.Renormalize,
max_num_tokens=m,
)
with set_current_vllm_config(vllm_config):
quant_method = UnquantizedFusedMoEMethod(moe_config)
# Verify TRTLLM backend was selected
assert (
quant_method.unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM
), f"Expected FLASHINFER_TRTLLM backend, got {quant_method.unquantized_backend}"
# Verify it's using monolithic path
assert quant_method.is_monolithic, (
"FLASHINFER_TRTLLM backend should use monolithic forward"
)
layer = torch.nn.Module()
layer.w13_weight = Parameter(w1.clone(), requires_grad=False)
layer.w2_weight = Parameter(w2.clone(), requires_grad=False)
layer.global_num_experts = e
layer.local_num_experts = e
layer.top_k = topk
layer.num_expert_group = 1
layer.topk_group = 1
layer.intermediate_size_per_partition = n
layer.ep_rank = 0
layer.activation = "silu"
layer.e_score_correction_bias = None
layer.routing_method_type = RoutingMethodType.Renormalize
quant_method.process_weights_after_loading(layer)
trtllm_output = quant_method.forward_monolithic_cuda(
layer=layer,
x=a,
router_logits=router_logits,
)
# Compute torch baseline
w1_original = w1.clone()
w2_original = w2.clone()
baseline_output = torch_moe(a, w1_original, w2_original, router_logits, topk)
close = torch.isclose(trtllm_output, baseline_output, atol=1e-1, rtol=0.85)
assert close.float().mean() > 0.925