diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index 32c77b9a0..bd4a737ca 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -24,6 +24,10 @@ from vllm.utils.deep_gemm import ( per_block_cast_to_fp8, should_use_deepgemm_for_fp8_linear, ) +from vllm.utils.flashinfer import ( + flashinfer_fp8_blockscale_gemm, + has_flashinfer_fp8_blockscale_gemm, +) from vllm.utils.import_utils import has_deep_gemm if current_platform.get_device_capability() < (9, 0): @@ -205,3 +209,50 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)) ) / torch.mean(torch.abs(ref_out.to(torch.float32))) assert rel_diff < 0.001 + + +@pytest.mark.skipif( + current_platform.is_fp8_fnuz(), + reason="This platform supports e4m3fnuz, not e4m3fn.", +) +@pytest.mark.parametrize( + "M,N,K,block_size,out_dtype,seed", + itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS), +) +@torch.inference_mode() +def test_w8a8_block_fp8_flashinfer_matmul(M, N, K, block_size, out_dtype, seed): + if not has_flashinfer_fp8_blockscale_gemm(): + pytest.skip( + "FlashInfer block GEMM not available (requires SM90+ and FlashInfer)" + ) + # only aligned sizes + if K % 128 != 0 or N % 64 != 0: + pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}") + + torch.manual_seed(seed) + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max = fp8_info.max + + A_bf16 = (torch.rand(M, K, dtype=torch.bfloat16) - 0.5) * 2 * fp8_max + B_bf16 = (torch.rand(N, K, dtype=torch.bfloat16) - 0.5) * 2 * fp8_max + + A_fp8, As_fp8 = per_token_group_quant_fp8(A_bf16, block_size[1], use_ue8m0=False) + B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_bf16, block_size, use_ue8m0=False) + + As = As_fp8.to(torch.float32) + Bs = Bs_fp8.to(torch.float32) + + ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) + + out = flashinfer_fp8_blockscale_gemm( + input=A_bf16, + weight=B_fp8, + input_scale=None, + weight_scale=Bs, + out_dtype=out_dtype, + ) + + rel_diff = torch.mean( + torch.abs(out.to(torch.bfloat16) - ref_out.to(torch.bfloat16)) + ) / torch.mean(torch.abs(ref_out.to(torch.bfloat16))) + assert rel_diff < 0.001 diff --git a/vllm/envs.py b/vllm/envs.py index dadb8c8a2..9b06aadfc 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -168,6 +168,7 @@ if TYPE_CHECKING: "relax", ] = "relax" VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True + VLLM_BLOCKSCALE_FP8_GEMM_FLASHINFER: bool = False VLLM_USE_FLASHINFER_MOE_FP16: bool = False VLLM_USE_FLASHINFER_MOE_FP8: bool = False VLLM_USE_FLASHINFER_MOE_FP4: bool = False @@ -1206,6 +1207,11 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_USE_FUSED_MOE_GROUPED_TOPK": lambda: bool( int(os.getenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "1")) ), + # Allow use of FlashInfer FP8 block-scale GEMM for linear layers. + # This uses TensorRT-LLM kernels and requires SM90+ (Hopper). + "VLLM_BLOCKSCALE_FP8_GEMM_FLASHINFER": lambda: bool( + int(os.getenv("VLLM_BLOCKSCALE_FP8_GEMM_FLASHINFER", "0")) + ), # Allow use of FlashInfer MoE kernels for fused moe ops. "VLLM_USE_FLASHINFER_MOE_FP16": lambda: bool( int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP16", "0")) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 880b6a89c..1c9e36e02 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -38,6 +38,11 @@ from vllm.utils.deep_gemm import ( should_use_deepgemm_for_fp8_linear, transform_sf_into_required_layout, ) +from vllm.utils.flashinfer import ( + flashinfer_fp8_blockscale_gemm, + is_flashinfer_fp8_blockscale_gemm_supported, + should_use_flashinfer_for_blockscale_fp8_gemm, +) from vllm.utils.torch_utils import direct_register_custom_op logger = init_logger(__name__) @@ -229,6 +234,112 @@ direct_register_custom_op( ) +def _flashinfer_fp8_blockscale_gemm_impl( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + group_size: int, + use_deep_gemm_e8m0: bool, +) -> torch.Tensor: + """ + Conditional FlashInfer FP8 blockscale GEMM with batch-size-dependent selection. + + This function switches between two optimized kernels based on the input batch size: + - For small batches (M < 32): Uses FlashInfer's DeepGEMM swapAB optimization. + - For larger batches (M >= 32): Uses the official DeepGEMM kernel. + + The conditional logic must use torch.cond() instead of a simple if-else statement + to maintain compatibility with torch.compile graph compilation. + + This batch-size-dependent selection is essential for maintaining model accuracy. + Benchmarks on GSM8K show a significant accuracy gap (88% vs 95%) for DeepSeek-V3.1 + when using FlashInfer's DeepGEMM on M>=32. The M < 32 strategy fixes the accurracy + drop. + + Args: + input: Input tensor of shape (batch_size, input_dim) in FP8 format + weight: Weight tensor of shape (output_dim, input_dim) in FP8 format + weight_scale: Scale factors for weight quantization (per-group) + group_size: Quantization group size for the weight tensor + use_deep_gemm_e8m0: Whether to use the E8M0 format in DeepGEMM quantization + + Returns: + Output tensor of shape (batch_size, output_dim) in bfloat16 format + """ + + def run_flashinfer_deepgemm_swapAB( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + ) -> torch.Tensor: + return flashinfer_fp8_blockscale_gemm( + input=input, + weight=weight, + weight_scale=weight_scale, + out_dtype=torch.bfloat16, + ) + + def run_deepgemm( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + ) -> torch.Tensor: + q_input, input_scale = per_token_group_quant_fp8( + input, + group_size=group_size, + column_major_scales=True, + use_ue8m0=use_deep_gemm_e8m0, + ) + output = torch.empty( + (q_input.shape[0], weight.shape[0]), + dtype=torch.bfloat16, + device=q_input.device, + ) + fp8_gemm_nt( + (q_input, input_scale), + (weight, weight_scale), + output, + is_deep_gemm_e8m0_used=use_deep_gemm_e8m0, + ) + return output + + condition = input.shape[0] < 32 + + # PyTorch's torch.compile cannot handle input-dependent control flow in standard + # Python conditionals. torch.cond() explicitly registers both code paths in the + # computation graph, allowing torch.compile to capture both branches. + # without torch.cond, the M < 32 condition won't be able to be captured by torch + # compile + return torch.cond( + condition, + run_flashinfer_deepgemm_swapAB, + run_deepgemm, + (input, weight, weight_scale), + ) + + +def _flashinfer_fp8_blockscale_gemm_fake( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + group_size: int, + use_deep_gemm_e8m0: bool, +) -> torch.Tensor: + """ + Required fake/meta implementation for torch.compile graph tracing. + """ + return torch.empty( + input.shape[0], weight.shape[0], dtype=torch.bfloat16, device=input.device + ) + + +direct_register_custom_op( + "flashinfer_fp8_blockscale_gemm", + _flashinfer_fp8_blockscale_gemm_impl, + fake_impl=_flashinfer_fp8_blockscale_gemm_fake, +) + + # TODO fix ROCm->Triton custom path: # https://github.com/vllm-project/vllm/issues/14397 class W8A8BlockFp8LinearOp: @@ -249,6 +360,7 @@ class W8A8BlockFp8LinearOp: self.is_deep_gemm_supported = is_deep_gemm_supported() self.is_hopper = current_platform.is_device_capability(90) self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used() + self.is_flashinfer_supported = is_flashinfer_fp8_blockscale_gemm_supported() # Get the correct blockscale mul and input quant operations. # We can't use _dispatch_w8a8_blockscale_op to figure out if we want @@ -284,7 +396,14 @@ class W8A8BlockFp8LinearOp: output_shape = [*input.shape[:-1], weight.shape[0]] output_dtype = input.dtype - if should_use_deepgemm_for_fp8_linear( + if should_use_flashinfer_for_blockscale_fp8_gemm( + self.is_flashinfer_supported, output_dtype, input_2d, weight + ) and should_use_deepgemm_for_fp8_linear( + output_dtype, weight, self.is_deep_gemm_supported + ): + output = self._run_flashinfer(input_2d, weight, weight_scale) + + elif should_use_deepgemm_for_fp8_linear( output_dtype, weight, self.is_deep_gemm_supported ): output = self._run_deepgemm(input_2d, weight, weight_scale) @@ -412,6 +531,29 @@ class W8A8BlockFp8LinearOp: input_2d.dtype, ) + def _run_flashinfer( + self, + input_2d: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + ) -> torch.Tensor: + """ + Run FlashInfer FP8 block-scale GEMM. + + This backend uses TensorRT-LLM's FP8 block-scale GEMM kernels + and supports FP8+FP8 (W8A8 full quantization) on SM90+ (Hopper). + """ + # Now call FlashInfer with BF16 input + FP8 weight, input will be + # quantized with FlashInfer kernel (W8A8) + output = torch.ops.vllm.flashinfer_fp8_blockscale_gemm( + input=input_2d, # BF16 input + weight=weight, # FP8 weight + weight_scale=weight_scale, # Weight scales + group_size=self.act_quant_group_shape.col, + use_deep_gemm_e8m0=self.use_deep_gemm_e8m0, + ) + return output + def _dispatch_w8a8_blockscale_op( self, use_cutlass: bool, diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 6bbe02348..0804add23 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -540,6 +540,59 @@ def flashinfer_scaled_fp8_mm( return output +flashinfer_fp8_blockscale_gemm = _lazy_import_wrapper( + "flashinfer.gemm", "fp8_blockscale_gemm_sm90" +) + + +@functools.cache +def has_flashinfer_fp8_blockscale_gemm() -> bool: + """Return `True` if FlashInfer block-scale FP8 GEMM is available.""" + return ( + has_flashinfer() + and current_platform.is_device_capability(90) + and hasattr(_get_submodule("flashinfer.gemm"), "fp8_blockscale_gemm_sm90") + ) + + +@functools.cache +def is_flashinfer_fp8_blockscale_gemm_supported() -> bool: + """Return `True` if FlashInfer block-scale FP8 GEMM is supported.""" + return ( + envs.VLLM_BLOCKSCALE_FP8_GEMM_FLASHINFER + and has_flashinfer_fp8_blockscale_gemm() + ) + + +def should_use_flashinfer_for_blockscale_fp8_gemm( + is_flashinfer_supported: bool, + output_dtype: torch.dtype, + input: torch.Tensor, + weight: torch.Tensor, +): + if not is_flashinfer_supported: + return False + + # Verify DeepGEMM N/K dims requirements + # NOTE: Also synchronized with test_w8a8_block_fp8_deep_gemm_matmul + # test inside kernels/quatization/test_block_fp8.py + N_MULTIPLE = 64 + K_MULTIPLE = 128 + + weight_dtype = weight.dtype + input_dtype = input.dtype + + should_use_flashinfer = ( + output_dtype == torch.bfloat16 + and input_dtype == torch.bfloat16 + and weight_dtype == torch.float8_e4m3fn + and weight.shape[0] % N_MULTIPLE == 0 + and weight.shape[1] % K_MULTIPLE == 0 + ) + + return should_use_flashinfer + + __all__ = [ "has_flashinfer", "flashinfer_trtllm_fp8_block_scale_moe", @@ -556,10 +609,14 @@ __all__ = [ "has_flashinfer_all2all", "has_flashinfer_cutlass_fused_moe", "has_flashinfer_cutedsl_grouped_gemm_nt_masked", + "has_flashinfer_fp8_blockscale_gemm", "has_nvidia_artifactory", "supports_trtllm_attention", "can_use_trtllm_attention", "use_trtllm_attention", "flashinfer_scaled_fp4_mm", "flashinfer_scaled_fp8_mm", + "flashinfer_fp8_blockscale_gemm", + "should_use_flashinfer_for_blockscale_fp8_gemm", + "is_flashinfer_fp8_blockscale_gemm_supported", ]