[MoE Refactor][5/N] Isolate zero expert to LongCatFlash (#28891)
Signed-off-by: baonudesifeizhai <85092850+baonudesifeizhai@users.noreply.github.com> Signed-off-by: Dongjie Zou <85092850+baonudesifeizhai@users.noreply.github.com> Signed-off-by: baonudesifeizhai <baonudesifeizhai@gmail.com> Signed-off-by: Robert Shaw <robertgshaw2@gmail.com> Co-authored-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <robertgshaw2@gmail.com>
This commit is contained in:
@@ -127,7 +127,7 @@ def test_routing_strategy_integration(monkeypatch, device):
|
||||
envs.environment_variables[env_name] = lambda s=strategy: s
|
||||
|
||||
# Test the select_experts method
|
||||
topk_weights, topk_ids, _ = fused_moe.select_experts(
|
||||
topk_weights, topk_ids = fused_moe.select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
@@ -25,6 +25,9 @@ from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
|
||||
UnquantizedFusedMoEMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import activation_without_mul
|
||||
from vllm.model_executor.layers.fused_moe.zero_expert_fused_moe import (
|
||||
ZeroExpertFusedMoE,
|
||||
)
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
_config: dict[str, Any] | None = None
|
||||
@@ -54,6 +57,7 @@ __all__ = [
|
||||
"FusedMoEPrepareAndFinalize",
|
||||
"RoutingMethodType",
|
||||
"SharedFusedMoE",
|
||||
"ZeroExpertFusedMoE",
|
||||
"activation_without_mul",
|
||||
"override_config",
|
||||
"get_config",
|
||||
|
||||
@@ -92,7 +92,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
topk_weights, topk_ids, zero_expert_result = layer.select_experts(
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
@@ -110,10 +110,4 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
||||
expert_map=None if self.disable_expert_map else layer.expert_map,
|
||||
)
|
||||
|
||||
if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
|
||||
assert not isinstance(result, tuple), (
|
||||
"Shared + zero experts are mutually exclusive not yet supported"
|
||||
)
|
||||
return result, zero_expert_result
|
||||
else:
|
||||
return result
|
||||
return result
|
||||
|
||||
@@ -32,7 +32,6 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import zero_experts_compute_triton
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
init_aiter_topK_meta_data,
|
||||
)
|
||||
@@ -350,8 +349,6 @@ class FusedMoE(CustomOp):
|
||||
num_redundant_experts: int = 0,
|
||||
has_bias: bool = False,
|
||||
is_sequence_parallel=False,
|
||||
zero_expert_num: int | None = 0,
|
||||
zero_expert_type: str | None = None,
|
||||
expert_mapping: list[tuple[str, str, int, str]] | None = None,
|
||||
n_shared_experts: int | None = None,
|
||||
routing_method_type: int | None = None,
|
||||
@@ -409,8 +406,6 @@ class FusedMoE(CustomOp):
|
||||
|
||||
self.global_num_experts = num_experts + num_redundant_experts
|
||||
self.logical_num_experts = num_experts
|
||||
self.zero_expert_num = zero_expert_num
|
||||
self.zero_expert_type = zero_expert_type
|
||||
|
||||
# Expert mapping used in self.load_weights
|
||||
self.expert_mapping = expert_mapping
|
||||
@@ -1525,15 +1520,15 @@ class FusedMoE(CustomOp):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Route the input hidden states to the top-k experts based on the
|
||||
router logits.
|
||||
|
||||
Returns:
|
||||
(topk_weights, topk_ids, zero_expert_result)
|
||||
(tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
|
||||
The weights, expert ids, and zero expert computation result.
|
||||
(topk_weights, topk_ids)
|
||||
(tuple[torch.Tensor, torch.Tensor]):
|
||||
The weights and expert ids.
|
||||
|
||||
**Compatibility**: When EPLB is not enabled, the returned ids are
|
||||
equivalent to global logical ids, so should be compatible with
|
||||
@@ -1655,23 +1650,7 @@ class FusedMoE(CustomOp):
|
||||
|
||||
assert topk_ids.dtype == indices_type or indices_type is None
|
||||
|
||||
# Compute zero expert result if needed
|
||||
if (
|
||||
self.zero_expert_num is not None
|
||||
and self.zero_expert_num > 0
|
||||
and self.zero_expert_type is not None
|
||||
and self.global_num_experts is not None
|
||||
):
|
||||
zero_expert_result = zero_experts_compute_triton(
|
||||
expert_indices=topk_ids,
|
||||
expert_scales=topk_weights,
|
||||
num_experts=self.global_num_experts,
|
||||
zero_expert_type=self.zero_expert_type,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
else:
|
||||
zero_expert_result = None
|
||||
return topk_weights, topk_ids, zero_expert_result
|
||||
return topk_weights, topk_ids
|
||||
|
||||
def must_reduce_shared_expert_outputs(self) -> bool:
|
||||
"""
|
||||
@@ -1736,14 +1715,7 @@ class FusedMoE(CustomOp):
|
||||
fused_output = torch.ops.vllm.moe_forward(
|
||||
hidden_states, router_logits, self.layer_name
|
||||
)
|
||||
if self.zero_expert_num is not None and self.zero_expert_num > 0:
|
||||
assert isinstance(fused_output, tuple)
|
||||
fused_output, zero_expert_result = fused_output
|
||||
return (reduce_output(fused_output) + zero_expert_result)[
|
||||
..., :og_hidden_states
|
||||
]
|
||||
else:
|
||||
return reduce_output(fused_output)[..., :og_hidden_states]
|
||||
return reduce_output(fused_output)[..., :og_hidden_states]
|
||||
else:
|
||||
if current_platform.is_tpu() or current_platform.is_cpu():
|
||||
# TODO: Once the OOM issue for the TPU backend is resolved, we
|
||||
@@ -1841,13 +1813,6 @@ class FusedMoE(CustomOp):
|
||||
final_hidden_states,
|
||||
)
|
||||
|
||||
if self.zero_expert_num is not None and self.zero_expert_num > 0:
|
||||
assert isinstance(final_hidden_states, tuple)
|
||||
assert self.shared_experts is None
|
||||
final_hidden_states, zero_expert_result = final_hidden_states
|
||||
if zero_expert_result is not None:
|
||||
final_hidden_states += zero_expert_result
|
||||
|
||||
if not skip_result_store:
|
||||
if self.shared_experts is None:
|
||||
full_fused_final_hidden_states[chunk_start:chunk_end, :].copy_(
|
||||
@@ -2030,9 +1995,6 @@ class FusedMoE(CustomOp):
|
||||
shared_output,
|
||||
final_hidden_states,
|
||||
)
|
||||
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
|
||||
assert isinstance(final_hidden_states, tuple)
|
||||
final_hidden_states, zero_expert_result = final_hidden_states
|
||||
|
||||
def combine_output(states: torch.Tensor) -> torch.Tensor:
|
||||
if do_naive_dispatch_combine:
|
||||
@@ -2051,9 +2013,6 @@ class FusedMoE(CustomOp):
|
||||
final_hidden_states[0],
|
||||
combine_output(final_hidden_states[1]),
|
||||
)
|
||||
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
|
||||
assert isinstance(final_hidden_states, torch.Tensor)
|
||||
return (combine_output(final_hidden_states), zero_expert_result)
|
||||
else:
|
||||
return combine_output(final_hidden_states)
|
||||
|
||||
|
||||
@@ -295,7 +295,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
topk_weights, topk_ids, zero_expert_result = layer.select_experts(
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
@@ -336,13 +336,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
expert_map=layer.expert_map,
|
||||
)
|
||||
|
||||
if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
|
||||
assert not isinstance(result, tuple), (
|
||||
"Shared + zero experts are mutually exclusive not yet supported"
|
||||
)
|
||||
return result, zero_expert_result
|
||||
else:
|
||||
return result
|
||||
return result
|
||||
|
||||
def forward_cpu(
|
||||
self,
|
||||
|
||||
189
vllm/model_executor/layers/fused_moe/zero_expert_fused_moe.py
Normal file
189
vllm/model_executor/layers/fused_moe/zero_expert_fused_moe.py
Normal file
@@ -0,0 +1,189 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import zero_experts_compute_triton
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
|
||||
|
||||
class ZeroExpertFusedMoE(FusedMoE):
|
||||
"""
|
||||
A FusedMoE operation that also computes the results of zero experts.
|
||||
Zero experts perform identity operations (scaled pass-through) instead
|
||||
of full MLP computations.
|
||||
|
||||
This class uses memoization to avoid redundant routing computation:
|
||||
routing is computed once and reused for both zero expert computation
|
||||
and the main FusedMoE forward pass.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
zero_expert_num: int,
|
||||
zero_expert_type: str,
|
||||
router: nn.Module,
|
||||
**kwargs,
|
||||
):
|
||||
# ZeroExpertFusedMoE manages its own custom_routing_function for memoization
|
||||
assert (
|
||||
"custom_routing_function" not in kwargs
|
||||
or kwargs.get("custom_routing_function") is None
|
||||
), (
|
||||
"ZeroExpertFusedMoE does not support external custom_routing_function. "
|
||||
"It manages its own for routing memoization."
|
||||
)
|
||||
|
||||
# Automatically slice router's e_score_correction_bias to only include
|
||||
# real experts (not zero_experts) for the base FusedMoE.
|
||||
# The full bias will be used temporarily in forward() for routing.
|
||||
if hasattr(router, "e_score_correction_bias") and "num_experts" in kwargs:
|
||||
num_real_experts = kwargs["num_experts"]
|
||||
router_bias = router.e_score_correction_bias
|
||||
user_bias = kwargs.get("e_score_correction_bias")
|
||||
|
||||
# Use router's bias if:
|
||||
# 1. User didn't provide bias, or
|
||||
# 2. User provided full bias (same size as router)
|
||||
if user_bias is None or user_bias.shape[0] == router_bias.shape[0]:
|
||||
kwargs["e_score_correction_bias"] = router_bias[:num_real_experts]
|
||||
|
||||
# FusedMoE no longer accepts zero_expert_num/zero_expert_type.
|
||||
# We handle zero experts ourselves in forward().
|
||||
super().__init__(**kwargs)
|
||||
# Store the actual zero_expert_num and zero_expert_type for our own use
|
||||
self._actual_zero_expert_num = zero_expert_num
|
||||
self._actual_zero_expert_type = zero_expert_type
|
||||
self._router = router # Full router (includes zero experts)
|
||||
|
||||
# Expose zero_expert_num and zero_expert_type as attributes for
|
||||
# compatibility with quantization methods that check these attributes
|
||||
self.zero_expert_num = 0
|
||||
self.zero_expert_type = None
|
||||
|
||||
# Memoization state for routing results
|
||||
self._memoized_topk_weights: torch.Tensor | None = None
|
||||
self._memoized_topk_ids: torch.Tensor | None = None
|
||||
|
||||
# Create custom_routing_function to reuse memoized routing results
|
||||
def custom_routing_function(hidden_states, gating_output, topk, renormalize):
|
||||
"""Return memoized `topk_weights` and `topk_ids`."""
|
||||
if self._memoized_topk_weights is None or self._memoized_topk_ids is None:
|
||||
raise RuntimeError(
|
||||
"ZeroExpertFusedMoE: routing results not memoized. "
|
||||
"Call select_experts first to compute routing."
|
||||
)
|
||||
return self._memoized_topk_weights, self._memoized_topk_ids
|
||||
|
||||
self.custom_routing_function = custom_routing_function
|
||||
|
||||
@contextmanager
|
||||
def _temporarily_set_attrs(self, **attrs):
|
||||
"""
|
||||
Temporarily set attributes using object.__setattr__ and restore them.
|
||||
|
||||
This bypasses nn.Module.__setattr__ to avoid Dynamo tracing issues.
|
||||
When PyTorch Dynamo traces the forward pass, it cannot handle
|
||||
nn.Module.__setattr__ calls (which include parameter registration logic),
|
||||
resulting in "Unsupported" errors. Using object.__setattr__ directly
|
||||
sets the attribute without triggering nn.Module's custom __setattr__,
|
||||
allowing Dynamo to trace the code successfully.
|
||||
"""
|
||||
originals = {key: getattr(self, key) for key in attrs}
|
||||
try:
|
||||
for key, value in attrs.items():
|
||||
object.__setattr__(self, key, value)
|
||||
yield
|
||||
finally:
|
||||
for key, value in originals.items():
|
||||
object.__setattr__(self, key, value)
|
||||
|
||||
def _compute_zero_expert_result(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
) -> torch.Tensor | None:
|
||||
"""Compute zero expert results using pre-computed routing."""
|
||||
if (
|
||||
self._actual_zero_expert_num is None
|
||||
or self._actual_zero_expert_num <= 0
|
||||
or self._actual_zero_expert_type is None
|
||||
):
|
||||
return None
|
||||
|
||||
return zero_experts_compute_triton(
|
||||
expert_indices=topk_ids.clone(),
|
||||
expert_scales=topk_weights.clone(),
|
||||
num_experts=self.logical_num_experts,
|
||||
zero_expert_type=self._actual_zero_expert_type,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor, # Full logits including zero experts
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass with zero expert support and routing memoization.
|
||||
|
||||
Args:
|
||||
hidden_states: Input hidden states
|
||||
router_logits: Full router logits (including zero experts)
|
||||
|
||||
Returns:
|
||||
Combined output from real experts and zero experts
|
||||
"""
|
||||
# Prepare temporary attribute overrides for routing computation
|
||||
temp_attrs = {
|
||||
"custom_routing_function": None, # Disable for first routing
|
||||
}
|
||||
if self._router is not None:
|
||||
temp_attrs["e_score_correction_bias"] = self._router.e_score_correction_bias
|
||||
|
||||
# Compute routing with temporary attributes
|
||||
# Pass full router_logits (including zero experts) so that zero experts
|
||||
# can be properly identified in topk_ids
|
||||
with self._temporarily_set_attrs(**temp_attrs):
|
||||
topk_weights, topk_ids = self.select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits, # Full logits (includes zero experts)
|
||||
)
|
||||
|
||||
# Compute zero expert result if needed
|
||||
zero_expert_result = self._compute_zero_expert_result(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
)
|
||||
|
||||
# Memoize routing results for reuse in super().forward()
|
||||
self._memoized_topk_weights = topk_weights
|
||||
self._memoized_topk_ids = topk_ids
|
||||
|
||||
# Slice router_logits for real experts only
|
||||
router_logits_sliced = router_logits[..., : self.logical_num_experts]
|
||||
|
||||
# Compute real expert results (will reuse memoized routing via
|
||||
# custom_routing_function)
|
||||
# zero_expert_num is already 0, so FusedMoE won't handle zero experts
|
||||
fused_out = super().forward(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits_sliced,
|
||||
)
|
||||
|
||||
# Combine results
|
||||
# Both zero_expert_result and fused_out are computed from the same
|
||||
# hidden_states, so they should be on the same device.
|
||||
if zero_expert_result is not None:
|
||||
fused_out = fused_out + zero_expert_result
|
||||
|
||||
# Clear memoization after use
|
||||
self._memoized_topk_weights = None
|
||||
self._memoized_topk_ids = None
|
||||
|
||||
return fused_out
|
||||
@@ -764,7 +764,7 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert layer.activation == "silu", "Only SiLU activation is supported."
|
||||
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
@@ -500,7 +500,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
@@ -574,7 +574,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
e_score_correction_bias=layer.e_score_correction_bias,
|
||||
)
|
||||
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
@@ -1166,7 +1166,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
@@ -1403,7 +1403,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
@@ -1765,7 +1765,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
f"{layer.activation} not supported for Marlin MoE."
|
||||
)
|
||||
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
@@ -1991,7 +1991,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
@@ -2607,7 +2607,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
"EPLB not supported for `CompressedTensorsW4A8Fp8MoEMethod` yet."
|
||||
)
|
||||
assert self.moe_quant_config is not None
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
@@ -142,7 +142,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
@@ -1292,13 +1292,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
select_result = layer.select_experts(
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
topk_weights, topk_ids, zero_expert_result = select_result
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
||||
rocm_aiter_fused_experts,
|
||||
@@ -1353,13 +1351,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
|
||||
assert not isinstance(result, tuple), (
|
||||
"Shared + zero experts are mutually exclusive not yet supported"
|
||||
)
|
||||
return result, zero_expert_result
|
||||
else:
|
||||
return result
|
||||
return result
|
||||
|
||||
|
||||
class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
||||
|
||||
@@ -639,7 +639,7 @@ class GGUFMoEMethod(FusedMoEMethodBase):
|
||||
"fused GGUF MoE method."
|
||||
)
|
||||
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
@@ -900,7 +900,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert layer.activation == "silu", "Only SiLU activation is supported."
|
||||
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
@@ -796,7 +796,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
|
||||
# Expert selection
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
@@ -1599,7 +1599,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
x_routing, _ = x
|
||||
else:
|
||||
x_routing = x
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
hidden_states=x_routing,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
@@ -370,7 +370,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
assert layer.activation == "silu", "Only SiLU activation is supported."
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
@@ -896,7 +896,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
raise NotImplementedError("EPLB is not supported for mxfp4")
|
||||
|
||||
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
@@ -990,7 +990,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
):
|
||||
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
|
||||
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
@@ -338,7 +338,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
@@ -530,7 +530,7 @@ class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
@@ -738,7 +738,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
@@ -359,7 +359,7 @@ class RTNMoEMethod(FusedMoEMethodBase):
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
@@ -46,7 +46,7 @@ from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE, ZeroExpertFusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
@@ -179,7 +179,7 @@ class FlashConfig(PretrainedConfig):
|
||||
self.intermediate_size = (
|
||||
self.ffn_hidden_size
|
||||
if hasattr(self, "ffn_hidden_size")
|
||||
else self.intermediate_size
|
||||
else intermediate_size
|
||||
)
|
||||
if hasattr(self, "moe_intermediate_size"):
|
||||
self.moe_intermediate_size = self.moe_intermediate_size
|
||||
@@ -280,10 +280,6 @@ class LongcatMoe(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.zero_expert_num = config.zero_expert_num
|
||||
self.zero_expert_type = config.zero_expert_type
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.enable_eplb = enable_eplb
|
||||
# Gate always runs at half / full precision for now.
|
||||
self.rounter_params_dtype = params_dtype
|
||||
if config.router_dtype == "float32":
|
||||
@@ -291,25 +287,27 @@ class LongcatMoe(nn.Module):
|
||||
|
||||
self.router = LongcatRouter(
|
||||
config=config,
|
||||
zero_expert_num=self.zero_expert_num,
|
||||
zero_expert_num=config.zero_expert_num,
|
||||
rounter_params_dtype=self.rounter_params_dtype,
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
|
||||
self.experts = FusedMoE(
|
||||
assert config.zero_expert_num is not None
|
||||
assert config.zero_expert_type is not None
|
||||
self.experts = ZeroExpertFusedMoE(
|
||||
zero_expert_num=config.zero_expert_num,
|
||||
zero_expert_type=config.zero_expert_type,
|
||||
router=self.router,
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
reduce_results=True,
|
||||
params_dtype=params_dtype,
|
||||
e_score_correction_bias=self.router.e_score_correction_bias,
|
||||
renormalize=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts",
|
||||
zero_expert_num=self.zero_expert_num,
|
||||
zero_expert_type=self.zero_expert_type,
|
||||
enable_eplb=self.enable_eplb,
|
||||
enable_eplb=enable_eplb,
|
||||
routed_scaling_factor=config.routed_scaling_factor,
|
||||
)
|
||||
|
||||
@@ -317,11 +315,34 @@ class LongcatMoe(nn.Module):
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
router_logits = self.router(hidden_states.to(self.rounter_params_dtype))
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states, router_logits=router_logits
|
||||
# Align to FusedMoE padded hidden size to avoid dim mismatch
|
||||
padded_hidden = self.experts.hidden_size
|
||||
if hidden_dim < padded_hidden:
|
||||
hidden_states_padded = torch.nn.functional.pad(
|
||||
hidden_states,
|
||||
(0, padded_hidden - hidden_dim),
|
||||
mode="constant",
|
||||
value=0.0,
|
||||
)
|
||||
else:
|
||||
hidden_states_padded = hidden_states
|
||||
|
||||
router_logits_full = self.router(
|
||||
hidden_states_padded.to(self.rounter_params_dtype)
|
||||
)
|
||||
|
||||
# ZeroExpertFusedMoE handles routing memoization and zero expert computation
|
||||
# internally. Pass full router_logits (including zero experts) so that
|
||||
# zero experts can be properly identified in routing.
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states_padded,
|
||||
router_logits=router_logits_full, # Full logits (includes zero experts)
|
||||
)
|
||||
|
||||
# Crop back to original hidden dimension if padded earlier
|
||||
if padded_hidden != hidden_dim:
|
||||
final_hidden_states = final_hidden_states[..., :hidden_dim]
|
||||
|
||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||
|
||||
|
||||
@@ -419,6 +440,7 @@ class FlashDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn[0](
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
llama_4_scaling=None,
|
||||
)
|
||||
|
||||
hidden_states, residual = self.post_attention_layernorm[0](
|
||||
@@ -438,6 +460,7 @@ class FlashDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn[1](
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
llama_4_scaling=None,
|
||||
)
|
||||
hidden_states, residual = self.post_attention_layernorm[1](
|
||||
hidden_states, residual
|
||||
|
||||
Reference in New Issue
Block a user