[Performance] Move apply_w8a8_block_fp8_linear to an op class (#24666)
Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: ElizaWszola <elizaw.9289@gmail.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Luka Govedič <lgovedic@redhat.com>
This commit is contained in:
@@ -13,8 +13,9 @@ import torch
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
group_broadcast)
|
||||
GroupShape, group_broadcast)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
CUTLASS_BLOCK_FP8_SUPPORTED)
|
||||
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
|
||||
@@ -24,6 +25,7 @@ from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import cdiv, direct_register_custom_op
|
||||
from vllm.utils.deep_gemm import (is_deep_gemm_e8m0_used,
|
||||
is_deep_gemm_supported,
|
||||
should_use_deepgemm_for_fp8_linear)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -35,6 +37,8 @@ def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
|
||||
return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz
|
||||
|
||||
|
||||
# We need to pass in the is_hopper flag as argument because the function
|
||||
# current_platform.is_device_capability() is not supported by Torch compiler.
|
||||
def cutlass_scaled_mm(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
@@ -42,15 +46,17 @@ def cutlass_scaled_mm(
|
||||
Bs: torch.Tensor,
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
is_hopper: Optional[bool] = None,
|
||||
) -> torch.Tensor:
|
||||
if is_hopper is None:
|
||||
is_hopper = current_platform.is_device_capability(90)
|
||||
return ops.cutlass_scaled_mm(
|
||||
A,
|
||||
B.T,
|
||||
out_dtype=output_dtype,
|
||||
scale_a=As,
|
||||
# SM90 block FP8 requires row-major scale_b, which we do ahead of time
|
||||
scale_b=Bs if block_size is not None
|
||||
and current_platform.is_device_capability(90) else Bs.T)
|
||||
scale_b=Bs if block_size is not None and is_hopper else Bs.T)
|
||||
|
||||
|
||||
def rocm_aiter_gemm_w8a8_blockscale_impl(
|
||||
@@ -98,122 +104,189 @@ if current_platform.is_rocm():
|
||||
aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128)
|
||||
|
||||
|
||||
def dispatch_w8a8_blockscale_func(
|
||||
use_cutlass: bool, use_aiter_and_is_supported: bool
|
||||
) -> Callable[[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
list[int],
|
||||
torch.dtype,
|
||||
], torch.Tensor]:
|
||||
if use_cutlass:
|
||||
return cutlass_scaled_mm
|
||||
if (use_aiter_and_is_supported):
|
||||
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale
|
||||
return w8a8_block_fp8_matmul
|
||||
# TODO we should be able to change the type of block_size to GroupShape
|
||||
# after we resolve GroupShape compilation issue
|
||||
# https://github.com/vllm-project/vllm/issues/25270
|
||||
def _w8a8_triton_block_scaled_mm_func(
|
||||
qx: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
x_scale: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
return w8a8_triton_block_scaled_mm(qx, weight, x_scale, weight_scale,
|
||||
block_size, output_dtype)
|
||||
|
||||
|
||||
def _w8a8_triton_block_scaled_mm_fake(
|
||||
qx: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
x_scale: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty((qx.size(0), weight.size(0)),
|
||||
dtype=output_dtype,
|
||||
device=qx.device)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
"w8a8_triton_block_scaled_mm_func",
|
||||
_w8a8_triton_block_scaled_mm_func,
|
||||
mutates_args=[],
|
||||
fake_impl=_w8a8_triton_block_scaled_mm_fake,
|
||||
dispatch_key="CUDA",
|
||||
)
|
||||
|
||||
|
||||
# TODO fix ROCm->Triton custom path:
|
||||
# https://github.com/vllm-project/vllm/issues/14397
|
||||
def apply_w8a8_block_fp8_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
block_size: list[int],
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
|
||||
use_aiter_and_is_supported: bool = False,
|
||||
) -> torch.Tensor:
|
||||
assert input_scale is None
|
||||
# View input as 2D matrix for fp8 methods
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||
output_dtype = input.dtype
|
||||
class W8A8BlockFp8LinearOp:
|
||||
"""
|
||||
This class executes a Blocked FP8 linear layer using cutlass if supported
|
||||
and torch.scaled_mm otherwise.
|
||||
"""
|
||||
|
||||
if should_use_deepgemm_for_fp8_linear(output_dtype, weight):
|
||||
def __init__(
|
||||
self,
|
||||
weight_group_shape: GroupShape,
|
||||
act_quant_group_shape: GroupShape,
|
||||
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
|
||||
use_aiter_and_is_supported: bool = False,
|
||||
):
|
||||
self.weight_group_shape = weight_group_shape
|
||||
self.act_quant_group_shape = act_quant_group_shape
|
||||
self.is_deep_gemm_supported = is_deep_gemm_supported()
|
||||
self.is_hopper = current_platform.is_device_capability(90)
|
||||
|
||||
# Get the correct blockscale mul and input quant operations.
|
||||
# We can't use _dispatch_w8a8_blockscale_op to figure out if we want
|
||||
# to use deepgemm because we don't know the shape of weights (and
|
||||
# whether deepgemm supports it) at the init time.
|
||||
self.w8a8_blockscale_op, self.input_quant_op = \
|
||||
self._dispatch_w8a8_blockscale_op(
|
||||
cutlass_block_fp8_supported, use_aiter_and_is_supported)
|
||||
self.deepgemm_input_quant_op = (QuantFP8(
|
||||
False,
|
||||
self.act_quant_group_shape,
|
||||
column_major_scales=True,
|
||||
use_ue8m0=is_deep_gemm_e8m0_used()) if self.is_deep_gemm_supported
|
||||
else None)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert input_scale is None
|
||||
# View input as 2D matrix for fp8 methods
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||
output_dtype = input.dtype
|
||||
|
||||
q_input, x_scale = per_token_group_quant_fp8(
|
||||
input_2d,
|
||||
block_size[1],
|
||||
column_major_scales=True,
|
||||
)
|
||||
if should_use_deepgemm_for_fp8_linear(output_dtype, weight,
|
||||
self.is_deep_gemm_supported):
|
||||
output = self._run_deepgemm(input, weight, weight_scale)
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(dtype=input.dtype).view(*output_shape)
|
||||
|
||||
output = self.w8a8_blockscale_op(input_2d, weight, weight_scale)
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(dtype=input.dtype).view(*output_shape)
|
||||
|
||||
def _run_deepgemm(
|
||||
self,
|
||||
input_2d: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# ensure DeepGEMM-backed custom op is registered before use
|
||||
import vllm.model_executor.layers.quantization.deepgemm # noqa: F401
|
||||
|
||||
output = torch.ops.vllm.w8a8_block_fp8_matmul_deepgemm(
|
||||
assert self.deepgemm_input_quant_op is not None
|
||||
q_input, x_scale = self.deepgemm_input_quant_op(input_2d)
|
||||
return torch.ops.vllm.w8a8_deepgemm_block_scaled_mm(
|
||||
q_input,
|
||||
weight,
|
||||
x_scale,
|
||||
weight_scale,
|
||||
block_size,
|
||||
output_dtype=output_dtype)
|
||||
if bias is not None:
|
||||
output += bias
|
||||
return output.to(dtype=output_dtype).view(*output_shape)
|
||||
self.weight_group_shape,
|
||||
output_dtype=input_2d.dtype)
|
||||
|
||||
w8a8_blockscale_func = dispatch_w8a8_blockscale_func(
|
||||
cutlass_block_fp8_supported, use_aiter_and_is_supported)
|
||||
if cutlass_block_fp8_supported:
|
||||
num_pad = 0
|
||||
if current_platform.is_device_capability(90):
|
||||
# pad first dimension to be divisible by 4 due to
|
||||
# cutlass blockwise gemm limitation for hopper
|
||||
num_pad = 4 - (input_2d.shape[0] % 4)
|
||||
if num_pad > 0:
|
||||
input_2d = torch.nn.functional.pad(input_2d,
|
||||
(0, 0, 0, num_pad),
|
||||
"constant", 0)
|
||||
q_input, x_scale = per_token_group_quant_fp8(input_2d,
|
||||
block_size[1],
|
||||
column_major_scales=True)
|
||||
output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale,
|
||||
block_size, input.dtype)
|
||||
if num_pad > 0:
|
||||
output = output[:-num_pad]
|
||||
else:
|
||||
if use_aiter_and_is_supported:
|
||||
q_input, x_scale = aiter_per1x128_quant(
|
||||
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8)
|
||||
def _run_cutlass(
|
||||
self,
|
||||
input_2d: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
assert self.input_quant_op is not None
|
||||
if self.is_hopper:
|
||||
# We pad unconditionally (even if shape is already divisible by 4)
|
||||
# to support dynamic shape for input_2d.shape[0] in torch.compile
|
||||
x = torch.nn.functional.pad(input_2d,
|
||||
(0, 0, 0, -input_2d.shape[0] % 4))
|
||||
else:
|
||||
q_input, x_scale = per_token_group_quant_fp8(
|
||||
input_2d, block_size[1], column_major_scales=False)
|
||||
x = input_2d
|
||||
|
||||
output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale,
|
||||
block_size, input.dtype)
|
||||
q_input, x_scale = self.input_quant_op(x)
|
||||
output = cutlass_scaled_mm(q_input, weight, x_scale, weight_scale,
|
||||
list(self.weight_group_shape),
|
||||
input_2d.dtype, self.is_hopper)
|
||||
output = output[0:input_2d.shape[0], ...]
|
||||
return output
|
||||
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(dtype=input.dtype).view(*output_shape)
|
||||
def _run_aiter(
|
||||
self,
|
||||
input_2d: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
assert self.act_quant_group_shape == GroupShape(1, 128)
|
||||
q_input, x_scale = aiter_per1x128_quant(
|
||||
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8)
|
||||
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale(
|
||||
q_input, weight, x_scale, weight_scale, self.weight_group_shape,
|
||||
input_2d.dtype)
|
||||
|
||||
def _run_triton(
|
||||
self,
|
||||
input_2d: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
assert self.input_quant_op is not None
|
||||
q_input, x_scale = self.input_quant_op(input_2d)
|
||||
return torch.ops.vllm.w8a8_triton_block_scaled_mm_func(
|
||||
q_input, weight, x_scale, weight_scale, self.weight_group_shape,
|
||||
input_2d.dtype)
|
||||
|
||||
def apply_w8a8_block_fp8_linear_fake(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
block_size: list[int],
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
|
||||
use_aiter_and_is_supported: bool = False,
|
||||
) -> torch.Tensor:
|
||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||
return torch.empty(output_shape, dtype=input.dtype, device=input.device)
|
||||
|
||||
|
||||
if not current_platform.is_cpu():
|
||||
direct_register_custom_op(
|
||||
op_name="apply_w8a8_block_fp8_linear",
|
||||
op_func=apply_w8a8_block_fp8_linear,
|
||||
mutates_args=[],
|
||||
fake_impl=apply_w8a8_block_fp8_linear_fake,
|
||||
)
|
||||
def _dispatch_w8a8_blockscale_op(
|
||||
self,
|
||||
use_cutlass: bool,
|
||||
use_aiter_and_is_supported: bool,
|
||||
) -> tuple[Callable[[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
], torch.Tensor], Optional[QuantFP8]]:
|
||||
if use_cutlass:
|
||||
return self._run_cutlass, (QuantFP8(False,
|
||||
self.act_quant_group_shape,
|
||||
column_major_scales=True,
|
||||
use_ue8m0=False))
|
||||
if use_aiter_and_is_supported:
|
||||
return self._run_aiter, None
|
||||
return self._run_triton, (QuantFP8(False,
|
||||
self.act_quant_group_shape,
|
||||
column_major_scales=False,
|
||||
use_ue8m0=False))
|
||||
|
||||
|
||||
def input_to_float8(
|
||||
@@ -465,7 +538,7 @@ def per_token_group_quant_fp8(
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _w8a8_block_fp8_matmul(
|
||||
def _w8a8_triton_block_scaled_mm(
|
||||
# Pointers to inputs and output
|
||||
A,
|
||||
B,
|
||||
@@ -590,7 +663,7 @@ def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int,
|
||||
return None
|
||||
|
||||
|
||||
def w8a8_block_fp8_matmul(
|
||||
def w8a8_triton_block_scaled_mm(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
@@ -650,7 +723,7 @@ def w8a8_block_fp8_matmul(
|
||||
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) *
|
||||
triton.cdiv(N, META["BLOCK_SIZE_N"]), )
|
||||
|
||||
_w8a8_block_fp8_matmul[grid](
|
||||
_w8a8_triton_block_scaled_mm[grid](
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
@@ -997,25 +1070,6 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module,
|
||||
layer.weight_scale.data.T.contiguous(), requires_grad=False)
|
||||
|
||||
|
||||
def apply_fp8_block_linear(layer: torch.nn.Module, input: torch.Tensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
cutlass_block_fp8_supported: bool,
|
||||
use_aiter_and_is_supported: bool) -> torch.Tensor:
|
||||
"""Apply block-wise FP8 linear operation."""
|
||||
assert layer.weight_block_size is not None
|
||||
|
||||
return torch.ops.vllm.apply_w8a8_block_fp8_linear(
|
||||
input=input,
|
||||
weight=layer.weight,
|
||||
block_size=layer.weight_block_size,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
cutlass_block_fp8_supported=cutlass_block_fp8_supported,
|
||||
use_aiter_and_is_supported=use_aiter_and_is_supported,
|
||||
)
|
||||
|
||||
|
||||
def expert_weight_is_col_major(x: torch.Tensor) -> bool:
|
||||
assert x.dim() == 3
|
||||
b, m, n = x.shape
|
||||
|
||||
Reference in New Issue
Block a user