diff --git a/tests/kernels/moe/test_gpt_oss_triton_kernels.py b/tests/kernels/moe/test_gpt_oss_triton_kernels.py index 630ea2e3f..f659ec56c 100644 --- a/tests/kernels/moe/test_gpt_oss_triton_kernels.py +++ b/tests/kernels/moe/test_gpt_oss_triton_kernels.py @@ -21,12 +21,16 @@ from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_m from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor from triton_kernels.tensor_details import layout from triton_kernels.testing import assert_close +from triton_kernels.topk import topk as topk_fn from vllm.model_executor.layers.fused_moe.config import mxfp4_w4a16_moe_quant_config from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( + legacy_routing, + make_routing_data, triton_kernel_moe_forward, ) from vllm.utils.math_utils import round_up +from vllm.utils.torch_utils import set_random_seed from .utils import shuffle_weight @@ -355,3 +359,43 @@ def test_unit_shuffle(): ) assert_close(ref=out_ref, tri=out) + + +@pytest.mark.parametrize("num_tokens", [2, 8, 64]) +@pytest.mark.parametrize("num_experts", [32, 128]) +@pytest.mark.parametrize("topk", [1, 4]) +@pytest.mark.parametrize("renormalize", [True, False]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +def test_legacy_routing( + num_tokens: int, num_experts: int, topk: int, renormalize: bool, dtype: torch.dtype +): + set_random_seed(0) + gating_output = torch.randn(num_tokens, num_experts, device="cuda", dtype=dtype) + + sm_first = not renormalize + logits = gating_output + if sm_first: + logits = torch.softmax(logits, dim=-1) + sparse_logits = topk_fn(logits, topk, apply_softmax=not sm_first) + topk_ids = sparse_logits.indx.to(torch.long) + topk_weights = sparse_logits.vals + routing_data_ref, gather_indx_ref, scatter_indx_ref = make_routing_data( + topk_ids, topk_weights, num_experts + ) + + routing_data, gather_indx, scatter_indx = legacy_routing( + gating_output, topk, sm_first=sm_first + ) + + assert_close( + ref=gather_indx_ref.src_indx, tri=gather_indx.src_indx, maxtol=0, rmstol=0 + ) + assert_close( + ref=gather_indx_ref.dst_indx, tri=gather_indx.dst_indx, maxtol=0, rmstol=0 + ) + assert_close( + ref=scatter_indx_ref.src_indx, tri=scatter_indx.src_indx, maxtol=0, rmstol=0 + ) + assert_close( + ref=scatter_indx_ref.dst_indx, tri=scatter_indx.dst_indx, maxtol=0, rmstol=0 + ) diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index 82b0a21cb..5e7e7aa46 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -142,6 +142,33 @@ def legacy_routing_from_bitmatrix( return routing_data, gather_idx, scatter_idx +def legacy_routing_from_sparsematrix( + sparse_logits: "SparseMatrix", + n_expts_tot: int, + n_expts_act: int, +) -> tuple["RoutingData", "GatherIndx", "ScatterIndx"]: + """ + Creates routing data from a SparseMatrix representation. + """ + dispatch_indx = sparse_logits.mask_metadata.row_sorted_indx + combine_indx = sparse_logits.mask_metadata.col_sorted_indx + ragged_batch_metadata = make_ragged_tensor_metadata( + sparse_logits.mask_metadata.col_sum, + dispatch_indx.shape[0], + ) + gate_scal = sparse_logits.vals.flatten()[combine_indx] + routing_data = RoutingData( + gate_scal, + ragged_batch_metadata.block_sizes, + n_expts_tot, + n_expts_act, + ragged_batch_metadata, + ) + gather_idx = GatherIndx(combine_indx, dispatch_indx) + scatter_idx = ScatterIndx(dispatch_indx, combine_indx) + return routing_data, gather_idx, scatter_idx + + def legacy_routing( logits: torch.Tensor, n_expts_act: int, @@ -158,10 +185,8 @@ def legacy_routing( if sm_first: logits = torch.softmax(logits, dim=-1) sparse_logits = topk(logits, n_expts_act, apply_softmax=not sm_first) - return legacy_routing_from_bitmatrix( - sparse_logits.mask, - sparse_logits.vals, - sparse_logits.indx, + return legacy_routing_from_sparsematrix( + sparse_logits, logits.shape[-1], n_expts_act, )