[MoE Refactor] Move select_experts from FusedMoEQuantMethod -> FusedMoE (#31996)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2026-01-22 18:21:35 -05:00
committed by GitHub
parent fc56f4a071
commit dc917cceb8
22 changed files with 498 additions and 533 deletions

View File

@@ -22,7 +22,6 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoEMethodBase,
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize,
FusedMoERouter,
FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.fused_moe.config import (
@@ -968,71 +967,79 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def allow_inplace(self) -> bool:
return True
def apply(
@property
def is_monolithic(self) -> bool:
return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
def apply_monolithic(
self,
layer: FusedMoE,
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
# TODO(rob): convert this to MK.
if layer.enable_eplb:
raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
assert layer.activation == "silu", (
f"Expected 'silu' activation but got {layer.activation}"
)
assert self.is_monolithic
assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
if self.block_quant:
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
e_score_correction_bias = (
layer.e_score_correction_bias.to(x.dtype)
if layer.e_score_correction_bias is not None
else None
)
routing_method_type = layer.routing_method_type
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
routing_logits=router_logits.to(torch.float32)
if routing_method_type == RoutingMethodType.DeepSeekV3
else router_logits,
routing_bias=e_score_correction_bias,
x=x,
w13_weight=layer.w13_weight,
w13_weight_scale_inv=layer.w13_weight_scale_inv,
w2_weight=layer.w2_weight,
w2_weight_scale_inv=layer.w2_weight_scale_inv,
global_num_experts=layer.global_num_experts,
top_k=layer.top_k,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
intermediate_size=layer.intermediate_size_per_partition,
expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
block_shape=self.weight_block_size,
routing_method_type=routing_method_type,
routed_scaling=layer.routed_scaling_factor,
)
else:
return apply_fi_trtllm_fp8_per_tensor_moe(
layer=layer,
hidden_states=x,
router_logits=router_logits,
routing_bias=layer.e_score_correction_bias,
global_num_experts=layer.global_num_experts,
top_k=layer.top_k,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,
# TODO(rob): convert this to MK.
if layer.enable_eplb:
raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
assert layer.activation == "silu", (
f"Expected 'silu' activation but got {layer.activation}"
)
if self.block_quant:
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
e_score_correction_bias = (
layer.e_score_correction_bias.to(x.dtype)
if layer.e_score_correction_bias is not None
else None
)
routing_method_type = layer.routing_method_type
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
routing_logits=router_logits.to(torch.float32)
if routing_method_type == RoutingMethodType.DeepSeekV3
else router_logits,
routing_bias=e_score_correction_bias,
x=x,
w13_weight=layer.w13_weight,
w13_weight_scale_inv=layer.w13_weight_scale_inv,
w2_weight=layer.w2_weight,
w2_weight_scale_inv=layer.w2_weight_scale_inv,
global_num_experts=layer.global_num_experts,
top_k=layer.top_k,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
intermediate_size=layer.intermediate_size_per_partition,
expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
block_shape=self.weight_block_size,
routing_method_type=routing_method_type,
routed_scaling=layer.routed_scaling_factor,
)
else:
return apply_fi_trtllm_fp8_per_tensor_moe(
layer=layer,
hidden_states=x,
router_logits=router_logits,
routing_bias=layer.e_score_correction_bias,
global_num_experts=layer.global_num_experts,
top_k=layer.top_k,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
def apply(
self,
layer: FusedMoE,
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.kernel is not None
result = self.kernel(
assert not self.is_monolithic
return self.kernel(
x,
layer.w13_weight,
layer.w2_weight,
@@ -1045,8 +1052,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
return result
class Fp8OnlineMoEMethod(Fp8MoEMethod):
"""MoE method for online FP8 quantization.