[Kernel] Remove scaled_fp8_quant kernel padding footgun (#6842)

This commit is contained in:
Tyler Michael Smith
2024-07-30 16:37:01 -04:00
committed by GitHub
parent 052b6f8ca4
commit d7a299edaa
3 changed files with 17 additions and 14 deletions

View File

@@ -123,7 +123,7 @@ def test_scaled_fp8_quant(dtype) -> None:
assert torch.allclose(ref_y, per_tensor_dequantize(y, inv_scale, dtype))
# Padding
y, _ = ops.scaled_fp8_quant(x, inv_scale, batch_dim_padding=17)
y, _ = ops.scaled_fp8_quant(x, inv_scale, num_token_padding=17)
assert y.shape[0] == 17
assert torch.allclose(
ref_y,