Signed-off-by: maral <maralbahari.98@gmail.com> Signed-off-by: Maral <maralbahari.98@gmail.com>
168 lines
6.0 KiB
Python
168 lines
6.0 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
|
|
import torch
|
|
|
|
from vllm import _custom_ops as ops
|
|
from vllm._aiter_ops import rocm_aiter_ops
|
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|
GroupShape,
|
|
)
|
|
from vllm.platforms import current_platform
|
|
|
|
from .BlockScaledMMLinearKernel import (
|
|
Fp8BlockScaledMMLinearKernel,
|
|
FP8ScaledMMLinearLayerConfig,
|
|
)
|
|
from .cutlass import CutlassInt8ScaledMMLinearKernel
|
|
from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig
|
|
|
|
|
|
class AiterInt8ScaledMMLinearKernel(CutlassInt8ScaledMMLinearKernel):
|
|
@classmethod
|
|
def is_supported(
|
|
cls, compute_capability: int | None = None
|
|
) -> tuple[bool, str | None]:
|
|
if not current_platform.is_rocm():
|
|
return False, "Requires ROCm."
|
|
|
|
if compute_capability is not None and compute_capability < 90:
|
|
return False, "requires compute capability 90 and above."
|
|
|
|
try:
|
|
import aiter # noqa: F401 # deliberately attempt to import aiter
|
|
except Exception:
|
|
return False, "requires `aiter` to be installed."
|
|
|
|
if not rocm_aiter_ops.is_linear_enabled():
|
|
return (
|
|
False,
|
|
"requires setting `VLLM_ROCM_USE_AITER=1` "
|
|
"and `VLLM_ROCM_USE_AITER_LINEAR=1`. "
|
|
"`VLLM_ROCM_USE_AITER_LINEAR` default is True.",
|
|
)
|
|
return True, None
|
|
|
|
@classmethod
|
|
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
|
if not c.input_symmetric:
|
|
return False, "supports symmetric quantization only."
|
|
return True, None
|
|
|
|
def apply_weights(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
bias: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
"""
|
|
`AiterInt8ScaledMMLinearKernel` implements a fused version of
|
|
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
|
|
where scale_a * a and scale_b * b are implemented using numpy-style
|
|
broadcasting.
|
|
Currently only support per-tensor-per-tensor GEMM
|
|
and per-token-per-channel GEMM through AITER
|
|
w8a8 scaled gemm. `AiterInt8ScaledMMLinearKernel` also does not support
|
|
ATIER block scaled GEMM and mix-precision GEMM.
|
|
"""
|
|
w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer)
|
|
|
|
# ops.scaled_int8_quant supports both dynamic and static quant:
|
|
# * dynamic, i_s is None and x_s computed from x.
|
|
# * static, i_s is scalar and x_s is i_s.
|
|
symmetric = azp_adj is None
|
|
assert symmetric, (
|
|
"AiterInt8ScaledMMLinearKernel only supports symmetric quantization."
|
|
)
|
|
x_q, x_s, x_zp = ops.scaled_int8_quant(x, i_s, i_zp, symmetric=symmetric)
|
|
|
|
assert x_zp is None, (
|
|
"AiterInt8ScaledMMLinearKernel only supports symmetric quantization."
|
|
)
|
|
out_dtype = x.dtype
|
|
|
|
assert w_q.shape[0] % 16 == 0 and w_q.shape[1] % 16 == 0
|
|
assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
|
|
assert bias is None or bias.shape[0] == w_q.shape[1] and bias.dtype == out_dtype
|
|
|
|
m = x_q.shape[0] # a
|
|
n = w_q.shape[1] # b
|
|
|
|
per_tensor_scale_a = x_s.numel() == 1
|
|
per_tensor_scale_b = w_s.numel() == 1
|
|
per_token_scale_a = x_s.numel() == m
|
|
per_channel_scale_b = w_s.numel() == n
|
|
|
|
# @TODO:
|
|
# Maybe broadcast the per-tensor-scale into per-channel-scale
|
|
# if one of the scale is a per-channel-scale.
|
|
# For now, it only supports:
|
|
# - per-tensor-per-tensor a8w8 scaled GEMM, and
|
|
# - per-token-per-channel a8w8 scaled GEMM
|
|
assert (per_tensor_scale_a and per_tensor_scale_b) or (
|
|
per_token_scale_a and per_channel_scale_b
|
|
), (
|
|
"Currently only support per-tensor-per-tensor GEMM "
|
|
" and per-token-per-channel GEMM through AITER"
|
|
" w8a8 scaled gemm. `AiterInt8ScaledMMLinearKernel` "
|
|
"does not support AITER block scaled GEMM."
|
|
)
|
|
|
|
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
|
|
# a to be [M, K]
|
|
# b to be [N, K]
|
|
# CutlassInt8ScaledMMLinearKernel prepare weight `w_q` in [K, N] format
|
|
return rocm_aiter_ops.gemm_a8w8(x_q, w_q.t(), x_s, w_s, bias, out_dtype)
|
|
|
|
|
|
class AiterFp8BlockScaledMMKernel(Fp8BlockScaledMMLinearKernel):
|
|
def __init__(self, config: FP8ScaledMMLinearLayerConfig):
|
|
super().__init__(config)
|
|
n, k = config.weight_shape
|
|
|
|
self.use_triton = (
|
|
not current_platform.is_fp8_fnuz()
|
|
and rocm_aiter_ops.is_triton_gemm_w8a8_tuned(n, k)
|
|
)
|
|
|
|
@classmethod
|
|
def is_supported(cls, compute_capability=None):
|
|
return (
|
|
rocm_aiter_ops.is_linear_enabled(),
|
|
"Only supported on ROCm platform \
|
|
with aiter package installed.",
|
|
)
|
|
|
|
@classmethod
|
|
def can_implement(cls, config: FP8ScaledMMLinearLayerConfig):
|
|
can_implement_base, reason = super().can_implement(config)
|
|
if not can_implement_base:
|
|
return can_implement_base, reason
|
|
|
|
act_quant_desc = config.activation_quant_key.scale
|
|
if act_quant_desc.group_shape != GroupShape(1, 128):
|
|
return (
|
|
False,
|
|
"Supports only dynamic per token group activation "
|
|
"quantization with group_shape=(1,128).",
|
|
)
|
|
return True, None
|
|
|
|
def apply_block_scaled_mm(
|
|
self,
|
|
A: torch.Tensor,
|
|
B: torch.Tensor,
|
|
As: torch.Tensor,
|
|
Bs: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
out_dtype = self.config.out_dtype
|
|
if self.use_triton:
|
|
gemm_a8w8_blockscale_op = rocm_aiter_ops.triton_gemm_a8w8_blockscale
|
|
else:
|
|
gemm_a8w8_blockscale_op = rocm_aiter_ops.gemm_a8w8_blockscale
|
|
|
|
return gemm_a8w8_blockscale_op(
|
|
A, B, As, Bs, list(self.weight_group_shape), output_dtype=out_dtype
|
|
)
|