[Feature] Non-contiguous Support for FP8 Quantization (#21961)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Wentao Ye
2025-08-05 05:36:43 -04:00
committed by GitHub
parent 05fae02175
commit 4771df7b2b
4 changed files with 207 additions and 201 deletions

View File

@@ -194,3 +194,36 @@ def test_scaled_fp8_quant(dtype) -> None:
ref_y,
per_tensor_dequantize(torch.narrow(y, 0, 0, x.shape[0]), inv_scale,
dtype))
# non-contiguous input with padding
m, n, padded_stride = 975, 512, 576
padded_tensor = (torch.randn(size=(m, padded_stride), device="cuda") *
13).to(dtype)
x_nc = padded_tensor[:, :n] # shape (m, n) with stride (padded_stride, 1)
assert not x_nc.is_contiguous()
assert x_nc.stride(0) == padded_stride
# dynamic quantization
ref_y_nc, inv_scale_nc = ops.scaled_fp8_quant(x_nc, None)
ref_y_nc = per_tensor_dequantize(ref_y_nc, inv_scale_nc, dtype)
# reference dynamic quantization
y_nc = quantize_ref(x_nc, inv_scale_nc)
torch.testing.assert_close(
ref_y_nc, per_tensor_dequantize(y_nc, inv_scale_nc, dtype))
# static quantization
y_nc, _ = ops.scaled_fp8_quant(x_nc, inv_scale_nc)
torch.testing.assert_close(
ref_y_nc, per_tensor_dequantize(y_nc, inv_scale_nc, dtype))
# padding after non-contiguous input quantization
y_nc_pad, _ = ops.scaled_fp8_quant(x_nc,
inv_scale_nc,
num_token_padding=m + 10)
assert y_nc_pad.shape[0] == m + 10
torch.testing.assert_close(
ref_y_nc,
per_tensor_dequantize(torch.narrow(y_nc_pad, 0, 0, x_nc.shape[0]),
inv_scale_nc, dtype))