[Feature] Non-contiguous Support for FP8 Quantization (#21961)
Signed-off-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -194,3 +194,36 @@ def test_scaled_fp8_quant(dtype) -> None:
|
||||
ref_y,
|
||||
per_tensor_dequantize(torch.narrow(y, 0, 0, x.shape[0]), inv_scale,
|
||||
dtype))
|
||||
|
||||
# non-contiguous input with padding
|
||||
m, n, padded_stride = 975, 512, 576
|
||||
padded_tensor = (torch.randn(size=(m, padded_stride), device="cuda") *
|
||||
13).to(dtype)
|
||||
x_nc = padded_tensor[:, :n] # shape (m, n) with stride (padded_stride, 1)
|
||||
|
||||
assert not x_nc.is_contiguous()
|
||||
assert x_nc.stride(0) == padded_stride
|
||||
|
||||
# dynamic quantization
|
||||
ref_y_nc, inv_scale_nc = ops.scaled_fp8_quant(x_nc, None)
|
||||
ref_y_nc = per_tensor_dequantize(ref_y_nc, inv_scale_nc, dtype)
|
||||
|
||||
# reference dynamic quantization
|
||||
y_nc = quantize_ref(x_nc, inv_scale_nc)
|
||||
torch.testing.assert_close(
|
||||
ref_y_nc, per_tensor_dequantize(y_nc, inv_scale_nc, dtype))
|
||||
|
||||
# static quantization
|
||||
y_nc, _ = ops.scaled_fp8_quant(x_nc, inv_scale_nc)
|
||||
torch.testing.assert_close(
|
||||
ref_y_nc, per_tensor_dequantize(y_nc, inv_scale_nc, dtype))
|
||||
|
||||
# padding after non-contiguous input quantization
|
||||
y_nc_pad, _ = ops.scaled_fp8_quant(x_nc,
|
||||
inv_scale_nc,
|
||||
num_token_padding=m + 10)
|
||||
assert y_nc_pad.shape[0] == m + 10
|
||||
torch.testing.assert_close(
|
||||
ref_y_nc,
|
||||
per_tensor_dequantize(torch.narrow(y_nc_pad, 0, 0, x_nc.shape[0]),
|
||||
inv_scale_nc, dtype))
|
||||
|
||||
Reference in New Issue
Block a user