feat(moe): Add is_act_and_mul=False support for Triton MoE kernels (#31645)

Signed-off-by: rabi <ramishra@redhat.com>
This commit is contained in:
Rabi Mishra
2026-01-08 07:57:09 +05:30
committed by GitHub
parent 0d7667419f
commit 25eef3dc2e
7 changed files with 191 additions and 9 deletions

View File

@@ -0,0 +1,129 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Test for is_act_and_mul=False MoE using Triton.
This tests the code path used by models like Nemotron-H that use
non-fused activations (e.g., relu2_no_mul) instead of SwiGLU-style
fused activations.
This feature is supported on both CUDA and ROCm (with AITER disabled).
"""
import pytest
import torch
from vllm.platforms import current_platform
pytestmark = pytest.mark.skipif(
not current_platform.is_cuda_alike(),
reason="Tests for is_act_and_mul=False MoE require CUDA or ROCm",
)
@pytest.fixture
def disable_aiter_on_rocm(monkeypatch):
"""Fixture to disable AITER on ROCm to use Triton path."""
if current_platform.is_rocm():
from vllm._aiter_ops import rocm_aiter_ops
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "0")
monkeypatch.setenv("VLLM_ROCM_USE_AITER_MOE", "0")
rocm_aiter_ops.refresh_env_variables()
yield
rocm_aiter_ops.refresh_env_variables()
else:
# On CUDA, no special setup needed
yield
@pytest.fixture
def init_workspace():
"""Initialize workspace manager for MoE tests."""
from vllm.v1.worker.workspace import (
init_workspace_manager,
reset_workspace_manager,
)
torch.manual_seed(42)
init_workspace_manager(torch.cuda.current_device())
yield
reset_workspace_manager()
@pytest.mark.parametrize("m", [1, 33, 64, 222])
@pytest.mark.parametrize("n", [128, 256, 1024])
@pytest.mark.parametrize("k", [128, 512])
@pytest.mark.parametrize("e", [4, 8])
@pytest.mark.parametrize("topk", [2])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("activation", ["relu2_no_mul", "silu_no_mul", "gelu_no_mul"])
@torch.inference_mode()
def test_moe_no_act_mul(
disable_aiter_on_rocm,
init_workspace,
m: int,
n: int,
k: int,
e: int,
topk: int,
dtype: torch.dtype,
activation: str,
):
"""Test MoE with is_act_and_mul=False using Triton."""
from vllm.model_executor.layers.fused_moe import TritonExperts, fused_topk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
a = torch.randn((m, k), device="cuda", dtype=dtype)
w1 = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
quant_config = FusedMoEQuantConfig.make(is_act_and_mul=False)
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=True)
fused_experts = FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
TritonExperts(quant_config),
)
output = fused_experts(
hidden_states=a,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
)
assert output.shape == (m, k), f"Expected shape {(m, k)}, got {output.shape}"
assert not torch.isnan(output).any(), "Output contains NaN"
assert not torch.isinf(output).any(), "Output contains Inf"
assert output.abs().sum() > 0, "Output is all zeros"
@torch.inference_mode()
def test_moe_workspace_shapes_no_act_mul(disable_aiter_on_rocm):
"""Test workspace_shapes returns correct sizes for is_act_and_mul=False."""
from vllm.model_executor.layers.fused_moe import TritonExperts
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
M, N, K, topk = 64, 256, 128, 2
quant_config = FusedMoEQuantConfig.make(is_act_and_mul=False)
experts = TritonExperts(quant_config)
ws1, ws2, out = experts.workspace_shapes(M, N, K, topk, 8, 8, None)
assert ws1[2] == max(N, K)
assert out == (M, K)

View File

@@ -201,6 +201,11 @@ class FusedMoEQuantConfig:
_w1: FusedMoEQuantDesc
_w2: FusedMoEQuantDesc
# Whether activation is fused with gate multiplication (SwiGLU-style).
# When True: intermediate_size = N // 2 (gate and up are combined)
# When False: intermediate_size = N (no gate multiplication)
is_act_and_mul: bool = True
def __post_init__(self):
assert not self.per_act_token_quant or self.block_shape is None, (
"illegal quantization"
@@ -435,6 +440,7 @@ class FusedMoEQuantConfig:
w1_zp: torch.Tensor | None = None,
w2_zp: torch.Tensor | None = None,
weight_dtype: torch.dtype | str | None = None,
is_act_and_mul: bool = True,
) -> "FusedMoEQuantConfig":
"""
General builder function for a FusedMoEQuantConfig.
@@ -494,6 +500,7 @@ class FusedMoEQuantConfig:
_w2=FusedMoEQuantDesc(
weight_dtype, w_shape, w2_scale, g2_alphas, w2_zp, w2_bias
),
is_act_and_mul=is_act_and_mul,
)
assert quant_config.per_act_token_quant == per_act_token_quant
assert quant_config.per_out_ch_quant == per_out_ch_quant
@@ -806,6 +813,7 @@ def awq_marlin_moe_quant_config(
def biased_moe_quant_config(
w1_bias: torch.Tensor | None,
w2_bias: torch.Tensor | None,
is_act_and_mul: bool = True,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for unquantized activations with biases.
@@ -815,6 +823,7 @@ def biased_moe_quant_config(
_a2=FusedMoEQuantDesc(),
_w1=FusedMoEQuantDesc(bias=w1_bias),
_w2=FusedMoEQuantDesc(bias=w2_bias),
is_act_and_mul=is_act_and_mul,
)

View File

@@ -871,8 +871,11 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
num_dp = self.num_dispatchers
num_experts = local_num_experts
max_num_tokens = self.max_num_tokens
# For fused activations (SwiGLU): N = 2 * intermediate, after act = N/2
# For non-fused activations: N = intermediate, after act = N
intermediate_size = N // 2 if self.quant_config.is_act_and_mul else N
workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N))
workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2))
workspace2 = (num_experts, max_num_tokens * num_dp, intermediate_size)
output = (num_experts, max_num_tokens * num_dp, K)
return (workspace13, workspace2, output)
@@ -947,7 +950,11 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
# We can reuse the memory between these because by the time we need
# cache3, we're done with cache1
intermediate_cache1 = _resize_cache(workspace13, (E, max_num_tokens, N))
intermediate_cache2 = _resize_cache(workspace2, (E, max_num_tokens, N // 2))
# For fused activations (SwiGLU): output is N/2, for non-fused: output is N
intermediate_size = N // 2 if self.quant_config.is_act_and_mul else N
intermediate_cache2 = _resize_cache(
workspace2, (E, max_num_tokens, intermediate_size)
)
# TODO(bnell): should this be done for any quantized type?
if self.quant_config.use_fp8_w8a8:
@@ -978,7 +985,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
# TODO (bnell): use triton utility from batched deep gemm.
self.activation(
activation,
intermediate_cache2.view(-1, N // 2),
intermediate_cache2.view(-1, intermediate_size),
intermediate_cache1.view(-1, N),
)

View File

@@ -2292,7 +2292,10 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
workspace1 = (M, topk, max(N // 2, K))
# For fused activations (SwiGLU): N = 2 * intermediate, after act = N/2
# For non-fused activations: N = intermediate, after act = N
intermediate_size = N // 2 if self.quant_config.is_act_and_mul else N
workspace1 = (M, topk, max(intermediate_size, K))
workspace2 = (M, topk, max(N, K))
output = (M, K)
return (workspace1, workspace2, output)
@@ -2367,8 +2370,10 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
# Note that the output tensor might be in workspace1
intermediate_cache1 = _resize_cache(workspace2, (num_tokens, top_k_num, N))
# For fused activations (SwiGLU): output is N/2, for non-fused: output is N
intermediate_size = N // 2 if self.quant_config.is_act_and_mul else N
intermediate_cache2 = _resize_cache(
workspace13, (num_tokens * top_k_num, N // 2)
workspace13, (num_tokens * top_k_num, intermediate_size)
)
intermediate_cache3 = _resize_cache(workspace2, (num_tokens, top_k_num, K))

View File

@@ -603,9 +603,15 @@ class FusedMoE(CustomOp):
"is_act_and_mul=False is supported only for unquantized "
", ModelOpt FP8, and ModelOpt NvFp4 checkpoints"
)
if not current_platform.is_cuda():
# ROCm without AITER MoE uses Triton which supports
# is_act_and_mul=False via standard PyTorch ops (F.silu, F.gelu)
rocm_without_aiter_moe = (
current_platform.is_rocm() and not rocm_aiter_ops.is_fused_moe_enabled()
)
if not current_platform.is_cuda() and not rocm_without_aiter_moe:
raise NotImplementedError(
"is_act_and_mul=False is supported only for CUDA for now"
"is_act_and_mul=False is supported only for CUDA, or ROCm "
"(when AITER MoE is disabled) for now"
)
if self.enable_eplb and not self.quant_method.supports_eplb:

View File

@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from enum import Enum
from math import prod
from math import prod, sqrt
from typing import final
import torch
@@ -575,14 +575,35 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
def activation(
self, activation: str, output: torch.Tensor, input: torch.Tensor
) -> None:
assert output.size(-1) * 2 == input.size(-1)
# Fused activations (SwiGLU-style): output is half the size of input
if activation == "silu":
assert output.size(-1) * 2 == input.size(-1)
torch.ops._C.silu_and_mul(output, input)
elif activation == "gelu":
assert output.size(-1) * 2 == input.size(-1)
torch.ops._C.gelu_and_mul(output, input)
elif activation == "swigluoai":
# alpha = 1.702, limit = 7.0
assert output.size(-1) * 2 == input.size(-1)
torch.ops._C.swigluoai_and_mul(output, input)
# Non-fused activations (is_act_and_mul=False): output same size as input
elif activation == "silu_no_mul":
assert output.size(-1) == input.size(-1)
# Use out= argument to avoid intermediate tensor
torch.sigmoid(input, out=output)
output.mul_(input)
elif activation == "gelu_no_mul":
assert output.size(-1) == input.size(-1)
# GELU(x) = 0.5 * x * (1 + erf(x / sqrt(2)))
# Use out= and in-place ops to avoid intermediate tensors
output.copy_(input).div_(sqrt(2))
torch.erf(output, out=output)
output.add_(1).mul_(input).mul_(0.5)
elif activation == "relu2_no_mul":
assert output.size(-1) == input.size(-1)
# ReLU²: clamp has out=, then in-place square
torch.clamp(input, min=0, out=output)
output.square_()
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}")

View File

@@ -299,7 +299,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
return biased_moe_quant_config(
layer.w13_bias,
layer.w2_bias,
is_act_and_mul=self.moe.is_act_and_mul,
)
elif not self.moe.is_act_and_mul:
# Create a config with is_act_and_mul=False since
# FUSED_MOE_UNQUANTIZED_CONFIG has is_act_and_mul=True
return FusedMoEQuantConfig.make(is_act_and_mul=False)
else:
return FUSED_MOE_UNQUANTIZED_CONFIG