[FIX] Add NO_MUL activation support for modular kernel path (#31528)
Signed-off-by: dafrimi <dafrimi@nvidia.com> Signed-off-by: <> Co-authored-by: root <root@gpu-267.slurm-workers-slurm.slurm.svc.cluster.local> Co-authored-by: root <root@gpu-537.slurm-workers-slurm.slurm.svc.cluster.local> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: root <root@pool0-01777.cm.cluster>
This commit is contained in:
201
tests/kernels/moe/test_triton_moe_no_act_mul.py
Normal file
201
tests/kernels/moe/test_triton_moe_no_act_mul.py
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Tests for MoE with non-gated activations (*_no_mul).
|
||||||
|
|
||||||
|
These tests verify that MoE layers work correctly with activations like
|
||||||
|
silu_no_mul, gelu_no_mul, relu2_no_mul where the activation output dimension
|
||||||
|
equals N (not N // 2 like gated activations).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
|
FUSED_MOE_UNQUANTIZED_CONFIG,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
|
||||||
|
from vllm.model_executor.layers.fused_moe.utils import (
|
||||||
|
GELU_NO_MUL,
|
||||||
|
RELU2_NO_MUL,
|
||||||
|
SILU_NO_MUL,
|
||||||
|
)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
# Test parameters
|
||||||
|
M_SIZES = [1, 16, 64]
|
||||||
|
N_SIZES = [128, 256]
|
||||||
|
K_SIZES = [64, 128]
|
||||||
|
TOPK_VALUES = [1, 2]
|
||||||
|
NUM_EXPERTS = 8
|
||||||
|
NO_MUL_ACTIVATIONS = [SILU_NO_MUL, GELU_NO_MUL, RELU2_NO_MUL]
|
||||||
|
|
||||||
|
|
||||||
|
def make_test_tensors(
|
||||||
|
m: int,
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
num_experts: int,
|
||||||
|
topk: int,
|
||||||
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
|
device: str = "cuda",
|
||||||
|
):
|
||||||
|
"""Create test tensors for MoE with non-gated activation.
|
||||||
|
|
||||||
|
For non-gated activations (*_no_mul):
|
||||||
|
- w1: (E, N, K) - projects from K to N
|
||||||
|
- w2: (E, K, N) - projects from N back to K (note: N, not N//2)
|
||||||
|
"""
|
||||||
|
hidden_states = torch.randn(m, k, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
# For non-gated: w1 projects K -> N, w2 projects N -> K
|
||||||
|
w1 = torch.randn(num_experts, n, k, dtype=dtype, device=device) * 0.1
|
||||||
|
w2 = torch.randn(num_experts, k, n, dtype=dtype, device=device) * 0.1
|
||||||
|
|
||||||
|
topk_weights = torch.ones(m, topk, dtype=torch.float32, device=device) / topk
|
||||||
|
topk_ids = torch.randint(0, num_experts, (m, topk), device=device)
|
||||||
|
|
||||||
|
return hidden_states, w1, w2, topk_weights, topk_ids
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not current_platform.has_device_capability(80),
|
||||||
|
reason="Requires compute capability >= 8.0",
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("m", M_SIZES)
|
||||||
|
@pytest.mark.parametrize("n", N_SIZES)
|
||||||
|
@pytest.mark.parametrize("k", K_SIZES)
|
||||||
|
@pytest.mark.parametrize("topk", TOPK_VALUES)
|
||||||
|
@pytest.mark.parametrize("activation", NO_MUL_ACTIVATIONS)
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_triton_experts_no_mul_activation(
|
||||||
|
m: int,
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
topk: int,
|
||||||
|
activation: str,
|
||||||
|
):
|
||||||
|
hidden_states, w1, w2, topk_weights, topk_ids = make_test_tensors(
|
||||||
|
m, n, k, NUM_EXPERTS, topk
|
||||||
|
)
|
||||||
|
|
||||||
|
experts = TritonExperts(FUSED_MOE_UNQUANTIZED_CONFIG)
|
||||||
|
|
||||||
|
ws1_shape, ws2_shape, out_shape = experts.workspace_shapes(
|
||||||
|
M=m,
|
||||||
|
N=n,
|
||||||
|
K=k,
|
||||||
|
topk=topk,
|
||||||
|
global_num_experts=NUM_EXPERTS,
|
||||||
|
local_num_experts=NUM_EXPERTS,
|
||||||
|
expert_tokens_meta=None,
|
||||||
|
activation=activation,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify workspace shapes are correct for no_mul activation
|
||||||
|
# workspace1 should handle activation_out_dim = N (not N//2)
|
||||||
|
assert ws1_shape == (m, topk, max(n, k)), (
|
||||||
|
f"workspace1 shape mismatch: expected {(m, topk, max(n, k))}, got {ws1_shape}"
|
||||||
|
)
|
||||||
|
# workspace2 should handle max(N, K) for intermediate_cache1/cache3
|
||||||
|
assert ws2_shape == (m, topk, max(n, k)), (
|
||||||
|
f"workspace2 shape mismatch: expected {(m, topk, max(n, k))}, got {ws2_shape}"
|
||||||
|
)
|
||||||
|
assert out_shape == (m, k), (
|
||||||
|
f"output shape mismatch: expected {(m, k)}, got {out_shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
workspace1 = torch.empty(
|
||||||
|
ws1_shape[0] * ws1_shape[1] * ws1_shape[2],
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
device=hidden_states.device,
|
||||||
|
)
|
||||||
|
workspace2 = torch.empty(
|
||||||
|
ws2_shape[0] * ws2_shape[1] * ws2_shape[2],
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
device=hidden_states.device,
|
||||||
|
)
|
||||||
|
output = torch.zeros(m, k, dtype=hidden_states.dtype, device=hidden_states.device)
|
||||||
|
|
||||||
|
experts.apply(
|
||||||
|
output=output,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
w1=w1,
|
||||||
|
w2=w2,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
activation=activation,
|
||||||
|
global_num_experts=NUM_EXPERTS,
|
||||||
|
expert_map=None,
|
||||||
|
a1q_scale=None,
|
||||||
|
a2_scale=None,
|
||||||
|
workspace13=workspace1,
|
||||||
|
workspace2=workspace2,
|
||||||
|
expert_tokens_meta=None,
|
||||||
|
apply_router_weight_on_input=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert output.shape == (m, k), f"Expected shape {(m, k)}, got {output.shape}"
|
||||||
|
assert not torch.isnan(output).any(), "Output contains NaN"
|
||||||
|
assert not torch.isinf(output).any(), "Output contains Inf"
|
||||||
|
assert output.abs().sum() > 0, "Output is all zeros"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not current_platform.has_device_capability(80),
|
||||||
|
reason="Requires compute capability >= 8.0",
|
||||||
|
)
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_workspace_shapes_no_mul_vs_gated():
|
||||||
|
"""Test that workspace shapes differ correctly between gated and non-gated."""
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
|
||||||
|
|
||||||
|
M, N, K, topk = 64, 256, 128, 2
|
||||||
|
|
||||||
|
experts = TritonExperts(FUSED_MOE_UNQUANTIZED_CONFIG)
|
||||||
|
|
||||||
|
ws1_no_mul, _, out_no_mul = experts.workspace_shapes(
|
||||||
|
M, N, K, topk, 8, 8, None, SILU_NO_MUL
|
||||||
|
)
|
||||||
|
|
||||||
|
ws1_gated, _, out_gated = experts.workspace_shapes(
|
||||||
|
M, N, K, topk, 8, 8, None, "silu"
|
||||||
|
)
|
||||||
|
|
||||||
|
# For no_mul: activation_out_dim = N
|
||||||
|
# For gated: activation_out_dim = N // 2
|
||||||
|
# workspace1 should use max(activation_out_dim, K)
|
||||||
|
activation_out_dim_no_mul = N
|
||||||
|
activation_out_dim_gated = N // 2
|
||||||
|
|
||||||
|
assert ws1_no_mul[2] == max(activation_out_dim_no_mul, K), (
|
||||||
|
f"no_mul workspace1 last dim should be max({activation_out_dim_no_mul}, {K})"
|
||||||
|
)
|
||||||
|
assert ws1_gated[2] == max(activation_out_dim_gated, K), (
|
||||||
|
f"gated workspace1 last dim should be max({activation_out_dim_gated}, {K})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Output shapes should be the same
|
||||||
|
assert out_no_mul == out_gated == (M, K)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not current_platform.has_device_capability(80),
|
||||||
|
reason="Requires compute capability >= 8.0",
|
||||||
|
)
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_adjust_n_for_activation():
|
||||||
|
"""Test the adjust_N_for_activation method."""
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
|
||||||
|
|
||||||
|
experts = TritonExperts(FUSED_MOE_UNQUANTIZED_CONFIG)
|
||||||
|
|
||||||
|
N = 256
|
||||||
|
|
||||||
|
# Gated activations should return N // 2
|
||||||
|
assert experts.adjust_N_for_activation(N, "silu") == N // 2
|
||||||
|
assert experts.adjust_N_for_activation(N, "gelu") == N // 2
|
||||||
|
|
||||||
|
# Non-gated activations should return N
|
||||||
|
assert experts.adjust_N_for_activation(N, SILU_NO_MUL) == N
|
||||||
|
assert experts.adjust_N_for_activation(N, GELU_NO_MUL) == N
|
||||||
|
assert experts.adjust_N_for_activation(N, RELU2_NO_MUL) == N
|
||||||
@@ -305,6 +305,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
global_num_experts: int,
|
global_num_experts: int,
|
||||||
local_num_experts: int,
|
local_num_experts: int,
|
||||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||||
|
activation: str,
|
||||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||||
# FIXME (varun): We should be able to dispatch only from the leader
|
# FIXME (varun): We should be able to dispatch only from the leader
|
||||||
# DP ranks in the case of TP > 1. At the moment, all the Ranks
|
# DP ranks in the case of TP > 1. At the moment, all the Ranks
|
||||||
@@ -312,8 +313,9 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
num_dispatchers = self.num_dispatchers
|
num_dispatchers = self.num_dispatchers
|
||||||
num_experts = local_num_experts
|
num_experts = local_num_experts
|
||||||
max_num_tokens = M if self.max_num_tokens is None else self.max_num_tokens
|
max_num_tokens = M if self.max_num_tokens is None else self.max_num_tokens
|
||||||
|
activation_out_dim = self.adjust_N_for_activation(N, activation)
|
||||||
workspace13 = (num_experts, max_num_tokens * num_dispatchers, max(K, N))
|
workspace13 = (num_experts, max_num_tokens * num_dispatchers, max(K, N))
|
||||||
workspace2 = (num_experts, max_num_tokens * num_dispatchers, (N // 2))
|
workspace2 = (num_experts, max_num_tokens * num_dispatchers, activation_out_dim)
|
||||||
output = (num_experts, max_num_tokens * num_dispatchers, K)
|
output = (num_experts, max_num_tokens * num_dispatchers, K)
|
||||||
return (workspace13, workspace2, output)
|
return (workspace13, workspace2, output)
|
||||||
|
|
||||||
|
|||||||
@@ -355,9 +355,11 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
|
|||||||
global_num_experts: int,
|
global_num_experts: int,
|
||||||
local_num_experts: int,
|
local_num_experts: int,
|
||||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||||
|
activation: str,
|
||||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||||
|
activation_out_dim = self.adjust_N_for_activation(N, activation)
|
||||||
workspace1 = (M * topk, max(N, K))
|
workspace1 = (M * topk, max(N, K))
|
||||||
workspace2 = (M * topk, max(N // 2, K))
|
workspace2 = (M * topk, max(activation_out_dim, K))
|
||||||
output = (M, K)
|
output = (M, K)
|
||||||
return (workspace1, workspace2, output)
|
return (workspace1, workspace2, output)
|
||||||
|
|
||||||
@@ -402,11 +404,17 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
|
|||||||
global_num_experts: int,
|
global_num_experts: int,
|
||||||
local_num_experts: int,
|
local_num_experts: int,
|
||||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||||
|
activation: str,
|
||||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||||
num_dp = self.num_dispatchers
|
num_dp = self.num_dispatchers
|
||||||
assert num_dp is not None
|
assert num_dp is not None
|
||||||
|
activation_out_dim = self.adjust_N_for_activation(N, activation)
|
||||||
workspace1 = (self.max_experts_per_worker, M * num_dp, max(N, K))
|
workspace1 = (self.max_experts_per_worker, M * num_dp, max(N, K))
|
||||||
workspace2 = (self.max_experts_per_worker, M * num_dp, max(N // 2, K))
|
workspace2 = (
|
||||||
|
self.max_experts_per_worker,
|
||||||
|
M * num_dp,
|
||||||
|
max(activation_out_dim, K),
|
||||||
|
)
|
||||||
output = (self.max_experts_per_worker, M, K)
|
output = (self.max_experts_per_worker, M, K)
|
||||||
return (workspace1, workspace2, output)
|
return (workspace1, workspace2, output)
|
||||||
|
|
||||||
@@ -635,13 +643,15 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
global_num_experts: int,
|
global_num_experts: int,
|
||||||
local_num_experts: int,
|
local_num_experts: int,
|
||||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||||
|
activation: str,
|
||||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||||
|
activation_out_dim = self.adjust_N_for_activation(N, activation)
|
||||||
workspace1: tuple[int, ...] = ()
|
workspace1: tuple[int, ...] = ()
|
||||||
workspace2: tuple[int, ...] = ()
|
workspace2: tuple[int, ...] = ()
|
||||||
output: tuple[int, ...] = ()
|
output: tuple[int, ...] = ()
|
||||||
if self.use_batched_format:
|
if self.use_batched_format:
|
||||||
workspace1 = (self.max_experts_per_worker, M, max(N, K))
|
workspace1 = (self.max_experts_per_worker, M, max(N, K))
|
||||||
workspace2 = (self.max_experts_per_worker, M, (N // 2))
|
workspace2 = (self.max_experts_per_worker, M, activation_out_dim)
|
||||||
output = (self.max_experts_per_worker, M, K)
|
output = (self.max_experts_per_worker, M, K)
|
||||||
else:
|
else:
|
||||||
workspace1 = (M * topk, max(2 * N, K))
|
workspace1 = (M * topk, max(2 * N, K))
|
||||||
@@ -896,9 +906,11 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
global_num_experts: int,
|
global_num_experts: int,
|
||||||
local_num_experts: int,
|
local_num_experts: int,
|
||||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||||
|
activation: str,
|
||||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||||
|
activation_out_dim = self.adjust_N_for_activation(N, activation)
|
||||||
workspace1 = (M * topk, max(N, K))
|
workspace1 = (M * topk, max(N, K))
|
||||||
workspace2 = (M * topk, max(N // 2, K))
|
workspace2 = (M * topk, max(activation_out_dim, K))
|
||||||
output = (M, K)
|
output = (M, K)
|
||||||
return (workspace1, workspace2, output)
|
return (workspace1, workspace2, output)
|
||||||
|
|
||||||
|
|||||||
@@ -143,6 +143,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
global_num_experts: int,
|
global_num_experts: int,
|
||||||
local_num_experts: int,
|
local_num_experts: int,
|
||||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||||
|
activation: str,
|
||||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||||
assert self.block_shape is not None
|
assert self.block_shape is not None
|
||||||
block_m = self.block_shape[0]
|
block_m = self.block_shape[0]
|
||||||
@@ -151,7 +152,8 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
)
|
)
|
||||||
assert M_sum % block_m == 0
|
assert M_sum % block_m == 0
|
||||||
|
|
||||||
workspace1 = (M_sum, max(N // 2, K))
|
activation_out_dim = self.adjust_N_for_activation(N, activation)
|
||||||
|
workspace1 = (M_sum, max(activation_out_dim, K))
|
||||||
workspace2 = (M_sum, max(N, K))
|
workspace2 = (M_sum, max(N, K))
|
||||||
output = (M, K)
|
output = (M, K)
|
||||||
return (workspace1, workspace2, output)
|
return (workspace1, workspace2, output)
|
||||||
@@ -163,11 +165,13 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
block_k = self.block_shape[1]
|
block_k = self.block_shape[1]
|
||||||
scale_fmt = DeepGemmQuantScaleFMT.from_oracle()
|
scale_fmt = DeepGemmQuantScaleFMT.from_oracle()
|
||||||
|
|
||||||
|
M_sum, N = input.size()
|
||||||
|
activation_out_dim = self.adjust_N_for_activation(N, activation)
|
||||||
|
|
||||||
# 1. DeepGemm UE8M0: use packed per-token-group quant
|
# 1. DeepGemm UE8M0: use packed per-token-group quant
|
||||||
if scale_fmt == DeepGemmQuantScaleFMT.UE8M0:
|
if scale_fmt == DeepGemmQuantScaleFMT.UE8M0:
|
||||||
M_sum, N = input.size()
|
|
||||||
act_out = torch.empty(
|
act_out = torch.empty(
|
||||||
(M_sum, N // 2), dtype=input.dtype, device=input.device
|
(M_sum, activation_out_dim), dtype=input.dtype, device=input.device
|
||||||
)
|
)
|
||||||
self.activation(activation, act_out, input)
|
self.activation(activation, act_out, input)
|
||||||
a2q, a2q_scale = per_token_group_quant_fp8_packed_for_deepgemm(
|
a2q, a2q_scale = per_token_group_quant_fp8_packed_for_deepgemm(
|
||||||
@@ -187,8 +191,9 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 3. fallback path for non-SiLU activations in non‑UE8M0 cases.
|
# 3. fallback path for non-SiLU activations in non‑UE8M0 cases.
|
||||||
M_sum, N = input.size()
|
act_out = torch.empty(
|
||||||
act_out = torch.empty((M_sum, N // 2), dtype=input.dtype, device=input.device)
|
(M_sum, activation_out_dim), dtype=input.dtype, device=input.device
|
||||||
|
)
|
||||||
self.activation(activation, act_out, input)
|
self.activation(activation, act_out, input)
|
||||||
return per_token_group_quant_fp8(
|
return per_token_group_quant_fp8(
|
||||||
act_out, block_k, column_major_scales=True, out_q=output
|
act_out, block_k, column_major_scales=True, out_q=output
|
||||||
@@ -254,8 +259,9 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
(a1q, a1q_scale), (w1, self.w1_scale), mm1_out, expert_ids
|
(a1q, a1q_scale), (w1, self.w1_scale), mm1_out, expert_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
|
activation_out_dim = self.adjust_N_for_activation(N, activation)
|
||||||
quant_out = _resize_cache(
|
quant_out = _resize_cache(
|
||||||
workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, N // 2)
|
workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, activation_out_dim)
|
||||||
)
|
)
|
||||||
a2q, a2q_scale = self._act_mul_quant(
|
a2q, a2q_scale = self._act_mul_quant(
|
||||||
input=mm1_out.view(-1, N), output=quant_out, activation=activation
|
input=mm1_out.view(-1, N), output=quant_out, activation=activation
|
||||||
|
|||||||
@@ -76,6 +76,7 @@ class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
|
|||||||
global_num_experts: int,
|
global_num_experts: int,
|
||||||
local_num_experts: int,
|
local_num_experts: int,
|
||||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||||
|
activation: str,
|
||||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|||||||
@@ -91,6 +91,7 @@ class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
global_num_experts: int,
|
global_num_experts: int,
|
||||||
local_num_experts: int,
|
local_num_experts: int,
|
||||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||||
|
activation: str,
|
||||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||||
# We use global_num_experts due to how moe_align_block_size handles
|
# We use global_num_experts due to how moe_align_block_size handles
|
||||||
# expert_maps.
|
# expert_maps.
|
||||||
|
|||||||
@@ -103,6 +103,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
global_num_experts: int,
|
global_num_experts: int,
|
||||||
local_num_experts: int,
|
local_num_experts: int,
|
||||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||||
|
activation: str,
|
||||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||||
# We use global_num_experts due to how moe_align_block_size handles
|
# We use global_num_experts due to how moe_align_block_size handles
|
||||||
# expert_maps.
|
# expert_maps.
|
||||||
|
|||||||
@@ -673,6 +673,7 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
global_num_experts: int,
|
global_num_experts: int,
|
||||||
local_num_experts: int,
|
local_num_experts: int,
|
||||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||||
|
activation: str,
|
||||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||||
num_dp = self.num_dispatchers
|
num_dp = self.num_dispatchers
|
||||||
num_experts = local_num_experts
|
num_experts = local_num_experts
|
||||||
@@ -867,12 +868,14 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
global_num_experts: int,
|
global_num_experts: int,
|
||||||
local_num_experts: int,
|
local_num_experts: int,
|
||||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||||
|
activation: str,
|
||||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||||
num_dp = self.num_dispatchers
|
num_dp = self.num_dispatchers
|
||||||
num_experts = local_num_experts
|
num_experts = local_num_experts
|
||||||
max_num_tokens = self.max_num_tokens
|
max_num_tokens = self.max_num_tokens
|
||||||
|
activation_out_dim = self.adjust_N_for_activation(N, activation)
|
||||||
workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N))
|
workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N))
|
||||||
workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2))
|
workspace2 = (num_experts, max_num_tokens * num_dp, activation_out_dim)
|
||||||
output = (num_experts, max_num_tokens * num_dp, K)
|
output = (num_experts, max_num_tokens * num_dp, K)
|
||||||
return (workspace13, workspace2, output)
|
return (workspace13, workspace2, output)
|
||||||
|
|
||||||
@@ -947,7 +950,10 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
# We can reuse the memory between these because by the time we need
|
# We can reuse the memory between these because by the time we need
|
||||||
# cache3, we're done with cache1
|
# cache3, we're done with cache1
|
||||||
intermediate_cache1 = _resize_cache(workspace13, (E, max_num_tokens, N))
|
intermediate_cache1 = _resize_cache(workspace13, (E, max_num_tokens, N))
|
||||||
intermediate_cache2 = _resize_cache(workspace2, (E, max_num_tokens, N // 2))
|
activation_out_dim = self.adjust_N_for_activation(N, activation)
|
||||||
|
intermediate_cache2 = _resize_cache(
|
||||||
|
workspace2, (E, max_num_tokens, activation_out_dim)
|
||||||
|
)
|
||||||
|
|
||||||
# TODO(bnell): should this be done for any quantized type?
|
# TODO(bnell): should this be done for any quantized type?
|
||||||
if self.quant_config.use_fp8_w8a8:
|
if self.quant_config.use_fp8_w8a8:
|
||||||
@@ -978,7 +984,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
# TODO (bnell): use triton utility from batched deep gemm.
|
# TODO (bnell): use triton utility from batched deep gemm.
|
||||||
self.activation(
|
self.activation(
|
||||||
activation,
|
activation,
|
||||||
intermediate_cache2.view(-1, N // 2),
|
intermediate_cache2.view(-1, activation_out_dim),
|
||||||
intermediate_cache1.view(-1, N),
|
intermediate_cache1.view(-1, N),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -640,6 +640,7 @@ class MarlinExperts(MarlinExpertsBase):
|
|||||||
global_num_experts: int,
|
global_num_experts: int,
|
||||||
local_num_experts: int,
|
local_num_experts: int,
|
||||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||||
|
activation: str,
|
||||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||||
# Modular Kernel provisions output buffer from workspace1. However in
|
# Modular Kernel provisions output buffer from workspace1. However in
|
||||||
# the fused_marlin_moe() function, the final torch.sum(), is defined
|
# the fused_marlin_moe() function, the final torch.sum(), is defined
|
||||||
@@ -768,6 +769,7 @@ class BatchedMarlinExperts(MarlinExpertsBase):
|
|||||||
global_num_experts: int,
|
global_num_experts: int,
|
||||||
local_num_experts: int,
|
local_num_experts: int,
|
||||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||||
|
activation: str,
|
||||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||||
num_dispatchers = self.num_dispatchers
|
num_dispatchers = self.num_dispatchers
|
||||||
num_experts = local_num_experts
|
num_experts = local_num_experts
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from collections.abc import Callable
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
@@ -43,7 +42,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.utils import (
|
from vllm.model_executor.layers.fused_moe.utils import (
|
||||||
_resize_cache,
|
_resize_cache,
|
||||||
activation_without_mul,
|
apply_moe_activation,
|
||||||
disable_inplace,
|
disable_inplace,
|
||||||
moe_kernel_quantize_input,
|
moe_kernel_quantize_input,
|
||||||
)
|
)
|
||||||
@@ -1957,11 +1956,6 @@ def fused_experts(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
SILU_NO_MUL: str = activation_without_mul("silu")
|
|
||||||
GELU_NO_MUL: str = activation_without_mul("gelu")
|
|
||||||
RELU2_NO_MUL: str = activation_without_mul("relu2")
|
|
||||||
|
|
||||||
|
|
||||||
def _get_config_quant_dtype(
|
def _get_config_quant_dtype(
|
||||||
use_fp8_w8a8: bool,
|
use_fp8_w8a8: bool,
|
||||||
use_int8_w8a8: bool,
|
use_int8_w8a8: bool,
|
||||||
@@ -2094,8 +2088,13 @@ def fused_experts_impl(
|
|||||||
intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K)
|
intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K)
|
||||||
|
|
||||||
# This needs separate memory since it's used concurrently with cache1
|
# This needs separate memory since it's used concurrently with cache1
|
||||||
|
activation_out_dim = mk.FusedMoEPermuteExpertsUnpermute.adjust_N_for_activation(
|
||||||
|
N, activation
|
||||||
|
)
|
||||||
intermediate_cache2 = torch.empty(
|
intermediate_cache2 = torch.empty(
|
||||||
(M * top_k_num, N // 2), device=hidden_states.device, dtype=hidden_states.dtype
|
(M * top_k_num, activation_out_dim),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
if hidden_states.dtype == torch.bfloat16:
|
if hidden_states.dtype == torch.bfloat16:
|
||||||
@@ -2235,29 +2234,9 @@ def fused_experts_impl(
|
|||||||
B_bias=w1_bias,
|
B_bias=w1_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Activation function with multiplication
|
apply_moe_activation(
|
||||||
if activation == "silu":
|
activation, intermediate_cache2, intermediate_cache1.view(-1, N)
|
||||||
torch.ops._C.silu_and_mul(
|
)
|
||||||
intermediate_cache2, intermediate_cache1.view(-1, N)
|
|
||||||
)
|
|
||||||
elif activation == "gelu":
|
|
||||||
torch.ops._C.gelu_and_mul(
|
|
||||||
intermediate_cache2, intermediate_cache1.view(-1, N)
|
|
||||||
)
|
|
||||||
elif activation == "swigluoai":
|
|
||||||
# alpha = 1.702, limit = 7.0
|
|
||||||
torch.ops._C.swigluoai_and_mul(
|
|
||||||
intermediate_cache2, intermediate_cache1.view(-1, N)
|
|
||||||
)
|
|
||||||
# Activation function without multiplication
|
|
||||||
elif activation == SILU_NO_MUL:
|
|
||||||
intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N))
|
|
||||||
elif activation == GELU_NO_MUL:
|
|
||||||
intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N))
|
|
||||||
elif activation == RELU2_NO_MUL:
|
|
||||||
intermediate_cache2 = torch.square(F.relu(intermediate_cache1.view(-1, N)))
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}.")
|
|
||||||
|
|
||||||
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
||||||
A=intermediate_cache2,
|
A=intermediate_cache2,
|
||||||
@@ -2336,8 +2315,10 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
global_num_experts: int,
|
global_num_experts: int,
|
||||||
local_num_experts: int,
|
local_num_experts: int,
|
||||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||||
|
activation: str,
|
||||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||||
workspace1 = (M, topk, max(N // 2, K))
|
activation_out_dim = self.adjust_N_for_activation(N, activation)
|
||||||
|
workspace1 = (M, topk, max(activation_out_dim, K))
|
||||||
workspace2 = (M, topk, max(N, K))
|
workspace2 = (M, topk, max(N, K))
|
||||||
output = (M, K)
|
output = (M, K)
|
||||||
return (workspace1, workspace2, output)
|
return (workspace1, workspace2, output)
|
||||||
@@ -2412,8 +2393,9 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
|
|
||||||
# Note that the output tensor might be in workspace1
|
# Note that the output tensor might be in workspace1
|
||||||
intermediate_cache1 = _resize_cache(workspace2, (num_tokens, top_k_num, N))
|
intermediate_cache1 = _resize_cache(workspace2, (num_tokens, top_k_num, N))
|
||||||
|
cache2_dim = self.adjust_N_for_activation(N, activation)
|
||||||
intermediate_cache2 = _resize_cache(
|
intermediate_cache2 = _resize_cache(
|
||||||
workspace13, (num_tokens * top_k_num, N // 2)
|
workspace13, (num_tokens * top_k_num, cache2_dim)
|
||||||
)
|
)
|
||||||
intermediate_cache3 = _resize_cache(workspace2, (num_tokens, top_k_num, K))
|
intermediate_cache3 = _resize_cache(workspace2, (num_tokens, top_k_num, K))
|
||||||
|
|
||||||
@@ -2565,8 +2547,9 @@ class TritonWNA16Experts(TritonExperts):
|
|||||||
|
|
||||||
# Note that the output tensor might be in workspace1
|
# Note that the output tensor might be in workspace1
|
||||||
intermediate_cache1 = _resize_cache(workspace2, (num_tokens, top_k_num, N))
|
intermediate_cache1 = _resize_cache(workspace2, (num_tokens, top_k_num, N))
|
||||||
|
activation_out_dim = self.adjust_N_for_activation(N, activation)
|
||||||
intermediate_cache2 = _resize_cache(
|
intermediate_cache2 = _resize_cache(
|
||||||
workspace13, (num_tokens * top_k_num, N // 2)
|
workspace13, (num_tokens * top_k_num, activation_out_dim)
|
||||||
)
|
)
|
||||||
intermediate_cache3 = _resize_cache(workspace2, (num_tokens, top_k_num, K))
|
intermediate_cache3 = _resize_cache(workspace2, (num_tokens, top_k_num, K))
|
||||||
|
|
||||||
|
|||||||
@@ -323,10 +323,12 @@ class OAITritonExperts(BaseOAITritonExperts):
|
|||||||
global_num_experts: int,
|
global_num_experts: int,
|
||||||
local_num_experts: int,
|
local_num_experts: int,
|
||||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||||
|
activation: str,
|
||||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||||
# workspace are allocated inside the kernel
|
# workspace are allocated inside the kernel
|
||||||
|
activation_out_dim = self.adjust_N_for_activation(N, activation)
|
||||||
workspace1 = (0, 0)
|
workspace1 = (0, 0)
|
||||||
workspace2 = (M * topk, N // 2)
|
workspace2 = (M * topk, activation_out_dim)
|
||||||
output = (M, K)
|
output = (M, K)
|
||||||
return (workspace1, workspace2, output)
|
return (workspace1, workspace2, output)
|
||||||
|
|
||||||
@@ -415,9 +417,11 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
|
|||||||
global_num_experts: int,
|
global_num_experts: int,
|
||||||
local_num_experts: int,
|
local_num_experts: int,
|
||||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||||
|
activation: str,
|
||||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||||
# workspace are allocated inside the kernel
|
# workspace are allocated inside the kernel
|
||||||
workspace1 = (M * topk, N // 2)
|
activation_out_dim = self.adjust_N_for_activation(N, activation)
|
||||||
|
workspace1 = (M * topk, activation_out_dim)
|
||||||
workspace2 = (M * topk, max(N, K))
|
workspace2 = (M * topk, max(N, K))
|
||||||
output = (M, K)
|
output = (M, K)
|
||||||
return (workspace1, workspace2, output)
|
return (workspace1, workspace2, output)
|
||||||
@@ -443,8 +447,10 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
|
|||||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
):
|
):
|
||||||
if self.quant_config is None:
|
# Use local variable to help mypy narrow the type after None check
|
||||||
self.quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
quant_config = self.quant_config
|
||||||
|
if quant_config is None:
|
||||||
|
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||||
|
|
||||||
if expert_map is not None:
|
if expert_map is not None:
|
||||||
topk_ids = expert_map[topk_ids]
|
topk_ids = expert_map[topk_ids]
|
||||||
@@ -462,12 +468,10 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
|
|||||||
# type check, uint8 means mxfp4
|
# type check, uint8 means mxfp4
|
||||||
assert hidden_states.dtype == torch.bfloat16
|
assert hidden_states.dtype == torch.bfloat16
|
||||||
assert (
|
assert (
|
||||||
self.quant_config.w1_bias is None
|
quant_config.w1_bias is None or quant_config.w1_bias.dtype == torch.float32
|
||||||
or self.quant_config.w1_bias.dtype == torch.float32
|
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
self.quant_config.w2_bias is None
|
quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32
|
||||||
or self.quant_config.w2_bias.dtype == torch.float32
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Shape check, only check non-mxfp4
|
# Shape check, only check non-mxfp4
|
||||||
@@ -485,17 +489,18 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
|
|||||||
# Note that the output tensor might be in workspace13
|
# Note that the output tensor might be in workspace13
|
||||||
intermediate_cache1 = _resize_cache(workspace2, (batch_dim, M * topk, N))
|
intermediate_cache1 = _resize_cache(workspace2, (batch_dim, M * topk, N))
|
||||||
intermediate_cache3 = _resize_cache(workspace2, (batch_dim, M * topk, K))
|
intermediate_cache3 = _resize_cache(workspace2, (batch_dim, M * topk, K))
|
||||||
intermediate_cache2 = _resize_cache(workspace13, (M * topk, N // 2))
|
activation_out_dim = self.adjust_N_for_activation(N, activation)
|
||||||
|
intermediate_cache2 = _resize_cache(workspace13, (M * topk, activation_out_dim))
|
||||||
|
|
||||||
gammas = routing_data.gate_scal if routing_data else None
|
gammas = routing_data.gate_scal if routing_data else None
|
||||||
|
|
||||||
matmul_ogs(
|
matmul_ogs(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
w1,
|
w1,
|
||||||
self.quant_config.w1_bias,
|
quant_config.w1_bias,
|
||||||
routing_data,
|
routing_data,
|
||||||
gather_indx=gather_indx,
|
gather_indx=gather_indx,
|
||||||
precision_config=self.quant_config.w1_precision,
|
precision_config=quant_config.w1_precision,
|
||||||
gammas=gammas if apply_router_weight_on_input else None,
|
gammas=gammas if apply_router_weight_on_input else None,
|
||||||
fused_activation=None,
|
fused_activation=None,
|
||||||
y=intermediate_cache1,
|
y=intermediate_cache1,
|
||||||
@@ -515,10 +520,10 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
|
|||||||
matmul_ogs(
|
matmul_ogs(
|
||||||
intermediate_cache2[gather_indx.src_indx],
|
intermediate_cache2[gather_indx.src_indx],
|
||||||
w2,
|
w2,
|
||||||
self.quant_config.w2_bias,
|
quant_config.w2_bias,
|
||||||
routing_data,
|
routing_data,
|
||||||
scatter_indx=scatter_indx,
|
scatter_indx=scatter_indx,
|
||||||
precision_config=self.quant_config.w2_precision,
|
precision_config=quant_config.w2_precision,
|
||||||
gammas=None if apply_router_weight_on_input else gammas,
|
gammas=None if apply_router_weight_on_input else gammas,
|
||||||
y=intermediate_cache3,
|
y=intermediate_cache3,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.utils import (
|
from vllm.model_executor.layers.fused_moe.utils import (
|
||||||
_resize_cache,
|
_resize_cache,
|
||||||
|
apply_moe_activation,
|
||||||
count_expert_num_tokens,
|
count_expert_num_tokens,
|
||||||
disable_inplace,
|
disable_inplace,
|
||||||
)
|
)
|
||||||
@@ -542,6 +543,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
|||||||
global_num_experts: int,
|
global_num_experts: int,
|
||||||
local_num_experts: int,
|
local_num_experts: int,
|
||||||
expert_tokens_meta: ExpertTokensMetadata | None,
|
expert_tokens_meta: ExpertTokensMetadata | None,
|
||||||
|
activation: str,
|
||||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||||
"""
|
"""
|
||||||
Compute the shapes for the temporary and final outputs of the two gemms
|
Compute the shapes for the temporary and final outputs of the two gemms
|
||||||
@@ -572,19 +574,31 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def adjust_N_for_activation(N: int, activation: str) -> int:
|
||||||
|
"""
|
||||||
|
Calculate the output dimension for the activation function.
|
||||||
|
|
||||||
|
For *_no_mul activations (e.g. relu2_no_mul),
|
||||||
|
there's no gate/up split, so output size equals input size (N).
|
||||||
|
|
||||||
|
For regular gated activations (e.g., silu, gelu, swigluoai),
|
||||||
|
output size is N // 2 due to gate × activation(up) multiplication.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
N: The intermediate size (width of w1/w3 weights).
|
||||||
|
activation: The activation function name.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The output dimension after activation.
|
||||||
|
"""
|
||||||
|
is_no_mul = activation.endswith("_no_mul")
|
||||||
|
return N if is_no_mul else N // 2
|
||||||
|
|
||||||
def activation(
|
def activation(
|
||||||
self, activation: str, output: torch.Tensor, input: torch.Tensor
|
self, activation: str, output: torch.Tensor, input: torch.Tensor
|
||||||
) -> None:
|
) -> None:
|
||||||
assert output.size(-1) * 2 == input.size(-1)
|
apply_moe_activation(activation, output, input)
|
||||||
if activation == "silu":
|
|
||||||
torch.ops._C.silu_and_mul(output, input)
|
|
||||||
elif activation == "gelu":
|
|
||||||
torch.ops._C.gelu_and_mul(output, input)
|
|
||||||
elif activation == "swigluoai":
|
|
||||||
# alpha = 1.702, limit = 7.0
|
|
||||||
torch.ops._C.swigluoai_and_mul(output, input)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
|
||||||
|
|
||||||
def enable_chunking(self):
|
def enable_chunking(self):
|
||||||
return (
|
return (
|
||||||
@@ -761,6 +775,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
global_num_experts: int,
|
global_num_experts: int,
|
||||||
local_num_experts: int,
|
local_num_experts: int,
|
||||||
expert_tokens_meta: ExpertTokensMetadata | None,
|
expert_tokens_meta: ExpertTokensMetadata | None,
|
||||||
|
activation: str,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Allocate temporary and output buffers for the fused experts op.
|
Allocate temporary and output buffers for the fused experts op.
|
||||||
@@ -796,6 +811,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
# amount of workspace. Mark it None, so we allocate for
|
# amount of workspace. Mark it None, so we allocate for
|
||||||
# the worst-case scenario.
|
# the worst-case scenario.
|
||||||
expert_tokens_meta=None,
|
expert_tokens_meta=None,
|
||||||
|
activation=activation,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -814,6 +830,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
global_num_experts,
|
global_num_experts,
|
||||||
local_num_experts,
|
local_num_experts,
|
||||||
expert_tokens_meta,
|
expert_tokens_meta,
|
||||||
|
activation,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get final output shape based on the full M size.
|
# Get final output shape based on the full M size.
|
||||||
@@ -825,6 +842,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
global_num_experts,
|
global_num_experts,
|
||||||
local_num_experts,
|
local_num_experts,
|
||||||
expert_tokens_meta,
|
expert_tokens_meta,
|
||||||
|
activation,
|
||||||
)
|
)
|
||||||
|
|
||||||
# We can reuse the memory between cache1 and cache3 because by the
|
# We can reuse the memory between cache1 and cache3 because by the
|
||||||
@@ -1043,6 +1061,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
global_num_experts,
|
global_num_experts,
|
||||||
local_num_experts,
|
local_num_experts,
|
||||||
expert_tokens_meta,
|
expert_tokens_meta,
|
||||||
|
activation,
|
||||||
)
|
)
|
||||||
|
|
||||||
for chunk_idx in range(num_chunks):
|
for chunk_idx in range(num_chunks):
|
||||||
|
|||||||
@@ -299,6 +299,7 @@ class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
global_num_experts: int,
|
global_num_experts: int,
|
||||||
local_num_experts: int,
|
local_num_experts: int,
|
||||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||||
|
activation: str,
|
||||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||||
# Workspaces are managed internally by AITER.
|
# Workspaces are managed internally by AITER.
|
||||||
workspace1 = (0,)
|
workspace1 = (0,)
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ class TritonOrCutlassExperts(FallbackExperts):
|
|||||||
global_num_experts: int,
|
global_num_experts: int,
|
||||||
local_num_experts: int,
|
local_num_experts: int,
|
||||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||||
|
activation: str,
|
||||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||||
# Small batch fallback for sm100.
|
# Small batch fallback for sm100.
|
||||||
if self.is_sm100 and M <= 8:
|
if self.is_sm100 and M <= 8:
|
||||||
@@ -50,6 +51,7 @@ class TritonOrCutlassExperts(FallbackExperts):
|
|||||||
global_num_experts,
|
global_num_experts,
|
||||||
local_num_experts,
|
local_num_experts,
|
||||||
expert_tokens_meta,
|
expert_tokens_meta,
|
||||||
|
activation,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self.experts.workspace_shapes(
|
return self.experts.workspace_shapes(
|
||||||
@@ -60,6 +62,7 @@ class TritonOrCutlassExperts(FallbackExperts):
|
|||||||
global_num_experts,
|
global_num_experts,
|
||||||
local_num_experts,
|
local_num_experts,
|
||||||
expert_tokens_meta,
|
expert_tokens_meta,
|
||||||
|
activation,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _select_experts_impl(
|
def _select_experts_impl(
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ class TritonOrDeepGemmExperts(FallbackExperts):
|
|||||||
global_num_experts: int,
|
global_num_experts: int,
|
||||||
local_num_experts: int,
|
local_num_experts: int,
|
||||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||||
|
activation: str,
|
||||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||||
# Note: the deep gemm workspaces are strictly larger than the triton
|
# Note: the deep gemm workspaces are strictly larger than the triton
|
||||||
# workspaces so we can be pessimistic here and allocate for DeepGemm
|
# workspaces so we can be pessimistic here and allocate for DeepGemm
|
||||||
@@ -48,6 +49,7 @@ class TritonOrDeepGemmExperts(FallbackExperts):
|
|||||||
global_num_experts,
|
global_num_experts,
|
||||||
local_num_experts,
|
local_num_experts,
|
||||||
expert_tokens_meta,
|
expert_tokens_meta,
|
||||||
|
activation,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self.fallback_experts.workspace_shapes(
|
return self.fallback_experts.workspace_shapes(
|
||||||
@@ -58,6 +60,7 @@ class TritonOrDeepGemmExperts(FallbackExperts):
|
|||||||
global_num_experts,
|
global_num_experts,
|
||||||
local_num_experts,
|
local_num_experts,
|
||||||
expert_tokens_meta,
|
expert_tokens_meta,
|
||||||
|
activation,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _select_experts_impl(
|
def _select_experts_impl(
|
||||||
|
|||||||
@@ -57,6 +57,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
global_num_experts: int,
|
global_num_experts: int,
|
||||||
local_num_experts: int,
|
local_num_experts: int,
|
||||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||||
|
activation: str,
|
||||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||||
# The workspaces for this implementation are managed by flashinfer.
|
# The workspaces for this implementation are managed by flashinfer.
|
||||||
workspace1 = (0,)
|
workspace1 = (0,)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import functools
|
|||||||
from math import prod
|
from math import prod
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
@@ -324,6 +325,55 @@ def activation_without_mul(activation: str) -> str:
|
|||||||
return activation + "_no_mul"
|
return activation + "_no_mul"
|
||||||
|
|
||||||
|
|
||||||
|
RELU2_NO_MUL: str = activation_without_mul("relu2")
|
||||||
|
SILU_NO_MUL: str = activation_without_mul("silu")
|
||||||
|
GELU_NO_MUL: str = activation_without_mul("gelu")
|
||||||
|
|
||||||
|
|
||||||
|
def apply_moe_activation(
|
||||||
|
activation: str,
|
||||||
|
output: torch.Tensor,
|
||||||
|
input: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Apply MoE activation function.
|
||||||
|
|
||||||
|
For *_and_mul activations (silu, gelu, swigluoai):
|
||||||
|
- Expects output.size(-1) * 2 == input.size(-1)
|
||||||
|
|
||||||
|
For *_no_mul activations (silu_no_mul, gelu_no_mul, relu2_no_mul):
|
||||||
|
- Expects output.size(-1) == input.size(-1)
|
||||||
|
"""
|
||||||
|
is_no_mul = activation.endswith("_no_mul")
|
||||||
|
if is_no_mul:
|
||||||
|
assert output.size(-1) == input.size(-1), (
|
||||||
|
f"{activation} expects equal sizes: {output.size(-1)} vs {input.size(-1)}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert output.size(-1) * 2 == input.size(-1), (
|
||||||
|
f"{activation} expects 2x ratio: {output.size(-1) * 2} vs {input.size(-1)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Activations with gated multiplication (gate × activation(up))
|
||||||
|
if activation == "silu":
|
||||||
|
torch.ops._C.silu_and_mul(output, input)
|
||||||
|
elif activation == "gelu":
|
||||||
|
torch.ops._C.gelu_and_mul(output, input)
|
||||||
|
elif activation == "swigluoai":
|
||||||
|
torch.ops._C.swigluoai_and_mul(output, input)
|
||||||
|
# Activations without gated multiplication
|
||||||
|
elif activation == SILU_NO_MUL:
|
||||||
|
output.copy_(F.silu(input))
|
||||||
|
elif activation == GELU_NO_MUL:
|
||||||
|
output.copy_(F.gelu(input))
|
||||||
|
elif activation == RELU2_NO_MUL:
|
||||||
|
torch.square(F.relu(input), out=output)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
# Torch custom ops can't deal with outputs aliasing inputs so we need to
|
# Torch custom ops can't deal with outputs aliasing inputs so we need to
|
||||||
# disable inplace for torch >= 2.9.
|
# disable inplace for torch >= 2.9.
|
||||||
# See https://github.com/vllm-project/vllm/issues/26378
|
# See https://github.com/vllm-project/vllm/issues/26378
|
||||||
|
|||||||
Reference in New Issue
Block a user