Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -18,26 +18,24 @@ SCALE = [0.1, 2.1]
|
||||
|
||||
def opcheck_int8_quant_static(output, input, scale, azp=None):
|
||||
if azp is None:
|
||||
opcheck(torch.ops._C.static_scaled_int8_quant,
|
||||
(output, input, scale, None))
|
||||
opcheck(torch.ops._C.static_scaled_int8_quant, (output, input, scale, None))
|
||||
else:
|
||||
opcheck(torch.ops._C.static_scaled_int8_quant,
|
||||
(output, input, scale, azp))
|
||||
opcheck(torch.ops._C.static_scaled_int8_quant, (output, input, scale, azp))
|
||||
|
||||
|
||||
def opcheck_int8_quant_dynamic(output, input, symmetric=True):
|
||||
scale = torch.empty((input.numel() // input.shape[-1], 1),
|
||||
device=input.device,
|
||||
dtype=torch.float32)
|
||||
scale = torch.empty(
|
||||
(input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32
|
||||
)
|
||||
if symmetric:
|
||||
opcheck(torch.ops._C.dynamic_scaled_int8_quant,
|
||||
(output, input, scale, None))
|
||||
opcheck(torch.ops._C.dynamic_scaled_int8_quant, (output, input, scale, None))
|
||||
else:
|
||||
azp = torch.empty((input.numel() // input.shape[-1], 1),
|
||||
device=input.device,
|
||||
dtype=torch.int32)
|
||||
opcheck(torch.ops._C.dynamic_scaled_int8_quant,
|
||||
(output, input, scale, azp))
|
||||
azp = torch.empty(
|
||||
(input.numel() // input.shape[-1], 1),
|
||||
device=input.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
opcheck(torch.ops._C.dynamic_scaled_int8_quant, (output, input, scale, azp))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@@ -45,8 +43,9 @@ def opcheck_int8_quant_dynamic(output, input, symmetric=True):
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype, seed: int) -> None:
|
||||
def test_dynamic_scaled_int8_quant(
|
||||
num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
|
||||
@@ -68,30 +67,31 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype, seed: int) -> None:
|
||||
def test_dynamic_scaled_int8_azp_quant(
|
||||
num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype,
|
||||
device="cuda") * 1000 - 300
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 - 300
|
||||
|
||||
x_token_max, _ = x.to(dtype=torch.float32).max(dim=1, keepdim=True)
|
||||
x_token_min, _ = x.to(dtype=torch.float32).min(dim=1, keepdim=True)
|
||||
|
||||
# calculate scale and azp, and adjust the range
|
||||
scales = (x_token_max - x_token_min) / torch.tensor(255.0)
|
||||
azps = torch.round(torch.tensor(-128.0) - x_token_min / scales).to(
|
||||
torch.int32)
|
||||
azps = torch.round(torch.tensor(-128.0) - x_token_min / scales).to(torch.int32)
|
||||
|
||||
torch_out = ((x / scales).round() + azps).clamp(
|
||||
int8_traits.min, int8_traits.max).to(torch.int8)
|
||||
assert torch_out.min() >= int8_traits.min and torch_out.max(
|
||||
) <= int8_traits.max
|
||||
torch_out = (
|
||||
((x / scales).round() + azps)
|
||||
.clamp(int8_traits.min, int8_traits.max)
|
||||
.to(torch.int8)
|
||||
)
|
||||
assert torch_out.min() >= int8_traits.min and torch_out.max() <= int8_traits.max
|
||||
|
||||
ops_out, scales_out, azp_out = scaled_int8_quant(x, symmetric=False)
|
||||
|
||||
if (not torch.allclose(scales_out, scales)):
|
||||
if not torch.allclose(scales_out, scales):
|
||||
print(torch.argmax(torch.abs(scales_out - scales)))
|
||||
torch.testing.assert_close(scales_out, scales)
|
||||
# big atol to account for rounding errors
|
||||
@@ -108,17 +108,18 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("scale", SCALE)
|
||||
@torch.inference_mode()
|
||||
def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype, seed: int,
|
||||
scale: float) -> None:
|
||||
def test_static_scaled_int8_quant(
|
||||
num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int, scale: float
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
|
||||
scale_arg = torch.tensor([scale], dtype=torch.float32, device="cuda")
|
||||
|
||||
out1 = (x / scale_arg).round().clamp(int8_traits.min,
|
||||
int8_traits.max).to(torch.int8)
|
||||
out1 = (
|
||||
(x / scale_arg).round().clamp(int8_traits.min, int8_traits.max).to(torch.int8)
|
||||
)
|
||||
out2, scale2, _ = scaled_int8_quant(x, scale_arg)
|
||||
assert scale2 is scale_arg
|
||||
|
||||
@@ -135,24 +136,28 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
|
||||
@pytest.mark.parametrize("scale", SCALE)
|
||||
@pytest.mark.parametrize("azp", [-255, 54])
|
||||
@torch.inference_mode()
|
||||
def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype, seed: int,
|
||||
scale: float, azp: int) -> None:
|
||||
def test_static_scaled_int8_azp_quant(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
scale: float,
|
||||
azp: int,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype,
|
||||
device="cuda") * 1000 - 300
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 - 300
|
||||
|
||||
out1 = ((x / scale).round() + azp).clamp(int8_traits.min,
|
||||
int8_traits.max).to(torch.int8)
|
||||
out1 = (
|
||||
((x / scale).round() + azp)
|
||||
.clamp(int8_traits.min, int8_traits.max)
|
||||
.to(torch.int8)
|
||||
)
|
||||
scale_arg = torch.tensor([scale], dtype=torch.float32, device="cuda")
|
||||
azp_arg = torch.tensor([azp], dtype=torch.int32, device="cuda")
|
||||
|
||||
out2, scale2, azp2 = scaled_int8_quant(x,
|
||||
scale_arg,
|
||||
azp_arg,
|
||||
symmetric=False)
|
||||
out2, scale2, azp2 = scaled_int8_quant(x, scale_arg, azp_arg, symmetric=False)
|
||||
assert scale2 is scale_arg
|
||||
assert azp2 is azp_arg
|
||||
|
||||
@@ -172,10 +177,7 @@ def test_static_scaled_int8_azp_quant_saturating_cast(is_max: bool) -> None:
|
||||
int32_traits = torch.iinfo(torch.int32)
|
||||
val = float(int32_traits.max if is_max else int32_traits.min)
|
||||
|
||||
x_vals = [[
|
||||
nextafter(val, inf), val + 1, val, val - 1,
|
||||
nextafter(val, -inf)
|
||||
]]
|
||||
x_vals = [[nextafter(val, inf), val + 1, val, val - 1, nextafter(val, -inf)]]
|
||||
x = torch.tensor(x_vals, dtype=torch.float32, device="cuda")
|
||||
|
||||
# The calculation in the kernel is: cast<int8>(cast<int32>(x / scale) + azp)
|
||||
|
||||
Reference in New Issue
Block a user