[Kernels][MoE] Fix legacy_routing to use bitmatrix-based routing path (#38504)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -23,16 +23,12 @@ from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_m
|
||||
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
|
||||
from triton_kernels.tensor_details import layout
|
||||
from triton_kernels.testing import assert_close
|
||||
from triton_kernels.topk import topk as topk_fn
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.config import mxfp4_w4a16_moe_quant_config
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
|
||||
legacy_routing,
|
||||
make_routing_data,
|
||||
triton_kernel_moe_forward,
|
||||
)
|
||||
from vllm.utils.math_utils import round_up
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
|
||||
from .utils import shuffle_weight
|
||||
|
||||
@@ -97,10 +93,18 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
|
||||
if w_dtype != "mx4":
|
||||
pytest.skip("NYI")
|
||||
else: # quantize to mx4
|
||||
# careful on the padding here, the activation padding need to be
|
||||
# multiple of 64, the actual engine is not implemented
|
||||
w1_bottom_pad = round_up(w1_tri.shape[1], 64) - w1_tri.shape[1]
|
||||
w1_right_pad = round_up(w1_tri.shape[2], 128) - w1_tri.shape[2]
|
||||
# Padding alignment depends on the platform. On CDNA4 the scale
|
||||
# swizzle requires SCALE_K % 8 == 0 (K % 256) and
|
||||
# SCALE_N % 32 == 0 (2*N % 512), matching the production
|
||||
# alignment in mxfp4_round_up_hidden_size_and_intermediate_size.
|
||||
# On CUDA (Hopper) the scale layout pads internally, so the
|
||||
# original 64/128 alignment is sufficient.
|
||||
if current_platform.is_rocm():
|
||||
k_align, n2_align = 256, 512
|
||||
else:
|
||||
k_align, n2_align = 64, 128
|
||||
w1_bottom_pad = round_up(w1_tri.shape[1], k_align) - w1_tri.shape[1]
|
||||
w1_right_pad = round_up(w1_tri.shape[2], n2_align) - w1_tri.shape[2]
|
||||
|
||||
w2_bottom_pad = w1_right_pad // 2
|
||||
w2_right_pad = w1_bottom_pad
|
||||
@@ -367,52 +371,3 @@ def test_unit_shuffle():
|
||||
)
|
||||
|
||||
assert_close(ref=out_ref, tri=out)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [2, 8, 64])
|
||||
@pytest.mark.parametrize("num_experts", [32, 128])
|
||||
@pytest.mark.parametrize("topk", [1, 4])
|
||||
@pytest.mark.parametrize("renormalize", [True, False])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
def test_legacy_routing(
|
||||
num_tokens: int, num_experts: int, topk: int, renormalize: bool, dtype: torch.dtype
|
||||
):
|
||||
set_random_seed(0)
|
||||
gating_output = torch.randn(num_tokens, num_experts, device="cuda", dtype=dtype)
|
||||
|
||||
sm_first = not renormalize
|
||||
logits = gating_output
|
||||
if sm_first:
|
||||
logits = torch.softmax(logits, dim=-1)
|
||||
topk_result = topk_fn(logits, topk, apply_softmax=not sm_first)
|
||||
# topk_fn returns SparseMatrix on NVIDIA, plain tuple on ROCm.
|
||||
if isinstance(topk_result, tuple):
|
||||
topk_weights, topk_ids_raw, bitmatrix = topk_result
|
||||
from triton_kernels.routing import routing_from_bitmatrix
|
||||
|
||||
routing_data_ref, gather_indx_ref, scatter_indx_ref = routing_from_bitmatrix(
|
||||
bitmatrix, topk_weights, topk_ids_raw, num_experts, topk
|
||||
)
|
||||
else:
|
||||
topk_ids = topk_result.indx.to(torch.long)
|
||||
topk_weights = topk_result.vals
|
||||
routing_data_ref, gather_indx_ref, scatter_indx_ref = make_routing_data(
|
||||
topk_ids, topk_weights, num_experts
|
||||
)
|
||||
|
||||
routing_data, gather_indx, scatter_indx = legacy_routing(
|
||||
gating_output, topk, sm_first=sm_first
|
||||
)
|
||||
|
||||
assert_close(
|
||||
ref=gather_indx_ref.src_indx, tri=gather_indx.src_indx, maxtol=0, rmstol=0
|
||||
)
|
||||
assert_close(
|
||||
ref=gather_indx_ref.dst_indx, tri=gather_indx.dst_indx, maxtol=0, rmstol=0
|
||||
)
|
||||
assert_close(
|
||||
ref=scatter_indx_ref.src_indx, tri=scatter_indx.src_indx, maxtol=0, rmstol=0
|
||||
)
|
||||
assert_close(
|
||||
ref=scatter_indx_ref.dst_indx, tri=scatter_indx.dst_indx, maxtol=0, rmstol=0
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user