[bugfix] Fix static asymmetric quantization case (#10334)

Signed-off-by: Daniël de Kok <me@danieldk.eu>
Signed-off-by: luka <luka@neuralmagic.com>
Co-authored-by: Daniël de Kok <me@danieldk.eu>
This commit is contained in:
Luka Govedič
2024-11-14 20:35:11 -05:00
committed by GitHub
parent 972112d82f
commit bf2ddc6610
5 changed files with 58 additions and 15 deletions

View File

@@ -86,10 +86,7 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
assert torch_out.min() >= int8_traits.min and torch_out.max(
) <= int8_traits.max
ops_out = torch.empty_like(x, dtype=torch.int8)
scales_out = torch.empty_like(scales, dtype=torch.float32)
azp_out = torch.empty_like(azps, dtype=torch.int32)
torch.ops._C.dynamic_scaled_int8_quant(ops_out, x, scales_out, azp_out)
ops_out, scales_out, azp_out = scaled_int8_quant(x, symmetric=False)
if (not torch.allclose(scales_out, scales)):
print(torch.argmax(torch.abs(scales_out - scales)))
@@ -119,7 +116,8 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
out1 = (x / scale_arg).round().clamp(int8_traits.min,
int8_traits.max).to(torch.int8)
out2, _, _ = scaled_int8_quant(x, scale_arg)
out2, scale2, _ = scaled_int8_quant(x, scale_arg)
assert scale2 is scale_arg
# big atol to account for rounding errors
torch.testing.assert_close(out1, out2, atol=1, rtol=0.0)
@@ -145,11 +143,15 @@ def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
out1 = ((x / scale).round() + azp).clamp(int8_traits.min,
int8_traits.max).to(torch.int8)
out2 = torch.empty_like(x, dtype=torch.int8)
scale_arg = torch.tensor([scale], dtype=torch.float32, device="cuda")
azp_arg = torch.tensor([azp], dtype=torch.int32, device="cuda")
torch.ops._C.static_scaled_int8_quant(out2, x, scale_arg, azp_arg)
out2, scale2, azp2 = scaled_int8_quant(x,
scale_arg,
azp_arg,
symmetric=False)
assert scale2 is scale_arg
assert azp2 is azp_arg
# big atol to account for rounding errors
torch.testing.assert_close(out1, out2, atol=1, rtol=0.0)
@@ -184,6 +186,5 @@ def test_static_scaled_int8_azp_quant_saturating_cast(is_max: bool) -> None:
val_i8 = int8_traits.max if is_max else int8_traits.min
expected = torch.full((1, 5), val_i8, dtype=torch.int8, device="cuda")
out = torch.empty_like(expected)
torch.ops._C.static_scaled_int8_quant(out, x, scale, azp)
out, _, _ = scaled_int8_quant(x, scale, azp, symmetric=False)
torch.testing.assert_close(expected, out, atol=0, rtol=0)