[Custom Ops] Add functional + out variant for scaled_fp4_quant (#34389)
Signed-off-by: tianrengao <terrygao87@gmail.com>
(cherry picked from commit 3e6a1e1686)
This commit is contained in:
@@ -29,6 +29,81 @@ else:
|
||||
from torch.library import impl_abstract as register_fake
|
||||
|
||||
|
||||
# scaled_fp4_quant functional + out variant for torch.compile buffer management
|
||||
|
||||
|
||||
def create_fp4_scale_tensor(
|
||||
m: int,
|
||||
n: int,
|
||||
device: torch.device,
|
||||
is_sf_swizzled_layout: bool,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Allocate the output scale tensor for scaled_fp4_quant.
|
||||
|
||||
When is_sf_swizzled_layout=True, we use rounded values to store the
|
||||
swizzled scales. Due to the requirement of the Tensor Core, the minimum
|
||||
tile is 128x4 for the scales. So, we first pad the scales to multiples
|
||||
of 128 (rows) and 4 (cols). Then, the scales (in float8_e4m3fn) are
|
||||
packed into an int32 for every 4 values. More:
|
||||
https://docs.nvidia.com/cuda/parallel-thread-execution/
|
||||
#tcgen05-mma-scale-factor-b-layout-4x
|
||||
"""
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
block_size = 16
|
||||
if is_sf_swizzled_layout:
|
||||
rounded_m = round_up(m, 128)
|
||||
scale_n = n // block_size
|
||||
rounded_n = round_up(scale_n, 4)
|
||||
return torch.empty(
|
||||
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
|
||||
)
|
||||
else:
|
||||
return torch.empty((m, n // block_size), device=device, dtype=torch.uint8)
|
||||
|
||||
|
||||
def create_fp4_output_tensors(
|
||||
m: int,
|
||||
n: int,
|
||||
device: torch.device,
|
||||
is_sf_swizzled_layout: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Allocate both output tensors for scaled_fp4_quant:
|
||||
(quantized_output, output_scale).
|
||||
|
||||
Must match the C++ scaled_fp4_quant_func allocation exactly.
|
||||
"""
|
||||
output = torch.empty((m, n // 2), device=device, dtype=torch.uint8)
|
||||
output_scale = create_fp4_scale_tensor(m, n, device, is_sf_swizzled_layout)
|
||||
return output, output_scale
|
||||
|
||||
|
||||
if hasattr(torch.ops, "_C") and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
|
||||
@register_fake("_C::scaled_fp4_quant")
|
||||
def _scaled_fp4_quant_fake(
|
||||
input: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
is_sf_swizzled_layout: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
n = input.shape[-1]
|
||||
m = input.numel() // n
|
||||
return create_fp4_output_tensors(m, n, input.device, is_sf_swizzled_layout)
|
||||
|
||||
@register_fake("_C::scaled_fp4_quant.out")
|
||||
def _scaled_fp4_quant_out_fake(
|
||||
input: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
is_sf_swizzled_layout: bool,
|
||||
*,
|
||||
output: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
|
||||
# page attention ops
|
||||
def paged_attention_v1(
|
||||
out: torch.Tensor,
|
||||
@@ -1644,7 +1719,6 @@ def scaled_fp4_quant(
|
||||
input = input.reshape(other_dims, input.shape[-1])
|
||||
m, n = input.shape
|
||||
block_size = 16
|
||||
device = input.device
|
||||
|
||||
assert n % block_size == 0, f"last dim has to be multiple of 16, but got {n}."
|
||||
assert input.dtype in (torch.float16, torch.bfloat16), (
|
||||
@@ -1658,26 +1732,16 @@ def scaled_fp4_quant(
|
||||
input, input_global_scale
|
||||
)
|
||||
else:
|
||||
# Two fp4 values will be packed into an uint8.
|
||||
output = torch.empty((m, n // 2), device=device, dtype=torch.uint8)
|
||||
if is_sf_swizzled_layout:
|
||||
# We use the rounded values to store the swizzled values. Due to the
|
||||
# requirement of the Tensor Core, the minimum tile is 128x4 for the scales.
|
||||
# So, we first pad the scales to multiples of 128 and 4. Then, the scales
|
||||
# (in float8_e4m3fn) are packed into an int32 for every 4 values. More:
|
||||
# https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x
|
||||
round_up = lambda x, y: (x + y - 1) // y * y
|
||||
rounded_m = round_up(m, 128)
|
||||
scale_n = n // block_size
|
||||
rounded_n = round_up(scale_n, 4)
|
||||
output_scale = torch.empty(
|
||||
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
|
||||
)
|
||||
else:
|
||||
output_scale = torch.empty((m, n // 16), device=device, dtype=torch.uint8)
|
||||
|
||||
torch.ops._C.scaled_fp4_quant(
|
||||
output, input, output_scale, input_global_scale, is_sf_swizzled_layout
|
||||
# Pre-allocate and call .out variant (same behavior as old in-place API)
|
||||
output, output_scale = create_fp4_output_tensors(
|
||||
m, n, input.device, is_sf_swizzled_layout
|
||||
)
|
||||
torch.ops._C.scaled_fp4_quant.out(
|
||||
input,
|
||||
input_global_scale,
|
||||
is_sf_swizzled_layout,
|
||||
output=output,
|
||||
output_scale=output_scale,
|
||||
)
|
||||
|
||||
output_scale = output_scale.view(torch.float8_e4m3fn)
|
||||
|
||||
@@ -148,11 +148,11 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
|
||||
result_silu_mul = self.silu_and_mul_matcher(input)
|
||||
at = auto_functionalized(
|
||||
self.QUANT_OP,
|
||||
output=result,
|
||||
input=result_silu_mul,
|
||||
output_scale=output_scale,
|
||||
input_scale=scale,
|
||||
is_sf_swizzled_layout=True,
|
||||
output=result,
|
||||
output_scale=output_scale,
|
||||
)
|
||||
return at[1], at[2]
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ if find_spec("flashinfer"):
|
||||
pass
|
||||
|
||||
if hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default
|
||||
STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.out
|
||||
|
||||
# Max size of the input tensor per world size per device capability
|
||||
# to use flashinfer fused allreduce
|
||||
@@ -562,11 +562,11 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||
rms = self.rmsnorm_matcher(all_reduce, weight)
|
||||
quant_out_tuple = auto_functionalized(
|
||||
STATIC_FP4_QUANT_OP,
|
||||
output=quant_result,
|
||||
input=rms,
|
||||
output_scale=output_scale,
|
||||
input_scale=input_global_scale,
|
||||
is_sf_swizzled_layout=True,
|
||||
output=quant_result,
|
||||
output_scale=output_scale,
|
||||
)
|
||||
|
||||
# quant_out, allreduce_output, output_scale
|
||||
@@ -660,11 +660,11 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||
rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual)
|
||||
quant_out_tuple = auto_functionalized(
|
||||
STATIC_FP4_QUANT_OP,
|
||||
output=quant_result,
|
||||
input=rms,
|
||||
output_scale=output_scale,
|
||||
input_scale=input_global_scale,
|
||||
is_sf_swizzled_layout=True,
|
||||
output=quant_result,
|
||||
output_scale=output_scale,
|
||||
)
|
||||
|
||||
# quant_out, allreduce_output, output_scale
|
||||
|
||||
@@ -250,11 +250,11 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
||||
)
|
||||
at2 = auto_functionalized(
|
||||
self.QUANT_OP,
|
||||
output=output_quant,
|
||||
input=attn_out_view,
|
||||
output_scale=output_scale,
|
||||
input_scale=input_scale,
|
||||
is_sf_swizzled_layout=True,
|
||||
output=output_quant,
|
||||
output_scale=output_scale,
|
||||
)
|
||||
output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
|
||||
return at2[1], output_scale_view
|
||||
|
||||
@@ -38,7 +38,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
|
||||
}
|
||||
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
|
||||
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.out # noqa: E501
|
||||
|
||||
if current_platform.is_cuda():
|
||||
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||
|
||||
@@ -63,7 +63,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
|
||||
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
|
||||
}
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default
|
||||
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.out
|
||||
if current_platform.is_cuda():
|
||||
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||
|
||||
Reference in New Issue
Block a user