[MoE Refactor] Move select_experts from FusedMoEQuantMethod -> FusedMoE (#31996)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2026-01-22 18:21:35 -05:00
committed by GitHub
parent fc56f4a071
commit dc917cceb8
22 changed files with 498 additions and 533 deletions

View File

@@ -210,7 +210,6 @@ def fused_marlin_moe(
bias2: torch.Tensor | None,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
gating_output: torch.Tensor | None,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
quant_type_id: int,
@@ -250,8 +249,6 @@ def fused_marlin_moe(
- w2 (torch.Tensor): The second set of expert weights.
- w1_scale (torch.Tensor): Scale to be used for w1.
- w2_scale (torch.Tensor): Scale to be used for w2.
- gating_output (torch.Tensor|None): The output of the gating
operation (before softmax).
- g_idx1 (torch.Tensor|None): The first set of act_order indices.
- g_idx2 (torch.Tensor|None): The second set of act_order indices.
- sort_indices1 (torch.Tensor|None): The first act_order input
@@ -292,8 +289,6 @@ def fused_marlin_moe(
topk = topk_ids.size(1)
# Check constraints.
if gating_output is not None:
assert gating_output.size(0) == M, "Number of tokens mismatch"
assert w1.size(1) * 16 == K, "Hidden size mismatch w1"
assert w2.size(2) // (num_bits // 2) == K, "Hidden size mismatch w2"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
@@ -381,7 +376,6 @@ def batched_fused_marlin_moe(
bias2: torch.Tensor | None,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
gating_output: torch.Tensor | None,
quant_type_id: int,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
@@ -718,7 +712,6 @@ class MarlinExperts(MarlinExpertsBase):
bias2=self.w2_bias,
w1_scale=self.w1_scale,
w2_scale=self.w2_scale,
gating_output=None,
topk_weights=topk_weights,
topk_ids=topk_ids,
global_scale1=self.g1_alphas,
@@ -833,7 +826,6 @@ class BatchedMarlinExperts(MarlinExpertsBase):
bias2=self.w2_bias,
w1_scale=self.w1_scale,
w2_scale=self.w2_scale,
gating_output=None,
quant_type_id=self.quant_type_id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,

View File

@@ -14,9 +14,6 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
FusedMoERouter,
)
from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase,
)
@@ -108,11 +105,24 @@ class FusedMoEMethodBase(QuantizeMethodBase):
def method_name(self) -> str:
return self.__class__.__name__
@abstractmethod
@property
def is_monolithic(self) -> bool:
return False
# @abstractmethod
def apply(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
router: FusedMoERouter,
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
# @abstractmethod
def apply_monolithic(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:

View File

@@ -16,9 +16,6 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel,
FusedMoEPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
FusedMoERouter,
)
logger = init_logger(__name__)
@@ -40,6 +37,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
not self.fused_experts.supports_expert_map(),
)
self.old_quant_method = old_quant_method
assert not self.old_quant_method.is_monolithic
logger.debug("Swapping out %s", self.old_quant_method.__class__.__name__)
@staticmethod
@@ -94,16 +92,11 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
def apply(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
result = self.fused_experts(
return self.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
@@ -115,5 +108,3 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=None if self.disable_expert_map else layer.expert_map,
)
return result

View File

@@ -522,8 +522,7 @@ class FusedMoE(CustomOp):
self.apply_router_weight_on_input = apply_router_weight_on_input
self.activation = activation
# TODO(bnell): in next PR move capture back to layer
capture: Callable[[torch.Tensor], None] | None = None
self.capture: Callable[[torch.Tensor], None] | None = None
if (
self.vllm_config.model_config is not None
and self.vllm_config.model_config.enable_return_routed_experts
@@ -531,7 +530,9 @@ class FusedMoE(CustomOp):
# In dummy runs, the capturer is not initialized.
capturer = RoutedExpertsCapturer.get_instance()
if capturer is not None:
capture = lambda topk_ids: capturer.capture(self.layer_id, topk_ids)
self.capture = lambda topk_ids: capturer.capture(
self.layer_id, topk_ids
)
self.router = create_fused_moe_router(
top_k=top_k,
@@ -550,7 +551,6 @@ class FusedMoE(CustomOp):
# TODO(bnell): once we can construct the MK at init time, we
# can make this a value.
indices_type_getter=lambda: self.quant_method.topk_indices_dtype,
capture=capture,
)
self.routing_method_type: RoutingMethodType = self.router.routing_method_type
@@ -1673,12 +1673,27 @@ class FusedMoE(CustomOp):
staged_router_logits.copy_(router_logits, non_blocking=True)
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
router=self.router,
x=staged_hidden_states,
router_logits=staged_router_logits,
)
if self.quant_method.is_monolithic:
final_hidden_states = self.quant_method.apply_monolithic(
layer=self,
x=staged_hidden_states,
router_logits=staged_router_logits,
)
else:
topk_weights, topk_ids = self.router.select_experts(
hidden_states=staged_hidden_states,
router_logits=staged_router_logits,
)
if self.capture is not None:
self.capture(topk_ids)
final_hidden_states = self.quant_method.apply(
layer=self,
x=staged_hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
)
if has_separate_shared_experts:
assert not isinstance(final_hidden_states, tuple)
@@ -1810,15 +1825,20 @@ class FusedMoE(CustomOp):
extra_tensors=extra_tensors,
)
if extra_tensors is not None:
hidden_states_combined, router_logits, extra_tensors_combined = (
dispatch_res
)
(
orig_hidden_states,
router_logits,
extra_tensors_combined,
) = dispatch_res
hidden_states_combined = (
hidden_states_combined,
orig_hidden_states,
extra_tensors_combined[0],
)
else:
hidden_states_combined, router_logits = dispatch_res
orig_hidden_states = hidden_states_combined
else:
orig_hidden_states = hidden_states
# Run shared experts before matrix multiply.
# because matrix multiply maybe modify the hidden_states.
@@ -1840,14 +1860,33 @@ class FusedMoE(CustomOp):
)
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
router=self.router,
x=hidden_states_combined
if do_naive_dispatch_combine
else hidden_states,
router_logits=router_logits,
)
x = hidden_states_combined if do_naive_dispatch_combine else hidden_states
# TODO(bnell): deal with fp4 flashinfer tuple hidden states hack (#30014).
# Figure out nicer way to do this.
x_orig = orig_hidden_states if do_naive_dispatch_combine else hidden_states
if self.quant_method.is_monolithic:
final_hidden_states = self.quant_method.apply_monolithic(
layer=self,
x=x,
router_logits=router_logits,
)
else:
topk_weights, topk_ids = self.router.select_experts(
hidden_states=x_orig,
router_logits=router_logits,
)
if self.capture is not None:
self.capture(topk_ids)
final_hidden_states = self.quant_method.apply(
layer=self,
x=x, # The type signture of this is wrong due to the hack.
topk_weights=topk_weights,
topk_ids=topk_ids,
)
if has_separate_shared_experts:
assert self.shared_experts is not None

View File

@@ -127,7 +127,6 @@ class BaseRouter(FusedMoERouter):
self.eplb_state = eplb_state
self.enable_eplb = enable_eplb
self.indices_type_getter = indices_type_getter
self.capture: Callable[[torch.tensor], None] | None = None
def _validate_eplb_state(self) -> None:
"""Validate that EPLB state is properly initialized if EPLB is enabled."""
@@ -238,8 +237,4 @@ class BaseRouter(FusedMoERouter):
# Step 5: Convert indices dtype
topk_ids = self._convert_indices_dtype(topk_ids, indices_type)
# TODO(bnell): temporary hack until select_experts is moved into FusedMoE
if self.capture is not None:
self.capture(topk_ids)
return topk_weights, topk_ids

View File

@@ -55,4 +55,6 @@ class CustomRoutingRouter(BaseRouter):
renormalize=self.renormalize,
)
return topk_weights.to(torch.float32), topk_ids
return topk_weights.to(torch.float32), topk_ids.to(
torch.int32 if indices_type is None else indices_type
)

View File

@@ -124,7 +124,9 @@ def fused_topk_bias(
topk_weights = scores.gather(1, topk_indices)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights.to(torch.float32), topk_indices.to(torch.int32)
return topk_weights.to(torch.float32), topk_indices.to(
torch.int32 if indices_type is None else indices_type
)
class FusedTopKBiasRouter(BaseRouter):
@@ -176,6 +178,7 @@ class FusedTopKBiasRouter(BaseRouter):
topk=self.top_k,
renormalize=self.renormalize,
scoring_func=self.scoring_func,
indices_type=indices_type,
)
if self.routed_scaling_factor != 1.0:

View File

@@ -6,7 +6,6 @@ import torch
import vllm.envs as envs
from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
from vllm.model_executor.layers.fused_moe.router.custom_routing_router import (
CustomRoutingRouter,
)
@@ -49,7 +48,6 @@ def create_fused_moe_router(
# eplb parameters
enable_eplb: bool = False,
eplb_state: EplbLayerState = EMPTY_EPLB_STATE,
capture: Callable[[torch.tensor], None] | None = None,
) -> FusedMoERouter:
"""
Factory function to create the appropriate FusedMoERouter subclass based on
@@ -90,21 +88,16 @@ def create_fused_moe_router(
Returns:
An instance of the appropriate FusedMoERouter subclass
"""
router: BaseRouter
routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY
if routing_strategy != "":
router = RoutingSimulatorRouter(
return RoutingSimulatorRouter(
top_k=top_k,
global_num_experts=global_num_experts,
eplb_state=eplb_state,
enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter,
)
# TODO(bnell): this is temporary until select_experts is
# separated from apply.
router.capture = capture
return router
if use_grouped_topk:
assert custom_routing_function is None
@@ -113,7 +106,7 @@ def create_fused_moe_router(
"num_expert_group and topk_group must be provided when "
"use_grouped_topk is True"
)
router = GroupedTopKRouter(
return GroupedTopKRouter(
top_k=top_k,
global_num_experts=global_num_experts,
eplb_state=eplb_state,
@@ -127,11 +120,9 @@ def create_fused_moe_router(
enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter,
)
router.capture = capture
return router
if custom_routing_function is not None:
router = CustomRoutingRouter(
return CustomRoutingRouter(
top_k=top_k,
global_num_experts=global_num_experts,
eplb_state=eplb_state,
@@ -140,11 +131,9 @@ def create_fused_moe_router(
enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter,
)
router.capture = capture
return router
if e_score_correction_bias is not None:
router = FusedTopKBiasRouter(
return FusedTopKBiasRouter(
top_k=top_k,
global_num_experts=global_num_experts,
eplb_state=eplb_state,
@@ -155,10 +144,8 @@ def create_fused_moe_router(
enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter,
)
router.capture = capture
return router
router = FusedTopKRouter(
return FusedTopKRouter(
top_k=top_k,
global_num_experts=global_num_experts,
eplb_state=eplb_state,
@@ -167,5 +154,3 @@ def create_fused_moe_router(
enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter,
)
router.capture = capture
return router

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
import torch.nn.functional as F
@@ -31,9 +32,6 @@ from vllm.model_executor.layers.fused_moe.oracle.unquantized import (
make_unquantized_moe_kernel,
select_unquantized_moe_backend,
)
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
FusedMoERouter,
)
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
@@ -66,6 +64,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
rocm_aiter_ops.is_fused_moe_enabled() and moe.is_act_and_mul
)
self.kernel: mk.FusedMoEModularKernel | None = None
self._is_monolithic = current_platform.is_cpu() or current_platform.is_xpu()
@property
def is_monolithic(self) -> bool:
return self._is_monolithic
@property
def supports_eplb(self) -> bool:
@@ -212,7 +215,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
import intel_extension_for_pytorch as ipex
ep_rank_start = self.moe.ep_rank * self.moe.num_local_experts
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
self.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
layer.w13_weight,
layer.w2_weight,
use_prepack=True,
@@ -244,11 +247,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
)
assert packed_w2_weight.size() == layer.w2_weight.size()
layer.w2_weight.copy_(packed_w2_weight)
layer.cpu_fused_moe = cpu_fused_moe.SGLFusedMOE(layer)
self.cpu_fused_moe: Callable = cpu_fused_moe.SGLFusedMOE(layer)
else:
layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
self.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
else:
layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
self.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
elif current_platform.is_cuda_alike():
self._setup_kernel(
layer=layer,
@@ -259,15 +262,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def apply(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.forward(
router=router,
layer=layer,
x=x,
router_logits=router_logits,
topk_weights=topk_weights,
topk_ids=topk_ids,
)
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
@@ -282,18 +285,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def forward_cuda(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.kernel
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
result = self.kernel(
assert self.kernel is not None
return self.kernel(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
@@ -306,24 +303,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map=layer.expert_map,
)
return result
def forward_cpu(
def forward_monolithic_cpu(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if (
layer.enable_eplb is not False
or layer.eplb_state.expert_load_view is not None
or layer.eplb_state.logical_to_physical_map is not None
or layer.eplb_state.logical_replica_count is not None
):
raise NotImplementedError("Expert load balancing is not supported for CPU.")
return layer.cpu_fused_moe(
return self.cpu_fused_moe(
layer,
x,
layer.use_grouped_topk,
@@ -342,21 +328,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.activation,
)
def forward_xpu(
def forward_monolithic_xpu(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if (
layer.enable_eplb is not False
or layer.eplb_state.expert_load_view is not None
or layer.eplb_state.logical_to_physical_map is not None
or layer.eplb_state.logical_replica_count is not None
):
raise NotImplementedError("Expert load balancing is not supported for XPU.")
return layer.ipex_fusion(
return self.ipex_fusion(
x,
layer.use_grouped_topk,
layer.top_k,
@@ -368,8 +346,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
)
if current_platform.is_cpu():
forward_native = forward_cpu
forward_native: Callable = forward_monolithic_cpu
apply_monolithic = forward_monolithic_cpu
elif current_platform.is_xpu():
forward_native = forward_xpu
forward_native = forward_monolithic_xpu
apply_monolithic = forward_monolithic_xpu
else:
forward_native = forward_cuda

View File

@@ -10,7 +10,6 @@ from torch.nn import Parameter
import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoERouter
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
@@ -762,15 +761,10 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
def apply(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
return fused_marlin_moe(
x,
layer.w13_qweight,
@@ -779,7 +773,6 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
getattr(layer, "w2_bias", None),
layer.w13_scales,
layer.w2_scales,
router_logits,
topk_weights,
topk_ids,
input_global_scale1=getattr(layer, "w13_input_global_scale", None),

View File

@@ -6,7 +6,6 @@ from typing import Any, Union
import torch
from packaging import version
from vllm.model_executor.layers.fused_moe import FusedMoERouter
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
@@ -499,16 +498,12 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
def apply(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
# TODO(bnell): Do these need to be called on the hot path?
if self.quant_config.load_in_8bit:
w13, w2 = self._apply_8bit_dequant(layer)

View File

@@ -21,7 +21,6 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoEActivationFormat,
FusedMoEMethodBase,
FusedMoEPermuteExpertsUnpermute,
FusedMoERouter,
FusedMoeWeightScaleSupported,
UnquantizedFusedMoEMethod,
)
@@ -126,7 +125,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
layer: torch.nn.Module,
layer_name: str,
) -> "CompressedTensorsMoEMethod":
) -> FusedMoEMethodBase:
# FusedMoE was made by combining multiple Linears so need to
# make sure quantization config for Linear can target it
quant_config._add_fused_moe_to_target_scheme_map()
@@ -345,19 +344,10 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
def apply(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if isinstance(x, tuple):
x_routing, _ = x
else:
x_routing = x
topk_weights, topk_ids = router.select_experts(
hidden_states=x_routing,
router_logits=router_logits,
)
assert self.kernel is not None
return self.kernel(
x,
@@ -639,41 +629,47 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
a2_scale=layer.w2_input_scale,
)
def apply(
@property
def is_monolithic(self) -> bool:
return (
self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
and not self.moe.moe_parallel_config.enable_eplb
)
def apply_monolithic(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.is_monolithic
assert layer.activation == "silu", "Only SiLU activation is supported."
if (
assert (
self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
and not layer.enable_eplb
):
return flashinfer_trtllm_fp4_moe(
layer=layer,
x=x,
router_logits=router_logits,
top_k=layer.top_k,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
custom_routing_function=layer.custom_routing_function,
e_score_correction_bias=layer.e_score_correction_bias,
)
# Hidden_states in select_experts is only used to extract metadata
if isinstance(x, tuple):
x_routing, _ = x
else:
x_routing = x
topk_weights, topk_ids = router.select_experts(
hidden_states=x_routing,
router_logits=router_logits,
)
return flashinfer_trtllm_fp4_moe(
layer=layer,
x=x,
router_logits=router_logits,
top_k=layer.top_k,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
custom_routing_function=layer.custom_routing_function,
e_score_correction_bias=layer.e_score_correction_bias,
)
def apply(
self,
layer: FusedMoE,
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
assert layer.activation == "silu", "Only SiLU activation is supported."
# EPLB path
if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
@@ -1059,70 +1055,73 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
block_shape=self.weight_block_size,
)
def apply(
@property
def is_monolithic(self) -> bool:
return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
def apply_monolithic(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
if layer.enable_eplb:
raise NotImplementedError(
"EPLB not supported for `FlashInfer TRTLLM FP8 MoE`."
)
assert layer.activation == "silu"
assert self.is_monolithic
assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
assert layer.activation == "silu"
if self.block_quant:
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
if self.block_quant:
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
e_score_correction_bias = (
layer.e_score_correction_bias.to(x.dtype)
if layer.e_score_correction_bias is not None
else None
)
routing_method_type = layer.routing_method_type
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
routing_logits=router_logits.to(torch.float32)
if routing_method_type == RoutingMethodType.DeepSeekV3
else router_logits,
routing_bias=e_score_correction_bias,
x=x,
w13_weight=layer.w13_weight,
w13_weight_scale_inv=layer.w13_weight_scale,
w2_weight=layer.w2_weight,
w2_weight_scale_inv=layer.w2_weight_scale,
global_num_experts=layer.global_num_experts,
top_k=layer.top_k,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
intermediate_size=layer.intermediate_size_per_partition,
expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
block_shape=self.weight_block_size,
routing_method_type=routing_method_type,
routed_scaling=layer.routed_scaling_factor,
)
else:
return apply_fi_trtllm_fp8_per_tensor_moe(
layer=layer,
hidden_states=x,
router_logits=router_logits,
routing_bias=layer.e_score_correction_bias,
global_num_experts=layer.global_num_experts,
top_k=layer.top_k,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
e_score_correction_bias = (
layer.e_score_correction_bias.to(x.dtype)
if layer.e_score_correction_bias is not None
else None
)
routing_method_type = layer.routing_method_type
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
routing_logits=router_logits.to(torch.float32)
if routing_method_type == RoutingMethodType.DeepSeekV3
else router_logits,
routing_bias=e_score_correction_bias,
x=x,
w13_weight=layer.w13_weight,
w13_weight_scale_inv=layer.w13_weight_scale,
w2_weight=layer.w2_weight,
w2_weight_scale_inv=layer.w2_weight_scale,
global_num_experts=layer.global_num_experts,
top_k=layer.top_k,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
intermediate_size=layer.intermediate_size_per_partition,
expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
block_shape=self.weight_block_size,
routing_method_type=routing_method_type,
routed_scaling=layer.routed_scaling_factor,
)
else:
return apply_fi_trtllm_fp8_per_tensor_moe(
layer=layer,
hidden_states=x,
router_logits=router_logits,
routing_bias=layer.e_score_correction_bias,
global_num_experts=layer.global_num_experts,
top_k=layer.top_k,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
def apply(
self,
layer: FusedMoE,
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
assert self.kernel is not None
result = self.kernel(
return self.kernel(
x,
layer.w13_weight,
layer.w2_weight,
@@ -1137,8 +1136,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
return result
@property
def supports_eplb(self) -> bool:
return True
@@ -1257,17 +1254,12 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
def apply(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
@@ -1621,15 +1613,10 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
def apply(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
return fused_marlin_moe(
x,
layer.w13_weight_packed,
@@ -1638,7 +1625,6 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
None,
layer.w13_weight_scale,
layer.w2_weight_scale,
router_logits,
topk_weights,
topk_ids,
input_global_scale1=getattr(layer, "w13_input_global_scale", None),
@@ -1873,17 +1859,12 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
def apply(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
return fused_experts(
x,
layer.w13_weight_packed,
@@ -2172,10 +2153,13 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
# fused_experts; quant config is not needed.
return None
def apply(
@property
def is_monolithic(self) -> bool:
return True
def apply_monolithic(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor:
@@ -2489,19 +2473,15 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def apply(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if layer.enable_eplb:
raise NotImplementedError(
"EPLB not supported for `CompressedTensorsW4A8Fp8MoEMethod` yet."
)
assert self.moe_quant_config is not None
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_w4a8_fp8,

View File

@@ -10,7 +10,6 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoE,
FusedMoEConfig,
FusedMoEMethodBase,
FusedMoERouter,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
@@ -138,17 +137,12 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
def apply(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
return fused_experts(
x,
layer.w13_weight,

View File

@@ -22,7 +22,6 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoEMethodBase,
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize,
FusedMoERouter,
FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.fused_moe.config import (
@@ -968,71 +967,79 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def allow_inplace(self) -> bool:
return True
def apply(
@property
def is_monolithic(self) -> bool:
return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
def apply_monolithic(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
# TODO(rob): convert this to MK.
if layer.enable_eplb:
raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
assert layer.activation == "silu", (
f"Expected 'silu' activation but got {layer.activation}"
)
assert self.is_monolithic
assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
if self.block_quant:
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
e_score_correction_bias = (
layer.e_score_correction_bias.to(x.dtype)
if layer.e_score_correction_bias is not None
else None
)
routing_method_type = layer.routing_method_type
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
routing_logits=router_logits.to(torch.float32)
if routing_method_type == RoutingMethodType.DeepSeekV3
else router_logits,
routing_bias=e_score_correction_bias,
x=x,
w13_weight=layer.w13_weight,
w13_weight_scale_inv=layer.w13_weight_scale_inv,
w2_weight=layer.w2_weight,
w2_weight_scale_inv=layer.w2_weight_scale_inv,
global_num_experts=layer.global_num_experts,
top_k=layer.top_k,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
intermediate_size=layer.intermediate_size_per_partition,
expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
block_shape=self.weight_block_size,
routing_method_type=routing_method_type,
routed_scaling=layer.routed_scaling_factor,
)
else:
return apply_fi_trtllm_fp8_per_tensor_moe(
layer=layer,
hidden_states=x,
router_logits=router_logits,
routing_bias=layer.e_score_correction_bias,
global_num_experts=layer.global_num_experts,
top_k=layer.top_k,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
# TODO(rob): convert this to MK.
if layer.enable_eplb:
raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
assert layer.activation == "silu", (
f"Expected 'silu' activation but got {layer.activation}"
)
if self.block_quant:
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
e_score_correction_bias = (
layer.e_score_correction_bias.to(x.dtype)
if layer.e_score_correction_bias is not None
else None
)
routing_method_type = layer.routing_method_type
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
routing_logits=router_logits.to(torch.float32)
if routing_method_type == RoutingMethodType.DeepSeekV3
else router_logits,
routing_bias=e_score_correction_bias,
x=x,
w13_weight=layer.w13_weight,
w13_weight_scale_inv=layer.w13_weight_scale_inv,
w2_weight=layer.w2_weight,
w2_weight_scale_inv=layer.w2_weight_scale_inv,
global_num_experts=layer.global_num_experts,
top_k=layer.top_k,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
intermediate_size=layer.intermediate_size_per_partition,
expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
block_shape=self.weight_block_size,
routing_method_type=routing_method_type,
routed_scaling=layer.routed_scaling_factor,
)
else:
return apply_fi_trtllm_fp8_per_tensor_moe(
layer=layer,
hidden_states=x,
router_logits=router_logits,
routing_bias=layer.e_score_correction_bias,
global_num_experts=layer.global_num_experts,
top_k=layer.top_k,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
def apply(
self,
layer: FusedMoE,
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.kernel is not None
result = self.kernel(
assert not self.is_monolithic
return self.kernel(
x,
layer.w13_weight,
layer.w2_weight,
@@ -1045,8 +1052,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
return result
class Fp8OnlineMoEMethod(Fp8MoEMethod):
"""MoE method for online FP8 quantization.

View File

@@ -12,7 +12,6 @@ from torch.nn.parameter import Parameter, UninitializedParameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoERouter
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
@@ -633,9 +632,9 @@ class GGUFMoEMethod(FusedMoEMethodBase):
def apply(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert layer.activation == "silu", "Only SiLU activation is supported."
if layer.apply_router_weight_on_input:
@@ -644,10 +643,6 @@ class GGUFMoEMethod(FusedMoEMethodBase):
"fused GGUF MoE method."
)
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
return fused_moe_gguf(
x,
layer.w13_qweight,

View File

@@ -10,7 +10,6 @@ from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoERouter
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
@@ -898,15 +897,10 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
def apply(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
return fused_marlin_moe(
x,
layer.w13_qweight,
@@ -915,7 +909,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
getattr(layer, "w2_bias", None),
layer.w13_scales,
layer.w2_scales,
router_logits,
topk_weights,
topk_ids,
input_global_scale1=getattr(layer, "w13_input_global_scale", None),

View File

@@ -8,9 +8,6 @@ from packaging import version
from torch.nn import Module
from vllm._ipex_ops import ipex_ops as ops
from vllm.model_executor.layers.fused_moe import (
FusedMoERouter,
)
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.linear import (
LinearBase,
@@ -384,10 +381,13 @@ class XPUFp8MoEMethod(Fp8OnlineMoEMethod):
) -> FusedMoEQuantConfig | None:
return None
def apply(
@property
def is_monolithic(self) -> bool:
return True
def apply_monolithic(
self,
layer: torch.nn.Module,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor:

View File

@@ -13,7 +13,6 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.attention.layer import Attention
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoERouter
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
@@ -945,42 +944,49 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
a2_scale=a2_scale,
)
def apply(
@property
def is_monolithic(self) -> bool:
return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
def apply_monolithic(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
if layer.enable_eplb:
raise NotImplementedError(
"EPLB not supported for FlashInfer TRTLLM FP8 MoE Backend."
)
# TODO(rob): this validation should happen at kernel selection
# time in the oracle rather than here.
assert layer.activation == "silu", (
f"Expected 'silu' activation but got {layer.activation}"
assert self.is_monolithic
assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
if layer.enable_eplb:
raise NotImplementedError(
"EPLB not supported for FlashInfer TRTLLM FP8 MoE Backend."
)
assert not layer.renormalize
return apply_fi_trtllm_fp8_per_tensor_moe(
layer=layer,
hidden_states=x,
router_logits=router_logits,
routing_bias=layer.e_score_correction_bias,
global_num_experts=layer.global_num_experts,
top_k=layer.top_k,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
# Expert selection
topk_weights, topk_ids = router.select_experts(
# TODO(rob): this validation should happen at kernel selection
# time in the oracle rather than here.
assert layer.activation == "silu", (
f"Expected 'silu' activation but got {layer.activation}"
)
assert not layer.renormalize
return apply_fi_trtllm_fp8_per_tensor_moe(
layer=layer,
hidden_states=x,
router_logits=router_logits,
routing_bias=layer.e_score_correction_bias,
global_num_experts=layer.global_num_experts,
top_k=layer.top_k,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
def apply(
self,
layer: FusedMoE,
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
# TODO(rob): this validation should happen at kernel selection
# time in the oracle rather than here.
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
@@ -990,7 +996,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
)
assert self.kernel is not None
result = self.kernel(
return self.kernel(
x,
layer.w13_weight,
layer.w2_weight,
@@ -1003,8 +1009,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
return result
ModelOptFp8Config.LinearMethodCls = ModelOptFp8LinearMethod
ModelOptFp8Config.FusedMoEMethodCls = ModelOptFp8MoEMethod
@@ -1629,40 +1633,47 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
def supports_eplb(self) -> bool:
return True
def apply(
@property
def is_monolithic(self) -> bool:
return (
self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
and not self.moe.moe_parallel_config.enable_eplb
)
def apply_monolithic(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if (
assert self.is_monolithic
assert (
self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
and not layer.enable_eplb
):
return flashinfer_trtllm_fp4_moe(
layer=layer,
x=x,
router_logits=router_logits,
top_k=layer.top_k,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
custom_routing_function=layer.custom_routing_function,
e_score_correction_bias=layer.e_score_correction_bias,
)
# Hidden_states in select_experts is only used to extract metadata
if isinstance(x, tuple):
x_routing, _ = x
else:
x_routing = x
topk_weights, topk_ids = router.select_experts(
hidden_states=x_routing,
router_logits=router_logits,
)
return flashinfer_trtllm_fp4_moe(
layer=layer,
x=x,
router_logits=router_logits,
top_k=layer.top_k,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
custom_routing_function=layer.custom_routing_function,
e_score_correction_bias=layer.e_score_correction_bias,
)
def apply(
self,
layer: FusedMoE,
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
# EPLB path
if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
assert layer.enable_eplb

View File

@@ -6,7 +6,6 @@ from typing import Any, Optional
import torch
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
from vllm.model_executor.layers.fused_moe import FusedMoERouter
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
int4_w4a16_moe_quant_config,
@@ -365,17 +364,13 @@ class MoeWNA16Method(FusedMoEMethodBase):
def apply(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
assert layer.activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
return fused_experts(
x,

View File

@@ -14,7 +14,6 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoE,
FusedMoEConfig,
FusedMoEMethodBase,
FusedMoERouter,
)
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import (
@@ -890,22 +889,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
def allow_inplace(self) -> bool:
return True
@property
def is_monolithic(self) -> bool:
return (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
or self.mxfp4_backend == Mxfp4Backend.TRITON
)
def apply(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
if layer.enable_eplb:
raise NotImplementedError("EPLB is not supported for mxfp4")
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
return fused_marlin_moe(
x,
layer.w13_weight,
@@ -914,7 +917,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer.w2_bias,
layer.w13_weight_scale,
layer.w2_weight_scale,
router_logits,
topk_weights,
topk_ids,
global_scale1=None,
@@ -942,6 +944,98 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer.eplb_state.logical_replica_count,
), "MXFP4 are not supported with this configuration."
assert (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
)
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
# Backend-specific preparation
if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS:
from flashinfer import mxfp8_quantize
x_quant, x_scale = mxfp8_quantize(x, True, 32)
fake_input_scale = torch.ones(self.num_experts, device=x.device)
quant_scales = [
layer.w13_weight_scale.contiguous().view(torch.int32),
fake_input_scale,
layer.w2_weight_scale.contiguous().view(torch.int32),
fake_input_scale,
]
fi_input = x_quant
extra_kwargs = dict(
use_mxfp8_act_scaling=True,
input_sf=x_scale,
fc1_expert_weights=layer.w13_weight.contiguous().view(torch.long),
fc2_expert_weights=layer.w2_weight.contiguous().view(torch.long),
)
elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16:
assert x.dtype == torch.bfloat16
quant_scales = [
layer.w13_weight_scale,
layer.w2_weight_scale,
]
fi_input = x
extra_kwargs = dict(
use_w4_group_scaling=True,
fc1_expert_weights=layer.w13_weight,
fc2_expert_weights=layer.w2_weight,
)
output = torch.empty_like(x, dtype=torch.bfloat16)
flashinfer_cutlass_fused_moe(
input=fi_input,
token_selected_experts=topk_ids.to(torch.int).contiguous(),
token_final_scales=topk_weights,
output_dtype=torch.bfloat16,
output=output,
quant_scales=quant_scales,
fc1_expert_biases=layer.w13_bias,
fc2_expert_biases=layer.w2_bias,
swiglu_alpha=layer.gemm1_alpha,
swiglu_beta=layer.gemm1_beta,
swiglu_limit=layer.gemm1_clamp_limit,
tp_size=self.moe.tp_size,
tp_rank=self.moe.tp_rank,
ep_size=self.moe.ep_size,
ep_rank=self.moe.ep_rank,
tune_max_num_tokens=max(self.max_capture_size, 1),
**extra_kwargs,
)
return output
def apply_monolithic(
self,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.is_monolithic
if layer.enable_eplb:
raise NotImplementedError("EPLB is not supported for mxfp4")
assert _can_support_mxfp4(
layer.use_grouped_topk,
layer.topk_group,
layer.num_expert_group,
layer.expert_map,
layer.custom_routing_function,
layer.e_score_correction_bias,
layer.apply_router_weight_on_input,
layer.scoring_func,
layer.activation,
layer.eplb_state.expert_load_view,
layer.eplb_state.logical_to_physical_map,
layer.eplb_state.logical_replica_count,
), "MXFP4 are not supported with this configuration."
if (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
@@ -988,75 +1082,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
tune_max_num_tokens=max(self.max_capture_size, 1),
)[0]
return trtllm_gen_output
elif (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
):
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
# Backend-specific preparation
if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS:
from flashinfer import mxfp8_quantize
x_quant, x_scale = mxfp8_quantize(x, True, 32)
fake_input_scale = torch.ones(self.num_experts, device=x.device)
quant_scales = [
layer.w13_weight_scale.contiguous().view(torch.int32),
fake_input_scale,
layer.w2_weight_scale.contiguous().view(torch.int32),
fake_input_scale,
]
fi_input = x_quant
extra_kwargs = dict(
use_mxfp8_act_scaling=True,
input_sf=x_scale,
fc1_expert_weights=layer.w13_weight.contiguous().view(torch.long),
fc2_expert_weights=layer.w2_weight.contiguous().view(torch.long),
)
elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16:
assert x.dtype == torch.bfloat16
quant_scales = [
layer.w13_weight_scale,
layer.w2_weight_scale,
]
fi_input = x
extra_kwargs = dict(
use_w4_group_scaling=True,
fc1_expert_weights=layer.w13_weight,
fc2_expert_weights=layer.w2_weight,
)
output = torch.empty_like(x, dtype=torch.bfloat16)
_ = flashinfer_cutlass_fused_moe(
input=fi_input,
token_selected_experts=topk_ids.to(torch.int).contiguous(),
token_final_scales=topk_weights,
output_dtype=torch.bfloat16,
output=output,
quant_scales=quant_scales,
fc1_expert_biases=layer.w13_bias,
fc2_expert_biases=layer.w2_bias,
swiglu_alpha=layer.gemm1_alpha,
swiglu_beta=layer.gemm1_beta,
swiglu_limit=layer.gemm1_clamp_limit,
tp_size=self.moe.tp_size,
tp_rank=self.moe.tp_rank,
ep_size=self.moe.ep_size,
ep_rank=self.moe.ep_rank,
tune_max_num_tokens=max(self.max_capture_size, 1),
**extra_kwargs,
)
return output
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501
triton_kernel_moe_forward,
@@ -1119,10 +1144,13 @@ class IpexMxfp4MoEMethod(Mxfp4MoEMethod):
experts_start_id=ep_rank_start,
)
def apply(
@property
def is_monolithic(self) -> bool:
return True
def apply_monolithic(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor:

View File

@@ -13,7 +13,6 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoE,
FusedMoEConfig,
FusedMoEMethodBase,
FusedMoERouter,
FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.fused_moe.config import (
@@ -351,15 +350,10 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
def apply(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
if self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts,
@@ -388,7 +382,6 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
None,
layer.w13_weight_scale,
layer.w2_weight_scale,
router_logits,
topk_weights,
topk_ids,
quant_type_id=scalar_types.float8_e4m3fn.id,
@@ -544,15 +537,10 @@ class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod):
def apply(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts,
)
@@ -753,15 +741,10 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
def apply(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
)
if not self.emulate:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts,