[Custom Ops] Add functional + out variant for scaled_fp4_quant (#34389)

Signed-off-by: tianrengao <terrygao87@gmail.com>
This commit is contained in:
Terry Gao
2026-03-16 15:51:46 -07:00
committed by GitHub
parent 7961486a9b
commit 3e6a1e1686
12 changed files with 213 additions and 44 deletions

View File

@@ -295,10 +295,14 @@ void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a); std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a);
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, std::tuple<torch::Tensor, torch::Tensor> scaled_fp4_quant_func(
torch::Tensor& output_scale, torch::Tensor const& input, torch::Tensor const& input_scale,
torch::Tensor const& input_scale, bool is_sf_swizzled_layout);
bool is_sf_swizzled_layout);
void scaled_fp4_quant_out(torch::Tensor const& input,
torch::Tensor const& input_scale,
bool is_sf_swizzled_layout, torch::Tensor& output,
torch::Tensor& output_scale);
void scaled_fp4_experts_quant( void scaled_fp4_experts_quant(
torch::Tensor& output, torch::Tensor& output_scale, torch::Tensor& output, torch::Tensor& output_scale,

View File

@@ -16,6 +16,8 @@
#include <torch/all.h> #include <torch/all.h>
#include "nvfp4_utils.cuh"
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ #if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
@@ -51,9 +53,10 @@ void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
torch::Tensor const& output_scale_offset_by_experts); torch::Tensor const& output_scale_offset_by_experts);
#endif #endif
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, void scaled_fp4_quant_out(torch::Tensor const& input,
torch::Tensor& output_sf, torch::Tensor const& input_sf, torch::Tensor const& input_sf,
bool is_sf_swizzled_layout) { bool is_sf_swizzled_layout, torch::Tensor& output,
torch::Tensor& output_sf) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ #if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf, return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf,
@@ -62,6 +65,34 @@ void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization kernel"); TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization kernel");
} }
std::tuple<torch::Tensor, torch::Tensor> scaled_fp4_quant_func(
torch::Tensor const& input, torch::Tensor const& input_sf,
bool is_sf_swizzled_layout) {
int64_t n = input.size(-1);
int64_t m = input.numel() / n;
auto device = input.device();
// Two fp4 values packed into a uint8
auto output = torch::empty(
{m, n / 2}, torch::TensorOptions().device(device).dtype(torch::kUInt8));
torch::Tensor output_sf;
if (is_sf_swizzled_layout) {
auto [sf_m, sf_n] = vllm::computeSwizzledSFShape(m, n);
output_sf = torch::empty(
{sf_m, sf_n},
torch::TensorOptions().device(device).dtype(torch::kInt32));
} else {
output_sf = torch::empty(
{m, n / CVT_FP4_SF_VEC_SIZE},
torch::TensorOptions().device(device).dtype(torch::kUInt8));
}
scaled_fp4_quant_out(input, input_sf, is_sf_swizzled_layout, output,
output_sf);
return {output, output_sf};
}
void scaled_fp4_experts_quant( void scaled_fp4_experts_quant(
torch::Tensor& output, torch::Tensor& output_scale, torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale, torch::Tensor const& input, torch::Tensor const& input_global_scale,

View File

@@ -18,6 +18,7 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp8.h> #include <cuda_fp8.h>
#include <utility>
#include "../../cuda_vec_utils.cuh" #include "../../cuda_vec_utils.cuh"
@@ -54,6 +55,18 @@ inline int computeEffectiveRows(int m) {
return round_up(m, ROW_TILE); return round_up(m, ROW_TILE);
} }
// Compute the shape of the swizzled SF output tensor.
// Returns (rounded_m, rounded_n / 4) where:
// rounded_m = round_up(m, 128)
// rounded_n = round_up(n / CVT_FP4_SF_VEC_SIZE, 4)
inline std::pair<int64_t, int64_t> computeSwizzledSFShape(int64_t m,
int64_t n) {
int64_t rounded_m = round_up(m, static_cast<int64_t>(128));
int64_t scale_n = n / CVT_FP4_SF_VEC_SIZE;
int64_t rounded_n = round_up(scale_n, static_cast<int64_t>(4));
return {rounded_m, rounded_n / 4};
}
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). // Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
inline __device__ uint32_t fp32_vec8_to_e2m1(float (&array)[8]) { inline __device__ uint32_t fp32_vec8_to_e2m1(float (&array)[8]) {
uint32_t val; uint32_t val;

View File

@@ -564,10 +564,21 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Compute NVFP4 block quantized tensor. // Compute NVFP4 block quantized tensor.
ops.def( ops.def(
"scaled_fp4_quant(Tensor! output, Tensor input," "scaled_fp4_quant(Tensor input,"
" Tensor! output_scale, Tensor input_scale, bool " " Tensor input_scale, bool "
"is_sf_swizzled_layout) -> ()"); "is_sf_swizzled_layout) -> (Tensor, Tensor)");
ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant); ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant_func);
// Out variant
// TODO: Add {at::Tag::out_variant} tag and update all call sites
// to use the functional variant once vLLM upgrades PyTorch.
// See pytorch/pytorch#176117.
ops.def(
"scaled_fp4_quant.out(Tensor input,"
" Tensor input_scale, bool "
"is_sf_swizzled_layout, *, Tensor(a!) output, Tensor(b!) output_scale) "
"-> ()");
ops.impl("scaled_fp4_quant.out", torch::kCUDA, &scaled_fp4_quant_out);
// Compute NVFP4 experts quantization. // Compute NVFP4 experts quantization.
ops.def( ops.def(

View File

@@ -179,7 +179,7 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
def ops_in_model_before(self): def ops_in_model_before(self):
return [ return [
torch.ops.vllm.all_reduce.default, torch.ops.vllm.all_reduce.default,
torch.ops._C.scaled_fp4_quant.default, torch.ops._C.scaled_fp4_quant.out,
] ]

View File

@@ -159,6 +159,52 @@ def test_quantize_to_fp4(
torch.testing.assert_close(scale_ans, scale_ref) torch.testing.assert_close(scale_ans, scale_ref)
@pytest.mark.parametrize(
"shape",
[(32, 4096), (128, 4096), (1, 64), (127, 1024), (256, 16384)],
)
@pytest.mark.parametrize("is_sf_swizzled_layout", [True, False])
@torch.inference_mode()
def test_python_util_matches_cpp_allocation(
shape: tuple[int, int],
is_sf_swizzled_layout: bool,
) -> None:
"""
Verify that the Python utility (create_fp4_output_tensors) allocates
tensors with the same shapes and dtypes as the C++ functional variant
(scaled_fp4_quant_func).
"""
from vllm._custom_ops import create_fp4_output_tensors
torch.set_default_device("cuda:0")
m, n = shape
input_tensor = torch.randn((m, n), dtype=torch.bfloat16)
input_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda:0")
# C++ functional variant allocates internally
cpp_out, cpp_scale = torch.ops._C.scaled_fp4_quant(
input_tensor, input_scale, is_sf_swizzled_layout
)
# Python utility
py_out, py_scale = create_fp4_output_tensors(
m, n, torch.device("cuda:0"), is_sf_swizzled_layout
)
assert py_out.shape == cpp_out.shape, (
f"Output shape mismatch: Python {py_out.shape} vs C++ {cpp_out.shape}"
)
assert py_out.dtype == cpp_out.dtype, (
f"Output dtype mismatch: Python {py_out.dtype} vs C++ {cpp_out.dtype}"
)
assert py_scale.shape == cpp_scale.shape, (
f"Scale shape mismatch: Python {py_scale.shape} vs C++ {cpp_scale.shape}"
)
assert py_scale.dtype == cpp_scale.dtype, (
f"Scale dtype mismatch: Python {py_scale.dtype} vs C++ {cpp_scale.dtype}"
)
@pytest.mark.parametrize("pad_shape", PAD_SHAPES) @pytest.mark.parametrize("pad_shape", PAD_SHAPES)
@torch.inference_mode() @torch.inference_mode()
def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None: def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None:

View File

@@ -29,6 +29,81 @@ else:
from torch.library import impl_abstract as register_fake from torch.library import impl_abstract as register_fake
# scaled_fp4_quant functional + out variant for torch.compile buffer management
def create_fp4_scale_tensor(
m: int,
n: int,
device: torch.device,
is_sf_swizzled_layout: bool,
) -> torch.Tensor:
"""
Allocate the output scale tensor for scaled_fp4_quant.
When is_sf_swizzled_layout=True, we use rounded values to store the
swizzled scales. Due to the requirement of the Tensor Core, the minimum
tile is 128x4 for the scales. So, we first pad the scales to multiples
of 128 (rows) and 4 (cols). Then, the scales (in float8_e4m3fn) are
packed into an int32 for every 4 values. More:
https://docs.nvidia.com/cuda/parallel-thread-execution/
#tcgen05-mma-scale-factor-b-layout-4x
"""
from vllm.utils.math_utils import round_up
block_size = 16
if is_sf_swizzled_layout:
rounded_m = round_up(m, 128)
scale_n = n // block_size
rounded_n = round_up(scale_n, 4)
return torch.empty(
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
)
else:
return torch.empty((m, n // block_size), device=device, dtype=torch.uint8)
def create_fp4_output_tensors(
m: int,
n: int,
device: torch.device,
is_sf_swizzled_layout: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Allocate both output tensors for scaled_fp4_quant:
(quantized_output, output_scale).
Must match the C++ scaled_fp4_quant_func allocation exactly.
"""
output = torch.empty((m, n // 2), device=device, dtype=torch.uint8)
output_scale = create_fp4_scale_tensor(m, n, device, is_sf_swizzled_layout)
return output, output_scale
if hasattr(torch.ops, "_C") and hasattr(torch.ops._C, "scaled_fp4_quant"):
@register_fake("_C::scaled_fp4_quant")
def _scaled_fp4_quant_fake(
input: torch.Tensor,
input_scale: torch.Tensor,
is_sf_swizzled_layout: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
n = input.shape[-1]
m = input.numel() // n
return create_fp4_output_tensors(m, n, input.device, is_sf_swizzled_layout)
@register_fake("_C::scaled_fp4_quant.out")
def _scaled_fp4_quant_out_fake(
input: torch.Tensor,
input_scale: torch.Tensor,
is_sf_swizzled_layout: bool,
*,
output: torch.Tensor,
output_scale: torch.Tensor,
) -> None:
return None
# page attention ops # page attention ops
def paged_attention_v1( def paged_attention_v1(
out: torch.Tensor, out: torch.Tensor,
@@ -1644,7 +1719,6 @@ def scaled_fp4_quant(
input = input.reshape(other_dims, input.shape[-1]) input = input.reshape(other_dims, input.shape[-1])
m, n = input.shape m, n = input.shape
block_size = 16 block_size = 16
device = input.device
assert n % block_size == 0, f"last dim has to be multiple of 16, but got {n}." assert n % block_size == 0, f"last dim has to be multiple of 16, but got {n}."
assert input.dtype in (torch.float16, torch.bfloat16), ( assert input.dtype in (torch.float16, torch.bfloat16), (
@@ -1658,26 +1732,16 @@ def scaled_fp4_quant(
input, input_global_scale input, input_global_scale
) )
else: else:
# Two fp4 values will be packed into an uint8. # Pre-allocate and call .out variant (same behavior as old in-place API)
output = torch.empty((m, n // 2), device=device, dtype=torch.uint8) output, output_scale = create_fp4_output_tensors(
if is_sf_swizzled_layout: m, n, input.device, is_sf_swizzled_layout
# We use the rounded values to store the swizzled values. Due to the )
# requirement of the Tensor Core, the minimum tile is 128x4 for the scales. torch.ops._C.scaled_fp4_quant.out(
# So, we first pad the scales to multiples of 128 and 4. Then, the scales input,
# (in float8_e4m3fn) are packed into an int32 for every 4 values. More: input_global_scale,
# https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x is_sf_swizzled_layout,
round_up = lambda x, y: (x + y - 1) // y * y output=output,
rounded_m = round_up(m, 128) output_scale=output_scale,
scale_n = n // block_size
rounded_n = round_up(scale_n, 4)
output_scale = torch.empty(
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
)
else:
output_scale = torch.empty((m, n // 16), device=device, dtype=torch.uint8)
torch.ops._C.scaled_fp4_quant(
output, input, output_scale, input_global_scale, is_sf_swizzled_layout
) )
output_scale = output_scale.view(torch.float8_e4m3fn) output_scale = output_scale.view(torch.float8_e4m3fn)

View File

@@ -148,11 +148,11 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
result_silu_mul = self.silu_and_mul_matcher(input) result_silu_mul = self.silu_and_mul_matcher(input)
at = auto_functionalized( at = auto_functionalized(
self.QUANT_OP, self.QUANT_OP,
output=result,
input=result_silu_mul, input=result_silu_mul,
output_scale=output_scale,
input_scale=scale, input_scale=scale,
is_sf_swizzled_layout=True, is_sf_swizzled_layout=True,
output=result,
output_scale=output_scale,
) )
return at[1], at[2] return at[1], at[2]

View File

@@ -47,7 +47,7 @@ if find_spec("flashinfer"):
pass pass
if hasattr(torch.ops._C, "scaled_fp4_quant"): if hasattr(torch.ops._C, "scaled_fp4_quant"):
STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.out
# Max size of the input tensor per world size per device capability # Max size of the input tensor per world size per device capability
# to use flashinfer fused allreduce # to use flashinfer fused allreduce
@@ -562,11 +562,11 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
rms = self.rmsnorm_matcher(all_reduce, weight) rms = self.rmsnorm_matcher(all_reduce, weight)
quant_out_tuple = auto_functionalized( quant_out_tuple = auto_functionalized(
STATIC_FP4_QUANT_OP, STATIC_FP4_QUANT_OP,
output=quant_result,
input=rms, input=rms,
output_scale=output_scale,
input_scale=input_global_scale, input_scale=input_global_scale,
is_sf_swizzled_layout=True, is_sf_swizzled_layout=True,
output=quant_result,
output_scale=output_scale,
) )
# quant_out, allreduce_output, output_scale # quant_out, allreduce_output, output_scale
@@ -660,11 +660,11 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual) rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual)
quant_out_tuple = auto_functionalized( quant_out_tuple = auto_functionalized(
STATIC_FP4_QUANT_OP, STATIC_FP4_QUANT_OP,
output=quant_result,
input=rms, input=rms,
output_scale=output_scale,
input_scale=input_global_scale, input_scale=input_global_scale,
is_sf_swizzled_layout=True, is_sf_swizzled_layout=True,
output=quant_result,
output_scale=output_scale,
) )
# quant_out, allreduce_output, output_scale # quant_out, allreduce_output, output_scale

View File

@@ -250,11 +250,11 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
) )
at2 = auto_functionalized( at2 = auto_functionalized(
self.QUANT_OP, self.QUANT_OP,
output=output_quant,
input=attn_out_view, input=attn_out_view,
output_scale=output_scale,
input_scale=input_scale, input_scale=input_scale,
is_sf_swizzled_layout=True, is_sf_swizzled_layout=True,
output=output_quant,
output_scale=output_scale,
) )
output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE) output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
return at2[1], output_scale_view return at2[1], output_scale_view

View File

@@ -38,7 +38,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
} }
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default # noqa: E501 QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.out # noqa: E501
if current_platform.is_cuda(): if current_platform.is_cuda():
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501 QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501

View File

@@ -63,7 +63,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
} }
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.out
if current_platform.is_cuda(): if current_platform.is_cuda():
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501 QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501 QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501