[Kernels][MoE] Fix legacy_routing to use bitmatrix-based routing path (#38504)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Andreas Karatzas
2026-04-06 21:57:09 -05:00
committed by GitHub
parent 62095e82c1
commit 2df2c85be4
7 changed files with 84 additions and 216 deletions

View File

@@ -23,16 +23,12 @@ from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_m
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
from triton_kernels.tensor_details import layout
from triton_kernels.testing import assert_close
from triton_kernels.topk import topk as topk_fn
from vllm.model_executor.layers.fused_moe.config import mxfp4_w4a16_moe_quant_config
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
legacy_routing,
make_routing_data,
triton_kernel_moe_forward,
)
from vllm.utils.math_utils import round_up
from vllm.utils.torch_utils import set_random_seed
from .utils import shuffle_weight
@@ -97,10 +93,18 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
if w_dtype != "mx4":
pytest.skip("NYI")
else: # quantize to mx4
# careful on the padding here, the activation padding need to be
# multiple of 64, the actual engine is not implemented
w1_bottom_pad = round_up(w1_tri.shape[1], 64) - w1_tri.shape[1]
w1_right_pad = round_up(w1_tri.shape[2], 128) - w1_tri.shape[2]
# Padding alignment depends on the platform. On CDNA4 the scale
# swizzle requires SCALE_K % 8 == 0 (K % 256) and
# SCALE_N % 32 == 0 (2*N % 512), matching the production
# alignment in mxfp4_round_up_hidden_size_and_intermediate_size.
# On CUDA (Hopper) the scale layout pads internally, so the
# original 64/128 alignment is sufficient.
if current_platform.is_rocm():
k_align, n2_align = 256, 512
else:
k_align, n2_align = 64, 128
w1_bottom_pad = round_up(w1_tri.shape[1], k_align) - w1_tri.shape[1]
w1_right_pad = round_up(w1_tri.shape[2], n2_align) - w1_tri.shape[2]
w2_bottom_pad = w1_right_pad // 2
w2_right_pad = w1_bottom_pad
@@ -367,52 +371,3 @@ def test_unit_shuffle():
)
assert_close(ref=out_ref, tri=out)
@pytest.mark.parametrize("num_tokens", [2, 8, 64])
@pytest.mark.parametrize("num_experts", [32, 128])
@pytest.mark.parametrize("topk", [1, 4])
@pytest.mark.parametrize("renormalize", [True, False])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
def test_legacy_routing(
num_tokens: int, num_experts: int, topk: int, renormalize: bool, dtype: torch.dtype
):
set_random_seed(0)
gating_output = torch.randn(num_tokens, num_experts, device="cuda", dtype=dtype)
sm_first = not renormalize
logits = gating_output
if sm_first:
logits = torch.softmax(logits, dim=-1)
topk_result = topk_fn(logits, topk, apply_softmax=not sm_first)
# topk_fn returns SparseMatrix on NVIDIA, plain tuple on ROCm.
if isinstance(topk_result, tuple):
topk_weights, topk_ids_raw, bitmatrix = topk_result
from triton_kernels.routing import routing_from_bitmatrix
routing_data_ref, gather_indx_ref, scatter_indx_ref = routing_from_bitmatrix(
bitmatrix, topk_weights, topk_ids_raw, num_experts, topk
)
else:
topk_ids = topk_result.indx.to(torch.long)
topk_weights = topk_result.vals
routing_data_ref, gather_indx_ref, scatter_indx_ref = make_routing_data(
topk_ids, topk_weights, num_experts
)
routing_data, gather_indx, scatter_indx = legacy_routing(
gating_output, topk, sm_first=sm_first
)
assert_close(
ref=gather_indx_ref.src_indx, tri=gather_indx.src_indx, maxtol=0, rmstol=0
)
assert_close(
ref=gather_indx_ref.dst_indx, tri=gather_indx.dst_indx, maxtol=0, rmstol=0
)
assert_close(
ref=scatter_indx_ref.src_indx, tri=scatter_indx.src_indx, maxtol=0, rmstol=0
)
assert_close(
ref=scatter_indx_ref.dst_indx, tri=scatter_indx.dst_indx, maxtol=0, rmstol=0
)