[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-09-17 19:43:31 -04:00
committed by GitHub
parent e6585ddb45
commit 5963b98b46
68 changed files with 2698 additions and 2526 deletions

View File

@@ -22,7 +22,8 @@ from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
# yapf: disable
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEParallelConfig)
FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEConfig, FusedMoEParallelConfig,
FusedMoEQuantConfig, biased_moe_quant_config)
# yapf: enable
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat, FusedMoEModularKernel,
@@ -78,11 +79,11 @@ class FusedMoeWeightScaleSupported(Enum):
class FusedMoEMethodBase(QuantizeMethodBase):
# TODO(bnell): also pass quant_config?
def __init__(self, moe: FusedMoEConfig):
super().__init__()
self.moe = moe
self.fused_experts: Optional[Callable] = None
self.moe_quant_config: Optional[FusedMoEQuantConfig] = None
self.fused_experts: Optional[FusedMoEModularKernel] = None
self.topk_indices_dtype = None
@abstractmethod
@@ -103,23 +104,28 @@ class FusedMoEMethodBase(QuantizeMethodBase):
@staticmethod
def _maybe_make_prepare_finalize(
moe: FusedMoEConfig, ) -> Optional[FusedMoEPrepareAndFinalize]:
moe: FusedMoEConfig,
quant_config: Optional[FusedMoEQuantConfig],
) -> Optional[FusedMoEPrepareAndFinalize]:
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None
# TODO: could allow this now
assert not moe.use_flashinfer_cutlass_kernels, \
"Must be created in modelopt.py"
if moe.use_pplx_kernels:
assert quant_config is not None
hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes(
moe.max_num_tokens,
moe.hidden_dim,
moe.in_dtype,
moe.quant_dtype,
per_act_token_quant=moe.per_act_token_quant,
block_shape=moe.block_shape,
quant_config.quant_dtype,
per_act_token_quant=quant_config.per_act_token_quant,
block_shape=quant_config.block_shape,
)
all_to_all_args = dict(
@@ -165,6 +171,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
)
elif moe.use_deepep_ll_kernels:
assert quant_config is not None
all_to_all_args = dict(
max_num_tokens_per_dp_rank=moe.max_num_tokens,
token_hidden_size=moe.hidden_dim,
@@ -174,13 +181,11 @@ class FusedMoEMethodBase(QuantizeMethodBase):
all2all_manager.world_size)
handle = all2all_manager.get_handle(all_to_all_args)
# Note : We may want to use FP8 dispatch even otherwise just to
# reduce datamovement
use_fp8_dispatch = (moe.quant_config is not None
and moe.quant_config.quant_dtype
== current_platform.fp8_dtype()
and moe.quant_config.block_shape
== DEEPEP_QUANT_BLOCK_SHAPE)
# Note: We may want to use FP8 dispatch just to reduce
# data movement.
use_fp8_dispatch = (
quant_config.quant_dtype == current_platform.fp8_dtype()
and quant_config.block_shape == DEEPEP_QUANT_BLOCK_SHAPE)
prepare_finalize = DeepEPLLPrepareAndFinalize(
handle,
@@ -192,11 +197,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
return prepare_finalize
def maybe_make_prepare_finalize(
self,
moe: FusedMoEConfig,
) -> Optional[FusedMoEPrepareAndFinalize]:
if moe.moe_parallel_config.use_all2all_kernels:
return FusedMoEMethodBase._maybe_make_prepare_finalize(moe)
self) -> Optional[FusedMoEPrepareAndFinalize]:
if self.moe.moe_parallel_config.use_all2all_kernels:
return FusedMoEMethodBase._maybe_make_prepare_finalize(
self.moe, self.moe_quant_config)
else:
return None
@@ -204,7 +208,13 @@ class FusedMoEMethodBase(QuantizeMethodBase):
# prepare_communication_buffer_for_model.
def init_prepare_finalize(self, layer: torch.nn.Module):
assert self.moe is not None
prepare_finalize = self.maybe_make_prepare_finalize(self.moe)
# We must get the quant config here so that the layer is
# completely initialized, i.e. all weights loaded and post
# processed.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
prepare_finalize = self.maybe_make_prepare_finalize()
if prepare_finalize is not None:
logger.debug("%s for %s(%s)", prepare_finalize.__class__.__name__,
@@ -213,7 +223,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
assert self.fused_experts is None, \
f"Attempt to override experts for {id(self)}!"
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
experts = self.select_gemm_impl(prepare_finalize, self.moe, layer)
experts = self.select_gemm_impl(prepare_finalize, layer)
self.fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
@@ -223,7 +233,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
# based on the all2all implementation, select the appropriate
@@ -232,6 +241,11 @@ class FusedMoEMethodBase(QuantizeMethodBase):
f"{self.__class__.__name__} must select appropriate gemm "
"implementation based on the prepare_finalize")
@abstractmethod
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
raise NotImplementedError
@abstractmethod
def apply(
self,
@@ -265,7 +279,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def __init__(self, moe: FusedMoEConfig):
super().__init__(moe)
self.has_bias = self.moe.has_bias
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
if self.rocm_aiter_moe_enabled:
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
@@ -273,23 +286,30 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
else:
self.rocm_aiter_fused_experts = None # type: ignore
def maybe_make_prepare_finalize(
self) -> Optional[FusedMoEPrepareAndFinalize]:
if self.rocm_aiter_moe_enabled:
return None
else:
return super().maybe_make_prepare_finalize()
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
# TODO(bnell): Remove. Every layer should have an moe config object.
moe: FusedMoEConfig,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
assert self.moe_quant_config is not None
if (prepare_finalize.activation_format ==
FusedMoEActivationFormat.BatchedExperts):
logger.debug("BatchedTritonExperts %s", self.moe)
return BatchedTritonExperts(
max_num_tokens=self.moe.max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config,
)
else:
logger.debug("TritonExperts %s", self.moe)
return TritonExperts()
return TritonExperts(self.moe_quant_config)
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
@@ -303,7 +323,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
if self.has_bias:
if self.moe.has_bias:
w13_bias = torch.nn.Parameter(torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
@@ -320,7 +340,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
if self.has_bias:
if self.moe.has_bias:
w2_bias = torch.nn.Parameter(torch.zeros(num_experts,
hidden_size,
dtype=params_dtype),
@@ -442,6 +462,16 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
logical_replica_count=logical_replica_count,
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
if self.moe.has_bias:
return biased_moe_quant_config(
layer.w13_bias,
layer.w2_bias,
)
else:
return FUSED_MOE_UNQUANTIZED_CONFIG
def forward_cuda(
self,
layer: torch.nn.Module,
@@ -486,6 +516,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
logical_replica_count=logical_replica_count)
if self.rocm_aiter_moe_enabled:
assert self.fused_experts is None
return self.rocm_aiter_fused_experts(
hidden_states=x,
w1=layer.w13_weight,
@@ -496,7 +527,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input)
elif self.fused_experts is not None:
if self.has_bias:
if self.moe.has_bias:
raise ValueError(
"FusedMoEModularKernel does not support bias.")
return self.fused_experts(
@@ -517,12 +548,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
w1_bias=layer.w13_bias if self.has_bias else None,
w2_bias=layer.w2_bias if self.has_bias else None,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
quant_config=self.moe_quant_config,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
@@ -933,16 +963,18 @@ class FusedMoE(CustomOp):
# since model_config is not set in the pytest test.
model_dtype = params_dtype
moe = FusedMoEConfig.make(num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
in_dtype=model_dtype,
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
quant_config=quant_config,
has_bias=has_bias)
moe = FusedMoEConfig(
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
in_dtype=model_dtype,
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
has_bias=has_bias,
)
self.moe_config = moe
self.moe_quant_config: Optional[FusedMoEQuantConfig] = None
self.quant_config = quant_config
# Note: get_quant_method will look at the layer's local_num_experts
@@ -990,6 +1022,9 @@ class FusedMoE(CustomOp):
# Chunked all2all staging tensor
self.batched_hidden_states: Optional[torch.Tensor] = None
self.batched_router_logits: Optional[torch.Tensor] = None
# TODO(bnell): flashinfer uses non-batched format.
# Does it really need a batched buffer?
if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels
or self.moe_config.use_flashinfer_cutlass_kernels):
@@ -1062,7 +1097,9 @@ class FusedMoE(CustomOp):
@property
def use_flashinfer_cutlass_kernels(self):
return self.moe_config.use_flashinfer_cutlass_kernels
return (self.moe_quant_config is not None
and self.moe_quant_config.quant_dtype == "nvfp4"
and self.moe_config.use_flashinfer_cutlass_kernels)
def update_expert_map(self):
# ep_size and ep_rank should already be updated
@@ -1492,6 +1529,11 @@ class FusedMoE(CustomOp):
self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx]
self.logical_replica_count = logical_replica_count[moe_layer_idx]
def ensure_moe_quant_config(self):
if self.quant_method.moe_quant_config is None:
self.quant_method.moe_quant_config = (
self.quant_method.get_fused_moe_quant_config(self))
@staticmethod
def select_experts(
hidden_states: torch.Tensor,
@@ -1711,6 +1753,8 @@ class FusedMoE(CustomOp):
assert (
self.batched_router_logits.size(-1) == full_router_logits.size(-1))
self.ensure_moe_quant_config()
full_fused_final_hidden_states = torch.empty_like(full_hidden_states)
if self.shared_experts is not None:
full_shared_final_hidden_states = torch.empty_like(
@@ -1825,14 +1869,17 @@ class FusedMoE(CustomOp):
router_logits: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert self.quant_method is not None
self.ensure_moe_quant_config()
# Route to the chunked forward path using the FlashInfer Cutlass kernel
# only when data parallelism (DP) is enabled.
use_flashinfer_cutlass_kernels = (
self.dp_size > 1
and self.moe_config.use_flashinfer_cutlass_kernels)
_use_flashinfer_cutlass_kernels = (self.dp_size > 1 and
self.use_flashinfer_cutlass_kernels)
if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels
or use_flashinfer_cutlass_kernels):
or _use_flashinfer_cutlass_kernels):
return self.forward_impl_chunked(hidden_states, router_logits)
do_naive_dispatch_combine: bool = (