[Kernels][MoE] Fix legacy_routing to use bitmatrix-based routing path (#38504)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -3,4 +3,4 @@
|
||||
model_name: openai/gpt-oss-20b
|
||||
metric_threshold: 0.568
|
||||
reasoning_effort: low
|
||||
server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN"
|
||||
server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN --tensor-parallel-size 2"
|
||||
@@ -3,6 +3,6 @@
|
||||
model_name: amd/gpt-oss-20b-w-mxfp4-a-bf16
|
||||
metric_threshold: 0.568
|
||||
reasoning_effort: low
|
||||
server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN --moe-backend aiter"
|
||||
server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN --moe-backend aiter --tokenizer openai/gpt-oss-20b --tensor-parallel-size 2"
|
||||
env:
|
||||
VLLM_ROCM_USE_AITER: "1"
|
||||
VLLM_ROCM_USE_AITER: "1"
|
||||
@@ -3,4 +3,4 @@
|
||||
model_name: amd/gpt-oss-20b-w-mxfp4-a-bf16
|
||||
metric_threshold: 0.568
|
||||
reasoning_effort: low
|
||||
server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN --moe-backend triton"
|
||||
server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN --moe-backend triton --tokenizer openai/gpt-oss-20b --tensor-parallel-size 2"
|
||||
@@ -3,6 +3,6 @@
|
||||
model_name: amd/gpt-oss-20b-MoE-Quant-W-MXFP4-A-FP8-KV-FP8
|
||||
metric_threshold: 0.568
|
||||
reasoning_effort: low
|
||||
server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN"
|
||||
server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN --tensor-parallel-size 2"
|
||||
env:
|
||||
VLLM_ROCM_USE_AITER: "1"
|
||||
@@ -23,16 +23,12 @@ from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_m
|
||||
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
|
||||
from triton_kernels.tensor_details import layout
|
||||
from triton_kernels.testing import assert_close
|
||||
from triton_kernels.topk import topk as topk_fn
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.config import mxfp4_w4a16_moe_quant_config
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
|
||||
legacy_routing,
|
||||
make_routing_data,
|
||||
triton_kernel_moe_forward,
|
||||
)
|
||||
from vllm.utils.math_utils import round_up
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
|
||||
from .utils import shuffle_weight
|
||||
|
||||
@@ -97,10 +93,18 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
|
||||
if w_dtype != "mx4":
|
||||
pytest.skip("NYI")
|
||||
else: # quantize to mx4
|
||||
# careful on the padding here, the activation padding need to be
|
||||
# multiple of 64, the actual engine is not implemented
|
||||
w1_bottom_pad = round_up(w1_tri.shape[1], 64) - w1_tri.shape[1]
|
||||
w1_right_pad = round_up(w1_tri.shape[2], 128) - w1_tri.shape[2]
|
||||
# Padding alignment depends on the platform. On CDNA4 the scale
|
||||
# swizzle requires SCALE_K % 8 == 0 (K % 256) and
|
||||
# SCALE_N % 32 == 0 (2*N % 512), matching the production
|
||||
# alignment in mxfp4_round_up_hidden_size_and_intermediate_size.
|
||||
# On CUDA (Hopper) the scale layout pads internally, so the
|
||||
# original 64/128 alignment is sufficient.
|
||||
if current_platform.is_rocm():
|
||||
k_align, n2_align = 256, 512
|
||||
else:
|
||||
k_align, n2_align = 64, 128
|
||||
w1_bottom_pad = round_up(w1_tri.shape[1], k_align) - w1_tri.shape[1]
|
||||
w1_right_pad = round_up(w1_tri.shape[2], n2_align) - w1_tri.shape[2]
|
||||
|
||||
w2_bottom_pad = w1_right_pad // 2
|
||||
w2_right_pad = w1_bottom_pad
|
||||
@@ -367,52 +371,3 @@ def test_unit_shuffle():
|
||||
)
|
||||
|
||||
assert_close(ref=out_ref, tri=out)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [2, 8, 64])
|
||||
@pytest.mark.parametrize("num_experts", [32, 128])
|
||||
@pytest.mark.parametrize("topk", [1, 4])
|
||||
@pytest.mark.parametrize("renormalize", [True, False])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
def test_legacy_routing(
|
||||
num_tokens: int, num_experts: int, topk: int, renormalize: bool, dtype: torch.dtype
|
||||
):
|
||||
set_random_seed(0)
|
||||
gating_output = torch.randn(num_tokens, num_experts, device="cuda", dtype=dtype)
|
||||
|
||||
sm_first = not renormalize
|
||||
logits = gating_output
|
||||
if sm_first:
|
||||
logits = torch.softmax(logits, dim=-1)
|
||||
topk_result = topk_fn(logits, topk, apply_softmax=not sm_first)
|
||||
# topk_fn returns SparseMatrix on NVIDIA, plain tuple on ROCm.
|
||||
if isinstance(topk_result, tuple):
|
||||
topk_weights, topk_ids_raw, bitmatrix = topk_result
|
||||
from triton_kernels.routing import routing_from_bitmatrix
|
||||
|
||||
routing_data_ref, gather_indx_ref, scatter_indx_ref = routing_from_bitmatrix(
|
||||
bitmatrix, topk_weights, topk_ids_raw, num_experts, topk
|
||||
)
|
||||
else:
|
||||
topk_ids = topk_result.indx.to(torch.long)
|
||||
topk_weights = topk_result.vals
|
||||
routing_data_ref, gather_indx_ref, scatter_indx_ref = make_routing_data(
|
||||
topk_ids, topk_weights, num_experts
|
||||
)
|
||||
|
||||
routing_data, gather_indx, scatter_indx = legacy_routing(
|
||||
gating_output, topk, sm_first=sm_first
|
||||
)
|
||||
|
||||
assert_close(
|
||||
ref=gather_indx_ref.src_indx, tri=gather_indx.src_indx, maxtol=0, rmstol=0
|
||||
)
|
||||
assert_close(
|
||||
ref=gather_indx_ref.dst_indx, tri=gather_indx.dst_indx, maxtol=0, rmstol=0
|
||||
)
|
||||
assert_close(
|
||||
ref=scatter_indx_ref.src_indx, tri=scatter_indx.src_indx, maxtol=0, rmstol=0
|
||||
)
|
||||
assert_close(
|
||||
ref=scatter_indx_ref.dst_indx, tri=scatter_indx.dst_indx, maxtol=0, rmstol=0
|
||||
)
|
||||
|
||||
@@ -4,12 +4,9 @@
|
||||
Tests that triton_kernel_moe_forward correctly applies expert_map
|
||||
remapping when expert parallelism (EP) is enabled.
|
||||
|
||||
Previously, legacy_routing was always used and it produced routing data
|
||||
with global expert IDs that didn't correspond to local weight indices,
|
||||
causing illegal memory access with EP. The fix splits routing: when
|
||||
expert_map is provided, topk selection is performed first, expert_map is
|
||||
applied to remap global→local IDs, and make_routing_data builds routing
|
||||
structures from the local IDs.
|
||||
Both EP and non-EP paths use topk + make_routing_data. When expert_map
|
||||
is provided, global expert IDs are remapped to local IDs before building
|
||||
routing structures.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
@@ -24,21 +21,15 @@ class TestTritonMoeForwardExpertMap:
|
||||
|
||||
@pytest.mark.parametrize("expert_map_present", [False, True])
|
||||
def test_routing_path_selection(self, expert_map_present):
|
||||
"""Verify that the EP-aware routing path is taken when expert_map
|
||||
is present, and the legacy_routing path is taken otherwise."""
|
||||
"""Verify that both EP and non-EP paths use topk + make_routing_data,
|
||||
and that expert_map remapping is applied when present."""
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
# This is a structural test: we mock the routing functions to
|
||||
# verify the correct path is exercised.
|
||||
mock_expert_map = (
|
||||
torch.tensor([0, -1, 1, -1], device=device) if expert_map_present else None
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"vllm.model_executor.layers.fused_moe."
|
||||
"gpt_oss_triton_kernels_moe.legacy_routing"
|
||||
) as mock_legacy,
|
||||
patch("triton_kernels.topk.topk") as mock_topk,
|
||||
patch(
|
||||
"vllm.model_executor.layers.fused_moe."
|
||||
@@ -53,27 +44,19 @@ class TestTritonMoeForwardExpertMap:
|
||||
triton_kernel_moe_forward,
|
||||
)
|
||||
|
||||
# Set up return values
|
||||
mock_routing_data = MagicMock()
|
||||
mock_gather = MagicMock()
|
||||
mock_scatter = MagicMock()
|
||||
|
||||
if expert_map_present:
|
||||
sparse_result = MagicMock()
|
||||
sparse_result.indx = torch.tensor([[0, 2]], dtype=torch.int32)
|
||||
sparse_result.vals = torch.tensor([[0.6, 0.4]])
|
||||
mock_topk.return_value = sparse_result
|
||||
mock_make_routing.return_value = (
|
||||
mock_routing_data,
|
||||
mock_gather,
|
||||
mock_scatter,
|
||||
)
|
||||
else:
|
||||
mock_legacy.return_value = (
|
||||
mock_routing_data,
|
||||
mock_gather,
|
||||
mock_scatter,
|
||||
)
|
||||
sparse_result = MagicMock()
|
||||
sparse_result.indx = torch.tensor([[0, 2]], dtype=torch.int32)
|
||||
sparse_result.vals = torch.tensor([[0.6, 0.4]])
|
||||
mock_topk.return_value = sparse_result
|
||||
mock_make_routing.return_value = (
|
||||
mock_routing_data,
|
||||
mock_gather,
|
||||
mock_scatter,
|
||||
)
|
||||
|
||||
mock_fused_experts.return_value = torch.zeros((1, 8), device=device)
|
||||
|
||||
@@ -92,20 +75,14 @@ class TestTritonMoeForwardExpertMap:
|
||||
expert_map=mock_expert_map,
|
||||
)
|
||||
|
||||
# Both paths use topk + make_routing_data
|
||||
mock_topk.assert_called_once()
|
||||
mock_make_routing.assert_called_once()
|
||||
|
||||
if expert_map_present:
|
||||
# EP path: should use topk + make_routing_data, NOT
|
||||
# legacy_routing
|
||||
mock_topk.assert_called_once()
|
||||
mock_make_routing.assert_called_once()
|
||||
mock_legacy.assert_not_called()
|
||||
# expert_map should be None in the fused_experts call
|
||||
# (already applied)
|
||||
call_kwargs = mock_fused_experts.call_args
|
||||
assert call_kwargs[1].get("expert_map") is None or (
|
||||
len(call_kwargs[0]) > 0
|
||||
)
|
||||
else:
|
||||
# Non-EP path: should use legacy_routing
|
||||
mock_legacy.assert_called_once()
|
||||
mock_topk.assert_not_called()
|
||||
mock_make_routing.assert_not_called()
|
||||
|
||||
@@ -47,7 +47,6 @@ if has_triton_kernels():
|
||||
BIT,
|
||||
Bitmatrix,
|
||||
)
|
||||
from triton_kernels.topk import topk
|
||||
|
||||
try:
|
||||
from triton_kernels.tensor import (
|
||||
@@ -89,6 +88,7 @@ def pack_bitmatrix(
|
||||
offsets = offsets_m[:, None] * n_expts_act + offsets_k[None, :]
|
||||
mask = (offsets_m < n_rows)[:, None] & (offsets_k < n_expts_act)[None, :]
|
||||
indices = tl.load(topk_ids + offsets, mask=mask, other=-1)
|
||||
valid = indices >= 0
|
||||
div = indices // 32
|
||||
rem = indices % 32
|
||||
one = tl.cast(1, tl.uint32)
|
||||
@@ -99,8 +99,13 @@ def pack_bitmatrix(
|
||||
offs = tl.arange(0, BLOCK_SIZE_K // 32) + i * (BLOCK_SIZE_K // 32)
|
||||
# All topks that need to go into this column has the correct bit set.
|
||||
# Other bits are 0. x is a 2D tensor.
|
||||
# Guard with `valid` to prevent negative indices from producing
|
||||
# spurious bits (on HIP, -1 // 32 == 0 and 1 << (-1 % 32) sets
|
||||
# bit 31).
|
||||
x = tl.where(
|
||||
div[:, :, None] == offs[None, None, :], (one << rem)[:, :, None], 0
|
||||
valid[:, :, None] & (div[:, :, None] == offs[None, None, :]),
|
||||
(one << rem)[:, :, None],
|
||||
0,
|
||||
)
|
||||
# Reduce x to get a single int32_t bitpack.
|
||||
y = tl.reduce_or(x, axis=1)
|
||||
@@ -108,93 +113,6 @@ def pack_bitmatrix(
|
||||
tl.store(bitmatrix_ptrs, y, mask=offsets_m[:, None] < n_rows)
|
||||
|
||||
|
||||
def legacy_routing_from_bitmatrix(
|
||||
bitmatrix: "Bitmatrix",
|
||||
expt_scal: torch.Tensor,
|
||||
expt_indx: torch.Tensor,
|
||||
n_expts_tot: int,
|
||||
n_expts_act: int,
|
||||
) -> tuple["RoutingData", "GatherIndx", "ScatterIndx"]:
|
||||
"""
|
||||
Replacement for the removed triton_kernels.routing.routing_from_bitmatrix.
|
||||
Creates routing data from a bitmatrix representation.
|
||||
"""
|
||||
if use_legacy_triton_kernels:
|
||||
from triton_kernels.routing import routing_from_bitmatrix
|
||||
|
||||
return routing_from_bitmatrix(
|
||||
bitmatrix, expt_scal, expt_indx, n_expts_tot, n_expts_act
|
||||
)
|
||||
sparse_logits = SparseMatrix(indx=expt_indx, vals=expt_scal, mask=bitmatrix)
|
||||
dispatch_indx = sparse_logits.mask_metadata.row_sorted_indx
|
||||
combine_indx = sparse_logits.mask_metadata.col_sorted_indx
|
||||
ragged_batch_metadata = make_ragged_tensor_metadata(
|
||||
sparse_logits.mask_metadata.col_sum,
|
||||
dispatch_indx.shape[0],
|
||||
)
|
||||
gate_scal = sparse_logits.vals.flatten()[combine_indx]
|
||||
routing_data = RoutingData(
|
||||
gate_scal,
|
||||
ragged_batch_metadata.block_sizes,
|
||||
n_expts_tot,
|
||||
n_expts_act,
|
||||
ragged_batch_metadata,
|
||||
)
|
||||
gather_idx = GatherIndx(combine_indx, dispatch_indx)
|
||||
scatter_idx = ScatterIndx(dispatch_indx, combine_indx)
|
||||
return routing_data, gather_idx, scatter_idx
|
||||
|
||||
|
||||
def legacy_routing_from_sparsematrix(
|
||||
sparse_logits: "SparseMatrix",
|
||||
n_expts_tot: int,
|
||||
n_expts_act: int,
|
||||
) -> tuple["RoutingData", "GatherIndx", "ScatterIndx"]:
|
||||
"""
|
||||
Creates routing data from a SparseMatrix representation.
|
||||
"""
|
||||
dispatch_indx = sparse_logits.mask_metadata.row_sorted_indx
|
||||
combine_indx = sparse_logits.mask_metadata.col_sorted_indx
|
||||
ragged_batch_metadata = make_ragged_tensor_metadata(
|
||||
sparse_logits.mask_metadata.col_sum,
|
||||
dispatch_indx.shape[0],
|
||||
)
|
||||
gate_scal = sparse_logits.vals.flatten()[combine_indx]
|
||||
routing_data = RoutingData(
|
||||
gate_scal,
|
||||
ragged_batch_metadata.block_sizes,
|
||||
n_expts_tot,
|
||||
n_expts_act,
|
||||
ragged_batch_metadata,
|
||||
)
|
||||
gather_idx = GatherIndx(combine_indx, dispatch_indx)
|
||||
scatter_idx = ScatterIndx(dispatch_indx, combine_indx)
|
||||
return routing_data, gather_idx, scatter_idx
|
||||
|
||||
|
||||
def legacy_routing(
|
||||
logits: torch.Tensor,
|
||||
n_expts_act: int,
|
||||
sm_first: bool = False,
|
||||
) -> tuple["RoutingData", "GatherIndx", "ScatterIndx"]:
|
||||
"""
|
||||
Replacement for the removed triton_kernels.routing.routing function.
|
||||
Computes routing data from gating logits.
|
||||
"""
|
||||
if use_legacy_triton_kernels:
|
||||
from triton_kernels.routing import routing
|
||||
|
||||
return routing(logits, n_expts_act, sm_first=sm_first)
|
||||
if sm_first:
|
||||
logits = torch.softmax(logits, dim=-1)
|
||||
sparse_logits = topk(logits, n_expts_act, apply_softmax=not sm_first)
|
||||
return legacy_routing_from_sparsematrix(
|
||||
sparse_logits,
|
||||
logits.shape[-1],
|
||||
n_expts_act,
|
||||
)
|
||||
|
||||
|
||||
def triton_kernel_moe_forward(
|
||||
hidden_states: torch.Tensor,
|
||||
w1, # Tensor or triton_kernels.Tensor
|
||||
@@ -241,26 +159,22 @@ def triton_kernel_moe_forward(
|
||||
unpadded_K_w2=unpadded_K_w2,
|
||||
)
|
||||
|
||||
if expert_map is not None:
|
||||
# With expert parallelism, legacy_routing produces routing data
|
||||
# using global expert IDs which don't correspond to local weight
|
||||
# indices. Split the routing into topk selection + expert_map
|
||||
# remapping + local routing data construction (matching the
|
||||
# approach used by OAITritonExperts.apply).
|
||||
from triton_kernels.topk import topk as topk_fn
|
||||
from triton_kernels.topk import topk as topk_fn
|
||||
|
||||
sm_first = not renormalize
|
||||
logits = gating_output
|
||||
if sm_first:
|
||||
logits = torch.softmax(logits, dim=-1)
|
||||
topk_result = topk_fn(logits, topk, apply_softmax=not sm_first)
|
||||
# topk may return a tuple (vals, indx, bitmatrix) or a
|
||||
# SparseMatrix depending on the triton_kernels version.
|
||||
if isinstance(topk_result, tuple):
|
||||
topk_weights, topk_ids_raw, _ = topk_result
|
||||
else:
|
||||
topk_weights = topk_result.vals
|
||||
topk_ids_raw = topk_result.indx
|
||||
sm_first = not renormalize
|
||||
logits = gating_output
|
||||
if sm_first:
|
||||
logits = torch.softmax(logits, dim=-1)
|
||||
topk_result = topk_fn(logits, topk, apply_softmax=not sm_first)
|
||||
# topk may return a tuple (vals, indx, bitmatrix) or a
|
||||
# SparseMatrix depending on the triton_kernels version.
|
||||
if isinstance(topk_result, tuple):
|
||||
topk_weights, topk_ids_raw, _ = topk_result
|
||||
else:
|
||||
topk_weights = topk_result.vals
|
||||
topk_ids_raw = topk_result.indx
|
||||
|
||||
if expert_map is not None:
|
||||
# topk_ids_raw contains global expert IDs - remap to local.
|
||||
topk_ids = expert_map[topk_ids_raw.to(torch.long)]
|
||||
local_num_experts = w1.shape[0]
|
||||
@@ -271,8 +185,9 @@ def triton_kernel_moe_forward(
|
||||
effective_expert_map = None
|
||||
effective_global_num_experts = local_num_experts
|
||||
else:
|
||||
routing_data, gather_idx, scatter_idx = legacy_routing(
|
||||
gating_output, topk, sm_first=not renormalize
|
||||
topk_ids = topk_ids_raw.to(torch.long)
|
||||
routing_data, gather_idx, scatter_idx = make_routing_data(
|
||||
topk_ids, topk_weights, gating_output.shape[-1]
|
||||
)
|
||||
effective_expert_map = expert_map
|
||||
effective_global_num_experts = global_num_experts
|
||||
@@ -539,10 +454,31 @@ def make_routing_data(
|
||||
|
||||
# matmul_ogs expects invalid topk_weights to be -1s
|
||||
topk_weights = torch.where(topk_ids == -1, -1.0, topk_weights)
|
||||
routing_data, gather_indx, scatter_indx = legacy_routing_from_bitmatrix(
|
||||
bitmatrix, topk_weights, topk_ids, num_local_experts, num_topk
|
||||
)
|
||||
|
||||
if use_legacy_triton_kernels:
|
||||
from triton_kernels.routing import routing_from_bitmatrix
|
||||
|
||||
return routing_from_bitmatrix(
|
||||
bitmatrix, topk_weights, topk_ids, num_local_experts, num_topk
|
||||
)
|
||||
|
||||
sparse_logits = SparseMatrix(indx=topk_ids, vals=topk_weights, mask=bitmatrix)
|
||||
dispatch_indx = sparse_logits.mask_metadata.row_sorted_indx
|
||||
combine_indx = sparse_logits.mask_metadata.col_sorted_indx
|
||||
ragged_batch_metadata = make_ragged_tensor_metadata(
|
||||
sparse_logits.mask_metadata.col_sum,
|
||||
dispatch_indx.shape[0],
|
||||
)
|
||||
gate_scal = sparse_logits.vals.flatten()[combine_indx]
|
||||
routing_data = RoutingData(
|
||||
gate_scal,
|
||||
ragged_batch_metadata.block_sizes,
|
||||
num_local_experts,
|
||||
num_topk,
|
||||
ragged_batch_metadata,
|
||||
)
|
||||
gather_indx = GatherIndx(combine_indx, dispatch_indx)
|
||||
scatter_indx = ScatterIndx(dispatch_indx, combine_indx)
|
||||
return routing_data, gather_indx, scatter_indx
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user