[Kernels] Enable FlashInfer FP8 Blockscale on SM90 (for TEP DSR1) (#27134)

Signed-off-by: Duncan Moss <djm.moss@gmail.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
Duncan Moss
2025-11-14 08:02:44 -08:00
committed by GitHub
parent 511a6b611d
commit 3f8a874065
4 changed files with 179 additions and 68 deletions

View File

@@ -17,6 +17,7 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
create_flashinfer_prepare_finalize,
)
from vllm.platforms import current_platform
logger = init_logger(__name__)
@@ -190,17 +191,22 @@ def register_moe_scaling_factors(layer: torch.nn.Module) -> None:
def build_flashinfer_fp8_cutlass_moe_prepare_finalize(
moe: FusedMoEConfig | None,
moe: FusedMoEConfig | None, use_deepseek_fp8_block_scale: bool = False
) -> mk.FusedMoEPrepareAndFinalize:
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False
return create_flashinfer_prepare_finalize(use_dp)
# Propagate block-scale flag so prepare/finalize can skip act quantization
# and inform the kernel to consume per-block weight scales.
return create_flashinfer_prepare_finalize(
use_dp, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale
)
def select_cutlass_fp8_gemm_impl(
moe: FusedMoEConfig | None,
quant_config: FusedMoEQuantConfig,
out_dtype: torch.dtype | None = None,
use_deepseek_fp8_block_scale: bool = False,
) -> mk.FusedMoEPermuteExpertsUnpermute:
"""Return a GEMM *experts* implementation for fused-MoE layers"""
@@ -212,12 +218,14 @@ def select_cutlass_fp8_gemm_impl(
ep_size=moe.moe_parallel_config.ep_size,
tp_rank=moe.moe_parallel_config.tp_rank,
tp_size=moe.moe_parallel_config.tp_size,
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
)
assert out_dtype is not None, "If moe config is None, out_dtype must be passed"
return FlashInferExperts(
out_dtype=out_dtype,
quant_config=quant_config,
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
)
@@ -231,14 +239,22 @@ def flashinfer_cutlass_moe_fp8(
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
use_deepseek_fp8_block_scale: bool = False,
moe: FusedMoEConfig | None = None,
) -> torch.Tensor:
quant_config = layer.quant_method.get_fused_moe_quant_config(layer)
assert quant_config is not None
# Construct modular kernel with block-scale support when requested.
fused_experts = mk.FusedMoEModularKernel(
build_flashinfer_fp8_cutlass_moe_prepare_finalize(moe=None),
build_flashinfer_fp8_cutlass_moe_prepare_finalize(
moe=moe, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale
),
select_cutlass_fp8_gemm_impl(
moe=None, quant_config=quant_config, out_dtype=hidden_states.dtype
moe=moe,
quant_config=quant_config,
out_dtype=hidden_states.dtype,
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
),
)
@@ -258,7 +274,10 @@ def flashinfer_cutlass_moe_fp8(
def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
if flashinfer_moe_backend == "throughput":
# Prefer CUTLASS on SM90 to cover both SM90/SM100 generations
if flashinfer_moe_backend == "throughput" or current_platform.is_device_capability(
90
):
return FlashinferMoeBackend.CUTLASS
elif flashinfer_moe_backend == "latency":
return FlashinferMoeBackend.TENSORRT_LLM