[Kernel] Integrate CUTLASS MoE kernel with PPLX (#18762)

Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
ElizaWszola
2025-06-07 03:26:11 +02:00
committed by GitHub
parent 6e0cd10f72
commit 84166fee97
26 changed files with 918 additions and 409 deletions

View File

@@ -9,6 +9,9 @@ from typing import Callable, Optional, Union
import torch
import torch.nn.functional as F
from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy,
QuantizationType)
from torch.nn.parameter import UninitializedParameter
import vllm.envs as envs
@@ -210,6 +213,7 @@ class MoEConfig:
moe_parallel_config: FusedMoEParallelConfig
in_dtype: torch.dtype # The activation type.
quant_dtype: torch.dtype = None
# TODO: add more quantization params, blocked, per-token, etc.
block_size: int = 128
@@ -264,8 +268,22 @@ class FusedMoeWeightScaleSupported(Enum):
BLOCK = "block"
def get_quant_config_input_activations(
quant_config: Optional[QuantizationConfig]
) -> Optional[QuantizationArgs]:
if (quant_config is not None and hasattr(quant_config, 'target_scheme_map')
and "Linear" in quant_config.target_scheme_map and
"input_activations" in quant_config.target_scheme_map["Linear"]):
return quant_config.target_scheme_map["Linear"].get(
"input_activations")
else:
return None
class FusedMoEMethodBase(QuantizeMethodBase):
moe: MoEConfig
@abstractmethod
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
@@ -277,6 +295,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
self.moe = moe
quant_dtype = None
act_quant_block_size = None
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
@@ -297,13 +316,14 @@ class FusedMoEMethodBase(QuantizeMethodBase):
# dp_size actually means tp_size, bug in pplx kernels
dp_size=all2all_manager.tp_group.world_size,
hidden_dim=moe.hidden_dim,
hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize,
hidden_dim_bytes=moe.hidden_dim * moe.quant_dtype.itemsize,
# For blocked per token: set to
# ceil_div(hidden_dim, block_size) * sizeof(float32)
# For per-token: set to sizeof(float32)
hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else (
(moe.hidden_dim + moe.block_size - 1) // moe.block_size *
torch.float32.itemsize)),
hidden_dim_scale_bytes=(
0 if moe.quant_dtype.itemsize != 1 else
((moe.hidden_dim + moe.block_size - 1) // moe.block_size *
torch.float32.itemsize)),
)
# Intranode pplx a2a takes a group name while internode does not.
@@ -313,6 +333,9 @@ class FusedMoEMethodBase(QuantizeMethodBase):
handle = all2all_manager.get_handle(all_to_all_args)
input_activations = get_quant_config_input_activations(
quant_config)
prepare_finalize = PplxPrepareAndFinalize(
handle,
max_num_tokens=moe.max_num_tokens,
@@ -320,7 +343,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
rank=all2all_manager.rank,
# dp_size actually means tp_size, bug in pplx kernels
dp_size=all2all_manager.tp_group.world_size,
quant_dtype=moe.in_dtype,
quant_dtype=moe.quant_dtype,
per_act_token=(input_activations.strategy
== QuantizationStrategy.TOKEN
if input_activations is not None else False),
)
elif moe.use_deepep_ht_kernels:
assert moe.dp_size == all2all_manager.dp_world_size
@@ -365,15 +391,15 @@ class FusedMoEMethodBase(QuantizeMethodBase):
self.topk_indices_dtype = None
if prepare_finalize is not None:
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
experts = self.select_gemm_impl(prepare_finalize)
experts = self.select_gemm_impl(prepare_finalize, moe)
self.fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
)
def select_gemm_impl(
self, prepare_finalize: FusedMoEPrepareAndFinalize
) -> FusedMoEPermuteExpertsUnpermute:
self, prepare_finalize: FusedMoEPrepareAndFinalize,
moe: Optional[MoEConfig]) -> FusedMoEPermuteExpertsUnpermute:
# based on the all2all implementation, select the appropriate
# gemm implementation
raise NotImplementedError(
@@ -419,7 +445,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
else:
self.rocm_aiter_fused_experts = None # type: ignore
def select_gemm_impl(self, prepare_finalize: FusedMoEPrepareAndFinalize):
def select_gemm_impl(self, prepare_finalize: FusedMoEPrepareAndFinalize,
moe: Optional[MoEConfig]):
assert self.fused_experts == fused_experts
@@ -809,7 +836,6 @@ class FusedMoE(torch.nn.Module):
activation: str = "silu",
):
super().__init__()
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
@@ -869,14 +895,24 @@ class FusedMoE(torch.nn.Module):
from vllm_hpu_extension.ops import DynamicFusedMOE
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
# Only support float8 for now.
quant_dtype = params_dtype
if quant_config is not None:
input_activations = get_quant_config_input_activations(
quant_config)
if (input_activations is not None
and input_activations.num_bits == 8
and input_activations.type == QuantizationType.FLOAT):
quant_dtype = torch.float8_e4m3fn
moe = MoEConfig(
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
# TODO (bnell): this needs to be fixed for quantized types.
in_dtype=params_dtype,
quant_dtype=quant_dtype,
max_num_tokens=MOE_DP_CHUNK_SIZE,
)
self.moe_config = moe