From 3e6a1e1686958dcd7eff1438bc5418b8d56daa30 Mon Sep 17 00:00:00 2001 From: Terry Gao <32590313+tianrengao@users.noreply.github.com> Date: Mon, 16 Mar 2026 15:51:46 -0700 Subject: [PATCH] [Custom Ops] Add functional + out variant for scaled_fp4_quant (#34389) Signed-off-by: tianrengao --- csrc/ops.h | 12 +- csrc/quantization/fp4/nvfp4_quant_entry.cu | 37 +++++- csrc/quantization/fp4/nvfp4_utils.cuh | 13 +++ csrc/torch_bindings.cpp | 19 +++- .../distributed/test_fusion_all_reduce.py | 2 +- .../kernels/quantization/test_nvfp4_quant.py | 46 ++++++++ vllm/_custom_ops.py | 106 ++++++++++++++---- .../passes/fusion/act_quant_fusion.py | 4 +- .../passes/fusion/allreduce_rms_fusion.py | 10 +- .../passes/fusion/attn_quant_fusion.py | 4 +- .../passes/fusion/matcher_utils.py | 2 +- .../passes/fusion/rms_quant_fusion.py | 2 +- 12 files changed, 213 insertions(+), 44 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index 921d6484d..299650be7 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -295,10 +295,14 @@ void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a, std::vector cutlass_sparse_compress(torch::Tensor const& a); -void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, - torch::Tensor& output_scale, - torch::Tensor const& input_scale, - bool is_sf_swizzled_layout); +std::tuple scaled_fp4_quant_func( + torch::Tensor const& input, torch::Tensor const& input_scale, + bool is_sf_swizzled_layout); + +void scaled_fp4_quant_out(torch::Tensor const& input, + torch::Tensor const& input_scale, + bool is_sf_swizzled_layout, torch::Tensor& output, + torch::Tensor& output_scale); void scaled_fp4_experts_quant( torch::Tensor& output, torch::Tensor& output_scale, diff --git a/csrc/quantization/fp4/nvfp4_quant_entry.cu b/csrc/quantization/fp4/nvfp4_quant_entry.cu index 650b9da8a..8b5a1fd22 100644 --- a/csrc/quantization/fp4/nvfp4_quant_entry.cu +++ b/csrc/quantization/fp4/nvfp4_quant_entry.cu @@ -16,6 +16,8 @@ #include +#include "nvfp4_utils.cuh" + #if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, @@ -51,9 +53,10 @@ void silu_and_mul_scaled_fp4_experts_quant_sm1xxa( torch::Tensor const& output_scale_offset_by_experts); #endif -void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, - torch::Tensor& output_sf, torch::Tensor const& input_sf, - bool is_sf_swizzled_layout) { +void scaled_fp4_quant_out(torch::Tensor const& input, + torch::Tensor const& input_sf, + bool is_sf_swizzled_layout, torch::Tensor& output, + torch::Tensor& output_sf) { #if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf, @@ -62,6 +65,34 @@ void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization kernel"); } +std::tuple scaled_fp4_quant_func( + torch::Tensor const& input, torch::Tensor const& input_sf, + bool is_sf_swizzled_layout) { + int64_t n = input.size(-1); + int64_t m = input.numel() / n; + auto device = input.device(); + + // Two fp4 values packed into a uint8 + auto output = torch::empty( + {m, n / 2}, torch::TensorOptions().device(device).dtype(torch::kUInt8)); + + torch::Tensor output_sf; + if (is_sf_swizzled_layout) { + auto [sf_m, sf_n] = vllm::computeSwizzledSFShape(m, n); + output_sf = torch::empty( + {sf_m, sf_n}, + torch::TensorOptions().device(device).dtype(torch::kInt32)); + } else { + output_sf = torch::empty( + {m, n / CVT_FP4_SF_VEC_SIZE}, + torch::TensorOptions().device(device).dtype(torch::kUInt8)); + } + + scaled_fp4_quant_out(input, input_sf, is_sf_swizzled_layout, output, + output_sf); + return {output, output_sf}; +} + void scaled_fp4_experts_quant( torch::Tensor& output, torch::Tensor& output_scale, torch::Tensor const& input, torch::Tensor const& input_global_scale, diff --git a/csrc/quantization/fp4/nvfp4_utils.cuh b/csrc/quantization/fp4/nvfp4_utils.cuh index c1df1860c..0c04f0108 100644 --- a/csrc/quantization/fp4/nvfp4_utils.cuh +++ b/csrc/quantization/fp4/nvfp4_utils.cuh @@ -18,6 +18,7 @@ #include #include +#include #include "../../cuda_vec_utils.cuh" @@ -54,6 +55,18 @@ inline int computeEffectiveRows(int m) { return round_up(m, ROW_TILE); } +// Compute the shape of the swizzled SF output tensor. +// Returns (rounded_m, rounded_n / 4) where: +// rounded_m = round_up(m, 128) +// rounded_n = round_up(n / CVT_FP4_SF_VEC_SIZE, 4) +inline std::pair computeSwizzledSFShape(int64_t m, + int64_t n) { + int64_t rounded_m = round_up(m, static_cast(128)); + int64_t scale_n = n / CVT_FP4_SF_VEC_SIZE; + int64_t rounded_n = round_up(scale_n, static_cast(4)); + return {rounded_m, rounded_n / 4}; +} + // Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). inline __device__ uint32_t fp32_vec8_to_e2m1(float (&array)[8]) { uint32_t val; diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index d98e987d9..aadc9fe33 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -564,10 +564,21 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Compute NVFP4 block quantized tensor. ops.def( - "scaled_fp4_quant(Tensor! output, Tensor input," - " Tensor! output_scale, Tensor input_scale, bool " - "is_sf_swizzled_layout) -> ()"); - ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant); + "scaled_fp4_quant(Tensor input," + " Tensor input_scale, bool " + "is_sf_swizzled_layout) -> (Tensor, Tensor)"); + ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant_func); + + // Out variant + // TODO: Add {at::Tag::out_variant} tag and update all call sites + // to use the functional variant once vLLM upgrades PyTorch. + // See pytorch/pytorch#176117. + ops.def( + "scaled_fp4_quant.out(Tensor input," + " Tensor input_scale, bool " + "is_sf_swizzled_layout, *, Tensor(a!) output, Tensor(b!) output_scale) " + "-> ()"); + ops.impl("scaled_fp4_quant.out", torch::kCUDA, &scaled_fp4_quant_out); // Compute NVFP4 experts quantization. ops.def( diff --git a/tests/compile/passes/distributed/test_fusion_all_reduce.py b/tests/compile/passes/distributed/test_fusion_all_reduce.py index fe50081e5..92e7402c0 100644 --- a/tests/compile/passes/distributed/test_fusion_all_reduce.py +++ b/tests/compile/passes/distributed/test_fusion_all_reduce.py @@ -179,7 +179,7 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module): def ops_in_model_before(self): return [ torch.ops.vllm.all_reduce.default, - torch.ops._C.scaled_fp4_quant.default, + torch.ops._C.scaled_fp4_quant.out, ] diff --git a/tests/kernels/quantization/test_nvfp4_quant.py b/tests/kernels/quantization/test_nvfp4_quant.py index 1d2f9d413..e2db59758 100644 --- a/tests/kernels/quantization/test_nvfp4_quant.py +++ b/tests/kernels/quantization/test_nvfp4_quant.py @@ -159,6 +159,52 @@ def test_quantize_to_fp4( torch.testing.assert_close(scale_ans, scale_ref) +@pytest.mark.parametrize( + "shape", + [(32, 4096), (128, 4096), (1, 64), (127, 1024), (256, 16384)], +) +@pytest.mark.parametrize("is_sf_swizzled_layout", [True, False]) +@torch.inference_mode() +def test_python_util_matches_cpp_allocation( + shape: tuple[int, int], + is_sf_swizzled_layout: bool, +) -> None: + """ + Verify that the Python utility (create_fp4_output_tensors) allocates + tensors with the same shapes and dtypes as the C++ functional variant + (scaled_fp4_quant_func). + """ + from vllm._custom_ops import create_fp4_output_tensors + + torch.set_default_device("cuda:0") + m, n = shape + input_tensor = torch.randn((m, n), dtype=torch.bfloat16) + input_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda:0") + + # C++ functional variant allocates internally + cpp_out, cpp_scale = torch.ops._C.scaled_fp4_quant( + input_tensor, input_scale, is_sf_swizzled_layout + ) + + # Python utility + py_out, py_scale = create_fp4_output_tensors( + m, n, torch.device("cuda:0"), is_sf_swizzled_layout + ) + + assert py_out.shape == cpp_out.shape, ( + f"Output shape mismatch: Python {py_out.shape} vs C++ {cpp_out.shape}" + ) + assert py_out.dtype == cpp_out.dtype, ( + f"Output dtype mismatch: Python {py_out.dtype} vs C++ {cpp_out.dtype}" + ) + assert py_scale.shape == cpp_scale.shape, ( + f"Scale shape mismatch: Python {py_scale.shape} vs C++ {cpp_scale.shape}" + ) + assert py_scale.dtype == cpp_scale.dtype, ( + f"Scale dtype mismatch: Python {py_scale.dtype} vs C++ {cpp_scale.dtype}" + ) + + @pytest.mark.parametrize("pad_shape", PAD_SHAPES) @torch.inference_mode() def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None: diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index fdc468d3b..63f347d89 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -29,6 +29,81 @@ else: from torch.library import impl_abstract as register_fake +# scaled_fp4_quant functional + out variant for torch.compile buffer management + + +def create_fp4_scale_tensor( + m: int, + n: int, + device: torch.device, + is_sf_swizzled_layout: bool, +) -> torch.Tensor: + """ + Allocate the output scale tensor for scaled_fp4_quant. + + When is_sf_swizzled_layout=True, we use rounded values to store the + swizzled scales. Due to the requirement of the Tensor Core, the minimum + tile is 128x4 for the scales. So, we first pad the scales to multiples + of 128 (rows) and 4 (cols). Then, the scales (in float8_e4m3fn) are + packed into an int32 for every 4 values. More: + https://docs.nvidia.com/cuda/parallel-thread-execution/ + #tcgen05-mma-scale-factor-b-layout-4x + """ + from vllm.utils.math_utils import round_up + + block_size = 16 + if is_sf_swizzled_layout: + rounded_m = round_up(m, 128) + scale_n = n // block_size + rounded_n = round_up(scale_n, 4) + return torch.empty( + (rounded_m, rounded_n // 4), device=device, dtype=torch.int32 + ) + else: + return torch.empty((m, n // block_size), device=device, dtype=torch.uint8) + + +def create_fp4_output_tensors( + m: int, + n: int, + device: torch.device, + is_sf_swizzled_layout: bool, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Allocate both output tensors for scaled_fp4_quant: + (quantized_output, output_scale). + + Must match the C++ scaled_fp4_quant_func allocation exactly. + """ + output = torch.empty((m, n // 2), device=device, dtype=torch.uint8) + output_scale = create_fp4_scale_tensor(m, n, device, is_sf_swizzled_layout) + return output, output_scale + + +if hasattr(torch.ops, "_C") and hasattr(torch.ops._C, "scaled_fp4_quant"): + + @register_fake("_C::scaled_fp4_quant") + def _scaled_fp4_quant_fake( + input: torch.Tensor, + input_scale: torch.Tensor, + is_sf_swizzled_layout: bool, + ) -> tuple[torch.Tensor, torch.Tensor]: + n = input.shape[-1] + m = input.numel() // n + return create_fp4_output_tensors(m, n, input.device, is_sf_swizzled_layout) + + @register_fake("_C::scaled_fp4_quant.out") + def _scaled_fp4_quant_out_fake( + input: torch.Tensor, + input_scale: torch.Tensor, + is_sf_swizzled_layout: bool, + *, + output: torch.Tensor, + output_scale: torch.Tensor, + ) -> None: + return None + + # page attention ops def paged_attention_v1( out: torch.Tensor, @@ -1644,7 +1719,6 @@ def scaled_fp4_quant( input = input.reshape(other_dims, input.shape[-1]) m, n = input.shape block_size = 16 - device = input.device assert n % block_size == 0, f"last dim has to be multiple of 16, but got {n}." assert input.dtype in (torch.float16, torch.bfloat16), ( @@ -1658,26 +1732,16 @@ def scaled_fp4_quant( input, input_global_scale ) else: - # Two fp4 values will be packed into an uint8. - output = torch.empty((m, n // 2), device=device, dtype=torch.uint8) - if is_sf_swizzled_layout: - # We use the rounded values to store the swizzled values. Due to the - # requirement of the Tensor Core, the minimum tile is 128x4 for the scales. - # So, we first pad the scales to multiples of 128 and 4. Then, the scales - # (in float8_e4m3fn) are packed into an int32 for every 4 values. More: - # https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x - round_up = lambda x, y: (x + y - 1) // y * y - rounded_m = round_up(m, 128) - scale_n = n // block_size - rounded_n = round_up(scale_n, 4) - output_scale = torch.empty( - (rounded_m, rounded_n // 4), device=device, dtype=torch.int32 - ) - else: - output_scale = torch.empty((m, n // 16), device=device, dtype=torch.uint8) - - torch.ops._C.scaled_fp4_quant( - output, input, output_scale, input_global_scale, is_sf_swizzled_layout + # Pre-allocate and call .out variant (same behavior as old in-place API) + output, output_scale = create_fp4_output_tensors( + m, n, input.device, is_sf_swizzled_layout + ) + torch.ops._C.scaled_fp4_quant.out( + input, + input_global_scale, + is_sf_swizzled_layout, + output=output, + output_scale=output_scale, ) output_scale = output_scale.view(torch.float8_e4m3fn) diff --git a/vllm/compilation/passes/fusion/act_quant_fusion.py b/vllm/compilation/passes/fusion/act_quant_fusion.py index e14100384..911775f69 100644 --- a/vllm/compilation/passes/fusion/act_quant_fusion.py +++ b/vllm/compilation/passes/fusion/act_quant_fusion.py @@ -148,11 +148,11 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern): result_silu_mul = self.silu_and_mul_matcher(input) at = auto_functionalized( self.QUANT_OP, - output=result, input=result_silu_mul, - output_scale=output_scale, input_scale=scale, is_sf_swizzled_layout=True, + output=result, + output_scale=output_scale, ) return at[1], at[2] diff --git a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py index 44dc3d67b..f141a7c17 100644 --- a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py +++ b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py @@ -47,7 +47,7 @@ if find_spec("flashinfer"): pass if hasattr(torch.ops._C, "scaled_fp4_quant"): - STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default + STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.out # Max size of the input tensor per world size per device capability # to use flashinfer fused allreduce @@ -562,11 +562,11 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern): rms = self.rmsnorm_matcher(all_reduce, weight) quant_out_tuple = auto_functionalized( STATIC_FP4_QUANT_OP, - output=quant_result, input=rms, - output_scale=output_scale, input_scale=input_global_scale, is_sf_swizzled_layout=True, + output=quant_result, + output_scale=output_scale, ) # quant_out, allreduce_output, output_scale @@ -660,11 +660,11 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern): rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual) quant_out_tuple = auto_functionalized( STATIC_FP4_QUANT_OP, - output=quant_result, input=rms, - output_scale=output_scale, input_scale=input_global_scale, is_sf_swizzled_layout=True, + output=quant_result, + output_scale=output_scale, ) # quant_out, allreduce_output, output_scale diff --git a/vllm/compilation/passes/fusion/attn_quant_fusion.py b/vllm/compilation/passes/fusion/attn_quant_fusion.py index 5e6bf28c0..0e1b846af 100644 --- a/vllm/compilation/passes/fusion/attn_quant_fusion.py +++ b/vllm/compilation/passes/fusion/attn_quant_fusion.py @@ -250,11 +250,11 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern): ) at2 = auto_functionalized( self.QUANT_OP, - output=output_quant, input=attn_out_view, - output_scale=output_scale, input_scale=input_scale, is_sf_swizzled_layout=True, + output=output_quant, + output_scale=output_scale, ) output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE) return at2[1], output_scale_view diff --git a/vllm/compilation/passes/fusion/matcher_utils.py b/vllm/compilation/passes/fusion/matcher_utils.py index 03f680552..ec36c12d1 100644 --- a/vllm/compilation/passes/fusion/matcher_utils.py +++ b/vllm/compilation/passes/fusion/matcher_utils.py @@ -38,7 +38,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = { } if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): - QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default # noqa: E501 + QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.out # noqa: E501 if current_platform.is_cuda(): QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501 diff --git a/vllm/compilation/passes/fusion/rms_quant_fusion.py b/vllm/compilation/passes/fusion/rms_quant_fusion.py index 2d084783d..95ce7b22e 100644 --- a/vllm/compilation/passes/fusion/rms_quant_fusion.py +++ b/vllm/compilation/passes/fusion/rms_quant_fusion.py @@ -63,7 +63,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = { kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 } if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): - QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default + QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.out if current_platform.is_cuda(): QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501 QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501