[ROCm] Fix MoE kernel test failures on gfx950 (#37833)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Matthew Wong <Matthew.Wong2@amd.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Matthew Wong <Matthew.Wong2@amd.com>
This commit is contained in:
Andreas Karatzas
2026-03-25 13:46:40 -05:00
committed by GitHub
parent e38817fadb
commit 7d6917bef5
12 changed files with 478 additions and 86 deletions

View File

@@ -32,6 +32,14 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8Dynamic128Sym,
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8Static128BlockSym,
kFp8StaticChannelSym,
kFp8StaticTensorSym,
)
from vllm.utils.import_utils import (
has_aiter,
has_deep_ep,
@@ -152,6 +160,39 @@ class Config:
return vllm_config, env_dict
def fe_supports_quant_scheme(self) -> bool:
"""Check if the fused experts class supports this quant config.
See https://github.com/ROCm/aiter/issues/2419 for AITER gaps."""
if self.quant_config is None or self.quant_dtype is None:
return True
if self.quant_dtype != torch.float8_e4m3fn:
return True
# Derive QuantKeys from test config
if self.quant_block_shape is not None:
w_key = kFp8Static128BlockSym
a_key = kFp8Dynamic128Sym
elif self.is_per_out_ch_quant:
w_key = kFp8StaticChannelSym
a_key = (
kFp8DynamicTokenSym
if self.is_per_act_token_quant
else kFp8StaticTensorSym
)
else:
w_key = kFp8StaticTensorSym
a_key = (
kFp8DynamicTensorSym
if self.is_per_act_token_quant
else kFp8StaticTensorSym
)
fe_cls = self.fused_experts_type
if hasattr(fe_cls, "_supports_quant_scheme"):
try:
return fe_cls._supports_quant_scheme(w_key, a_key)
except NotImplementedError:
pass
return True
def is_fp8_block_quantized(self):
return (
self.quant_dtype == torch.float8_e4m3fn
@@ -253,6 +294,15 @@ class Config:
f"{self.fe_supported_types()}."
)
# Check quant scheme compatibility with fused experts class
if not self.fe_supports_quant_scheme():
return False, (
f"FE {self.fused_experts_type.__name__} does not support "
f"quant scheme (per_out_ch={self.is_per_out_ch_quant}, "
f"per_act_token={self.is_per_act_token_quant}, "
f"block={self.quant_block_shape})"
)
# Check block quantization support
is_block_quantized = self.quant_block_shape is not None
if is_block_quantized and self.quant_dtype is None:

View File

@@ -384,12 +384,21 @@ def test_legacy_routing(
logits = gating_output
if sm_first:
logits = torch.softmax(logits, dim=-1)
sparse_logits = topk_fn(logits, topk, apply_softmax=not sm_first)
topk_ids = sparse_logits.indx.to(torch.long)
topk_weights = sparse_logits.vals
routing_data_ref, gather_indx_ref, scatter_indx_ref = make_routing_data(
topk_ids, topk_weights, num_experts
)
topk_result = topk_fn(logits, topk, apply_softmax=not sm_first)
# topk_fn returns SparseMatrix on NVIDIA, plain tuple on ROCm.
if isinstance(topk_result, tuple):
topk_weights, topk_ids_raw, bitmatrix = topk_result
from triton_kernels.routing import routing_from_bitmatrix
routing_data_ref, gather_indx_ref, scatter_indx_ref = routing_from_bitmatrix(
bitmatrix, topk_weights, topk_ids_raw, num_experts, topk
)
else:
topk_ids = topk_result.indx.to(torch.long)
topk_weights = topk_result.vals
routing_data_ref, gather_indx_ref, scatter_indx_ref = make_routing_data(
topk_ids, topk_weights, num_experts
)
routing_data, gather_indx, scatter_indx = legacy_routing(
gating_output, topk, sm_first=sm_first

View File

@@ -108,6 +108,23 @@ def rank_worker(
# inputs for rank
rank_tensors = RankTensors.make(config, pgi)
# Skip unsupported: AITER block-scaled MoE does not
# support apply_router_weight_on_input (topk=1 path).
# https://github.com/ROCm/aiter/issues/2418
if (
topk == 1
and config.supports_apply_weight_on_input()
and getattr(config.fused_experts_type, "__name__", "") == "AiterExperts"
and config.quant_block_shape is not None
):
print(
f"Skipping[{pgi.rank}]: m={m}, topk={topk}"
" (AITER block-scaled + weight-on-input,"
" https://github.com/ROCm/aiter/issues/2418)"
)
count -= 1
continue
# modular kernel out
mk_out = run_modular_kernel(pgi, vllm_config, config, weights, rank_tensors)
@@ -121,7 +138,48 @@ def rank_worker(
atol = 3e-2
rtol = 3e-2
torch.testing.assert_close(ref_out, mk_out, atol=atol, rtol=rtol)
# On ROCm, AITER FP8 fused MoE uses hardware FP8
# dot-product which can produce slightly larger error
# than dequant+f32 matmul at FP8 representable-value
# boundaries. Allow a small percentage of elements to
# exceed the base tolerance by a bounded margin.
# https://github.com/ROCm/aiter/issues/2421
from vllm.platforms import current_platform as _cp
is_aiter_fp8 = (
_cp.is_rocm()
and getattr(config.fused_experts_type, "__name__", "") == "AiterExperts"
and config.quant_config is not None
)
if is_aiter_fp8:
diff = (ref_out - mk_out).abs()
n_total = diff.numel()
max_diff = diff.max().item()
n_exceed = int((diff > atol).sum().item())
pct_exceed = n_exceed / n_total * 100
# FP8 hw matmul vs f32 reference: up to ~4% of
# elements may exceed base tolerance, but max
# error should stay within 3x base tolerance.
max_pct_allowed = 5.0
relaxed_atol = atol * 4
print(
f"[AITER FP8 precision] "
f"max_diff={max_diff:.6f}, "
f"exceed_atol={n_exceed}/{n_total} "
f"({pct_exceed:.4f}%), "
f"max_pct_allowed={max_pct_allowed}%, "
f"relaxed_limit={relaxed_atol}"
)
assert pct_exceed <= max_pct_allowed, (
f"AITER FP8: {pct_exceed:.2f}% elements exceed "
f"atol={atol} (max allowed {max_pct_allowed}%)"
)
assert max_diff <= relaxed_atol, (
f"AITER FP8: max_diff={max_diff:.6f} exceeds "
f"relaxed limit {relaxed_atol}"
)
else:
torch.testing.assert_close(ref_out, mk_out, atol=atol, rtol=rtol)
format_result(verbose, config.describe())
except Exception as ex:
format_result(verbose, config.describe(), ex)

View File

@@ -1,15 +1,31 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from unittest.mock import patch
import pytest
import torch
from vllm._aiter_ops import rocm_aiter_ops
from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.layers.fused_moe.router.router_factory import (
create_fused_moe_router,
)
from vllm.model_executor.models.llama4 import Llama4MoE
from vllm.platforms import current_platform
def _is_aiter_capable() -> bool:
"""Check if the platform supports AITER (gfx942/gfx950)."""
if not current_platform.is_rocm():
return False
try:
from vllm.platforms.rocm import _ON_MI3XX
return _ON_MI3XX
except ImportError:
return False
# Test parameters
MK_S = [(32, 256), (64, 512)]
@@ -96,6 +112,60 @@ def assert_routing_results_close(
)
def assert_aiter_routing_valid(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
num_experts: int,
renormalize: bool,
routed_scaling_factor: float = 1.0,
):
"""Validate AITER routing outputs are structurally correct.
AITER grouped_topk is a fundamentally different implementation from
the Python baseline (different group selection, scoring internals),
so numerical comparison is not meaningful. Instead we verify the
outputs satisfy the routing contract: correct shapes, valid expert
IDs, non-negative weights, and proper normalization."""
n_tokens = topk_weights.shape[0]
# Shape
assert topk_weights.shape == (n_tokens, top_k), (
f"weights shape {topk_weights.shape} != ({n_tokens}, {top_k})"
)
assert topk_ids.shape == (n_tokens, top_k), (
f"ids shape {topk_ids.shape} != ({n_tokens}, {top_k})"
)
# Expert IDs in valid range
assert (topk_ids >= 0).all() and (topk_ids < num_experts).all(), (
f"expert IDs out of range [0, {num_experts}): "
f"min={topk_ids.min().item()}, max={topk_ids.max().item()}"
)
# No duplicate expert IDs per token
for i in range(n_tokens):
ids = topk_ids[i]
assert ids.unique().numel() == top_k, (
f"token {i}: duplicate expert IDs {ids.tolist()}"
)
# Weights are non-negative
assert (topk_weights >= 0).all(), "negative routing weights"
# If renormalized, weights should sum to ~scaling_factor per token
# (renormalization to 1.0 happens before scaling)
if renormalize:
expected_sum = routed_scaling_factor
sums = topk_weights.sum(dim=-1)
torch.testing.assert_close(
sums,
torch.full_like(sums, expected_sum),
atol=1e-3,
rtol=1e-3,
)
def baseline_fused_topk(
router_logits: torch.Tensor, top_k: int, renormalize: bool
) -> tuple[torch.Tensor, torch.Tensor]:
@@ -400,10 +470,7 @@ def test_grouped_topk(
hidden_states, router_logits = make_test_data(m, k, global_num_experts)
# Get router output
topk_weights, topk_ids = router.select_experts(hidden_states, router_logits)
# Compute baseline
# Compute baseline (pure Python implementation)
baseline_weights, baseline_ids = baseline_grouped_topk(
router_logits,
top_k,
@@ -415,8 +482,32 @@ def test_grouped_topk(
routed_scaling_factor,
)
# Compare results
assert_routing_results_close(topk_weights, topk_ids, baseline_weights, baseline_ids)
# Test 1: Python/Triton path against baseline (exact match)
with patch(
"vllm.model_executor.layers.fused_moe.router.grouped_topk_router.rocm_aiter_ops.is_fused_moe_enabled",
return_value=False,
):
py_weights, py_ids = router.select_experts(hidden_states, router_logits)
assert_routing_results_close(py_weights, py_ids, baseline_weights, baseline_ids)
# Test 2: AITER path — verify outputs are structurally valid.
# AITER grouped_topk is a different implementation so we can't
# compare numerically against the Python baseline.
if _is_aiter_capable():
# Force-enable AITER for gfx942/gfx950 regardless of env var,
# so CI always exercises this path on capable hardware.
with patch.object(rocm_aiter_ops, "_AITER_ENABLED", True):
aiter_weights, aiter_ids = router.select_experts(
hidden_states, router_logits
)
assert_aiter_routing_valid(
aiter_weights,
aiter_ids,
top_k,
global_num_experts,
renormalize,
routed_scaling_factor,
)
@pytest.mark.parametrize("m,k", MK_S)

View File

@@ -14,6 +14,7 @@ import torch.nn as nn
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm.platforms import current_platform
from vllm.utils.torch_utils import is_torch_equal_or_newer
@@ -51,6 +52,60 @@ class SimpleSharedExperts(nn.Module):
return self.down(nn.functional.silu(gate) * up)
def _assert_close(
actual: torch.Tensor,
expected: torch.Tensor,
atol: float,
rtol: float,
label: str,
) -> None:
"""assert_close that prints diff diagnostics on both success and failure."""
actual_nans = int(actual.isnan().sum().item())
expected_nans = int(expected.isnan().sum().item())
actual_zeros = int((actual == 0).sum().item())
expected_zeros = int((expected == 0).sum().item())
n_total = actual.numel()
diff = (actual - expected).abs()
max_diff = diff.max().item()
mean_diff = diff.mean().item()
n_exceed = int((diff > atol).sum().item())
pct_exceed = n_exceed / n_total * 100
print(
f"[{label}] "
f"shape={list(actual.shape)}, "
f"max_diff={max_diff:.6e}, "
f"mean_diff={mean_diff:.6e}, "
f"exceed_atol({atol})={n_exceed}/{n_total} ({pct_exceed:.2f}%), "
f"actual=[{actual.min().item():.4f}, {actual.max().item():.4f}], "
f"expected=[{expected.min().item():.4f}, {expected.max().item():.4f}], "
f"nan(actual/expected)={actual_nans}/{expected_nans}, "
f"zeros(actual/expected)={actual_zeros}/{expected_zeros}"
)
assert actual_nans == 0, (
f"{label}: actual has {actual_nans}/{n_total} NaN values "
f"(expected has {expected_nans}). "
f"This indicates a kernel bug, not a precision issue."
)
assert expected_nans == 0, (
f"{label}: expected has {expected_nans}/{n_total} NaN values. "
f"This indicates a kernel bug, not a precision issue."
)
torch.testing.assert_close(
actual,
expected,
atol=atol,
rtol=rtol,
msg=(
f"{label}: max_diff={max_diff:.6e}, mean_diff={mean_diff:.6e}, "
f"exceed_atol({atol})={n_exceed}/{n_total} ({pct_exceed:.2f}%)"
),
)
@pytest.fixture(autouse=True)
def setup_cuda():
if not torch.cuda.is_available():
@@ -61,6 +116,9 @@ def setup_cuda():
@pytest.mark.parametrize("num_tokens", [1, 32])
@pytest.mark.parametrize("hidden_size,latent_size", [(256, 128), (128, 64)])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
)
@pytest.mark.skipif(
is_torch_equal_or_newer("2.10.0"),
reason="Test fails with PyTorch 2.10.0 see: https://github.com/vllm-project/vllm/issues/33995",
@@ -70,14 +128,24 @@ def test_routed_input_transform_inside_vs_outside(
hidden_size: int,
latent_size: int,
dtype: torch.dtype,
use_rocm_aiter: bool,
dist_init,
workspace_init,
monkeypatch,
):
"""Compare SharedFusedMoE with transform inside vs manually applying outside.
Method A (inside): SharedFusedMoE with routed_input_transform
Method B (outside): Manually transform, then SharedFusedMoE without transform
"""
if current_platform.is_rocm() and use_rocm_aiter:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1" if use_rocm_aiter else "0")
monkeypatch.setenv("VLLM_ROCM_USE_AITER_MOE", "1" if use_rocm_aiter else "0")
from vllm._aiter_ops import rocm_aiter_ops
rocm_aiter_ops.refresh_env_variables()
torch.manual_seed(42)
torch.cuda.manual_seed(42)
num_experts = 8
top_k = 2
@@ -125,7 +193,13 @@ def test_routed_input_transform_inside_vs_outside(
prefix="moe_without_transform",
)
# Weights are created via torch.empty (uninitialized).
# Initialize with seeded random values for reproducibility.
with torch.no_grad():
moe_with_transform.w13_weight.normal_()
moe_with_transform.w13_weight.div_(10)
moe_with_transform.w2_weight.normal_()
moe_with_transform.w2_weight.div_(10)
moe_without_transform.w13_weight.copy_(moe_with_transform.w13_weight)
moe_without_transform.w2_weight.copy_(moe_with_transform.w2_weight)
@@ -139,9 +213,14 @@ def test_routed_input_transform_inside_vs_outside(
hidden_states = torch.randn(num_tokens, hidden_size, device="cuda", dtype=dtype)
router_logits = torch.randn(num_tokens, num_experts, device="cuda", dtype=dtype)
# Clone inputs so any in-place modification by Method A
# cannot affect Method B's computation.
hidden_states_A = hidden_states.clone()
router_logits_A = router_logits.clone()
with set_forward_context(None, vllm_config, num_tokens=num_tokens):
shared_out_A, routed_out_A = moe_with_transform(
hidden_states, router_logits
hidden_states_A, router_logits_A
)
transformed_hidden = routed_transform(hidden_states)
@@ -149,19 +228,19 @@ def test_routed_input_transform_inside_vs_outside(
transformed_hidden, router_logits
)
torch.testing.assert_close(
expected_shared_out = shared_experts(hidden_states)
_assert_close(
routed_out_A,
routed_out_B,
atol=1e-3,
rtol=1e-3,
msg="Routed output should match: transform inside vs outside",
label="Routed output: transform inside vs outside",
)
expected_shared_out = shared_experts(hidden_states)
torch.testing.assert_close(
_assert_close(
shared_out_A,
expected_shared_out,
atol=1e-3,
rtol=1e-3,
label="Shared expert output",
)

View File

@@ -10,18 +10,15 @@ import torch
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
persistent_masked_m_silu_mul_quant,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
get_fp8_min_max,
)
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import DeepGemmQuantScaleFMT, has_deep_gemm
from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.torch_utils import set_random_seed
if current_platform.is_fp8_fnuz():
pytest.skip(
"Tests in this file require float8_e4m3fn and platform does not support",
allow_module_level=True,
)
fp8_dtype = torch.float8_e4m3fn
fp8_dtype = current_platform.fp8_dtype()
CASES = [
(1, 1, 128, fp8_dtype),
@@ -58,22 +55,21 @@ def as_uint8(x) -> torch.Tensor:
def silu(x: torch.Tensor) -> torch.Tensor:
one_f32 = torch.tensor([1.0], device=x.device, dtype=torch.float32)
x_f32 = x.to(torch.float32)
act_f32 = x_f32 / (one_f32 + torch.exp(-x_f32))
assert act_f32.dtype == torch.float32
return act_f32.to(torch.bfloat16)
act_f32 = x_f32 / (1.0 + torch.exp(-x_f32))
if current_platform.is_cuda():
# C++ kernel returns bf16
return act_f32.to(torch.bfloat16)
# Triton fallback stays in f32
return act_f32
def do_quant(x: torch.Tensor, group_size: int, ceil_ue8m0: bool):
fp8_min_val, fp8_max_val = get_fp8_min_max()
eps_bf16 = torch.tensor([1e-10], device=x.device, dtype=torch.bfloat16)
one_bf16 = torch.tensor([1.0], device=x.device, dtype=torch.bfloat16)
fp8_max_bf16 = torch.tensor(
[torch.finfo(fp8_dtype).max], device=x.device, dtype=torch.bfloat16
)
fp8_min_bf16 = torch.tensor(
[torch.finfo(fp8_dtype).min], device=x.device, dtype=torch.bfloat16
)
fp8_max_bf16 = torch.tensor([fp8_max_val], device=x.device, dtype=torch.bfloat16)
fp8_min_bf16 = torch.tensor([fp8_min_val], device=x.device, dtype=torch.bfloat16)
fp8_max_inv = one_bf16 / fp8_max_bf16
assert fp8_max_inv.dtype == torch.bfloat16
@@ -81,22 +77,36 @@ def do_quant(x: torch.Tensor, group_size: int, ceil_ue8m0: bool):
num_groups = x.numel() // group_size
x_og_shape = x.shape
x = x.to(torch.bfloat16)
x = x.view((-1, group_size))
amax = x.abs().amax(dim=1).clamp(min=eps_bf16)
assert amax.dtype == torch.bfloat16
s = amax * fp8_max_inv
if current_platform.is_cuda():
# C++ kernel computes entirely in bf16
x = x.to(torch.bfloat16)
x = x.view((-1, group_size))
amax = x.abs().amax(dim=1).clamp(min=eps_bf16)
assert amax.dtype == torch.bfloat16
s = amax * fp8_max_inv
if ceil_ue8m0:
s = torch.exp2(
torch.ceil(torch.log2(s).to(torch.bfloat16)).to(torch.bfloat16)
).to(torch.bfloat16)
if ceil_ue8m0:
s = torch.exp2(
torch.ceil(torch.log2(s).to(torch.bfloat16)).to(torch.bfloat16)
).to(torch.bfloat16)
inv_s = one_bf16 / s
inv_s = inv_s.view((num_groups, 1))
xq = torch.clamp(x * inv_s, min=fp8_min_bf16.item(), max=fp8_max_bf16.item()).to(
fp8_dtype
)
inv_s = one_bf16 / s
inv_s = inv_s.view((num_groups, 1))
xq = torch.clamp(
x * inv_s, min=fp8_min_bf16.item(), max=fp8_max_bf16.item()
).to(fp8_dtype)
else:
# Triton fallback computes in f32. Use multiply-by-reciprocal
# to match Triton's constexpr evaluation of 1.0/fp8_max.
fp8_min_f, fp8_max_f = get_fp8_min_max()
x = x.to(torch.float32).view((-1, group_size))
amax = x.abs().amax(dim=1).clamp(min=1e-10)
s = amax * (1.0 / fp8_max_f)
if ceil_ue8m0:
s = torch.exp2(torch.ceil(torch.log2(s)))
inv_s = (1.0 / s).view((num_groups, 1))
xq = torch.clamp(x * inv_s, min=fp8_min_f, max=fp8_max_f).to(fp8_dtype)
xq = xq.view(x_og_shape)
xs = s.view((-1, xq.size(-1) // group_size))
@@ -112,12 +122,10 @@ def silu_mul_quant(
assert gate.dtype == torch.bfloat16
assert up.dtype == torch.bfloat16
act_bf16 = silu(gate)
assert act_bf16.dtype == torch.bfloat16
act = silu(gate)
# act & mul
a_m = act_bf16 * up
assert a_m.dtype == torch.bfloat16
a_m = act * up
q, s = do_quant(a_m, group_size, ceil_ue8m0)
return q, s
@@ -221,8 +229,12 @@ def test_silu_mul_fp8_quant_deep_gemm(E: int, T: int, H: int, fp8_type: torch.dt
scale_fmts = [
DeepGemmQuantScaleFMT.FLOAT32,
DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0,
DeepGemmQuantScaleFMT.UE8M0,
]
# UE8M0 (int32 packed) scales require the C++ kernel which is
# not available on ROCm (#ifndef USE_ROCM).
# https://github.com/ROCm/aiter/issues/2420
if current_platform.is_cuda():
scale_fmts.append(DeepGemmQuantScaleFMT.UE8M0)
# Run the SiLU V2 kernel
for scale_fmt in scale_fmts:
@@ -274,10 +286,23 @@ def test_silu_mul_fp8_quant_deep_gemm(E: int, T: int, H: int, fp8_type: torch.dt
for e in range(E):
nt = tokens_per_expert[e].item()
torch.testing.assert_close(
y_q[e, :nt].to(torch.float32),
ref_y_q[e, :nt].to(torch.float32),
)
if current_platform.is_rocm():
# On ROCm the Triton fallback kernel uses f32 math
# intrinsics (tl.exp) that may differ from PyTorch's
# torch.exp by 1 ULP. At FP8 quantization
# boundaries this can flip one representable value.
# Allow 1 FP8 quantum of tolerance.
torch.testing.assert_close(
y_q[e, :nt].to(torch.float32),
ref_y_q[e, :nt].to(torch.float32),
atol=32.0,
rtol=0.2,
)
else:
torch.testing.assert_close(
y_q[e, :nt].to(torch.float32),
ref_y_q[e, :nt].to(torch.float32),
)
if scale_fmt == DeepGemmQuantScaleFMT.UE8M0:
G = H // group_size

View File

@@ -16,7 +16,7 @@ from vllm.platforms import current_platform
"platform_method,expected_backend",
[
("is_cuda", UnquantizedMoeBackend.TRITON), # Default CUDA without FlashInfer
("is_rocm", UnquantizedMoeBackend.TRITON),
("is_rocm", UnquantizedMoeBackend.TRITON), # ROCm without AITER
("is_cpu", UnquantizedMoeBackend.CPU),
("is_xpu", UnquantizedMoeBackend.XPU),
("is_tpu", UnquantizedMoeBackend.TPU),
@@ -27,13 +27,19 @@ from vllm.platforms import current_platform
"vllm.model_executor.layers.fused_moe.oracle.unquantized.has_flashinfer",
return_value=False,
)
@patch(
"vllm.model_executor.layers.fused_moe.oracle.unquantized.rocm_aiter_ops.is_fused_moe_enabled",
return_value=False,
)
def test_select_default_backend_by_platform(
mock_aiter_enabled,
mock_has_flashinfer,
monkeypatch,
platform_method,
expected_backend,
):
"""Test backend selection for different platforms."""
"""Test default backend selection per platform with all optional
accelerators (FlashInfer, AITER) disabled."""
with patch(
"vllm.model_executor.layers.fused_moe.oracle.unquantized.current_platform"
) as mock_platform:
@@ -58,6 +64,39 @@ def test_select_default_backend_by_platform(
assert selected_backend == expected_backend
@patch(
"vllm.model_executor.layers.fused_moe.oracle.unquantized.has_flashinfer",
return_value=False,
)
@patch(
"vllm.model_executor.layers.fused_moe.oracle.unquantized.rocm_aiter_ops.is_fused_moe_enabled",
return_value=True,
)
@pytest.mark.skipif(
not current_platform.is_rocm(), reason="ROCm-specific backend selection test"
)
def test_select_rocm_aiter_backend(mock_aiter_enabled, mock_has_flashinfer):
"""Test ROCm backend selection when AITER is available."""
with patch(
"vllm.model_executor.layers.fused_moe.oracle.unquantized.current_platform"
) as mock_platform:
mock_platform.is_cuda.return_value = False
mock_platform.is_rocm.return_value = True
mock_platform.is_cpu.return_value = False
mock_platform.is_xpu.return_value = False
mock_platform.is_tpu.return_value = False
mock_platform.is_out_of_tree.return_value = False
moe_config = make_dummy_moe_config()
selected_backend = select_unquantized_moe_backend(
moe_config=moe_config,
use_ep=False,
use_dp=False,
)
assert selected_backend == UnquantizedMoeBackend.AITER
@patch(
"vllm.model_executor.layers.fused_moe.oracle.unquantized.has_flashinfer",
return_value=True,