Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -15,18 +15,26 @@ from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, group_broadcast)
GroupShape,
group_broadcast,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED)
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
ChannelQuantScaleParameter,
PerTensorScaleParameter)
CUTLASS_BLOCK_FP8_SUPPORTED,
)
from vllm.model_executor.parameter import (
BlockQuantScaleParameter,
ChannelQuantScaleParameter,
PerTensorScaleParameter,
)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op
from vllm.utils.deep_gemm import (fp8_gemm_nt, is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
should_use_deepgemm_for_fp8_linear)
from vllm.utils.deep_gemm import (
fp8_gemm_nt,
is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
should_use_deepgemm_for_fp8_linear,
)
logger = init_logger(__name__)
@@ -56,7 +64,8 @@ def cutlass_scaled_mm(
out_dtype=output_dtype,
scale_a=As,
# SM90 block FP8 requires row-major scale_b, which we do ahead of time
scale_b=Bs if block_size is not None and is_hopper else Bs.T)
scale_b=Bs if block_size is not None and is_hopper else Bs.T,
)
def rocm_aiter_gemm_w8a8_blockscale_impl(
@@ -80,7 +89,6 @@ def rocm_aiter_gemm_w8a8_blockscale_fake(
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
m = A.shape[0]
n = B.shape[0]
Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
@@ -93,9 +101,11 @@ if current_platform.is_rocm():
op_func=rocm_aiter_gemm_w8a8_blockscale_impl,
fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake,
)
if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR
and current_platform.is_fp8_fnuz()):
if (
envs.VLLM_ROCM_USE_AITER
and envs.VLLM_ROCM_USE_AITER_LINEAR
and current_platform.is_fp8_fnuz()
):
import aiter as rocm_aiter
from aiter import get_hip_quant
@@ -113,8 +123,9 @@ def _w8a8_triton_block_scaled_mm_func(
block_size: list[int],
output_dtype: torch.dtype,
) -> torch.Tensor:
return w8a8_triton_block_scaled_mm(qx, weight, x_scale, weight_scale,
block_size, output_dtype)
return w8a8_triton_block_scaled_mm(
qx, weight, x_scale, weight_scale, block_size, output_dtype
)
def _w8a8_triton_block_scaled_mm_fake(
@@ -125,9 +136,9 @@ def _w8a8_triton_block_scaled_mm_fake(
block_size: list[int],
output_dtype: torch.dtype,
) -> torch.Tensor:
return torch.empty((qx.size(0), weight.size(0)),
dtype=output_dtype,
device=qx.device)
return torch.empty(
(qx.size(0), weight.size(0)), dtype=output_dtype, device=qx.device
)
direct_register_custom_op(
@@ -147,22 +158,24 @@ def _padded_cutlass(
) -> torch.Tensor:
pad_multiple = 4
dim = qx.shape[0]
padded = dim if dim % pad_multiple == 0 else dim + pad_multiple - (
dim % pad_multiple)
padded = (
dim if dim % pad_multiple == 0 else dim + pad_multiple - (dim % pad_multiple)
)
padded_shape = [padded, *qx.shape[1:]]
padded_qx = torch.zeros(padded_shape, device=qx.device, dtype=qx.dtype)
padded_qx[0:qx.shape[0], ...].copy_(qx)
padded_qx[0 : qx.shape[0], ...].copy_(qx)
padded_x_scale_shape = [*x_scale.shape[1:], padded]
padded_x_scale = torch.ones(padded_x_scale_shape,
device=x_scale.device,
dtype=x_scale.dtype).permute(-1, -2)
padded_x_scale[0:x_scale.shape[0], ...].copy_(x_scale)
padded_x_scale = torch.ones(
padded_x_scale_shape, device=x_scale.device, dtype=x_scale.dtype
).permute(-1, -2)
padded_x_scale[0 : x_scale.shape[0], ...].copy_(x_scale)
output = cutlass_scaled_mm(padded_qx, weight, padded_x_scale, weight_scale,
block_size, output_dtype, True)
return output[0:qx.shape[0], ...]
output = cutlass_scaled_mm(
padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype, True
)
return output[0 : qx.shape[0], ...]
def _padded_cutlass_fake(
@@ -173,9 +186,9 @@ def _padded_cutlass_fake(
block_size: list[int],
output_dtype: torch.dtype,
) -> torch.Tensor:
return torch.empty((qx.size(0), weight.size(0)),
dtype=output_dtype,
device=qx.device)
return torch.empty(
(qx.size(0), weight.size(0)), dtype=output_dtype, device=qx.device
)
direct_register_custom_op(
@@ -185,18 +198,30 @@ direct_register_custom_op(
)
def _fp8_gemm_nt_op(q_input: torch.Tensor, input_scale: torch.Tensor,
weight: torch.Tensor, weight_scale: torch.Tensor,
output: torch.Tensor, use_deep_gemm_e8m0: bool) -> None:
fp8_gemm_nt((q_input, input_scale), (weight, weight_scale),
output,
is_deep_gemm_e8m0_used=use_deep_gemm_e8m0)
def _fp8_gemm_nt_op(
q_input: torch.Tensor,
input_scale: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
output: torch.Tensor,
use_deep_gemm_e8m0: bool,
) -> None:
fp8_gemm_nt(
(q_input, input_scale),
(weight, weight_scale),
output,
is_deep_gemm_e8m0_used=use_deep_gemm_e8m0,
)
def _fp8_gemm_nt_op_fake(q_input: torch.Tensor, input_scale: torch.Tensor,
weight: torch.Tensor, weight_scale: torch.Tensor,
output: torch.Tensor,
use_deep_gemm_e8m0: bool) -> None:
def _fp8_gemm_nt_op_fake(
q_input: torch.Tensor,
input_scale: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
output: torch.Tensor,
use_deep_gemm_e8m0: bool,
) -> None:
return None
@@ -233,15 +258,21 @@ class W8A8BlockFp8LinearOp:
# We can't use _dispatch_w8a8_blockscale_op to figure out if we want
# to use deepgemm because we don't know the shape of weights (and
# whether deepgemm supports it) at the init time.
self.w8a8_blockscale_op, self.input_quant_op = \
self._dispatch_w8a8_blockscale_op(
cutlass_block_fp8_supported, use_aiter_and_is_supported)
self.deepgemm_input_quant_op = (QuantFP8(
False,
self.act_quant_group_shape,
column_major_scales=True,
use_ue8m0=self.use_deep_gemm_e8m0) if self.is_deep_gemm_supported
else None)
self.w8a8_blockscale_op, self.input_quant_op = (
self._dispatch_w8a8_blockscale_op(
cutlass_block_fp8_supported, use_aiter_and_is_supported
)
)
self.deepgemm_input_quant_op = (
QuantFP8(
False,
self.act_quant_group_shape,
column_major_scales=True,
use_ue8m0=self.use_deep_gemm_e8m0,
)
if self.is_deep_gemm_supported
else None
)
def apply(
self,
@@ -257,8 +288,9 @@ class W8A8BlockFp8LinearOp:
output_shape = [*input.shape[:-1], weight.shape[0]]
output_dtype = input.dtype
if should_use_deepgemm_for_fp8_linear(output_dtype, weight,
self.is_deep_gemm_supported):
if should_use_deepgemm_for_fp8_linear(
output_dtype, weight, self.is_deep_gemm_supported
):
output = self._run_deepgemm(input_2d, weight, weight_scale)
else:
output = self.w8a8_blockscale_op(input_2d, weight, weight_scale)
@@ -275,12 +307,14 @@ class W8A8BlockFp8LinearOp:
) -> torch.Tensor:
assert self.deepgemm_input_quant_op is not None
q_input, input_scale = self.deepgemm_input_quant_op(input_2d)
output = torch.empty((q_input.shape[0], weight.shape[0]),
dtype=torch.bfloat16,
device=q_input.device)
torch.ops.vllm.fp8_gemm_nt_op(q_input, input_scale, weight,
weight_scale, output,
self.use_deep_gemm_e8m0)
output = torch.empty(
(q_input.shape[0], weight.shape[0]),
dtype=torch.bfloat16,
device=q_input.device,
)
torch.ops.vllm.fp8_gemm_nt_op(
q_input, input_scale, weight, weight_scale, output, self.use_deep_gemm_e8m0
)
return output
def _run_cutlass(
@@ -292,15 +326,24 @@ class W8A8BlockFp8LinearOp:
assert self.input_quant_op is not None
q_input, input_scale = self.input_quant_op(input_2d)
if self.is_hopper:
return torch.ops.vllm.padded_cutlass(q_input, weight, input_scale,
weight_scale,
list(self.weight_group_shape),
input_2d.dtype)
return torch.ops.vllm.padded_cutlass(
q_input,
weight,
input_scale,
weight_scale,
list(self.weight_group_shape),
input_2d.dtype,
)
else:
return cutlass_scaled_mm(q_input, weight,
input_scale, weight_scale,
list(self.weight_group_shape),
input_2d.dtype, False)
return cutlass_scaled_mm(
q_input,
weight,
input_scale,
weight_scale,
list(self.weight_group_shape),
input_2d.dtype,
False,
)
def _run_aiter(
self,
@@ -310,10 +353,16 @@ class W8A8BlockFp8LinearOp:
) -> torch.Tensor:
assert self.act_quant_group_shape == GroupShape(1, 128)
q_input, input_scale = aiter_per1x128_quant(
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8)
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8
)
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale(
q_input, weight, input_scale, weight_scale,
self.weight_group_shape, input_2d.dtype)
q_input,
weight,
input_scale,
weight_scale,
self.weight_group_shape,
input_2d.dtype,
)
def _run_triton(
self,
@@ -324,34 +373,52 @@ class W8A8BlockFp8LinearOp:
assert self.input_quant_op is not None
q_input, input_scale = self.input_quant_op(input_2d)
return torch.ops.vllm.w8a8_triton_block_scaled_mm_func(
q_input, weight, input_scale, weight_scale,
self.weight_group_shape, input_2d.dtype)
q_input,
weight,
input_scale,
weight_scale,
self.weight_group_shape,
input_2d.dtype,
)
def _dispatch_w8a8_blockscale_op(
self,
use_cutlass: bool,
use_aiter_and_is_supported: bool,
) -> tuple[Callable[[
) -> tuple[
Callable[
[
torch.Tensor,
torch.Tensor,
torch.Tensor,
],
torch.Tensor,
torch.Tensor,
torch.Tensor,
], torch.Tensor], Optional[QuantFP8]]:
],
Optional[QuantFP8],
]:
if use_cutlass:
return self._run_cutlass, (QuantFP8(False,
self.act_quant_group_shape,
column_major_scales=True,
use_ue8m0=False))
return self._run_cutlass, (
QuantFP8(
False,
self.act_quant_group_shape,
column_major_scales=True,
use_ue8m0=False,
)
)
if use_aiter_and_is_supported:
return self._run_aiter, None
return self._run_triton, (QuantFP8(False,
self.act_quant_group_shape,
column_major_scales=False,
use_ue8m0=False))
return self._run_triton, (
QuantFP8(
False,
self.act_quant_group_shape,
column_major_scales=False,
use_ue8m0=False,
)
)
def input_to_float8(
x: torch.Tensor,
dtype: Optional[torch.dtype] = None
x: torch.Tensor, dtype: Optional[torch.dtype] = None
) -> tuple[torch.Tensor, torch.Tensor]:
"""This function quantizes input values to float8 values "
"with tensor-wise quantization."""
@@ -410,8 +477,9 @@ def _per_token_group_quant_fp8(
row_g_id = g_id % groups_per_row
# Ensure offset calculations use int64 to prevent overflow
y_ptr_offset = (row.to(tl.int64) * y_row_stride) + (row_g_id.to(tl.int64) *
group_size)
y_ptr_offset = (row.to(tl.int64) * y_row_stride) + (
row_g_id.to(tl.int64) * group_size
)
y_ptr += y_ptr_offset
y_q_ptr_offset = g_id.to(tl.int64) * group_size
@@ -465,8 +533,9 @@ def _per_token_group_quant_fp8_colmajor(
row_g_id = g_id % groups_per_row
# Ensure offset calculations use int64 to prevent overflow
y_ptr_offset = (row.to(tl.int64) * y_row_stride) + (row_g_id.to(tl.int64) *
group_size)
y_ptr_offset = (row.to(tl.int64) * y_row_stride) + (
row_g_id.to(tl.int64) * group_size
)
y_ptr += y_ptr_offset
y_q_ptr_offset = g_id.to(tl.int64) * group_size
@@ -478,8 +547,7 @@ def _per_token_group_quant_fp8_colmajor(
scale_col = g_id % blocks_per_row
scale_row = g_id // blocks_per_row
# Ensure offset calculation uses int64 for y_s_ptr
y_s_ptr_offset = (scale_col.to(tl.int64) * y_s_col_stride) + scale_row.to(
tl.int64)
y_s_ptr_offset = (scale_col.to(tl.int64) * y_s_col_stride) + scale_row.to(tl.int64)
y_s_ptr += y_s_ptr_offset
cols = tl.arange(0, BLOCK) # group_size <= BLOCK
@@ -523,9 +591,10 @@ def per_token_group_quant_fp8(
if use_ue8m0 is None:
use_ue8m0 = is_deep_gemm_e8m0_used()
dtype = current_platform.fp8_dtype() if dtype is None else dtype
assert (x.shape[-1] % group_size == 0), (
assert x.shape[-1] % group_size == 0, (
f"the last dimension of `x` {x.shape[-1]} must be divisible "
f"by `group_size` {group_size}")
f"by `group_size` {group_size}"
)
assert x.stride(-1) == 1, "`x` groups must be contiguous"
finfo = torch.finfo(dtype)
@@ -539,18 +608,18 @@ def per_token_group_quant_fp8(
# Allocate the scale tensor in either row- or column-major format.
if column_major_scales:
shape = (x.shape[-1] // group_size, ) + x.shape[:-1]
x_s = torch.empty(shape, device=x.device,
dtype=torch.float32).permute(-1, -2)
shape = (x.shape[-1] // group_size,) + x.shape[:-1]
x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
else:
shape = x.shape[:-1] + (x.shape[-1] // group_size, )
shape = x.shape[:-1] + (x.shape[-1] // group_size,)
x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
# prefer CUDA kernel if available
# TODO(bnell): this causes some fp8 moe test to fail.
if current_platform.is_cuda() and x.is_contiguous():
torch.ops._C.per_token_group_fp8_quant(x, x_q, x_s, group_size, eps,
fp8_min, fp8_max, use_ue8m0)
torch.ops._C.per_token_group_fp8_quant(
x, x_q, x_s, group_size, eps, fp8_min, fp8_max, use_ue8m0
)
return x_q, x_s
# TRITON FALLBACK
@@ -561,7 +630,7 @@ def per_token_group_quant_fp8(
num_warps = min(max(BLOCK // 256, 1), 8)
num_stages = 1
if column_major_scales:
_per_token_group_quant_fp8_colmajor[(M, )](
_per_token_group_quant_fp8_colmajor[(M,)](
x,
x_q,
x_s,
@@ -578,7 +647,7 @@ def per_token_group_quant_fp8(
num_stages=num_stages,
)
else:
_per_token_group_quant_fp8[(M, )](
_per_token_group_quant_fp8[(M,)](
x,
x_q,
x_s,
@@ -656,12 +725,8 @@ def _w8a8_triton_block_scaled_mm(
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs,
mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
other=0.0)
b = tl.load(b_ptrs,
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0)
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
@@ -687,8 +752,9 @@ def _w8a8_triton_block_scaled_mm(
@functools.lru_cache
def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int,
block_k: int) -> Optional[dict[int, Any]]:
def get_w8a8_block_fp8_configs(
N: int, K: int, block_n: int, block_k: int
) -> Optional[dict[int, Any]]:
"""
Return optimized configurations for the w8a8 block fp8 kernel.
The return value will be a dictionary that maps an irregular grid of
@@ -703,7 +769,8 @@ def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int,
json_file_name = f"N={N},K={K},device_name={device_name},dtype=fp8_w8a8,block_shape=[{block_n},{block_k}].json" # noqa: E501
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.info(
@@ -759,7 +826,7 @@ def w8a8_triton_block_scaled_mm(
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]
C_shape = A.shape[:-1] + (N, )
C_shape = A.shape[:-1] + (N,)
C = A.new_empty(C_shape, dtype=output_dtype)
configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
@@ -780,8 +847,9 @@ def w8a8_triton_block_scaled_mm(
}
def grid(META):
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) *
triton.cdiv(N, META["BLOCK_SIZE_N"]), )
return (
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
_w8a8_triton_block_scaled_mm[grid](
A,
@@ -811,9 +879,9 @@ def w8a8_triton_block_scaled_mm(
def requant_weight_ue8m0_inplace(
weight: torch.Tensor,
weight_scale: torch.Tensor,
block_size: Sequence[int] = (128, 128),
weight: torch.Tensor,
weight_scale: torch.Tensor,
block_size: Sequence[int] = (128, 128),
) -> None:
"""Re-quantise *weight* so that its per-block scaling factors are in the
UE8M0 (power-of-two) format expected by the new DeepGEMM kernels inplace.
@@ -830,8 +898,9 @@ def requant_weight_ue8m0_inplace(
return
if weight.dtype != torch.float8_e4m3fn:
raise ValueError("Expected *weight* to be torch.float8_e4m3fn, got "
f"{weight.dtype} instead.")
raise ValueError(
f"Expected *weight* to be torch.float8_e4m3fn, got {weight.dtype} instead."
)
from vllm.utils.deep_gemm import per_block_cast_to_fp8
@@ -860,8 +929,9 @@ def requant_weight_ue8m0_inplace(
s_exp = s_exp[:m_cur, :k_cur]
w_dq = w_q.to(torch.float32) * s_exp
# Re-quantise using power-of-two scaling (UE8M0).
w_requant, s_requant = per_block_cast_to_fp8(w_dq, [block_m, block_k],
use_ue8m0=True)
w_requant, s_requant = per_block_cast_to_fp8(
w_dq, [block_m, block_k], use_ue8m0=True
)
# Write back the results in-place.
w_q.copy_(w_requant)
@@ -871,28 +941,39 @@ def requant_weight_ue8m0_inplace(
def check_aiter_fp8_linear_support() -> bool:
"""AITER is only supported on ROCm and only for FP8_FNUZ
and at the moment are MI300 series"""
return (current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER
and envs.VLLM_ROCM_USE_AITER_LINEAR
and current_platform.is_fp8_fnuz())
return (
current_platform.is_rocm()
and envs.VLLM_ROCM_USE_AITER
and envs.VLLM_ROCM_USE_AITER_LINEAR
and current_platform.is_fp8_fnuz()
)
def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor:
"""Pad the weight tensor. This is an optimization on ROCm platform, which
can benefit from tensors located far enough from one another in memory"""
if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm()
and weight.stride(-1) == 1
and (weight.stride(-2) * weight.element_size()) % 512 == 0):
if (
envs.VLLM_ROCM_FP8_PADDING
and current_platform.is_rocm()
and weight.stride(-1) == 1
and (weight.stride(-2) * weight.element_size()) % 512 == 0
):
num_pad = 256 // weight.element_size()
import torch.nn.functional as F
weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
torch.cuda.empty_cache()
return weight
def validate_fp8_block_shape(layer: torch.nn.Module, input_size: int,
output_size: int, input_size_per_partition: int,
output_partition_sizes: list[int],
block_size: list[int]) -> None:
def validate_fp8_block_shape(
layer: torch.nn.Module,
input_size: int,
output_size: int,
input_size_per_partition: int,
output_partition_sizes: list[int],
block_size: list[int],
) -> None:
"""Validate block quantization shapes for tensor parallelism."""
from vllm.distributed import get_tensor_model_parallel_world_size
@@ -900,15 +981,18 @@ def validate_fp8_block_shape(layer: torch.nn.Module, input_size: int,
block_n, block_k = block_size[0], block_size[1]
# Required by row parallel
if (tp_size > 1 and input_size // input_size_per_partition == tp_size
and input_size_per_partition % block_k != 0):
if (
tp_size > 1
and input_size // input_size_per_partition == tp_size
and input_size_per_partition % block_k != 0
):
raise ValueError(
f"Weight input_size_per_partition = {input_size_per_partition} "
f"is not divisible by weight quantization block_k = {block_k}.")
f"is not divisible by weight quantization block_k = {block_k}."
)
# Required by column parallel or enabling merged weights
is_tp_split = (tp_size > 1
and output_size // sum(output_partition_sizes) == tp_size)
is_tp_split = tp_size > 1 and output_size // sum(output_partition_sizes) == tp_size
is_merged_gemm = len(output_partition_sizes) > 1
if is_tp_split or is_merged_gemm:
sizes_to_check = output_partition_sizes
@@ -921,33 +1005,44 @@ def validate_fp8_block_shape(layer: torch.nn.Module, input_size: int,
raise ValueError(
f"Weight output_partition_size = "
f"{output_partition_size} is not divisible by "
f"weight quantization block_n = {block_n}.")
f"weight quantization block_n = {block_n}."
)
def create_fp8_weight_parameter(
output_size_per_partition: int, input_size_per_partition: int,
weight_loader: Optional[Callable]) -> torch.nn.Parameter:
output_size_per_partition: int,
input_size_per_partition: int,
weight_loader: Optional[Callable],
) -> torch.nn.Parameter:
"""Create FP8 weight parameter."""
from vllm.model_executor.parameter import ModelWeightParameter
return ModelWeightParameter(data=torch.empty(output_size_per_partition,
input_size_per_partition,
dtype=torch.float8_e4m3fn),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
return ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=torch.float8_e4m3fn,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
def create_fp8_scale_parameter(
parameter_type: torch.nn.Parameter, output_partition_sizes: list[int],
input_size_per_partition: int, block_size: Optional[list[int]],
weight_loader: Optional[Callable]) -> torch.nn.Parameter:
parameter_type: torch.nn.Parameter,
output_partition_sizes: list[int],
input_size_per_partition: int,
block_size: Optional[list[int]],
weight_loader: Optional[Callable],
) -> torch.nn.Parameter:
"""Create scale parameter based on quantization strategy."""
if parameter_type == ChannelQuantScaleParameter:
scale = parameter_type(data=torch.empty(
(sum(output_partition_sizes), 1), dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)
scale = parameter_type(
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader,
)
elif parameter_type == BlockQuantScaleParameter:
assert block_size is not None
block_n, block_k = block_size[0], block_size[1]
@@ -963,9 +1058,10 @@ def create_fp8_scale_parameter(
weight_loader=weight_loader,
)
elif parameter_type == PerTensorScaleParameter:
scale = parameter_type(data=torch.empty(len(output_partition_sizes),
dtype=torch.float32),
weight_loader=weight_loader)
scale = parameter_type(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
else:
raise ValueError(f"Unknown parameter type: {parameter_type}")
@@ -974,14 +1070,15 @@ def create_fp8_scale_parameter(
def create_fp8_input_scale(
output_partition_sizes: list[int],
weight_loader: Optional[Callable]) -> torch.nn.Parameter:
output_partition_sizes: list[int], weight_loader: Optional[Callable]
) -> torch.nn.Parameter:
"""Create input scale parameter for static activation quantization."""
from vllm.model_executor.parameter import PerTensorScaleParameter
scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
scale[:] = torch.finfo(torch.float32).min
return scale
@@ -990,15 +1087,18 @@ def process_fp8_weight_tensor_strategy(
weight: torch.Tensor,
weight_scale: torch.Tensor,
logical_widths: list[int],
input_scale: Optional[torch.Tensor] = None
input_scale: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Process weights for tensor-wise quantization strategy."""
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale)
normalize_e4m3fn_to_e4m3fnuz,
requantize_with_max_scale,
)
if current_platform.is_fp8_fnuz():
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight_scale=weight_scale, input_scale=input_scale)
weight=weight, weight_scale=weight_scale, input_scale=input_scale
)
# Requantize with max scale
weight_scale, weight = requantize_with_max_scale(
@@ -1014,15 +1114,17 @@ def process_fp8_weight_tensor_strategy(
def process_fp8_weight_channel_strategy(
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None
input_scale: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Process weights for channel-wise quantization strategy."""
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
normalize_e4m3fn_to_e4m3fnuz)
normalize_e4m3fn_to_e4m3fnuz,
)
if current_platform.is_fp8_fnuz():
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight_scale=weight_scale, input_scale=input_scale)
weight=weight, weight_scale=weight_scale, input_scale=input_scale
)
return weight, weight_scale, input_scale
@@ -1033,37 +1135,48 @@ def process_fp8_weight_block_strategy(
) -> tuple[torch.Tensor, torch.Tensor]:
"""Process weights for block-wise quantization strategy."""
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
normalize_e4m3fn_to_e4m3fnuz)
normalize_e4m3fn_to_e4m3fnuz,
)
if current_platform.is_fp8_fnuz():
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight_scale=weight_scale)
weight=weight, weight_scale=weight_scale
)
weight = _maybe_pad_fp8_weight(weight)
return weight, weight_scale
def maybe_post_process_fp8_weight_block(layer: torch.nn.Module,
cutlass_block_fp8_supported: bool):
def maybe_post_process_fp8_weight_block(
layer: torch.nn.Module, cutlass_block_fp8_supported: bool
):
assert layer.weight_block_size is not None
from vllm.utils.deep_gemm import (is_deep_gemm_e8m0_used,
should_use_deepgemm_for_fp8_linear)
from vllm.utils.deep_gemm import (
is_deep_gemm_e8m0_used,
should_use_deepgemm_for_fp8_linear,
)
# On Blackwell or Hopper, if E8M0 for DeepGemm is used, we need to
# requantize the weight and input to the specific scale
# at the same time.
should_use_deepgemm = should_use_deepgemm_for_fp8_linear(
layer.orig_dtype, layer.weight)
layer.orig_dtype, layer.weight
)
if is_deep_gemm_e8m0_used() and should_use_deepgemm:
block_sz = tuple(layer.weight_block_size)
requant_weight_ue8m0_inplace(layer.weight.data,
layer.weight_scale.data, block_sz)
requant_weight_ue8m0_inplace(
layer.weight.data, layer.weight_scale.data, block_sz
)
# SM90 Block FP8 CUTLASS requires row-major weight scales
elif (current_platform.is_device_capability(90)
and cutlass_block_fp8_supported and not should_use_deepgemm):
elif (
current_platform.is_device_capability(90)
and cutlass_block_fp8_supported
and not should_use_deepgemm
):
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.data.T.contiguous(), requires_grad=False)
layer.weight_scale.data.T.contiguous(), requires_grad=False
)
def expert_weight_is_col_major(x: torch.Tensor) -> bool: