[Misc][Refactor] Add FusedMoERouter object (#30519)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -11,6 +11,9 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
|
||||
FusedMoEMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_router import (
|
||||
FusedMoERouter,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE,
|
||||
FusedMoeWeightScaleSupported,
|
||||
@@ -48,6 +51,7 @@ def get_config() -> dict[str, Any] | None:
|
||||
|
||||
__all__ = [
|
||||
"FusedMoE",
|
||||
"FusedMoERouter",
|
||||
"FusedMoEConfig",
|
||||
"FusedMoEMethodBase",
|
||||
"UnquantizedFusedMoEMethod",
|
||||
|
||||
@@ -10,6 +10,9 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_router import (
|
||||
FusedMoERouter,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEPermuteExpertsUnpermute,
|
||||
FusedMoEPrepareAndFinalize,
|
||||
@@ -109,6 +112,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
def apply(
|
||||
self,
|
||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||
router: FusedMoERouter,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
@@ -12,6 +12,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
|
||||
FusedMoEMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel,
|
||||
FusedMoEPrepareAndFinalize,
|
||||
@@ -88,10 +89,11 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
||||
def apply(
|
||||
self,
|
||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||
router: FusedMoERouter,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
topk_weights, topk_ids = router.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
40
vllm/model_executor/layers/fused_moe/fused_moe_router.py
Normal file
40
vllm/model_executor/layers/fused_moe/fused_moe_router.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
|
||||
|
||||
|
||||
class FusedMoERouter(ABC):
|
||||
"""
|
||||
FusedMoERouter is an abstract class that provides a 'select_experts'
|
||||
method that is used for routing hidden states based on router logits.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def routing_method_type(self) -> RoutingMethodType:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def select_experts(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Route the input hidden states to the top-k experts based on the
|
||||
router logits.
|
||||
|
||||
Returns:
|
||||
(topk_weights, topk_ids)
|
||||
(tuple[torch.Tensor, torch.Tensor]):
|
||||
The weights and expert ids computation result.
|
||||
|
||||
**Compatibility**: When EPLB is not enabled, the returned ids are
|
||||
equivalent to global logical ids, so should be compatible with
|
||||
plain MoE implementations without redundant experts.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -31,6 +31,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
init_aiter_topK_meta_data,
|
||||
)
|
||||
@@ -284,6 +285,23 @@ def maybe_roundup_hidden_size(
|
||||
return hidden_size
|
||||
|
||||
|
||||
class FusedMoERouterImpl(FusedMoERouter):
|
||||
def __init__(self, layer: "FusedMoE"):
|
||||
super().__init__()
|
||||
self.layer = layer
|
||||
|
||||
@property
|
||||
def routing_method_type(self) -> RoutingMethodType:
|
||||
return self.layer.routing_method_type
|
||||
|
||||
def select_experts(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.layer._select_experts(hidden_states, router_logits)
|
||||
|
||||
|
||||
@CustomOp.register("fused_moe")
|
||||
class FusedMoE(CustomOp):
|
||||
"""FusedMoE layer for MoE models.
|
||||
@@ -339,7 +357,7 @@ class FusedMoE(CustomOp):
|
||||
is_sequence_parallel=False,
|
||||
expert_mapping: list[tuple[str, str, int, str]] | None = None,
|
||||
n_shared_experts: int | None = None,
|
||||
routing_method_type: int | None = None,
|
||||
routing_method_type: RoutingMethodType | None = None,
|
||||
router_logits_dtype: torch.dtype | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -529,7 +547,7 @@ class FusedMoE(CustomOp):
|
||||
|
||||
# ToDo: Better logic to determine the routing method type
|
||||
if routing_method_type is not None:
|
||||
self.routing_method_type = routing_method_type
|
||||
self.routing_method_type: RoutingMethodType = routing_method_type
|
||||
else:
|
||||
if scoring_func == "sigmoid":
|
||||
if self.use_grouped_topk:
|
||||
@@ -640,6 +658,8 @@ class FusedMoE(CustomOp):
|
||||
self.batched_hidden_states: torch.Tensor | None = None
|
||||
self.batched_router_logits: torch.Tensor | None = None
|
||||
|
||||
self.router = FusedMoERouterImpl(self)
|
||||
|
||||
# Note: maybe_init_modular_kernel should only be called by
|
||||
# prepare_communication_buffer_for_model.
|
||||
# This is called after all weight loading and post-processing, so it
|
||||
@@ -1503,7 +1523,7 @@ class FusedMoE(CustomOp):
|
||||
device=torch.cuda.current_device(),
|
||||
)
|
||||
|
||||
def select_experts(
|
||||
def _select_experts(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
@@ -1772,6 +1792,7 @@ class FusedMoE(CustomOp):
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
router=self.router,
|
||||
x=staged_hidden_states,
|
||||
router_logits=staged_router_logits,
|
||||
)
|
||||
@@ -1944,6 +1965,7 @@ class FusedMoE(CustomOp):
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
router=self.router,
|
||||
x=hidden_states_combined
|
||||
if do_naive_dispatch_combine
|
||||
else hidden_states,
|
||||
|
||||
@@ -22,6 +22,7 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
|
||||
FusedMoEMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEActivationFormat,
|
||||
FusedMoEPermuteExpertsUnpermute,
|
||||
@@ -285,10 +286,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
def apply(
|
||||
self,
|
||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||
router: FusedMoERouter,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.forward(
|
||||
router=router,
|
||||
layer=layer,
|
||||
x=x,
|
||||
router_logits=router_logits,
|
||||
@@ -306,10 +309,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
def forward_cuda(
|
||||
self,
|
||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||
router: FusedMoERouter,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
topk_weights, topk_ids = router.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
@@ -332,6 +336,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
def forward_cpu(
|
||||
self,
|
||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||
router: FusedMoERouter,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -365,6 +370,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
def forward_xpu(
|
||||
self,
|
||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||
router: FusedMoERouter,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
@@ -15,6 +15,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
@@ -759,12 +760,13 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
router: FusedMoERouter,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert layer.activation == "silu", "Only SiLU activation is supported."
|
||||
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
topk_weights, topk_ids = router.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
@@ -10,7 +10,11 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearBase,
|
||||
LinearMethodBase,
|
||||
@@ -495,12 +499,13 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
router: FusedMoERouter,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
topk_weights, topk_ids = router.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
@@ -40,6 +40,7 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
MarlinExperts,
|
||||
fused_marlin_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
|
||||
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
|
||||
Fp8MoeBackend,
|
||||
convert_to_fp8_moe_kernel_format,
|
||||
@@ -458,6 +459,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
router: FusedMoERouter,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -484,7 +486,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
x_routing, _ = x
|
||||
else:
|
||||
x_routing = x
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
topk_weights, topk_ids = router.select_experts(
|
||||
hidden_states=x_routing,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
@@ -926,10 +928,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
router: FusedMoERouter,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
topk_weights, topk_ids = router.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
@@ -1066,12 +1069,13 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
router: FusedMoERouter,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
topk_weights, topk_ids = router.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
@@ -1426,6 +1430,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
router: FusedMoERouter,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -1433,7 +1438,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
f"{layer.activation} not supported for Marlin MoE."
|
||||
)
|
||||
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
topk_weights, topk_ids = router.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
@@ -1677,12 +1682,13 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
router: FusedMoERouter,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
topk_weights, topk_ids = router.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
@@ -1978,6 +1984,7 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
router: FusedMoERouter,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
@@ -2290,6 +2297,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
router: FusedMoERouter,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
):
|
||||
@@ -2298,7 +2306,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
"EPLB not supported for `CompressedTensorsW4A8Fp8MoEMethod` yet."
|
||||
)
|
||||
assert self.moe_quant_config is not None
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
topk_weights, topk_ids = router.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
@@ -15,6 +15,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
int8_w8a16_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
|
||||
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
@@ -137,12 +138,13 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
router: FusedMoERouter,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
topk_weights, topk_ids = router.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
@@ -29,6 +29,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
|
||||
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
|
||||
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
|
||||
Fp8MoeBackend,
|
||||
@@ -997,6 +998,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
router: FusedMoERouter,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -1051,7 +1053,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
topk_weights, topk_ids = router.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
@@ -16,7 +16,11 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearBase,
|
||||
LinearMethodBase,
|
||||
@@ -629,6 +633,7 @@ class GGUFMoEMethod(FusedMoEMethodBase):
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
router: FusedMoERouter,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -639,7 +644,7 @@ class GGUFMoEMethod(FusedMoEMethodBase):
|
||||
"fused GGUF MoE method."
|
||||
)
|
||||
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
topk_weights, topk_ids = router.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
@@ -15,6 +15,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
@@ -895,12 +896,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
router: FusedMoERouter,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert layer.activation == "silu", "Only SiLU activation is supported."
|
||||
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
topk_weights, topk_ids = router.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
@@ -9,6 +9,9 @@ from torch.nn import Module
|
||||
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_router import (
|
||||
FusedMoERouter,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearBase,
|
||||
LinearMethodBase,
|
||||
@@ -384,6 +387,7 @@ class XPUFp8MoEMethod(Fp8OnlineMoEMethod):
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
router: FusedMoERouter,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
|
||||
@@ -14,8 +14,10 @@ from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
@@ -200,7 +202,9 @@ class ModelOptQuantConfigBase(QuantizationConfig):
|
||||
quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
|
||||
return quant_method
|
||||
elif isinstance(layer, FusedMoE):
|
||||
quant_method = self.FusedMoEMethodCls(quant_config=self, layer=layer)
|
||||
quant_method = self.FusedMoEMethodCls(
|
||||
quant_config=self, moe_config=layer.moe_config
|
||||
)
|
||||
if getattr(quant_method, "backend", "") == "marlin":
|
||||
quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
|
||||
return quant_method
|
||||
@@ -720,14 +724,14 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: ModelOptFp8Config,
|
||||
layer: FusedMoE,
|
||||
moe_config: FusedMoEConfig,
|
||||
) -> None:
|
||||
super().__init__(layer.moe_config)
|
||||
super().__init__(moe_config)
|
||||
self.quant_config = quant_config
|
||||
assert self.quant_config.is_checkpoint_fp8_serialized
|
||||
self.fp8_backend = select_fp8_moe_backend(
|
||||
block_quant=False,
|
||||
tp_size=layer.moe_parallel_config.tp_size,
|
||||
tp_size=moe_config.moe_parallel_config.tp_size,
|
||||
with_lora_support=self.moe.is_lora_enabled,
|
||||
)
|
||||
self.kernel: mk.FusedMoEModularKernel | None = None
|
||||
@@ -935,6 +939,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
router: FusedMoERouter,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -961,7 +966,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
# Expert selection
|
||||
topk_weights, topk_ids = router.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
@@ -1325,9 +1331,9 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: ModelOptNvFp4Config,
|
||||
layer: FusedMoE,
|
||||
moe_config: FusedMoEConfig,
|
||||
) -> None:
|
||||
super().__init__(layer.moe_config)
|
||||
super().__init__(moe_config)
|
||||
self.quant_config = quant_config
|
||||
self.nvfp4_backend = select_nvfp4_moe_backend()
|
||||
# TODO: move this type of check into the oracle.
|
||||
@@ -1597,6 +1603,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
router: FusedMoERouter,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -1621,7 +1628,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
x_routing, _ = x
|
||||
else:
|
||||
x_routing = x
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
topk_weights, topk_ids = router.select_experts(
|
||||
hidden_states=x_routing,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
@@ -11,6 +11,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
int4_w4a16_moe_quant_config,
|
||||
int8_w8a16_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE,
|
||||
FusedMoEConfig,
|
||||
@@ -364,13 +365,14 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
router: FusedMoERouter,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
assert layer.activation == "silu", "Only SiLU activation is supported."
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
topk_weights, topk_ids = router.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
@@ -27,6 +27,7 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
MarlinExperts,
|
||||
fused_marlin_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
|
||||
OAITritonExperts,
|
||||
UnfusedOAITritonExperts,
|
||||
@@ -891,6 +892,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
router: FusedMoERouter,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -898,7 +900,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
raise NotImplementedError("EPLB is not supported for mxfp4")
|
||||
|
||||
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
topk_weights, topk_ids = router.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
@@ -992,7 +994,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
):
|
||||
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
|
||||
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
topk_weights, topk_ids = router.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
@@ -1119,7 +1121,8 @@ class IpexMxfp4MoEMethod(Mxfp4MoEMethod):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
layer: FusedMoE,
|
||||
router: FusedMoERouter,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
|
||||
@@ -13,6 +13,7 @@ from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoE,
|
||||
FusedMoEConfig,
|
||||
FusedMoEMethodBase,
|
||||
FusedMoERouter,
|
||||
FusedMoeWeightScaleSupported,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
@@ -350,10 +351,11 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
router: FusedMoERouter,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
topk_weights, topk_ids = router.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
@@ -542,6 +544,7 @@ class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
router: FusedMoERouter,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -750,10 +753,11 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
router: FusedMoERouter,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
topk_weights, topk_ids = router.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
@@ -15,7 +15,11 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearBase,
|
||||
LinearMethodBase,
|
||||
@@ -356,10 +360,11 @@ class RTNMoEMethod(FusedMoEMethodBase):
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
router: FusedMoERouter,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
topk_weights, topk_ids = router.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user