[Perf][Kernels] Enable FlashInfer DeepGEMM swapAB on SM90 (for W8A8 Linear Op) (#29213)
Signed-off-by: Kate Cheng <yunhsuanc@nvidia.com> Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com> Co-authored-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
This commit is contained in:
@@ -38,6 +38,11 @@ from vllm.utils.deep_gemm import (
|
||||
should_use_deepgemm_for_fp8_linear,
|
||||
transform_sf_into_required_layout,
|
||||
)
|
||||
from vllm.utils.flashinfer import (
|
||||
flashinfer_fp8_blockscale_gemm,
|
||||
is_flashinfer_fp8_blockscale_gemm_supported,
|
||||
should_use_flashinfer_for_blockscale_fp8_gemm,
|
||||
)
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -229,6 +234,112 @@ direct_register_custom_op(
|
||||
)
|
||||
|
||||
|
||||
def _flashinfer_fp8_blockscale_gemm_impl(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
group_size: int,
|
||||
use_deep_gemm_e8m0: bool,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Conditional FlashInfer FP8 blockscale GEMM with batch-size-dependent selection.
|
||||
|
||||
This function switches between two optimized kernels based on the input batch size:
|
||||
- For small batches (M < 32): Uses FlashInfer's DeepGEMM swapAB optimization.
|
||||
- For larger batches (M >= 32): Uses the official DeepGEMM kernel.
|
||||
|
||||
The conditional logic must use torch.cond() instead of a simple if-else statement
|
||||
to maintain compatibility with torch.compile graph compilation.
|
||||
|
||||
This batch-size-dependent selection is essential for maintaining model accuracy.
|
||||
Benchmarks on GSM8K show a significant accuracy gap (88% vs 95%) for DeepSeek-V3.1
|
||||
when using FlashInfer's DeepGEMM on M>=32. The M < 32 strategy fixes the accurracy
|
||||
drop.
|
||||
|
||||
Args:
|
||||
input: Input tensor of shape (batch_size, input_dim) in FP8 format
|
||||
weight: Weight tensor of shape (output_dim, input_dim) in FP8 format
|
||||
weight_scale: Scale factors for weight quantization (per-group)
|
||||
group_size: Quantization group size for the weight tensor
|
||||
use_deep_gemm_e8m0: Whether to use the E8M0 format in DeepGEMM quantization
|
||||
|
||||
Returns:
|
||||
Output tensor of shape (batch_size, output_dim) in bfloat16 format
|
||||
"""
|
||||
|
||||
def run_flashinfer_deepgemm_swapAB(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return flashinfer_fp8_blockscale_gemm(
|
||||
input=input,
|
||||
weight=weight,
|
||||
weight_scale=weight_scale,
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
def run_deepgemm(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
q_input, input_scale = per_token_group_quant_fp8(
|
||||
input,
|
||||
group_size=group_size,
|
||||
column_major_scales=True,
|
||||
use_ue8m0=use_deep_gemm_e8m0,
|
||||
)
|
||||
output = torch.empty(
|
||||
(q_input.shape[0], weight.shape[0]),
|
||||
dtype=torch.bfloat16,
|
||||
device=q_input.device,
|
||||
)
|
||||
fp8_gemm_nt(
|
||||
(q_input, input_scale),
|
||||
(weight, weight_scale),
|
||||
output,
|
||||
is_deep_gemm_e8m0_used=use_deep_gemm_e8m0,
|
||||
)
|
||||
return output
|
||||
|
||||
condition = input.shape[0] < 32
|
||||
|
||||
# PyTorch's torch.compile cannot handle input-dependent control flow in standard
|
||||
# Python conditionals. torch.cond() explicitly registers both code paths in the
|
||||
# computation graph, allowing torch.compile to capture both branches.
|
||||
# without torch.cond, the M < 32 condition won't be able to be captured by torch
|
||||
# compile
|
||||
return torch.cond(
|
||||
condition,
|
||||
run_flashinfer_deepgemm_swapAB,
|
||||
run_deepgemm,
|
||||
(input, weight, weight_scale),
|
||||
)
|
||||
|
||||
|
||||
def _flashinfer_fp8_blockscale_gemm_fake(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
group_size: int,
|
||||
use_deep_gemm_e8m0: bool,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Required fake/meta implementation for torch.compile graph tracing.
|
||||
"""
|
||||
return torch.empty(
|
||||
input.shape[0], weight.shape[0], dtype=torch.bfloat16, device=input.device
|
||||
)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
"flashinfer_fp8_blockscale_gemm",
|
||||
_flashinfer_fp8_blockscale_gemm_impl,
|
||||
fake_impl=_flashinfer_fp8_blockscale_gemm_fake,
|
||||
)
|
||||
|
||||
|
||||
# TODO fix ROCm->Triton custom path:
|
||||
# https://github.com/vllm-project/vllm/issues/14397
|
||||
class W8A8BlockFp8LinearOp:
|
||||
@@ -249,6 +360,7 @@ class W8A8BlockFp8LinearOp:
|
||||
self.is_deep_gemm_supported = is_deep_gemm_supported()
|
||||
self.is_hopper = current_platform.is_device_capability(90)
|
||||
self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used()
|
||||
self.is_flashinfer_supported = is_flashinfer_fp8_blockscale_gemm_supported()
|
||||
|
||||
# Get the correct blockscale mul and input quant operations.
|
||||
# We can't use _dispatch_w8a8_blockscale_op to figure out if we want
|
||||
@@ -284,7 +396,14 @@ class W8A8BlockFp8LinearOp:
|
||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||
output_dtype = input.dtype
|
||||
|
||||
if should_use_deepgemm_for_fp8_linear(
|
||||
if should_use_flashinfer_for_blockscale_fp8_gemm(
|
||||
self.is_flashinfer_supported, output_dtype, input_2d, weight
|
||||
) and should_use_deepgemm_for_fp8_linear(
|
||||
output_dtype, weight, self.is_deep_gemm_supported
|
||||
):
|
||||
output = self._run_flashinfer(input_2d, weight, weight_scale)
|
||||
|
||||
elif should_use_deepgemm_for_fp8_linear(
|
||||
output_dtype, weight, self.is_deep_gemm_supported
|
||||
):
|
||||
output = self._run_deepgemm(input_2d, weight, weight_scale)
|
||||
@@ -412,6 +531,29 @@ class W8A8BlockFp8LinearOp:
|
||||
input_2d.dtype,
|
||||
)
|
||||
|
||||
def _run_flashinfer(
|
||||
self,
|
||||
input_2d: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Run FlashInfer FP8 block-scale GEMM.
|
||||
|
||||
This backend uses TensorRT-LLM's FP8 block-scale GEMM kernels
|
||||
and supports FP8+FP8 (W8A8 full quantization) on SM90+ (Hopper).
|
||||
"""
|
||||
# Now call FlashInfer with BF16 input + FP8 weight, input will be
|
||||
# quantized with FlashInfer kernel (W8A8)
|
||||
output = torch.ops.vllm.flashinfer_fp8_blockscale_gemm(
|
||||
input=input_2d, # BF16 input
|
||||
weight=weight, # FP8 weight
|
||||
weight_scale=weight_scale, # Weight scales
|
||||
group_size=self.act_quant_group_shape.col,
|
||||
use_deep_gemm_e8m0=self.use_deep_gemm_e8m0,
|
||||
)
|
||||
return output
|
||||
|
||||
def _dispatch_w8a8_blockscale_op(
|
||||
self,
|
||||
use_cutlass: bool,
|
||||
|
||||
Reference in New Issue
Block a user