[ROCm] Fix MoE kernel test failures on gfx950 (#37833)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Matthew Wong <Matthew.Wong2@amd.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Matthew Wong <Matthew.Wong2@amd.com>
This commit is contained in:
Andreas Karatzas
2026-03-25 13:46:40 -05:00
committed by GitHub
parent e38817fadb
commit 7d6917bef5
12 changed files with 478 additions and 86 deletions

View File

@@ -1,15 +1,31 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from unittest.mock import patch
import pytest
import torch
from vllm._aiter_ops import rocm_aiter_ops
from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.layers.fused_moe.router.router_factory import (
create_fused_moe_router,
)
from vllm.model_executor.models.llama4 import Llama4MoE
from vllm.platforms import current_platform
def _is_aiter_capable() -> bool:
"""Check if the platform supports AITER (gfx942/gfx950)."""
if not current_platform.is_rocm():
return False
try:
from vllm.platforms.rocm import _ON_MI3XX
return _ON_MI3XX
except ImportError:
return False
# Test parameters
MK_S = [(32, 256), (64, 512)]
@@ -96,6 +112,60 @@ def assert_routing_results_close(
)
def assert_aiter_routing_valid(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
num_experts: int,
renormalize: bool,
routed_scaling_factor: float = 1.0,
):
"""Validate AITER routing outputs are structurally correct.
AITER grouped_topk is a fundamentally different implementation from
the Python baseline (different group selection, scoring internals),
so numerical comparison is not meaningful. Instead we verify the
outputs satisfy the routing contract: correct shapes, valid expert
IDs, non-negative weights, and proper normalization."""
n_tokens = topk_weights.shape[0]
# Shape
assert topk_weights.shape == (n_tokens, top_k), (
f"weights shape {topk_weights.shape} != ({n_tokens}, {top_k})"
)
assert topk_ids.shape == (n_tokens, top_k), (
f"ids shape {topk_ids.shape} != ({n_tokens}, {top_k})"
)
# Expert IDs in valid range
assert (topk_ids >= 0).all() and (topk_ids < num_experts).all(), (
f"expert IDs out of range [0, {num_experts}): "
f"min={topk_ids.min().item()}, max={topk_ids.max().item()}"
)
# No duplicate expert IDs per token
for i in range(n_tokens):
ids = topk_ids[i]
assert ids.unique().numel() == top_k, (
f"token {i}: duplicate expert IDs {ids.tolist()}"
)
# Weights are non-negative
assert (topk_weights >= 0).all(), "negative routing weights"
# If renormalized, weights should sum to ~scaling_factor per token
# (renormalization to 1.0 happens before scaling)
if renormalize:
expected_sum = routed_scaling_factor
sums = topk_weights.sum(dim=-1)
torch.testing.assert_close(
sums,
torch.full_like(sums, expected_sum),
atol=1e-3,
rtol=1e-3,
)
def baseline_fused_topk(
router_logits: torch.Tensor, top_k: int, renormalize: bool
) -> tuple[torch.Tensor, torch.Tensor]:
@@ -400,10 +470,7 @@ def test_grouped_topk(
hidden_states, router_logits = make_test_data(m, k, global_num_experts)
# Get router output
topk_weights, topk_ids = router.select_experts(hidden_states, router_logits)
# Compute baseline
# Compute baseline (pure Python implementation)
baseline_weights, baseline_ids = baseline_grouped_topk(
router_logits,
top_k,
@@ -415,8 +482,32 @@ def test_grouped_topk(
routed_scaling_factor,
)
# Compare results
assert_routing_results_close(topk_weights, topk_ids, baseline_weights, baseline_ids)
# Test 1: Python/Triton path against baseline (exact match)
with patch(
"vllm.model_executor.layers.fused_moe.router.grouped_topk_router.rocm_aiter_ops.is_fused_moe_enabled",
return_value=False,
):
py_weights, py_ids = router.select_experts(hidden_states, router_logits)
assert_routing_results_close(py_weights, py_ids, baseline_weights, baseline_ids)
# Test 2: AITER path — verify outputs are structurally valid.
# AITER grouped_topk is a different implementation so we can't
# compare numerically against the Python baseline.
if _is_aiter_capable():
# Force-enable AITER for gfx942/gfx950 regardless of env var,
# so CI always exercises this path on capable hardware.
with patch.object(rocm_aiter_ops, "_AITER_ENABLED", True):
aiter_weights, aiter_ids = router.select_experts(
hidden_states, router_logits
)
assert_aiter_routing_valid(
aiter_weights,
aiter_ids,
top_k,
global_num_experts,
renormalize,
routed_scaling_factor,
)
@pytest.mark.parametrize("m,k", MK_S)