Files
vllm/tests/kernels/quantization/test_mxfp4_triton_ep.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

195 lines
7.1 KiB
Python
Raw Normal View History

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests that triton_kernel_moe_forward correctly applies expert_map
remapping when expert parallelism (EP) is enabled.
Previously, legacy_routing was always used and it produced routing data
with global expert IDs that didn't correspond to local weight indices,
causing illegal memory access with EP. The fix splits routing: when
expert_map is provided, topk selection is performed first, expert_map is
applied to remap globallocal IDs, and make_routing_data builds routing
structures from the local IDs.
"""
from unittest.mock import MagicMock, patch
import pytest
import torch
from vllm.model_executor.layers.quantization.mxfp4 import (
Mxfp4Backend,
Mxfp4MoEMethod,
)
def _make_mock_moe_config(ep_size: int = 1) -> MagicMock:
"""Create a mock FusedMoEConfig with the given EP size."""
parallel_config = MagicMock()
parallel_config.ep_size = ep_size
moe_config = MagicMock()
moe_config.ep_size = ep_size
moe_config.is_lora_enabled = False
moe_config.moe_parallel_config = parallel_config
return moe_config
class TestMxfp4TritonIsMonolithic:
"""Verify that is_monolithic is always True for the TRITON backend,
regardless of EP size, since triton_kernel_moe_forward now handles
expert_map remapping internally."""
@pytest.mark.parametrize(
"backend,ep_size,expected_monolithic",
[
# TRITON is always monolithic (handles EP via expert_map remapping)
(Mxfp4Backend.TRITON, 1, True),
(Mxfp4Backend.TRITON, 2, True),
(Mxfp4Backend.TRITON, 4, True),
# SM100 backends are always monolithic
(Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM, 1, True),
(Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM, 2, True),
(Mxfp4Backend.SM100_FI_MXFP4_BF16, 1, True),
(Mxfp4Backend.SM100_FI_MXFP4_BF16, 2, True),
# MARLIN is never monolithic
(Mxfp4Backend.MARLIN, 1, False),
(Mxfp4Backend.MARLIN, 2, False),
],
ids=[
"triton-no-ep",
"triton-ep2",
"triton-ep4",
"sm100-trtllm-no-ep",
"sm100-trtllm-ep2",
"sm100-bf16-no-ep",
"sm100-bf16-ep2",
"marlin-no-ep",
"marlin-ep2",
],
)
@patch(
"vllm.model_executor.layers.quantization.mxfp4.get_mxfp4_backend",
)
@patch(
"vllm.model_executor.layers.quantization.mxfp4.get_current_vllm_config",
)
def test_is_monolithic(
self,
mock_get_config,
mock_get_backend,
backend,
ep_size,
expected_monolithic,
):
"""is_monolithic should be True for TRITON regardless of EP size."""
mock_get_backend.return_value = backend
mock_compilation_config = MagicMock()
mock_compilation_config.max_cudagraph_capture_size = 1024
mock_vllm_config = MagicMock()
mock_vllm_config.compilation_config = mock_compilation_config
mock_get_config.return_value = mock_vllm_config
moe_config = _make_mock_moe_config(ep_size=ep_size)
method = Mxfp4MoEMethod(moe_config)
assert method.is_monolithic == expected_monolithic, (
f"Expected is_monolithic={expected_monolithic} for "
f"backend={backend.name}, ep_size={ep_size}, "
f"but got {method.is_monolithic}."
)
class TestTritonMoeForwardExpertMap:
"""Test that triton_kernel_moe_forward applies expert_map remapping
when expert_map is provided (EP active)."""
@pytest.mark.parametrize("expert_map_present", [False, True])
def test_routing_path_selection(self, expert_map_present):
"""Verify that the EP-aware routing path is taken when expert_map
is present, and the legacy_routing path is taken otherwise."""
device = "cuda" if torch.cuda.is_available() else "cpu"
# This is a structural test: we mock the routing functions to
# verify the correct path is exercised.
mock_expert_map = (
torch.tensor([0, -1, 1, -1], device=device) if expert_map_present else None
)
with (
patch(
"vllm.model_executor.layers.fused_moe."
"gpt_oss_triton_kernels_moe.legacy_routing"
) as mock_legacy,
patch("triton_kernels.topk.topk") as mock_topk,
patch(
"vllm.model_executor.layers.fused_moe."
"gpt_oss_triton_kernels_moe.make_routing_data"
) as mock_make_routing,
patch(
"vllm.model_executor.layers.fused_moe."
"gpt_oss_triton_kernels_moe.triton_kernel_fused_experts"
) as mock_fused_experts,
):
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501
triton_kernel_moe_forward,
)
# Set up return values
mock_routing_data = MagicMock()
mock_gather = MagicMock()
mock_scatter = MagicMock()
if expert_map_present:
sparse_result = MagicMock()
sparse_result.indx = torch.tensor([[0, 2]], dtype=torch.int32)
sparse_result.vals = torch.tensor([[0.6, 0.4]])
mock_topk.return_value = sparse_result
mock_make_routing.return_value = (
mock_routing_data,
mock_gather,
mock_scatter,
)
else:
mock_legacy.return_value = (
mock_routing_data,
mock_gather,
mock_scatter,
)
mock_fused_experts.return_value = torch.zeros((1, 8), device=device)
hidden = torch.randn((1, 8), device=device)
w1 = torch.randn((2, 8, 16), device=device)
w2 = torch.randn((2, 8, 8), device=device)
logits = torch.randn((1, 4), device=device)
triton_kernel_moe_forward(
hidden_states=hidden,
w1=w1,
w2=w2,
gating_output=logits,
topk=2,
renormalize=True,
expert_map=mock_expert_map,
)
if expert_map_present:
# EP path: should use topk + make_routing_data, NOT
# legacy_routing
mock_topk.assert_called_once()
mock_make_routing.assert_called_once()
mock_legacy.assert_not_called()
# expert_map should be None in the fused_experts call
# (already applied)
call_kwargs = mock_fused_experts.call_args
assert call_kwargs[1].get("expert_map") is None or (
len(call_kwargs[0]) > 0
)
else:
# Non-EP path: should use legacy_routing
mock_legacy.assert_called_once()
mock_topk.assert_not_called()
mock_make_routing.assert_not_called()