[ROCm] Fix MoE kernel test failures on gfx950 (#37833)
Signed-off-by: Andreas Karatzas <akaratza@amd.com> Signed-off-by: Matthew Wong <Matthew.Wong2@amd.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Matthew Wong <Matthew.Wong2@amd.com>
This commit is contained in:
@@ -1,15 +1,31 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.distributed.eplb.eplb_state import EplbLayerState
|
||||
from vllm.model_executor.layers.fused_moe.router.router_factory import (
|
||||
create_fused_moe_router,
|
||||
)
|
||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def _is_aiter_capable() -> bool:
|
||||
"""Check if the platform supports AITER (gfx942/gfx950)."""
|
||||
if not current_platform.is_rocm():
|
||||
return False
|
||||
try:
|
||||
from vllm.platforms.rocm import _ON_MI3XX
|
||||
|
||||
return _ON_MI3XX
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
# Test parameters
|
||||
MK_S = [(32, 256), (64, 512)]
|
||||
@@ -96,6 +112,60 @@ def assert_routing_results_close(
|
||||
)
|
||||
|
||||
|
||||
def assert_aiter_routing_valid(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
top_k: int,
|
||||
num_experts: int,
|
||||
renormalize: bool,
|
||||
routed_scaling_factor: float = 1.0,
|
||||
):
|
||||
"""Validate AITER routing outputs are structurally correct.
|
||||
|
||||
AITER grouped_topk is a fundamentally different implementation from
|
||||
the Python baseline (different group selection, scoring internals),
|
||||
so numerical comparison is not meaningful. Instead we verify the
|
||||
outputs satisfy the routing contract: correct shapes, valid expert
|
||||
IDs, non-negative weights, and proper normalization."""
|
||||
n_tokens = topk_weights.shape[0]
|
||||
|
||||
# Shape
|
||||
assert topk_weights.shape == (n_tokens, top_k), (
|
||||
f"weights shape {topk_weights.shape} != ({n_tokens}, {top_k})"
|
||||
)
|
||||
assert topk_ids.shape == (n_tokens, top_k), (
|
||||
f"ids shape {topk_ids.shape} != ({n_tokens}, {top_k})"
|
||||
)
|
||||
|
||||
# Expert IDs in valid range
|
||||
assert (topk_ids >= 0).all() and (topk_ids < num_experts).all(), (
|
||||
f"expert IDs out of range [0, {num_experts}): "
|
||||
f"min={topk_ids.min().item()}, max={topk_ids.max().item()}"
|
||||
)
|
||||
|
||||
# No duplicate expert IDs per token
|
||||
for i in range(n_tokens):
|
||||
ids = topk_ids[i]
|
||||
assert ids.unique().numel() == top_k, (
|
||||
f"token {i}: duplicate expert IDs {ids.tolist()}"
|
||||
)
|
||||
|
||||
# Weights are non-negative
|
||||
assert (topk_weights >= 0).all(), "negative routing weights"
|
||||
|
||||
# If renormalized, weights should sum to ~scaling_factor per token
|
||||
# (renormalization to 1.0 happens before scaling)
|
||||
if renormalize:
|
||||
expected_sum = routed_scaling_factor
|
||||
sums = topk_weights.sum(dim=-1)
|
||||
torch.testing.assert_close(
|
||||
sums,
|
||||
torch.full_like(sums, expected_sum),
|
||||
atol=1e-3,
|
||||
rtol=1e-3,
|
||||
)
|
||||
|
||||
|
||||
def baseline_fused_topk(
|
||||
router_logits: torch.Tensor, top_k: int, renormalize: bool
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -400,10 +470,7 @@ def test_grouped_topk(
|
||||
|
||||
hidden_states, router_logits = make_test_data(m, k, global_num_experts)
|
||||
|
||||
# Get router output
|
||||
topk_weights, topk_ids = router.select_experts(hidden_states, router_logits)
|
||||
|
||||
# Compute baseline
|
||||
# Compute baseline (pure Python implementation)
|
||||
baseline_weights, baseline_ids = baseline_grouped_topk(
|
||||
router_logits,
|
||||
top_k,
|
||||
@@ -415,8 +482,32 @@ def test_grouped_topk(
|
||||
routed_scaling_factor,
|
||||
)
|
||||
|
||||
# Compare results
|
||||
assert_routing_results_close(topk_weights, topk_ids, baseline_weights, baseline_ids)
|
||||
# Test 1: Python/Triton path against baseline (exact match)
|
||||
with patch(
|
||||
"vllm.model_executor.layers.fused_moe.router.grouped_topk_router.rocm_aiter_ops.is_fused_moe_enabled",
|
||||
return_value=False,
|
||||
):
|
||||
py_weights, py_ids = router.select_experts(hidden_states, router_logits)
|
||||
assert_routing_results_close(py_weights, py_ids, baseline_weights, baseline_ids)
|
||||
|
||||
# Test 2: AITER path — verify outputs are structurally valid.
|
||||
# AITER grouped_topk is a different implementation so we can't
|
||||
# compare numerically against the Python baseline.
|
||||
if _is_aiter_capable():
|
||||
# Force-enable AITER for gfx942/gfx950 regardless of env var,
|
||||
# so CI always exercises this path on capable hardware.
|
||||
with patch.object(rocm_aiter_ops, "_AITER_ENABLED", True):
|
||||
aiter_weights, aiter_ids = router.select_experts(
|
||||
hidden_states, router_logits
|
||||
)
|
||||
assert_aiter_routing_valid(
|
||||
aiter_weights,
|
||||
aiter_ids,
|
||||
top_k,
|
||||
global_num_experts,
|
||||
renormalize,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,k", MK_S)
|
||||
|
||||
Reference in New Issue
Block a user