[MoE Refactor] Move select_experts from FusedMoEQuantMethod -> FusedMoE (#31996)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -14,7 +14,6 @@ from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoE,
|
||||
FusedMoEConfig,
|
||||
FusedMoEMethodBase,
|
||||
FusedMoERouter,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
@@ -890,22 +889,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
def allow_inplace(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
return (
|
||||
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
||||
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
|
||||
or self.mxfp4_backend == Mxfp4Backend.TRITON
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
router: FusedMoERouter,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert not self.is_monolithic
|
||||
if layer.enable_eplb:
|
||||
raise NotImplementedError("EPLB is not supported for mxfp4")
|
||||
|
||||
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
|
||||
topk_weights, topk_ids = router.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
return fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
@@ -914,7 +917,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_bias,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_scale1=None,
|
||||
@@ -942,6 +944,98 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
layer.eplb_state.logical_replica_count,
|
||||
), "MXFP4 are not supported with this configuration."
|
||||
|
||||
assert (
|
||||
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
|
||||
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
|
||||
)
|
||||
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
|
||||
|
||||
# Backend-specific preparation
|
||||
if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS:
|
||||
from flashinfer import mxfp8_quantize
|
||||
|
||||
x_quant, x_scale = mxfp8_quantize(x, True, 32)
|
||||
|
||||
fake_input_scale = torch.ones(self.num_experts, device=x.device)
|
||||
quant_scales = [
|
||||
layer.w13_weight_scale.contiguous().view(torch.int32),
|
||||
fake_input_scale,
|
||||
layer.w2_weight_scale.contiguous().view(torch.int32),
|
||||
fake_input_scale,
|
||||
]
|
||||
|
||||
fi_input = x_quant
|
||||
extra_kwargs = dict(
|
||||
use_mxfp8_act_scaling=True,
|
||||
input_sf=x_scale,
|
||||
fc1_expert_weights=layer.w13_weight.contiguous().view(torch.long),
|
||||
fc2_expert_weights=layer.w2_weight.contiguous().view(torch.long),
|
||||
)
|
||||
elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16:
|
||||
assert x.dtype == torch.bfloat16
|
||||
|
||||
quant_scales = [
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
]
|
||||
|
||||
fi_input = x
|
||||
extra_kwargs = dict(
|
||||
use_w4_group_scaling=True,
|
||||
fc1_expert_weights=layer.w13_weight,
|
||||
fc2_expert_weights=layer.w2_weight,
|
||||
)
|
||||
|
||||
output = torch.empty_like(x, dtype=torch.bfloat16)
|
||||
|
||||
flashinfer_cutlass_fused_moe(
|
||||
input=fi_input,
|
||||
token_selected_experts=topk_ids.to(torch.int).contiguous(),
|
||||
token_final_scales=topk_weights,
|
||||
output_dtype=torch.bfloat16,
|
||||
output=output,
|
||||
quant_scales=quant_scales,
|
||||
fc1_expert_biases=layer.w13_bias,
|
||||
fc2_expert_biases=layer.w2_bias,
|
||||
swiglu_alpha=layer.gemm1_alpha,
|
||||
swiglu_beta=layer.gemm1_beta,
|
||||
swiglu_limit=layer.gemm1_clamp_limit,
|
||||
tp_size=self.moe.tp_size,
|
||||
tp_rank=self.moe.tp_rank,
|
||||
ep_size=self.moe.ep_size,
|
||||
ep_rank=self.moe.ep_rank,
|
||||
tune_max_num_tokens=max(self.max_capture_size, 1),
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def apply_monolithic(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.is_monolithic
|
||||
|
||||
if layer.enable_eplb:
|
||||
raise NotImplementedError("EPLB is not supported for mxfp4")
|
||||
|
||||
assert _can_support_mxfp4(
|
||||
layer.use_grouped_topk,
|
||||
layer.topk_group,
|
||||
layer.num_expert_group,
|
||||
layer.expert_map,
|
||||
layer.custom_routing_function,
|
||||
layer.e_score_correction_bias,
|
||||
layer.apply_router_weight_on_input,
|
||||
layer.scoring_func,
|
||||
layer.activation,
|
||||
layer.eplb_state.expert_load_view,
|
||||
layer.eplb_state.logical_to_physical_map,
|
||||
layer.eplb_state.logical_replica_count,
|
||||
), "MXFP4 are not supported with this configuration."
|
||||
|
||||
if (
|
||||
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
||||
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
|
||||
@@ -988,75 +1082,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
tune_max_num_tokens=max(self.max_capture_size, 1),
|
||||
)[0]
|
||||
return trtllm_gen_output
|
||||
elif (
|
||||
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
|
||||
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
|
||||
):
|
||||
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
|
||||
|
||||
topk_weights, topk_ids = router.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
# Backend-specific preparation
|
||||
if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS:
|
||||
from flashinfer import mxfp8_quantize
|
||||
|
||||
x_quant, x_scale = mxfp8_quantize(x, True, 32)
|
||||
|
||||
fake_input_scale = torch.ones(self.num_experts, device=x.device)
|
||||
quant_scales = [
|
||||
layer.w13_weight_scale.contiguous().view(torch.int32),
|
||||
fake_input_scale,
|
||||
layer.w2_weight_scale.contiguous().view(torch.int32),
|
||||
fake_input_scale,
|
||||
]
|
||||
|
||||
fi_input = x_quant
|
||||
extra_kwargs = dict(
|
||||
use_mxfp8_act_scaling=True,
|
||||
input_sf=x_scale,
|
||||
fc1_expert_weights=layer.w13_weight.contiguous().view(torch.long),
|
||||
fc2_expert_weights=layer.w2_weight.contiguous().view(torch.long),
|
||||
)
|
||||
elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16:
|
||||
assert x.dtype == torch.bfloat16
|
||||
|
||||
quant_scales = [
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
]
|
||||
|
||||
fi_input = x
|
||||
extra_kwargs = dict(
|
||||
use_w4_group_scaling=True,
|
||||
fc1_expert_weights=layer.w13_weight,
|
||||
fc2_expert_weights=layer.w2_weight,
|
||||
)
|
||||
|
||||
output = torch.empty_like(x, dtype=torch.bfloat16)
|
||||
_ = flashinfer_cutlass_fused_moe(
|
||||
input=fi_input,
|
||||
token_selected_experts=topk_ids.to(torch.int).contiguous(),
|
||||
token_final_scales=topk_weights,
|
||||
output_dtype=torch.bfloat16,
|
||||
output=output,
|
||||
quant_scales=quant_scales,
|
||||
fc1_expert_biases=layer.w13_bias,
|
||||
fc2_expert_biases=layer.w2_bias,
|
||||
swiglu_alpha=layer.gemm1_alpha,
|
||||
swiglu_beta=layer.gemm1_beta,
|
||||
swiglu_limit=layer.gemm1_clamp_limit,
|
||||
tp_size=self.moe.tp_size,
|
||||
tp_rank=self.moe.tp_rank,
|
||||
ep_size=self.moe.ep_size,
|
||||
ep_rank=self.moe.ep_rank,
|
||||
tune_max_num_tokens=max(self.max_capture_size, 1),
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
return output
|
||||
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501
|
||||
triton_kernel_moe_forward,
|
||||
@@ -1119,10 +1144,13 @@ class IpexMxfp4MoEMethod(Mxfp4MoEMethod):
|
||||
experts_start_id=ep_rank_start,
|
||||
)
|
||||
|
||||
def apply(
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
return True
|
||||
|
||||
def apply_monolithic(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
router: FusedMoERouter,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
|
||||
Reference in New Issue
Block a user