[Performance] Move apply_w8a8_block_fp8_linear to an op class (#24666)

Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <elizaw.9289@gmail.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Luka Govedič <lgovedic@redhat.com>
This commit is contained in:
ElizaWszola
2025-09-23 21:03:10 +02:00
committed by GitHub
parent 8c1c81a3de
commit 63400259d0
14 changed files with 341 additions and 201 deletions

View File

@@ -13,8 +13,9 @@ import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
group_broadcast)
GroupShape, group_broadcast)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED)
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
@@ -24,6 +25,7 @@ from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import cdiv, direct_register_custom_op
from vllm.utils.deep_gemm import (is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
should_use_deepgemm_for_fp8_linear)
logger = init_logger(__name__)
@@ -35,6 +37,8 @@ def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz
# We need to pass in the is_hopper flag as argument because the function
# current_platform.is_device_capability() is not supported by Torch compiler.
def cutlass_scaled_mm(
A: torch.Tensor,
B: torch.Tensor,
@@ -42,15 +46,17 @@ def cutlass_scaled_mm(
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
is_hopper: Optional[bool] = None,
) -> torch.Tensor:
if is_hopper is None:
is_hopper = current_platform.is_device_capability(90)
return ops.cutlass_scaled_mm(
A,
B.T,
out_dtype=output_dtype,
scale_a=As,
# SM90 block FP8 requires row-major scale_b, which we do ahead of time
scale_b=Bs if block_size is not None
and current_platform.is_device_capability(90) else Bs.T)
scale_b=Bs if block_size is not None and is_hopper else Bs.T)
def rocm_aiter_gemm_w8a8_blockscale_impl(
@@ -98,122 +104,189 @@ if current_platform.is_rocm():
aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128)
def dispatch_w8a8_blockscale_func(
use_cutlass: bool, use_aiter_and_is_supported: bool
) -> Callable[[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
list[int],
torch.dtype,
], torch.Tensor]:
if use_cutlass:
return cutlass_scaled_mm
if (use_aiter_and_is_supported):
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale
return w8a8_block_fp8_matmul
# TODO we should be able to change the type of block_size to GroupShape
# after we resolve GroupShape compilation issue
# https://github.com/vllm-project/vllm/issues/25270
def _w8a8_triton_block_scaled_mm_func(
qx: torch.Tensor,
weight: torch.Tensor,
x_scale: torch.Tensor,
weight_scale: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype,
) -> torch.Tensor:
return w8a8_triton_block_scaled_mm(qx, weight, x_scale, weight_scale,
block_size, output_dtype)
def _w8a8_triton_block_scaled_mm_fake(
qx: torch.Tensor,
weight: torch.Tensor,
x_scale: torch.Tensor,
weight_scale: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype,
) -> torch.Tensor:
return torch.empty((qx.size(0), weight.size(0)),
dtype=output_dtype,
device=qx.device)
direct_register_custom_op(
"w8a8_triton_block_scaled_mm_func",
_w8a8_triton_block_scaled_mm_func,
mutates_args=[],
fake_impl=_w8a8_triton_block_scaled_mm_fake,
dispatch_key="CUDA",
)
# TODO fix ROCm->Triton custom path:
# https://github.com/vllm-project/vllm/issues/14397
def apply_w8a8_block_fp8_linear(
input: torch.Tensor,
weight: torch.Tensor,
block_size: list[int],
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
use_aiter_and_is_supported: bool = False,
) -> torch.Tensor:
assert input_scale is None
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
output_dtype = input.dtype
class W8A8BlockFp8LinearOp:
"""
This class executes a Blocked FP8 linear layer using cutlass if supported
and torch.scaled_mm otherwise.
"""
if should_use_deepgemm_for_fp8_linear(output_dtype, weight):
def __init__(
self,
weight_group_shape: GroupShape,
act_quant_group_shape: GroupShape,
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
use_aiter_and_is_supported: bool = False,
):
self.weight_group_shape = weight_group_shape
self.act_quant_group_shape = act_quant_group_shape
self.is_deep_gemm_supported = is_deep_gemm_supported()
self.is_hopper = current_platform.is_device_capability(90)
# Get the correct blockscale mul and input quant operations.
# We can't use _dispatch_w8a8_blockscale_op to figure out if we want
# to use deepgemm because we don't know the shape of weights (and
# whether deepgemm supports it) at the init time.
self.w8a8_blockscale_op, self.input_quant_op = \
self._dispatch_w8a8_blockscale_op(
cutlass_block_fp8_supported, use_aiter_and_is_supported)
self.deepgemm_input_quant_op = (QuantFP8(
False,
self.act_quant_group_shape,
column_major_scales=True,
use_ue8m0=is_deep_gemm_e8m0_used()) if self.is_deep_gemm_supported
else None)
def apply(
self,
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert input_scale is None
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
output_dtype = input.dtype
q_input, x_scale = per_token_group_quant_fp8(
input_2d,
block_size[1],
column_major_scales=True,
)
if should_use_deepgemm_for_fp8_linear(output_dtype, weight,
self.is_deep_gemm_supported):
output = self._run_deepgemm(input, weight, weight_scale)
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
output = self.w8a8_blockscale_op(input_2d, weight, weight_scale)
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
def _run_deepgemm(
self,
input_2d: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
# ensure DeepGEMM-backed custom op is registered before use
import vllm.model_executor.layers.quantization.deepgemm # noqa: F401
output = torch.ops.vllm.w8a8_block_fp8_matmul_deepgemm(
assert self.deepgemm_input_quant_op is not None
q_input, x_scale = self.deepgemm_input_quant_op(input_2d)
return torch.ops.vllm.w8a8_deepgemm_block_scaled_mm(
q_input,
weight,
x_scale,
weight_scale,
block_size,
output_dtype=output_dtype)
if bias is not None:
output += bias
return output.to(dtype=output_dtype).view(*output_shape)
self.weight_group_shape,
output_dtype=input_2d.dtype)
w8a8_blockscale_func = dispatch_w8a8_blockscale_func(
cutlass_block_fp8_supported, use_aiter_and_is_supported)
if cutlass_block_fp8_supported:
num_pad = 0
if current_platform.is_device_capability(90):
# pad first dimension to be divisible by 4 due to
# cutlass blockwise gemm limitation for hopper
num_pad = 4 - (input_2d.shape[0] % 4)
if num_pad > 0:
input_2d = torch.nn.functional.pad(input_2d,
(0, 0, 0, num_pad),
"constant", 0)
q_input, x_scale = per_token_group_quant_fp8(input_2d,
block_size[1],
column_major_scales=True)
output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale,
block_size, input.dtype)
if num_pad > 0:
output = output[:-num_pad]
else:
if use_aiter_and_is_supported:
q_input, x_scale = aiter_per1x128_quant(
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8)
def _run_cutlass(
self,
input_2d: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
assert self.input_quant_op is not None
if self.is_hopper:
# We pad unconditionally (even if shape is already divisible by 4)
# to support dynamic shape for input_2d.shape[0] in torch.compile
x = torch.nn.functional.pad(input_2d,
(0, 0, 0, -input_2d.shape[0] % 4))
else:
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=False)
x = input_2d
output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale,
block_size, input.dtype)
q_input, x_scale = self.input_quant_op(x)
output = cutlass_scaled_mm(q_input, weight, x_scale, weight_scale,
list(self.weight_group_shape),
input_2d.dtype, self.is_hopper)
output = output[0:input_2d.shape[0], ...]
return output
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
def _run_aiter(
self,
input_2d: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
assert self.act_quant_group_shape == GroupShape(1, 128)
q_input, x_scale = aiter_per1x128_quant(
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8)
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale(
q_input, weight, x_scale, weight_scale, self.weight_group_shape,
input_2d.dtype)
def _run_triton(
self,
input_2d: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
assert self.input_quant_op is not None
q_input, x_scale = self.input_quant_op(input_2d)
return torch.ops.vllm.w8a8_triton_block_scaled_mm_func(
q_input, weight, x_scale, weight_scale, self.weight_group_shape,
input_2d.dtype)
def apply_w8a8_block_fp8_linear_fake(
input: torch.Tensor,
weight: torch.Tensor,
block_size: list[int],
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
use_aiter_and_is_supported: bool = False,
) -> torch.Tensor:
output_shape = [*input.shape[:-1], weight.shape[0]]
return torch.empty(output_shape, dtype=input.dtype, device=input.device)
if not current_platform.is_cpu():
direct_register_custom_op(
op_name="apply_w8a8_block_fp8_linear",
op_func=apply_w8a8_block_fp8_linear,
mutates_args=[],
fake_impl=apply_w8a8_block_fp8_linear_fake,
)
def _dispatch_w8a8_blockscale_op(
self,
use_cutlass: bool,
use_aiter_and_is_supported: bool,
) -> tuple[Callable[[
torch.Tensor,
torch.Tensor,
torch.Tensor,
], torch.Tensor], Optional[QuantFP8]]:
if use_cutlass:
return self._run_cutlass, (QuantFP8(False,
self.act_quant_group_shape,
column_major_scales=True,
use_ue8m0=False))
if use_aiter_and_is_supported:
return self._run_aiter, None
return self._run_triton, (QuantFP8(False,
self.act_quant_group_shape,
column_major_scales=False,
use_ue8m0=False))
def input_to_float8(
@@ -465,7 +538,7 @@ def per_token_group_quant_fp8(
@triton.jit
def _w8a8_block_fp8_matmul(
def _w8a8_triton_block_scaled_mm(
# Pointers to inputs and output
A,
B,
@@ -590,7 +663,7 @@ def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int,
return None
def w8a8_block_fp8_matmul(
def w8a8_triton_block_scaled_mm(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
@@ -650,7 +723,7 @@ def w8a8_block_fp8_matmul(
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) *
triton.cdiv(N, META["BLOCK_SIZE_N"]), )
_w8a8_block_fp8_matmul[grid](
_w8a8_triton_block_scaled_mm[grid](
A,
B,
C,
@@ -997,25 +1070,6 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module,
layer.weight_scale.data.T.contiguous(), requires_grad=False)
def apply_fp8_block_linear(layer: torch.nn.Module, input: torch.Tensor,
bias: Optional[torch.Tensor],
cutlass_block_fp8_supported: bool,
use_aiter_and_is_supported: bool) -> torch.Tensor:
"""Apply block-wise FP8 linear operation."""
assert layer.weight_block_size is not None
return torch.ops.vllm.apply_w8a8_block_fp8_linear(
input=input,
weight=layer.weight,
block_size=layer.weight_block_size,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
cutlass_block_fp8_supported=cutlass_block_fp8_supported,
use_aiter_and_is_supported=use_aiter_and_is_supported,
)
def expert_weight_is_col_major(x: torch.Tensor) -> bool:
assert x.dim() == 3
b, m, n = x.shape