[FIX] Add NO_MUL activation support for modular kernel path (#31528)
Signed-off-by: dafrimi <dafrimi@nvidia.com> Signed-off-by: <> Co-authored-by: root <root@gpu-267.slurm-workers-slurm.slurm.svc.cluster.local> Co-authored-by: root <root@gpu-537.slurm-workers-slurm.slurm.svc.cluster.local> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: root <root@pool0-01777.cm.cluster>
This commit is contained in:
201
tests/kernels/moe/test_triton_moe_no_act_mul.py
Normal file
201
tests/kernels/moe/test_triton_moe_no_act_mul.py
Normal file
@@ -0,0 +1,201 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for MoE with non-gated activations (*_no_mul).
|
||||
|
||||
These tests verify that MoE layers work correctly with activations like
|
||||
silu_no_mul, gelu_no_mul, relu2_no_mul where the activation output dimension
|
||||
equals N (not N // 2 like gated activations).
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
GELU_NO_MUL,
|
||||
RELU2_NO_MUL,
|
||||
SILU_NO_MUL,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# Test parameters
|
||||
M_SIZES = [1, 16, 64]
|
||||
N_SIZES = [128, 256]
|
||||
K_SIZES = [64, 128]
|
||||
TOPK_VALUES = [1, 2]
|
||||
NUM_EXPERTS = 8
|
||||
NO_MUL_ACTIVATIONS = [SILU_NO_MUL, GELU_NO_MUL, RELU2_NO_MUL]
|
||||
|
||||
|
||||
def make_test_tensors(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
num_experts: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
device: str = "cuda",
|
||||
):
|
||||
"""Create test tensors for MoE with non-gated activation.
|
||||
|
||||
For non-gated activations (*_no_mul):
|
||||
- w1: (E, N, K) - projects from K to N
|
||||
- w2: (E, K, N) - projects from N back to K (note: N, not N//2)
|
||||
"""
|
||||
hidden_states = torch.randn(m, k, dtype=dtype, device=device)
|
||||
|
||||
# For non-gated: w1 projects K -> N, w2 projects N -> K
|
||||
w1 = torch.randn(num_experts, n, k, dtype=dtype, device=device) * 0.1
|
||||
w2 = torch.randn(num_experts, k, n, dtype=dtype, device=device) * 0.1
|
||||
|
||||
topk_weights = torch.ones(m, topk, dtype=torch.float32, device=device) / topk
|
||||
topk_ids = torch.randint(0, num_experts, (m, topk), device=device)
|
||||
|
||||
return hidden_states, w1, w2, topk_weights, topk_ids
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(80),
|
||||
reason="Requires compute capability >= 8.0",
|
||||
)
|
||||
@pytest.mark.parametrize("m", M_SIZES)
|
||||
@pytest.mark.parametrize("n", N_SIZES)
|
||||
@pytest.mark.parametrize("k", K_SIZES)
|
||||
@pytest.mark.parametrize("topk", TOPK_VALUES)
|
||||
@pytest.mark.parametrize("activation", NO_MUL_ACTIVATIONS)
|
||||
@torch.inference_mode()
|
||||
def test_triton_experts_no_mul_activation(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
topk: int,
|
||||
activation: str,
|
||||
):
|
||||
hidden_states, w1, w2, topk_weights, topk_ids = make_test_tensors(
|
||||
m, n, k, NUM_EXPERTS, topk
|
||||
)
|
||||
|
||||
experts = TritonExperts(FUSED_MOE_UNQUANTIZED_CONFIG)
|
||||
|
||||
ws1_shape, ws2_shape, out_shape = experts.workspace_shapes(
|
||||
M=m,
|
||||
N=n,
|
||||
K=k,
|
||||
topk=topk,
|
||||
global_num_experts=NUM_EXPERTS,
|
||||
local_num_experts=NUM_EXPERTS,
|
||||
expert_tokens_meta=None,
|
||||
activation=activation,
|
||||
)
|
||||
|
||||
# Verify workspace shapes are correct for no_mul activation
|
||||
# workspace1 should handle activation_out_dim = N (not N//2)
|
||||
assert ws1_shape == (m, topk, max(n, k)), (
|
||||
f"workspace1 shape mismatch: expected {(m, topk, max(n, k))}, got {ws1_shape}"
|
||||
)
|
||||
# workspace2 should handle max(N, K) for intermediate_cache1/cache3
|
||||
assert ws2_shape == (m, topk, max(n, k)), (
|
||||
f"workspace2 shape mismatch: expected {(m, topk, max(n, k))}, got {ws2_shape}"
|
||||
)
|
||||
assert out_shape == (m, k), (
|
||||
f"output shape mismatch: expected {(m, k)}, got {out_shape}"
|
||||
)
|
||||
|
||||
workspace1 = torch.empty(
|
||||
ws1_shape[0] * ws1_shape[1] * ws1_shape[2],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
workspace2 = torch.empty(
|
||||
ws2_shape[0] * ws2_shape[1] * ws2_shape[2],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
output = torch.zeros(m, k, dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
|
||||
experts.apply(
|
||||
output=output,
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=NUM_EXPERTS,
|
||||
expert_map=None,
|
||||
a1q_scale=None,
|
||||
a2_scale=None,
|
||||
workspace13=workspace1,
|
||||
workspace2=workspace2,
|
||||
expert_tokens_meta=None,
|
||||
apply_router_weight_on_input=False,
|
||||
)
|
||||
|
||||
assert output.shape == (m, k), f"Expected shape {(m, k)}, got {output.shape}"
|
||||
assert not torch.isnan(output).any(), "Output contains NaN"
|
||||
assert not torch.isinf(output).any(), "Output contains Inf"
|
||||
assert output.abs().sum() > 0, "Output is all zeros"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(80),
|
||||
reason="Requires compute capability >= 8.0",
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_workspace_shapes_no_mul_vs_gated():
|
||||
"""Test that workspace shapes differ correctly between gated and non-gated."""
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
|
||||
|
||||
M, N, K, topk = 64, 256, 128, 2
|
||||
|
||||
experts = TritonExperts(FUSED_MOE_UNQUANTIZED_CONFIG)
|
||||
|
||||
ws1_no_mul, _, out_no_mul = experts.workspace_shapes(
|
||||
M, N, K, topk, 8, 8, None, SILU_NO_MUL
|
||||
)
|
||||
|
||||
ws1_gated, _, out_gated = experts.workspace_shapes(
|
||||
M, N, K, topk, 8, 8, None, "silu"
|
||||
)
|
||||
|
||||
# For no_mul: activation_out_dim = N
|
||||
# For gated: activation_out_dim = N // 2
|
||||
# workspace1 should use max(activation_out_dim, K)
|
||||
activation_out_dim_no_mul = N
|
||||
activation_out_dim_gated = N // 2
|
||||
|
||||
assert ws1_no_mul[2] == max(activation_out_dim_no_mul, K), (
|
||||
f"no_mul workspace1 last dim should be max({activation_out_dim_no_mul}, {K})"
|
||||
)
|
||||
assert ws1_gated[2] == max(activation_out_dim_gated, K), (
|
||||
f"gated workspace1 last dim should be max({activation_out_dim_gated}, {K})"
|
||||
)
|
||||
|
||||
# Output shapes should be the same
|
||||
assert out_no_mul == out_gated == (M, K)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(80),
|
||||
reason="Requires compute capability >= 8.0",
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_adjust_n_for_activation():
|
||||
"""Test the adjust_N_for_activation method."""
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
|
||||
|
||||
experts = TritonExperts(FUSED_MOE_UNQUANTIZED_CONFIG)
|
||||
|
||||
N = 256
|
||||
|
||||
# Gated activations should return N // 2
|
||||
assert experts.adjust_N_for_activation(N, "silu") == N // 2
|
||||
assert experts.adjust_N_for_activation(N, "gelu") == N // 2
|
||||
|
||||
# Non-gated activations should return N
|
||||
assert experts.adjust_N_for_activation(N, SILU_NO_MUL) == N
|
||||
assert experts.adjust_N_for_activation(N, GELU_NO_MUL) == N
|
||||
assert experts.adjust_N_for_activation(N, RELU2_NO_MUL) == N
|
||||
@@ -305,6 +305,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# FIXME (varun): We should be able to dispatch only from the leader
|
||||
# DP ranks in the case of TP > 1. At the moment, all the Ranks
|
||||
@@ -312,8 +313,9 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
num_dispatchers = self.num_dispatchers
|
||||
num_experts = local_num_experts
|
||||
max_num_tokens = M if self.max_num_tokens is None else self.max_num_tokens
|
||||
activation_out_dim = self.adjust_N_for_activation(N, activation)
|
||||
workspace13 = (num_experts, max_num_tokens * num_dispatchers, max(K, N))
|
||||
workspace2 = (num_experts, max_num_tokens * num_dispatchers, (N // 2))
|
||||
workspace2 = (num_experts, max_num_tokens * num_dispatchers, activation_out_dim)
|
||||
output = (num_experts, max_num_tokens * num_dispatchers, K)
|
||||
return (workspace13, workspace2, output)
|
||||
|
||||
|
||||
@@ -355,9 +355,11 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
activation_out_dim = self.adjust_N_for_activation(N, activation)
|
||||
workspace1 = (M * topk, max(N, K))
|
||||
workspace2 = (M * topk, max(N // 2, K))
|
||||
workspace2 = (M * topk, max(activation_out_dim, K))
|
||||
output = (M, K)
|
||||
return (workspace1, workspace2, output)
|
||||
|
||||
@@ -402,11 +404,17 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
num_dp = self.num_dispatchers
|
||||
assert num_dp is not None
|
||||
activation_out_dim = self.adjust_N_for_activation(N, activation)
|
||||
workspace1 = (self.max_experts_per_worker, M * num_dp, max(N, K))
|
||||
workspace2 = (self.max_experts_per_worker, M * num_dp, max(N // 2, K))
|
||||
workspace2 = (
|
||||
self.max_experts_per_worker,
|
||||
M * num_dp,
|
||||
max(activation_out_dim, K),
|
||||
)
|
||||
output = (self.max_experts_per_worker, M, K)
|
||||
return (workspace1, workspace2, output)
|
||||
|
||||
@@ -635,13 +643,15 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
activation_out_dim = self.adjust_N_for_activation(N, activation)
|
||||
workspace1: tuple[int, ...] = ()
|
||||
workspace2: tuple[int, ...] = ()
|
||||
output: tuple[int, ...] = ()
|
||||
if self.use_batched_format:
|
||||
workspace1 = (self.max_experts_per_worker, M, max(N, K))
|
||||
workspace2 = (self.max_experts_per_worker, M, (N // 2))
|
||||
workspace2 = (self.max_experts_per_worker, M, activation_out_dim)
|
||||
output = (self.max_experts_per_worker, M, K)
|
||||
else:
|
||||
workspace1 = (M * topk, max(2 * N, K))
|
||||
@@ -896,9 +906,11 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
activation_out_dim = self.adjust_N_for_activation(N, activation)
|
||||
workspace1 = (M * topk, max(N, K))
|
||||
workspace2 = (M * topk, max(N // 2, K))
|
||||
workspace2 = (M * topk, max(activation_out_dim, K))
|
||||
output = (M, K)
|
||||
return (workspace1, workspace2, output)
|
||||
|
||||
|
||||
@@ -143,6 +143,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
assert self.block_shape is not None
|
||||
block_m = self.block_shape[0]
|
||||
@@ -151,7 +152,8 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
)
|
||||
assert M_sum % block_m == 0
|
||||
|
||||
workspace1 = (M_sum, max(N // 2, K))
|
||||
activation_out_dim = self.adjust_N_for_activation(N, activation)
|
||||
workspace1 = (M_sum, max(activation_out_dim, K))
|
||||
workspace2 = (M_sum, max(N, K))
|
||||
output = (M, K)
|
||||
return (workspace1, workspace2, output)
|
||||
@@ -163,11 +165,13 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
block_k = self.block_shape[1]
|
||||
scale_fmt = DeepGemmQuantScaleFMT.from_oracle()
|
||||
|
||||
M_sum, N = input.size()
|
||||
activation_out_dim = self.adjust_N_for_activation(N, activation)
|
||||
|
||||
# 1. DeepGemm UE8M0: use packed per-token-group quant
|
||||
if scale_fmt == DeepGemmQuantScaleFMT.UE8M0:
|
||||
M_sum, N = input.size()
|
||||
act_out = torch.empty(
|
||||
(M_sum, N // 2), dtype=input.dtype, device=input.device
|
||||
(M_sum, activation_out_dim), dtype=input.dtype, device=input.device
|
||||
)
|
||||
self.activation(activation, act_out, input)
|
||||
a2q, a2q_scale = per_token_group_quant_fp8_packed_for_deepgemm(
|
||||
@@ -187,8 +191,9 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
)
|
||||
|
||||
# 3. fallback path for non-SiLU activations in non‑UE8M0 cases.
|
||||
M_sum, N = input.size()
|
||||
act_out = torch.empty((M_sum, N // 2), dtype=input.dtype, device=input.device)
|
||||
act_out = torch.empty(
|
||||
(M_sum, activation_out_dim), dtype=input.dtype, device=input.device
|
||||
)
|
||||
self.activation(activation, act_out, input)
|
||||
return per_token_group_quant_fp8(
|
||||
act_out, block_k, column_major_scales=True, out_q=output
|
||||
@@ -254,8 +259,9 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
(a1q, a1q_scale), (w1, self.w1_scale), mm1_out, expert_ids
|
||||
)
|
||||
|
||||
activation_out_dim = self.adjust_N_for_activation(N, activation)
|
||||
quant_out = _resize_cache(
|
||||
workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, N // 2)
|
||||
workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, activation_out_dim)
|
||||
)
|
||||
a2q, a2q_scale = self._act_mul_quant(
|
||||
input=mm1_out.view(-1, N), output=quant_out, activation=activation
|
||||
|
||||
@@ -76,6 +76,7 @@ class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -91,6 +91,7 @@ class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# We use global_num_experts due to how moe_align_block_size handles
|
||||
# expert_maps.
|
||||
|
||||
@@ -103,6 +103,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# We use global_num_experts due to how moe_align_block_size handles
|
||||
# expert_maps.
|
||||
|
||||
@@ -673,6 +673,7 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
num_dp = self.num_dispatchers
|
||||
num_experts = local_num_experts
|
||||
@@ -867,12 +868,14 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
num_dp = self.num_dispatchers
|
||||
num_experts = local_num_experts
|
||||
max_num_tokens = self.max_num_tokens
|
||||
activation_out_dim = self.adjust_N_for_activation(N, activation)
|
||||
workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N))
|
||||
workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2))
|
||||
workspace2 = (num_experts, max_num_tokens * num_dp, activation_out_dim)
|
||||
output = (num_experts, max_num_tokens * num_dp, K)
|
||||
return (workspace13, workspace2, output)
|
||||
|
||||
@@ -947,7 +950,10 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
# We can reuse the memory between these because by the time we need
|
||||
# cache3, we're done with cache1
|
||||
intermediate_cache1 = _resize_cache(workspace13, (E, max_num_tokens, N))
|
||||
intermediate_cache2 = _resize_cache(workspace2, (E, max_num_tokens, N // 2))
|
||||
activation_out_dim = self.adjust_N_for_activation(N, activation)
|
||||
intermediate_cache2 = _resize_cache(
|
||||
workspace2, (E, max_num_tokens, activation_out_dim)
|
||||
)
|
||||
|
||||
# TODO(bnell): should this be done for any quantized type?
|
||||
if self.quant_config.use_fp8_w8a8:
|
||||
@@ -978,7 +984,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
# TODO (bnell): use triton utility from batched deep gemm.
|
||||
self.activation(
|
||||
activation,
|
||||
intermediate_cache2.view(-1, N // 2),
|
||||
intermediate_cache2.view(-1, activation_out_dim),
|
||||
intermediate_cache1.view(-1, N),
|
||||
)
|
||||
|
||||
|
||||
@@ -640,6 +640,7 @@ class MarlinExperts(MarlinExpertsBase):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# Modular Kernel provisions output buffer from workspace1. However in
|
||||
# the fused_marlin_moe() function, the final torch.sum(), is defined
|
||||
@@ -768,6 +769,7 @@ class BatchedMarlinExperts(MarlinExpertsBase):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
num_dispatchers = self.num_dispatchers
|
||||
num_experts = local_num_experts
|
||||
|
||||
@@ -9,7 +9,6 @@ from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
@@ -43,7 +42,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_resize_cache,
|
||||
activation_without_mul,
|
||||
apply_moe_activation,
|
||||
disable_inplace,
|
||||
moe_kernel_quantize_input,
|
||||
)
|
||||
@@ -1957,11 +1956,6 @@ def fused_experts(
|
||||
)
|
||||
|
||||
|
||||
SILU_NO_MUL: str = activation_without_mul("silu")
|
||||
GELU_NO_MUL: str = activation_without_mul("gelu")
|
||||
RELU2_NO_MUL: str = activation_without_mul("relu2")
|
||||
|
||||
|
||||
def _get_config_quant_dtype(
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
@@ -2094,8 +2088,13 @@ def fused_experts_impl(
|
||||
intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K)
|
||||
|
||||
# This needs separate memory since it's used concurrently with cache1
|
||||
activation_out_dim = mk.FusedMoEPermuteExpertsUnpermute.adjust_N_for_activation(
|
||||
N, activation
|
||||
)
|
||||
intermediate_cache2 = torch.empty(
|
||||
(M * top_k_num, N // 2), device=hidden_states.device, dtype=hidden_states.dtype
|
||||
(M * top_k_num, activation_out_dim),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
if hidden_states.dtype == torch.bfloat16:
|
||||
@@ -2235,29 +2234,9 @@ def fused_experts_impl(
|
||||
B_bias=w1_bias,
|
||||
)
|
||||
|
||||
# Activation function with multiplication
|
||||
if activation == "silu":
|
||||
torch.ops._C.silu_and_mul(
|
||||
intermediate_cache2, intermediate_cache1.view(-1, N)
|
||||
)
|
||||
elif activation == "gelu":
|
||||
torch.ops._C.gelu_and_mul(
|
||||
intermediate_cache2, intermediate_cache1.view(-1, N)
|
||||
)
|
||||
elif activation == "swigluoai":
|
||||
# alpha = 1.702, limit = 7.0
|
||||
torch.ops._C.swigluoai_and_mul(
|
||||
intermediate_cache2, intermediate_cache1.view(-1, N)
|
||||
)
|
||||
# Activation function without multiplication
|
||||
elif activation == SILU_NO_MUL:
|
||||
intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N))
|
||||
elif activation == GELU_NO_MUL:
|
||||
intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N))
|
||||
elif activation == RELU2_NO_MUL:
|
||||
intermediate_cache2 = torch.square(F.relu(intermediate_cache1.view(-1, N)))
|
||||
else:
|
||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}.")
|
||||
apply_moe_activation(
|
||||
activation, intermediate_cache2, intermediate_cache1.view(-1, N)
|
||||
)
|
||||
|
||||
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
||||
A=intermediate_cache2,
|
||||
@@ -2336,8 +2315,10 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
workspace1 = (M, topk, max(N // 2, K))
|
||||
activation_out_dim = self.adjust_N_for_activation(N, activation)
|
||||
workspace1 = (M, topk, max(activation_out_dim, K))
|
||||
workspace2 = (M, topk, max(N, K))
|
||||
output = (M, K)
|
||||
return (workspace1, workspace2, output)
|
||||
@@ -2412,8 +2393,9 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
# Note that the output tensor might be in workspace1
|
||||
intermediate_cache1 = _resize_cache(workspace2, (num_tokens, top_k_num, N))
|
||||
cache2_dim = self.adjust_N_for_activation(N, activation)
|
||||
intermediate_cache2 = _resize_cache(
|
||||
workspace13, (num_tokens * top_k_num, N // 2)
|
||||
workspace13, (num_tokens * top_k_num, cache2_dim)
|
||||
)
|
||||
intermediate_cache3 = _resize_cache(workspace2, (num_tokens, top_k_num, K))
|
||||
|
||||
@@ -2565,8 +2547,9 @@ class TritonWNA16Experts(TritonExperts):
|
||||
|
||||
# Note that the output tensor might be in workspace1
|
||||
intermediate_cache1 = _resize_cache(workspace2, (num_tokens, top_k_num, N))
|
||||
activation_out_dim = self.adjust_N_for_activation(N, activation)
|
||||
intermediate_cache2 = _resize_cache(
|
||||
workspace13, (num_tokens * top_k_num, N // 2)
|
||||
workspace13, (num_tokens * top_k_num, activation_out_dim)
|
||||
)
|
||||
intermediate_cache3 = _resize_cache(workspace2, (num_tokens, top_k_num, K))
|
||||
|
||||
|
||||
@@ -323,10 +323,12 @@ class OAITritonExperts(BaseOAITritonExperts):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# workspace are allocated inside the kernel
|
||||
activation_out_dim = self.adjust_N_for_activation(N, activation)
|
||||
workspace1 = (0, 0)
|
||||
workspace2 = (M * topk, N // 2)
|
||||
workspace2 = (M * topk, activation_out_dim)
|
||||
output = (M, K)
|
||||
return (workspace1, workspace2, output)
|
||||
|
||||
@@ -415,9 +417,11 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# workspace are allocated inside the kernel
|
||||
workspace1 = (M * topk, N // 2)
|
||||
activation_out_dim = self.adjust_N_for_activation(N, activation)
|
||||
workspace1 = (M * topk, activation_out_dim)
|
||||
workspace2 = (M * topk, max(N, K))
|
||||
output = (M, K)
|
||||
return (workspace1, workspace2, output)
|
||||
@@ -443,8 +447,10 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
if self.quant_config is None:
|
||||
self.quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
# Use local variable to help mypy narrow the type after None check
|
||||
quant_config = self.quant_config
|
||||
if quant_config is None:
|
||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
|
||||
if expert_map is not None:
|
||||
topk_ids = expert_map[topk_ids]
|
||||
@@ -462,12 +468,10 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
|
||||
# type check, uint8 means mxfp4
|
||||
assert hidden_states.dtype == torch.bfloat16
|
||||
assert (
|
||||
self.quant_config.w1_bias is None
|
||||
or self.quant_config.w1_bias.dtype == torch.float32
|
||||
quant_config.w1_bias is None or quant_config.w1_bias.dtype == torch.float32
|
||||
)
|
||||
assert (
|
||||
self.quant_config.w2_bias is None
|
||||
or self.quant_config.w2_bias.dtype == torch.float32
|
||||
quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32
|
||||
)
|
||||
|
||||
# Shape check, only check non-mxfp4
|
||||
@@ -485,17 +489,18 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
|
||||
# Note that the output tensor might be in workspace13
|
||||
intermediate_cache1 = _resize_cache(workspace2, (batch_dim, M * topk, N))
|
||||
intermediate_cache3 = _resize_cache(workspace2, (batch_dim, M * topk, K))
|
||||
intermediate_cache2 = _resize_cache(workspace13, (M * topk, N // 2))
|
||||
activation_out_dim = self.adjust_N_for_activation(N, activation)
|
||||
intermediate_cache2 = _resize_cache(workspace13, (M * topk, activation_out_dim))
|
||||
|
||||
gammas = routing_data.gate_scal if routing_data else None
|
||||
|
||||
matmul_ogs(
|
||||
hidden_states,
|
||||
w1,
|
||||
self.quant_config.w1_bias,
|
||||
quant_config.w1_bias,
|
||||
routing_data,
|
||||
gather_indx=gather_indx,
|
||||
precision_config=self.quant_config.w1_precision,
|
||||
precision_config=quant_config.w1_precision,
|
||||
gammas=gammas if apply_router_weight_on_input else None,
|
||||
fused_activation=None,
|
||||
y=intermediate_cache1,
|
||||
@@ -515,10 +520,10 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
|
||||
matmul_ogs(
|
||||
intermediate_cache2[gather_indx.src_indx],
|
||||
w2,
|
||||
self.quant_config.w2_bias,
|
||||
quant_config.w2_bias,
|
||||
routing_data,
|
||||
scatter_indx=scatter_indx,
|
||||
precision_config=self.quant_config.w2_precision,
|
||||
precision_config=quant_config.w2_precision,
|
||||
gammas=None if apply_router_weight_on_input else gammas,
|
||||
y=intermediate_cache3,
|
||||
)
|
||||
|
||||
@@ -18,6 +18,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_resize_cache,
|
||||
apply_moe_activation,
|
||||
count_expert_num_tokens,
|
||||
disable_inplace,
|
||||
)
|
||||
@@ -542,6 +543,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
"""
|
||||
Compute the shapes for the temporary and final outputs of the two gemms
|
||||
@@ -572,19 +574,31 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def adjust_N_for_activation(N: int, activation: str) -> int:
|
||||
"""
|
||||
Calculate the output dimension for the activation function.
|
||||
|
||||
For *_no_mul activations (e.g. relu2_no_mul),
|
||||
there's no gate/up split, so output size equals input size (N).
|
||||
|
||||
For regular gated activations (e.g., silu, gelu, swigluoai),
|
||||
output size is N // 2 due to gate × activation(up) multiplication.
|
||||
|
||||
Args:
|
||||
N: The intermediate size (width of w1/w3 weights).
|
||||
activation: The activation function name.
|
||||
|
||||
Returns:
|
||||
The output dimension after activation.
|
||||
"""
|
||||
is_no_mul = activation.endswith("_no_mul")
|
||||
return N if is_no_mul else N // 2
|
||||
|
||||
def activation(
|
||||
self, activation: str, output: torch.Tensor, input: torch.Tensor
|
||||
) -> None:
|
||||
assert output.size(-1) * 2 == input.size(-1)
|
||||
if activation == "silu":
|
||||
torch.ops._C.silu_and_mul(output, input)
|
||||
elif activation == "gelu":
|
||||
torch.ops._C.gelu_and_mul(output, input)
|
||||
elif activation == "swigluoai":
|
||||
# alpha = 1.702, limit = 7.0
|
||||
torch.ops._C.swigluoai_and_mul(output, input)
|
||||
else:
|
||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
||||
apply_moe_activation(activation, output, input)
|
||||
|
||||
def enable_chunking(self):
|
||||
return (
|
||||
@@ -761,6 +775,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Allocate temporary and output buffers for the fused experts op.
|
||||
@@ -796,6 +811,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
# amount of workspace. Mark it None, so we allocate for
|
||||
# the worst-case scenario.
|
||||
expert_tokens_meta=None,
|
||||
activation=activation,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -814,6 +830,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
global_num_experts,
|
||||
local_num_experts,
|
||||
expert_tokens_meta,
|
||||
activation,
|
||||
)
|
||||
|
||||
# Get final output shape based on the full M size.
|
||||
@@ -825,6 +842,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
global_num_experts,
|
||||
local_num_experts,
|
||||
expert_tokens_meta,
|
||||
activation,
|
||||
)
|
||||
|
||||
# We can reuse the memory between cache1 and cache3 because by the
|
||||
@@ -1043,6 +1061,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
global_num_experts,
|
||||
local_num_experts,
|
||||
expert_tokens_meta,
|
||||
activation,
|
||||
)
|
||||
|
||||
for chunk_idx in range(num_chunks):
|
||||
|
||||
@@ -299,6 +299,7 @@ class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# Workspaces are managed internally by AITER.
|
||||
workspace1 = (0,)
|
||||
|
||||
@@ -39,6 +39,7 @@ class TritonOrCutlassExperts(FallbackExperts):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# Small batch fallback for sm100.
|
||||
if self.is_sm100 and M <= 8:
|
||||
@@ -50,6 +51,7 @@ class TritonOrCutlassExperts(FallbackExperts):
|
||||
global_num_experts,
|
||||
local_num_experts,
|
||||
expert_tokens_meta,
|
||||
activation,
|
||||
)
|
||||
else:
|
||||
return self.experts.workspace_shapes(
|
||||
@@ -60,6 +62,7 @@ class TritonOrCutlassExperts(FallbackExperts):
|
||||
global_num_experts,
|
||||
local_num_experts,
|
||||
expert_tokens_meta,
|
||||
activation,
|
||||
)
|
||||
|
||||
def _select_experts_impl(
|
||||
|
||||
@@ -35,6 +35,7 @@ class TritonOrDeepGemmExperts(FallbackExperts):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# Note: the deep gemm workspaces are strictly larger than the triton
|
||||
# workspaces so we can be pessimistic here and allocate for DeepGemm
|
||||
@@ -48,6 +49,7 @@ class TritonOrDeepGemmExperts(FallbackExperts):
|
||||
global_num_experts,
|
||||
local_num_experts,
|
||||
expert_tokens_meta,
|
||||
activation,
|
||||
)
|
||||
else:
|
||||
return self.fallback_experts.workspace_shapes(
|
||||
@@ -58,6 +60,7 @@ class TritonOrDeepGemmExperts(FallbackExperts):
|
||||
global_num_experts,
|
||||
local_num_experts,
|
||||
expert_tokens_meta,
|
||||
activation,
|
||||
)
|
||||
|
||||
def _select_experts_impl(
|
||||
|
||||
@@ -57,6 +57,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# The workspaces for this implementation are managed by flashinfer.
|
||||
workspace1 = (0,)
|
||||
|
||||
@@ -4,6 +4,7 @@ import functools
|
||||
from math import prod
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
@@ -324,6 +325,55 @@ def activation_without_mul(activation: str) -> str:
|
||||
return activation + "_no_mul"
|
||||
|
||||
|
||||
RELU2_NO_MUL: str = activation_without_mul("relu2")
|
||||
SILU_NO_MUL: str = activation_without_mul("silu")
|
||||
GELU_NO_MUL: str = activation_without_mul("gelu")
|
||||
|
||||
|
||||
def apply_moe_activation(
|
||||
activation: str,
|
||||
output: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply MoE activation function.
|
||||
|
||||
For *_and_mul activations (silu, gelu, swigluoai):
|
||||
- Expects output.size(-1) * 2 == input.size(-1)
|
||||
|
||||
For *_no_mul activations (silu_no_mul, gelu_no_mul, relu2_no_mul):
|
||||
- Expects output.size(-1) == input.size(-1)
|
||||
"""
|
||||
is_no_mul = activation.endswith("_no_mul")
|
||||
if is_no_mul:
|
||||
assert output.size(-1) == input.size(-1), (
|
||||
f"{activation} expects equal sizes: {output.size(-1)} vs {input.size(-1)}"
|
||||
)
|
||||
else:
|
||||
assert output.size(-1) * 2 == input.size(-1), (
|
||||
f"{activation} expects 2x ratio: {output.size(-1) * 2} vs {input.size(-1)}"
|
||||
)
|
||||
|
||||
# Activations with gated multiplication (gate × activation(up))
|
||||
if activation == "silu":
|
||||
torch.ops._C.silu_and_mul(output, input)
|
||||
elif activation == "gelu":
|
||||
torch.ops._C.gelu_and_mul(output, input)
|
||||
elif activation == "swigluoai":
|
||||
torch.ops._C.swigluoai_and_mul(output, input)
|
||||
# Activations without gated multiplication
|
||||
elif activation == SILU_NO_MUL:
|
||||
output.copy_(F.silu(input))
|
||||
elif activation == GELU_NO_MUL:
|
||||
output.copy_(F.gelu(input))
|
||||
elif activation == RELU2_NO_MUL:
|
||||
torch.square(F.relu(input), out=output)
|
||||
else:
|
||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
||||
|
||||
return output
|
||||
|
||||
|
||||
# Torch custom ops can't deal with outputs aliasing inputs so we need to
|
||||
# disable inplace for torch >= 2.9.
|
||||
# See https://github.com/vllm-project/vllm/issues/26378
|
||||
|
||||
Reference in New Issue
Block a user