[Kernels][MoE] Fix legacy_routing to use bitmatrix-based routing path (#38504)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Andreas Karatzas
2026-04-06 21:57:09 -05:00
committed by GitHub
parent 62095e82c1
commit 2df2c85be4
7 changed files with 84 additions and 216 deletions

View File

@@ -3,4 +3,4 @@
model_name: openai/gpt-oss-20b
metric_threshold: 0.568
reasoning_effort: low
server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN"
server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN --tensor-parallel-size 2"

View File

@@ -3,6 +3,6 @@
model_name: amd/gpt-oss-20b-w-mxfp4-a-bf16
metric_threshold: 0.568
reasoning_effort: low
server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN --moe-backend aiter"
server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN --moe-backend aiter --tokenizer openai/gpt-oss-20b --tensor-parallel-size 2"
env:
VLLM_ROCM_USE_AITER: "1"
VLLM_ROCM_USE_AITER: "1"

View File

@@ -3,4 +3,4 @@
model_name: amd/gpt-oss-20b-w-mxfp4-a-bf16
metric_threshold: 0.568
reasoning_effort: low
server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN --moe-backend triton"
server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN --moe-backend triton --tokenizer openai/gpt-oss-20b --tensor-parallel-size 2"

View File

@@ -3,6 +3,6 @@
model_name: amd/gpt-oss-20b-MoE-Quant-W-MXFP4-A-FP8-KV-FP8
metric_threshold: 0.568
reasoning_effort: low
server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN"
server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN --tensor-parallel-size 2"
env:
VLLM_ROCM_USE_AITER: "1"

View File

@@ -23,16 +23,12 @@ from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_m
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
from triton_kernels.tensor_details import layout
from triton_kernels.testing import assert_close
from triton_kernels.topk import topk as topk_fn
from vllm.model_executor.layers.fused_moe.config import mxfp4_w4a16_moe_quant_config
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
legacy_routing,
make_routing_data,
triton_kernel_moe_forward,
)
from vllm.utils.math_utils import round_up
from vllm.utils.torch_utils import set_random_seed
from .utils import shuffle_weight
@@ -97,10 +93,18 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
if w_dtype != "mx4":
pytest.skip("NYI")
else: # quantize to mx4
# careful on the padding here, the activation padding need to be
# multiple of 64, the actual engine is not implemented
w1_bottom_pad = round_up(w1_tri.shape[1], 64) - w1_tri.shape[1]
w1_right_pad = round_up(w1_tri.shape[2], 128) - w1_tri.shape[2]
# Padding alignment depends on the platform. On CDNA4 the scale
# swizzle requires SCALE_K % 8 == 0 (K % 256) and
# SCALE_N % 32 == 0 (2*N % 512), matching the production
# alignment in mxfp4_round_up_hidden_size_and_intermediate_size.
# On CUDA (Hopper) the scale layout pads internally, so the
# original 64/128 alignment is sufficient.
if current_platform.is_rocm():
k_align, n2_align = 256, 512
else:
k_align, n2_align = 64, 128
w1_bottom_pad = round_up(w1_tri.shape[1], k_align) - w1_tri.shape[1]
w1_right_pad = round_up(w1_tri.shape[2], n2_align) - w1_tri.shape[2]
w2_bottom_pad = w1_right_pad // 2
w2_right_pad = w1_bottom_pad
@@ -367,52 +371,3 @@ def test_unit_shuffle():
)
assert_close(ref=out_ref, tri=out)
@pytest.mark.parametrize("num_tokens", [2, 8, 64])
@pytest.mark.parametrize("num_experts", [32, 128])
@pytest.mark.parametrize("topk", [1, 4])
@pytest.mark.parametrize("renormalize", [True, False])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
def test_legacy_routing(
num_tokens: int, num_experts: int, topk: int, renormalize: bool, dtype: torch.dtype
):
set_random_seed(0)
gating_output = torch.randn(num_tokens, num_experts, device="cuda", dtype=dtype)
sm_first = not renormalize
logits = gating_output
if sm_first:
logits = torch.softmax(logits, dim=-1)
topk_result = topk_fn(logits, topk, apply_softmax=not sm_first)
# topk_fn returns SparseMatrix on NVIDIA, plain tuple on ROCm.
if isinstance(topk_result, tuple):
topk_weights, topk_ids_raw, bitmatrix = topk_result
from triton_kernels.routing import routing_from_bitmatrix
routing_data_ref, gather_indx_ref, scatter_indx_ref = routing_from_bitmatrix(
bitmatrix, topk_weights, topk_ids_raw, num_experts, topk
)
else:
topk_ids = topk_result.indx.to(torch.long)
topk_weights = topk_result.vals
routing_data_ref, gather_indx_ref, scatter_indx_ref = make_routing_data(
topk_ids, topk_weights, num_experts
)
routing_data, gather_indx, scatter_indx = legacy_routing(
gating_output, topk, sm_first=sm_first
)
assert_close(
ref=gather_indx_ref.src_indx, tri=gather_indx.src_indx, maxtol=0, rmstol=0
)
assert_close(
ref=gather_indx_ref.dst_indx, tri=gather_indx.dst_indx, maxtol=0, rmstol=0
)
assert_close(
ref=scatter_indx_ref.src_indx, tri=scatter_indx.src_indx, maxtol=0, rmstol=0
)
assert_close(
ref=scatter_indx_ref.dst_indx, tri=scatter_indx.dst_indx, maxtol=0, rmstol=0
)

View File

@@ -4,12 +4,9 @@
Tests that triton_kernel_moe_forward correctly applies expert_map
remapping when expert parallelism (EP) is enabled.
Previously, legacy_routing was always used and it produced routing data
with global expert IDs that didn't correspond to local weight indices,
causing illegal memory access with EP. The fix splits routing: when
expert_map is provided, topk selection is performed first, expert_map is
applied to remap global→local IDs, and make_routing_data builds routing
structures from the local IDs.
Both EP and non-EP paths use topk + make_routing_data. When expert_map
is provided, global expert IDs are remapped to local IDs before building
routing structures.
"""
from unittest.mock import MagicMock, patch
@@ -24,21 +21,15 @@ class TestTritonMoeForwardExpertMap:
@pytest.mark.parametrize("expert_map_present", [False, True])
def test_routing_path_selection(self, expert_map_present):
"""Verify that the EP-aware routing path is taken when expert_map
is present, and the legacy_routing path is taken otherwise."""
"""Verify that both EP and non-EP paths use topk + make_routing_data,
and that expert_map remapping is applied when present."""
device = "cuda" if torch.cuda.is_available() else "cpu"
# This is a structural test: we mock the routing functions to
# verify the correct path is exercised.
mock_expert_map = (
torch.tensor([0, -1, 1, -1], device=device) if expert_map_present else None
)
with (
patch(
"vllm.model_executor.layers.fused_moe."
"gpt_oss_triton_kernels_moe.legacy_routing"
) as mock_legacy,
patch("triton_kernels.topk.topk") as mock_topk,
patch(
"vllm.model_executor.layers.fused_moe."
@@ -53,27 +44,19 @@ class TestTritonMoeForwardExpertMap:
triton_kernel_moe_forward,
)
# Set up return values
mock_routing_data = MagicMock()
mock_gather = MagicMock()
mock_scatter = MagicMock()
if expert_map_present:
sparse_result = MagicMock()
sparse_result.indx = torch.tensor([[0, 2]], dtype=torch.int32)
sparse_result.vals = torch.tensor([[0.6, 0.4]])
mock_topk.return_value = sparse_result
mock_make_routing.return_value = (
mock_routing_data,
mock_gather,
mock_scatter,
)
else:
mock_legacy.return_value = (
mock_routing_data,
mock_gather,
mock_scatter,
)
sparse_result = MagicMock()
sparse_result.indx = torch.tensor([[0, 2]], dtype=torch.int32)
sparse_result.vals = torch.tensor([[0.6, 0.4]])
mock_topk.return_value = sparse_result
mock_make_routing.return_value = (
mock_routing_data,
mock_gather,
mock_scatter,
)
mock_fused_experts.return_value = torch.zeros((1, 8), device=device)
@@ -92,20 +75,14 @@ class TestTritonMoeForwardExpertMap:
expert_map=mock_expert_map,
)
# Both paths use topk + make_routing_data
mock_topk.assert_called_once()
mock_make_routing.assert_called_once()
if expert_map_present:
# EP path: should use topk + make_routing_data, NOT
# legacy_routing
mock_topk.assert_called_once()
mock_make_routing.assert_called_once()
mock_legacy.assert_not_called()
# expert_map should be None in the fused_experts call
# (already applied)
call_kwargs = mock_fused_experts.call_args
assert call_kwargs[1].get("expert_map") is None or (
len(call_kwargs[0]) > 0
)
else:
# Non-EP path: should use legacy_routing
mock_legacy.assert_called_once()
mock_topk.assert_not_called()
mock_make_routing.assert_not_called()