[FIX] Add NO_MUL activation support for modular kernel path (#31528)

Signed-off-by: dafrimi <dafrimi@nvidia.com>
Signed-off-by: <>
Co-authored-by: root <root@gpu-267.slurm-workers-slurm.slurm.svc.cluster.local>
Co-authored-by: root <root@gpu-537.slurm-workers-slurm.slurm.svc.cluster.local>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: root <root@pool0-01777.cm.cluster>
This commit is contained in:
danielafrimi
2026-01-12 18:55:49 +02:00
committed by GitHub
parent 6bc9c8473e
commit 3f72639d36
17 changed files with 368 additions and 71 deletions

View File

@@ -0,0 +1,201 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for MoE with non-gated activations (*_no_mul).
These tests verify that MoE layers work correctly with activations like
silu_no_mul, gelu_no_mul, relu2_no_mul where the activation output dimension
equals N (not N // 2 like gated activations).
"""
import pytest
import torch
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
)
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
from vllm.model_executor.layers.fused_moe.utils import (
GELU_NO_MUL,
RELU2_NO_MUL,
SILU_NO_MUL,
)
from vllm.platforms import current_platform
# Test parameters
M_SIZES = [1, 16, 64]
N_SIZES = [128, 256]
K_SIZES = [64, 128]
TOPK_VALUES = [1, 2]
NUM_EXPERTS = 8
NO_MUL_ACTIVATIONS = [SILU_NO_MUL, GELU_NO_MUL, RELU2_NO_MUL]
def make_test_tensors(
m: int,
n: int,
k: int,
num_experts: int,
topk: int,
dtype: torch.dtype = torch.bfloat16,
device: str = "cuda",
):
"""Create test tensors for MoE with non-gated activation.
For non-gated activations (*_no_mul):
- w1: (E, N, K) - projects from K to N
- w2: (E, K, N) - projects from N back to K (note: N, not N//2)
"""
hidden_states = torch.randn(m, k, dtype=dtype, device=device)
# For non-gated: w1 projects K -> N, w2 projects N -> K
w1 = torch.randn(num_experts, n, k, dtype=dtype, device=device) * 0.1
w2 = torch.randn(num_experts, k, n, dtype=dtype, device=device) * 0.1
topk_weights = torch.ones(m, topk, dtype=torch.float32, device=device) / topk
topk_ids = torch.randint(0, num_experts, (m, topk), device=device)
return hidden_states, w1, w2, topk_weights, topk_ids
@pytest.mark.skipif(
not current_platform.has_device_capability(80),
reason="Requires compute capability >= 8.0",
)
@pytest.mark.parametrize("m", M_SIZES)
@pytest.mark.parametrize("n", N_SIZES)
@pytest.mark.parametrize("k", K_SIZES)
@pytest.mark.parametrize("topk", TOPK_VALUES)
@pytest.mark.parametrize("activation", NO_MUL_ACTIVATIONS)
@torch.inference_mode()
def test_triton_experts_no_mul_activation(
m: int,
n: int,
k: int,
topk: int,
activation: str,
):
hidden_states, w1, w2, topk_weights, topk_ids = make_test_tensors(
m, n, k, NUM_EXPERTS, topk
)
experts = TritonExperts(FUSED_MOE_UNQUANTIZED_CONFIG)
ws1_shape, ws2_shape, out_shape = experts.workspace_shapes(
M=m,
N=n,
K=k,
topk=topk,
global_num_experts=NUM_EXPERTS,
local_num_experts=NUM_EXPERTS,
expert_tokens_meta=None,
activation=activation,
)
# Verify workspace shapes are correct for no_mul activation
# workspace1 should handle activation_out_dim = N (not N//2)
assert ws1_shape == (m, topk, max(n, k)), (
f"workspace1 shape mismatch: expected {(m, topk, max(n, k))}, got {ws1_shape}"
)
# workspace2 should handle max(N, K) for intermediate_cache1/cache3
assert ws2_shape == (m, topk, max(n, k)), (
f"workspace2 shape mismatch: expected {(m, topk, max(n, k))}, got {ws2_shape}"
)
assert out_shape == (m, k), (
f"output shape mismatch: expected {(m, k)}, got {out_shape}"
)
workspace1 = torch.empty(
ws1_shape[0] * ws1_shape[1] * ws1_shape[2],
dtype=hidden_states.dtype,
device=hidden_states.device,
)
workspace2 = torch.empty(
ws2_shape[0] * ws2_shape[1] * ws2_shape[2],
dtype=hidden_states.dtype,
device=hidden_states.device,
)
output = torch.zeros(m, k, dtype=hidden_states.dtype, device=hidden_states.device)
experts.apply(
output=output,
hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
global_num_experts=NUM_EXPERTS,
expert_map=None,
a1q_scale=None,
a2_scale=None,
workspace13=workspace1,
workspace2=workspace2,
expert_tokens_meta=None,
apply_router_weight_on_input=False,
)
assert output.shape == (m, k), f"Expected shape {(m, k)}, got {output.shape}"
assert not torch.isnan(output).any(), "Output contains NaN"
assert not torch.isinf(output).any(), "Output contains Inf"
assert output.abs().sum() > 0, "Output is all zeros"
@pytest.mark.skipif(
not current_platform.has_device_capability(80),
reason="Requires compute capability >= 8.0",
)
@torch.inference_mode()
def test_workspace_shapes_no_mul_vs_gated():
"""Test that workspace shapes differ correctly between gated and non-gated."""
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
M, N, K, topk = 64, 256, 128, 2
experts = TritonExperts(FUSED_MOE_UNQUANTIZED_CONFIG)
ws1_no_mul, _, out_no_mul = experts.workspace_shapes(
M, N, K, topk, 8, 8, None, SILU_NO_MUL
)
ws1_gated, _, out_gated = experts.workspace_shapes(
M, N, K, topk, 8, 8, None, "silu"
)
# For no_mul: activation_out_dim = N
# For gated: activation_out_dim = N // 2
# workspace1 should use max(activation_out_dim, K)
activation_out_dim_no_mul = N
activation_out_dim_gated = N // 2
assert ws1_no_mul[2] == max(activation_out_dim_no_mul, K), (
f"no_mul workspace1 last dim should be max({activation_out_dim_no_mul}, {K})"
)
assert ws1_gated[2] == max(activation_out_dim_gated, K), (
f"gated workspace1 last dim should be max({activation_out_dim_gated}, {K})"
)
# Output shapes should be the same
assert out_no_mul == out_gated == (M, K)
@pytest.mark.skipif(
not current_platform.has_device_capability(80),
reason="Requires compute capability >= 8.0",
)
@torch.inference_mode()
def test_adjust_n_for_activation():
"""Test the adjust_N_for_activation method."""
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
experts = TritonExperts(FUSED_MOE_UNQUANTIZED_CONFIG)
N = 256
# Gated activations should return N // 2
assert experts.adjust_N_for_activation(N, "silu") == N // 2
assert experts.adjust_N_for_activation(N, "gelu") == N // 2
# Non-gated activations should return N
assert experts.adjust_N_for_activation(N, SILU_NO_MUL) == N
assert experts.adjust_N_for_activation(N, GELU_NO_MUL) == N
assert experts.adjust_N_for_activation(N, RELU2_NO_MUL) == N

View File

@@ -305,6 +305,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# FIXME (varun): We should be able to dispatch only from the leader # FIXME (varun): We should be able to dispatch only from the leader
# DP ranks in the case of TP > 1. At the moment, all the Ranks # DP ranks in the case of TP > 1. At the moment, all the Ranks
@@ -312,8 +313,9 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
num_dispatchers = self.num_dispatchers num_dispatchers = self.num_dispatchers
num_experts = local_num_experts num_experts = local_num_experts
max_num_tokens = M if self.max_num_tokens is None else self.max_num_tokens max_num_tokens = M if self.max_num_tokens is None else self.max_num_tokens
activation_out_dim = self.adjust_N_for_activation(N, activation)
workspace13 = (num_experts, max_num_tokens * num_dispatchers, max(K, N)) workspace13 = (num_experts, max_num_tokens * num_dispatchers, max(K, N))
workspace2 = (num_experts, max_num_tokens * num_dispatchers, (N // 2)) workspace2 = (num_experts, max_num_tokens * num_dispatchers, activation_out_dim)
output = (num_experts, max_num_tokens * num_dispatchers, K) output = (num_experts, max_num_tokens * num_dispatchers, K)
return (workspace13, workspace2, output) return (workspace13, workspace2, output)

View File

@@ -355,9 +355,11 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
activation_out_dim = self.adjust_N_for_activation(N, activation)
workspace1 = (M * topk, max(N, K)) workspace1 = (M * topk, max(N, K))
workspace2 = (M * topk, max(N // 2, K)) workspace2 = (M * topk, max(activation_out_dim, K))
output = (M, K) output = (M, K)
return (workspace1, workspace2, output) return (workspace1, workspace2, output)
@@ -402,11 +404,17 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
num_dp = self.num_dispatchers num_dp = self.num_dispatchers
assert num_dp is not None assert num_dp is not None
activation_out_dim = self.adjust_N_for_activation(N, activation)
workspace1 = (self.max_experts_per_worker, M * num_dp, max(N, K)) workspace1 = (self.max_experts_per_worker, M * num_dp, max(N, K))
workspace2 = (self.max_experts_per_worker, M * num_dp, max(N // 2, K)) workspace2 = (
self.max_experts_per_worker,
M * num_dp,
max(activation_out_dim, K),
)
output = (self.max_experts_per_worker, M, K) output = (self.max_experts_per_worker, M, K)
return (workspace1, workspace2, output) return (workspace1, workspace2, output)
@@ -635,13 +643,15 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
activation_out_dim = self.adjust_N_for_activation(N, activation)
workspace1: tuple[int, ...] = () workspace1: tuple[int, ...] = ()
workspace2: tuple[int, ...] = () workspace2: tuple[int, ...] = ()
output: tuple[int, ...] = () output: tuple[int, ...] = ()
if self.use_batched_format: if self.use_batched_format:
workspace1 = (self.max_experts_per_worker, M, max(N, K)) workspace1 = (self.max_experts_per_worker, M, max(N, K))
workspace2 = (self.max_experts_per_worker, M, (N // 2)) workspace2 = (self.max_experts_per_worker, M, activation_out_dim)
output = (self.max_experts_per_worker, M, K) output = (self.max_experts_per_worker, M, K)
else: else:
workspace1 = (M * topk, max(2 * N, K)) workspace1 = (M * topk, max(2 * N, K))
@@ -896,9 +906,11 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
activation_out_dim = self.adjust_N_for_activation(N, activation)
workspace1 = (M * topk, max(N, K)) workspace1 = (M * topk, max(N, K))
workspace2 = (M * topk, max(N // 2, K)) workspace2 = (M * topk, max(activation_out_dim, K))
output = (M, K) output = (M, K)
return (workspace1, workspace2, output) return (workspace1, workspace2, output)

View File

@@ -143,6 +143,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
assert self.block_shape is not None assert self.block_shape is not None
block_m = self.block_shape[0] block_m = self.block_shape[0]
@@ -151,7 +152,8 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
) )
assert M_sum % block_m == 0 assert M_sum % block_m == 0
workspace1 = (M_sum, max(N // 2, K)) activation_out_dim = self.adjust_N_for_activation(N, activation)
workspace1 = (M_sum, max(activation_out_dim, K))
workspace2 = (M_sum, max(N, K)) workspace2 = (M_sum, max(N, K))
output = (M, K) output = (M, K)
return (workspace1, workspace2, output) return (workspace1, workspace2, output)
@@ -163,11 +165,13 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
block_k = self.block_shape[1] block_k = self.block_shape[1]
scale_fmt = DeepGemmQuantScaleFMT.from_oracle() scale_fmt = DeepGemmQuantScaleFMT.from_oracle()
M_sum, N = input.size()
activation_out_dim = self.adjust_N_for_activation(N, activation)
# 1. DeepGemm UE8M0: use packed per-token-group quant # 1. DeepGemm UE8M0: use packed per-token-group quant
if scale_fmt == DeepGemmQuantScaleFMT.UE8M0: if scale_fmt == DeepGemmQuantScaleFMT.UE8M0:
M_sum, N = input.size()
act_out = torch.empty( act_out = torch.empty(
(M_sum, N // 2), dtype=input.dtype, device=input.device (M_sum, activation_out_dim), dtype=input.dtype, device=input.device
) )
self.activation(activation, act_out, input) self.activation(activation, act_out, input)
a2q, a2q_scale = per_token_group_quant_fp8_packed_for_deepgemm( a2q, a2q_scale = per_token_group_quant_fp8_packed_for_deepgemm(
@@ -187,8 +191,9 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
) )
# 3. fallback path for non-SiLU activations in nonUE8M0 cases. # 3. fallback path for non-SiLU activations in nonUE8M0 cases.
M_sum, N = input.size() act_out = torch.empty(
act_out = torch.empty((M_sum, N // 2), dtype=input.dtype, device=input.device) (M_sum, activation_out_dim), dtype=input.dtype, device=input.device
)
self.activation(activation, act_out, input) self.activation(activation, act_out, input)
return per_token_group_quant_fp8( return per_token_group_quant_fp8(
act_out, block_k, column_major_scales=True, out_q=output act_out, block_k, column_major_scales=True, out_q=output
@@ -254,8 +259,9 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
(a1q, a1q_scale), (w1, self.w1_scale), mm1_out, expert_ids (a1q, a1q_scale), (w1, self.w1_scale), mm1_out, expert_ids
) )
activation_out_dim = self.adjust_N_for_activation(N, activation)
quant_out = _resize_cache( quant_out = _resize_cache(
workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, N // 2) workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, activation_out_dim)
) )
a2q, a2q_scale = self._act_mul_quant( a2q, a2q_scale = self._act_mul_quant(
input=mm1_out.view(-1, N), output=quant_out, activation=activation input=mm1_out.view(-1, N), output=quant_out, activation=activation

View File

@@ -76,6 +76,7 @@ class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
raise NotImplementedError raise NotImplementedError

View File

@@ -91,6 +91,7 @@ class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# We use global_num_experts due to how moe_align_block_size handles # We use global_num_experts due to how moe_align_block_size handles
# expert_maps. # expert_maps.

View File

@@ -103,6 +103,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# We use global_num_experts due to how moe_align_block_size handles # We use global_num_experts due to how moe_align_block_size handles
# expert_maps. # expert_maps.

View File

@@ -673,6 +673,7 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
num_dp = self.num_dispatchers num_dp = self.num_dispatchers
num_experts = local_num_experts num_experts = local_num_experts
@@ -867,12 +868,14 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
num_dp = self.num_dispatchers num_dp = self.num_dispatchers
num_experts = local_num_experts num_experts = local_num_experts
max_num_tokens = self.max_num_tokens max_num_tokens = self.max_num_tokens
activation_out_dim = self.adjust_N_for_activation(N, activation)
workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N)) workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N))
workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2)) workspace2 = (num_experts, max_num_tokens * num_dp, activation_out_dim)
output = (num_experts, max_num_tokens * num_dp, K) output = (num_experts, max_num_tokens * num_dp, K)
return (workspace13, workspace2, output) return (workspace13, workspace2, output)
@@ -947,7 +950,10 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
# We can reuse the memory between these because by the time we need # We can reuse the memory between these because by the time we need
# cache3, we're done with cache1 # cache3, we're done with cache1
intermediate_cache1 = _resize_cache(workspace13, (E, max_num_tokens, N)) intermediate_cache1 = _resize_cache(workspace13, (E, max_num_tokens, N))
intermediate_cache2 = _resize_cache(workspace2, (E, max_num_tokens, N // 2)) activation_out_dim = self.adjust_N_for_activation(N, activation)
intermediate_cache2 = _resize_cache(
workspace2, (E, max_num_tokens, activation_out_dim)
)
# TODO(bnell): should this be done for any quantized type? # TODO(bnell): should this be done for any quantized type?
if self.quant_config.use_fp8_w8a8: if self.quant_config.use_fp8_w8a8:
@@ -978,7 +984,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
# TODO (bnell): use triton utility from batched deep gemm. # TODO (bnell): use triton utility from batched deep gemm.
self.activation( self.activation(
activation, activation,
intermediate_cache2.view(-1, N // 2), intermediate_cache2.view(-1, activation_out_dim),
intermediate_cache1.view(-1, N), intermediate_cache1.view(-1, N),
) )

View File

@@ -640,6 +640,7 @@ class MarlinExperts(MarlinExpertsBase):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# Modular Kernel provisions output buffer from workspace1. However in # Modular Kernel provisions output buffer from workspace1. However in
# the fused_marlin_moe() function, the final torch.sum(), is defined # the fused_marlin_moe() function, the final torch.sum(), is defined
@@ -768,6 +769,7 @@ class BatchedMarlinExperts(MarlinExpertsBase):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
num_dispatchers = self.num_dispatchers num_dispatchers = self.num_dispatchers
num_experts = local_num_experts num_experts = local_num_experts

View File

@@ -9,7 +9,6 @@ from collections.abc import Callable
from typing import Any from typing import Any
import torch import torch
import torch.nn.functional as F
import vllm.envs as envs import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
@@ -43,7 +42,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
) )
from vllm.model_executor.layers.fused_moe.utils import ( from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, _resize_cache,
activation_without_mul, apply_moe_activation,
disable_inplace, disable_inplace,
moe_kernel_quantize_input, moe_kernel_quantize_input,
) )
@@ -1957,11 +1956,6 @@ def fused_experts(
) )
SILU_NO_MUL: str = activation_without_mul("silu")
GELU_NO_MUL: str = activation_without_mul("gelu")
RELU2_NO_MUL: str = activation_without_mul("relu2")
def _get_config_quant_dtype( def _get_config_quant_dtype(
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a8: bool, use_int8_w8a8: bool,
@@ -2094,8 +2088,13 @@ def fused_experts_impl(
intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K) intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K)
# This needs separate memory since it's used concurrently with cache1 # This needs separate memory since it's used concurrently with cache1
activation_out_dim = mk.FusedMoEPermuteExpertsUnpermute.adjust_N_for_activation(
N, activation
)
intermediate_cache2 = torch.empty( intermediate_cache2 = torch.empty(
(M * top_k_num, N // 2), device=hidden_states.device, dtype=hidden_states.dtype (M * top_k_num, activation_out_dim),
device=hidden_states.device,
dtype=hidden_states.dtype,
) )
if hidden_states.dtype == torch.bfloat16: if hidden_states.dtype == torch.bfloat16:
@@ -2235,29 +2234,9 @@ def fused_experts_impl(
B_bias=w1_bias, B_bias=w1_bias,
) )
# Activation function with multiplication apply_moe_activation(
if activation == "silu": activation, intermediate_cache2, intermediate_cache1.view(-1, N)
torch.ops._C.silu_and_mul( )
intermediate_cache2, intermediate_cache1.view(-1, N)
)
elif activation == "gelu":
torch.ops._C.gelu_and_mul(
intermediate_cache2, intermediate_cache1.view(-1, N)
)
elif activation == "swigluoai":
# alpha = 1.702, limit = 7.0
torch.ops._C.swigluoai_and_mul(
intermediate_cache2, intermediate_cache1.view(-1, N)
)
# Activation function without multiplication
elif activation == SILU_NO_MUL:
intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N))
elif activation == GELU_NO_MUL:
intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N))
elif activation == RELU2_NO_MUL:
intermediate_cache2 = torch.square(F.relu(intermediate_cache1.view(-1, N)))
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}.")
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
A=intermediate_cache2, A=intermediate_cache2,
@@ -2336,8 +2315,10 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
workspace1 = (M, topk, max(N // 2, K)) activation_out_dim = self.adjust_N_for_activation(N, activation)
workspace1 = (M, topk, max(activation_out_dim, K))
workspace2 = (M, topk, max(N, K)) workspace2 = (M, topk, max(N, K))
output = (M, K) output = (M, K)
return (workspace1, workspace2, output) return (workspace1, workspace2, output)
@@ -2412,8 +2393,9 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
# Note that the output tensor might be in workspace1 # Note that the output tensor might be in workspace1
intermediate_cache1 = _resize_cache(workspace2, (num_tokens, top_k_num, N)) intermediate_cache1 = _resize_cache(workspace2, (num_tokens, top_k_num, N))
cache2_dim = self.adjust_N_for_activation(N, activation)
intermediate_cache2 = _resize_cache( intermediate_cache2 = _resize_cache(
workspace13, (num_tokens * top_k_num, N // 2) workspace13, (num_tokens * top_k_num, cache2_dim)
) )
intermediate_cache3 = _resize_cache(workspace2, (num_tokens, top_k_num, K)) intermediate_cache3 = _resize_cache(workspace2, (num_tokens, top_k_num, K))
@@ -2565,8 +2547,9 @@ class TritonWNA16Experts(TritonExperts):
# Note that the output tensor might be in workspace1 # Note that the output tensor might be in workspace1
intermediate_cache1 = _resize_cache(workspace2, (num_tokens, top_k_num, N)) intermediate_cache1 = _resize_cache(workspace2, (num_tokens, top_k_num, N))
activation_out_dim = self.adjust_N_for_activation(N, activation)
intermediate_cache2 = _resize_cache( intermediate_cache2 = _resize_cache(
workspace13, (num_tokens * top_k_num, N // 2) workspace13, (num_tokens * top_k_num, activation_out_dim)
) )
intermediate_cache3 = _resize_cache(workspace2, (num_tokens, top_k_num, K)) intermediate_cache3 = _resize_cache(workspace2, (num_tokens, top_k_num, K))

View File

@@ -323,10 +323,12 @@ class OAITritonExperts(BaseOAITritonExperts):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# workspace are allocated inside the kernel # workspace are allocated inside the kernel
activation_out_dim = self.adjust_N_for_activation(N, activation)
workspace1 = (0, 0) workspace1 = (0, 0)
workspace2 = (M * topk, N // 2) workspace2 = (M * topk, activation_out_dim)
output = (M, K) output = (M, K)
return (workspace1, workspace2, output) return (workspace1, workspace2, output)
@@ -415,9 +417,11 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# workspace are allocated inside the kernel # workspace are allocated inside the kernel
workspace1 = (M * topk, N // 2) activation_out_dim = self.adjust_N_for_activation(N, activation)
workspace1 = (M * topk, activation_out_dim)
workspace2 = (M * topk, max(N, K)) workspace2 = (M * topk, max(N, K))
output = (M, K) output = (M, K)
return (workspace1, workspace2, output) return (workspace1, workspace2, output)
@@ -443,8 +447,10 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
): ):
if self.quant_config is None: # Use local variable to help mypy narrow the type after None check
self.quant_config = FUSED_MOE_UNQUANTIZED_CONFIG quant_config = self.quant_config
if quant_config is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
if expert_map is not None: if expert_map is not None:
topk_ids = expert_map[topk_ids] topk_ids = expert_map[topk_ids]
@@ -462,12 +468,10 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
# type check, uint8 means mxfp4 # type check, uint8 means mxfp4
assert hidden_states.dtype == torch.bfloat16 assert hidden_states.dtype == torch.bfloat16
assert ( assert (
self.quant_config.w1_bias is None quant_config.w1_bias is None or quant_config.w1_bias.dtype == torch.float32
or self.quant_config.w1_bias.dtype == torch.float32
) )
assert ( assert (
self.quant_config.w2_bias is None quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32
or self.quant_config.w2_bias.dtype == torch.float32
) )
# Shape check, only check non-mxfp4 # Shape check, only check non-mxfp4
@@ -485,17 +489,18 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
# Note that the output tensor might be in workspace13 # Note that the output tensor might be in workspace13
intermediate_cache1 = _resize_cache(workspace2, (batch_dim, M * topk, N)) intermediate_cache1 = _resize_cache(workspace2, (batch_dim, M * topk, N))
intermediate_cache3 = _resize_cache(workspace2, (batch_dim, M * topk, K)) intermediate_cache3 = _resize_cache(workspace2, (batch_dim, M * topk, K))
intermediate_cache2 = _resize_cache(workspace13, (M * topk, N // 2)) activation_out_dim = self.adjust_N_for_activation(N, activation)
intermediate_cache2 = _resize_cache(workspace13, (M * topk, activation_out_dim))
gammas = routing_data.gate_scal if routing_data else None gammas = routing_data.gate_scal if routing_data else None
matmul_ogs( matmul_ogs(
hidden_states, hidden_states,
w1, w1,
self.quant_config.w1_bias, quant_config.w1_bias,
routing_data, routing_data,
gather_indx=gather_indx, gather_indx=gather_indx,
precision_config=self.quant_config.w1_precision, precision_config=quant_config.w1_precision,
gammas=gammas if apply_router_weight_on_input else None, gammas=gammas if apply_router_weight_on_input else None,
fused_activation=None, fused_activation=None,
y=intermediate_cache1, y=intermediate_cache1,
@@ -515,10 +520,10 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
matmul_ogs( matmul_ogs(
intermediate_cache2[gather_indx.src_indx], intermediate_cache2[gather_indx.src_indx],
w2, w2,
self.quant_config.w2_bias, quant_config.w2_bias,
routing_data, routing_data,
scatter_indx=scatter_indx, scatter_indx=scatter_indx,
precision_config=self.quant_config.w2_precision, precision_config=quant_config.w2_precision,
gammas=None if apply_router_weight_on_input else gammas, gammas=None if apply_router_weight_on_input else gammas,
y=intermediate_cache3, y=intermediate_cache3,
) )

View File

@@ -18,6 +18,7 @@ from vllm.model_executor.layers.fused_moe.config import (
) )
from vllm.model_executor.layers.fused_moe.utils import ( from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, _resize_cache,
apply_moe_activation,
count_expert_num_tokens, count_expert_num_tokens,
disable_inplace, disable_inplace,
) )
@@ -542,6 +543,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: ExpertTokensMetadata | None, expert_tokens_meta: ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
""" """
Compute the shapes for the temporary and final outputs of the two gemms Compute the shapes for the temporary and final outputs of the two gemms
@@ -572,19 +574,31 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
""" """
raise NotImplementedError raise NotImplementedError
@staticmethod
def adjust_N_for_activation(N: int, activation: str) -> int:
"""
Calculate the output dimension for the activation function.
For *_no_mul activations (e.g. relu2_no_mul),
there's no gate/up split, so output size equals input size (N).
For regular gated activations (e.g., silu, gelu, swigluoai),
output size is N // 2 due to gate × activation(up) multiplication.
Args:
N: The intermediate size (width of w1/w3 weights).
activation: The activation function name.
Returns:
The output dimension after activation.
"""
is_no_mul = activation.endswith("_no_mul")
return N if is_no_mul else N // 2
def activation( def activation(
self, activation: str, output: torch.Tensor, input: torch.Tensor self, activation: str, output: torch.Tensor, input: torch.Tensor
) -> None: ) -> None:
assert output.size(-1) * 2 == input.size(-1) apply_moe_activation(activation, output, input)
if activation == "silu":
torch.ops._C.silu_and_mul(output, input)
elif activation == "gelu":
torch.ops._C.gelu_and_mul(output, input)
elif activation == "swigluoai":
# alpha = 1.702, limit = 7.0
torch.ops._C.swigluoai_and_mul(output, input)
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
def enable_chunking(self): def enable_chunking(self):
return ( return (
@@ -761,6 +775,7 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: ExpertTokensMetadata | None, expert_tokens_meta: ExpertTokensMetadata | None,
activation: str,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Allocate temporary and output buffers for the fused experts op. Allocate temporary and output buffers for the fused experts op.
@@ -796,6 +811,7 @@ class FusedMoEModularKernel(torch.nn.Module):
# amount of workspace. Mark it None, so we allocate for # amount of workspace. Mark it None, so we allocate for
# the worst-case scenario. # the worst-case scenario.
expert_tokens_meta=None, expert_tokens_meta=None,
activation=activation,
) )
) )
@@ -814,6 +830,7 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts, global_num_experts,
local_num_experts, local_num_experts,
expert_tokens_meta, expert_tokens_meta,
activation,
) )
# Get final output shape based on the full M size. # Get final output shape based on the full M size.
@@ -825,6 +842,7 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts, global_num_experts,
local_num_experts, local_num_experts,
expert_tokens_meta, expert_tokens_meta,
activation,
) )
# We can reuse the memory between cache1 and cache3 because by the # We can reuse the memory between cache1 and cache3 because by the
@@ -1043,6 +1061,7 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts, global_num_experts,
local_num_experts, local_num_experts,
expert_tokens_meta, expert_tokens_meta,
activation,
) )
for chunk_idx in range(num_chunks): for chunk_idx in range(num_chunks):

View File

@@ -299,6 +299,7 @@ class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# Workspaces are managed internally by AITER. # Workspaces are managed internally by AITER.
workspace1 = (0,) workspace1 = (0,)

View File

@@ -39,6 +39,7 @@ class TritonOrCutlassExperts(FallbackExperts):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# Small batch fallback for sm100. # Small batch fallback for sm100.
if self.is_sm100 and M <= 8: if self.is_sm100 and M <= 8:
@@ -50,6 +51,7 @@ class TritonOrCutlassExperts(FallbackExperts):
global_num_experts, global_num_experts,
local_num_experts, local_num_experts,
expert_tokens_meta, expert_tokens_meta,
activation,
) )
else: else:
return self.experts.workspace_shapes( return self.experts.workspace_shapes(
@@ -60,6 +62,7 @@ class TritonOrCutlassExperts(FallbackExperts):
global_num_experts, global_num_experts,
local_num_experts, local_num_experts,
expert_tokens_meta, expert_tokens_meta,
activation,
) )
def _select_experts_impl( def _select_experts_impl(

View File

@@ -35,6 +35,7 @@ class TritonOrDeepGemmExperts(FallbackExperts):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# Note: the deep gemm workspaces are strictly larger than the triton # Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm # workspaces so we can be pessimistic here and allocate for DeepGemm
@@ -48,6 +49,7 @@ class TritonOrDeepGemmExperts(FallbackExperts):
global_num_experts, global_num_experts,
local_num_experts, local_num_experts,
expert_tokens_meta, expert_tokens_meta,
activation,
) )
else: else:
return self.fallback_experts.workspace_shapes( return self.fallback_experts.workspace_shapes(
@@ -58,6 +60,7 @@ class TritonOrDeepGemmExperts(FallbackExperts):
global_num_experts, global_num_experts,
local_num_experts, local_num_experts,
expert_tokens_meta, expert_tokens_meta,
activation,
) )
def _select_experts_impl( def _select_experts_impl(

View File

@@ -57,6 +57,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int, global_num_experts: int,
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# The workspaces for this implementation are managed by flashinfer. # The workspaces for this implementation are managed by flashinfer.
workspace1 = (0,) workspace1 = (0,)

View File

@@ -4,6 +4,7 @@ import functools
from math import prod from math import prod
import torch import torch
import torch.nn.functional as F
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
@@ -324,6 +325,55 @@ def activation_without_mul(activation: str) -> str:
return activation + "_no_mul" return activation + "_no_mul"
RELU2_NO_MUL: str = activation_without_mul("relu2")
SILU_NO_MUL: str = activation_without_mul("silu")
GELU_NO_MUL: str = activation_without_mul("gelu")
def apply_moe_activation(
activation: str,
output: torch.Tensor,
input: torch.Tensor,
) -> torch.Tensor:
"""
Apply MoE activation function.
For *_and_mul activations (silu, gelu, swigluoai):
- Expects output.size(-1) * 2 == input.size(-1)
For *_no_mul activations (silu_no_mul, gelu_no_mul, relu2_no_mul):
- Expects output.size(-1) == input.size(-1)
"""
is_no_mul = activation.endswith("_no_mul")
if is_no_mul:
assert output.size(-1) == input.size(-1), (
f"{activation} expects equal sizes: {output.size(-1)} vs {input.size(-1)}"
)
else:
assert output.size(-1) * 2 == input.size(-1), (
f"{activation} expects 2x ratio: {output.size(-1) * 2} vs {input.size(-1)}"
)
# Activations with gated multiplication (gate × activation(up))
if activation == "silu":
torch.ops._C.silu_and_mul(output, input)
elif activation == "gelu":
torch.ops._C.gelu_and_mul(output, input)
elif activation == "swigluoai":
torch.ops._C.swigluoai_and_mul(output, input)
# Activations without gated multiplication
elif activation == SILU_NO_MUL:
output.copy_(F.silu(input))
elif activation == GELU_NO_MUL:
output.copy_(F.gelu(input))
elif activation == RELU2_NO_MUL:
torch.square(F.relu(input), out=output)
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
return output
# Torch custom ops can't deal with outputs aliasing inputs so we need to # Torch custom ops can't deal with outputs aliasing inputs so we need to
# disable inplace for torch >= 2.9. # disable inplace for torch >= 2.9.
# See https://github.com/vllm-project/vllm/issues/26378 # See https://github.com/vllm-project/vllm/issues/26378