Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -5,8 +5,7 @@ from typing import Optional, Union
import torch
from vllm.model_executor.layers.quantization.utils.quant_utils import (
group_broadcast)
from vllm.model_executor.layers.quantization.utils.quant_utils import group_broadcast
from vllm.platforms import current_platform
from vllm.utils import round_up
@@ -17,25 +16,31 @@ FP8_DTYPE = current_platform.fp8_dtype()
def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
return torch.as_tensor(x, dtype=torch.float32, device='cuda')
return torch.as_tensor(x, dtype=torch.float32, device="cuda")
def ref_dynamic_per_token_quant(x: torch.tensor,
quant_dtype: torch.dtype,
scale_ub: Optional[torch.tensor] = None) \
-> tuple[torch.tensor, torch.tensor]:
def ref_dynamic_per_token_quant(
x: torch.tensor, quant_dtype: torch.dtype, scale_ub: Optional[torch.tensor] = None
) -> tuple[torch.tensor, torch.tensor]:
assert quant_dtype in [torch.int8, FP8_DTYPE]
if scale_ub is not None:
assert quant_dtype == FP8_DTYPE
qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \
else torch.finfo(quant_dtype)
qtype_traits_max = ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \
and current_platform.is_fp8_fnuz() \
else qtype_traits.max
qtype_traits_min = -ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \
and current_platform.is_fp8_fnuz() \
else qtype_traits.min
qtype_traits = (
torch.iinfo(quant_dtype)
if quant_dtype == torch.int8
else torch.finfo(quant_dtype)
)
qtype_traits_max = (
ROCM_FP8FNUZ_MAX
if current_platform.is_rocm() and current_platform.is_fp8_fnuz()
else qtype_traits.max
)
qtype_traits_min = (
-ROCM_FP8FNUZ_MAX
if current_platform.is_rocm() and current_platform.is_fp8_fnuz()
else qtype_traits.min
)
qtype_max = as_float32_tensor(qtype_traits_max)
s_1 = as_float32_tensor(1.0)
s_512 = as_float32_tensor(512.0)
@@ -56,15 +61,13 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
iscales = as_float32_tensor(s_1 / scales)
torch_out = as_float32_tensor(x) * iscales
torch_out = torch_out.round()
torch_out = torch_out.clamp(qtype_traits_min,
qtype_traits_max).to(quant_dtype)
torch_out = torch_out.clamp(qtype_traits_min, qtype_traits_max).to(quant_dtype)
else:
assert quant_dtype == FP8_DTYPE
min_scaling_factor = s_1 / (qtype_max * s_512)
scales = scales.clamp(min=min_scaling_factor)
torch_out = as_float32_tensor(x) / scales
torch_out = torch_out.clamp(qtype_traits_min,
qtype_traits_max).to(quant_dtype)
torch_out = torch_out.clamp(qtype_traits_min, qtype_traits_max).to(quant_dtype)
return torch_out, scales
@@ -72,16 +75,20 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
# The int8 version is very similar. Incorporate the int8 version, like in
# ref_dynamic_per_token_quant, when we have a dynamic_per_tensor int8 quant
# kernel
def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
-> tuple[torch.tensor, torch.tensor]:
def ref_dynamic_per_tensor_fp8_quant(
x: torch.tensor,
) -> tuple[torch.tensor, torch.tensor]:
fp8_traits = torch.finfo(FP8_DTYPE)
fp8_traits_max = ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \
and current_platform.is_fp8_fnuz() \
else fp8_traits.max
fp8_traits_min = -ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \
and current_platform.is_fp8_fnuz() \
else fp8_traits.min
fp8_traits_max = (
ROCM_FP8FNUZ_MAX
if current_platform.is_rocm() and current_platform.is_fp8_fnuz()
else fp8_traits.max
)
fp8_traits_min = (
-ROCM_FP8FNUZ_MAX
if current_platform.is_rocm() and current_platform.is_fp8_fnuz()
else fp8_traits.min
)
fp8_max = as_float32_tensor(fp8_traits_max)
one = as_float32_tensor(1.0)
@@ -92,9 +99,12 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
x_max = as_float32_tensor(x.abs().max())
ref_scale = x_max / fp8_max
ref_iscale = one / ref_scale
ref_out = (as_float32_tensor(x) * ref_iscale).clamp(
fp8_traits_min, fp8_traits_max).to(FP8_DTYPE)
return ref_out, ref_scale.view((1, ))
ref_out = (
(as_float32_tensor(x) * ref_iscale)
.clamp(fp8_traits_min, fp8_traits_max)
.to(FP8_DTYPE)
)
return ref_out, ref_scale.view((1,))
def native_w8a8_block_matmul(
@@ -126,7 +136,7 @@ def native_w8a8_block_matmul(
M = A.numel() // A.shape[-1]
N, K = B.shape
origin_C_shape = A.shape[:-1] + (N, )
origin_C_shape = A.shape[:-1] + (N,)
A = A.reshape(M, A.shape[-1])
As = As.reshape(M, As.shape[-1])
n_tiles = (N + block_n - 1) // block_n
@@ -137,19 +147,19 @@ def native_w8a8_block_matmul(
C_shape = (M, N)
C = torch.zeros(C_shape, dtype=compute_type, device=A.device)
A_tiles = [
A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles)
A_tiles = [A[:, i * block_k : min((i + 1) * block_k, K)] for i in range(k_tiles)]
B_tiles = [
[
B[
j * block_n : min((j + 1) * block_n, N),
i * block_k : min((i + 1) * block_k, K),
]
for i in range(k_tiles)
]
for j in range(n_tiles)
]
B_tiles = [[
B[
j * block_n:min((j + 1) * block_n, N),
i * block_k:min((i + 1) * block_k, K),
] for i in range(k_tiles)
] for j in range(n_tiles)]
C_tiles = [
C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles)
]
As_tiles = [As[:, i:i + 1] for i in range(k_tiles)]
C_tiles = [C[:, j * block_n : min((j + 1) * block_n, N)] for j in range(n_tiles)]
As_tiles = [As[:, i : i + 1] for i in range(k_tiles)]
for i in range(k_tiles):
for j in range(n_tiles):
@@ -163,14 +173,14 @@ def native_w8a8_block_matmul(
return C
def native_per_token_group_quant_fp8(x,
group_size,
eps=1e-10,
dtype=torch.float8_e4m3fn):
def native_per_token_group_quant_fp8(
x, group_size, eps=1e-10, dtype=torch.float8_e4m3fn
):
"""Function to perform per-token-group quantization on an input tensor
`x` using native torch."""
assert x.shape[-1] % group_size == 0, ("the last dimension of `x` must "
"be divisible by `group_size`")
assert x.shape[-1] % group_size == 0, (
"the last dimension of `x` must be divisible by `group_size`"
)
assert x.is_contiguous(), "`x` is not contiguous"
finfo = torch.finfo(dtype)
@@ -178,28 +188,25 @@ def native_per_token_group_quant_fp8(x,
fp8_max = finfo.max
x_ = x.reshape(x.numel() // group_size, group_size)
amax = x_.abs().max(dim=-1,
keepdim=True)[0].clamp(min=eps).to(torch.float32)
amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32)
x_s = amax / fp8_max
x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype)
x_q = x_q.reshape(x.shape)
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, ))
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,))
return x_q, x_s
def native_per_token_group_quant_int8(x,
group_size,
eps=1e-10,
dtype=torch.int8):
def native_per_token_group_quant_int8(x, group_size, eps=1e-10, dtype=torch.int8):
"""Function to perform per-token-group quantization on an input tensor
`x` using native torch.
It converts the tensor values into int8 values and returns the
quantized tensor along with the scaling factor used for quantization.
"""
assert (x.shape[-1] % group_size == 0
), "the last dimension of `x` must be divisible by `group_size`"
assert x.shape[-1] % group_size == 0, (
"the last dimension of `x` must be divisible by `group_size`"
)
assert x.is_contiguous(), "`x` is not contiguous"
iinfo = torch.iinfo(dtype)
@@ -208,13 +215,13 @@ def native_per_token_group_quant_int8(x,
x_ = x.reshape(x.numel() // group_size, group_size)
# Use float32 for scale calculation for stability
amax = x_.abs().max(dim=-1,
keepdim=True)[0].clamp(min=eps).to(torch.float32)
amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32)
x_s = amax / int8_max
x_q = (x_.to(torch.float32) / x_s).round().clamp(
min=int8_min, max=int8_max).to(dtype) # Round before clamping
x_q = (
(x_.to(torch.float32) / x_s).round().clamp(min=int8_min, max=int8_max).to(dtype)
) # Round before clamping
x_q = x_q.reshape(x.shape)
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, ))
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,))
return x_q, x_s
@@ -229,9 +236,9 @@ def per_block_cast_to_int8(
block_m, block_n = block_shape
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros((round_up(m, block_m), round_up(n, block_n)),
dtype=x.dtype,
device=x.device)
x_padded = torch.zeros(
(round_up(m, block_m), round_up(n, block_n)), dtype=x.dtype, device=x.device
)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
@@ -269,8 +276,9 @@ def batched_dequant(
assert t.shape[0] == scale.shape[0]
out = torch.empty_like(t, dtype=out_dtype)
for e in range(t.shape[0]):
out[e] = dequant(t[e], scale[e], block_shape, per_act_token_quant,
out_dtype)
out[e] = dequant(
t[e], scale[e], block_shape, per_act_token_quant, out_dtype
)
return out
return t.to(out_dtype)
@@ -294,15 +302,17 @@ def native_batched_masked_quant_matmul(
num_tokens = num_expert_tokens_cpu[e]
if A.dtype.itemsize == 1 and block_shape is not None:
assert A_scale is not None and B_scale is not None
tmp = native_w8a8_block_matmul(A[e], B[e], A_scale[e], B_scale[e],
block_shape, C.dtype)
tmp = native_w8a8_block_matmul(
A[e], B[e], A_scale[e], B_scale[e], block_shape, C.dtype
)
C[e, :num_tokens, :] = tmp[:num_tokens, :]
elif A.dtype.itemsize == 1 and block_shape is None:
assert A_scale is not None and B_scale is not None
A_dq = dequant(A[e], A_scale[e], block_shape, per_act_token_quant)
B_dq = dequant(B[e], B_scale[e], block_shape, per_act_token_quant)
C[e, :num_tokens, :] = (
A_dq[:num_tokens] @ B_dq.transpose(0, 1)).to(C.dtype)
C[e, :num_tokens, :] = (A_dq[:num_tokens] @ B_dq.transpose(0, 1)).to(
C.dtype
)
else:
assert A_scale is None
assert B_scale is None