[Kernel] Integrate CUTLASS MoE kernel with PPLX (#18762)
Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
@@ -9,6 +9,9 @@ from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from compressed_tensors.quantization import (QuantizationArgs,
|
||||
QuantizationStrategy,
|
||||
QuantizationType)
|
||||
from torch.nn.parameter import UninitializedParameter
|
||||
|
||||
import vllm.envs as envs
|
||||
@@ -210,6 +213,7 @@ class MoEConfig:
|
||||
moe_parallel_config: FusedMoEParallelConfig
|
||||
|
||||
in_dtype: torch.dtype # The activation type.
|
||||
quant_dtype: torch.dtype = None
|
||||
|
||||
# TODO: add more quantization params, blocked, per-token, etc.
|
||||
block_size: int = 128
|
||||
@@ -264,8 +268,22 @@ class FusedMoeWeightScaleSupported(Enum):
|
||||
BLOCK = "block"
|
||||
|
||||
|
||||
def get_quant_config_input_activations(
|
||||
quant_config: Optional[QuantizationConfig]
|
||||
) -> Optional[QuantizationArgs]:
|
||||
if (quant_config is not None and hasattr(quant_config, 'target_scheme_map')
|
||||
and "Linear" in quant_config.target_scheme_map and
|
||||
"input_activations" in quant_config.target_scheme_map["Linear"]):
|
||||
return quant_config.target_scheme_map["Linear"].get(
|
||||
"input_activations")
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
|
||||
moe: MoEConfig
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
@@ -277,6 +295,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||
assert all2all_manager is not None
|
||||
|
||||
self.moe = moe
|
||||
quant_dtype = None
|
||||
act_quant_block_size = None
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||
@@ -297,13 +316,14 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
# dp_size actually means tp_size, bug in pplx kernels
|
||||
dp_size=all2all_manager.tp_group.world_size,
|
||||
hidden_dim=moe.hidden_dim,
|
||||
hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize,
|
||||
hidden_dim_bytes=moe.hidden_dim * moe.quant_dtype.itemsize,
|
||||
# For blocked per token: set to
|
||||
# ceil_div(hidden_dim, block_size) * sizeof(float32)
|
||||
# For per-token: set to sizeof(float32)
|
||||
hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else (
|
||||
(moe.hidden_dim + moe.block_size - 1) // moe.block_size *
|
||||
torch.float32.itemsize)),
|
||||
hidden_dim_scale_bytes=(
|
||||
0 if moe.quant_dtype.itemsize != 1 else
|
||||
((moe.hidden_dim + moe.block_size - 1) // moe.block_size *
|
||||
torch.float32.itemsize)),
|
||||
)
|
||||
|
||||
# Intranode pplx a2a takes a group name while internode does not.
|
||||
@@ -313,6 +333,9 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
|
||||
handle = all2all_manager.get_handle(all_to_all_args)
|
||||
|
||||
input_activations = get_quant_config_input_activations(
|
||||
quant_config)
|
||||
|
||||
prepare_finalize = PplxPrepareAndFinalize(
|
||||
handle,
|
||||
max_num_tokens=moe.max_num_tokens,
|
||||
@@ -320,7 +343,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
rank=all2all_manager.rank,
|
||||
# dp_size actually means tp_size, bug in pplx kernels
|
||||
dp_size=all2all_manager.tp_group.world_size,
|
||||
quant_dtype=moe.in_dtype,
|
||||
quant_dtype=moe.quant_dtype,
|
||||
per_act_token=(input_activations.strategy
|
||||
== QuantizationStrategy.TOKEN
|
||||
if input_activations is not None else False),
|
||||
)
|
||||
elif moe.use_deepep_ht_kernels:
|
||||
assert moe.dp_size == all2all_manager.dp_world_size
|
||||
@@ -365,15 +391,15 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
self.topk_indices_dtype = None
|
||||
if prepare_finalize is not None:
|
||||
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
|
||||
experts = self.select_gemm_impl(prepare_finalize)
|
||||
experts = self.select_gemm_impl(prepare_finalize, moe)
|
||||
self.fused_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
)
|
||||
|
||||
def select_gemm_impl(
|
||||
self, prepare_finalize: FusedMoEPrepareAndFinalize
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
self, prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
moe: Optional[MoEConfig]) -> FusedMoEPermuteExpertsUnpermute:
|
||||
# based on the all2all implementation, select the appropriate
|
||||
# gemm implementation
|
||||
raise NotImplementedError(
|
||||
@@ -419,7 +445,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
else:
|
||||
self.rocm_aiter_fused_experts = None # type: ignore
|
||||
|
||||
def select_gemm_impl(self, prepare_finalize: FusedMoEPrepareAndFinalize):
|
||||
def select_gemm_impl(self, prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
moe: Optional[MoEConfig]):
|
||||
|
||||
assert self.fused_experts == fused_experts
|
||||
|
||||
@@ -809,7 +836,6 @@ class FusedMoE(torch.nn.Module):
|
||||
activation: str = "silu",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
@@ -869,14 +895,24 @@ class FusedMoE(torch.nn.Module):
|
||||
from vllm_hpu_extension.ops import DynamicFusedMOE
|
||||
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
|
||||
|
||||
# Only support float8 for now.
|
||||
quant_dtype = params_dtype
|
||||
if quant_config is not None:
|
||||
input_activations = get_quant_config_input_activations(
|
||||
quant_config)
|
||||
if (input_activations is not None
|
||||
and input_activations.num_bits == 8
|
||||
and input_activations.type == QuantizationType.FLOAT):
|
||||
quant_dtype = torch.float8_e4m3fn
|
||||
|
||||
moe = MoEConfig(
|
||||
num_experts=self.global_num_experts,
|
||||
experts_per_token=top_k,
|
||||
hidden_dim=hidden_size,
|
||||
num_local_experts=self.local_num_experts,
|
||||
moe_parallel_config=self.moe_parallel_config,
|
||||
# TODO (bnell): this needs to be fixed for quantized types.
|
||||
in_dtype=params_dtype,
|
||||
quant_dtype=quant_dtype,
|
||||
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
||||
)
|
||||
self.moe_config = moe
|
||||
|
||||
Reference in New Issue
Block a user