[Perf] Fix and reapply move apply w8a8 block fp8 linear to class (#25696)

Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <elizaw.9289@gmail.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Luka Govedič <lgovedic@redhat.com>
This commit is contained in:
ElizaWszola
2025-10-02 21:35:13 +02:00
committed by GitHub
parent 3d5f1c8640
commit 502640c3f9
13 changed files with 412 additions and 200 deletions

View File

@@ -13,8 +13,9 @@ import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
group_broadcast)
GroupShape, group_broadcast)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED)
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
@@ -24,6 +25,7 @@ from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op
from vllm.utils.deep_gemm import (fp8_gemm_nt, is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
should_use_deepgemm_for_fp8_linear)
logger = init_logger(__name__)
@@ -35,6 +37,8 @@ def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz
# We need to pass in the is_hopper flag as argument because the function
# current_platform.is_device_capability() is not supported by Torch compiler.
def cutlass_scaled_mm(
A: torch.Tensor,
B: torch.Tensor,
@@ -42,15 +46,17 @@ def cutlass_scaled_mm(
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
is_hopper: Optional[bool] = None,
) -> torch.Tensor:
if is_hopper is None:
is_hopper = current_platform.is_device_capability(90)
return ops.cutlass_scaled_mm(
A,
B.T,
out_dtype=output_dtype,
scale_a=As,
# SM90 block FP8 requires row-major scale_b, which we do ahead of time
scale_b=Bs if block_size is not None
and current_platform.is_device_capability(90) else Bs.T)
scale_b=Bs if block_size is not None and is_hopper else Bs.T)
def rocm_aiter_gemm_w8a8_blockscale_impl(
@@ -96,115 +102,251 @@ if current_platform.is_rocm():
aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128)
def dispatch_w8a8_blockscale_func(
use_cutlass: bool, use_aiter_and_is_supported: bool
) -> Callable[[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
list[int],
torch.dtype,
], torch.Tensor]:
if use_cutlass:
return cutlass_scaled_mm
if (use_aiter_and_is_supported):
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale
return w8a8_block_fp8_matmul
# TODO we should be able to change the type of block_size to GroupShape
# after we resolve GroupShape compilation issue
# https://github.com/vllm-project/vllm/issues/25270
def _w8a8_triton_block_scaled_mm_func(
qx: torch.Tensor,
weight: torch.Tensor,
x_scale: torch.Tensor,
weight_scale: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype,
) -> torch.Tensor:
return w8a8_triton_block_scaled_mm(qx, weight, x_scale, weight_scale,
block_size, output_dtype)
def _w8a8_triton_block_scaled_mm_fake(
qx: torch.Tensor,
weight: torch.Tensor,
x_scale: torch.Tensor,
weight_scale: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype,
) -> torch.Tensor:
return torch.empty((qx.size(0), weight.size(0)),
dtype=output_dtype,
device=qx.device)
direct_register_custom_op(
"w8a8_triton_block_scaled_mm_func",
_w8a8_triton_block_scaled_mm_func,
fake_impl=_w8a8_triton_block_scaled_mm_fake,
)
def _padded_cutlass(
qx: torch.Tensor,
weight: torch.Tensor,
x_scale: torch.Tensor,
weight_scale: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype,
) -> torch.Tensor:
pad_multiple = 4
dim = qx.shape[0]
padded = dim if dim % pad_multiple == 0 else dim + pad_multiple - (
dim % pad_multiple)
padded_shape = [padded, *qx.shape[1:]]
padded_qx = torch.zeros(padded_shape, device=qx.device, dtype=qx.dtype)
padded_qx[0:qx.shape[0], ...].copy_(qx)
padded_x_scale_shape = [*x_scale.shape[1:], padded]
padded_x_scale = torch.ones(padded_x_scale_shape,
device=x_scale.device,
dtype=x_scale.dtype).permute(-1, -2)
padded_x_scale[0:x_scale.shape[0], ...].copy_(x_scale)
output = cutlass_scaled_mm(padded_qx, weight, padded_x_scale, weight_scale,
block_size, output_dtype, True)
return output[0:qx.shape[0], ...]
def _padded_cutlass_fake(
qx: torch.Tensor,
weight: torch.Tensor,
x_scale: torch.Tensor,
weight_scale: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype,
) -> torch.Tensor:
return torch.empty((qx.size(0), weight.size(0)),
dtype=output_dtype,
device=qx.device)
direct_register_custom_op(
"padded_cutlass",
_padded_cutlass,
fake_impl=_padded_cutlass_fake,
)
def _fp8_gemm_nt_op(q_input: torch.Tensor, input_scale: torch.Tensor,
weight: torch.Tensor, weight_scale: torch.Tensor,
output: torch.Tensor, use_deep_gemm_e8m0: bool) -> None:
fp8_gemm_nt((q_input, input_scale), (weight, weight_scale),
output,
is_deep_gemm_e8m0_used=use_deep_gemm_e8m0)
def _fp8_gemm_nt_op_fake(q_input: torch.Tensor, input_scale: torch.Tensor,
weight: torch.Tensor, weight_scale: torch.Tensor,
output: torch.Tensor,
use_deep_gemm_e8m0: bool) -> None:
return None
direct_register_custom_op(
"fp8_gemm_nt_op",
_fp8_gemm_nt_op,
mutates_args=["output"],
fake_impl=_fp8_gemm_nt_op_fake,
)
# TODO fix ROCm->Triton custom path:
# https://github.com/vllm-project/vllm/issues/14397
def apply_w8a8_block_fp8_linear(
input: torch.Tensor,
weight: torch.Tensor,
block_size: list[int],
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
use_aiter_and_is_supported: bool = False,
) -> torch.Tensor:
assert input_scale is None
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
output_dtype = input.dtype
class W8A8BlockFp8LinearOp:
"""
This class executes a Blocked FP8 linear layer using cutlass if supported
and torch.scaled_mm otherwise.
"""
if should_use_deepgemm_for_fp8_linear(output_dtype, weight):
def __init__(
self,
weight_group_shape: GroupShape,
act_quant_group_shape: GroupShape,
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
use_aiter_and_is_supported: bool = False,
):
self.weight_group_shape = weight_group_shape
self.act_quant_group_shape = act_quant_group_shape
self.is_deep_gemm_supported = is_deep_gemm_supported()
self.is_hopper = current_platform.is_device_capability(90)
self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used()
# Get the correct blockscale mul and input quant operations.
# We can't use _dispatch_w8a8_blockscale_op to figure out if we want
# to use deepgemm because we don't know the shape of weights (and
# whether deepgemm supports it) at the init time.
self.w8a8_blockscale_op, self.input_quant_op = \
self._dispatch_w8a8_blockscale_op(
cutlass_block_fp8_supported, use_aiter_and_is_supported)
self.deepgemm_input_quant_op = (QuantFP8(
False,
self.act_quant_group_shape,
column_major_scales=True,
use_ue8m0=self.use_deep_gemm_e8m0) if self.is_deep_gemm_supported
else None)
def apply(
self,
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert input_scale is None
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
output_dtype = input.dtype
q_input, x_scale = per_token_group_quant_fp8(
input_2d,
block_size[1],
column_major_scales=True,
)
if should_use_deepgemm_for_fp8_linear(output_dtype, weight,
self.is_deep_gemm_supported):
output = self._run_deepgemm(input_2d, weight, weight_scale)
else:
output = self.w8a8_blockscale_op(input_2d, weight, weight_scale)
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
def _run_deepgemm(
self,
input_2d: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
assert self.deepgemm_input_quant_op is not None
q_input, input_scale = self.deepgemm_input_quant_op(input_2d)
output = torch.empty((q_input.shape[0], weight.shape[0]),
dtype=torch.bfloat16,
device=q_input.device)
fp8_gemm_nt((q_input, x_scale), (weight, weight_scale), output)
if bias is not None:
output += bias
return output.to(dtype=output_dtype).view(*output_shape)
torch.ops.vllm.fp8_gemm_nt_op(q_input, input_scale, weight,
weight_scale, output,
self.use_deep_gemm_e8m0)
return output
w8a8_blockscale_func = dispatch_w8a8_blockscale_func(
cutlass_block_fp8_supported, use_aiter_and_is_supported)
if cutlass_block_fp8_supported:
num_pad = 0
if current_platform.is_device_capability(90):
# pad first dimension to be divisible by 4 due to
# cutlass blockwise gemm limitation for hopper
num_pad = 4 - (input_2d.shape[0] % 4)
if num_pad > 0:
input_2d = torch.nn.functional.pad(input_2d,
(0, 0, 0, num_pad),
"constant", 0)
q_input, x_scale = per_token_group_quant_fp8(input_2d,
block_size[1],
column_major_scales=True)
output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale,
block_size, input.dtype)
if num_pad > 0:
output = output[:-num_pad]
else:
if use_aiter_and_is_supported:
q_input, x_scale = aiter_per1x128_quant(
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8)
def _run_cutlass(
self,
input_2d: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
assert self.input_quant_op is not None
q_input, input_scale = self.input_quant_op(input_2d)
if self.is_hopper:
return torch.ops.vllm.padded_cutlass(q_input, weight, input_scale,
weight_scale,
list(self.weight_group_shape),
input_2d.dtype)
else:
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=False)
return cutlass_scaled_mm(q_input, weight,
input_scale, weight_scale,
list(self.weight_group_shape),
input_2d.dtype, False)
output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale,
block_size, input.dtype)
def _run_aiter(
self,
input_2d: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
assert self.act_quant_group_shape == GroupShape(1, 128)
q_input, input_scale = aiter_per1x128_quant(
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8)
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale(
q_input, weight, input_scale, weight_scale,
self.weight_group_shape, input_2d.dtype)
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
def _run_triton(
self,
input_2d: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
assert self.input_quant_op is not None
q_input, input_scale = self.input_quant_op(input_2d)
return torch.ops.vllm.w8a8_triton_block_scaled_mm_func(
q_input, weight, input_scale, weight_scale,
self.weight_group_shape, input_2d.dtype)
def apply_w8a8_block_fp8_linear_fake(
input: torch.Tensor,
weight: torch.Tensor,
block_size: list[int],
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
use_aiter_and_is_supported: bool = False,
) -> torch.Tensor:
output_shape = [*input.shape[:-1], weight.shape[0]]
return torch.empty(output_shape, dtype=input.dtype, device=input.device)
if not current_platform.is_cpu():
direct_register_custom_op(
op_name="apply_w8a8_block_fp8_linear",
op_func=apply_w8a8_block_fp8_linear,
mutates_args=[],
fake_impl=apply_w8a8_block_fp8_linear_fake,
)
def _dispatch_w8a8_blockscale_op(
self,
use_cutlass: bool,
use_aiter_and_is_supported: bool,
) -> tuple[Callable[[
torch.Tensor,
torch.Tensor,
torch.Tensor,
], torch.Tensor], Optional[QuantFP8]]:
if use_cutlass:
return self._run_cutlass, (QuantFP8(False,
self.act_quant_group_shape,
column_major_scales=True,
use_ue8m0=False))
if use_aiter_and_is_supported:
return self._run_aiter, None
return self._run_triton, (QuantFP8(False,
self.act_quant_group_shape,
column_major_scales=False,
use_ue8m0=False))
def input_to_float8(
@@ -456,7 +598,7 @@ def per_token_group_quant_fp8(
@triton.jit
def _w8a8_block_fp8_matmul(
def _w8a8_triton_block_scaled_mm(
# Pointers to inputs and output
A,
B,
@@ -581,7 +723,7 @@ def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int,
return None
def w8a8_block_fp8_matmul(
def w8a8_triton_block_scaled_mm(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
@@ -641,7 +783,7 @@ def w8a8_block_fp8_matmul(
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) *
triton.cdiv(N, META["BLOCK_SIZE_N"]), )
_w8a8_block_fp8_matmul[grid](
_w8a8_triton_block_scaled_mm[grid](
A,
B,
C,
@@ -924,25 +1066,6 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module,
layer.weight_scale.data.T.contiguous(), requires_grad=False)
def apply_fp8_block_linear(layer: torch.nn.Module, input: torch.Tensor,
bias: Optional[torch.Tensor],
cutlass_block_fp8_supported: bool,
use_aiter_and_is_supported: bool) -> torch.Tensor:
"""Apply block-wise FP8 linear operation."""
assert layer.weight_block_size is not None
return torch.ops.vllm.apply_w8a8_block_fp8_linear(
input=input,
weight=layer.weight,
block_size=layer.weight_block_size,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
cutlass_block_fp8_supported=cutlass_block_fp8_supported,
use_aiter_and_is_supported=use_aiter_and_is_supported,
)
def expert_weight_is_col_major(x: torch.Tensor) -> bool:
assert x.dim() == 3
b, m, n = x.shape