[MoE Refactor] Separate Router into OO Classes (#30623)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -11,9 +11,6 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
|
||||
FusedMoEMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_router import (
|
||||
FusedMoERouter,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE,
|
||||
FusedMoeWeightScaleSupported,
|
||||
@@ -23,6 +20,9 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEPermuteExpertsUnpermute,
|
||||
FusedMoEPrepareAndFinalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
|
||||
FusedMoERouter,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
|
||||
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
|
||||
UnquantizedFusedMoEMethod,
|
||||
@@ -83,13 +83,17 @@ if HAS_TRITON:
|
||||
BatchedTritonExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
GroupedTopk,
|
||||
TritonExperts,
|
||||
TritonWNA16Experts,
|
||||
fused_experts,
|
||||
fused_topk,
|
||||
get_config_file_name,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.fused_topk_router import (
|
||||
fused_topk,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.grouped_topk_router import (
|
||||
GroupedTopk,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts,
|
||||
)
|
||||
|
||||
@@ -117,8 +117,12 @@ class RoutingMethodType(IntEnum):
|
||||
RenormalizeNaive = (4,)
|
||||
# TopK: TopK (no softmax)
|
||||
TopK = (5,)
|
||||
# Custom
|
||||
Custom = (6,)
|
||||
# Simulated
|
||||
Simulated = (7,)
|
||||
# Unspecified
|
||||
Unspecified = 6.0
|
||||
Unspecified = 8.0
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -13,9 +13,7 @@ import torch
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
@@ -34,9 +32,6 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
||||
rocm_aiter_grouped_topk,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceNoOP,
|
||||
)
|
||||
@@ -49,7 +44,6 @@ from vllm.model_executor.layers.fused_moe.utils import (
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6
|
||||
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme
|
||||
from vllm.model_executor.utils import maybe_disable_graph_partition
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
||||
@@ -1318,375 +1312,6 @@ def try_get_optimal_moe_config(
|
||||
return config
|
||||
|
||||
|
||||
def vllm_topk_softmax(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
token_expert_indices: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
renormalize: bool,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
ops.topk_softmax(
|
||||
topk_weights,
|
||||
topk_indices,
|
||||
token_expert_indices,
|
||||
gating_output,
|
||||
renormalize,
|
||||
)
|
||||
|
||||
return topk_weights, topk_indices
|
||||
|
||||
|
||||
def dispatch_topk_func(
|
||||
use_rocm_aiter: bool = False,
|
||||
) -> Callable[..., tuple[torch.Tensor, ...]]:
|
||||
if use_rocm_aiter:
|
||||
return rocm_aiter_ops.topk_softmax
|
||||
return vllm_topk_softmax
|
||||
|
||||
|
||||
def fused_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
indices_type: torch.dtype | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
|
||||
|
||||
M, _ = hidden_states.size()
|
||||
|
||||
topk_weights = torch.empty(
|
||||
M, topk, dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
topk_ids = torch.empty(
|
||||
M,
|
||||
topk,
|
||||
dtype=torch.int32 if indices_type is None else indices_type,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
token_expert_indices = torch.empty(
|
||||
M, topk, dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
|
||||
topk_func = dispatch_topk_func(use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled())
|
||||
topk_weights, topk_ids = topk_func(
|
||||
topk_weights, topk_ids, token_expert_indices, gating_output, renormalize
|
||||
)
|
||||
|
||||
return topk_weights, topk_ids, token_expert_indices
|
||||
|
||||
|
||||
def fused_topk_bias(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
e_score_correction_bias: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
):
|
||||
n_routed_experts = gating_output.shape[-1]
|
||||
scores = gating_output.softmax(dim=-1)
|
||||
scores_for_choice = scores.view(
|
||||
-1, n_routed_experts
|
||||
) + e_score_correction_bias.unsqueeze(0)
|
||||
|
||||
# For batch invariance, use sorted=True to ensure deterministic expert selection
|
||||
use_sorted = vllm_is_batch_invariant()
|
||||
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1]
|
||||
topk_weights = scores.gather(1, topk_indices)
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
return topk_weights.to(torch.float32), topk_indices.to(torch.int32)
|
||||
|
||||
|
||||
# This is used by the Deepseek-V2 and Deepseek-V3 model
|
||||
@torch.compile(
|
||||
dynamic=True,
|
||||
backend=current_platform.simple_compile_backend,
|
||||
options=maybe_disable_graph_partition(current_platform.simple_compile_backend),
|
||||
)
|
||||
def grouped_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if (
|
||||
envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK
|
||||
and current_platform.is_cuda()
|
||||
and num_expert_group <= 32
|
||||
and topk <= 32
|
||||
and e_score_correction_bias is not None
|
||||
):
|
||||
return fused_grouped_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=gating_output,
|
||||
topk=topk,
|
||||
renormalize=renormalize,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
|
||||
|
||||
if scoring_func == "softmax":
|
||||
scores = torch.softmax(gating_output, dim=-1)
|
||||
elif scoring_func == "sigmoid":
|
||||
scores = gating_output.sigmoid()
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
||||
|
||||
num_token = scores.size(0)
|
||||
if e_score_correction_bias is not None:
|
||||
# Store original scores before applying correction bias. We use biased
|
||||
# scores for expert selection but original scores for routing weights
|
||||
original_scores = scores
|
||||
scores = scores + e_score_correction_bias.unsqueeze(0)
|
||||
group_scores = (
|
||||
scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
|
||||
)
|
||||
else:
|
||||
group_scores = (
|
||||
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
||||
) # [n, n_group]
|
||||
|
||||
# For batch invariance, use sorted=True to ensure deterministic expert selection
|
||||
use_sorted = vllm_is_batch_invariant()
|
||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[
|
||||
1
|
||||
] # [n, top_k_group]
|
||||
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
||||
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
||||
score_mask = (
|
||||
group_mask.unsqueeze(-1)
|
||||
.expand(num_token, num_expert_group, scores.size(-1) // num_expert_group)
|
||||
.reshape(num_token, -1)
|
||||
) # [n, e]
|
||||
tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
|
||||
|
||||
if e_score_correction_bias is not None:
|
||||
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)[1]
|
||||
# Use original unbiased scores for the routing weights
|
||||
topk_weights = original_scores.gather(1, topk_ids)
|
||||
else:
|
||||
topk_weights, topk_ids = torch.topk(
|
||||
tmp_scores, k=topk, dim=-1, sorted=use_sorted
|
||||
)
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
if routed_scaling_factor != 1.0:
|
||||
topk_weights = topk_weights * routed_scaling_factor
|
||||
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||
|
||||
|
||||
# --8<-- [start:grouped_topk]
|
||||
@CustomOp.register("grouped_topk")
|
||||
class GroupedTopk(CustomOp):
|
||||
"""GroupedTopk used by the Deepseek-V2 and Deepseek-V3 model."""
|
||||
|
||||
# --8<-- [end:grouped_topk]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
num_fused_shared_experts: int = 0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.native_impl = grouped_topk
|
||||
self.topk = topk
|
||||
self.renormalize = renormalize
|
||||
self.num_expert_group = num_expert_group
|
||||
self.topk_group = topk_group
|
||||
self.scoring_func = scoring_func
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
self.num_fused_shared_experts = num_fused_shared_experts
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.native_impl(
|
||||
hidden_states,
|
||||
gating_output,
|
||||
self.topk,
|
||||
self.renormalize,
|
||||
self.num_expert_group,
|
||||
self.topk_group,
|
||||
self.scoring_func,
|
||||
self.routed_scaling_factor,
|
||||
e_score_correction_bias,
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.forward_native(
|
||||
hidden_states, gating_output, e_score_correction_bias
|
||||
)
|
||||
|
||||
def forward_hip(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if rocm_aiter_ops.is_fused_moe_enabled():
|
||||
if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled():
|
||||
assert self.num_fused_shared_experts == 0
|
||||
return rocm_aiter_grouped_topk(
|
||||
hidden_states,
|
||||
gating_output,
|
||||
self.topk,
|
||||
self.renormalize,
|
||||
self.num_expert_group,
|
||||
self.topk_group,
|
||||
self.scoring_func,
|
||||
self.routed_scaling_factor,
|
||||
e_score_correction_bias,
|
||||
self.num_fused_shared_experts,
|
||||
)
|
||||
else:
|
||||
return self.forward_native(
|
||||
hidden_states, gating_output, e_score_correction_bias
|
||||
)
|
||||
|
||||
|
||||
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
||||
def eplb_map_to_physical_and_record(
|
||||
topk_ids: torch.Tensor,
|
||||
expert_load_view: torch.Tensor,
|
||||
logical_to_physical_map: torch.Tensor,
|
||||
logical_replica_count: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Map the logical expert ids to physical expert ids
|
||||
and record the expert load metrics.
|
||||
|
||||
This will select a pseudo-random replica for each logical expert.
|
||||
Only used for EPLB.
|
||||
|
||||
Args:
|
||||
topk_ids: The logical expert ids.
|
||||
expert_load_view: The expert load view.
|
||||
logical_to_physical_map: The logical to physical map.
|
||||
logical_replica_count: The logical replica count.
|
||||
|
||||
Returns:
|
||||
The physical expert ids.
|
||||
"""
|
||||
|
||||
# 1. Convert the logical expert ids to physical expert ids
|
||||
# Directly select a random replica for each logical expert
|
||||
|
||||
# In case `indices_type` is not `torch.long` or `torch.int`,
|
||||
# e.g. `torch.uint32` as required by dispatch/combine kernels
|
||||
topk_ids_long = topk_ids.long()
|
||||
# Use (token position) modulo (replica count)
|
||||
# to deterministically choose a replica
|
||||
replica_count = logical_replica_count[topk_ids_long]
|
||||
# Flatten-position based index, reshaped back to `topk_ids` shape
|
||||
pos_indices = torch.arange(
|
||||
topk_ids.numel(), device=topk_ids.device, dtype=torch.long
|
||||
).reshape_as(topk_ids)
|
||||
# Compute pseudo-random indices by modulo
|
||||
replica_indices = (pos_indices % replica_count).unsqueeze(-1)
|
||||
physical_ids = (
|
||||
logical_to_physical_map[topk_ids_long].gather(-1, replica_indices).squeeze(-1)
|
||||
)
|
||||
|
||||
topk_ids = physical_ids
|
||||
|
||||
# 2. Record expert load metrics.
|
||||
|
||||
# TODO(bowen): When using `FusedMoEModularKernel`, this
|
||||
# can be done in a more unified way, since
|
||||
# `FusedMoEPrepareAndFinalize` will return the expert
|
||||
# token count, in some cases directly from the kernel.
|
||||
# However, now there are many code paths not using
|
||||
# the modular kernel, e.g. calling `fused_experts`,
|
||||
# so we decide to keep the logic here.
|
||||
#
|
||||
# If later refactor moved all the MoE kernel calls
|
||||
# to the modular kernel, we can move this logic there
|
||||
# to achieve better efficiency.
|
||||
|
||||
# `expert_load_view`: (num_physical_experts,)
|
||||
|
||||
# `torch.bincount` is not compilable, so use `scatter_add_` instead.
|
||||
topk_ids_flatten = topk_ids.flatten()
|
||||
expert_load_view.scatter_add_(
|
||||
dim=0,
|
||||
index=topk_ids_flatten.long(),
|
||||
src=torch.ones_like(topk_ids_flatten).to(expert_load_view),
|
||||
)
|
||||
return topk_ids
|
||||
|
||||
|
||||
def fused_grouped_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
e_score_correction_bias: torch.Tensor,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
|
||||
|
||||
if scoring_func == "sigmoid":
|
||||
# Fully fused kernel path for sigmoid
|
||||
topk_values, topk_indices = ops.grouped_topk(
|
||||
gating_output, # raw logits
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
topk,
|
||||
renormalize,
|
||||
routed_scaling_factor,
|
||||
e_score_correction_bias,
|
||||
1, # scoring_func=1 for sigmoid
|
||||
)
|
||||
elif scoring_func == "softmax":
|
||||
# Apply softmax in Python, then use fused kernel
|
||||
# TODO: Add support for softmax in kernel
|
||||
scores = torch.softmax(gating_output, dim=-1)
|
||||
topk_values, topk_indices = ops.grouped_topk(
|
||||
scores, # pre-computed scores
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
topk,
|
||||
renormalize,
|
||||
routed_scaling_factor,
|
||||
e_score_correction_bias,
|
||||
0, # scoring_func=0 (no activation, scores already computed)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
||||
|
||||
# Fused kernel outputs float32 values and int32 indices directly
|
||||
return topk_values, topk_indices
|
||||
|
||||
|
||||
def inplace_fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
|
||||
@@ -10,13 +10,13 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_router import (
|
||||
FusedMoERouter,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEPermuteExpertsUnpermute,
|
||||
FusedMoEPrepareAndFinalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
|
||||
FusedMoERouter,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
|
||||
@@ -12,11 +12,13 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
|
||||
FusedMoEMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel,
|
||||
FusedMoEPrepareAndFinalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
|
||||
FusedMoERouter,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ from vllm.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.distributed.eplb.eplb_state import EplbState
|
||||
from vllm.distributed.eplb.eplb_state import EplbLayerState, EplbState
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
@@ -31,14 +31,24 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
|
||||
FusedMoEMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import (
|
||||
FusedMoEModularMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
init_aiter_topK_meta_data,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.routed_experts_capturer import (
|
||||
RoutedExpertsCapturer,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator
|
||||
from vllm.model_executor.layers.fused_moe.router.router_factory import (
|
||||
create_fused_moe_router,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
|
||||
UnquantizedFusedMoEMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
)
|
||||
@@ -52,31 +62,6 @@ from vllm.utils.torch_utils import (
|
||||
)
|
||||
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from .fused_moe import eplb_map_to_physical_and_record
|
||||
else:
|
||||
|
||||
def _eplb_map_to_physical_and_record(
|
||||
topk_ids: torch.Tensor,
|
||||
expert_load_view: torch.Tensor,
|
||||
logical_to_physical_map: torch.Tensor,
|
||||
logical_replica_count: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# CPU fallback: no EPLB so just return as is
|
||||
return topk_ids
|
||||
|
||||
eplb_map_to_physical_and_record = _eplb_map_to_physical_and_record
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import GroupedTopk
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
|
||||
FusedMoEMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import (
|
||||
FusedMoEModularMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
|
||||
UnquantizedFusedMoEMethod,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -288,23 +273,6 @@ def maybe_roundup_hidden_size(
|
||||
return hidden_size
|
||||
|
||||
|
||||
class FusedMoERouterImpl(FusedMoERouter):
|
||||
def __init__(self, layer: "FusedMoE"):
|
||||
super().__init__()
|
||||
self.layer = layer
|
||||
|
||||
@property
|
||||
def routing_method_type(self) -> RoutingMethodType:
|
||||
return self.layer.routing_method_type
|
||||
|
||||
def select_experts(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.layer._select_experts(hidden_states, router_logits)
|
||||
|
||||
|
||||
# --8<-- [start:fused_moe]
|
||||
@CustomOp.register("fused_moe")
|
||||
class FusedMoE(CustomOp):
|
||||
@@ -440,9 +408,7 @@ class FusedMoE(CustomOp):
|
||||
self.layer_name = prefix
|
||||
|
||||
self.enable_eplb = enable_eplb
|
||||
self.expert_load_view: torch.Tensor | None = None
|
||||
self.logical_to_physical_map: torch.Tensor | None = None
|
||||
self.logical_replica_count: torch.Tensor | None = None
|
||||
self.eplb_state = EplbLayerState()
|
||||
self.expert_placement_strategy: ExpertPlacementStrategy = (
|
||||
vllm_config.parallel_config.expert_placement_strategy
|
||||
)
|
||||
@@ -538,6 +504,8 @@ class FusedMoE(CustomOp):
|
||||
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
||||
self.reduce_results = reduce_results
|
||||
self.renormalize = renormalize
|
||||
|
||||
# TODO(bnell): these attributes are only used by cpu/xpu/mxfp4
|
||||
self.use_grouped_topk = use_grouped_topk
|
||||
if self.use_grouped_topk:
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
@@ -547,46 +515,11 @@ class FusedMoE(CustomOp):
|
||||
self.scoring_func = scoring_func
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
self.e_score_correction_bias = e_score_correction_bias
|
||||
# TODO(bnell): end attributes
|
||||
|
||||
self.apply_router_weight_on_input = apply_router_weight_on_input
|
||||
self.activation = activation
|
||||
|
||||
self._grouped_topk_impl: GroupedTopk | None = None
|
||||
if self.use_grouped_topk:
|
||||
assert self.num_expert_group is not None
|
||||
assert self.topk_group is not None
|
||||
self._grouped_topk_impl = GroupedTopk(
|
||||
topk=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
num_expert_group=self.num_expert_group,
|
||||
topk_group=self.topk_group,
|
||||
scoring_func=self.scoring_func,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
)
|
||||
|
||||
if self.scoring_func != "softmax" and not self.use_grouped_topk:
|
||||
raise ValueError(
|
||||
"Only softmax scoring function is supported for non-grouped topk."
|
||||
)
|
||||
|
||||
# ToDo: Better logic to determine the routing method type
|
||||
if routing_method_type is not None:
|
||||
self.routing_method_type: RoutingMethodType = routing_method_type
|
||||
else:
|
||||
if scoring_func == "sigmoid":
|
||||
if self.use_grouped_topk:
|
||||
self.routing_method_type = RoutingMethodType.DeepSeekV3
|
||||
elif self.top_k == 1:
|
||||
self.routing_method_type = RoutingMethodType.Llama4
|
||||
elif self.scoring_func == "softmax":
|
||||
self.routing_method_type = (
|
||||
RoutingMethodType.Renormalize
|
||||
if not self.renormalize
|
||||
else RoutingMethodType.RenormalizeNaive
|
||||
)
|
||||
else:
|
||||
self.routing_method_type = RoutingMethodType.TopK
|
||||
|
||||
self.moe_config: FusedMoEConfig = FusedMoEConfig(
|
||||
num_experts=self.global_num_experts,
|
||||
experts_per_token=top_k,
|
||||
@@ -637,8 +570,7 @@ class FusedMoE(CustomOp):
|
||||
# If you plan to add support for more quantization methods,
|
||||
# please refer to the implementation in `Fp8MoEMethod`.
|
||||
raise NotImplementedError(
|
||||
f"EPLB is not supported {self.quant_method.__class__.__name__}. "
|
||||
"EPLB is only supported for FP8 quantization for now."
|
||||
f"EPLB is not supported {self.quant_method.__class__.__name__}."
|
||||
)
|
||||
|
||||
moe_quant_params = {
|
||||
@@ -663,7 +595,38 @@ class FusedMoE(CustomOp):
|
||||
self.batched_hidden_states: torch.Tensor | None = None
|
||||
self.batched_router_logits: torch.Tensor | None = None
|
||||
|
||||
self.router = FusedMoERouterImpl(self)
|
||||
# TODO(bnell): in next PR move capture back to layer
|
||||
capture: Callable[[torch.Tensor], None] | None = None
|
||||
if (
|
||||
self.vllm_config.model_config is not None
|
||||
and self.vllm_config.model_config.enable_return_routed_experts
|
||||
):
|
||||
# In dummy runs, the capturer is not initialized.
|
||||
capturer = RoutedExpertsCapturer.get_instance()
|
||||
if capturer is not None:
|
||||
capture = lambda topk_ids: capturer.capture(self.layer_id, topk_ids)
|
||||
|
||||
self.router = create_fused_moe_router(
|
||||
top_k=top_k,
|
||||
global_num_experts=self.global_num_experts,
|
||||
eplb_state=self.eplb_state,
|
||||
renormalize=renormalize,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
enable_eplb=enable_eplb,
|
||||
# TODO(bnell): once we can construct the MK at init time, we
|
||||
# can make this a value.
|
||||
indices_type_getter=lambda: self.quant_method.topk_indices_dtype,
|
||||
routing_method_type=routing_method_type,
|
||||
capture=capture,
|
||||
)
|
||||
self.routing_method_type: RoutingMethodType = self.router.routing_method_type
|
||||
|
||||
# Note: maybe_init_modular_kernel should only be called by
|
||||
# prepare_communication_buffer_for_model.
|
||||
@@ -1492,9 +1455,9 @@ class FusedMoE(CustomOp):
|
||||
This is used later in forward pass, where we get the expert mapping
|
||||
and record the load metrics in `expert_load_view`.
|
||||
"""
|
||||
self.expert_load_view = expert_load_view[moe_layer_idx]
|
||||
self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx]
|
||||
self.logical_replica_count = logical_replica_count[moe_layer_idx]
|
||||
self.eplb_state.expert_load_view = expert_load_view[moe_layer_idx]
|
||||
self.eplb_state.logical_to_physical_map = logical_to_physical_map[moe_layer_idx]
|
||||
self.eplb_state.logical_replica_count = logical_replica_count[moe_layer_idx]
|
||||
|
||||
def ensure_moe_quant_config_init(self):
|
||||
if self.quant_method.moe_quant_config is None:
|
||||
@@ -1535,130 +1498,6 @@ class FusedMoE(CustomOp):
|
||||
device=torch.cuda.current_device(),
|
||||
)
|
||||
|
||||
def _select_experts(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Route the input hidden states to the top-k experts based on the
|
||||
router logits.
|
||||
|
||||
Returns:
|
||||
(topk_weights, topk_ids)
|
||||
(tuple[torch.Tensor, torch.Tensor]):
|
||||
The weights and expert ids.
|
||||
|
||||
**Compatibility**: When EPLB is not enabled, the returned ids are
|
||||
equivalent to global logical ids, so should be compatible with
|
||||
plain MoE implementations without redundant experts.
|
||||
"""
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_topk,
|
||||
fused_topk_bias,
|
||||
)
|
||||
|
||||
if self.enable_eplb:
|
||||
if self.quant_method.supports_eplb:
|
||||
if self.expert_load_view is None:
|
||||
raise ValueError(
|
||||
"enable_eplb=True requiere expert_load_view != None"
|
||||
)
|
||||
if self.logical_to_physical_map is None:
|
||||
raise ValueError(
|
||||
"enable_eplb=True requiere logical_to_physical_map != None"
|
||||
)
|
||||
if self.logical_replica_count is None:
|
||||
raise ValueError(
|
||||
"enable_eplb=True requiere logical_replica_count != None"
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"EPLB is not supported for {self.quant_method.method_name}."
|
||||
)
|
||||
|
||||
def valid_grouping() -> bool:
|
||||
# Check if num_experts is greater than num_expert_group
|
||||
# and is divisible by num_expert_group
|
||||
num_experts = router_logits.shape[-1]
|
||||
if num_experts <= self.num_expert_group:
|
||||
return False
|
||||
return num_experts % self.num_expert_group == 0
|
||||
|
||||
indices_type = self.quant_method.topk_indices_dtype
|
||||
|
||||
# Check if we should use a routing simulation strategy
|
||||
routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY
|
||||
if routing_strategy != "":
|
||||
topk_weights, topk_ids = RoutingSimulator.simulate_routing(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
strategy_name=routing_strategy,
|
||||
top_k=self.top_k,
|
||||
indices_type=indices_type,
|
||||
)
|
||||
|
||||
# DeepSeekv2 uses grouped_top_k
|
||||
elif self.use_grouped_topk and valid_grouping():
|
||||
assert self._grouped_topk_impl is not None
|
||||
topk_weights, topk_ids = self._grouped_topk_impl(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
)
|
||||
elif self.e_score_correction_bias is not None:
|
||||
topk_weights, topk_ids = fused_topk_bias(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
e_score_correction_bias=self.e_score_correction_bias.data,
|
||||
topk=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
)
|
||||
if self.routed_scaling_factor != 1.0:
|
||||
topk_weights *= self.routed_scaling_factor
|
||||
elif self.custom_routing_function is None:
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
indices_type=indices_type,
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = self.custom_routing_function(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
)
|
||||
|
||||
if self.enable_eplb:
|
||||
topk_ids = eplb_map_to_physical_and_record(
|
||||
topk_ids=topk_ids,
|
||||
expert_load_view=self.expert_load_view,
|
||||
logical_to_physical_map=self.logical_to_physical_map,
|
||||
logical_replica_count=self.logical_replica_count,
|
||||
)
|
||||
|
||||
if (indices_type is not None) and topk_ids.dtype != indices_type:
|
||||
topk_ids = topk_ids.to(dtype=indices_type)
|
||||
|
||||
assert topk_ids.dtype == indices_type or indices_type is None
|
||||
|
||||
if (
|
||||
self.vllm_config.model_config is not None
|
||||
and self.vllm_config.model_config.enable_return_routed_experts
|
||||
):
|
||||
# In dummy runs, the capturer is not initialized.
|
||||
capturer = RoutedExpertsCapturer.get_instance()
|
||||
if capturer is not None: # in dummmy_run may be None
|
||||
capturer.capture( # noqa
|
||||
layer_id=self.layer_id,
|
||||
topk_ids=topk_ids,
|
||||
)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
def must_reduce_shared_expert_outputs(self) -> bool:
|
||||
"""
|
||||
The shared_experts are typically computed using the RowParallelLinear
|
||||
@@ -1761,8 +1600,12 @@ class FusedMoE(CustomOp):
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.batched_hidden_states is not None
|
||||
assert self.batched_router_logits is not None
|
||||
assert self.batched_hidden_states.dtype == full_hidden_states.dtype
|
||||
assert self.batched_router_logits.dtype == full_router_logits.dtype
|
||||
assert self.batched_hidden_states.dtype == full_hidden_states.dtype, (
|
||||
f"{self.batched_hidden_states.dtype} == {full_hidden_states.dtype}"
|
||||
)
|
||||
assert self.batched_router_logits.dtype == full_router_logits.dtype, (
|
||||
f"{self.batched_router_logits.dtype} == {full_router_logits.dtype}"
|
||||
)
|
||||
# Check size compatibility.
|
||||
assert self.batched_hidden_states.size(-1) == full_hidden_states.size(-1)
|
||||
assert self.batched_router_logits.size(-1) == full_router_logits.size(-1)
|
||||
@@ -2080,15 +1923,8 @@ class FusedMoE(CustomOp):
|
||||
f"tp_size={self.tp_size},\n"
|
||||
f"ep_size={self.ep_size}, "
|
||||
f"reduce_results={self.reduce_results}, "
|
||||
f"renormalize={self.renormalize}, "
|
||||
f"use_grouped_topk={self.use_grouped_topk}"
|
||||
)
|
||||
|
||||
if self.use_grouped_topk:
|
||||
s += f", num_expert_group={self.num_expert_group}, topk_group={self.topk_group}" # noqa: E501
|
||||
|
||||
s += f", scoring_func='{self.scoring_func}', activation='{self.activation}'" # noqa: E501
|
||||
|
||||
return s
|
||||
|
||||
|
||||
|
||||
2
vllm/model_executor/layers/fused_moe/router/__init__.py
Normal file
2
vllm/model_executor/layers/fused_moe/router/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
245
vllm/model_executor/layers/fused_moe/router/base_router.py
Normal file
245
vllm/model_executor/layers/fused_moe/router/base_router.py
Normal file
@@ -0,0 +1,245 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed.eplb.eplb_state import EplbLayerState
|
||||
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
|
||||
FusedMoERouter,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
|
||||
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
||||
def eplb_map_to_physical_and_record(
|
||||
topk_ids: torch.Tensor,
|
||||
expert_load_view: torch.Tensor,
|
||||
logical_to_physical_map: torch.Tensor,
|
||||
logical_replica_count: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Map the logical expert ids to physical expert ids
|
||||
and record the expert load metrics.
|
||||
|
||||
This will select a pseudo-random replica for each logical expert.
|
||||
Only used for EPLB.
|
||||
|
||||
Args:
|
||||
topk_ids: The logical expert ids.
|
||||
expert_load_view: The expert load view.
|
||||
logical_to_physical_map: The logical to physical map.
|
||||
logical_replica_count: The logical replica count.
|
||||
|
||||
Returns:
|
||||
The physical expert ids.
|
||||
"""
|
||||
|
||||
# 1. Convert the logical expert ids to physical expert ids
|
||||
# Directly select a random replica for each logical expert
|
||||
|
||||
# In case `indices_type` is not `torch.long` or `torch.int`,
|
||||
# e.g. `torch.uint32` as required by dispatch/combine kernels
|
||||
topk_ids_long = topk_ids.long()
|
||||
# Use (token position) modulo (replica count)
|
||||
# to deterministically choose a replica
|
||||
replica_count = logical_replica_count[topk_ids_long]
|
||||
# Flatten-position based index, reshaped back to `topk_ids` shape
|
||||
pos_indices = torch.arange(
|
||||
topk_ids.numel(), device=topk_ids.device, dtype=torch.long
|
||||
).reshape_as(topk_ids)
|
||||
# Compute pseudo-random indices by modulo
|
||||
replica_indices = (pos_indices % replica_count).unsqueeze(-1)
|
||||
physical_ids = (
|
||||
logical_to_physical_map[topk_ids_long]
|
||||
.gather(-1, replica_indices)
|
||||
.squeeze(-1)
|
||||
)
|
||||
|
||||
topk_ids = physical_ids
|
||||
|
||||
# 2. Record expert load metrics.
|
||||
|
||||
# TODO(bowen): When using `FusedMoEModularKernel`, this
|
||||
# can be done in a more unified way, since
|
||||
# `FusedMoEPrepareAndFinalize` will return the expert
|
||||
# token count, in some cases directly from the kernel.
|
||||
# However, now there are many code paths not using
|
||||
# the modular kernel, e.g. calling `fused_experts`,
|
||||
# so we decide to keep the logic here.
|
||||
#
|
||||
# If later refactor moved all the MoE kernel calls
|
||||
# to the modular kernel, we can move this logic there
|
||||
# to achieve better efficiency.
|
||||
|
||||
# `expert_load_view`: (num_physical_experts,)
|
||||
|
||||
# `torch.bincount` is not compilable, so use `scatter_add_` instead.
|
||||
topk_ids_flatten = topk_ids.flatten()
|
||||
expert_load_view.scatter_add_(
|
||||
dim=0,
|
||||
index=topk_ids_flatten.long(),
|
||||
src=torch.ones_like(topk_ids_flatten).to(expert_load_view),
|
||||
)
|
||||
return topk_ids
|
||||
else:
|
||||
|
||||
def eplb_map_to_physical_and_record(
|
||||
topk_ids: torch.Tensor,
|
||||
expert_load_view: torch.Tensor,
|
||||
logical_to_physical_map: torch.Tensor,
|
||||
logical_replica_count: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# CPU fallback: no EPLB so just return as is
|
||||
return topk_ids
|
||||
|
||||
|
||||
class BaseRouter(FusedMoERouter):
|
||||
"""
|
||||
Base router class that provides common functionality for all router implementations.
|
||||
|
||||
This class implements the template method pattern where select_experts() handles
|
||||
common pre-processing and post-processing, delegating the actual routing logic
|
||||
to the abstract _compute_routing() method.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
top_k: int,
|
||||
global_num_experts: int,
|
||||
eplb_state: EplbLayerState,
|
||||
enable_eplb: bool = False,
|
||||
# TODO(bnell): Once the MK is constructed at layer init time, we
|
||||
# can make this a plain value instead of a callback.
|
||||
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
|
||||
):
|
||||
"""
|
||||
Note: the indices dtype might not be available at router construction
|
||||
time, so we need to supply a callback to get it at runtime. This is
|
||||
because the indices type is supplied by modular kernels which are
|
||||
created after MoE layer/router construction.
|
||||
"""
|
||||
super().__init__()
|
||||
self.top_k = top_k
|
||||
self.global_num_experts = global_num_experts
|
||||
self.eplb_state = eplb_state
|
||||
self.enable_eplb = enable_eplb
|
||||
self.indices_type_getter = indices_type_getter
|
||||
self.capture: Callable[[torch.tensor], None] | None = None
|
||||
|
||||
def _validate_eplb_state(self) -> None:
|
||||
"""Validate that EPLB state is properly initialized if EPLB is enabled."""
|
||||
if self.enable_eplb:
|
||||
if self.eplb_state.expert_load_view is None:
|
||||
raise ValueError("enable_eplb=True requires expert_load_view != None")
|
||||
if self.eplb_state.logical_to_physical_map is None:
|
||||
raise ValueError(
|
||||
"enable_eplb=True requires logical_to_physical_map != None"
|
||||
)
|
||||
if self.eplb_state.logical_replica_count is None:
|
||||
raise ValueError(
|
||||
"enable_eplb=True requires logical_replica_count != None"
|
||||
)
|
||||
|
||||
def _get_indices_type(self) -> torch.dtype | None:
|
||||
"""Get the desired indices dtype from the getter function."""
|
||||
return (
|
||||
self.indices_type_getter() if self.indices_type_getter is not None else None
|
||||
)
|
||||
|
||||
def _apply_eplb_mapping(self, topk_ids: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply EPLB mapping to convert logical expert IDs to physical expert IDs."""
|
||||
if self.enable_eplb:
|
||||
assert self.eplb_state.expert_load_view is not None
|
||||
assert self.eplb_state.logical_to_physical_map is not None
|
||||
assert self.eplb_state.logical_replica_count is not None
|
||||
return eplb_map_to_physical_and_record(
|
||||
topk_ids=topk_ids,
|
||||
expert_load_view=self.eplb_state.expert_load_view,
|
||||
logical_to_physical_map=self.eplb_state.logical_to_physical_map,
|
||||
logical_replica_count=self.eplb_state.logical_replica_count,
|
||||
)
|
||||
return topk_ids
|
||||
|
||||
def _convert_indices_dtype(
|
||||
self, topk_ids: torch.Tensor, indices_type: torch.dtype | None
|
||||
) -> torch.Tensor:
|
||||
"""Convert topk_ids to the desired dtype if needed."""
|
||||
if (indices_type is not None) and topk_ids.dtype != indices_type:
|
||||
topk_ids = topk_ids.to(dtype=indices_type)
|
||||
|
||||
assert topk_ids.dtype == indices_type or indices_type is None
|
||||
return topk_ids
|
||||
|
||||
@abstractmethod
|
||||
def _compute_routing(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
indices_type: torch.dtype | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Compute the actual routing logic.
|
||||
|
||||
This method must be implemented by subclasses to provide the specific
|
||||
routing algorithm (e.g., grouped_topk, fused_topk, custom routing, etc.).
|
||||
|
||||
Args:
|
||||
hidden_states: Input hidden states
|
||||
router_logits: Router logits for expert selection
|
||||
indices_type: Desired dtype for expert indices (may be None)
|
||||
|
||||
Returns:
|
||||
tuple of (topk_weights, topk_ids)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def select_experts(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Route the input hidden states to the top-k experts based on the
|
||||
router logits.
|
||||
|
||||
This method implements the template method pattern:
|
||||
1. Validates EPLB state
|
||||
2. Gets indices type
|
||||
3. Calls _compute_routing() to get topk_weights and topk_ids
|
||||
4. Applies EPLB mapping if enabled
|
||||
5. Converts indices dtype if needed
|
||||
|
||||
Returns:
|
||||
(topk_weights, topk_ids)
|
||||
(tuple[torch.Tensor, torch.Tensor]):
|
||||
The weights and expert ids computation result.
|
||||
|
||||
**Compatibility**: When EPLB is not enabled, the returned ids are
|
||||
equivalent to global logical ids, so should be compatible with
|
||||
plain MoE implementations without redundant experts.
|
||||
"""
|
||||
# Step 1: Validate EPLB state
|
||||
self._validate_eplb_state()
|
||||
|
||||
# Step 2: Get indices type.
|
||||
indices_type = self._get_indices_type()
|
||||
|
||||
# Step 3: Compute routing (delegated to subclass)
|
||||
topk_weights, topk_ids = self._compute_routing(
|
||||
hidden_states, router_logits, indices_type
|
||||
)
|
||||
|
||||
# Step 4: Apply EPLB mapping
|
||||
topk_ids = self._apply_eplb_mapping(topk_ids)
|
||||
|
||||
# Step 5: Convert indices dtype
|
||||
topk_ids = self._convert_indices_dtype(topk_ids, indices_type)
|
||||
|
||||
# TODO(bnell): temporary hack until select_experts is moved into FusedMoE
|
||||
if self.capture is not None:
|
||||
self.capture(topk_ids)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
@@ -0,0 +1,53 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed.eplb.eplb_state import EplbLayerState
|
||||
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
|
||||
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
|
||||
|
||||
|
||||
class CustomRoutingRouter(BaseRouter):
|
||||
"""Router using a custom user-provided routing function."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
top_k: int,
|
||||
global_num_experts: int,
|
||||
eplb_state: EplbLayerState,
|
||||
custom_routing_function: Callable,
|
||||
renormalize: bool = True,
|
||||
enable_eplb: bool = False,
|
||||
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
eplb_state=eplb_state,
|
||||
enable_eplb=enable_eplb,
|
||||
indices_type_getter=indices_type_getter,
|
||||
)
|
||||
self.custom_routing_function = custom_routing_function
|
||||
self.renormalize = renormalize
|
||||
|
||||
@property
|
||||
def routing_method_type(self) -> RoutingMethodType:
|
||||
return RoutingMethodType.Custom
|
||||
|
||||
def _compute_routing(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
indices_type: torch.dtype | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute routing using the custom routing function."""
|
||||
topk_weights, topk_ids = self.custom_routing_function(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
)
|
||||
|
||||
return topk_weights.to(torch.float32), topk_ids
|
||||
@@ -0,0 +1,88 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed.eplb.eplb_state import EplbLayerState
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
|
||||
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
|
||||
|
||||
|
||||
def fused_topk_bias(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
e_score_correction_bias: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
):
|
||||
n_routed_experts = gating_output.shape[-1]
|
||||
scores = gating_output.softmax(dim=-1)
|
||||
scores_for_choice = scores.view(
|
||||
-1, n_routed_experts
|
||||
) + e_score_correction_bias.unsqueeze(0)
|
||||
|
||||
# For batch invariance, use sorted=True to ensure deterministic expert selection
|
||||
use_sorted = vllm_is_batch_invariant()
|
||||
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1]
|
||||
topk_weights = scores.gather(1, topk_indices)
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
return topk_weights.to(torch.float32), topk_indices.to(torch.int32)
|
||||
|
||||
|
||||
class FusedTopKBiasRouter(BaseRouter):
|
||||
"""Router using fused top-k with e_score_correction_bias."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
top_k: int,
|
||||
global_num_experts: int,
|
||||
eplb_state: EplbLayerState,
|
||||
e_score_correction_bias: torch.Tensor,
|
||||
renormalize: bool = True,
|
||||
routed_scaling_factor: float = 1.0,
|
||||
enable_eplb: bool = False,
|
||||
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
eplb_state=eplb_state,
|
||||
enable_eplb=enable_eplb,
|
||||
indices_type_getter=indices_type_getter,
|
||||
)
|
||||
self.e_score_correction_bias = e_score_correction_bias
|
||||
self.renormalize = renormalize
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
|
||||
@property
|
||||
def routing_method_type(self) -> RoutingMethodType:
|
||||
return (
|
||||
RoutingMethodType.Renormalize
|
||||
if not self.renormalize
|
||||
else RoutingMethodType.RenormalizeNaive
|
||||
)
|
||||
|
||||
def _compute_routing(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
indices_type: torch.dtype | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute routing using fused top-k with bias."""
|
||||
topk_weights, topk_ids = fused_topk_bias(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
e_score_correction_bias=self.e_score_correction_bias.data,
|
||||
topk=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
)
|
||||
|
||||
if self.routed_scaling_factor != 1.0:
|
||||
topk_weights *= self.routed_scaling_factor
|
||||
|
||||
return topk_weights, topk_ids
|
||||
118
vllm/model_executor/layers/fused_moe/router/fused_topk_router.py
Normal file
118
vllm/model_executor/layers/fused_moe/router/fused_topk_router.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.distributed.eplb.eplb_state import EplbLayerState
|
||||
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
|
||||
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
|
||||
|
||||
|
||||
def vllm_topk_softmax(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
token_expert_indices: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
renormalize: bool,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
ops.topk_softmax(
|
||||
topk_weights,
|
||||
topk_indices,
|
||||
token_expert_indices,
|
||||
gating_output,
|
||||
renormalize,
|
||||
)
|
||||
|
||||
return topk_weights, topk_indices
|
||||
|
||||
|
||||
def dispatch_topk_func(
|
||||
use_rocm_aiter: bool = False,
|
||||
) -> Callable[..., tuple[torch.Tensor, ...]]:
|
||||
if use_rocm_aiter:
|
||||
return rocm_aiter_ops.topk_softmax
|
||||
return vllm_topk_softmax
|
||||
|
||||
|
||||
def fused_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
indices_type: torch.dtype | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
|
||||
|
||||
M, _ = hidden_states.size()
|
||||
|
||||
topk_weights = torch.empty(
|
||||
M, topk, dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
topk_ids = torch.empty(
|
||||
M,
|
||||
topk,
|
||||
dtype=torch.int32 if indices_type is None else indices_type,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
token_expert_indices = torch.empty(
|
||||
M, topk, dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
|
||||
topk_func = dispatch_topk_func(use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled())
|
||||
topk_weights, topk_ids = topk_func(
|
||||
topk_weights, topk_ids, token_expert_indices, gating_output, renormalize
|
||||
)
|
||||
|
||||
return topk_weights, topk_ids, token_expert_indices
|
||||
|
||||
|
||||
class FusedTopKRouter(BaseRouter):
|
||||
"""Default router using standard fused top-k routing."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
top_k: int,
|
||||
global_num_experts: int,
|
||||
eplb_state: EplbLayerState,
|
||||
scoring_func: str = "softmax",
|
||||
renormalize: bool = True,
|
||||
enable_eplb: bool = False,
|
||||
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
|
||||
):
|
||||
assert scoring_func == "softmax", "FusedTopKRouter only supports softmax."
|
||||
super().__init__(
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
eplb_state=eplb_state,
|
||||
enable_eplb=enable_eplb,
|
||||
indices_type_getter=indices_type_getter,
|
||||
)
|
||||
self.renormalize = renormalize
|
||||
|
||||
@property
|
||||
def routing_method_type(self) -> RoutingMethodType:
|
||||
return (
|
||||
RoutingMethodType.Renormalize
|
||||
if not self.renormalize
|
||||
else RoutingMethodType.RenormalizeNaive
|
||||
)
|
||||
|
||||
def _compute_routing(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
indices_type: torch.dtype | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute routing using standard fused top-k."""
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
indices_type=indices_type,
|
||||
)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
@@ -0,0 +1,353 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs as envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.distributed.eplb.eplb_state import EplbLayerState
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
rocm_aiter_grouped_topk,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
|
||||
from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import (
|
||||
fused_topk_bias,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.fused_topk_router import fused_topk
|
||||
from vllm.model_executor.utils import maybe_disable_graph_partition
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def fused_grouped_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
e_score_correction_bias: torch.Tensor,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
|
||||
|
||||
if scoring_func == "sigmoid":
|
||||
# Fully fused kernel path for sigmoid
|
||||
topk_values, topk_indices = ops.grouped_topk(
|
||||
gating_output, # raw logits
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
topk,
|
||||
renormalize,
|
||||
routed_scaling_factor,
|
||||
e_score_correction_bias,
|
||||
1, # scoring_func=1 for sigmoid
|
||||
)
|
||||
elif scoring_func == "softmax":
|
||||
# Apply softmax in Python, then use fused kernel
|
||||
# TODO: Add support for softmax in kernel
|
||||
scores = torch.softmax(gating_output, dim=-1)
|
||||
topk_values, topk_indices = ops.grouped_topk(
|
||||
scores, # pre-computed scores
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
topk,
|
||||
renormalize,
|
||||
routed_scaling_factor,
|
||||
e_score_correction_bias,
|
||||
0, # scoring_func=0 (no activation, scores already computed)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
||||
|
||||
# Fused kernel outputs float32 values and int32 indices directly
|
||||
return topk_values, topk_indices
|
||||
|
||||
|
||||
# This is used by the Deepseek-V2 and Deepseek-V3 model
|
||||
@torch.compile(
|
||||
dynamic=True,
|
||||
backend=current_platform.simple_compile_backend,
|
||||
options=maybe_disable_graph_partition(current_platform.simple_compile_backend),
|
||||
)
|
||||
def grouped_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if (
|
||||
envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK
|
||||
and current_platform.is_cuda()
|
||||
and num_expert_group <= 32
|
||||
and topk <= 32
|
||||
and e_score_correction_bias is not None
|
||||
):
|
||||
return fused_grouped_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=gating_output,
|
||||
topk=topk,
|
||||
renormalize=renormalize,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
|
||||
|
||||
if scoring_func == "softmax":
|
||||
scores = torch.softmax(gating_output, dim=-1)
|
||||
elif scoring_func == "sigmoid":
|
||||
scores = gating_output.sigmoid()
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
||||
|
||||
num_token = scores.size(0)
|
||||
if e_score_correction_bias is not None:
|
||||
# Store original scores before applying correction bias. We use biased
|
||||
# scores for expert selection but original scores for routing weights
|
||||
original_scores = scores
|
||||
scores = scores + e_score_correction_bias.unsqueeze(0)
|
||||
group_scores = (
|
||||
scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
|
||||
)
|
||||
else:
|
||||
group_scores = (
|
||||
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
||||
) # [n, n_group]
|
||||
|
||||
# For batch invariance, use sorted=True to ensure deterministic expert selection
|
||||
use_sorted = vllm_is_batch_invariant()
|
||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[
|
||||
1
|
||||
] # [n, top_k_group]
|
||||
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
||||
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
||||
score_mask = (
|
||||
group_mask.unsqueeze(-1)
|
||||
.expand(num_token, num_expert_group, scores.size(-1) // num_expert_group)
|
||||
.reshape(num_token, -1)
|
||||
) # [n, e]
|
||||
tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
|
||||
|
||||
if e_score_correction_bias is not None:
|
||||
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)[1]
|
||||
# Use original unbiased scores for the routing weights
|
||||
topk_weights = original_scores.gather(1, topk_ids)
|
||||
else:
|
||||
topk_weights, topk_ids = torch.topk(
|
||||
tmp_scores, k=topk, dim=-1, sorted=use_sorted
|
||||
)
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
if routed_scaling_factor != 1.0:
|
||||
topk_weights = topk_weights * routed_scaling_factor
|
||||
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||
|
||||
|
||||
# --8<-- [start:grouped_topk]
|
||||
@CustomOp.register("grouped_topk")
|
||||
class GroupedTopk(CustomOp):
|
||||
"""GroupedTopk used by the Deepseek-V2 and Deepseek-V3 model."""
|
||||
|
||||
# --8<-- [end:grouped_topk]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
num_fused_shared_experts: int = 0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.native_impl = grouped_topk
|
||||
self.topk = topk
|
||||
self.renormalize = renormalize
|
||||
self.num_expert_group = num_expert_group
|
||||
self.topk_group = topk_group
|
||||
self.scoring_func = scoring_func
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
self.num_fused_shared_experts = num_fused_shared_experts
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.native_impl(
|
||||
hidden_states,
|
||||
gating_output,
|
||||
self.topk,
|
||||
self.renormalize,
|
||||
self.num_expert_group,
|
||||
self.topk_group,
|
||||
self.scoring_func,
|
||||
self.routed_scaling_factor,
|
||||
e_score_correction_bias,
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.forward_native(
|
||||
hidden_states, gating_output, e_score_correction_bias
|
||||
)
|
||||
|
||||
def forward_hip(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if rocm_aiter_ops.is_fused_moe_enabled():
|
||||
if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled():
|
||||
assert self.num_fused_shared_experts == 0
|
||||
return rocm_aiter_grouped_topk(
|
||||
hidden_states,
|
||||
gating_output,
|
||||
self.topk,
|
||||
self.renormalize,
|
||||
self.num_expert_group,
|
||||
self.topk_group,
|
||||
self.scoring_func,
|
||||
self.routed_scaling_factor,
|
||||
e_score_correction_bias,
|
||||
self.num_fused_shared_experts,
|
||||
)
|
||||
else:
|
||||
return self.forward_native(
|
||||
hidden_states, gating_output, e_score_correction_bias
|
||||
)
|
||||
|
||||
|
||||
class GroupedTopKRouter(BaseRouter):
|
||||
"""Router using grouped top-k routing (e.g., DeepSeekV2/V3)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
top_k: int,
|
||||
global_num_experts: int,
|
||||
eplb_state: EplbLayerState,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
renormalize: bool = True,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
enable_eplb: bool = False,
|
||||
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
|
||||
routing_method_type: RoutingMethodType | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
eplb_state=eplb_state,
|
||||
enable_eplb=enable_eplb,
|
||||
indices_type_getter=indices_type_getter,
|
||||
)
|
||||
self.num_expert_group = num_expert_group
|
||||
self.topk_group = topk_group
|
||||
self.renormalize = renormalize
|
||||
self.scoring_func = scoring_func
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
self.e_score_correction_bias = e_score_correction_bias
|
||||
self.num_fused_shared_experts = num_fused_shared_experts
|
||||
|
||||
# Determine routing method type
|
||||
if routing_method_type is not None:
|
||||
self._routing_method_type = routing_method_type
|
||||
elif scoring_func == "sigmoid":
|
||||
self._routing_method_type = RoutingMethodType.DeepSeekV3
|
||||
else:
|
||||
self._routing_method_type = RoutingMethodType.TopK
|
||||
|
||||
@property
|
||||
def routing_method_type(self) -> RoutingMethodType:
|
||||
return self._routing_method_type
|
||||
|
||||
def _compute_routing(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
indices_type: torch.dtype | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute routing using grouped top-k."""
|
||||
|
||||
def valid_grouping() -> bool:
|
||||
# Check if num_experts is greater than num_expert_group
|
||||
# and is divisible by num_expert_group
|
||||
num_experts = router_logits.shape[-1]
|
||||
if num_experts <= self.num_expert_group:
|
||||
return False
|
||||
return num_experts % self.num_expert_group == 0
|
||||
|
||||
if not valid_grouping():
|
||||
if self.e_score_correction_bias is not None:
|
||||
topk_weights, topk_ids = fused_topk_bias(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
e_score_correction_bias=self.e_score_correction_bias.data,
|
||||
topk=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
)
|
||||
if self.routed_scaling_factor != 1.0:
|
||||
topk_weights *= self.routed_scaling_factor
|
||||
else:
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
indices_type=indices_type,
|
||||
)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
# Select grouped_topk implementation
|
||||
if rocm_aiter_ops.is_fused_moe_enabled():
|
||||
if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled():
|
||||
assert self.num_fused_shared_experts == 0
|
||||
grouped_topk_impl = partial(
|
||||
rocm_aiter_grouped_topk,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
)
|
||||
else:
|
||||
grouped_topk_impl = grouped_topk
|
||||
|
||||
topk_weights, topk_ids = grouped_topk_impl(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
num_expert_group=self.num_expert_group,
|
||||
topk_group=self.topk_group,
|
||||
scoring_func=self.scoring_func,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
178
vllm/model_executor/layers/fused_moe/router/router_factory.py
Normal file
178
vllm/model_executor/layers/fused_moe/router/router_factory.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.eplb.eplb_state import EplbLayerState
|
||||
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
|
||||
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
|
||||
from vllm.model_executor.layers.fused_moe.router.custom_routing_router import (
|
||||
CustomRoutingRouter,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
|
||||
FusedMoERouter,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import (
|
||||
FusedTopKBiasRouter,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.fused_topk_router import (
|
||||
FusedTopKRouter,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.grouped_topk_router import (
|
||||
GroupedTopKRouter,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.routing_simulator_router import (
|
||||
RoutingSimulatorRouter,
|
||||
)
|
||||
|
||||
EMPTY_EPLB_STATE: EplbLayerState = EplbLayerState()
|
||||
|
||||
|
||||
def create_fused_moe_router(
|
||||
# common parameters
|
||||
top_k: int,
|
||||
global_num_experts: int,
|
||||
renormalize: bool = True,
|
||||
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
|
||||
routing_method_type: RoutingMethodType | None = None,
|
||||
# grouped topk parameters
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: int | None = None,
|
||||
topk_group: int | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
num_fused_shared_experts: int = 0,
|
||||
# grouped topk + fused topk bias parameters
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
# custom routing paramaters
|
||||
custom_routing_function: Callable | None = None,
|
||||
# eplb parameters
|
||||
enable_eplb: bool = False,
|
||||
eplb_state: EplbLayerState = EMPTY_EPLB_STATE,
|
||||
capture: Callable[[torch.tensor], None] | None = None,
|
||||
) -> FusedMoERouter:
|
||||
"""
|
||||
Factory function to create the appropriate FusedMoERouter subclass based on
|
||||
the provided parameters.
|
||||
|
||||
The selection logic follows this priority order:
|
||||
1. RoutingSimulatorRouter - if VLLM_MOE_ROUTING_SIMULATION_STRATEGY env var is set
|
||||
2. GroupedTopKRouter - if use_grouped_topk is True
|
||||
3. CustomRoutingRouter - if custom_routing_function is not None
|
||||
4. FusedTopKBiasRouter - if e_score_correction_bias is not None
|
||||
5. FusedTopKRouter - default fallback
|
||||
|
||||
Common arguments:
|
||||
top_k: Number of experts to select per token
|
||||
global_num_experts: Total number of experts in the model
|
||||
renormalize: Whether to renormalize the routing weights
|
||||
indices_type_getter: Function to get the desired indices dtype
|
||||
routing_method_type: Optional explicit routing method type
|
||||
|
||||
Grouped topk arguments:
|
||||
use_grouped_topk: Whether to use grouped top-k routing
|
||||
num_expert_group: Number of expert groups (for grouped routing)
|
||||
topk_group: Top-k within each group (for grouped routing)
|
||||
scoring_func: Scoring function to use ("softmax" or "sigmoid")
|
||||
num_fused_shared_experts: Number of fused shared experts (for ROCm AITER)
|
||||
|
||||
Grouped topk and fused topk bias arguments:
|
||||
routed_scaling_factor: Scaling factor for routed weights
|
||||
e_score_correction_bias: Optional bias correction for expert scores
|
||||
|
||||
Custom routing arguments:
|
||||
custom_routing_function: Optional custom routing function
|
||||
|
||||
EPLB arguments:
|
||||
enable_eplb: Whether EPLB is enabled
|
||||
eplb_state: EPLB (Expert Parallelism Load Balancing) state
|
||||
|
||||
Returns:
|
||||
An instance of the appropriate FusedMoERouter subclass
|
||||
"""
|
||||
router: BaseRouter
|
||||
|
||||
routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY
|
||||
if routing_strategy != "":
|
||||
router = RoutingSimulatorRouter(
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
eplb_state=eplb_state,
|
||||
enable_eplb=enable_eplb,
|
||||
indices_type_getter=indices_type_getter,
|
||||
)
|
||||
# TODO(bnell): this is temporary until select_experts is
|
||||
# separated from apply.
|
||||
router.capture = capture
|
||||
return router
|
||||
|
||||
if use_grouped_topk:
|
||||
assert custom_routing_function is None
|
||||
if num_expert_group is None or topk_group is None:
|
||||
raise ValueError(
|
||||
"num_expert_group and topk_group must be provided when "
|
||||
"use_grouped_topk is True"
|
||||
)
|
||||
router = GroupedTopKRouter(
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
eplb_state=eplb_state,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
renormalize=renormalize,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
enable_eplb=enable_eplb,
|
||||
indices_type_getter=indices_type_getter,
|
||||
routing_method_type=routing_method_type,
|
||||
)
|
||||
router.capture = capture
|
||||
return router
|
||||
|
||||
if custom_routing_function is not None:
|
||||
router = CustomRoutingRouter(
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
eplb_state=eplb_state,
|
||||
custom_routing_function=custom_routing_function,
|
||||
renormalize=renormalize,
|
||||
enable_eplb=enable_eplb,
|
||||
indices_type_getter=indices_type_getter,
|
||||
)
|
||||
router.capture = capture
|
||||
return router
|
||||
|
||||
if scoring_func != "softmax":
|
||||
raise ValueError(
|
||||
"Only softmax scoring function is supported for non-grouped topk."
|
||||
)
|
||||
|
||||
if e_score_correction_bias is not None:
|
||||
router = FusedTopKBiasRouter(
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
eplb_state=eplb_state,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
renormalize=renormalize,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
enable_eplb=enable_eplb,
|
||||
indices_type_getter=indices_type_getter,
|
||||
)
|
||||
router.capture = capture
|
||||
return router
|
||||
|
||||
router = FusedTopKRouter(
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
eplb_state=eplb_state,
|
||||
renormalize=renormalize,
|
||||
scoring_func=scoring_func,
|
||||
enable_eplb=enable_eplb,
|
||||
indices_type_getter=indices_type_getter,
|
||||
)
|
||||
router.capture = capture
|
||||
return router
|
||||
@@ -1,20 +1,16 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Token-to-Expert Routing Simulator
|
||||
|
||||
This module provides a framework for simulating and testing different
|
||||
token-to-expert routing strategies for Mixture of Experts (MoE) models.
|
||||
It supports routing logic customization and includes example implementations
|
||||
like uniform random routing.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.eplb.eplb_state import EplbLayerState
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
|
||||
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -308,3 +304,44 @@ class RoutingSimulator:
|
||||
top_k=top_k,
|
||||
indices_type=indices_type,
|
||||
)
|
||||
|
||||
|
||||
class RoutingSimulatorRouter(BaseRouter):
|
||||
"""Router that uses routing simulation strategies for testing/debugging."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
top_k: int,
|
||||
global_num_experts: int,
|
||||
eplb_state: EplbLayerState,
|
||||
enable_eplb: bool = False,
|
||||
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
eplb_state=eplb_state,
|
||||
enable_eplb=enable_eplb,
|
||||
indices_type_getter=indices_type_getter,
|
||||
)
|
||||
|
||||
@property
|
||||
def routing_method_type(self) -> RoutingMethodType:
|
||||
return RoutingMethodType.Simulated
|
||||
|
||||
def _compute_routing(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
indices_type: torch.dtype | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Use routing simulator to compute routing."""
|
||||
routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY
|
||||
topk_weights, topk_ids = RoutingSimulator.simulate_routing(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
strategy_name=routing_strategy,
|
||||
top_k=self.top_k,
|
||||
indices_type=indices_type,
|
||||
)
|
||||
return topk_weights, topk_ids
|
||||
@@ -20,7 +20,6 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
|
||||
FusedMoEMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEActivationFormat,
|
||||
FusedMoEPermuteExpertsUnpermute,
|
||||
@@ -32,6 +31,9 @@ from vllm.model_executor.layers.fused_moe.oracle.unquantized import (
|
||||
make_unquantized_moe_kernel,
|
||||
select_unquantized_moe_backend,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
|
||||
FusedMoERouter,
|
||||
)
|
||||
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import CpuArchEnum
|
||||
@@ -312,9 +314,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if (
|
||||
layer.enable_eplb is not False
|
||||
or layer.expert_load_view is not None
|
||||
or layer.logical_to_physical_map is not None
|
||||
or layer.logical_replica_count is not None
|
||||
or layer.eplb_state.expert_load_view is not None
|
||||
or layer.eplb_state.logical_to_physical_map is not None
|
||||
or layer.eplb_state.logical_replica_count is not None
|
||||
):
|
||||
raise NotImplementedError("Expert load balancing is not supported for CPU.")
|
||||
|
||||
@@ -346,9 +348,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if (
|
||||
layer.enable_eplb is not False
|
||||
or layer.expert_load_view is not None
|
||||
or layer.logical_to_physical_map is not None
|
||||
or layer.logical_replica_count is not None
|
||||
or layer.eplb_state.expert_load_view is not None
|
||||
or layer.eplb_state.logical_to_physical_map is not None
|
||||
or layer.eplb_state.logical_replica_count is not None
|
||||
):
|
||||
raise NotImplementedError("Expert load balancing is not supported for XPU.")
|
||||
return layer.ipex_fusion(
|
||||
|
||||
@@ -10,12 +10,12 @@ from torch.nn import Parameter
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoERouter
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
|
||||
@@ -6,11 +6,11 @@ from typing import Any, Union
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoERouter
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
|
||||
@@ -22,6 +22,7 @@ from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEMethodBase,
|
||||
FusedMoEPermuteExpertsUnpermute,
|
||||
FusedMoERouter,
|
||||
FusedMoeWeightScaleSupported,
|
||||
UnquantizedFusedMoEMethod,
|
||||
)
|
||||
@@ -40,7 +41,6 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
MarlinExperts,
|
||||
fused_marlin_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
|
||||
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
|
||||
Fp8MoeBackend,
|
||||
convert_to_fp8_moe_kernel_format,
|
||||
|
||||
@@ -10,12 +10,12 @@ from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoE,
|
||||
FusedMoEConfig,
|
||||
FusedMoEMethodBase,
|
||||
FusedMoERouter,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
int8_w8a16_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
|
||||
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
|
||||
@@ -23,13 +23,13 @@ from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoEMethodBase,
|
||||
FusedMoEPermuteExpertsUnpermute,
|
||||
FusedMoEPrepareAndFinalize,
|
||||
FusedMoERouter,
|
||||
FusedMoeWeightScaleSupported,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
|
||||
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
|
||||
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
|
||||
Fp8MoeBackend,
|
||||
|
||||
@@ -12,11 +12,11 @@ from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoERouter
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
|
||||
@@ -10,12 +10,12 @@ from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoERouter
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
|
||||
@@ -8,10 +8,10 @@ from packaging import version
|
||||
from torch.nn import Module
|
||||
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_router import (
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoERouter,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearBase,
|
||||
LinearMethodBase,
|
||||
|
||||
@@ -13,11 +13,11 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoERouter
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
|
||||
@@ -6,12 +6,12 @@ from typing import Any, Optional
|
||||
import torch
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoERouter
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
int4_w4a16_moe_quant_config,
|
||||
int8_w8a16_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE,
|
||||
FusedMoEConfig,
|
||||
|
||||
@@ -14,6 +14,7 @@ from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoE,
|
||||
FusedMoEConfig,
|
||||
FusedMoEMethodBase,
|
||||
FusedMoERouter,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
@@ -27,7 +28,6 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
MarlinExperts,
|
||||
fused_marlin_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
|
||||
OAITritonExperts,
|
||||
UnfusedOAITritonExperts,
|
||||
@@ -936,9 +936,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
layer.apply_router_weight_on_input,
|
||||
layer.scoring_func,
|
||||
layer.activation,
|
||||
layer.expert_load_view,
|
||||
layer.logical_to_physical_map,
|
||||
layer.logical_replica_count,
|
||||
layer.eplb_state.expert_load_view,
|
||||
layer.eplb_state.logical_to_physical_map,
|
||||
layer.eplb_state.logical_replica_count,
|
||||
), "MXFP4 are not supported with this configuration."
|
||||
|
||||
if (
|
||||
|
||||
@@ -548,7 +548,7 @@ class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
topk_weights, topk_ids = router.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
@@ -10,12 +10,12 @@ import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoERouter
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
|
||||
@@ -201,6 +201,7 @@ class Ernie4_5_MoeMoE(nn.Module):
|
||||
e_score_correction_bias=self.gate.e_score_correction_bias,
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts,
|
||||
router_logits_dtype=torch.float32,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@@ -269,6 +269,7 @@ class Ernie4_5_VLMoeMoE(nn.Module):
|
||||
quant_config=quant_config,
|
||||
e_score_correction_bias=self.e_score_correction_bias[0],
|
||||
prefix=f"{prefix}.text_experts",
|
||||
router_logits_dtype=torch.float32,
|
||||
)
|
||||
else:
|
||||
self.text_experts = Ernie4_5_VLMoeMLP(
|
||||
@@ -306,6 +307,7 @@ class Ernie4_5_VLMoeMoE(nn.Module):
|
||||
quant_config=quant_config,
|
||||
e_score_correction_bias=self.e_score_correction_bias[1],
|
||||
prefix=f"{prefix}.vision_experts",
|
||||
router_logits_dtype=torch.float32,
|
||||
)
|
||||
else:
|
||||
self.vision_experts = Ernie4_5_VLMoeMLP(
|
||||
|
||||
Reference in New Issue
Block a user