Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -44,24 +44,27 @@ def ref_int8_scaled_mm(
|
||||
):
|
||||
if azp is not None:
|
||||
a = a.to(dtype=torch.float32) - azp.to(dtype=torch.float32)
|
||||
output = torch.mm((scale_a * a.to(dtype=torch.float32)),
|
||||
(scale_b * b.to(dtype=torch.float32)))
|
||||
output = torch.mm(
|
||||
(scale_a * a.to(dtype=torch.float32)), (scale_b * b.to(dtype=torch.float32))
|
||||
)
|
||||
if bias is not None:
|
||||
output += bias.float()
|
||||
|
||||
return output.to(dtype=output_type)
|
||||
|
||||
|
||||
def onednn_int8_gemm_test_helper(primitive_cache_size: int,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
per_tensor_a_quant: bool,
|
||||
per_tensor_b_quant: bool,
|
||||
use_azp: bool,
|
||||
use_bias: bool,
|
||||
out_dtype: torch.dtype = torch.bfloat16,
|
||||
device: str = "cpu"):
|
||||
def onednn_int8_gemm_test_helper(
|
||||
primitive_cache_size: int,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
per_tensor_a_quant: bool,
|
||||
per_tensor_b_quant: bool,
|
||||
use_azp: bool,
|
||||
use_bias: bool,
|
||||
out_dtype: torch.dtype = torch.bfloat16,
|
||||
device: str = "cpu",
|
||||
):
|
||||
# Test for a oneDNN kernel with per-tensor / per-token activation
|
||||
# quantization and per-tensor / per-output channel weight quantization.
|
||||
a = to_int8(torch.randn((m, k), device=device) * 5)
|
||||
@@ -70,8 +73,8 @@ def onednn_int8_gemm_test_helper(primitive_cache_size: int,
|
||||
a_scales_shape = (1, 1) if per_tensor_a_quant else (m, 1)
|
||||
b_scales_shape = (1, 1) if per_tensor_b_quant else (1, n)
|
||||
|
||||
scale_a = (torch.randn(a_scales_shape, device=device, dtype=torch.float32))
|
||||
scale_b = (torch.randn(b_scales_shape, device=device, dtype=torch.float32))
|
||||
scale_a = torch.randn(a_scales_shape, device=device, dtype=torch.float32)
|
||||
scale_b = torch.randn(b_scales_shape, device=device, dtype=torch.float32)
|
||||
|
||||
if use_azp:
|
||||
azp = torch.rand(a_scales_shape, dtype=torch.float32) * 10 + 1.5
|
||||
@@ -82,7 +85,7 @@ def onednn_int8_gemm_test_helper(primitive_cache_size: int,
|
||||
azp_adj = None
|
||||
|
||||
if use_bias:
|
||||
bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
|
||||
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10
|
||||
else:
|
||||
bias = None
|
||||
|
||||
@@ -105,20 +108,21 @@ def onednn_int8_gemm_test_helper(primitive_cache_size: int,
|
||||
# To test runtime bias setting
|
||||
out = torch.zeros((m, n), dtype=out_dtype)
|
||||
ops.onednn_scaled_mm(handler, a, out, scale_a, azp, azp_adj, None)
|
||||
baseline = ref_int8_scaled_mm(a, b, scale_a, scale_b, azp, None,
|
||||
out_dtype)
|
||||
baseline = ref_int8_scaled_mm(a, b, scale_a, scale_b, azp, None, out_dtype)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
||||
|
||||
|
||||
def onednn_gemm_test_helper(primitive_cache_size: int,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
use_bias: bool,
|
||||
use_stride: bool,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
device: str = "cpu"):
|
||||
def onednn_gemm_test_helper(
|
||||
primitive_cache_size: int,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
use_bias: bool,
|
||||
use_stride: bool,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
device: str = "cpu",
|
||||
):
|
||||
if use_stride:
|
||||
a = torch.rand((m, 2 * k), dtype=dtype, device=device) * 1.5
|
||||
a = a[:, :k]
|
||||
@@ -128,7 +132,7 @@ def onednn_gemm_test_helper(primitive_cache_size: int,
|
||||
b = torch.rand((n, k), dtype=dtype, device=device) * 1.5
|
||||
|
||||
if use_bias:
|
||||
bias = torch.rand((n, ), device=device, dtype=dtype) * 5
|
||||
bias = torch.rand((n,), device=device, dtype=dtype) * 5
|
||||
bias_f32 = bias.float()
|
||||
else:
|
||||
bias = None
|
||||
@@ -140,16 +144,18 @@ def onednn_gemm_test_helper(primitive_cache_size: int,
|
||||
)
|
||||
|
||||
out = ops.onednn_mm(handler, a, bias)
|
||||
baseline = torch.nn.functional.linear(a.float(), b.float(),
|
||||
bias_f32).to(dtype=a.dtype)
|
||||
baseline = torch.nn.functional.linear(a.float(), b.float(), bias_f32).to(
|
||||
dtype=a.dtype
|
||||
)
|
||||
|
||||
torch.testing.assert_close(out, baseline)
|
||||
|
||||
if use_bias:
|
||||
# To test runtime bias setting
|
||||
out = ops.onednn_mm(handler, a, None)
|
||||
baseline = torch.nn.functional.linear(a.float(), b.float(),
|
||||
None).to(dtype=a.dtype)
|
||||
baseline = torch.nn.functional.linear(a.float(), b.float(), None).to(
|
||||
dtype=a.dtype
|
||||
)
|
||||
|
||||
torch.testing.assert_close(out, baseline)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user