[FIX] Add NO_MUL activation support for modular kernel path (#31528)

Signed-off-by: dafrimi <dafrimi@nvidia.com>
Signed-off-by: <>
Co-authored-by: root <root@gpu-267.slurm-workers-slurm.slurm.svc.cluster.local>
Co-authored-by: root <root@gpu-537.slurm-workers-slurm.slurm.svc.cluster.local>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: root <root@pool0-01777.cm.cluster>
This commit is contained in:
danielafrimi
2026-01-12 18:55:49 +02:00
committed by GitHub
parent 6bc9c8473e
commit 3f72639d36
17 changed files with 368 additions and 71 deletions

View File

@@ -305,6 +305,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# FIXME (varun): We should be able to dispatch only from the leader
# DP ranks in the case of TP > 1. At the moment, all the Ranks
@@ -312,8 +313,9 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
num_dispatchers = self.num_dispatchers
num_experts = local_num_experts
max_num_tokens = M if self.max_num_tokens is None else self.max_num_tokens
activation_out_dim = self.adjust_N_for_activation(N, activation)
workspace13 = (num_experts, max_num_tokens * num_dispatchers, max(K, N))
workspace2 = (num_experts, max_num_tokens * num_dispatchers, (N // 2))
workspace2 = (num_experts, max_num_tokens * num_dispatchers, activation_out_dim)
output = (num_experts, max_num_tokens * num_dispatchers, K)
return (workspace13, workspace2, output)

View File

@@ -355,9 +355,11 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
activation_out_dim = self.adjust_N_for_activation(N, activation)
workspace1 = (M * topk, max(N, K))
workspace2 = (M * topk, max(N // 2, K))
workspace2 = (M * topk, max(activation_out_dim, K))
output = (M, K)
return (workspace1, workspace2, output)
@@ -402,11 +404,17 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
num_dp = self.num_dispatchers
assert num_dp is not None
activation_out_dim = self.adjust_N_for_activation(N, activation)
workspace1 = (self.max_experts_per_worker, M * num_dp, max(N, K))
workspace2 = (self.max_experts_per_worker, M * num_dp, max(N // 2, K))
workspace2 = (
self.max_experts_per_worker,
M * num_dp,
max(activation_out_dim, K),
)
output = (self.max_experts_per_worker, M, K)
return (workspace1, workspace2, output)
@@ -635,13 +643,15 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
activation_out_dim = self.adjust_N_for_activation(N, activation)
workspace1: tuple[int, ...] = ()
workspace2: tuple[int, ...] = ()
output: tuple[int, ...] = ()
if self.use_batched_format:
workspace1 = (self.max_experts_per_worker, M, max(N, K))
workspace2 = (self.max_experts_per_worker, M, (N // 2))
workspace2 = (self.max_experts_per_worker, M, activation_out_dim)
output = (self.max_experts_per_worker, M, K)
else:
workspace1 = (M * topk, max(2 * N, K))
@@ -896,9 +906,11 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
activation_out_dim = self.adjust_N_for_activation(N, activation)
workspace1 = (M * topk, max(N, K))
workspace2 = (M * topk, max(N // 2, K))
workspace2 = (M * topk, max(activation_out_dim, K))
output = (M, K)
return (workspace1, workspace2, output)

View File

@@ -143,6 +143,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
assert self.block_shape is not None
block_m = self.block_shape[0]
@@ -151,7 +152,8 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
)
assert M_sum % block_m == 0
workspace1 = (M_sum, max(N // 2, K))
activation_out_dim = self.adjust_N_for_activation(N, activation)
workspace1 = (M_sum, max(activation_out_dim, K))
workspace2 = (M_sum, max(N, K))
output = (M, K)
return (workspace1, workspace2, output)
@@ -163,11 +165,13 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
block_k = self.block_shape[1]
scale_fmt = DeepGemmQuantScaleFMT.from_oracle()
M_sum, N = input.size()
activation_out_dim = self.adjust_N_for_activation(N, activation)
# 1. DeepGemm UE8M0: use packed per-token-group quant
if scale_fmt == DeepGemmQuantScaleFMT.UE8M0:
M_sum, N = input.size()
act_out = torch.empty(
(M_sum, N // 2), dtype=input.dtype, device=input.device
(M_sum, activation_out_dim), dtype=input.dtype, device=input.device
)
self.activation(activation, act_out, input)
a2q, a2q_scale = per_token_group_quant_fp8_packed_for_deepgemm(
@@ -187,8 +191,9 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
)
# 3. fallback path for non-SiLU activations in nonUE8M0 cases.
M_sum, N = input.size()
act_out = torch.empty((M_sum, N // 2), dtype=input.dtype, device=input.device)
act_out = torch.empty(
(M_sum, activation_out_dim), dtype=input.dtype, device=input.device
)
self.activation(activation, act_out, input)
return per_token_group_quant_fp8(
act_out, block_k, column_major_scales=True, out_q=output
@@ -254,8 +259,9 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
(a1q, a1q_scale), (w1, self.w1_scale), mm1_out, expert_ids
)
activation_out_dim = self.adjust_N_for_activation(N, activation)
quant_out = _resize_cache(
workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, N // 2)
workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, activation_out_dim)
)
a2q, a2q_scale = self._act_mul_quant(
input=mm1_out.view(-1, N), output=quant_out, activation=activation

View File

@@ -76,6 +76,7 @@ class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
raise NotImplementedError

View File

@@ -91,6 +91,7 @@ class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# We use global_num_experts due to how moe_align_block_size handles
# expert_maps.

View File

@@ -103,6 +103,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# We use global_num_experts due to how moe_align_block_size handles
# expert_maps.

View File

@@ -673,6 +673,7 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
num_dp = self.num_dispatchers
num_experts = local_num_experts
@@ -867,12 +868,14 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
num_dp = self.num_dispatchers
num_experts = local_num_experts
max_num_tokens = self.max_num_tokens
activation_out_dim = self.adjust_N_for_activation(N, activation)
workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N))
workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2))
workspace2 = (num_experts, max_num_tokens * num_dp, activation_out_dim)
output = (num_experts, max_num_tokens * num_dp, K)
return (workspace13, workspace2, output)
@@ -947,7 +950,10 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
# We can reuse the memory between these because by the time we need
# cache3, we're done with cache1
intermediate_cache1 = _resize_cache(workspace13, (E, max_num_tokens, N))
intermediate_cache2 = _resize_cache(workspace2, (E, max_num_tokens, N // 2))
activation_out_dim = self.adjust_N_for_activation(N, activation)
intermediate_cache2 = _resize_cache(
workspace2, (E, max_num_tokens, activation_out_dim)
)
# TODO(bnell): should this be done for any quantized type?
if self.quant_config.use_fp8_w8a8:
@@ -978,7 +984,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
# TODO (bnell): use triton utility from batched deep gemm.
self.activation(
activation,
intermediate_cache2.view(-1, N // 2),
intermediate_cache2.view(-1, activation_out_dim),
intermediate_cache1.view(-1, N),
)

View File

@@ -640,6 +640,7 @@ class MarlinExperts(MarlinExpertsBase):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# Modular Kernel provisions output buffer from workspace1. However in
# the fused_marlin_moe() function, the final torch.sum(), is defined
@@ -768,6 +769,7 @@ class BatchedMarlinExperts(MarlinExpertsBase):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
num_dispatchers = self.num_dispatchers
num_experts = local_num_experts

View File

@@ -9,7 +9,6 @@ from collections.abc import Callable
from typing import Any
import torch
import torch.nn.functional as F
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
@@ -43,7 +42,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
)
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache,
activation_without_mul,
apply_moe_activation,
disable_inplace,
moe_kernel_quantize_input,
)
@@ -1957,11 +1956,6 @@ def fused_experts(
)
SILU_NO_MUL: str = activation_without_mul("silu")
GELU_NO_MUL: str = activation_without_mul("gelu")
RELU2_NO_MUL: str = activation_without_mul("relu2")
def _get_config_quant_dtype(
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
@@ -2094,8 +2088,13 @@ def fused_experts_impl(
intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K)
# This needs separate memory since it's used concurrently with cache1
activation_out_dim = mk.FusedMoEPermuteExpertsUnpermute.adjust_N_for_activation(
N, activation
)
intermediate_cache2 = torch.empty(
(M * top_k_num, N // 2), device=hidden_states.device, dtype=hidden_states.dtype
(M * top_k_num, activation_out_dim),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
if hidden_states.dtype == torch.bfloat16:
@@ -2235,29 +2234,9 @@ def fused_experts_impl(
B_bias=w1_bias,
)
# Activation function with multiplication
if activation == "silu":
torch.ops._C.silu_and_mul(
intermediate_cache2, intermediate_cache1.view(-1, N)
)
elif activation == "gelu":
torch.ops._C.gelu_and_mul(
intermediate_cache2, intermediate_cache1.view(-1, N)
)
elif activation == "swigluoai":
# alpha = 1.702, limit = 7.0
torch.ops._C.swigluoai_and_mul(
intermediate_cache2, intermediate_cache1.view(-1, N)
)
# Activation function without multiplication
elif activation == SILU_NO_MUL:
intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N))
elif activation == GELU_NO_MUL:
intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N))
elif activation == RELU2_NO_MUL:
intermediate_cache2 = torch.square(F.relu(intermediate_cache1.view(-1, N)))
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}.")
apply_moe_activation(
activation, intermediate_cache2, intermediate_cache1.view(-1, N)
)
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
A=intermediate_cache2,
@@ -2336,8 +2315,10 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
workspace1 = (M, topk, max(N // 2, K))
activation_out_dim = self.adjust_N_for_activation(N, activation)
workspace1 = (M, topk, max(activation_out_dim, K))
workspace2 = (M, topk, max(N, K))
output = (M, K)
return (workspace1, workspace2, output)
@@ -2412,8 +2393,9 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
# Note that the output tensor might be in workspace1
intermediate_cache1 = _resize_cache(workspace2, (num_tokens, top_k_num, N))
cache2_dim = self.adjust_N_for_activation(N, activation)
intermediate_cache2 = _resize_cache(
workspace13, (num_tokens * top_k_num, N // 2)
workspace13, (num_tokens * top_k_num, cache2_dim)
)
intermediate_cache3 = _resize_cache(workspace2, (num_tokens, top_k_num, K))
@@ -2565,8 +2547,9 @@ class TritonWNA16Experts(TritonExperts):
# Note that the output tensor might be in workspace1
intermediate_cache1 = _resize_cache(workspace2, (num_tokens, top_k_num, N))
activation_out_dim = self.adjust_N_for_activation(N, activation)
intermediate_cache2 = _resize_cache(
workspace13, (num_tokens * top_k_num, N // 2)
workspace13, (num_tokens * top_k_num, activation_out_dim)
)
intermediate_cache3 = _resize_cache(workspace2, (num_tokens, top_k_num, K))

View File

@@ -323,10 +323,12 @@ class OAITritonExperts(BaseOAITritonExperts):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# workspace are allocated inside the kernel
activation_out_dim = self.adjust_N_for_activation(N, activation)
workspace1 = (0, 0)
workspace2 = (M * topk, N // 2)
workspace2 = (M * topk, activation_out_dim)
output = (M, K)
return (workspace1, workspace2, output)
@@ -415,9 +417,11 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# workspace are allocated inside the kernel
workspace1 = (M * topk, N // 2)
activation_out_dim = self.adjust_N_for_activation(N, activation)
workspace1 = (M * topk, activation_out_dim)
workspace2 = (M * topk, max(N, K))
output = (M, K)
return (workspace1, workspace2, output)
@@ -443,8 +447,10 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
if self.quant_config is None:
self.quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
# Use local variable to help mypy narrow the type after None check
quant_config = self.quant_config
if quant_config is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
if expert_map is not None:
topk_ids = expert_map[topk_ids]
@@ -462,12 +468,10 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
# type check, uint8 means mxfp4
assert hidden_states.dtype == torch.bfloat16
assert (
self.quant_config.w1_bias is None
or self.quant_config.w1_bias.dtype == torch.float32
quant_config.w1_bias is None or quant_config.w1_bias.dtype == torch.float32
)
assert (
self.quant_config.w2_bias is None
or self.quant_config.w2_bias.dtype == torch.float32
quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32
)
# Shape check, only check non-mxfp4
@@ -485,17 +489,18 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
# Note that the output tensor might be in workspace13
intermediate_cache1 = _resize_cache(workspace2, (batch_dim, M * topk, N))
intermediate_cache3 = _resize_cache(workspace2, (batch_dim, M * topk, K))
intermediate_cache2 = _resize_cache(workspace13, (M * topk, N // 2))
activation_out_dim = self.adjust_N_for_activation(N, activation)
intermediate_cache2 = _resize_cache(workspace13, (M * topk, activation_out_dim))
gammas = routing_data.gate_scal if routing_data else None
matmul_ogs(
hidden_states,
w1,
self.quant_config.w1_bias,
quant_config.w1_bias,
routing_data,
gather_indx=gather_indx,
precision_config=self.quant_config.w1_precision,
precision_config=quant_config.w1_precision,
gammas=gammas if apply_router_weight_on_input else None,
fused_activation=None,
y=intermediate_cache1,
@@ -515,10 +520,10 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
matmul_ogs(
intermediate_cache2[gather_indx.src_indx],
w2,
self.quant_config.w2_bias,
quant_config.w2_bias,
routing_data,
scatter_indx=scatter_indx,
precision_config=self.quant_config.w2_precision,
precision_config=quant_config.w2_precision,
gammas=None if apply_router_weight_on_input else gammas,
y=intermediate_cache3,
)

View File

@@ -18,6 +18,7 @@ from vllm.model_executor.layers.fused_moe.config import (
)
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache,
apply_moe_activation,
count_expert_num_tokens,
disable_inplace,
)
@@ -542,6 +543,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
"""
Compute the shapes for the temporary and final outputs of the two gemms
@@ -572,19 +574,31 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
"""
raise NotImplementedError
@staticmethod
def adjust_N_for_activation(N: int, activation: str) -> int:
"""
Calculate the output dimension for the activation function.
For *_no_mul activations (e.g. relu2_no_mul),
there's no gate/up split, so output size equals input size (N).
For regular gated activations (e.g., silu, gelu, swigluoai),
output size is N // 2 due to gate × activation(up) multiplication.
Args:
N: The intermediate size (width of w1/w3 weights).
activation: The activation function name.
Returns:
The output dimension after activation.
"""
is_no_mul = activation.endswith("_no_mul")
return N if is_no_mul else N // 2
def activation(
self, activation: str, output: torch.Tensor, input: torch.Tensor
) -> None:
assert output.size(-1) * 2 == input.size(-1)
if activation == "silu":
torch.ops._C.silu_and_mul(output, input)
elif activation == "gelu":
torch.ops._C.gelu_and_mul(output, input)
elif activation == "swigluoai":
# alpha = 1.702, limit = 7.0
torch.ops._C.swigluoai_and_mul(output, input)
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
apply_moe_activation(activation, output, input)
def enable_chunking(self):
return (
@@ -761,6 +775,7 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: ExpertTokensMetadata | None,
activation: str,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Allocate temporary and output buffers for the fused experts op.
@@ -796,6 +811,7 @@ class FusedMoEModularKernel(torch.nn.Module):
# amount of workspace. Mark it None, so we allocate for
# the worst-case scenario.
expert_tokens_meta=None,
activation=activation,
)
)
@@ -814,6 +830,7 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts,
local_num_experts,
expert_tokens_meta,
activation,
)
# Get final output shape based on the full M size.
@@ -825,6 +842,7 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts,
local_num_experts,
expert_tokens_meta,
activation,
)
# We can reuse the memory between cache1 and cache3 because by the
@@ -1043,6 +1061,7 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts,
local_num_experts,
expert_tokens_meta,
activation,
)
for chunk_idx in range(num_chunks):

View File

@@ -299,6 +299,7 @@ class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# Workspaces are managed internally by AITER.
workspace1 = (0,)

View File

@@ -39,6 +39,7 @@ class TritonOrCutlassExperts(FallbackExperts):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# Small batch fallback for sm100.
if self.is_sm100 and M <= 8:
@@ -50,6 +51,7 @@ class TritonOrCutlassExperts(FallbackExperts):
global_num_experts,
local_num_experts,
expert_tokens_meta,
activation,
)
else:
return self.experts.workspace_shapes(
@@ -60,6 +62,7 @@ class TritonOrCutlassExperts(FallbackExperts):
global_num_experts,
local_num_experts,
expert_tokens_meta,
activation,
)
def _select_experts_impl(

View File

@@ -35,6 +35,7 @@ class TritonOrDeepGemmExperts(FallbackExperts):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm
@@ -48,6 +49,7 @@ class TritonOrDeepGemmExperts(FallbackExperts):
global_num_experts,
local_num_experts,
expert_tokens_meta,
activation,
)
else:
return self.fallback_experts.workspace_shapes(
@@ -58,6 +60,7 @@ class TritonOrDeepGemmExperts(FallbackExperts):
global_num_experts,
local_num_experts,
expert_tokens_meta,
activation,
)
def _select_experts_impl(

View File

@@ -57,6 +57,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: str,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# The workspaces for this implementation are managed by flashinfer.
workspace1 = (0,)

View File

@@ -4,6 +4,7 @@ import functools
from math import prod
import torch
import torch.nn.functional as F
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
@@ -324,6 +325,55 @@ def activation_without_mul(activation: str) -> str:
return activation + "_no_mul"
RELU2_NO_MUL: str = activation_without_mul("relu2")
SILU_NO_MUL: str = activation_without_mul("silu")
GELU_NO_MUL: str = activation_without_mul("gelu")
def apply_moe_activation(
activation: str,
output: torch.Tensor,
input: torch.Tensor,
) -> torch.Tensor:
"""
Apply MoE activation function.
For *_and_mul activations (silu, gelu, swigluoai):
- Expects output.size(-1) * 2 == input.size(-1)
For *_no_mul activations (silu_no_mul, gelu_no_mul, relu2_no_mul):
- Expects output.size(-1) == input.size(-1)
"""
is_no_mul = activation.endswith("_no_mul")
if is_no_mul:
assert output.size(-1) == input.size(-1), (
f"{activation} expects equal sizes: {output.size(-1)} vs {input.size(-1)}"
)
else:
assert output.size(-1) * 2 == input.size(-1), (
f"{activation} expects 2x ratio: {output.size(-1) * 2} vs {input.size(-1)}"
)
# Activations with gated multiplication (gate × activation(up))
if activation == "silu":
torch.ops._C.silu_and_mul(output, input)
elif activation == "gelu":
torch.ops._C.gelu_and_mul(output, input)
elif activation == "swigluoai":
torch.ops._C.swigluoai_and_mul(output, input)
# Activations without gated multiplication
elif activation == SILU_NO_MUL:
output.copy_(F.silu(input))
elif activation == GELU_NO_MUL:
output.copy_(F.gelu(input))
elif activation == RELU2_NO_MUL:
torch.square(F.relu(input), out=output)
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
return output
# Torch custom ops can't deal with outputs aliasing inputs so we need to
# disable inplace for torch >= 2.9.
# See https://github.com/vllm-project/vllm/issues/26378