[Refactor] Replace activation: str with MoEActivation enum (#33843)
Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -4,6 +4,11 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.activation import (
|
||||
MoEActivation,
|
||||
activation_without_mul,
|
||||
apply_moe_activation,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
RoutingMethodType,
|
||||
@@ -27,7 +32,6 @@ from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
|
||||
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
|
||||
UnquantizedFusedMoEMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import activation_without_mul
|
||||
from vllm.model_executor.layers.fused_moe.zero_expert_fused_moe import (
|
||||
ZeroExpertFusedMoE,
|
||||
)
|
||||
@@ -54,6 +58,7 @@ __all__ = [
|
||||
"FusedMoERouter",
|
||||
"FusedMoEConfig",
|
||||
"FusedMoEMethodBase",
|
||||
"MoEActivation",
|
||||
"UnquantizedFusedMoEMethod",
|
||||
"FusedMoeWeightScaleSupported",
|
||||
"FusedMoEPermuteExpertsUnpermute",
|
||||
@@ -63,6 +68,7 @@ __all__ = [
|
||||
"SharedFusedMoE",
|
||||
"ZeroExpertFusedMoE",
|
||||
"activation_without_mul",
|
||||
"apply_moe_activation",
|
||||
"override_config",
|
||||
"get_config",
|
||||
]
|
||||
|
||||
136
vllm/model_executor/layers/fused_moe/activation.py
Normal file
136
vllm/model_executor/layers/fused_moe/activation.py
Normal file
@@ -0,0 +1,136 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""MoE activation function enum and utilities."""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class MoEActivation(Enum):
|
||||
"""Activation functions for MoE layers."""
|
||||
|
||||
# Gated activations (gate * activation(up)) expect input of shape [..., 2*d]
|
||||
# and produce output of shape [..., d]
|
||||
SILU = "silu"
|
||||
GELU = "gelu"
|
||||
RELU2 = "relu2"
|
||||
SWIGLUOAI = "swigluoai"
|
||||
SWIGLUSTEP = "swiglustep"
|
||||
|
||||
# Non-gated activations (no mul with gate) expect input of shape [..., d]
|
||||
# and produce output of shape [..., d].
|
||||
# NOTE: Non-gated activations require the "_no_mul" suffix to be present.
|
||||
SILU_NO_MUL = "silu_no_mul"
|
||||
GELU_NO_MUL = "gelu_no_mul"
|
||||
RELU2_NO_MUL = "relu2_no_mul"
|
||||
|
||||
@property
|
||||
def is_gated(self) -> bool:
|
||||
"""Returns True if activation expects gate*activation(up) pattern.
|
||||
|
||||
Gated activations expect input tensor with 2x the output size,
|
||||
where the first half is the gate and second half is the up projection.
|
||||
"""
|
||||
return not self.value.endswith("_no_mul")
|
||||
|
||||
@property
|
||||
def custom_op_name(self) -> str:
|
||||
"""Maps to the CustomOp name of activations
|
||||
in vllm/model_executor/layers/activation.py."""
|
||||
return _CUSTOM_OP_NAMES[self]
|
||||
|
||||
def without_mul(self) -> "MoEActivation":
|
||||
"""Get the non-gated variant of this activation.
|
||||
|
||||
For activations that have a _no_mul variant, returns that variant.
|
||||
For activations without a _no_mul variant (or already _no_mul),
|
||||
returns self.
|
||||
"""
|
||||
return _WITHOUT_MUL.get(self, self)
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, s: str) -> "MoEActivation":
|
||||
"""Parse from string for backward compatibility."""
|
||||
for member in cls:
|
||||
if member.value == s:
|
||||
return member
|
||||
valid = [m.value for m in cls]
|
||||
raise ValueError(f"Unknown MoE activation: {s!r}. Valid activations: {valid}")
|
||||
|
||||
|
||||
# Module-level lookup tables used by MoEActivation functions.
|
||||
_CUSTOM_OP_NAMES: dict[MoEActivation, str] = {
|
||||
MoEActivation.SILU: "silu_and_mul",
|
||||
MoEActivation.GELU: "gelu_and_mul",
|
||||
MoEActivation.SWIGLUOAI: "swigluoai_and_mul",
|
||||
MoEActivation.SWIGLUSTEP: "swiglustep_and_mul",
|
||||
MoEActivation.RELU2: "relu2",
|
||||
MoEActivation.SILU_NO_MUL: "silu_and_mul",
|
||||
MoEActivation.GELU_NO_MUL: "gelu_and_mul",
|
||||
MoEActivation.RELU2_NO_MUL: "relu2",
|
||||
}
|
||||
|
||||
_WITHOUT_MUL: dict[MoEActivation, MoEActivation] = {
|
||||
MoEActivation.SILU: MoEActivation.SILU_NO_MUL,
|
||||
MoEActivation.GELU: MoEActivation.GELU_NO_MUL,
|
||||
MoEActivation.RELU2: MoEActivation.RELU2_NO_MUL,
|
||||
}
|
||||
|
||||
|
||||
def activation_without_mul(activation: str) -> str:
|
||||
"""Get the non-gated variant of an activation function.
|
||||
|
||||
Args:
|
||||
activation: The activation function name (e.g., "silu", "gelu")
|
||||
|
||||
Returns:
|
||||
The non-gated activation name (e.g., "silu_no_mul", "gelu_no_mul")
|
||||
"""
|
||||
return MoEActivation.from_str(activation).without_mul().value
|
||||
|
||||
|
||||
def apply_moe_activation(
|
||||
activation: MoEActivation,
|
||||
output: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Apply MoE activation function."""
|
||||
assert input.dim() == 2, "Input must be 2D"
|
||||
assert output.dim() == 2, "Output must be 2D"
|
||||
if activation.is_gated:
|
||||
assert output.size(-1) * 2 == input.size(-1), (
|
||||
f"{activation.value} expects 2x ratio: "
|
||||
f"{output.size(-1) * 2} vs {input.size(-1)}"
|
||||
)
|
||||
else:
|
||||
assert output.size(-1) == input.size(-1), (
|
||||
f"{activation.value} expects equal sizes: "
|
||||
f"{output.size(-1)} vs {input.size(-1)}"
|
||||
)
|
||||
|
||||
# Activations with gated multiplication (gate × activation(up))
|
||||
if activation == MoEActivation.SILU:
|
||||
torch.ops._C.silu_and_mul(output, input)
|
||||
elif activation == MoEActivation.GELU:
|
||||
torch.ops._C.gelu_and_mul(output, input)
|
||||
elif activation == MoEActivation.SWIGLUOAI:
|
||||
torch.ops._C.swigluoai_and_mul(output, input)
|
||||
elif activation == MoEActivation.SWIGLUSTEP:
|
||||
from vllm.model_executor.layers.activation import swiglustep_and_mul_triton
|
||||
|
||||
swiglustep_and_mul_triton(output, input)
|
||||
|
||||
# Activations without gated multiplication
|
||||
elif activation == MoEActivation.SILU_NO_MUL:
|
||||
output.copy_(F.silu(input))
|
||||
elif activation == MoEActivation.GELU_NO_MUL:
|
||||
output.copy_(F.gelu(input))
|
||||
elif activation == MoEActivation.RELU2_NO_MUL:
|
||||
F.relu(input, inplace=True)
|
||||
torch.square(input, out=output)
|
||||
else:
|
||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
||||
|
||||
return output
|
||||
@@ -7,6 +7,7 @@ import torch
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.forward_context import get_forward_context, is_forward_context_available
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
@@ -303,8 +304,8 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: str) -> bool:
|
||||
return activation in ["silu"]
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
return activation == MoEActivation.SILU
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
@@ -338,7 +339,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# FIXME (varun): We should be able to dispatch only from the leader
|
||||
# DP ranks in the case of TP > 1. At the moment, all the Ranks
|
||||
@@ -389,7 +390,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
|
||||
@@ -14,6 +14,7 @@ from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
|
||||
OCP_MX_DTYPES,
|
||||
OCP_MX_Scheme,
|
||||
@@ -1132,7 +1133,7 @@ class FusedMoEConfig:
|
||||
intermediate_size_per_partition: int
|
||||
num_local_experts: int
|
||||
num_logical_experts: int
|
||||
activation: str
|
||||
activation: MoEActivation
|
||||
device: torch.device | str
|
||||
routing_method: RoutingMethodType
|
||||
moe_parallel_config: FusedMoEParallelConfig
|
||||
|
||||
@@ -9,6 +9,7 @@ from torch.nn import functional as F
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._custom_ops import cpu_fused_moe, cpu_prepack_moe_weight
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.quantization.utils.layer_utils import replace_parameter
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
@@ -36,9 +37,9 @@ def _swigluoai_forward_native(
|
||||
# Map activation names to their native forward functions.
|
||||
# Uses static methods or standalone functions to avoid instantiating CustomOp
|
||||
# classes, which would call get_current_vllm_config() before config is set.
|
||||
_CPU_MOE_ACT_FN: dict[str, Callable[[torch.Tensor], torch.Tensor]] = {
|
||||
"silu": SiluAndMul.forward_native,
|
||||
"swigluoai": _swigluoai_forward_native,
|
||||
_CPU_MOE_ACT_FN: dict[MoEActivation, Callable[[torch.Tensor], torch.Tensor]] = {
|
||||
MoEActivation.SILU: SiluAndMul.forward_native,
|
||||
MoEActivation.SWIGLUOAI: _swigluoai_forward_native,
|
||||
}
|
||||
|
||||
|
||||
@@ -168,9 +169,9 @@ class SGLFusedMOE:
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
activation: MoEActivation = MoEActivation.SILU,
|
||||
) -> torch.Tensor:
|
||||
assert activation == "silu", f"{activation} is not supported."
|
||||
assert activation == MoEActivation.SILU, f"{activation} is not supported."
|
||||
assert not apply_router_weight_on_input
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
@@ -235,7 +236,7 @@ class CPUFusedMOE:
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
activation: MoEActivation = MoEActivation.SILU,
|
||||
) -> torch.Tensor:
|
||||
assert activation in _CPU_MOE_ACT_FN, f"{activation} is not supported."
|
||||
|
||||
@@ -353,7 +354,7 @@ class CPUFusedMOE:
|
||||
input: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int = -1,
|
||||
skip_weighted: bool = False,
|
||||
) -> torch.Tensor:
|
||||
@@ -371,7 +372,7 @@ class CPUFusedMOE:
|
||||
getattr(layer, "w2_bias", None),
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation,
|
||||
activation.value,
|
||||
self.isa,
|
||||
skip_weighted,
|
||||
)
|
||||
@@ -383,7 +384,7 @@ class CPUFusedMOE:
|
||||
input: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int = -1,
|
||||
skip_weighted: bool = False,
|
||||
) -> torch.Tensor:
|
||||
@@ -419,6 +420,7 @@ def cpu_fused_moe_torch(
|
||||
global_num_experts: int = -1,
|
||||
skip_weighted: bool = False,
|
||||
) -> None:
|
||||
act = MoEActivation.from_str(activation)
|
||||
layer = _CPU_MOE_LAYER_CACHE[layer_id]()
|
||||
|
||||
# Ref code from https://github.com/sgl-project/sglang/blob/716e682721397df103f347d22da8bd46c6016dab/python/sglang/srt/layers/moe/fused_moe_native.py#L53
|
||||
@@ -442,7 +444,7 @@ def cpu_fused_moe_torch(
|
||||
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
||||
|
||||
gate_up = layer.gate_up_linear[i](tokens_for_this_expert) # type: ignore
|
||||
gate_up = _CPU_MOE_ACT_FN[activation](gate_up)
|
||||
gate_up = _CPU_MOE_ACT_FN[act](gate_up)
|
||||
expert_out = layer.down_linear[i](gate_up) # type: ignore
|
||||
outputs.append(expert_out)
|
||||
start_idx = end_idx
|
||||
|
||||
@@ -7,6 +7,10 @@ import torch
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.activation import (
|
||||
MoEActivation,
|
||||
apply_moe_activation,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
@@ -25,7 +29,6 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_resize_cache,
|
||||
apply_moe_activation,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
@@ -51,7 +54,7 @@ def run_cutlass_moe_fp8(
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
w1_scale: torch.Tensor | None,
|
||||
@@ -73,7 +76,7 @@ def run_cutlass_moe_fp8(
|
||||
):
|
||||
a1q = hidden_states
|
||||
|
||||
assert not activation.endswith("_no_mul"), "Only gated activation is supported"
|
||||
assert activation.is_gated, "Only gated activation is supported"
|
||||
assert w1_scale is not None
|
||||
assert w2_scale is not None
|
||||
assert w1.dtype == torch.float8_e4m3fn
|
||||
@@ -310,8 +313,12 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: str) -> bool:
|
||||
return activation in ["silu", "gelu", "swigluoai"]
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
return activation in [
|
||||
MoEActivation.SILU,
|
||||
MoEActivation.GELU,
|
||||
MoEActivation.SWIGLUOAI,
|
||||
]
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
||||
@@ -325,7 +332,7 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
@@ -415,7 +422,7 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
activation_out_dim = self.adjust_N_for_activation(N, activation)
|
||||
workspace1 = (M * topk, max(N, K))
|
||||
@@ -456,7 +463,7 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
num_dp = self.num_dispatchers
|
||||
assert num_dp is not None
|
||||
@@ -489,7 +496,7 @@ def run_cutlass_moe_fp4(
|
||||
w2_alphas: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
m: int,
|
||||
@@ -612,7 +619,7 @@ def run_cutlass_moe_fp4(
|
||||
blockscale_offsets[:-1],
|
||||
)
|
||||
del rep_a_fp4, rep_a_blockscale
|
||||
if activation == "silu":
|
||||
if activation == MoEActivation.SILU:
|
||||
# Fused SiLU+Mul+NVFP4 quantization
|
||||
# Note: c2 workspace is no longer needed since SiLU is fused with quantization.
|
||||
# c3 reuses workspace13 after c1 is consumed.
|
||||
@@ -682,8 +689,12 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
return (weight_key, activation_key) == (kNvfp4Static, kNvfp4Dynamic)
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: str) -> bool:
|
||||
return activation in ["silu", "gelu", "swigluoai"]
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
return activation in [
|
||||
MoEActivation.SILU,
|
||||
MoEActivation.GELU,
|
||||
MoEActivation.SWIGLUOAI,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
@@ -716,7 +727,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
workspace1 = (M * topk, max(2 * N, K))
|
||||
workspace2 = (M * topk, N)
|
||||
@@ -731,7 +742,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None, # unused
|
||||
@@ -776,7 +787,7 @@ def run_cutlass_moe_w4a8_fp8(
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
w1_scale: torch.Tensor | None,
|
||||
@@ -970,7 +981,7 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: str) -> bool:
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
raise NotImplementedError(
|
||||
"CutlassExpertsW4A8Fp8 is not yet used by an Oracle. "
|
||||
"This method should not be called."
|
||||
@@ -1005,7 +1016,7 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
activation_out_dim = self.adjust_N_for_activation(N, activation)
|
||||
workspace1 = (M * topk, max(N, K))
|
||||
@@ -1021,7 +1032,7 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
@@ -1094,7 +1105,7 @@ def cutlass_moe_w4a8_fp8(
|
||||
s_strides2: torch.Tensor,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
moe_config: FusedMoEConfig,
|
||||
activation: str = "silu",
|
||||
activation: MoEActivation = MoEActivation.SILU,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
@@ -1137,7 +1148,7 @@ def cutlass_moe_w4a8_fp8(
|
||||
dtype: torch.int64
|
||||
- per_act_token (Optional[bool]): Whether the scale is per-token or
|
||||
per-tensor.
|
||||
- activation (str): The activation function to use.
|
||||
- activation (MoEActivation): The activation function to use.
|
||||
- expert_map (Optional[torch.Tensor]): In the case of Expert parallel,
|
||||
every Rank is responsible for a subset of experts. expert_map is a
|
||||
mapping from global expert-id to local expert-id. When expert_map[i]
|
||||
|
||||
@@ -5,6 +5,7 @@ import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
@@ -145,8 +146,8 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: str) -> bool:
|
||||
return activation in ["silu", "swiglustep"]
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
return activation in [MoEActivation.SILU, MoEActivation.SWIGLUSTEP]
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
@@ -171,7 +172,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
assert self.block_shape is not None
|
||||
block_m = self.block_shape[0]
|
||||
@@ -187,7 +188,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
return (workspace1, workspace2, output)
|
||||
|
||||
def _act_mul_quant(
|
||||
self, input: torch.Tensor, output: torch.Tensor, activation: str
|
||||
self, input: torch.Tensor, output: torch.Tensor, activation: MoEActivation
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.block_shape is not None
|
||||
block_k = self.block_shape[1]
|
||||
@@ -210,7 +211,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
return a2q, a2q_scale
|
||||
|
||||
# 2. Hopper / non‑E8M0: prefer the fused SiLU+mul+quant kernel
|
||||
if activation == "silu":
|
||||
if activation == MoEActivation.SILU:
|
||||
use_ue8m0 = scale_fmt == DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0
|
||||
return silu_mul_per_token_group_quant_fp8_colmajor(
|
||||
input=input,
|
||||
@@ -235,7 +236,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
|
||||
@@ -6,6 +6,7 @@ from abc import ABC, abstractmethod
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
||||
|
||||
@@ -76,7 +77,7 @@ class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
|
||||
) and fallback_cls._supports_quant_scheme(weight_key, activation_key)
|
||||
|
||||
@classmethod
|
||||
def _supports_activation(cls, activation: str) -> bool:
|
||||
def _supports_activation(cls, activation: MoEActivation) -> bool:
|
||||
experts_cls, fallback_cls = cls.get_clses()
|
||||
return experts_cls._supports_activation(
|
||||
activation
|
||||
@@ -138,7 +139,7 @@ class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -159,7 +160,7 @@ class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
|
||||
@@ -6,6 +6,7 @@ import torch
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
@@ -72,8 +73,8 @@ class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: str) -> bool:
|
||||
return activation in ["silu"]
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
return activation == MoEActivation.SILU
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
@@ -101,7 +102,7 @@ class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# We use global_num_experts due to how moe_align_block_size handles
|
||||
# expert_maps.
|
||||
@@ -135,7 +136,7 @@ class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
|
||||
@@ -5,6 +5,7 @@ import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig,
|
||||
@@ -130,8 +131,8 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: str) -> bool:
|
||||
return activation in ["silu", "relu2_no_mul"]
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
@@ -164,7 +165,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# We use global_num_experts due to how moe_align_block_size handles
|
||||
# expert_maps.
|
||||
@@ -201,7 +202,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
@@ -214,8 +215,8 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
from flashinfer.fused_moe.core import ActivationType
|
||||
|
||||
activation_str_to_value_map = {
|
||||
"silu": ActivationType.Swiglu, # This is the default
|
||||
"relu2_no_mul": ActivationType.Relu2,
|
||||
MoEActivation.SILU: ActivationType.Swiglu, # This is the default
|
||||
MoEActivation.RELU2_NO_MUL: ActivationType.Relu2,
|
||||
}
|
||||
assert activation in activation_str_to_value_map, (
|
||||
f"{activation=} missing from {activation_str_to_value_map.keys()=}"
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
@@ -50,9 +51,9 @@ def _supports_quant_scheme(
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
|
||||
def _supports_activation(activation: str) -> bool:
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
"""Supports silu activation only."""
|
||||
return activation in ["silu"]
|
||||
return activation == MoEActivation.SILU
|
||||
|
||||
|
||||
def _supports_routing_method(
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
@@ -698,7 +699,7 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: str) -> bool:
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
raise NotImplementedError(
|
||||
"NaiveBatchedExperts is not yet used by an Oracle. "
|
||||
"This method should not be called."
|
||||
@@ -730,7 +731,7 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
assert self.num_dispatchers is not None
|
||||
assert self.max_num_tokens is not None
|
||||
@@ -757,7 +758,7 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
@@ -942,14 +943,14 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: str) -> bool:
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
return activation in [
|
||||
"silu",
|
||||
"gelu",
|
||||
"swigluoai",
|
||||
"silu_no_mul",
|
||||
"gelu_no_mul",
|
||||
"relu2_no_mul",
|
||||
MoEActivation.SILU,
|
||||
MoEActivation.GELU,
|
||||
MoEActivation.SWIGLUOAI,
|
||||
MoEActivation.SILU_NO_MUL,
|
||||
MoEActivation.GELU_NO_MUL,
|
||||
MoEActivation.RELU2_NO_MUL,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
@@ -975,7 +976,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
assert self.num_dispatchers is not None
|
||||
assert self.max_num_tokens is not None
|
||||
@@ -996,7 +997,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
|
||||
@@ -8,6 +8,10 @@ import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.activation import (
|
||||
MoEActivation,
|
||||
apply_moe_activation,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
@@ -23,7 +27,6 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_resize_cache,
|
||||
apply_moe_activation,
|
||||
disable_inplace,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
@@ -59,9 +62,9 @@ def _fused_marlin_moe(
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
activation: MoEActivation = MoEActivation.SILU,
|
||||
activation_func: Callable[
|
||||
[str, torch.Tensor, torch.Tensor], None
|
||||
[MoEActivation, torch.Tensor, torch.Tensor], None
|
||||
] = apply_moe_activation,
|
||||
input_global_scale1: torch.Tensor | None = None,
|
||||
input_global_scale2: torch.Tensor | None = None,
|
||||
@@ -83,7 +86,7 @@ def _fused_marlin_moe(
|
||||
assert hidden_states.ndim == 2
|
||||
M, K = hidden_states.size()
|
||||
N = marlin_moe_intermediate_size(w1, w2)
|
||||
w13_num_shards = 1 if "no_mul" in activation else 2
|
||||
w13_num_shards = 2 if activation.is_gated else 1
|
||||
if workspace is None:
|
||||
workspace = marlin_make_workspace_new(hidden_states.device, 4)
|
||||
|
||||
@@ -215,9 +218,9 @@ def fused_marlin_moe(
|
||||
quant_type_id: int,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
activation: str = "silu",
|
||||
activation: MoEActivation = MoEActivation.SILU,
|
||||
activation_func: Callable[
|
||||
[str, torch.Tensor, torch.Tensor], None
|
||||
[MoEActivation, torch.Tensor, torch.Tensor], None
|
||||
] = apply_moe_activation,
|
||||
moe_sum: Callable[[torch.Tensor, torch.Tensor], None] | None = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
@@ -377,7 +380,7 @@ def batched_fused_marlin_moe(
|
||||
quant_type_id: int,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
activation: str | None = "silu",
|
||||
activation: MoEActivation = MoEActivation.SILU,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
global_scale1: torch.Tensor | None = None,
|
||||
global_scale2: torch.Tensor | None = None,
|
||||
@@ -579,14 +582,14 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
return weight_key in SUPPORTED_W
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: str) -> bool:
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
return activation in [
|
||||
"silu",
|
||||
"gelu",
|
||||
"swigluoai",
|
||||
"silu_no_mul",
|
||||
"gelu_no_mul",
|
||||
"relu2_no_mul",
|
||||
MoEActivation.SILU,
|
||||
MoEActivation.GELU,
|
||||
MoEActivation.SWIGLUOAI,
|
||||
MoEActivation.SILU_NO_MUL,
|
||||
MoEActivation.GELU_NO_MUL,
|
||||
MoEActivation.RELU2_NO_MUL,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
@@ -661,7 +664,7 @@ class MarlinExperts(MarlinExpertsBase):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# Modular Kernel provisions output buffer from workspace1. However in
|
||||
# the fused_marlin_moe() function, the final torch.sum(), is defined
|
||||
@@ -692,7 +695,7 @@ class MarlinExperts(MarlinExpertsBase):
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
@@ -788,7 +791,7 @@ class BatchedMarlinExperts(MarlinExpertsBase):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
assert self.num_dispatchers is not None
|
||||
assert self.max_num_tokens is not None
|
||||
@@ -808,7 +811,7 @@ class BatchedMarlinExperts(MarlinExpertsBase):
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
|
||||
@@ -17,6 +17,10 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.activation import (
|
||||
MoEActivation,
|
||||
apply_moe_activation,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG,
|
||||
FusedMoEConfig,
|
||||
@@ -32,7 +36,6 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_resize_cache,
|
||||
apply_moe_activation,
|
||||
disable_inplace,
|
||||
moe_kernel_quantize_input,
|
||||
)
|
||||
@@ -1468,6 +1471,7 @@ def outplace_fused_experts_fake(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
@@ -1521,7 +1525,7 @@ def fused_experts(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
activation: MoEActivation = MoEActivation.SILU,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
@@ -1539,7 +1543,7 @@ def fused_experts(
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
activation=activation.value,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
use_fp8_w8a8=quant_config.use_fp8_w8a8,
|
||||
use_int8_w8a8=quant_config.use_int8_w8a8,
|
||||
@@ -1618,6 +1622,9 @@ def fused_experts_impl(
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# Convert string activation to enum for internal use
|
||||
activation_enum = MoEActivation.from_str(activation)
|
||||
|
||||
# Check constraints.
|
||||
if use_int4_w4a16:
|
||||
assert hidden_states.size(1) // 2 == w1.size(2), "Hidden size mismatch"
|
||||
@@ -1692,7 +1699,7 @@ def fused_experts_impl(
|
||||
|
||||
# This needs separate memory since it's used concurrently with cache1
|
||||
activation_out_dim = mk.FusedMoEPermuteExpertsUnpermute.adjust_N_for_activation(
|
||||
N, activation
|
||||
N, activation_enum
|
||||
)
|
||||
intermediate_cache2 = torch.empty(
|
||||
(M * top_k_num, activation_out_dim),
|
||||
@@ -1832,7 +1839,7 @@ def fused_experts_impl(
|
||||
)
|
||||
|
||||
apply_moe_activation(
|
||||
activation, intermediate_cache2, intermediate_cache1.view(-1, N)
|
||||
activation_enum, intermediate_cache2, intermediate_cache1.view(-1, N)
|
||||
)
|
||||
|
||||
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
||||
@@ -1932,8 +1939,13 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: str) -> bool:
|
||||
return activation in ["silu", "gelu", "swigluoai", "swiglustep"]
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
return activation in [
|
||||
MoEActivation.SILU,
|
||||
MoEActivation.GELU,
|
||||
MoEActivation.SWIGLUOAI,
|
||||
MoEActivation.SWIGLUSTEP,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
@@ -1957,7 +1969,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
activation_out_dim = self.adjust_N_for_activation(N, activation)
|
||||
workspace1 = (M, topk, max(activation_out_dim, K))
|
||||
@@ -1973,7 +1985,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
@@ -2138,7 +2150,7 @@ class TritonWNA16Experts(TritonExperts):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: str) -> bool:
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
raise NotImplementedError(
|
||||
"TritonWNA16Experts is not yet used by an Oracle. "
|
||||
"This method should not be called."
|
||||
@@ -2159,7 +2171,7 @@ class TritonWNA16Experts(TritonExperts):
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
|
||||
@@ -7,6 +7,7 @@ import torch
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG,
|
||||
FusedMoEParallelConfig,
|
||||
@@ -172,7 +173,7 @@ def triton_kernel_moe_forward(
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
activation: str = "silu",
|
||||
activation: MoEActivation = MoEActivation.SWIGLUOAI,
|
||||
quant_config: FusedMoEQuantConfig | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
@@ -211,7 +212,7 @@ def triton_kernel_fused_experts(
|
||||
gather_indx, # GatherIndx
|
||||
scatter_indx, # ScatterIndx
|
||||
topk: int,
|
||||
activation: str = "silu",
|
||||
activation: MoEActivation = MoEActivation.SWIGLUOAI,
|
||||
quant_config: FusedMoEQuantConfig | None = None,
|
||||
swiglu_alpha: float = 1.702,
|
||||
swiglu_limit: float = 7.0,
|
||||
@@ -222,6 +223,9 @@ def triton_kernel_fused_experts(
|
||||
a1q_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Triton implementation of fused expert computation using OAI kernels."""
|
||||
assert activation == MoEActivation.SWIGLUOAI, (
|
||||
"Only SWIGLUOAI activation is supported"
|
||||
)
|
||||
if quant_config is None:
|
||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
|
||||
@@ -379,7 +383,7 @@ class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: str) -> bool:
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
raise NotImplementedError(
|
||||
"OAITritonExperts is not yet used by an Oracle. "
|
||||
"This method should not be called."
|
||||
@@ -463,7 +467,7 @@ class OAITritonExperts(BaseOAITritonExperts):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# workspace are allocated inside the kernel
|
||||
activation_out_dim = self.adjust_N_for_activation(N, activation)
|
||||
@@ -480,7 +484,7 @@ class OAITritonExperts(BaseOAITritonExperts):
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
@@ -547,7 +551,7 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# workspace are allocated inside the kernel
|
||||
activation_out_dim = self.adjust_N_for_activation(N, activation)
|
||||
@@ -567,7 +571,7 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
|
||||
@@ -20,6 +20,7 @@ from vllm.distributed import (
|
||||
from vllm.distributed.eplb.eplb_state import EplbLayerState, EplbState
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
@@ -500,7 +501,7 @@ class FusedMoE(CustomOp):
|
||||
# TODO(bnell): end attributes
|
||||
|
||||
self.apply_router_weight_on_input = apply_router_weight_on_input
|
||||
self.activation = activation
|
||||
self.activation = MoEActivation.from_str(activation)
|
||||
|
||||
self.router = create_fused_moe_router(
|
||||
top_k=top_k,
|
||||
@@ -554,7 +555,7 @@ class FusedMoE(CustomOp):
|
||||
has_bias=has_bias,
|
||||
is_act_and_mul=is_act_and_mul,
|
||||
is_lora_enabled=vllm_config.lora_config is not None,
|
||||
activation=activation,
|
||||
activation=self.activation,
|
||||
device=vllm_config.device_config.device,
|
||||
routing_method=self.routing_method_type,
|
||||
# TODO: in_dtype == out_dtype?
|
||||
|
||||
@@ -12,6 +12,10 @@ import torch
|
||||
import vllm.envs as envs
|
||||
from vllm.forward_context import get_forward_context, is_forward_context_available
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.activation import (
|
||||
MoEActivation,
|
||||
apply_moe_activation,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
@@ -19,7 +23,6 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_resize_cache,
|
||||
apply_moe_activation,
|
||||
count_expert_num_tokens,
|
||||
disable_inplace,
|
||||
)
|
||||
@@ -536,7 +539,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def _supports_activation(activation: str) -> bool:
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
"""
|
||||
Whether the kernel supports a particular act function.
|
||||
"""
|
||||
@@ -658,7 +661,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
"""
|
||||
Compute the shapes for the temporary and final outputs of the two gemms
|
||||
@@ -690,7 +693,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def adjust_N_for_activation(N: int, activation: str) -> int:
|
||||
def adjust_N_for_activation(N: int, activation: MoEActivation) -> int:
|
||||
"""
|
||||
Calculate the output dimension for the activation function.
|
||||
|
||||
@@ -702,16 +705,15 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
|
||||
Args:
|
||||
N: The intermediate size (width of w1/w3 weights).
|
||||
activation: The activation function name.
|
||||
activation: The activation function enum.
|
||||
|
||||
Returns:
|
||||
The output dimension after activation.
|
||||
"""
|
||||
is_no_mul = activation.endswith("_no_mul")
|
||||
return N if is_no_mul else N // 2
|
||||
return N if not activation.is_gated else N // 2
|
||||
|
||||
def activation(
|
||||
self, activation: str, output: torch.Tensor, input: torch.Tensor
|
||||
self, activation: MoEActivation, output: torch.Tensor, input: torch.Tensor
|
||||
) -> None:
|
||||
apply_moe_activation(activation, output, input)
|
||||
|
||||
@@ -732,7 +734,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
@@ -892,7 +894,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Allocate temporary and output buffers for the fused experts op.
|
||||
@@ -1135,7 +1137,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
@@ -1309,7 +1311,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
activation: MoEActivation = MoEActivation.SILU,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
@@ -1326,7 +1328,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
- topk_weights (torch.Tensor): The topk weights applied at the end of
|
||||
the layer.
|
||||
- topk_ids (torch.Tensor): A map of row to expert id.
|
||||
- activation (str): The activation function to apply after the first
|
||||
- activation (MoEActivation): The activation function to apply after the first
|
||||
MoE layer.
|
||||
- global_num_experts (int): The total number of experts in the global
|
||||
expert space.
|
||||
|
||||
@@ -7,6 +7,7 @@ import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG,
|
||||
FusedMoEParallelConfig,
|
||||
@@ -184,7 +185,7 @@ def rocm_aiter_fused_experts(
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
activation: MoEActivation = MoEActivation.SILU,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
quant_config: FusedMoEQuantConfig | None = None,
|
||||
@@ -196,9 +197,13 @@ def rocm_aiter_fused_experts(
|
||||
if quant_config is None:
|
||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
|
||||
activation_method = (
|
||||
ActivationMethod.SILU if activation == "silu" else ActivationMethod.GELU
|
||||
)
|
||||
if activation == MoEActivation.SILU:
|
||||
activation_method = ActivationMethod.SILU
|
||||
elif activation == MoEActivation.GELU:
|
||||
activation_method = ActivationMethod.GELU
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation: {activation}")
|
||||
|
||||
# All AITER Fused MoE kernels are expecting the following datatypes
|
||||
topk_weights = topk_weights.to(torch.float32)
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
@@ -322,8 +327,8 @@ class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: str) -> bool:
|
||||
return activation in ["silu", "gelu"]
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
return activation in [MoEActivation.SILU, MoEActivation.GELU]
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
@@ -347,7 +352,7 @@ class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# Workspaces are managed internally by AITER.
|
||||
workspace1 = (0,)
|
||||
@@ -363,7 +368,7 @@ class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
@@ -45,7 +46,7 @@ class TritonOrCutlassExperts(FallbackExperts):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# Small batch fallback for sm100.
|
||||
if self.is_sm100 and M <= 8:
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
@@ -45,7 +46,7 @@ class TritonOrDeepGemmExperts(FallbackExperts):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# Note: the deep gemm workspaces are strictly larger than the triton
|
||||
# workspaces so we can be pessimistic here and allocate for DeepGemm
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
@@ -64,7 +65,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: str) -> bool:
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
raise NotImplementedError(
|
||||
"TrtLlmGenExperts is not yet used by an Oracle. "
|
||||
"This method should not be called."
|
||||
@@ -95,7 +96,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# The workspaces for this implementation are managed by flashinfer.
|
||||
workspace1 = (0,)
|
||||
@@ -111,7 +112,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
|
||||
@@ -4,7 +4,6 @@ import functools
|
||||
from math import prod
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
@@ -341,65 +340,6 @@ def _validate_scale_shape(
|
||||
assert a_scale.shape == expected, f"{a_scale.shape} == {expected}"
|
||||
|
||||
|
||||
def activation_without_mul(activation: str) -> str:
|
||||
return activation + "_no_mul"
|
||||
|
||||
|
||||
RELU2_NO_MUL: str = activation_without_mul("relu2")
|
||||
SILU_NO_MUL: str = activation_without_mul("silu")
|
||||
GELU_NO_MUL: str = activation_without_mul("gelu")
|
||||
|
||||
|
||||
def apply_moe_activation(
|
||||
activation: str,
|
||||
output: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply MoE activation function.
|
||||
|
||||
For *_and_mul activations (silu, gelu, swigluoai):
|
||||
- Expects output.size(-1) * 2 == input.size(-1)
|
||||
|
||||
For *_no_mul activations (silu_no_mul, gelu_no_mul, relu2_no_mul):
|
||||
- Expects output.size(-1) == input.size(-1)
|
||||
"""
|
||||
is_no_mul = activation.endswith("_no_mul")
|
||||
if is_no_mul:
|
||||
assert output.size(-1) == input.size(-1), (
|
||||
f"{activation} expects equal sizes: {output.size(-1)} vs {input.size(-1)}"
|
||||
)
|
||||
else:
|
||||
assert output.size(-1) * 2 == input.size(-1), (
|
||||
f"{activation} expects 2x ratio: {output.size(-1) * 2} vs {input.size(-1)}"
|
||||
)
|
||||
|
||||
# Activations with gated multiplication (gate × activation(up))
|
||||
if activation == "silu":
|
||||
torch.ops._C.silu_and_mul(output, input)
|
||||
elif activation == "gelu":
|
||||
torch.ops._C.gelu_and_mul(output, input)
|
||||
elif activation == "swigluoai":
|
||||
torch.ops._C.swigluoai_and_mul(output, input)
|
||||
elif activation == "swiglustep":
|
||||
from vllm.model_executor.layers.activation import swiglustep_and_mul_triton
|
||||
|
||||
swiglustep_and_mul_triton(output, input)
|
||||
|
||||
# Activations without gated multiplication
|
||||
elif activation == SILU_NO_MUL:
|
||||
output.copy_(F.silu(input))
|
||||
elif activation == GELU_NO_MUL:
|
||||
output.copy_(F.gelu(input))
|
||||
elif activation == RELU2_NO_MUL:
|
||||
F.relu(input, inplace=True)
|
||||
torch.square(input, out=output)
|
||||
else:
|
||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
||||
|
||||
return output
|
||||
|
||||
|
||||
# Torch custom ops can't deal with outputs aliasing inputs so we need to
|
||||
# disable inplace for torch >= 2.9.
|
||||
# See https://github.com/vllm-project/vllm/issues/26378
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
@@ -55,8 +56,12 @@ class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: str) -> bool:
|
||||
return activation in ["silu", "gelu", "swigluoai"]
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
return activation in [
|
||||
MoEActivation.SILU,
|
||||
MoEActivation.GELU,
|
||||
MoEActivation.SWIGLUOAI,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
@@ -92,7 +97,7 @@ class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
workspace1 = (0,)
|
||||
workspace2 = (0,)
|
||||
@@ -107,7 +112,7 @@ class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
@@ -129,7 +134,7 @@ class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
n_experts_per_token=topk,
|
||||
activation=activation,
|
||||
activation=activation.value,
|
||||
num_experts=self.moe_config.num_local_experts,
|
||||
ep_rank=self.moe_config.ep_rank,
|
||||
ep_size=self.moe_config.ep_size,
|
||||
|
||||
@@ -24,6 +24,7 @@ from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoeWeightScaleSupported,
|
||||
UnquantizedFusedMoEMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
@@ -622,7 +623,9 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.is_monolithic
|
||||
assert layer.activation == "silu", "Only SiLU activation is supported."
|
||||
assert layer.activation == MoEActivation.SILU, (
|
||||
f"Only SiLU activation is supported, not {layer.activation}."
|
||||
)
|
||||
assert (
|
||||
self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
|
||||
and not layer.enable_eplb
|
||||
@@ -649,7 +652,9 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert not self.is_monolithic
|
||||
assert layer.activation == "silu", "Only SiLU activation is supported."
|
||||
assert layer.activation == MoEActivation.SILU, (
|
||||
f"Only SiLU activation is supported, not {layer.activation}."
|
||||
)
|
||||
|
||||
# EPLB path
|
||||
if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
|
||||
@@ -1025,7 +1030,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.is_monolithic
|
||||
assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
|
||||
assert layer.activation == "silu"
|
||||
assert layer.activation == MoEActivation.SILU, (
|
||||
f"Only SiLU activation is supported, not {layer.activation}."
|
||||
)
|
||||
|
||||
if self.block_quant:
|
||||
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
|
||||
@@ -2271,19 +2278,21 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
assert not layer.enable_eplb, "EPLB not supported for W4A8-int MoE yet."
|
||||
assert layer.activation in ("silu", "swigluoai", "swiglu"), (
|
||||
"Only SiLU/SwiGLUGU/SwiGLUUG are supported."
|
||||
)
|
||||
assert layer.activation in (
|
||||
MoEActivation.SILU,
|
||||
MoEActivation.SWIGLUOAI,
|
||||
MoEActivation.SWIGLUSTEP,
|
||||
), "Only SiLU/SwiGLUGU/SwiGLUUG are supported."
|
||||
assert layer.expert_map is None, """expert_map/EP not implemented
|
||||
for CPU dyn-4bit MoE."""
|
||||
|
||||
def _act_kind(s: str) -> int:
|
||||
def _act_kind(s: MoEActivation) -> int:
|
||||
# 0 = SwiGLU_Gu (SiLU(g)*u), 1 = SwiGLU_Ug (SiLU(u)*g), 2 = SiLU
|
||||
if s == "swiglu":
|
||||
if s == MoEActivation.SWIGLUSTEP:
|
||||
return 0
|
||||
if s == "swigluoai":
|
||||
if s == MoEActivation.SWIGLUOAI:
|
||||
return 1
|
||||
if s == "silu":
|
||||
if s == MoEActivation.SILU:
|
||||
return 2
|
||||
raise ValueError(f"Unknown activation '{s}'")
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoEPermuteExpertsUnpermute,
|
||||
FusedMoEPrepareAndFinalize,
|
||||
FusedMoeWeightScaleSupported,
|
||||
MoEActivation,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
@@ -965,7 +966,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
# TODO(rob): convert this to MK.
|
||||
if layer.enable_eplb:
|
||||
raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
|
||||
assert layer.activation == "silu", (
|
||||
assert layer.activation == MoEActivation.SILU, (
|
||||
f"Expected 'silu' activation but got {layer.activation}"
|
||||
)
|
||||
|
||||
|
||||
@@ -12,6 +12,10 @@ from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.activation import (
|
||||
MoEActivation,
|
||||
apply_moe_activation,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
@@ -246,16 +250,13 @@ def _fused_moe_gguf(
|
||||
qweight_type2: int,
|
||||
activation: str,
|
||||
) -> torch.Tensor:
|
||||
activation_enum = MoEActivation.from_str(activation)
|
||||
|
||||
def act(x: torch.Tensor):
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = x.shape[:-1] + (d,)
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
if activation == "silu":
|
||||
torch.ops._C.silu_and_mul(out, x)
|
||||
elif activation == "gelu":
|
||||
torch.ops._C.gelu_and_mul(out, x)
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation: {activation}")
|
||||
apply_moe_activation(activation_enum, out, x)
|
||||
return out
|
||||
|
||||
# lazy import to avoid triggering triton import in CPU backend
|
||||
@@ -637,7 +638,6 @@ class GGUFMoEMethod(FusedMoEMethodBase):
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert layer.activation == "silu", "Only SiLU activation is supported."
|
||||
if layer.apply_router_weight_on_input:
|
||||
raise NotImplementedError(
|
||||
"Apply router weight on input is not supported for"
|
||||
@@ -652,7 +652,7 @@ class GGUFMoEMethod(FusedMoEMethodBase):
|
||||
topk_ids,
|
||||
layer.w13_qweight_type.weight_type,
|
||||
layer.w2_qweight_type.weight_type,
|
||||
layer.activation,
|
||||
layer.activation.value,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from torch.nn.parameter import Parameter
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
@@ -936,7 +937,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
# TODO(rob): this validation should happen at kernel selection
|
||||
# time in the oracle rather than here.
|
||||
assert layer.activation == "silu", (
|
||||
assert layer.activation == MoEActivation.SILU, (
|
||||
f"Expected 'silu' activation but got {layer.activation}"
|
||||
)
|
||||
assert not layer.renormalize
|
||||
@@ -965,7 +966,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
# TODO(rob): this validation should happen at kernel selection
|
||||
# time in the oracle rather than here.
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||
assert layer.activation in ("silu", "relu2_no_mul"), (
|
||||
assert layer.activation in (
|
||||
MoEActivation.SILU,
|
||||
MoEActivation.RELU2_NO_MUL,
|
||||
), (
|
||||
"Expected activation to be in ('silu', 'relu2_no_mul'),"
|
||||
f"but got {layer.activation}"
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any
|
||||
import torch
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
int4_w4a16_moe_quant_config,
|
||||
@@ -371,7 +372,9 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
assert layer.activation == "silu", "Only SiLU activation is supported."
|
||||
assert layer.activation == MoEActivation.SILU, (
|
||||
f"Only SiLU activation is supported, not {layer.activation}."
|
||||
)
|
||||
|
||||
return fused_experts(
|
||||
x,
|
||||
|
||||
@@ -13,6 +13,7 @@ from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoE,
|
||||
FusedMoEConfig,
|
||||
FusedMoEMethodBase,
|
||||
MoEActivation,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
@@ -1141,8 +1142,9 @@ class XpuMxfp4MoEMethod(Mxfp4MoEMethod):
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
assert layer.activation == "swigluoai", (
|
||||
"Only swiglu_oai activation is supported for XPU MXFP4 MoE"
|
||||
assert layer.activation == MoEActivation.SWIGLUOAI, (
|
||||
"Only swiglu_oai activation is supported for "
|
||||
f"XPU MXFP4 MoE, not {layer.activation}."
|
||||
)
|
||||
from vllm_xpu_kernels.fused_moe_interface import xpu_fused_moe
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported,
|
||||
MoEActivation,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
@@ -438,7 +439,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
expert_map=layer.expert_map,
|
||||
)
|
||||
elif self.use_marlin:
|
||||
assert layer.activation == "silu", (
|
||||
assert layer.activation == MoEActivation.SILU, (
|
||||
f"{layer.activation} not supported for Marlin MoE."
|
||||
)
|
||||
return fused_marlin_moe(
|
||||
|
||||
@@ -9,6 +9,7 @@ import torch
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
@@ -64,9 +65,9 @@ def _supports_quant_scheme(
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
|
||||
def _supports_activation(activation: str) -> bool:
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
"""Supports silu activation only."""
|
||||
return activation in ["silu"]
|
||||
return activation in [MoEActivation.SILU]
|
||||
|
||||
|
||||
def _supports_routing_method(
|
||||
@@ -267,7 +268,7 @@ def flashinfer_trtllm_fp4_moe(
|
||||
x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
num_expert_group: int | None,
|
||||
topk_group: int | None,
|
||||
@@ -297,7 +298,7 @@ def flashinfer_trtllm_fp4_moe(
|
||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||
|
||||
# https://github.com/flashinfer-ai/flashinfer/blob/f0277fd1bff90e309e5c19cab36c5dae056d685d/flashinfer/fused_moe/core.py#L2404
|
||||
assert activation == "silu", (
|
||||
assert activation == MoEActivation.SILU, (
|
||||
"Only SiLU activation is supported for FlashInfer TRTLLM FP4 MoE. "
|
||||
f"{activation} found instead."
|
||||
)
|
||||
@@ -365,7 +366,7 @@ def flashinfer_trtllm_fp4_routed_moe(
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
top_k: int,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -387,7 +388,7 @@ def flashinfer_trtllm_fp4_routed_moe(
|
||||
import flashinfer
|
||||
|
||||
# https://github.com/flashinfer-ai/flashinfer/blob/f0277fd1bff90e309e5c19cab36c5dae056d685d/flashinfer/fused_moe/core.py#L2535
|
||||
assert activation == "silu", (
|
||||
assert activation == MoEActivation.SILU, (
|
||||
"Only SiLU activation is supported for FlashInfer TRTLLM FP4 Routed MoE. "
|
||||
f"{activation} found instead."
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils.import_utils import has_triton_kernels
|
||||
@@ -88,7 +89,7 @@ def _can_support_mxfp4(
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
scoring_func: str = "softmax",
|
||||
activation: str = "swigluoai",
|
||||
activation: MoEActivation = MoEActivation.SWIGLUOAI,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
@@ -101,7 +102,7 @@ def _can_support_mxfp4(
|
||||
or e_score_correction_bias
|
||||
or apply_router_weight_on_input
|
||||
or scoring_func != "softmax"
|
||||
or activation != "swigluoai"
|
||||
or activation != MoEActivation.SWIGLUOAI
|
||||
or expert_load_view
|
||||
or logical_to_physical_map
|
||||
or logical_replica_count
|
||||
|
||||
@@ -33,8 +33,11 @@ from vllm.distributed.communication_op import tensor_model_parallel_all_gather
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.model_executor.layers.activation import ReLUSquaredActivation
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE, SharedFusedMoE
|
||||
from vllm.model_executor.layers.fused_moe.utils import activation_without_mul
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoE,
|
||||
SharedFusedMoE,
|
||||
activation_without_mul,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
|
||||
Reference in New Issue
Block a user