[bugfix] Fix static asymmetric quantization case (#10334)
Signed-off-by: Daniël de Kok <me@danieldk.eu> Signed-off-by: luka <luka@neuralmagic.com> Co-authored-by: Daniël de Kok <me@danieldk.eu>
This commit is contained in:
@@ -86,10 +86,7 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
|
||||
assert torch_out.min() >= int8_traits.min and torch_out.max(
|
||||
) <= int8_traits.max
|
||||
|
||||
ops_out = torch.empty_like(x, dtype=torch.int8)
|
||||
scales_out = torch.empty_like(scales, dtype=torch.float32)
|
||||
azp_out = torch.empty_like(azps, dtype=torch.int32)
|
||||
torch.ops._C.dynamic_scaled_int8_quant(ops_out, x, scales_out, azp_out)
|
||||
ops_out, scales_out, azp_out = scaled_int8_quant(x, symmetric=False)
|
||||
|
||||
if (not torch.allclose(scales_out, scales)):
|
||||
print(torch.argmax(torch.abs(scales_out - scales)))
|
||||
@@ -119,7 +116,8 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
|
||||
|
||||
out1 = (x / scale_arg).round().clamp(int8_traits.min,
|
||||
int8_traits.max).to(torch.int8)
|
||||
out2, _, _ = scaled_int8_quant(x, scale_arg)
|
||||
out2, scale2, _ = scaled_int8_quant(x, scale_arg)
|
||||
assert scale2 is scale_arg
|
||||
|
||||
# big atol to account for rounding errors
|
||||
torch.testing.assert_close(out1, out2, atol=1, rtol=0.0)
|
||||
@@ -145,11 +143,15 @@ def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
|
||||
|
||||
out1 = ((x / scale).round() + azp).clamp(int8_traits.min,
|
||||
int8_traits.max).to(torch.int8)
|
||||
out2 = torch.empty_like(x, dtype=torch.int8)
|
||||
scale_arg = torch.tensor([scale], dtype=torch.float32, device="cuda")
|
||||
azp_arg = torch.tensor([azp], dtype=torch.int32, device="cuda")
|
||||
|
||||
torch.ops._C.static_scaled_int8_quant(out2, x, scale_arg, azp_arg)
|
||||
out2, scale2, azp2 = scaled_int8_quant(x,
|
||||
scale_arg,
|
||||
azp_arg,
|
||||
symmetric=False)
|
||||
assert scale2 is scale_arg
|
||||
assert azp2 is azp_arg
|
||||
|
||||
# big atol to account for rounding errors
|
||||
torch.testing.assert_close(out1, out2, atol=1, rtol=0.0)
|
||||
@@ -184,6 +186,5 @@ def test_static_scaled_int8_azp_quant_saturating_cast(is_max: bool) -> None:
|
||||
val_i8 = int8_traits.max if is_max else int8_traits.min
|
||||
expected = torch.full((1, 5), val_i8, dtype=torch.int8, device="cuda")
|
||||
|
||||
out = torch.empty_like(expected)
|
||||
torch.ops._C.static_scaled_int8_quant(out, x, scale, azp)
|
||||
out, _, _ = scaled_int8_quant(x, scale, azp, symmetric=False)
|
||||
torch.testing.assert_close(expected, out, atol=0, rtol=0)
|
||||
|
||||
Reference in New Issue
Block a user