[Custom Ops] Add functional + out variant for scaled_fp4_quant (#34389)
Signed-off-by: tianrengao <terrygao87@gmail.com>
This commit is contained in:
12
csrc/ops.h
12
csrc/ops.h
@@ -295,10 +295,14 @@ void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
|
||||
|
||||
std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a);
|
||||
|
||||
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
|
||||
torch::Tensor& output_scale,
|
||||
torch::Tensor const& input_scale,
|
||||
bool is_sf_swizzled_layout);
|
||||
std::tuple<torch::Tensor, torch::Tensor> scaled_fp4_quant_func(
|
||||
torch::Tensor const& input, torch::Tensor const& input_scale,
|
||||
bool is_sf_swizzled_layout);
|
||||
|
||||
void scaled_fp4_quant_out(torch::Tensor const& input,
|
||||
torch::Tensor const& input_scale,
|
||||
bool is_sf_swizzled_layout, torch::Tensor& output,
|
||||
torch::Tensor& output_scale);
|
||||
|
||||
void scaled_fp4_experts_quant(
|
||||
torch::Tensor& output, torch::Tensor& output_scale,
|
||||
|
||||
@@ -16,6 +16,8 @@
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "nvfp4_utils.cuh"
|
||||
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||
void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
|
||||
@@ -51,9 +53,10 @@ void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
|
||||
torch::Tensor const& output_scale_offset_by_experts);
|
||||
#endif
|
||||
|
||||
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
|
||||
torch::Tensor& output_sf, torch::Tensor const& input_sf,
|
||||
bool is_sf_swizzled_layout) {
|
||||
void scaled_fp4_quant_out(torch::Tensor const& input,
|
||||
torch::Tensor const& input_sf,
|
||||
bool is_sf_swizzled_layout, torch::Tensor& output,
|
||||
torch::Tensor& output_sf) {
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||
return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf,
|
||||
@@ -62,6 +65,34 @@ void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization kernel");
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> scaled_fp4_quant_func(
|
||||
torch::Tensor const& input, torch::Tensor const& input_sf,
|
||||
bool is_sf_swizzled_layout) {
|
||||
int64_t n = input.size(-1);
|
||||
int64_t m = input.numel() / n;
|
||||
auto device = input.device();
|
||||
|
||||
// Two fp4 values packed into a uint8
|
||||
auto output = torch::empty(
|
||||
{m, n / 2}, torch::TensorOptions().device(device).dtype(torch::kUInt8));
|
||||
|
||||
torch::Tensor output_sf;
|
||||
if (is_sf_swizzled_layout) {
|
||||
auto [sf_m, sf_n] = vllm::computeSwizzledSFShape(m, n);
|
||||
output_sf = torch::empty(
|
||||
{sf_m, sf_n},
|
||||
torch::TensorOptions().device(device).dtype(torch::kInt32));
|
||||
} else {
|
||||
output_sf = torch::empty(
|
||||
{m, n / CVT_FP4_SF_VEC_SIZE},
|
||||
torch::TensorOptions().device(device).dtype(torch::kUInt8));
|
||||
}
|
||||
|
||||
scaled_fp4_quant_out(input, input_sf, is_sf_swizzled_layout, output,
|
||||
output_sf);
|
||||
return {output, output_sf};
|
||||
}
|
||||
|
||||
void scaled_fp4_experts_quant(
|
||||
torch::Tensor& output, torch::Tensor& output_scale,
|
||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <utility>
|
||||
|
||||
#include "../../cuda_vec_utils.cuh"
|
||||
|
||||
@@ -54,6 +55,18 @@ inline int computeEffectiveRows(int m) {
|
||||
return round_up(m, ROW_TILE);
|
||||
}
|
||||
|
||||
// Compute the shape of the swizzled SF output tensor.
|
||||
// Returns (rounded_m, rounded_n / 4) where:
|
||||
// rounded_m = round_up(m, 128)
|
||||
// rounded_n = round_up(n / CVT_FP4_SF_VEC_SIZE, 4)
|
||||
inline std::pair<int64_t, int64_t> computeSwizzledSFShape(int64_t m,
|
||||
int64_t n) {
|
||||
int64_t rounded_m = round_up(m, static_cast<int64_t>(128));
|
||||
int64_t scale_n = n / CVT_FP4_SF_VEC_SIZE;
|
||||
int64_t rounded_n = round_up(scale_n, static_cast<int64_t>(4));
|
||||
return {rounded_m, rounded_n / 4};
|
||||
}
|
||||
|
||||
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
|
||||
inline __device__ uint32_t fp32_vec8_to_e2m1(float (&array)[8]) {
|
||||
uint32_t val;
|
||||
|
||||
@@ -564,10 +564,21 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
|
||||
// Compute NVFP4 block quantized tensor.
|
||||
ops.def(
|
||||
"scaled_fp4_quant(Tensor! output, Tensor input,"
|
||||
" Tensor! output_scale, Tensor input_scale, bool "
|
||||
"is_sf_swizzled_layout) -> ()");
|
||||
ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant);
|
||||
"scaled_fp4_quant(Tensor input,"
|
||||
" Tensor input_scale, bool "
|
||||
"is_sf_swizzled_layout) -> (Tensor, Tensor)");
|
||||
ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant_func);
|
||||
|
||||
// Out variant
|
||||
// TODO: Add {at::Tag::out_variant} tag and update all call sites
|
||||
// to use the functional variant once vLLM upgrades PyTorch.
|
||||
// See pytorch/pytorch#176117.
|
||||
ops.def(
|
||||
"scaled_fp4_quant.out(Tensor input,"
|
||||
" Tensor input_scale, bool "
|
||||
"is_sf_swizzled_layout, *, Tensor(a!) output, Tensor(b!) output_scale) "
|
||||
"-> ()");
|
||||
ops.impl("scaled_fp4_quant.out", torch::kCUDA, &scaled_fp4_quant_out);
|
||||
|
||||
// Compute NVFP4 experts quantization.
|
||||
ops.def(
|
||||
|
||||
@@ -179,7 +179,7 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
|
||||
def ops_in_model_before(self):
|
||||
return [
|
||||
torch.ops.vllm.all_reduce.default,
|
||||
torch.ops._C.scaled_fp4_quant.default,
|
||||
torch.ops._C.scaled_fp4_quant.out,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -159,6 +159,52 @@ def test_quantize_to_fp4(
|
||||
torch.testing.assert_close(scale_ans, scale_ref)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"shape",
|
||||
[(32, 4096), (128, 4096), (1, 64), (127, 1024), (256, 16384)],
|
||||
)
|
||||
@pytest.mark.parametrize("is_sf_swizzled_layout", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_python_util_matches_cpp_allocation(
|
||||
shape: tuple[int, int],
|
||||
is_sf_swizzled_layout: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Verify that the Python utility (create_fp4_output_tensors) allocates
|
||||
tensors with the same shapes and dtypes as the C++ functional variant
|
||||
(scaled_fp4_quant_func).
|
||||
"""
|
||||
from vllm._custom_ops import create_fp4_output_tensors
|
||||
|
||||
torch.set_default_device("cuda:0")
|
||||
m, n = shape
|
||||
input_tensor = torch.randn((m, n), dtype=torch.bfloat16)
|
||||
input_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda:0")
|
||||
|
||||
# C++ functional variant allocates internally
|
||||
cpp_out, cpp_scale = torch.ops._C.scaled_fp4_quant(
|
||||
input_tensor, input_scale, is_sf_swizzled_layout
|
||||
)
|
||||
|
||||
# Python utility
|
||||
py_out, py_scale = create_fp4_output_tensors(
|
||||
m, n, torch.device("cuda:0"), is_sf_swizzled_layout
|
||||
)
|
||||
|
||||
assert py_out.shape == cpp_out.shape, (
|
||||
f"Output shape mismatch: Python {py_out.shape} vs C++ {cpp_out.shape}"
|
||||
)
|
||||
assert py_out.dtype == cpp_out.dtype, (
|
||||
f"Output dtype mismatch: Python {py_out.dtype} vs C++ {cpp_out.dtype}"
|
||||
)
|
||||
assert py_scale.shape == cpp_scale.shape, (
|
||||
f"Scale shape mismatch: Python {py_scale.shape} vs C++ {cpp_scale.shape}"
|
||||
)
|
||||
assert py_scale.dtype == cpp_scale.dtype, (
|
||||
f"Scale dtype mismatch: Python {py_scale.dtype} vs C++ {cpp_scale.dtype}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("pad_shape", PAD_SHAPES)
|
||||
@torch.inference_mode()
|
||||
def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None:
|
||||
|
||||
@@ -29,6 +29,81 @@ else:
|
||||
from torch.library import impl_abstract as register_fake
|
||||
|
||||
|
||||
# scaled_fp4_quant functional + out variant for torch.compile buffer management
|
||||
|
||||
|
||||
def create_fp4_scale_tensor(
|
||||
m: int,
|
||||
n: int,
|
||||
device: torch.device,
|
||||
is_sf_swizzled_layout: bool,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Allocate the output scale tensor for scaled_fp4_quant.
|
||||
|
||||
When is_sf_swizzled_layout=True, we use rounded values to store the
|
||||
swizzled scales. Due to the requirement of the Tensor Core, the minimum
|
||||
tile is 128x4 for the scales. So, we first pad the scales to multiples
|
||||
of 128 (rows) and 4 (cols). Then, the scales (in float8_e4m3fn) are
|
||||
packed into an int32 for every 4 values. More:
|
||||
https://docs.nvidia.com/cuda/parallel-thread-execution/
|
||||
#tcgen05-mma-scale-factor-b-layout-4x
|
||||
"""
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
block_size = 16
|
||||
if is_sf_swizzled_layout:
|
||||
rounded_m = round_up(m, 128)
|
||||
scale_n = n // block_size
|
||||
rounded_n = round_up(scale_n, 4)
|
||||
return torch.empty(
|
||||
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
|
||||
)
|
||||
else:
|
||||
return torch.empty((m, n // block_size), device=device, dtype=torch.uint8)
|
||||
|
||||
|
||||
def create_fp4_output_tensors(
|
||||
m: int,
|
||||
n: int,
|
||||
device: torch.device,
|
||||
is_sf_swizzled_layout: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Allocate both output tensors for scaled_fp4_quant:
|
||||
(quantized_output, output_scale).
|
||||
|
||||
Must match the C++ scaled_fp4_quant_func allocation exactly.
|
||||
"""
|
||||
output = torch.empty((m, n // 2), device=device, dtype=torch.uint8)
|
||||
output_scale = create_fp4_scale_tensor(m, n, device, is_sf_swizzled_layout)
|
||||
return output, output_scale
|
||||
|
||||
|
||||
if hasattr(torch.ops, "_C") and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
|
||||
@register_fake("_C::scaled_fp4_quant")
|
||||
def _scaled_fp4_quant_fake(
|
||||
input: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
is_sf_swizzled_layout: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
n = input.shape[-1]
|
||||
m = input.numel() // n
|
||||
return create_fp4_output_tensors(m, n, input.device, is_sf_swizzled_layout)
|
||||
|
||||
@register_fake("_C::scaled_fp4_quant.out")
|
||||
def _scaled_fp4_quant_out_fake(
|
||||
input: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
is_sf_swizzled_layout: bool,
|
||||
*,
|
||||
output: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
|
||||
# page attention ops
|
||||
def paged_attention_v1(
|
||||
out: torch.Tensor,
|
||||
@@ -1644,7 +1719,6 @@ def scaled_fp4_quant(
|
||||
input = input.reshape(other_dims, input.shape[-1])
|
||||
m, n = input.shape
|
||||
block_size = 16
|
||||
device = input.device
|
||||
|
||||
assert n % block_size == 0, f"last dim has to be multiple of 16, but got {n}."
|
||||
assert input.dtype in (torch.float16, torch.bfloat16), (
|
||||
@@ -1658,26 +1732,16 @@ def scaled_fp4_quant(
|
||||
input, input_global_scale
|
||||
)
|
||||
else:
|
||||
# Two fp4 values will be packed into an uint8.
|
||||
output = torch.empty((m, n // 2), device=device, dtype=torch.uint8)
|
||||
if is_sf_swizzled_layout:
|
||||
# We use the rounded values to store the swizzled values. Due to the
|
||||
# requirement of the Tensor Core, the minimum tile is 128x4 for the scales.
|
||||
# So, we first pad the scales to multiples of 128 and 4. Then, the scales
|
||||
# (in float8_e4m3fn) are packed into an int32 for every 4 values. More:
|
||||
# https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x
|
||||
round_up = lambda x, y: (x + y - 1) // y * y
|
||||
rounded_m = round_up(m, 128)
|
||||
scale_n = n // block_size
|
||||
rounded_n = round_up(scale_n, 4)
|
||||
output_scale = torch.empty(
|
||||
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
|
||||
)
|
||||
else:
|
||||
output_scale = torch.empty((m, n // 16), device=device, dtype=torch.uint8)
|
||||
|
||||
torch.ops._C.scaled_fp4_quant(
|
||||
output, input, output_scale, input_global_scale, is_sf_swizzled_layout
|
||||
# Pre-allocate and call .out variant (same behavior as old in-place API)
|
||||
output, output_scale = create_fp4_output_tensors(
|
||||
m, n, input.device, is_sf_swizzled_layout
|
||||
)
|
||||
torch.ops._C.scaled_fp4_quant.out(
|
||||
input,
|
||||
input_global_scale,
|
||||
is_sf_swizzled_layout,
|
||||
output=output,
|
||||
output_scale=output_scale,
|
||||
)
|
||||
|
||||
output_scale = output_scale.view(torch.float8_e4m3fn)
|
||||
|
||||
@@ -148,11 +148,11 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
|
||||
result_silu_mul = self.silu_and_mul_matcher(input)
|
||||
at = auto_functionalized(
|
||||
self.QUANT_OP,
|
||||
output=result,
|
||||
input=result_silu_mul,
|
||||
output_scale=output_scale,
|
||||
input_scale=scale,
|
||||
is_sf_swizzled_layout=True,
|
||||
output=result,
|
||||
output_scale=output_scale,
|
||||
)
|
||||
return at[1], at[2]
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ if find_spec("flashinfer"):
|
||||
pass
|
||||
|
||||
if hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default
|
||||
STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.out
|
||||
|
||||
# Max size of the input tensor per world size per device capability
|
||||
# to use flashinfer fused allreduce
|
||||
@@ -562,11 +562,11 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||
rms = self.rmsnorm_matcher(all_reduce, weight)
|
||||
quant_out_tuple = auto_functionalized(
|
||||
STATIC_FP4_QUANT_OP,
|
||||
output=quant_result,
|
||||
input=rms,
|
||||
output_scale=output_scale,
|
||||
input_scale=input_global_scale,
|
||||
is_sf_swizzled_layout=True,
|
||||
output=quant_result,
|
||||
output_scale=output_scale,
|
||||
)
|
||||
|
||||
# quant_out, allreduce_output, output_scale
|
||||
@@ -660,11 +660,11 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||
rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual)
|
||||
quant_out_tuple = auto_functionalized(
|
||||
STATIC_FP4_QUANT_OP,
|
||||
output=quant_result,
|
||||
input=rms,
|
||||
output_scale=output_scale,
|
||||
input_scale=input_global_scale,
|
||||
is_sf_swizzled_layout=True,
|
||||
output=quant_result,
|
||||
output_scale=output_scale,
|
||||
)
|
||||
|
||||
# quant_out, allreduce_output, output_scale
|
||||
|
||||
@@ -250,11 +250,11 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
||||
)
|
||||
at2 = auto_functionalized(
|
||||
self.QUANT_OP,
|
||||
output=output_quant,
|
||||
input=attn_out_view,
|
||||
output_scale=output_scale,
|
||||
input_scale=input_scale,
|
||||
is_sf_swizzled_layout=True,
|
||||
output=output_quant,
|
||||
output_scale=output_scale,
|
||||
)
|
||||
output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
|
||||
return at2[1], output_scale_view
|
||||
|
||||
@@ -38,7 +38,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
|
||||
}
|
||||
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
|
||||
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.out # noqa: E501
|
||||
|
||||
if current_platform.is_cuda():
|
||||
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||
|
||||
@@ -63,7 +63,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
|
||||
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
|
||||
}
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default
|
||||
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.out
|
||||
if current_platform.is_cuda():
|
||||
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||
|
||||
Reference in New Issue
Block a user