feat(moe): Add is_act_and_mul=False support for Triton MoE kernels (#31645)
Signed-off-by: rabi <ramishra@redhat.com>
This commit is contained in:
129
tests/kernels/moe/test_triton_moe_no_act_mul.py
Normal file
129
tests/kernels/moe/test_triton_moe_no_act_mul.py
Normal file
@@ -0,0 +1,129 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Test for is_act_and_mul=False MoE using Triton.
|
||||
|
||||
This tests the code path used by models like Nemotron-H that use
|
||||
non-fused activations (e.g., relu2_no_mul) instead of SwiGLU-style
|
||||
fused activations.
|
||||
|
||||
This feature is supported on both CUDA and ROCm (with AITER disabled).
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not current_platform.is_cuda_alike(),
|
||||
reason="Tests for is_act_and_mul=False MoE require CUDA or ROCm",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def disable_aiter_on_rocm(monkeypatch):
|
||||
"""Fixture to disable AITER on ROCm to use Triton path."""
|
||||
if current_platform.is_rocm():
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "0")
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER_MOE", "0")
|
||||
rocm_aiter_ops.refresh_env_variables()
|
||||
|
||||
yield
|
||||
|
||||
rocm_aiter_ops.refresh_env_variables()
|
||||
else:
|
||||
# On CUDA, no special setup needed
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def init_workspace():
|
||||
"""Initialize workspace manager for MoE tests."""
|
||||
from vllm.v1.worker.workspace import (
|
||||
init_workspace_manager,
|
||||
reset_workspace_manager,
|
||||
)
|
||||
|
||||
torch.manual_seed(42)
|
||||
init_workspace_manager(torch.cuda.current_device())
|
||||
|
||||
yield
|
||||
|
||||
reset_workspace_manager()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1, 33, 64, 222])
|
||||
@pytest.mark.parametrize("n", [128, 256, 1024])
|
||||
@pytest.mark.parametrize("k", [128, 512])
|
||||
@pytest.mark.parametrize("e", [4, 8])
|
||||
@pytest.mark.parametrize("topk", [2])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("activation", ["relu2_no_mul", "silu_no_mul", "gelu_no_mul"])
|
||||
@torch.inference_mode()
|
||||
def test_moe_no_act_mul(
|
||||
disable_aiter_on_rocm,
|
||||
init_workspace,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
activation: str,
|
||||
):
|
||||
"""Test MoE with is_act_and_mul=False using Triton."""
|
||||
from vllm.model_executor.layers.fused_moe import TritonExperts, fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype)
|
||||
w1 = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(is_act_and_mul=False)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=True)
|
||||
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
TritonExperts(quant_config),
|
||||
)
|
||||
|
||||
output = fused_experts(
|
||||
hidden_states=a,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
)
|
||||
|
||||
assert output.shape == (m, k), f"Expected shape {(m, k)}, got {output.shape}"
|
||||
assert not torch.isnan(output).any(), "Output contains NaN"
|
||||
assert not torch.isinf(output).any(), "Output contains Inf"
|
||||
assert output.abs().sum() > 0, "Output is all zeros"
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_moe_workspace_shapes_no_act_mul(disable_aiter_on_rocm):
|
||||
"""Test workspace_shapes returns correct sizes for is_act_and_mul=False."""
|
||||
from vllm.model_executor.layers.fused_moe import TritonExperts
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
|
||||
M, N, K, topk = 64, 256, 128, 2
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(is_act_and_mul=False)
|
||||
experts = TritonExperts(quant_config)
|
||||
ws1, ws2, out = experts.workspace_shapes(M, N, K, topk, 8, 8, None)
|
||||
|
||||
assert ws1[2] == max(N, K)
|
||||
assert out == (M, K)
|
||||
@@ -201,6 +201,11 @@ class FusedMoEQuantConfig:
|
||||
_w1: FusedMoEQuantDesc
|
||||
_w2: FusedMoEQuantDesc
|
||||
|
||||
# Whether activation is fused with gate multiplication (SwiGLU-style).
|
||||
# When True: intermediate_size = N // 2 (gate and up are combined)
|
||||
# When False: intermediate_size = N (no gate multiplication)
|
||||
is_act_and_mul: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
assert not self.per_act_token_quant or self.block_shape is None, (
|
||||
"illegal quantization"
|
||||
@@ -435,6 +440,7 @@ class FusedMoEQuantConfig:
|
||||
w1_zp: torch.Tensor | None = None,
|
||||
w2_zp: torch.Tensor | None = None,
|
||||
weight_dtype: torch.dtype | str | None = None,
|
||||
is_act_and_mul: bool = True,
|
||||
) -> "FusedMoEQuantConfig":
|
||||
"""
|
||||
General builder function for a FusedMoEQuantConfig.
|
||||
@@ -494,6 +500,7 @@ class FusedMoEQuantConfig:
|
||||
_w2=FusedMoEQuantDesc(
|
||||
weight_dtype, w_shape, w2_scale, g2_alphas, w2_zp, w2_bias
|
||||
),
|
||||
is_act_and_mul=is_act_and_mul,
|
||||
)
|
||||
assert quant_config.per_act_token_quant == per_act_token_quant
|
||||
assert quant_config.per_out_ch_quant == per_out_ch_quant
|
||||
@@ -806,6 +813,7 @@ def awq_marlin_moe_quant_config(
|
||||
def biased_moe_quant_config(
|
||||
w1_bias: torch.Tensor | None,
|
||||
w2_bias: torch.Tensor | None,
|
||||
is_act_and_mul: bool = True,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for unquantized activations with biases.
|
||||
@@ -815,6 +823,7 @@ def biased_moe_quant_config(
|
||||
_a2=FusedMoEQuantDesc(),
|
||||
_w1=FusedMoEQuantDesc(bias=w1_bias),
|
||||
_w2=FusedMoEQuantDesc(bias=w2_bias),
|
||||
is_act_and_mul=is_act_and_mul,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -871,8 +871,11 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
num_dp = self.num_dispatchers
|
||||
num_experts = local_num_experts
|
||||
max_num_tokens = self.max_num_tokens
|
||||
# For fused activations (SwiGLU): N = 2 * intermediate, after act = N/2
|
||||
# For non-fused activations: N = intermediate, after act = N
|
||||
intermediate_size = N // 2 if self.quant_config.is_act_and_mul else N
|
||||
workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N))
|
||||
workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2))
|
||||
workspace2 = (num_experts, max_num_tokens * num_dp, intermediate_size)
|
||||
output = (num_experts, max_num_tokens * num_dp, K)
|
||||
return (workspace13, workspace2, output)
|
||||
|
||||
@@ -947,7 +950,11 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
# We can reuse the memory between these because by the time we need
|
||||
# cache3, we're done with cache1
|
||||
intermediate_cache1 = _resize_cache(workspace13, (E, max_num_tokens, N))
|
||||
intermediate_cache2 = _resize_cache(workspace2, (E, max_num_tokens, N // 2))
|
||||
# For fused activations (SwiGLU): output is N/2, for non-fused: output is N
|
||||
intermediate_size = N // 2 if self.quant_config.is_act_and_mul else N
|
||||
intermediate_cache2 = _resize_cache(
|
||||
workspace2, (E, max_num_tokens, intermediate_size)
|
||||
)
|
||||
|
||||
# TODO(bnell): should this be done for any quantized type?
|
||||
if self.quant_config.use_fp8_w8a8:
|
||||
@@ -978,7 +985,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
# TODO (bnell): use triton utility from batched deep gemm.
|
||||
self.activation(
|
||||
activation,
|
||||
intermediate_cache2.view(-1, N // 2),
|
||||
intermediate_cache2.view(-1, intermediate_size),
|
||||
intermediate_cache1.view(-1, N),
|
||||
)
|
||||
|
||||
|
||||
@@ -2292,7 +2292,10 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
workspace1 = (M, topk, max(N // 2, K))
|
||||
# For fused activations (SwiGLU): N = 2 * intermediate, after act = N/2
|
||||
# For non-fused activations: N = intermediate, after act = N
|
||||
intermediate_size = N // 2 if self.quant_config.is_act_and_mul else N
|
||||
workspace1 = (M, topk, max(intermediate_size, K))
|
||||
workspace2 = (M, topk, max(N, K))
|
||||
output = (M, K)
|
||||
return (workspace1, workspace2, output)
|
||||
@@ -2367,8 +2370,10 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
# Note that the output tensor might be in workspace1
|
||||
intermediate_cache1 = _resize_cache(workspace2, (num_tokens, top_k_num, N))
|
||||
# For fused activations (SwiGLU): output is N/2, for non-fused: output is N
|
||||
intermediate_size = N // 2 if self.quant_config.is_act_and_mul else N
|
||||
intermediate_cache2 = _resize_cache(
|
||||
workspace13, (num_tokens * top_k_num, N // 2)
|
||||
workspace13, (num_tokens * top_k_num, intermediate_size)
|
||||
)
|
||||
intermediate_cache3 = _resize_cache(workspace2, (num_tokens, top_k_num, K))
|
||||
|
||||
|
||||
@@ -603,9 +603,15 @@ class FusedMoE(CustomOp):
|
||||
"is_act_and_mul=False is supported only for unquantized "
|
||||
", ModelOpt FP8, and ModelOpt NvFp4 checkpoints"
|
||||
)
|
||||
if not current_platform.is_cuda():
|
||||
# ROCm without AITER MoE uses Triton which supports
|
||||
# is_act_and_mul=False via standard PyTorch ops (F.silu, F.gelu)
|
||||
rocm_without_aiter_moe = (
|
||||
current_platform.is_rocm() and not rocm_aiter_ops.is_fused_moe_enabled()
|
||||
)
|
||||
if not current_platform.is_cuda() and not rocm_without_aiter_moe:
|
||||
raise NotImplementedError(
|
||||
"is_act_and_mul=False is supported only for CUDA for now"
|
||||
"is_act_and_mul=False is supported only for CUDA, or ROCm "
|
||||
"(when AITER MoE is disabled) for now"
|
||||
)
|
||||
|
||||
if self.enable_eplb and not self.quant_method.supports_eplb:
|
||||
|
||||
@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from math import prod
|
||||
from math import prod, sqrt
|
||||
from typing import final
|
||||
|
||||
import torch
|
||||
@@ -575,14 +575,35 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
def activation(
|
||||
self, activation: str, output: torch.Tensor, input: torch.Tensor
|
||||
) -> None:
|
||||
assert output.size(-1) * 2 == input.size(-1)
|
||||
# Fused activations (SwiGLU-style): output is half the size of input
|
||||
if activation == "silu":
|
||||
assert output.size(-1) * 2 == input.size(-1)
|
||||
torch.ops._C.silu_and_mul(output, input)
|
||||
elif activation == "gelu":
|
||||
assert output.size(-1) * 2 == input.size(-1)
|
||||
torch.ops._C.gelu_and_mul(output, input)
|
||||
elif activation == "swigluoai":
|
||||
# alpha = 1.702, limit = 7.0
|
||||
assert output.size(-1) * 2 == input.size(-1)
|
||||
torch.ops._C.swigluoai_and_mul(output, input)
|
||||
# Non-fused activations (is_act_and_mul=False): output same size as input
|
||||
elif activation == "silu_no_mul":
|
||||
assert output.size(-1) == input.size(-1)
|
||||
# Use out= argument to avoid intermediate tensor
|
||||
torch.sigmoid(input, out=output)
|
||||
output.mul_(input)
|
||||
elif activation == "gelu_no_mul":
|
||||
assert output.size(-1) == input.size(-1)
|
||||
# GELU(x) = 0.5 * x * (1 + erf(x / sqrt(2)))
|
||||
# Use out= and in-place ops to avoid intermediate tensors
|
||||
output.copy_(input).div_(sqrt(2))
|
||||
torch.erf(output, out=output)
|
||||
output.add_(1).mul_(input).mul_(0.5)
|
||||
elif activation == "relu2_no_mul":
|
||||
assert output.size(-1) == input.size(-1)
|
||||
# ReLU²: clamp has out=, then in-place square
|
||||
torch.clamp(input, min=0, out=output)
|
||||
output.square_()
|
||||
else:
|
||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
||||
|
||||
|
||||
@@ -299,7 +299,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
return biased_moe_quant_config(
|
||||
layer.w13_bias,
|
||||
layer.w2_bias,
|
||||
is_act_and_mul=self.moe.is_act_and_mul,
|
||||
)
|
||||
elif not self.moe.is_act_and_mul:
|
||||
# Create a config with is_act_and_mul=False since
|
||||
# FUSED_MOE_UNQUANTIZED_CONFIG has is_act_and_mul=True
|
||||
return FusedMoEQuantConfig.make(is_act_and_mul=False)
|
||||
else:
|
||||
return FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
|
||||
|
||||
Reference in New Issue
Block a user