[Kernels] Enable FlashInfer FP8 Blockscale on SM90 (for TEP DSR1) (#27134)

Signed-off-by: Duncan Moss <djm.moss@gmail.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
Duncan Moss
2025-11-14 08:02:44 -08:00
committed by GitHub
parent 511a6b611d
commit 3f8a874065
4 changed files with 179 additions and 68 deletions

View File

@@ -3,6 +3,7 @@
from collections.abc import Callable
from enum import Enum
from functools import partial
from typing import TYPE_CHECKING, Any, Optional
import torch
@@ -122,10 +123,13 @@ def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
Select the primary FP8 MoE backend
Note: Shape-specific fallbacks may still occur at runtime.
"""
# prefer FlashInfer backends when available and enabled on supported GPUs
# Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
if (
current_platform.is_cuda()
and current_platform.is_device_capability(100)
and (
current_platform.is_device_capability(100)
or current_platform.is_device_capability(90)
)
and envs.VLLM_USE_FLASHINFER_MOE_FP8
and has_flashinfer_moe()
):
@@ -134,14 +138,14 @@ def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
return Fp8MoeBackend.FLASHINFER_TRTLLM
else:
if block_quant:
if block_quant and current_platform.is_device_capability(100):
raise ValueError(
"FlashInfer FP8 MoE throughput backend does not "
"support block quantization. Please use "
"VLLM_FLASHINFER_MOE_BACKEND=latency "
"instead."
)
logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM100")
logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM90/SM100")
return Fp8MoeBackend.FLASHINFER_CUTLASS
# weight-only path for older GPUs without native FP8
@@ -641,6 +645,16 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
if self.block_quant:
assert self.weight_block_size == [128, 128], (
f"Only support weight_block_size == [128, 128], "
f"got {self.weight_block_size}"
)
self.flashinfer_moe_fn = partial(
flashinfer_cutlass_moe_fp8,
moe=self.moe,
use_deepseek_fp8_block_scale=self.block_quant,
)
self.allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
self.allow_cutlass_block_scaled_grouped_gemm = (
@@ -1012,8 +1026,15 @@ class Fp8MoEMethod(FusedMoEMethodBase):
):
return None
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
if self.block_quant:
assert self.weight_block_size == [128, 128], (
f"Only support weight_block_size == [128, 128], "
f"got {self.weight_block_size}"
)
# Wire block-scale flag through prepare/finalize when using CUTLASS
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
self.moe
self.moe,
use_deepseek_fp8_block_scale=self.block_quant,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
@@ -1062,9 +1083,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
# Select GEMM experts with block-scale when weights are block-quantized
experts = select_cutlass_fp8_gemm_impl(
self.moe,
self.moe_quant_config,
use_deepseek_fp8_block_scale=self.block_quant,
)
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
@@ -1251,16 +1274,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
workspace=layer.workspace,
)
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
assert not self.block_quant
assert not renormalize and custom_routing_function is not None
assert activation == "silu", (
f"Expected 'silu' activation but got {activation}"
)
assert scoring_func == "sigmoid", (
f"Expected 'sigmoid' scoring func but got {scoring_func}"
)
result = flashinfer_cutlass_moe_fp8(
if not self.block_quant:
assert not renormalize and custom_routing_function is not None
assert scoring_func == "sigmoid", (
f"Expected 'sigmoid' scoring func but got {scoring_func}"
)
# Delegate to CUTLASS FlashInfer path; function already bound with
# use_deepseek_fp8_block_scale for block-quant when applicable
result = self.flashinfer_moe_fn(
x,
layer,
topk_weights,