[Refactor] Replace activation: str with MoEActivation enum (#33843)
Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -12,6 +12,10 @@ import torch
|
||||
import vllm.envs as envs
|
||||
from vllm.forward_context import get_forward_context, is_forward_context_available
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.activation import (
|
||||
MoEActivation,
|
||||
apply_moe_activation,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
@@ -19,7 +23,6 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_resize_cache,
|
||||
apply_moe_activation,
|
||||
count_expert_num_tokens,
|
||||
disable_inplace,
|
||||
)
|
||||
@@ -536,7 +539,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def _supports_activation(activation: str) -> bool:
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
"""
|
||||
Whether the kernel supports a particular act function.
|
||||
"""
|
||||
@@ -658,7 +661,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
"""
|
||||
Compute the shapes for the temporary and final outputs of the two gemms
|
||||
@@ -690,7 +693,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def adjust_N_for_activation(N: int, activation: str) -> int:
|
||||
def adjust_N_for_activation(N: int, activation: MoEActivation) -> int:
|
||||
"""
|
||||
Calculate the output dimension for the activation function.
|
||||
|
||||
@@ -702,16 +705,15 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
|
||||
Args:
|
||||
N: The intermediate size (width of w1/w3 weights).
|
||||
activation: The activation function name.
|
||||
activation: The activation function enum.
|
||||
|
||||
Returns:
|
||||
The output dimension after activation.
|
||||
"""
|
||||
is_no_mul = activation.endswith("_no_mul")
|
||||
return N if is_no_mul else N // 2
|
||||
return N if not activation.is_gated else N // 2
|
||||
|
||||
def activation(
|
||||
self, activation: str, output: torch.Tensor, input: torch.Tensor
|
||||
self, activation: MoEActivation, output: torch.Tensor, input: torch.Tensor
|
||||
) -> None:
|
||||
apply_moe_activation(activation, output, input)
|
||||
|
||||
@@ -732,7 +734,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
@@ -892,7 +894,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: ExpertTokensMetadata | None,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Allocate temporary and output buffers for the fused experts op.
|
||||
@@ -1135,7 +1137,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
@@ -1309,7 +1311,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
activation: MoEActivation = MoEActivation.SILU,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
@@ -1326,7 +1328,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
- topk_weights (torch.Tensor): The topk weights applied at the end of
|
||||
the layer.
|
||||
- topk_ids (torch.Tensor): A map of row to expert id.
|
||||
- activation (str): The activation function to apply after the first
|
||||
- activation (MoEActivation): The activation function to apply after the first
|
||||
MoE layer.
|
||||
- global_num_experts (int): The total number of experts in the global
|
||||
expert space.
|
||||
|
||||
Reference in New Issue
Block a user