[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -22,7 +22,8 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
# yapf: disable
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig, FusedMoEParallelConfig)
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEConfig, FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig, biased_moe_quant_config)
|
||||
# yapf: enable
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEActivationFormat, FusedMoEModularKernel,
|
||||
@@ -78,11 +79,11 @@ class FusedMoeWeightScaleSupported(Enum):
|
||||
|
||||
class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
|
||||
# TODO(bnell): also pass quant_config?
|
||||
def __init__(self, moe: FusedMoEConfig):
|
||||
super().__init__()
|
||||
self.moe = moe
|
||||
self.fused_experts: Optional[Callable] = None
|
||||
self.moe_quant_config: Optional[FusedMoEQuantConfig] = None
|
||||
self.fused_experts: Optional[FusedMoEModularKernel] = None
|
||||
self.topk_indices_dtype = None
|
||||
|
||||
@abstractmethod
|
||||
@@ -103,23 +104,28 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
|
||||
@staticmethod
|
||||
def _maybe_make_prepare_finalize(
|
||||
moe: FusedMoEConfig, ) -> Optional[FusedMoEPrepareAndFinalize]:
|
||||
moe: FusedMoEConfig,
|
||||
quant_config: Optional[FusedMoEQuantConfig],
|
||||
) -> Optional[FusedMoEPrepareAndFinalize]:
|
||||
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||
assert all2all_manager is not None
|
||||
|
||||
prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None
|
||||
|
||||
# TODO: could allow this now
|
||||
assert not moe.use_flashinfer_cutlass_kernels, \
|
||||
"Must be created in modelopt.py"
|
||||
|
||||
if moe.use_pplx_kernels:
|
||||
assert quant_config is not None
|
||||
|
||||
hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes(
|
||||
moe.max_num_tokens,
|
||||
moe.hidden_dim,
|
||||
moe.in_dtype,
|
||||
moe.quant_dtype,
|
||||
per_act_token_quant=moe.per_act_token_quant,
|
||||
block_shape=moe.block_shape,
|
||||
quant_config.quant_dtype,
|
||||
per_act_token_quant=quant_config.per_act_token_quant,
|
||||
block_shape=quant_config.block_shape,
|
||||
)
|
||||
|
||||
all_to_all_args = dict(
|
||||
@@ -165,6 +171,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
)
|
||||
|
||||
elif moe.use_deepep_ll_kernels:
|
||||
assert quant_config is not None
|
||||
all_to_all_args = dict(
|
||||
max_num_tokens_per_dp_rank=moe.max_num_tokens,
|
||||
token_hidden_size=moe.hidden_dim,
|
||||
@@ -174,13 +181,11 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
all2all_manager.world_size)
|
||||
handle = all2all_manager.get_handle(all_to_all_args)
|
||||
|
||||
# Note : We may want to use FP8 dispatch even otherwise just to
|
||||
# reduce datamovement
|
||||
use_fp8_dispatch = (moe.quant_config is not None
|
||||
and moe.quant_config.quant_dtype
|
||||
== current_platform.fp8_dtype()
|
||||
and moe.quant_config.block_shape
|
||||
== DEEPEP_QUANT_BLOCK_SHAPE)
|
||||
# Note: We may want to use FP8 dispatch just to reduce
|
||||
# data movement.
|
||||
use_fp8_dispatch = (
|
||||
quant_config.quant_dtype == current_platform.fp8_dtype()
|
||||
and quant_config.block_shape == DEEPEP_QUANT_BLOCK_SHAPE)
|
||||
|
||||
prepare_finalize = DeepEPLLPrepareAndFinalize(
|
||||
handle,
|
||||
@@ -192,11 +197,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
return prepare_finalize
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
moe: FusedMoEConfig,
|
||||
) -> Optional[FusedMoEPrepareAndFinalize]:
|
||||
if moe.moe_parallel_config.use_all2all_kernels:
|
||||
return FusedMoEMethodBase._maybe_make_prepare_finalize(moe)
|
||||
self) -> Optional[FusedMoEPrepareAndFinalize]:
|
||||
if self.moe.moe_parallel_config.use_all2all_kernels:
|
||||
return FusedMoEMethodBase._maybe_make_prepare_finalize(
|
||||
self.moe, self.moe_quant_config)
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -204,7 +208,13 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
# prepare_communication_buffer_for_model.
|
||||
def init_prepare_finalize(self, layer: torch.nn.Module):
|
||||
assert self.moe is not None
|
||||
prepare_finalize = self.maybe_make_prepare_finalize(self.moe)
|
||||
|
||||
# We must get the quant config here so that the layer is
|
||||
# completely initialized, i.e. all weights loaded and post
|
||||
# processed.
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
|
||||
prepare_finalize = self.maybe_make_prepare_finalize()
|
||||
|
||||
if prepare_finalize is not None:
|
||||
logger.debug("%s for %s(%s)", prepare_finalize.__class__.__name__,
|
||||
@@ -213,7 +223,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
assert self.fused_experts is None, \
|
||||
f"Attempt to override experts for {id(self)}!"
|
||||
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
|
||||
experts = self.select_gemm_impl(prepare_finalize, self.moe, layer)
|
||||
experts = self.select_gemm_impl(prepare_finalize, layer)
|
||||
self.fused_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
@@ -223,7 +233,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
moe: FusedMoEConfig,
|
||||
layer: torch.nn.Module,
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
# based on the all2all implementation, select the appropriate
|
||||
@@ -232,6 +241,11 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
f"{self.__class__.__name__} must select appropriate gemm "
|
||||
"implementation based on the prepare_finalize")
|
||||
|
||||
@abstractmethod
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply(
|
||||
self,
|
||||
@@ -265,7 +279,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
|
||||
def __init__(self, moe: FusedMoEConfig):
|
||||
super().__init__(moe)
|
||||
self.has_bias = self.moe.has_bias
|
||||
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
|
||||
@@ -273,23 +286,30 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
else:
|
||||
self.rocm_aiter_fused_experts = None # type: ignore
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self) -> Optional[FusedMoEPrepareAndFinalize]:
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
return None
|
||||
else:
|
||||
return super().maybe_make_prepare_finalize()
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
# TODO(bnell): Remove. Every layer should have an moe config object.
|
||||
moe: FusedMoEConfig,
|
||||
layer: torch.nn.Module,
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
assert self.moe_quant_config is not None
|
||||
if (prepare_finalize.activation_format ==
|
||||
FusedMoEActivationFormat.BatchedExperts):
|
||||
logger.debug("BatchedTritonExperts %s", self.moe)
|
||||
return BatchedTritonExperts(
|
||||
max_num_tokens=self.moe.max_num_tokens,
|
||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
else:
|
||||
logger.debug("TritonExperts %s", self.moe)
|
||||
return TritonExperts()
|
||||
return TritonExperts(self.moe_quant_config)
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
@@ -303,7 +323,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
if self.has_bias:
|
||||
if self.moe.has_bias:
|
||||
w13_bias = torch.nn.Parameter(torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
@@ -320,7 +340,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
if self.has_bias:
|
||||
if self.moe.has_bias:
|
||||
w2_bias = torch.nn.Parameter(torch.zeros(num_experts,
|
||||
hidden_size,
|
||||
dtype=params_dtype),
|
||||
@@ -442,6 +462,16 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
logical_replica_count=logical_replica_count,
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
if self.moe.has_bias:
|
||||
return biased_moe_quant_config(
|
||||
layer.w13_bias,
|
||||
layer.w2_bias,
|
||||
)
|
||||
else:
|
||||
return FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -486,6 +516,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
logical_replica_count=logical_replica_count)
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
assert self.fused_experts is None
|
||||
return self.rocm_aiter_fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
@@ -496,7 +527,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
elif self.fused_experts is not None:
|
||||
if self.has_bias:
|
||||
if self.moe.has_bias:
|
||||
raise ValueError(
|
||||
"FusedMoEModularKernel does not support bias.")
|
||||
return self.fused_experts(
|
||||
@@ -517,12 +548,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
w1_bias=layer.w13_bias if self.has_bias else None,
|
||||
w2_bias=layer.w2_bias if self.has_bias else None,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
quant_config=self.moe_quant_config,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
@@ -933,16 +963,18 @@ class FusedMoE(CustomOp):
|
||||
# since model_config is not set in the pytest test.
|
||||
model_dtype = params_dtype
|
||||
|
||||
moe = FusedMoEConfig.make(num_experts=self.global_num_experts,
|
||||
experts_per_token=top_k,
|
||||
hidden_dim=hidden_size,
|
||||
num_local_experts=self.local_num_experts,
|
||||
moe_parallel_config=self.moe_parallel_config,
|
||||
in_dtype=model_dtype,
|
||||
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
|
||||
quant_config=quant_config,
|
||||
has_bias=has_bias)
|
||||
moe = FusedMoEConfig(
|
||||
num_experts=self.global_num_experts,
|
||||
experts_per_token=top_k,
|
||||
hidden_dim=hidden_size,
|
||||
num_local_experts=self.local_num_experts,
|
||||
moe_parallel_config=self.moe_parallel_config,
|
||||
in_dtype=model_dtype,
|
||||
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
|
||||
has_bias=has_bias,
|
||||
)
|
||||
self.moe_config = moe
|
||||
self.moe_quant_config: Optional[FusedMoEQuantConfig] = None
|
||||
self.quant_config = quant_config
|
||||
|
||||
# Note: get_quant_method will look at the layer's local_num_experts
|
||||
@@ -990,6 +1022,9 @@ class FusedMoE(CustomOp):
|
||||
# Chunked all2all staging tensor
|
||||
self.batched_hidden_states: Optional[torch.Tensor] = None
|
||||
self.batched_router_logits: Optional[torch.Tensor] = None
|
||||
|
||||
# TODO(bnell): flashinfer uses non-batched format.
|
||||
# Does it really need a batched buffer?
|
||||
if (self.moe_parallel_config.use_pplx_kernels
|
||||
or self.moe_parallel_config.use_deepep_ll_kernels
|
||||
or self.moe_config.use_flashinfer_cutlass_kernels):
|
||||
@@ -1062,7 +1097,9 @@ class FusedMoE(CustomOp):
|
||||
|
||||
@property
|
||||
def use_flashinfer_cutlass_kernels(self):
|
||||
return self.moe_config.use_flashinfer_cutlass_kernels
|
||||
return (self.moe_quant_config is not None
|
||||
and self.moe_quant_config.quant_dtype == "nvfp4"
|
||||
and self.moe_config.use_flashinfer_cutlass_kernels)
|
||||
|
||||
def update_expert_map(self):
|
||||
# ep_size and ep_rank should already be updated
|
||||
@@ -1492,6 +1529,11 @@ class FusedMoE(CustomOp):
|
||||
self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx]
|
||||
self.logical_replica_count = logical_replica_count[moe_layer_idx]
|
||||
|
||||
def ensure_moe_quant_config(self):
|
||||
if self.quant_method.moe_quant_config is None:
|
||||
self.quant_method.moe_quant_config = (
|
||||
self.quant_method.get_fused_moe_quant_config(self))
|
||||
|
||||
@staticmethod
|
||||
def select_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -1711,6 +1753,8 @@ class FusedMoE(CustomOp):
|
||||
assert (
|
||||
self.batched_router_logits.size(-1) == full_router_logits.size(-1))
|
||||
|
||||
self.ensure_moe_quant_config()
|
||||
|
||||
full_fused_final_hidden_states = torch.empty_like(full_hidden_states)
|
||||
if self.shared_experts is not None:
|
||||
full_shared_final_hidden_states = torch.empty_like(
|
||||
@@ -1825,14 +1869,17 @@ class FusedMoE(CustomOp):
|
||||
router_logits: torch.Tensor,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.quant_method is not None
|
||||
|
||||
self.ensure_moe_quant_config()
|
||||
|
||||
# Route to the chunked forward path using the FlashInfer Cutlass kernel
|
||||
# only when data parallelism (DP) is enabled.
|
||||
use_flashinfer_cutlass_kernels = (
|
||||
self.dp_size > 1
|
||||
and self.moe_config.use_flashinfer_cutlass_kernels)
|
||||
_use_flashinfer_cutlass_kernels = (self.dp_size > 1 and
|
||||
self.use_flashinfer_cutlass_kernels)
|
||||
|
||||
if (self.moe_parallel_config.use_pplx_kernels
|
||||
or self.moe_parallel_config.use_deepep_ll_kernels
|
||||
or use_flashinfer_cutlass_kernels):
|
||||
or _use_flashinfer_cutlass_kernels):
|
||||
return self.forward_impl_chunked(hidden_states, router_logits)
|
||||
|
||||
do_naive_dispatch_combine: bool = (
|
||||
|
||||
Reference in New Issue
Block a user