Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -15,18 +15,26 @@ from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape, group_broadcast)
|
||||
GroupShape,
|
||||
group_broadcast,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
CUTLASS_BLOCK_FP8_SUPPORTED)
|
||||
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
PerTensorScaleParameter)
|
||||
CUTLASS_BLOCK_FP8_SUPPORTED,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
BlockQuantScaleParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.utils.deep_gemm import (fp8_gemm_nt, is_deep_gemm_e8m0_used,
|
||||
is_deep_gemm_supported,
|
||||
should_use_deepgemm_for_fp8_linear)
|
||||
from vllm.utils.deep_gemm import (
|
||||
fp8_gemm_nt,
|
||||
is_deep_gemm_e8m0_used,
|
||||
is_deep_gemm_supported,
|
||||
should_use_deepgemm_for_fp8_linear,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -56,7 +64,8 @@ def cutlass_scaled_mm(
|
||||
out_dtype=output_dtype,
|
||||
scale_a=As,
|
||||
# SM90 block FP8 requires row-major scale_b, which we do ahead of time
|
||||
scale_b=Bs if block_size is not None and is_hopper else Bs.T)
|
||||
scale_b=Bs if block_size is not None and is_hopper else Bs.T,
|
||||
)
|
||||
|
||||
|
||||
def rocm_aiter_gemm_w8a8_blockscale_impl(
|
||||
@@ -80,7 +89,6 @@ def rocm_aiter_gemm_w8a8_blockscale_fake(
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
|
||||
m = A.shape[0]
|
||||
n = B.shape[0]
|
||||
Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
|
||||
@@ -93,9 +101,11 @@ if current_platform.is_rocm():
|
||||
op_func=rocm_aiter_gemm_w8a8_blockscale_impl,
|
||||
fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake,
|
||||
)
|
||||
if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR
|
||||
and current_platform.is_fp8_fnuz()):
|
||||
|
||||
if (
|
||||
envs.VLLM_ROCM_USE_AITER
|
||||
and envs.VLLM_ROCM_USE_AITER_LINEAR
|
||||
and current_platform.is_fp8_fnuz()
|
||||
):
|
||||
import aiter as rocm_aiter
|
||||
from aiter import get_hip_quant
|
||||
|
||||
@@ -113,8 +123,9 @@ def _w8a8_triton_block_scaled_mm_func(
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
return w8a8_triton_block_scaled_mm(qx, weight, x_scale, weight_scale,
|
||||
block_size, output_dtype)
|
||||
return w8a8_triton_block_scaled_mm(
|
||||
qx, weight, x_scale, weight_scale, block_size, output_dtype
|
||||
)
|
||||
|
||||
|
||||
def _w8a8_triton_block_scaled_mm_fake(
|
||||
@@ -125,9 +136,9 @@ def _w8a8_triton_block_scaled_mm_fake(
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty((qx.size(0), weight.size(0)),
|
||||
dtype=output_dtype,
|
||||
device=qx.device)
|
||||
return torch.empty(
|
||||
(qx.size(0), weight.size(0)), dtype=output_dtype, device=qx.device
|
||||
)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
@@ -147,22 +158,24 @@ def _padded_cutlass(
|
||||
) -> torch.Tensor:
|
||||
pad_multiple = 4
|
||||
dim = qx.shape[0]
|
||||
padded = dim if dim % pad_multiple == 0 else dim + pad_multiple - (
|
||||
dim % pad_multiple)
|
||||
padded = (
|
||||
dim if dim % pad_multiple == 0 else dim + pad_multiple - (dim % pad_multiple)
|
||||
)
|
||||
|
||||
padded_shape = [padded, *qx.shape[1:]]
|
||||
padded_qx = torch.zeros(padded_shape, device=qx.device, dtype=qx.dtype)
|
||||
padded_qx[0:qx.shape[0], ...].copy_(qx)
|
||||
padded_qx[0 : qx.shape[0], ...].copy_(qx)
|
||||
|
||||
padded_x_scale_shape = [*x_scale.shape[1:], padded]
|
||||
padded_x_scale = torch.ones(padded_x_scale_shape,
|
||||
device=x_scale.device,
|
||||
dtype=x_scale.dtype).permute(-1, -2)
|
||||
padded_x_scale[0:x_scale.shape[0], ...].copy_(x_scale)
|
||||
padded_x_scale = torch.ones(
|
||||
padded_x_scale_shape, device=x_scale.device, dtype=x_scale.dtype
|
||||
).permute(-1, -2)
|
||||
padded_x_scale[0 : x_scale.shape[0], ...].copy_(x_scale)
|
||||
|
||||
output = cutlass_scaled_mm(padded_qx, weight, padded_x_scale, weight_scale,
|
||||
block_size, output_dtype, True)
|
||||
return output[0:qx.shape[0], ...]
|
||||
output = cutlass_scaled_mm(
|
||||
padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype, True
|
||||
)
|
||||
return output[0 : qx.shape[0], ...]
|
||||
|
||||
|
||||
def _padded_cutlass_fake(
|
||||
@@ -173,9 +186,9 @@ def _padded_cutlass_fake(
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty((qx.size(0), weight.size(0)),
|
||||
dtype=output_dtype,
|
||||
device=qx.device)
|
||||
return torch.empty(
|
||||
(qx.size(0), weight.size(0)), dtype=output_dtype, device=qx.device
|
||||
)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
@@ -185,18 +198,30 @@ direct_register_custom_op(
|
||||
)
|
||||
|
||||
|
||||
def _fp8_gemm_nt_op(q_input: torch.Tensor, input_scale: torch.Tensor,
|
||||
weight: torch.Tensor, weight_scale: torch.Tensor,
|
||||
output: torch.Tensor, use_deep_gemm_e8m0: bool) -> None:
|
||||
fp8_gemm_nt((q_input, input_scale), (weight, weight_scale),
|
||||
output,
|
||||
is_deep_gemm_e8m0_used=use_deep_gemm_e8m0)
|
||||
def _fp8_gemm_nt_op(
|
||||
q_input: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
use_deep_gemm_e8m0: bool,
|
||||
) -> None:
|
||||
fp8_gemm_nt(
|
||||
(q_input, input_scale),
|
||||
(weight, weight_scale),
|
||||
output,
|
||||
is_deep_gemm_e8m0_used=use_deep_gemm_e8m0,
|
||||
)
|
||||
|
||||
|
||||
def _fp8_gemm_nt_op_fake(q_input: torch.Tensor, input_scale: torch.Tensor,
|
||||
weight: torch.Tensor, weight_scale: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
use_deep_gemm_e8m0: bool) -> None:
|
||||
def _fp8_gemm_nt_op_fake(
|
||||
q_input: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
use_deep_gemm_e8m0: bool,
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@@ -233,15 +258,21 @@ class W8A8BlockFp8LinearOp:
|
||||
# We can't use _dispatch_w8a8_blockscale_op to figure out if we want
|
||||
# to use deepgemm because we don't know the shape of weights (and
|
||||
# whether deepgemm supports it) at the init time.
|
||||
self.w8a8_blockscale_op, self.input_quant_op = \
|
||||
self._dispatch_w8a8_blockscale_op(
|
||||
cutlass_block_fp8_supported, use_aiter_and_is_supported)
|
||||
self.deepgemm_input_quant_op = (QuantFP8(
|
||||
False,
|
||||
self.act_quant_group_shape,
|
||||
column_major_scales=True,
|
||||
use_ue8m0=self.use_deep_gemm_e8m0) if self.is_deep_gemm_supported
|
||||
else None)
|
||||
self.w8a8_blockscale_op, self.input_quant_op = (
|
||||
self._dispatch_w8a8_blockscale_op(
|
||||
cutlass_block_fp8_supported, use_aiter_and_is_supported
|
||||
)
|
||||
)
|
||||
self.deepgemm_input_quant_op = (
|
||||
QuantFP8(
|
||||
False,
|
||||
self.act_quant_group_shape,
|
||||
column_major_scales=True,
|
||||
use_ue8m0=self.use_deep_gemm_e8m0,
|
||||
)
|
||||
if self.is_deep_gemm_supported
|
||||
else None
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@@ -257,8 +288,9 @@ class W8A8BlockFp8LinearOp:
|
||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||
output_dtype = input.dtype
|
||||
|
||||
if should_use_deepgemm_for_fp8_linear(output_dtype, weight,
|
||||
self.is_deep_gemm_supported):
|
||||
if should_use_deepgemm_for_fp8_linear(
|
||||
output_dtype, weight, self.is_deep_gemm_supported
|
||||
):
|
||||
output = self._run_deepgemm(input_2d, weight, weight_scale)
|
||||
else:
|
||||
output = self.w8a8_blockscale_op(input_2d, weight, weight_scale)
|
||||
@@ -275,12 +307,14 @@ class W8A8BlockFp8LinearOp:
|
||||
) -> torch.Tensor:
|
||||
assert self.deepgemm_input_quant_op is not None
|
||||
q_input, input_scale = self.deepgemm_input_quant_op(input_2d)
|
||||
output = torch.empty((q_input.shape[0], weight.shape[0]),
|
||||
dtype=torch.bfloat16,
|
||||
device=q_input.device)
|
||||
torch.ops.vllm.fp8_gemm_nt_op(q_input, input_scale, weight,
|
||||
weight_scale, output,
|
||||
self.use_deep_gemm_e8m0)
|
||||
output = torch.empty(
|
||||
(q_input.shape[0], weight.shape[0]),
|
||||
dtype=torch.bfloat16,
|
||||
device=q_input.device,
|
||||
)
|
||||
torch.ops.vllm.fp8_gemm_nt_op(
|
||||
q_input, input_scale, weight, weight_scale, output, self.use_deep_gemm_e8m0
|
||||
)
|
||||
return output
|
||||
|
||||
def _run_cutlass(
|
||||
@@ -292,15 +326,24 @@ class W8A8BlockFp8LinearOp:
|
||||
assert self.input_quant_op is not None
|
||||
q_input, input_scale = self.input_quant_op(input_2d)
|
||||
if self.is_hopper:
|
||||
return torch.ops.vllm.padded_cutlass(q_input, weight, input_scale,
|
||||
weight_scale,
|
||||
list(self.weight_group_shape),
|
||||
input_2d.dtype)
|
||||
return torch.ops.vllm.padded_cutlass(
|
||||
q_input,
|
||||
weight,
|
||||
input_scale,
|
||||
weight_scale,
|
||||
list(self.weight_group_shape),
|
||||
input_2d.dtype,
|
||||
)
|
||||
else:
|
||||
return cutlass_scaled_mm(q_input, weight,
|
||||
input_scale, weight_scale,
|
||||
list(self.weight_group_shape),
|
||||
input_2d.dtype, False)
|
||||
return cutlass_scaled_mm(
|
||||
q_input,
|
||||
weight,
|
||||
input_scale,
|
||||
weight_scale,
|
||||
list(self.weight_group_shape),
|
||||
input_2d.dtype,
|
||||
False,
|
||||
)
|
||||
|
||||
def _run_aiter(
|
||||
self,
|
||||
@@ -310,10 +353,16 @@ class W8A8BlockFp8LinearOp:
|
||||
) -> torch.Tensor:
|
||||
assert self.act_quant_group_shape == GroupShape(1, 128)
|
||||
q_input, input_scale = aiter_per1x128_quant(
|
||||
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8)
|
||||
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8
|
||||
)
|
||||
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale(
|
||||
q_input, weight, input_scale, weight_scale,
|
||||
self.weight_group_shape, input_2d.dtype)
|
||||
q_input,
|
||||
weight,
|
||||
input_scale,
|
||||
weight_scale,
|
||||
self.weight_group_shape,
|
||||
input_2d.dtype,
|
||||
)
|
||||
|
||||
def _run_triton(
|
||||
self,
|
||||
@@ -324,34 +373,52 @@ class W8A8BlockFp8LinearOp:
|
||||
assert self.input_quant_op is not None
|
||||
q_input, input_scale = self.input_quant_op(input_2d)
|
||||
return torch.ops.vllm.w8a8_triton_block_scaled_mm_func(
|
||||
q_input, weight, input_scale, weight_scale,
|
||||
self.weight_group_shape, input_2d.dtype)
|
||||
q_input,
|
||||
weight,
|
||||
input_scale,
|
||||
weight_scale,
|
||||
self.weight_group_shape,
|
||||
input_2d.dtype,
|
||||
)
|
||||
|
||||
def _dispatch_w8a8_blockscale_op(
|
||||
self,
|
||||
use_cutlass: bool,
|
||||
use_aiter_and_is_supported: bool,
|
||||
) -> tuple[Callable[[
|
||||
) -> tuple[
|
||||
Callable[
|
||||
[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
],
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
], torch.Tensor], Optional[QuantFP8]]:
|
||||
],
|
||||
Optional[QuantFP8],
|
||||
]:
|
||||
if use_cutlass:
|
||||
return self._run_cutlass, (QuantFP8(False,
|
||||
self.act_quant_group_shape,
|
||||
column_major_scales=True,
|
||||
use_ue8m0=False))
|
||||
return self._run_cutlass, (
|
||||
QuantFP8(
|
||||
False,
|
||||
self.act_quant_group_shape,
|
||||
column_major_scales=True,
|
||||
use_ue8m0=False,
|
||||
)
|
||||
)
|
||||
if use_aiter_and_is_supported:
|
||||
return self._run_aiter, None
|
||||
return self._run_triton, (QuantFP8(False,
|
||||
self.act_quant_group_shape,
|
||||
column_major_scales=False,
|
||||
use_ue8m0=False))
|
||||
return self._run_triton, (
|
||||
QuantFP8(
|
||||
False,
|
||||
self.act_quant_group_shape,
|
||||
column_major_scales=False,
|
||||
use_ue8m0=False,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def input_to_float8(
|
||||
x: torch.Tensor,
|
||||
dtype: Optional[torch.dtype] = None
|
||||
x: torch.Tensor, dtype: Optional[torch.dtype] = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""This function quantizes input values to float8 values "
|
||||
"with tensor-wise quantization."""
|
||||
@@ -410,8 +477,9 @@ def _per_token_group_quant_fp8(
|
||||
row_g_id = g_id % groups_per_row
|
||||
|
||||
# Ensure offset calculations use int64 to prevent overflow
|
||||
y_ptr_offset = (row.to(tl.int64) * y_row_stride) + (row_g_id.to(tl.int64) *
|
||||
group_size)
|
||||
y_ptr_offset = (row.to(tl.int64) * y_row_stride) + (
|
||||
row_g_id.to(tl.int64) * group_size
|
||||
)
|
||||
y_ptr += y_ptr_offset
|
||||
|
||||
y_q_ptr_offset = g_id.to(tl.int64) * group_size
|
||||
@@ -465,8 +533,9 @@ def _per_token_group_quant_fp8_colmajor(
|
||||
row_g_id = g_id % groups_per_row
|
||||
|
||||
# Ensure offset calculations use int64 to prevent overflow
|
||||
y_ptr_offset = (row.to(tl.int64) * y_row_stride) + (row_g_id.to(tl.int64) *
|
||||
group_size)
|
||||
y_ptr_offset = (row.to(tl.int64) * y_row_stride) + (
|
||||
row_g_id.to(tl.int64) * group_size
|
||||
)
|
||||
y_ptr += y_ptr_offset
|
||||
|
||||
y_q_ptr_offset = g_id.to(tl.int64) * group_size
|
||||
@@ -478,8 +547,7 @@ def _per_token_group_quant_fp8_colmajor(
|
||||
scale_col = g_id % blocks_per_row
|
||||
scale_row = g_id // blocks_per_row
|
||||
# Ensure offset calculation uses int64 for y_s_ptr
|
||||
y_s_ptr_offset = (scale_col.to(tl.int64) * y_s_col_stride) + scale_row.to(
|
||||
tl.int64)
|
||||
y_s_ptr_offset = (scale_col.to(tl.int64) * y_s_col_stride) + scale_row.to(tl.int64)
|
||||
y_s_ptr += y_s_ptr_offset
|
||||
|
||||
cols = tl.arange(0, BLOCK) # group_size <= BLOCK
|
||||
@@ -523,9 +591,10 @@ def per_token_group_quant_fp8(
|
||||
if use_ue8m0 is None:
|
||||
use_ue8m0 = is_deep_gemm_e8m0_used()
|
||||
dtype = current_platform.fp8_dtype() if dtype is None else dtype
|
||||
assert (x.shape[-1] % group_size == 0), (
|
||||
assert x.shape[-1] % group_size == 0, (
|
||||
f"the last dimension of `x` {x.shape[-1]} must be divisible "
|
||||
f"by `group_size` {group_size}")
|
||||
f"by `group_size` {group_size}"
|
||||
)
|
||||
assert x.stride(-1) == 1, "`x` groups must be contiguous"
|
||||
|
||||
finfo = torch.finfo(dtype)
|
||||
@@ -539,18 +608,18 @@ def per_token_group_quant_fp8(
|
||||
|
||||
# Allocate the scale tensor in either row- or column-major format.
|
||||
if column_major_scales:
|
||||
shape = (x.shape[-1] // group_size, ) + x.shape[:-1]
|
||||
x_s = torch.empty(shape, device=x.device,
|
||||
dtype=torch.float32).permute(-1, -2)
|
||||
shape = (x.shape[-1] // group_size,) + x.shape[:-1]
|
||||
x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
|
||||
else:
|
||||
shape = x.shape[:-1] + (x.shape[-1] // group_size, )
|
||||
shape = x.shape[:-1] + (x.shape[-1] // group_size,)
|
||||
x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
|
||||
|
||||
# prefer CUDA kernel if available
|
||||
# TODO(bnell): this causes some fp8 moe test to fail.
|
||||
if current_platform.is_cuda() and x.is_contiguous():
|
||||
torch.ops._C.per_token_group_fp8_quant(x, x_q, x_s, group_size, eps,
|
||||
fp8_min, fp8_max, use_ue8m0)
|
||||
torch.ops._C.per_token_group_fp8_quant(
|
||||
x, x_q, x_s, group_size, eps, fp8_min, fp8_max, use_ue8m0
|
||||
)
|
||||
return x_q, x_s
|
||||
|
||||
# TRITON FALLBACK
|
||||
@@ -561,7 +630,7 @@ def per_token_group_quant_fp8(
|
||||
num_warps = min(max(BLOCK // 256, 1), 8)
|
||||
num_stages = 1
|
||||
if column_major_scales:
|
||||
_per_token_group_quant_fp8_colmajor[(M, )](
|
||||
_per_token_group_quant_fp8_colmajor[(M,)](
|
||||
x,
|
||||
x_q,
|
||||
x_s,
|
||||
@@ -578,7 +647,7 @@ def per_token_group_quant_fp8(
|
||||
num_stages=num_stages,
|
||||
)
|
||||
else:
|
||||
_per_token_group_quant_fp8[(M, )](
|
||||
_per_token_group_quant_fp8[(M,)](
|
||||
x,
|
||||
x_q,
|
||||
x_s,
|
||||
@@ -656,12 +725,8 @@ def _w8a8_triton_block_scaled_mm(
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
a = tl.load(a_ptrs,
|
||||
mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
|
||||
other=0.0)
|
||||
b = tl.load(b_ptrs,
|
||||
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
|
||||
other=0.0)
|
||||
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
|
||||
k_start = k * BLOCK_SIZE_K
|
||||
offs_ks = k_start // group_k
|
||||
@@ -687,8 +752,9 @@ def _w8a8_triton_block_scaled_mm(
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int,
|
||||
block_k: int) -> Optional[dict[int, Any]]:
|
||||
def get_w8a8_block_fp8_configs(
|
||||
N: int, K: int, block_n: int, block_k: int
|
||||
) -> Optional[dict[int, Any]]:
|
||||
"""
|
||||
Return optimized configurations for the w8a8 block fp8 kernel.
|
||||
The return value will be a dictionary that maps an irregular grid of
|
||||
@@ -703,7 +769,8 @@ def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int,
|
||||
json_file_name = f"N={N},K={K},device_name={device_name},dtype=fp8_w8a8,block_shape=[{block_n},{block_k}].json" # noqa: E501
|
||||
|
||||
config_file_path = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
|
||||
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
|
||||
)
|
||||
if os.path.exists(config_file_path):
|
||||
with open(config_file_path) as f:
|
||||
logger.info(
|
||||
@@ -759,7 +826,7 @@ def w8a8_triton_block_scaled_mm(
|
||||
assert triton.cdiv(N, block_n) == Bs.shape[0]
|
||||
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
||||
|
||||
C_shape = A.shape[:-1] + (N, )
|
||||
C_shape = A.shape[:-1] + (N,)
|
||||
C = A.new_empty(C_shape, dtype=output_dtype)
|
||||
|
||||
configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
|
||||
@@ -780,8 +847,9 @@ def w8a8_triton_block_scaled_mm(
|
||||
}
|
||||
|
||||
def grid(META):
|
||||
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) *
|
||||
triton.cdiv(N, META["BLOCK_SIZE_N"]), )
|
||||
return (
|
||||
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
_w8a8_triton_block_scaled_mm[grid](
|
||||
A,
|
||||
@@ -811,9 +879,9 @@ def w8a8_triton_block_scaled_mm(
|
||||
|
||||
|
||||
def requant_weight_ue8m0_inplace(
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
block_size: Sequence[int] = (128, 128),
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
block_size: Sequence[int] = (128, 128),
|
||||
) -> None:
|
||||
"""Re-quantise *weight* so that its per-block scaling factors are in the
|
||||
UE8M0 (power-of-two) format expected by the new DeepGEMM kernels inplace.
|
||||
@@ -830,8 +898,9 @@ def requant_weight_ue8m0_inplace(
|
||||
return
|
||||
|
||||
if weight.dtype != torch.float8_e4m3fn:
|
||||
raise ValueError("Expected *weight* to be torch.float8_e4m3fn, got "
|
||||
f"{weight.dtype} instead.")
|
||||
raise ValueError(
|
||||
f"Expected *weight* to be torch.float8_e4m3fn, got {weight.dtype} instead."
|
||||
)
|
||||
|
||||
from vllm.utils.deep_gemm import per_block_cast_to_fp8
|
||||
|
||||
@@ -860,8 +929,9 @@ def requant_weight_ue8m0_inplace(
|
||||
s_exp = s_exp[:m_cur, :k_cur]
|
||||
w_dq = w_q.to(torch.float32) * s_exp
|
||||
# Re-quantise using power-of-two scaling (UE8M0).
|
||||
w_requant, s_requant = per_block_cast_to_fp8(w_dq, [block_m, block_k],
|
||||
use_ue8m0=True)
|
||||
w_requant, s_requant = per_block_cast_to_fp8(
|
||||
w_dq, [block_m, block_k], use_ue8m0=True
|
||||
)
|
||||
|
||||
# Write back the results in-place.
|
||||
w_q.copy_(w_requant)
|
||||
@@ -871,28 +941,39 @@ def requant_weight_ue8m0_inplace(
|
||||
def check_aiter_fp8_linear_support() -> bool:
|
||||
"""AITER is only supported on ROCm and only for FP8_FNUZ
|
||||
and at the moment are MI300 series"""
|
||||
return (current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER
|
||||
and envs.VLLM_ROCM_USE_AITER_LINEAR
|
||||
and current_platform.is_fp8_fnuz())
|
||||
return (
|
||||
current_platform.is_rocm()
|
||||
and envs.VLLM_ROCM_USE_AITER
|
||||
and envs.VLLM_ROCM_USE_AITER_LINEAR
|
||||
and current_platform.is_fp8_fnuz()
|
||||
)
|
||||
|
||||
|
||||
def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor:
|
||||
"""Pad the weight tensor. This is an optimization on ROCm platform, which
|
||||
can benefit from tensors located far enough from one another in memory"""
|
||||
if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm()
|
||||
and weight.stride(-1) == 1
|
||||
and (weight.stride(-2) * weight.element_size()) % 512 == 0):
|
||||
if (
|
||||
envs.VLLM_ROCM_FP8_PADDING
|
||||
and current_platform.is_rocm()
|
||||
and weight.stride(-1) == 1
|
||||
and (weight.stride(-2) * weight.element_size()) % 512 == 0
|
||||
):
|
||||
num_pad = 256 // weight.element_size()
|
||||
import torch.nn.functional as F
|
||||
|
||||
weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
|
||||
torch.cuda.empty_cache()
|
||||
return weight
|
||||
|
||||
|
||||
def validate_fp8_block_shape(layer: torch.nn.Module, input_size: int,
|
||||
output_size: int, input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
block_size: list[int]) -> None:
|
||||
def validate_fp8_block_shape(
|
||||
layer: torch.nn.Module,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
block_size: list[int],
|
||||
) -> None:
|
||||
"""Validate block quantization shapes for tensor parallelism."""
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
|
||||
@@ -900,15 +981,18 @@ def validate_fp8_block_shape(layer: torch.nn.Module, input_size: int,
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
|
||||
# Required by row parallel
|
||||
if (tp_size > 1 and input_size // input_size_per_partition == tp_size
|
||||
and input_size_per_partition % block_k != 0):
|
||||
if (
|
||||
tp_size > 1
|
||||
and input_size // input_size_per_partition == tp_size
|
||||
and input_size_per_partition % block_k != 0
|
||||
):
|
||||
raise ValueError(
|
||||
f"Weight input_size_per_partition = {input_size_per_partition} "
|
||||
f"is not divisible by weight quantization block_k = {block_k}.")
|
||||
f"is not divisible by weight quantization block_k = {block_k}."
|
||||
)
|
||||
|
||||
# Required by column parallel or enabling merged weights
|
||||
is_tp_split = (tp_size > 1
|
||||
and output_size // sum(output_partition_sizes) == tp_size)
|
||||
is_tp_split = tp_size > 1 and output_size // sum(output_partition_sizes) == tp_size
|
||||
is_merged_gemm = len(output_partition_sizes) > 1
|
||||
if is_tp_split or is_merged_gemm:
|
||||
sizes_to_check = output_partition_sizes
|
||||
@@ -921,33 +1005,44 @@ def validate_fp8_block_shape(layer: torch.nn.Module, input_size: int,
|
||||
raise ValueError(
|
||||
f"Weight output_partition_size = "
|
||||
f"{output_partition_size} is not divisible by "
|
||||
f"weight quantization block_n = {block_n}.")
|
||||
f"weight quantization block_n = {block_n}."
|
||||
)
|
||||
|
||||
|
||||
def create_fp8_weight_parameter(
|
||||
output_size_per_partition: int, input_size_per_partition: int,
|
||||
weight_loader: Optional[Callable]) -> torch.nn.Parameter:
|
||||
output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
weight_loader: Optional[Callable],
|
||||
) -> torch.nn.Parameter:
|
||||
"""Create FP8 weight parameter."""
|
||||
from vllm.model_executor.parameter import ModelWeightParameter
|
||||
|
||||
return ModelWeightParameter(data=torch.empty(output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=torch.float8_e4m3fn),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
return ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
|
||||
def create_fp8_scale_parameter(
|
||||
parameter_type: torch.nn.Parameter, output_partition_sizes: list[int],
|
||||
input_size_per_partition: int, block_size: Optional[list[int]],
|
||||
weight_loader: Optional[Callable]) -> torch.nn.Parameter:
|
||||
parameter_type: torch.nn.Parameter,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
block_size: Optional[list[int]],
|
||||
weight_loader: Optional[Callable],
|
||||
) -> torch.nn.Parameter:
|
||||
"""Create scale parameter based on quantization strategy."""
|
||||
if parameter_type == ChannelQuantScaleParameter:
|
||||
scale = parameter_type(data=torch.empty(
|
||||
(sum(output_partition_sizes), 1), dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
scale = parameter_type(
|
||||
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
elif parameter_type == BlockQuantScaleParameter:
|
||||
assert block_size is not None
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
@@ -963,9 +1058,10 @@ def create_fp8_scale_parameter(
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
elif parameter_type == PerTensorScaleParameter:
|
||||
scale = parameter_type(data=torch.empty(len(output_partition_sizes),
|
||||
dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
scale = parameter_type(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown parameter type: {parameter_type}")
|
||||
|
||||
@@ -974,14 +1070,15 @@ def create_fp8_scale_parameter(
|
||||
|
||||
|
||||
def create_fp8_input_scale(
|
||||
output_partition_sizes: list[int],
|
||||
weight_loader: Optional[Callable]) -> torch.nn.Parameter:
|
||||
output_partition_sizes: list[int], weight_loader: Optional[Callable]
|
||||
) -> torch.nn.Parameter:
|
||||
"""Create input scale parameter for static activation quantization."""
|
||||
from vllm.model_executor.parameter import PerTensorScaleParameter
|
||||
|
||||
scale = PerTensorScaleParameter(data=torch.empty(
|
||||
len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
return scale
|
||||
|
||||
@@ -990,15 +1087,18 @@ def process_fp8_weight_tensor_strategy(
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
logical_widths: list[int],
|
||||
input_scale: Optional[torch.Tensor] = None
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Process weights for tensor-wise quantization strategy."""
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale)
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
requantize_with_max_scale,
|
||||
)
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight, weight_scale=weight_scale, input_scale=input_scale)
|
||||
weight=weight, weight_scale=weight_scale, input_scale=input_scale
|
||||
)
|
||||
|
||||
# Requantize with max scale
|
||||
weight_scale, weight = requantize_with_max_scale(
|
||||
@@ -1014,15 +1114,17 @@ def process_fp8_weight_tensor_strategy(
|
||||
def process_fp8_weight_channel_strategy(
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Process weights for channel-wise quantization strategy."""
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
normalize_e4m3fn_to_e4m3fnuz)
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
)
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight, weight_scale=weight_scale, input_scale=input_scale)
|
||||
weight=weight, weight_scale=weight_scale, input_scale=input_scale
|
||||
)
|
||||
|
||||
return weight, weight_scale, input_scale
|
||||
|
||||
@@ -1033,37 +1135,48 @@ def process_fp8_weight_block_strategy(
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Process weights for block-wise quantization strategy."""
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
normalize_e4m3fn_to_e4m3fnuz)
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
)
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight, weight_scale=weight_scale)
|
||||
weight=weight, weight_scale=weight_scale
|
||||
)
|
||||
|
||||
weight = _maybe_pad_fp8_weight(weight)
|
||||
return weight, weight_scale
|
||||
|
||||
|
||||
def maybe_post_process_fp8_weight_block(layer: torch.nn.Module,
|
||||
cutlass_block_fp8_supported: bool):
|
||||
def maybe_post_process_fp8_weight_block(
|
||||
layer: torch.nn.Module, cutlass_block_fp8_supported: bool
|
||||
):
|
||||
assert layer.weight_block_size is not None
|
||||
|
||||
from vllm.utils.deep_gemm import (is_deep_gemm_e8m0_used,
|
||||
should_use_deepgemm_for_fp8_linear)
|
||||
from vllm.utils.deep_gemm import (
|
||||
is_deep_gemm_e8m0_used,
|
||||
should_use_deepgemm_for_fp8_linear,
|
||||
)
|
||||
|
||||
# On Blackwell or Hopper, if E8M0 for DeepGemm is used, we need to
|
||||
# requantize the weight and input to the specific scale
|
||||
# at the same time.
|
||||
should_use_deepgemm = should_use_deepgemm_for_fp8_linear(
|
||||
layer.orig_dtype, layer.weight)
|
||||
layer.orig_dtype, layer.weight
|
||||
)
|
||||
if is_deep_gemm_e8m0_used() and should_use_deepgemm:
|
||||
block_sz = tuple(layer.weight_block_size)
|
||||
requant_weight_ue8m0_inplace(layer.weight.data,
|
||||
layer.weight_scale.data, block_sz)
|
||||
requant_weight_ue8m0_inplace(
|
||||
layer.weight.data, layer.weight_scale.data, block_sz
|
||||
)
|
||||
# SM90 Block FP8 CUTLASS requires row-major weight scales
|
||||
elif (current_platform.is_device_capability(90)
|
||||
and cutlass_block_fp8_supported and not should_use_deepgemm):
|
||||
elif (
|
||||
current_platform.is_device_capability(90)
|
||||
and cutlass_block_fp8_supported
|
||||
and not should_use_deepgemm
|
||||
):
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
layer.weight_scale.data.T.contiguous(), requires_grad=False)
|
||||
layer.weight_scale.data.T.contiguous(), requires_grad=False
|
||||
)
|
||||
|
||||
|
||||
def expert_weight_is_col_major(x: torch.Tensor) -> bool:
|
||||
|
||||
Reference in New Issue
Block a user