[Custom Ops] Add functional + out variant for scaled_fp4_quant (#34389)

Signed-off-by: tianrengao <terrygao87@gmail.com>
(cherry picked from commit 3e6a1e1686)
This commit is contained in:
Terry Gao
2026-03-16 15:51:46 -07:00
committed by khluu
parent cdcffafef8
commit eeabf740bb
12 changed files with 213 additions and 44 deletions

View File

@@ -29,6 +29,81 @@ else:
from torch.library import impl_abstract as register_fake
# scaled_fp4_quant functional + out variant for torch.compile buffer management
def create_fp4_scale_tensor(
m: int,
n: int,
device: torch.device,
is_sf_swizzled_layout: bool,
) -> torch.Tensor:
"""
Allocate the output scale tensor for scaled_fp4_quant.
When is_sf_swizzled_layout=True, we use rounded values to store the
swizzled scales. Due to the requirement of the Tensor Core, the minimum
tile is 128x4 for the scales. So, we first pad the scales to multiples
of 128 (rows) and 4 (cols). Then, the scales (in float8_e4m3fn) are
packed into an int32 for every 4 values. More:
https://docs.nvidia.com/cuda/parallel-thread-execution/
#tcgen05-mma-scale-factor-b-layout-4x
"""
from vllm.utils.math_utils import round_up
block_size = 16
if is_sf_swizzled_layout:
rounded_m = round_up(m, 128)
scale_n = n // block_size
rounded_n = round_up(scale_n, 4)
return torch.empty(
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
)
else:
return torch.empty((m, n // block_size), device=device, dtype=torch.uint8)
def create_fp4_output_tensors(
m: int,
n: int,
device: torch.device,
is_sf_swizzled_layout: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Allocate both output tensors for scaled_fp4_quant:
(quantized_output, output_scale).
Must match the C++ scaled_fp4_quant_func allocation exactly.
"""
output = torch.empty((m, n // 2), device=device, dtype=torch.uint8)
output_scale = create_fp4_scale_tensor(m, n, device, is_sf_swizzled_layout)
return output, output_scale
if hasattr(torch.ops, "_C") and hasattr(torch.ops._C, "scaled_fp4_quant"):
@register_fake("_C::scaled_fp4_quant")
def _scaled_fp4_quant_fake(
input: torch.Tensor,
input_scale: torch.Tensor,
is_sf_swizzled_layout: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
n = input.shape[-1]
m = input.numel() // n
return create_fp4_output_tensors(m, n, input.device, is_sf_swizzled_layout)
@register_fake("_C::scaled_fp4_quant.out")
def _scaled_fp4_quant_out_fake(
input: torch.Tensor,
input_scale: torch.Tensor,
is_sf_swizzled_layout: bool,
*,
output: torch.Tensor,
output_scale: torch.Tensor,
) -> None:
return None
# page attention ops
def paged_attention_v1(
out: torch.Tensor,
@@ -1644,7 +1719,6 @@ def scaled_fp4_quant(
input = input.reshape(other_dims, input.shape[-1])
m, n = input.shape
block_size = 16
device = input.device
assert n % block_size == 0, f"last dim has to be multiple of 16, but got {n}."
assert input.dtype in (torch.float16, torch.bfloat16), (
@@ -1658,26 +1732,16 @@ def scaled_fp4_quant(
input, input_global_scale
)
else:
# Two fp4 values will be packed into an uint8.
output = torch.empty((m, n // 2), device=device, dtype=torch.uint8)
if is_sf_swizzled_layout:
# We use the rounded values to store the swizzled values. Due to the
# requirement of the Tensor Core, the minimum tile is 128x4 for the scales.
# So, we first pad the scales to multiples of 128 and 4. Then, the scales
# (in float8_e4m3fn) are packed into an int32 for every 4 values. More:
# https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x
round_up = lambda x, y: (x + y - 1) // y * y
rounded_m = round_up(m, 128)
scale_n = n // block_size
rounded_n = round_up(scale_n, 4)
output_scale = torch.empty(
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
)
else:
output_scale = torch.empty((m, n // 16), device=device, dtype=torch.uint8)
torch.ops._C.scaled_fp4_quant(
output, input, output_scale, input_global_scale, is_sf_swizzled_layout
# Pre-allocate and call .out variant (same behavior as old in-place API)
output, output_scale = create_fp4_output_tensors(
m, n, input.device, is_sf_swizzled_layout
)
torch.ops._C.scaled_fp4_quant.out(
input,
input_global_scale,
is_sf_swizzled_layout,
output=output,
output_scale=output_scale,
)
output_scale = output_scale.view(torch.float8_e4m3fn)

View File

@@ -148,11 +148,11 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
result_silu_mul = self.silu_and_mul_matcher(input)
at = auto_functionalized(
self.QUANT_OP,
output=result,
input=result_silu_mul,
output_scale=output_scale,
input_scale=scale,
is_sf_swizzled_layout=True,
output=result,
output_scale=output_scale,
)
return at[1], at[2]

View File

@@ -47,7 +47,7 @@ if find_spec("flashinfer"):
pass
if hasattr(torch.ops._C, "scaled_fp4_quant"):
STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default
STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.out
# Max size of the input tensor per world size per device capability
# to use flashinfer fused allreduce
@@ -562,11 +562,11 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
rms = self.rmsnorm_matcher(all_reduce, weight)
quant_out_tuple = auto_functionalized(
STATIC_FP4_QUANT_OP,
output=quant_result,
input=rms,
output_scale=output_scale,
input_scale=input_global_scale,
is_sf_swizzled_layout=True,
output=quant_result,
output_scale=output_scale,
)
# quant_out, allreduce_output, output_scale
@@ -660,11 +660,11 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual)
quant_out_tuple = auto_functionalized(
STATIC_FP4_QUANT_OP,
output=quant_result,
input=rms,
output_scale=output_scale,
input_scale=input_global_scale,
is_sf_swizzled_layout=True,
output=quant_result,
output_scale=output_scale,
)
# quant_out, allreduce_output, output_scale

View File

@@ -250,11 +250,11 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
)
at2 = auto_functionalized(
self.QUANT_OP,
output=output_quant,
input=attn_out_view,
output_scale=output_scale,
input_scale=input_scale,
is_sf_swizzled_layout=True,
output=output_quant,
output_scale=output_scale,
)
output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
return at2[1], output_scale_view

View File

@@ -38,7 +38,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
}
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.out # noqa: E501
if current_platform.is_cuda():
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501

View File

@@ -63,7 +63,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
}
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.out
if current_platform.is_cuda():
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501