fix(mxfp4): Disable monolithic path for TRITON backend with EP (#34270)
Signed-off-by: Elizabeth Thomas <email2eliza@gmail.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
194
tests/kernels/quantization/test_mxfp4_triton_ep.py
Normal file
194
tests/kernels/quantization/test_mxfp4_triton_ep.py
Normal file
@@ -0,0 +1,194 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Tests that triton_kernel_moe_forward correctly applies expert_map
|
||||
remapping when expert parallelism (EP) is enabled.
|
||||
|
||||
Previously, legacy_routing was always used and it produced routing data
|
||||
with global expert IDs that didn't correspond to local weight indices,
|
||||
causing illegal memory access with EP. The fix splits routing: when
|
||||
expert_map is provided, topk selection is performed first, expert_map is
|
||||
applied to remap global→local IDs, and make_routing_data builds routing
|
||||
structures from the local IDs.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.mxfp4 import (
|
||||
Mxfp4Backend,
|
||||
Mxfp4MoEMethod,
|
||||
)
|
||||
|
||||
|
||||
def _make_mock_moe_config(ep_size: int = 1) -> MagicMock:
|
||||
"""Create a mock FusedMoEConfig with the given EP size."""
|
||||
parallel_config = MagicMock()
|
||||
parallel_config.ep_size = ep_size
|
||||
|
||||
moe_config = MagicMock()
|
||||
moe_config.ep_size = ep_size
|
||||
moe_config.is_lora_enabled = False
|
||||
moe_config.moe_parallel_config = parallel_config
|
||||
return moe_config
|
||||
|
||||
|
||||
class TestMxfp4TritonIsMonolithic:
|
||||
"""Verify that is_monolithic is always True for the TRITON backend,
|
||||
regardless of EP size, since triton_kernel_moe_forward now handles
|
||||
expert_map remapping internally."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"backend,ep_size,expected_monolithic",
|
||||
[
|
||||
# TRITON is always monolithic (handles EP via expert_map remapping)
|
||||
(Mxfp4Backend.TRITON, 1, True),
|
||||
(Mxfp4Backend.TRITON, 2, True),
|
||||
(Mxfp4Backend.TRITON, 4, True),
|
||||
# SM100 backends are always monolithic
|
||||
(Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM, 1, True),
|
||||
(Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM, 2, True),
|
||||
(Mxfp4Backend.SM100_FI_MXFP4_BF16, 1, True),
|
||||
(Mxfp4Backend.SM100_FI_MXFP4_BF16, 2, True),
|
||||
# MARLIN is never monolithic
|
||||
(Mxfp4Backend.MARLIN, 1, False),
|
||||
(Mxfp4Backend.MARLIN, 2, False),
|
||||
],
|
||||
ids=[
|
||||
"triton-no-ep",
|
||||
"triton-ep2",
|
||||
"triton-ep4",
|
||||
"sm100-trtllm-no-ep",
|
||||
"sm100-trtllm-ep2",
|
||||
"sm100-bf16-no-ep",
|
||||
"sm100-bf16-ep2",
|
||||
"marlin-no-ep",
|
||||
"marlin-ep2",
|
||||
],
|
||||
)
|
||||
@patch(
|
||||
"vllm.model_executor.layers.quantization.mxfp4.get_mxfp4_backend",
|
||||
)
|
||||
@patch(
|
||||
"vllm.model_executor.layers.quantization.mxfp4.get_current_vllm_config",
|
||||
)
|
||||
def test_is_monolithic(
|
||||
self,
|
||||
mock_get_config,
|
||||
mock_get_backend,
|
||||
backend,
|
||||
ep_size,
|
||||
expected_monolithic,
|
||||
):
|
||||
"""is_monolithic should be True for TRITON regardless of EP size."""
|
||||
mock_get_backend.return_value = backend
|
||||
|
||||
mock_compilation_config = MagicMock()
|
||||
mock_compilation_config.max_cudagraph_capture_size = 1024
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.compilation_config = mock_compilation_config
|
||||
mock_get_config.return_value = mock_vllm_config
|
||||
|
||||
moe_config = _make_mock_moe_config(ep_size=ep_size)
|
||||
method = Mxfp4MoEMethod(moe_config)
|
||||
|
||||
assert method.is_monolithic == expected_monolithic, (
|
||||
f"Expected is_monolithic={expected_monolithic} for "
|
||||
f"backend={backend.name}, ep_size={ep_size}, "
|
||||
f"but got {method.is_monolithic}."
|
||||
)
|
||||
|
||||
|
||||
class TestTritonMoeForwardExpertMap:
|
||||
"""Test that triton_kernel_moe_forward applies expert_map remapping
|
||||
when expert_map is provided (EP active)."""
|
||||
|
||||
@pytest.mark.parametrize("expert_map_present", [False, True])
|
||||
def test_routing_path_selection(self, expert_map_present):
|
||||
"""Verify that the EP-aware routing path is taken when expert_map
|
||||
is present, and the legacy_routing path is taken otherwise."""
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
# This is a structural test: we mock the routing functions to
|
||||
# verify the correct path is exercised.
|
||||
mock_expert_map = (
|
||||
torch.tensor([0, -1, 1, -1], device=device) if expert_map_present else None
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"vllm.model_executor.layers.fused_moe."
|
||||
"gpt_oss_triton_kernels_moe.legacy_routing"
|
||||
) as mock_legacy,
|
||||
patch("triton_kernels.topk.topk") as mock_topk,
|
||||
patch(
|
||||
"vllm.model_executor.layers.fused_moe."
|
||||
"gpt_oss_triton_kernels_moe.make_routing_data"
|
||||
) as mock_make_routing,
|
||||
patch(
|
||||
"vllm.model_executor.layers.fused_moe."
|
||||
"gpt_oss_triton_kernels_moe.triton_kernel_fused_experts"
|
||||
) as mock_fused_experts,
|
||||
):
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501
|
||||
triton_kernel_moe_forward,
|
||||
)
|
||||
|
||||
# Set up return values
|
||||
mock_routing_data = MagicMock()
|
||||
mock_gather = MagicMock()
|
||||
mock_scatter = MagicMock()
|
||||
|
||||
if expert_map_present:
|
||||
sparse_result = MagicMock()
|
||||
sparse_result.indx = torch.tensor([[0, 2]], dtype=torch.int32)
|
||||
sparse_result.vals = torch.tensor([[0.6, 0.4]])
|
||||
mock_topk.return_value = sparse_result
|
||||
mock_make_routing.return_value = (
|
||||
mock_routing_data,
|
||||
mock_gather,
|
||||
mock_scatter,
|
||||
)
|
||||
else:
|
||||
mock_legacy.return_value = (
|
||||
mock_routing_data,
|
||||
mock_gather,
|
||||
mock_scatter,
|
||||
)
|
||||
|
||||
mock_fused_experts.return_value = torch.zeros((1, 8), device=device)
|
||||
|
||||
hidden = torch.randn((1, 8), device=device)
|
||||
w1 = torch.randn((2, 8, 16), device=device)
|
||||
w2 = torch.randn((2, 8, 8), device=device)
|
||||
logits = torch.randn((1, 4), device=device)
|
||||
|
||||
triton_kernel_moe_forward(
|
||||
hidden_states=hidden,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
gating_output=logits,
|
||||
topk=2,
|
||||
renormalize=True,
|
||||
expert_map=mock_expert_map,
|
||||
)
|
||||
|
||||
if expert_map_present:
|
||||
# EP path: should use topk + make_routing_data, NOT
|
||||
# legacy_routing
|
||||
mock_topk.assert_called_once()
|
||||
mock_make_routing.assert_called_once()
|
||||
mock_legacy.assert_not_called()
|
||||
# expert_map should be None in the fused_experts call
|
||||
# (already applied)
|
||||
call_kwargs = mock_fused_experts.call_args
|
||||
assert call_kwargs[1].get("expert_map") is None or (
|
||||
len(call_kwargs[0]) > 0
|
||||
)
|
||||
else:
|
||||
# Non-EP path: should use legacy_routing
|
||||
mock_legacy.assert_called_once()
|
||||
mock_topk.assert_not_called()
|
||||
mock_make_routing.assert_not_called()
|
||||
@@ -179,9 +179,35 @@ def triton_kernel_moe_forward(
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
routing_data, gather_idx, scatter_idx = legacy_routing(
|
||||
gating_output, topk, sm_first=not renormalize
|
||||
)
|
||||
if expert_map is not None:
|
||||
# With expert parallelism, legacy_routing produces routing data
|
||||
# using global expert IDs which don't correspond to local weight
|
||||
# indices. Split the routing into topk selection + expert_map
|
||||
# remapping + local routing data construction (matching the
|
||||
# approach used by OAITritonExperts.apply).
|
||||
from triton_kernels.topk import topk as topk_fn
|
||||
|
||||
sm_first = not renormalize
|
||||
logits = gating_output
|
||||
if sm_first:
|
||||
logits = torch.softmax(logits, dim=-1)
|
||||
sparse_logits = topk_fn(logits, topk, apply_softmax=not sm_first)
|
||||
# sparse_logits.indx contains global expert IDs – remap to local.
|
||||
topk_ids = expert_map[sparse_logits.indx.to(torch.long)]
|
||||
topk_weights = sparse_logits.vals
|
||||
local_num_experts = w1.size(0)
|
||||
routing_data, gather_idx, scatter_idx = make_routing_data(
|
||||
topk_ids, topk_weights, local_num_experts
|
||||
)
|
||||
# expert_map already applied; pass None downstream.
|
||||
effective_expert_map = None
|
||||
effective_global_num_experts = local_num_experts
|
||||
else:
|
||||
routing_data, gather_idx, scatter_idx = legacy_routing(
|
||||
gating_output, topk, sm_first=not renormalize
|
||||
)
|
||||
effective_expert_map = expert_map
|
||||
effective_global_num_experts = global_num_experts
|
||||
|
||||
output = torch.empty_like(hidden_states)
|
||||
|
||||
@@ -197,8 +223,8 @@ def triton_kernel_moe_forward(
|
||||
activation=activation,
|
||||
quant_config=quant_config,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
global_num_experts=effective_global_num_experts,
|
||||
expert_map=effective_expert_map,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user