Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -30,12 +30,9 @@ def apply_w8a8_block_int8_linear(
output_shape = [*input.shape[:-1], weight.shape[0]]
q_input, x_scale = per_token_group_quant_int8(input_2d, block_size[1])
output = w8a8_block_int8_matmul(q_input,
weight,
x_scale,
weight_scale,
block_size,
output_dtype=input.dtype)
output = w8a8_block_int8_matmul(
q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype
)
if bias is not None:
output = output + bias
@@ -43,8 +40,8 @@ def apply_w8a8_block_int8_linear(
def input_to_int8(
x: torch.Tensor,
dtype: torch.dtype = torch.int8) -> tuple[torch.Tensor, torch.Tensor]:
x: torch.Tensor, dtype: torch.dtype = torch.int8
) -> tuple[torch.Tensor, torch.Tensor]:
"""This function quantizes input values to int8 values with
tensor-wise quantization."""
iinfo = torch.iinfo(dtype)
@@ -78,8 +75,8 @@ def block_dequant(
for i in range(k_tiles):
for j in range(n_tiles):
x_dq_block[
j * block_n:min((j + 1) * block_n, n),
i * block_k:min((i + 1) * block_k, k),
j * block_n : min((j + 1) * block_n, n),
i * block_k : min((i + 1) * block_k, k),
] *= x_s[j][i]
return x_dq_block
@@ -91,15 +88,17 @@ if current_platform.is_rocm():
# NOTE: This can be removed when hip.libdevice.round() is available.
@core.extern
def round_f32(arg0, _builder=None):
return core.extern_elementwise("",
"", [arg0], {
(core.dtype("fp32"), ):
("llvm.round", core.dtype("fp32")),
(core.dtype("fp64"), ):
("llvm.round", core.dtype("fp64")),
},
is_pure=True,
_builder=_builder)
return core.extern_elementwise(
"",
"",
[arg0],
{
(core.dtype("fp32"),): ("llvm.round", core.dtype("fp32")),
(core.dtype("fp64"),): ("llvm.round", core.dtype("fp64")),
},
is_pure=True,
_builder=_builder,
)
@triton.jit
def round_int8(x):
@@ -127,8 +126,7 @@ def _per_token_quant_int8(
cols = tl.arange(0, BLOCK)
mask = cols < N
x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask,
other=0.0).to(tl.float32)
x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask, other=0.0).to(tl.float32)
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
scale_x = absmax / 127
x_q = x * (127 / absmax)
@@ -142,15 +140,13 @@ def per_token_quant_int8(x):
M = x.numel() // x.shape[-1]
N = x.shape[-1]
x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
scales = torch.empty(x.shape[:-1] + (1, ),
device=x.device,
dtype=torch.float32)
scales = torch.empty(x.shape[:-1] + (1,), device=x.device, dtype=torch.float32)
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
assert x.is_contiguous()
_per_token_quant_int8[(M, )](
_per_token_quant_int8[(M,)](
x,
x_q,
scales,
@@ -229,8 +225,9 @@ def per_token_group_quant_int8(
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
"""
assert (x.shape[-1] % group_size == 0
), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.shape[-1] % group_size == 0, (
"the last dimension of `x` cannot be divisible by `group_size`"
)
assert x.is_contiguous(), "`x` is not contiguous"
iinfo = torch.iinfo(dtype)
@@ -239,15 +236,15 @@ def per_token_group_quant_int8(
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
x_s = torch.empty(
x.shape[:-1] + (x.shape[-1] // group_size, ),
x.shape[:-1] + (x.shape[-1] // group_size,),
device=x.device,
dtype=torch.float32,
)
# prefer CUDA kernel if available
if current_platform.is_cuda():
torch.ops._C.per_token_group_quant_int8(x, x_q, x_s, group_size, eps,
float(int8_min),
float(int8_max))
torch.ops._C.per_token_group_quant_int8(
x, x_q, x_s, group_size, eps, float(int8_min), float(int8_max)
)
return x_q, x_s
M = x.numel() // group_size
@@ -257,7 +254,7 @@ def per_token_group_quant_int8(
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
num_stages = 1
_per_token_group_quant_int8[(M, )](
_per_token_group_quant_int8[(M,)](
x,
x_q,
x_s,
@@ -333,20 +330,15 @@ def _w8a8_block_int8_matmul(
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs,
mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
other=0.0)
b = tl.load(b_ptrs,
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0)
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
accumulator += tl.dot(a, b).to(tl.float32) * a_s[:,
None] * b_s[None, :]
accumulator += tl.dot(a, b).to(tl.float32) * a_s[:, None] * b_s[None, :]
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
@@ -365,8 +357,9 @@ def _w8a8_block_int8_matmul(
@functools.lru_cache
def get_w8a8_block_int8_configs(N: int, K: int, block_n: int,
block_k: int) -> Optional[dict[int, Any]]:
def get_w8a8_block_int8_configs(
N: int, K: int, block_n: int, block_k: int
) -> Optional[dict[int, Any]]:
"""
Return optimized configurations for the w8a8 block fp8 kernel.
@@ -382,7 +375,8 @@ def get_w8a8_block_int8_configs(N: int, K: int, block_n: int,
json_file_name = f"N={N},K={K},device_name={device_name},dtype=int8_w8a8,block_shape=[{block_n}, {block_k}].json" # noqa: E501
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.info(
@@ -395,8 +389,10 @@ def get_w8a8_block_int8_configs(N: int, K: int, block_n: int,
# If no optimized configuration is available, we will use the default
# configuration
logger.warning(
("Using default W8A8 Block INT8 kernel config. Performance might "
"be sub-optimal! Config file not found at %s"),
(
"Using default W8A8 Block INT8 kernel config. Performance might "
"be sub-optimal! Config file not found at %s"
),
config_file_path,
)
return None
@@ -441,7 +437,7 @@ def w8a8_block_int8_matmul(
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]
C_shape = A.shape[:-1] + (N, )
C_shape = A.shape[:-1] + (N,)
C = A.new_empty(C_shape, dtype=output_dtype)
configs = get_w8a8_block_int8_configs(N, K, block_size[0], block_size[1])
@@ -462,8 +458,9 @@ def w8a8_block_int8_matmul(
}
def grid(META):
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) *
triton.cdiv(N, META["BLOCK_SIZE_N"]), )
return (
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
_w8a8_block_int8_matmul[grid](
A,