[NVIDIA][test] Tests for flashinfer TRTLLM BF16 MoE (#33715)

Signed-off-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com>
Co-authored-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
Linda
2026-02-11 13:38:11 +01:00
committed by GitHub
parent 0f5e55e7a8
commit 275e0d2a99
7 changed files with 296 additions and 1 deletions

View File

@@ -318,3 +318,44 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
torch.testing.assert_close(
output, flashinfer_cutlass_output, atol=5.5e-2, rtol=1e-2
)
@pytest.mark.parametrize(
"num_experts,intermediate,hidden",
[
(8, 2048, 1536),
(64, 4096, 4096),
],
)
def test_convert_moe_weights_to_flashinfer_trtllm_block_layout(
num_experts, intermediate, hidden
):
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
convert_moe_weights_to_flashinfer_trtllm_block_layout,
)
w13 = torch.randn(
(num_experts, 2 * intermediate, hidden), dtype=torch.bfloat16, device="cuda"
)
w2 = torch.randn(
(num_experts, hidden, intermediate), dtype=torch.bfloat16, device="cuda"
)
cache: dict[torch.Size, torch.Tensor] = {}
w13_converted, w2_converted = convert_moe_weights_to_flashinfer_trtllm_block_layout(
cache, w13, w2
)
assert w13_converted.ndim == 4, (
f"Expected 4D tensor, got shape {w13_converted.shape}"
)
assert w2_converted.ndim == 4, f"Expected 4D tensor, got shape {w2_converted.shape}"
assert w13_converted.numel() == w13.numel(), "W13 element count should be preserved"
assert w2_converted.numel() == w2.numel(), "W2 element count should be preserved"
assert w13_converted.dtype == torch.bfloat16
assert w2_converted.dtype == torch.bfloat16
assert w13_converted.shape[0] == num_experts
assert w2_converted.shape[0] == num_experts