Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -30,12 +30,9 @@ def apply_w8a8_block_int8_linear(
|
||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||
|
||||
q_input, x_scale = per_token_group_quant_int8(input_2d, block_size[1])
|
||||
output = w8a8_block_int8_matmul(q_input,
|
||||
weight,
|
||||
x_scale,
|
||||
weight_scale,
|
||||
block_size,
|
||||
output_dtype=input.dtype)
|
||||
output = w8a8_block_int8_matmul(
|
||||
q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype
|
||||
)
|
||||
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
@@ -43,8 +40,8 @@ def apply_w8a8_block_int8_linear(
|
||||
|
||||
|
||||
def input_to_int8(
|
||||
x: torch.Tensor,
|
||||
dtype: torch.dtype = torch.int8) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
x: torch.Tensor, dtype: torch.dtype = torch.int8
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""This function quantizes input values to int8 values with
|
||||
tensor-wise quantization."""
|
||||
iinfo = torch.iinfo(dtype)
|
||||
@@ -78,8 +75,8 @@ def block_dequant(
|
||||
for i in range(k_tiles):
|
||||
for j in range(n_tiles):
|
||||
x_dq_block[
|
||||
j * block_n:min((j + 1) * block_n, n),
|
||||
i * block_k:min((i + 1) * block_k, k),
|
||||
j * block_n : min((j + 1) * block_n, n),
|
||||
i * block_k : min((i + 1) * block_k, k),
|
||||
] *= x_s[j][i]
|
||||
|
||||
return x_dq_block
|
||||
@@ -91,15 +88,17 @@ if current_platform.is_rocm():
|
||||
# NOTE: This can be removed when hip.libdevice.round() is available.
|
||||
@core.extern
|
||||
def round_f32(arg0, _builder=None):
|
||||
return core.extern_elementwise("",
|
||||
"", [arg0], {
|
||||
(core.dtype("fp32"), ):
|
||||
("llvm.round", core.dtype("fp32")),
|
||||
(core.dtype("fp64"), ):
|
||||
("llvm.round", core.dtype("fp64")),
|
||||
},
|
||||
is_pure=True,
|
||||
_builder=_builder)
|
||||
return core.extern_elementwise(
|
||||
"",
|
||||
"",
|
||||
[arg0],
|
||||
{
|
||||
(core.dtype("fp32"),): ("llvm.round", core.dtype("fp32")),
|
||||
(core.dtype("fp64"),): ("llvm.round", core.dtype("fp64")),
|
||||
},
|
||||
is_pure=True,
|
||||
_builder=_builder,
|
||||
)
|
||||
|
||||
@triton.jit
|
||||
def round_int8(x):
|
||||
@@ -127,8 +126,7 @@ def _per_token_quant_int8(
|
||||
cols = tl.arange(0, BLOCK)
|
||||
mask = cols < N
|
||||
|
||||
x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask,
|
||||
other=0.0).to(tl.float32)
|
||||
x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask, other=0.0).to(tl.float32)
|
||||
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
|
||||
scale_x = absmax / 127
|
||||
x_q = x * (127 / absmax)
|
||||
@@ -142,15 +140,13 @@ def per_token_quant_int8(x):
|
||||
M = x.numel() // x.shape[-1]
|
||||
N = x.shape[-1]
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
|
||||
scales = torch.empty(x.shape[:-1] + (1, ),
|
||||
device=x.device,
|
||||
dtype=torch.float32)
|
||||
scales = torch.empty(x.shape[:-1] + (1,), device=x.device, dtype=torch.float32)
|
||||
BLOCK = triton.next_power_of_2(N)
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK // 256, 1), 8)
|
||||
|
||||
assert x.is_contiguous()
|
||||
_per_token_quant_int8[(M, )](
|
||||
_per_token_quant_int8[(M,)](
|
||||
x,
|
||||
x_q,
|
||||
scales,
|
||||
@@ -229,8 +225,9 @@ def per_token_group_quant_int8(
|
||||
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
|
||||
scaling factor for quantization.
|
||||
"""
|
||||
assert (x.shape[-1] % group_size == 0
|
||||
), "the last dimension of `x` cannot be divisible by `group_size`"
|
||||
assert x.shape[-1] % group_size == 0, (
|
||||
"the last dimension of `x` cannot be divisible by `group_size`"
|
||||
)
|
||||
assert x.is_contiguous(), "`x` is not contiguous"
|
||||
|
||||
iinfo = torch.iinfo(dtype)
|
||||
@@ -239,15 +236,15 @@ def per_token_group_quant_int8(
|
||||
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
||||
x_s = torch.empty(
|
||||
x.shape[:-1] + (x.shape[-1] // group_size, ),
|
||||
x.shape[:-1] + (x.shape[-1] // group_size,),
|
||||
device=x.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
# prefer CUDA kernel if available
|
||||
if current_platform.is_cuda():
|
||||
torch.ops._C.per_token_group_quant_int8(x, x_q, x_s, group_size, eps,
|
||||
float(int8_min),
|
||||
float(int8_max))
|
||||
torch.ops._C.per_token_group_quant_int8(
|
||||
x, x_q, x_s, group_size, eps, float(int8_min), float(int8_max)
|
||||
)
|
||||
return x_q, x_s
|
||||
|
||||
M = x.numel() // group_size
|
||||
@@ -257,7 +254,7 @@ def per_token_group_quant_int8(
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK // 256, 1), 8)
|
||||
num_stages = 1
|
||||
_per_token_group_quant_int8[(M, )](
|
||||
_per_token_group_quant_int8[(M,)](
|
||||
x,
|
||||
x_q,
|
||||
x_s,
|
||||
@@ -333,20 +330,15 @@ def _w8a8_block_int8_matmul(
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
a = tl.load(a_ptrs,
|
||||
mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
|
||||
other=0.0)
|
||||
b = tl.load(b_ptrs,
|
||||
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
|
||||
other=0.0)
|
||||
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
|
||||
k_start = k * BLOCK_SIZE_K
|
||||
offs_ks = k_start // group_k
|
||||
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
|
||||
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
|
||||
|
||||
accumulator += tl.dot(a, b).to(tl.float32) * a_s[:,
|
||||
None] * b_s[None, :]
|
||||
accumulator += tl.dot(a, b).to(tl.float32) * a_s[:, None] * b_s[None, :]
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
|
||||
@@ -365,8 +357,9 @@ def _w8a8_block_int8_matmul(
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_w8a8_block_int8_configs(N: int, K: int, block_n: int,
|
||||
block_k: int) -> Optional[dict[int, Any]]:
|
||||
def get_w8a8_block_int8_configs(
|
||||
N: int, K: int, block_n: int, block_k: int
|
||||
) -> Optional[dict[int, Any]]:
|
||||
"""
|
||||
Return optimized configurations for the w8a8 block fp8 kernel.
|
||||
|
||||
@@ -382,7 +375,8 @@ def get_w8a8_block_int8_configs(N: int, K: int, block_n: int,
|
||||
json_file_name = f"N={N},K={K},device_name={device_name},dtype=int8_w8a8,block_shape=[{block_n}, {block_k}].json" # noqa: E501
|
||||
|
||||
config_file_path = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
|
||||
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
|
||||
)
|
||||
if os.path.exists(config_file_path):
|
||||
with open(config_file_path) as f:
|
||||
logger.info(
|
||||
@@ -395,8 +389,10 @@ def get_w8a8_block_int8_configs(N: int, K: int, block_n: int,
|
||||
# If no optimized configuration is available, we will use the default
|
||||
# configuration
|
||||
logger.warning(
|
||||
("Using default W8A8 Block INT8 kernel config. Performance might "
|
||||
"be sub-optimal! Config file not found at %s"),
|
||||
(
|
||||
"Using default W8A8 Block INT8 kernel config. Performance might "
|
||||
"be sub-optimal! Config file not found at %s"
|
||||
),
|
||||
config_file_path,
|
||||
)
|
||||
return None
|
||||
@@ -441,7 +437,7 @@ def w8a8_block_int8_matmul(
|
||||
assert triton.cdiv(N, block_n) == Bs.shape[0]
|
||||
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
||||
|
||||
C_shape = A.shape[:-1] + (N, )
|
||||
C_shape = A.shape[:-1] + (N,)
|
||||
C = A.new_empty(C_shape, dtype=output_dtype)
|
||||
|
||||
configs = get_w8a8_block_int8_configs(N, K, block_size[0], block_size[1])
|
||||
@@ -462,8 +458,9 @@ def w8a8_block_int8_matmul(
|
||||
}
|
||||
|
||||
def grid(META):
|
||||
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) *
|
||||
triton.cdiv(N, META["BLOCK_SIZE_N"]), )
|
||||
return (
|
||||
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
_w8a8_block_int8_matmul[grid](
|
||||
A,
|
||||
|
||||
Reference in New Issue
Block a user