[Kernels] Isolate modular kernel code from FusedMoEMethodBase subclasses. (#27123)
This commit is contained in:
@@ -197,8 +197,6 @@ class Mxfp4Config(QuantizationConfig):
|
||||
class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
def __init__(self, moe: FusedMoEConfig):
|
||||
super().__init__(moe)
|
||||
self.topk_indices_dtype = None
|
||||
self.moe = moe
|
||||
self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled)
|
||||
self.max_capture_size = (
|
||||
get_current_vllm_config().compilation_config.max_cudagraph_capture_size
|
||||
@@ -815,6 +813,18 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
"EP batched experts format"
|
||||
)
|
||||
else:
|
||||
layer.w13_weight = (
|
||||
self.w13_weight_triton_tensor
|
||||
if layer.w13_weight is None
|
||||
else layer.w13_weight
|
||||
)
|
||||
layer.w2_weight = (
|
||||
self.w2_weight_triton_tensor
|
||||
if layer.w2_weight is None
|
||||
else layer.w2_weight
|
||||
)
|
||||
assert all([w is not None for w in [layer.w13_weight, layer.w2_weight]])
|
||||
|
||||
assert self.moe_quant_config is not None
|
||||
if (
|
||||
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
||||
@@ -838,71 +848,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for EP"
|
||||
)
|
||||
|
||||
def _route_and_experts(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert isinstance(self.fused_experts, mk.FusedMoEModularKernel)
|
||||
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
enable_eplb=enable_eplb,
|
||||
expert_map=expert_map,
|
||||
expert_load_view=expert_load_view,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count,
|
||||
)
|
||||
|
||||
w13_weight = (
|
||||
self.w13_weight_triton_tensor
|
||||
if layer.w13_weight is None
|
||||
else layer.w13_weight
|
||||
)
|
||||
w2_weight = (
|
||||
self.w2_weight_triton_tensor if layer.w2_weight is None else layer.w2_weight
|
||||
)
|
||||
assert all([w is not None for w in [w13_weight, w2_weight]])
|
||||
|
||||
return self.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=w13_weight,
|
||||
w2=w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
@property
|
||||
def allow_inplace(self) -> bool:
|
||||
return True
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@@ -930,29 +878,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
if enable_eplb:
|
||||
raise NotImplementedError("EPLB is not supported for mxfp4")
|
||||
|
||||
if self.fused_experts is not None:
|
||||
return self._route_and_experts(
|
||||
layer,
|
||||
x,
|
||||
router_logits,
|
||||
top_k,
|
||||
renormalize,
|
||||
use_grouped_topk,
|
||||
topk_group,
|
||||
num_expert_group,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
custom_routing_function,
|
||||
scoring_func,
|
||||
e_score_correction_bias,
|
||||
apply_router_weight_on_input,
|
||||
activation,
|
||||
enable_eplb,
|
||||
expert_load_view,
|
||||
logical_to_physical_map,
|
||||
logical_replica_count,
|
||||
)
|
||||
|
||||
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
|
||||
Reference in New Issue
Block a user