[Custom Ops] Add functional + out variant for scaled_fp4_quant (#34389)
Signed-off-by: tianrengao <terrygao87@gmail.com>
This commit is contained in:
@@ -159,6 +159,52 @@ def test_quantize_to_fp4(
|
||||
torch.testing.assert_close(scale_ans, scale_ref)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"shape",
|
||||
[(32, 4096), (128, 4096), (1, 64), (127, 1024), (256, 16384)],
|
||||
)
|
||||
@pytest.mark.parametrize("is_sf_swizzled_layout", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_python_util_matches_cpp_allocation(
|
||||
shape: tuple[int, int],
|
||||
is_sf_swizzled_layout: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Verify that the Python utility (create_fp4_output_tensors) allocates
|
||||
tensors with the same shapes and dtypes as the C++ functional variant
|
||||
(scaled_fp4_quant_func).
|
||||
"""
|
||||
from vllm._custom_ops import create_fp4_output_tensors
|
||||
|
||||
torch.set_default_device("cuda:0")
|
||||
m, n = shape
|
||||
input_tensor = torch.randn((m, n), dtype=torch.bfloat16)
|
||||
input_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda:0")
|
||||
|
||||
# C++ functional variant allocates internally
|
||||
cpp_out, cpp_scale = torch.ops._C.scaled_fp4_quant(
|
||||
input_tensor, input_scale, is_sf_swizzled_layout
|
||||
)
|
||||
|
||||
# Python utility
|
||||
py_out, py_scale = create_fp4_output_tensors(
|
||||
m, n, torch.device("cuda:0"), is_sf_swizzled_layout
|
||||
)
|
||||
|
||||
assert py_out.shape == cpp_out.shape, (
|
||||
f"Output shape mismatch: Python {py_out.shape} vs C++ {cpp_out.shape}"
|
||||
)
|
||||
assert py_out.dtype == cpp_out.dtype, (
|
||||
f"Output dtype mismatch: Python {py_out.dtype} vs C++ {cpp_out.dtype}"
|
||||
)
|
||||
assert py_scale.shape == cpp_scale.shape, (
|
||||
f"Scale shape mismatch: Python {py_scale.shape} vs C++ {cpp_scale.shape}"
|
||||
)
|
||||
assert py_scale.dtype == cpp_scale.dtype, (
|
||||
f"Scale dtype mismatch: Python {py_scale.dtype} vs C++ {cpp_scale.dtype}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("pad_shape", PAD_SHAPES)
|
||||
@torch.inference_mode()
|
||||
def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None:
|
||||
|
||||
Reference in New Issue
Block a user