[Kernels] Enable FlashInfer FP8 Blockscale on SM90 (for TEP DSR1) (#27134)
Signed-off-by: Duncan Moss <djm.moss@gmail.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
@@ -17,6 +17,7 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
||||
create_flashinfer_prepare_finalize,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -190,17 +191,22 @@ def register_moe_scaling_factors(layer: torch.nn.Module) -> None:
|
||||
|
||||
|
||||
def build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
||||
moe: FusedMoEConfig | None,
|
||||
moe: FusedMoEConfig | None, use_deepseek_fp8_block_scale: bool = False
|
||||
) -> mk.FusedMoEPrepareAndFinalize:
|
||||
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
|
||||
use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False
|
||||
return create_flashinfer_prepare_finalize(use_dp)
|
||||
# Propagate block-scale flag so prepare/finalize can skip act quantization
|
||||
# and inform the kernel to consume per-block weight scales.
|
||||
return create_flashinfer_prepare_finalize(
|
||||
use_dp, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale
|
||||
)
|
||||
|
||||
|
||||
def select_cutlass_fp8_gemm_impl(
|
||||
moe: FusedMoEConfig | None,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
out_dtype: torch.dtype | None = None,
|
||||
use_deepseek_fp8_block_scale: bool = False,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
"""Return a GEMM *experts* implementation for fused-MoE layers"""
|
||||
|
||||
@@ -212,12 +218,14 @@ def select_cutlass_fp8_gemm_impl(
|
||||
ep_size=moe.moe_parallel_config.ep_size,
|
||||
tp_rank=moe.moe_parallel_config.tp_rank,
|
||||
tp_size=moe.moe_parallel_config.tp_size,
|
||||
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
|
||||
)
|
||||
|
||||
assert out_dtype is not None, "If moe config is None, out_dtype must be passed"
|
||||
return FlashInferExperts(
|
||||
out_dtype=out_dtype,
|
||||
quant_config=quant_config,
|
||||
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
|
||||
)
|
||||
|
||||
|
||||
@@ -231,14 +239,22 @@ def flashinfer_cutlass_moe_fp8(
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_deepseek_fp8_block_scale: bool = False,
|
||||
moe: FusedMoEConfig | None = None,
|
||||
) -> torch.Tensor:
|
||||
quant_config = layer.quant_method.get_fused_moe_quant_config(layer)
|
||||
assert quant_config is not None
|
||||
|
||||
# Construct modular kernel with block-scale support when requested.
|
||||
fused_experts = mk.FusedMoEModularKernel(
|
||||
build_flashinfer_fp8_cutlass_moe_prepare_finalize(moe=None),
|
||||
build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
||||
moe=moe, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale
|
||||
),
|
||||
select_cutlass_fp8_gemm_impl(
|
||||
moe=None, quant_config=quant_config, out_dtype=hidden_states.dtype
|
||||
moe=moe,
|
||||
quant_config=quant_config,
|
||||
out_dtype=hidden_states.dtype,
|
||||
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -258,7 +274,10 @@ def flashinfer_cutlass_moe_fp8(
|
||||
|
||||
def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
|
||||
flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
|
||||
if flashinfer_moe_backend == "throughput":
|
||||
# Prefer CUTLASS on SM90 to cover both SM90/SM100 generations
|
||||
if flashinfer_moe_backend == "throughput" or current_platform.is_device_capability(
|
||||
90
|
||||
):
|
||||
return FlashinferMoeBackend.CUTLASS
|
||||
elif flashinfer_moe_backend == "latency":
|
||||
return FlashinferMoeBackend.TENSORRT_LLM
|
||||
|
||||
Reference in New Issue
Block a user