[Kernels] Isolate modular kernel code from FusedMoEMethodBase subclasses. (#27123)

This commit is contained in:
bnellnm
2025-11-04 08:59:45 -05:00
committed by GitHub
parent e4ee658672
commit 938772af03
16 changed files with 271 additions and 311 deletions

View File

@@ -117,10 +117,8 @@ class FusedMoeWeightScaleSupported(Enum):
class FusedMoEMethodBase(QuantizeMethodBase):
def __init__(self, moe: FusedMoEConfig):
super().__init__()
self.moe = moe
self.moe: FusedMoEConfig = moe
self.moe_quant_config: FusedMoEQuantConfig | None = None
self.fused_experts: FusedMoEModularKernel | None = None
self.topk_indices_dtype = None
@abstractmethod
def create_weights(
@@ -245,9 +243,9 @@ class FusedMoEMethodBase(QuantizeMethodBase):
else:
return None
# Note: init_prepare_finalize should only be called by
# prepare_communication_buffer_for_model.
def init_prepare_finalize(self, layer: torch.nn.Module):
def maybe_init_modular_kernel(
self, layer: torch.nn.Module
) -> FusedMoEModularKernel | None:
assert self.moe is not None
# We must get the quant config here so that the layer is
@@ -261,17 +259,14 @@ class FusedMoEMethodBase(QuantizeMethodBase):
logger.debug(
"%s for %s(%s)", prepare_finalize.__class__.__name__, self, id(self)
)
assert self.topk_indices_dtype is None
assert self.fused_experts is None, (
f"Attempt to override experts for {id(self)}!"
)
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
experts = self.select_gemm_impl(prepare_finalize, layer)
self.fused_experts = FusedMoEModularKernel(
return FusedMoEModularKernel(
prepare_finalize,
experts,
layer.shared_experts,
)
else:
return None
def select_gemm_impl(
self,
@@ -292,8 +287,16 @@ class FusedMoEMethodBase(QuantizeMethodBase):
raise NotImplementedError
@property
def using_modular_kernel(self) -> bool:
return self.fused_experts is not None
def topk_indices_dtype(self) -> torch.dtype | None:
return None
@property
def supports_eplb(self) -> bool:
return False
@property
def allow_inplace(self) -> bool:
return False
@abstractmethod
def apply(
@@ -322,6 +325,138 @@ class FusedMoEMethodBase(QuantizeMethodBase):
raise NotImplementedError
@CustomOp.register("modular_fused_moe")
class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
def __init__(
self,
old_quant_method: FusedMoEMethodBase,
fused_experts: FusedMoEModularKernel,
):
super().__init__(old_quant_method.moe)
# Find better way to copy attributes? Should we even copy attributes?
# self.__dict__.update(old_quant_method.__dict__)
self.moe_quant_config = old_quant_method.moe_quant_config
self.fused_experts = fused_experts
self.disable_expert_map = getattr(
old_quant_method,
"disable_expert_map",
not fused_experts.supports_expert_map(),
)
self.old_quant_method = old_quant_method
logger.debug("Swapping out %s", self.old_quant_method.__class__.__name__)
@property
def topk_indices_dtype(self) -> torch.dtype | None:
return self.fused_experts.prepare_finalize.topk_indices_dtype()
@property
def supports_eplb(self) -> bool:
return self.old_quant_method.supports_eplb
@property
def allow_inplace(self) -> bool:
return self.old_quant_method.allow_inplace
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
raise NotImplementedError
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
return self.moe_quant_config
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
# Is getattr needed?
zero_expert_num = getattr(layer, "zero_expert_num", 0)
zero_expert_type = getattr(layer, "zero_expert_type", None)
if enable_eplb:
if self.supports_eplb:
assert expert_load_view is not None
assert logical_to_physical_map is not None
assert logical_replica_count is not None
assert isinstance(layer, FusedMoE)
else:
raise NotImplementedError(
"EPLB is not supported for "
f"{self.old_quant_method.__class__.__name__}."
)
topk_weights, topk_ids, zero_expert_result = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
enable_eplb=enable_eplb,
expert_map=expert_map,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
global_num_experts=global_num_experts,
zero_expert_num=zero_expert_num,
zero_expert_type=zero_expert_type,
)
result = self.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=self.allow_inplace,
activation=activation,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=None if self.disable_expert_map else expert_map,
)
if zero_expert_num != 0 and zero_expert_type is not None:
assert not isinstance(result, tuple), (
"Shared + zero experts are mutually exclusive not yet supported"
)
return result, zero_expert_result
else:
return result
@CustomOp.register("unquantized_fused_moe")
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization."""
@@ -378,6 +513,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
)
self.flashinfer_cutlass_moe = None # type: ignore
@property
def supports_eplb(self) -> bool:
return True
@property
def allow_inplace(self) -> bool:
return True
def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None:
if self.rocm_aiter_moe_enabled:
return None
@@ -650,7 +793,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
)
if self.rocm_aiter_moe_enabled:
assert self.fused_experts is None
result = self.rocm_aiter_fused_experts(
hidden_states=x,
w1=layer.w13_weight,
@@ -671,21 +813,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
)
elif self.fused_experts is not None:
result = self.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
)
else:
assert fused_experts is not None
result = fused_experts(
hidden_states=x,
w1=layer.w13_weight,
@@ -1267,7 +1395,7 @@ class FusedMoE(CustomOp):
"Only softmax scoring function is supported for non-grouped topk."
)
moe = FusedMoEConfig(
self.moe_config: FusedMoEConfig = FusedMoEConfig(
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
@@ -1279,24 +1407,26 @@ class FusedMoE(CustomOp):
is_act_and_mul=is_act_and_mul,
is_lora_enabled=vllm_config.lora_config is not None,
)
self.moe_config: FusedMoEConfig = moe
self.moe_quant_config: FusedMoEQuantConfig | None = None
self.quant_config = quant_config
def _get_quant_method() -> FusedMoEMethodBase:
"""
Helper method to ensure self.quant_method is never None and
of the proper type.
"""
quant_method = None
if self.quant_config is not None:
quant_method = self.quant_config.get_quant_method(self, prefix)
if quant_method is None:
quant_method = UnquantizedFusedMoEMethod(self.moe_config)
assert isinstance(quant_method, FusedMoEMethodBase)
return quant_method
# Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first.
quant_method: QuantizeMethodBase | None = None
quant_method = (
UnquantizedFusedMoEMethod(moe)
if quant_config is None
else quant_config.get_quant_method(self, prefix)
)
if quant_method is None:
quant_method = UnquantizedFusedMoEMethod(moe)
assert quant_method is not None
assert isinstance(quant_method, FusedMoEMethodBase)
self.quant_method = quant_method
self.quant_method: FusedMoEMethodBase = _get_quant_method()
if not self.moe_config.is_act_and_mul:
# Avoid circular import
@@ -1305,7 +1435,7 @@ class FusedMoE(CustomOp):
)
if not isinstance(
quant_method, (UnquantizedFusedMoEMethod, ModelOptFp8MoEMethod)
self.quant_method, (UnquantizedFusedMoEMethod, ModelOptFp8MoEMethod)
):
raise NotImplementedError(
"is_act_and_mul=False is supported only for unquantized "
@@ -1316,20 +1446,18 @@ class FusedMoE(CustomOp):
"is_act_and_mul=False is supported only for CUDA for now"
)
if self.enable_eplb:
from vllm.model_executor.layers.quantization.fp8 import Fp8MoEMethod
if not isinstance(quant_method, (Fp8MoEMethod, UnquantizedFusedMoEMethod)):
# TODO: Add support for additional quantization methods.
# The implementation for other quantization methods does not
# contain essential differences, but the current quant API
# design causes duplicated work when extending to new
# quantization methods, so I'm leaving it for now.
# If you plan to add support for more quantization methods,
# please refer to the implementation in `Fp8MoEMethod`.
raise NotImplementedError(
"EPLB is only supported for FP8 quantization for now."
)
if self.enable_eplb and not self.quant_method.supports_eplb:
# TODO: Add support for additional quantization methods.
# The implementation for other quantization methods does not
# contain essential differences, but the current quant API
# design causes duplicated work when extending to new
# quantization methods, so I'm leaving it for now.
# If you plan to add support for more quantization methods,
# please refer to the implementation in `Fp8MoEMethod`.
raise NotImplementedError(
f"EPLB is not supported {self.quant_method.__class__.__name__}. "
"EPLB is only supported for FP8 quantization for now."
)
moe_quant_params = {
"num_experts": self.local_num_experts,
@@ -1353,6 +1481,15 @@ class FusedMoE(CustomOp):
self.batched_hidden_states: torch.Tensor | None = None
self.batched_router_logits: torch.Tensor | None = None
# Note: maybe_init_modular_kernel should only be called by
# prepare_communication_buffer_for_model.
# This is called after all weight loading and post-processing, so it
# should be safe to swap out the quant_method.
def maybe_init_modular_kernel(self) -> None:
mk = self.quant_method.maybe_init_modular_kernel(self)
if mk is not None:
self.quant_method = FusedMoEModularMethod(self.quant_method, mk)
@property
def shared_experts(self) -> torch.nn.Module | None:
return None
@@ -2167,7 +2304,7 @@ class FusedMoE(CustomOp):
"""
assert self.quant_method is not None
return (
self.quant_method.fused_experts is not None
isinstance(self.quant_method, FusedMoEModularMethod)
and self.quant_method.fused_experts.output_is_reduced()
)
@@ -2403,7 +2540,7 @@ class FusedMoE(CustomOp):
self.ensure_dp_chunking_init()
has_separate_shared_experts = (
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)
not isinstance(self.quant_method, FusedMoEModularMethod)
and self.shared_experts is not None
)
@@ -2430,8 +2567,8 @@ class FusedMoE(CustomOp):
hidden_states, router_logits, has_separate_shared_experts
)
do_naive_dispatch_combine: bool = (
self.dp_size > 1 and not self.quant_method.using_modular_kernel
do_naive_dispatch_combine: bool = self.dp_size > 1 and not isinstance(
self.quant_method, FusedMoEModularMethod
)
# If there are shared experts but we are not using a modular kernel, the