[NVIDIA][test] Tests for flashinfer TRTLLM BF16 MoE (#33715)
Signed-off-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com> Co-authored-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
@@ -318,3 +318,44 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
|
||||
torch.testing.assert_close(
|
||||
output, flashinfer_cutlass_output, atol=5.5e-2, rtol=1e-2
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_experts,intermediate,hidden",
|
||||
[
|
||||
(8, 2048, 1536),
|
||||
(64, 4096, 4096),
|
||||
],
|
||||
)
|
||||
def test_convert_moe_weights_to_flashinfer_trtllm_block_layout(
|
||||
num_experts, intermediate, hidden
|
||||
):
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
convert_moe_weights_to_flashinfer_trtllm_block_layout,
|
||||
)
|
||||
|
||||
w13 = torch.randn(
|
||||
(num_experts, 2 * intermediate, hidden), dtype=torch.bfloat16, device="cuda"
|
||||
)
|
||||
w2 = torch.randn(
|
||||
(num_experts, hidden, intermediate), dtype=torch.bfloat16, device="cuda"
|
||||
)
|
||||
|
||||
cache: dict[torch.Size, torch.Tensor] = {}
|
||||
w13_converted, w2_converted = convert_moe_weights_to_flashinfer_trtllm_block_layout(
|
||||
cache, w13, w2
|
||||
)
|
||||
|
||||
assert w13_converted.ndim == 4, (
|
||||
f"Expected 4D tensor, got shape {w13_converted.shape}"
|
||||
)
|
||||
assert w2_converted.ndim == 4, f"Expected 4D tensor, got shape {w2_converted.shape}"
|
||||
|
||||
assert w13_converted.numel() == w13.numel(), "W13 element count should be preserved"
|
||||
assert w2_converted.numel() == w2.numel(), "W2 element count should be preserved"
|
||||
|
||||
assert w13_converted.dtype == torch.bfloat16
|
||||
assert w2_converted.dtype == torch.bfloat16
|
||||
|
||||
assert w13_converted.shape[0] == num_experts
|
||||
assert w2_converted.shape[0] == num_experts
|
||||
|
||||
Reference in New Issue
Block a user