diff --git a/tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-baseline.yaml b/tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-baseline.yaml index 76b1d7962..ec1c2b392 100644 --- a/tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-baseline.yaml +++ b/tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-baseline.yaml @@ -3,4 +3,4 @@ model_name: openai/gpt-oss-20b metric_threshold: 0.568 reasoning_effort: low -server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN" \ No newline at end of file +server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN --tensor-parallel-size 2" \ No newline at end of file diff --git a/tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-quark-mxfp4-bf16-aiter.yaml b/tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-quark-mxfp4-bf16-aiter.yaml index 850a6d28b..4ff2648ca 100644 --- a/tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-quark-mxfp4-bf16-aiter.yaml +++ b/tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-quark-mxfp4-bf16-aiter.yaml @@ -3,6 +3,6 @@ model_name: amd/gpt-oss-20b-w-mxfp4-a-bf16 metric_threshold: 0.568 reasoning_effort: low -server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN --moe-backend aiter" +server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN --moe-backend aiter --tokenizer openai/gpt-oss-20b --tensor-parallel-size 2" env: - VLLM_ROCM_USE_AITER: "1" + VLLM_ROCM_USE_AITER: "1" \ No newline at end of file diff --git a/tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-quark-mxfp4-bf16-triton.yaml b/tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-quark-mxfp4-bf16-triton.yaml index 903f30e59..5ae665a04 100644 --- a/tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-quark-mxfp4-bf16-triton.yaml +++ b/tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-quark-mxfp4-bf16-triton.yaml @@ -3,4 +3,4 @@ model_name: amd/gpt-oss-20b-w-mxfp4-a-bf16 metric_threshold: 0.568 reasoning_effort: low -server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN --moe-backend triton" \ No newline at end of file +server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN --moe-backend triton --tokenizer openai/gpt-oss-20b --tensor-parallel-size 2" \ No newline at end of file diff --git a/tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-quark-mxfp4-fp8-triton.yaml b/tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-quark-mxfp4-fp8-triton.yaml index f7dd14784..81270e010 100644 --- a/tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-quark-mxfp4-fp8-triton.yaml +++ b/tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-quark-mxfp4-fp8-triton.yaml @@ -3,6 +3,6 @@ model_name: amd/gpt-oss-20b-MoE-Quant-W-MXFP4-A-FP8-KV-FP8 metric_threshold: 0.568 reasoning_effort: low -server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN" +server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN --tensor-parallel-size 2" env: VLLM_ROCM_USE_AITER: "1" \ No newline at end of file diff --git a/tests/kernels/moe/test_gpt_oss_triton_kernels.py b/tests/kernels/moe/test_gpt_oss_triton_kernels.py index 172938f18..032b4fc04 100644 --- a/tests/kernels/moe/test_gpt_oss_triton_kernels.py +++ b/tests/kernels/moe/test_gpt_oss_triton_kernels.py @@ -23,16 +23,12 @@ from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_m from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor from triton_kernels.tensor_details import layout from triton_kernels.testing import assert_close -from triton_kernels.topk import topk as topk_fn from vllm.model_executor.layers.fused_moe.config import mxfp4_w4a16_moe_quant_config from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( - legacy_routing, - make_routing_data, triton_kernel_moe_forward, ) from vllm.utils.math_utils import round_up -from vllm.utils.torch_utils import set_random_seed from .utils import shuffle_weight @@ -97,10 +93,18 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int): if w_dtype != "mx4": pytest.skip("NYI") else: # quantize to mx4 - # careful on the padding here, the activation padding need to be - # multiple of 64, the actual engine is not implemented - w1_bottom_pad = round_up(w1_tri.shape[1], 64) - w1_tri.shape[1] - w1_right_pad = round_up(w1_tri.shape[2], 128) - w1_tri.shape[2] + # Padding alignment depends on the platform. On CDNA4 the scale + # swizzle requires SCALE_K % 8 == 0 (K % 256) and + # SCALE_N % 32 == 0 (2*N % 512), matching the production + # alignment in mxfp4_round_up_hidden_size_and_intermediate_size. + # On CUDA (Hopper) the scale layout pads internally, so the + # original 64/128 alignment is sufficient. + if current_platform.is_rocm(): + k_align, n2_align = 256, 512 + else: + k_align, n2_align = 64, 128 + w1_bottom_pad = round_up(w1_tri.shape[1], k_align) - w1_tri.shape[1] + w1_right_pad = round_up(w1_tri.shape[2], n2_align) - w1_tri.shape[2] w2_bottom_pad = w1_right_pad // 2 w2_right_pad = w1_bottom_pad @@ -367,52 +371,3 @@ def test_unit_shuffle(): ) assert_close(ref=out_ref, tri=out) - - -@pytest.mark.parametrize("num_tokens", [2, 8, 64]) -@pytest.mark.parametrize("num_experts", [32, 128]) -@pytest.mark.parametrize("topk", [1, 4]) -@pytest.mark.parametrize("renormalize", [True, False]) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) -def test_legacy_routing( - num_tokens: int, num_experts: int, topk: int, renormalize: bool, dtype: torch.dtype -): - set_random_seed(0) - gating_output = torch.randn(num_tokens, num_experts, device="cuda", dtype=dtype) - - sm_first = not renormalize - logits = gating_output - if sm_first: - logits = torch.softmax(logits, dim=-1) - topk_result = topk_fn(logits, topk, apply_softmax=not sm_first) - # topk_fn returns SparseMatrix on NVIDIA, plain tuple on ROCm. - if isinstance(topk_result, tuple): - topk_weights, topk_ids_raw, bitmatrix = topk_result - from triton_kernels.routing import routing_from_bitmatrix - - routing_data_ref, gather_indx_ref, scatter_indx_ref = routing_from_bitmatrix( - bitmatrix, topk_weights, topk_ids_raw, num_experts, topk - ) - else: - topk_ids = topk_result.indx.to(torch.long) - topk_weights = topk_result.vals - routing_data_ref, gather_indx_ref, scatter_indx_ref = make_routing_data( - topk_ids, topk_weights, num_experts - ) - - routing_data, gather_indx, scatter_indx = legacy_routing( - gating_output, topk, sm_first=sm_first - ) - - assert_close( - ref=gather_indx_ref.src_indx, tri=gather_indx.src_indx, maxtol=0, rmstol=0 - ) - assert_close( - ref=gather_indx_ref.dst_indx, tri=gather_indx.dst_indx, maxtol=0, rmstol=0 - ) - assert_close( - ref=scatter_indx_ref.src_indx, tri=scatter_indx.src_indx, maxtol=0, rmstol=0 - ) - assert_close( - ref=scatter_indx_ref.dst_indx, tri=scatter_indx.dst_indx, maxtol=0, rmstol=0 - ) diff --git a/tests/kernels/quantization/test_mxfp4_triton_ep.py b/tests/kernels/quantization/test_mxfp4_triton_ep.py index 6c8aebe42..045bc63de 100644 --- a/tests/kernels/quantization/test_mxfp4_triton_ep.py +++ b/tests/kernels/quantization/test_mxfp4_triton_ep.py @@ -4,12 +4,9 @@ Tests that triton_kernel_moe_forward correctly applies expert_map remapping when expert parallelism (EP) is enabled. -Previously, legacy_routing was always used and it produced routing data -with global expert IDs that didn't correspond to local weight indices, -causing illegal memory access with EP. The fix splits routing: when -expert_map is provided, topk selection is performed first, expert_map is -applied to remap global→local IDs, and make_routing_data builds routing -structures from the local IDs. +Both EP and non-EP paths use topk + make_routing_data. When expert_map +is provided, global expert IDs are remapped to local IDs before building +routing structures. """ from unittest.mock import MagicMock, patch @@ -24,21 +21,15 @@ class TestTritonMoeForwardExpertMap: @pytest.mark.parametrize("expert_map_present", [False, True]) def test_routing_path_selection(self, expert_map_present): - """Verify that the EP-aware routing path is taken when expert_map - is present, and the legacy_routing path is taken otherwise.""" + """Verify that both EP and non-EP paths use topk + make_routing_data, + and that expert_map remapping is applied when present.""" device = "cuda" if torch.cuda.is_available() else "cpu" - # This is a structural test: we mock the routing functions to - # verify the correct path is exercised. mock_expert_map = ( torch.tensor([0, -1, 1, -1], device=device) if expert_map_present else None ) with ( - patch( - "vllm.model_executor.layers.fused_moe." - "gpt_oss_triton_kernels_moe.legacy_routing" - ) as mock_legacy, patch("triton_kernels.topk.topk") as mock_topk, patch( "vllm.model_executor.layers.fused_moe." @@ -53,27 +44,19 @@ class TestTritonMoeForwardExpertMap: triton_kernel_moe_forward, ) - # Set up return values mock_routing_data = MagicMock() mock_gather = MagicMock() mock_scatter = MagicMock() - if expert_map_present: - sparse_result = MagicMock() - sparse_result.indx = torch.tensor([[0, 2]], dtype=torch.int32) - sparse_result.vals = torch.tensor([[0.6, 0.4]]) - mock_topk.return_value = sparse_result - mock_make_routing.return_value = ( - mock_routing_data, - mock_gather, - mock_scatter, - ) - else: - mock_legacy.return_value = ( - mock_routing_data, - mock_gather, - mock_scatter, - ) + sparse_result = MagicMock() + sparse_result.indx = torch.tensor([[0, 2]], dtype=torch.int32) + sparse_result.vals = torch.tensor([[0.6, 0.4]]) + mock_topk.return_value = sparse_result + mock_make_routing.return_value = ( + mock_routing_data, + mock_gather, + mock_scatter, + ) mock_fused_experts.return_value = torch.zeros((1, 8), device=device) @@ -92,20 +75,14 @@ class TestTritonMoeForwardExpertMap: expert_map=mock_expert_map, ) + # Both paths use topk + make_routing_data + mock_topk.assert_called_once() + mock_make_routing.assert_called_once() + if expert_map_present: - # EP path: should use topk + make_routing_data, NOT - # legacy_routing - mock_topk.assert_called_once() - mock_make_routing.assert_called_once() - mock_legacy.assert_not_called() # expert_map should be None in the fused_experts call # (already applied) call_kwargs = mock_fused_experts.call_args assert call_kwargs[1].get("expert_map") is None or ( len(call_kwargs[0]) > 0 ) - else: - # Non-EP path: should use legacy_routing - mock_legacy.assert_called_once() - mock_topk.assert_not_called() - mock_make_routing.assert_not_called() diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index e03ecd01a..a21ddaba0 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -47,7 +47,6 @@ if has_triton_kernels(): BIT, Bitmatrix, ) - from triton_kernels.topk import topk try: from triton_kernels.tensor import ( @@ -89,6 +88,7 @@ def pack_bitmatrix( offsets = offsets_m[:, None] * n_expts_act + offsets_k[None, :] mask = (offsets_m < n_rows)[:, None] & (offsets_k < n_expts_act)[None, :] indices = tl.load(topk_ids + offsets, mask=mask, other=-1) + valid = indices >= 0 div = indices // 32 rem = indices % 32 one = tl.cast(1, tl.uint32) @@ -99,8 +99,13 @@ def pack_bitmatrix( offs = tl.arange(0, BLOCK_SIZE_K // 32) + i * (BLOCK_SIZE_K // 32) # All topks that need to go into this column has the correct bit set. # Other bits are 0. x is a 2D tensor. + # Guard with `valid` to prevent negative indices from producing + # spurious bits (on HIP, -1 // 32 == 0 and 1 << (-1 % 32) sets + # bit 31). x = tl.where( - div[:, :, None] == offs[None, None, :], (one << rem)[:, :, None], 0 + valid[:, :, None] & (div[:, :, None] == offs[None, None, :]), + (one << rem)[:, :, None], + 0, ) # Reduce x to get a single int32_t bitpack. y = tl.reduce_or(x, axis=1) @@ -108,93 +113,6 @@ def pack_bitmatrix( tl.store(bitmatrix_ptrs, y, mask=offsets_m[:, None] < n_rows) -def legacy_routing_from_bitmatrix( - bitmatrix: "Bitmatrix", - expt_scal: torch.Tensor, - expt_indx: torch.Tensor, - n_expts_tot: int, - n_expts_act: int, -) -> tuple["RoutingData", "GatherIndx", "ScatterIndx"]: - """ - Replacement for the removed triton_kernels.routing.routing_from_bitmatrix. - Creates routing data from a bitmatrix representation. - """ - if use_legacy_triton_kernels: - from triton_kernels.routing import routing_from_bitmatrix - - return routing_from_bitmatrix( - bitmatrix, expt_scal, expt_indx, n_expts_tot, n_expts_act - ) - sparse_logits = SparseMatrix(indx=expt_indx, vals=expt_scal, mask=bitmatrix) - dispatch_indx = sparse_logits.mask_metadata.row_sorted_indx - combine_indx = sparse_logits.mask_metadata.col_sorted_indx - ragged_batch_metadata = make_ragged_tensor_metadata( - sparse_logits.mask_metadata.col_sum, - dispatch_indx.shape[0], - ) - gate_scal = sparse_logits.vals.flatten()[combine_indx] - routing_data = RoutingData( - gate_scal, - ragged_batch_metadata.block_sizes, - n_expts_tot, - n_expts_act, - ragged_batch_metadata, - ) - gather_idx = GatherIndx(combine_indx, dispatch_indx) - scatter_idx = ScatterIndx(dispatch_indx, combine_indx) - return routing_data, gather_idx, scatter_idx - - -def legacy_routing_from_sparsematrix( - sparse_logits: "SparseMatrix", - n_expts_tot: int, - n_expts_act: int, -) -> tuple["RoutingData", "GatherIndx", "ScatterIndx"]: - """ - Creates routing data from a SparseMatrix representation. - """ - dispatch_indx = sparse_logits.mask_metadata.row_sorted_indx - combine_indx = sparse_logits.mask_metadata.col_sorted_indx - ragged_batch_metadata = make_ragged_tensor_metadata( - sparse_logits.mask_metadata.col_sum, - dispatch_indx.shape[0], - ) - gate_scal = sparse_logits.vals.flatten()[combine_indx] - routing_data = RoutingData( - gate_scal, - ragged_batch_metadata.block_sizes, - n_expts_tot, - n_expts_act, - ragged_batch_metadata, - ) - gather_idx = GatherIndx(combine_indx, dispatch_indx) - scatter_idx = ScatterIndx(dispatch_indx, combine_indx) - return routing_data, gather_idx, scatter_idx - - -def legacy_routing( - logits: torch.Tensor, - n_expts_act: int, - sm_first: bool = False, -) -> tuple["RoutingData", "GatherIndx", "ScatterIndx"]: - """ - Replacement for the removed triton_kernels.routing.routing function. - Computes routing data from gating logits. - """ - if use_legacy_triton_kernels: - from triton_kernels.routing import routing - - return routing(logits, n_expts_act, sm_first=sm_first) - if sm_first: - logits = torch.softmax(logits, dim=-1) - sparse_logits = topk(logits, n_expts_act, apply_softmax=not sm_first) - return legacy_routing_from_sparsematrix( - sparse_logits, - logits.shape[-1], - n_expts_act, - ) - - def triton_kernel_moe_forward( hidden_states: torch.Tensor, w1, # Tensor or triton_kernels.Tensor @@ -241,26 +159,22 @@ def triton_kernel_moe_forward( unpadded_K_w2=unpadded_K_w2, ) - if expert_map is not None: - # With expert parallelism, legacy_routing produces routing data - # using global expert IDs which don't correspond to local weight - # indices. Split the routing into topk selection + expert_map - # remapping + local routing data construction (matching the - # approach used by OAITritonExperts.apply). - from triton_kernels.topk import topk as topk_fn + from triton_kernels.topk import topk as topk_fn - sm_first = not renormalize - logits = gating_output - if sm_first: - logits = torch.softmax(logits, dim=-1) - topk_result = topk_fn(logits, topk, apply_softmax=not sm_first) - # topk may return a tuple (vals, indx, bitmatrix) or a - # SparseMatrix depending on the triton_kernels version. - if isinstance(topk_result, tuple): - topk_weights, topk_ids_raw, _ = topk_result - else: - topk_weights = topk_result.vals - topk_ids_raw = topk_result.indx + sm_first = not renormalize + logits = gating_output + if sm_first: + logits = torch.softmax(logits, dim=-1) + topk_result = topk_fn(logits, topk, apply_softmax=not sm_first) + # topk may return a tuple (vals, indx, bitmatrix) or a + # SparseMatrix depending on the triton_kernels version. + if isinstance(topk_result, tuple): + topk_weights, topk_ids_raw, _ = topk_result + else: + topk_weights = topk_result.vals + topk_ids_raw = topk_result.indx + + if expert_map is not None: # topk_ids_raw contains global expert IDs - remap to local. topk_ids = expert_map[topk_ids_raw.to(torch.long)] local_num_experts = w1.shape[0] @@ -271,8 +185,9 @@ def triton_kernel_moe_forward( effective_expert_map = None effective_global_num_experts = local_num_experts else: - routing_data, gather_idx, scatter_idx = legacy_routing( - gating_output, topk, sm_first=not renormalize + topk_ids = topk_ids_raw.to(torch.long) + routing_data, gather_idx, scatter_idx = make_routing_data( + topk_ids, topk_weights, gating_output.shape[-1] ) effective_expert_map = expert_map effective_global_num_experts = global_num_experts @@ -539,10 +454,31 @@ def make_routing_data( # matmul_ogs expects invalid topk_weights to be -1s topk_weights = torch.where(topk_ids == -1, -1.0, topk_weights) - routing_data, gather_indx, scatter_indx = legacy_routing_from_bitmatrix( - bitmatrix, topk_weights, topk_ids, num_local_experts, num_topk - ) + if use_legacy_triton_kernels: + from triton_kernels.routing import routing_from_bitmatrix + + return routing_from_bitmatrix( + bitmatrix, topk_weights, topk_ids, num_local_experts, num_topk + ) + + sparse_logits = SparseMatrix(indx=topk_ids, vals=topk_weights, mask=bitmatrix) + dispatch_indx = sparse_logits.mask_metadata.row_sorted_indx + combine_indx = sparse_logits.mask_metadata.col_sorted_indx + ragged_batch_metadata = make_ragged_tensor_metadata( + sparse_logits.mask_metadata.col_sum, + dispatch_indx.shape[0], + ) + gate_scal = sparse_logits.vals.flatten()[combine_indx] + routing_data = RoutingData( + gate_scal, + ragged_batch_metadata.block_sizes, + num_local_experts, + num_topk, + ragged_batch_metadata, + ) + gather_indx = GatherIndx(combine_indx, dispatch_indx) + scatter_indx = ScatterIndx(dispatch_indx, combine_indx) return routing_data, gather_indx, scatter_indx