Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -24,7 +24,6 @@ class MPLinearLayerConfig:
|
||||
|
||||
|
||||
class MPLinearKernel(ABC):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
@@ -32,16 +31,17 @@ class MPLinearKernel(ABC):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(self,
|
||||
c: MPLinearLayerConfig,
|
||||
w_q_param_name: str,
|
||||
w_s_param_name: str,
|
||||
w_zp_param_name: Optional[str] = None,
|
||||
w_gidx_param_name: Optional[str] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
c: MPLinearLayerConfig,
|
||||
w_q_param_name: str,
|
||||
w_s_param_name: str,
|
||||
w_zp_param_name: Optional[str] = None,
|
||||
w_gidx_param_name: Optional[str] = None,
|
||||
) -> None:
|
||||
assert self.can_implement(c)
|
||||
self.config = c
|
||||
self.w_q_name = w_q_param_name
|
||||
@@ -58,31 +58,34 @@ class MPLinearKernel(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def _transform_param(self, layer: torch.nn.Module, name: Optional[str],
|
||||
fn: Callable) -> None:
|
||||
def _transform_param(
|
||||
self, layer: torch.nn.Module, name: Optional[str], fn: Callable
|
||||
) -> None:
|
||||
if name is not None and getattr(layer, name, None) is not None:
|
||||
|
||||
old_param = getattr(layer, name)
|
||||
new_param = fn(old_param)
|
||||
# replace the parameter with torch.nn.Parameter for TorchDynamo
|
||||
# compatibility
|
||||
replace_parameter(
|
||||
layer, name,
|
||||
torch.nn.Parameter(new_param.data, requires_grad=False))
|
||||
layer, name, torch.nn.Parameter(new_param.data, requires_grad=False)
|
||||
)
|
||||
|
||||
def _get_weight_params(
|
||||
self, layer: torch.nn.Module) -> tuple[
|
||||
torch.Tensor, # w_q
|
||||
torch.Tensor, # w_s
|
||||
Optional[torch.Tensor], # w_zp,
|
||||
Optional[torch.Tensor] # w_gidx
|
||||
]:
|
||||
self, layer: torch.nn.Module
|
||||
) -> tuple[
|
||||
torch.Tensor, # w_q
|
||||
torch.Tensor, # w_s
|
||||
Optional[torch.Tensor], # w_zp,
|
||||
Optional[torch.Tensor], # w_gidx
|
||||
]:
|
||||
return (
|
||||
getattr(layer, self.w_q_name),
|
||||
getattr(layer, self.w_s_name),
|
||||
|
||||
@@ -5,23 +5,33 @@ from typing import Optional
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark import ( # noqa: E501
|
||||
AllSparkLinearKernel)
|
||||
AllSparkLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.bitblas import ( # noqa: E501
|
||||
BitBLASLinearKernel)
|
||||
BitBLASLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.conch import ( # noqa: E501
|
||||
ConchLinearKernel)
|
||||
ConchLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.cutlass import ( # noqa: E501
|
||||
CutlassW4A8LinearKernel)
|
||||
CutlassW4A8LinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.dynamic_4bit import ( # noqa: E501
|
||||
Dynamic4bitLinearKernel)
|
||||
Dynamic4bitLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501
|
||||
ExllamaLinearKernel)
|
||||
ExllamaLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501
|
||||
MacheteLinearKernel)
|
||||
MacheteLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.marlin import ( # noqa: E501
|
||||
MarlinLinearKernel)
|
||||
MarlinLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKernel import ( # noqa: E501
|
||||
MPLinearKernel, MPLinearLayerConfig)
|
||||
MPLinearKernel,
|
||||
MPLinearLayerConfig,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# in priority/performance order (when available)
|
||||
@@ -38,11 +48,11 @@ _POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [
|
||||
|
||||
|
||||
def choose_mp_linear_kernel(
|
||||
config: MPLinearLayerConfig,
|
||||
compute_capability: Optional[int] = None) -> type[MPLinearKernel]:
|
||||
config: MPLinearLayerConfig, compute_capability: Optional[int] = None
|
||||
) -> type[MPLinearKernel]:
|
||||
"""
|
||||
Choose an MPLinearKernel that can implement the given config for the given
|
||||
compute capability. Attempts to choose the best kernel in terms of
|
||||
compute capability. Attempts to choose the best kernel in terms of
|
||||
performance.
|
||||
|
||||
Args:
|
||||
@@ -69,14 +79,18 @@ def choose_mp_linear_kernel(
|
||||
for kernel in _POSSIBLE_KERNELS:
|
||||
if kernel.__name__ in envs.VLLM_DISABLED_KERNELS:
|
||||
failure_reasons.append(
|
||||
f' {kernel.__name__} disabled by environment variable')
|
||||
f" {kernel.__name__} disabled by environment variable"
|
||||
)
|
||||
continue
|
||||
if (compute_capability is not None
|
||||
and kernel.get_min_capability() > compute_capability):
|
||||
if (
|
||||
compute_capability is not None
|
||||
and kernel.get_min_capability() > compute_capability
|
||||
):
|
||||
failure_reasons.append(
|
||||
f"{kernel.__name__} requires capability "
|
||||
f"{kernel.get_min_capability()}, current compute "
|
||||
f" capability is {compute_capability}")
|
||||
f" capability is {compute_capability}"
|
||||
)
|
||||
continue
|
||||
|
||||
can_implement, failure_reason = kernel.can_implement(config)
|
||||
@@ -84,10 +98,10 @@ def choose_mp_linear_kernel(
|
||||
return kernel
|
||||
else:
|
||||
failure_reasons.append(
|
||||
f' {kernel.__name__} cannot implement due to: {failure_reason}'
|
||||
f" {kernel.__name__} cannot implement due to: {failure_reason}"
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
"Failed to find a kernel that can implement the "\
|
||||
"WNA16 linear layer. Reasons: \n"
|
||||
+ '\n'.join(failure_reasons))
|
||||
"Failed to find a kernel that can implement the "
|
||||
"WNA16 linear layer. Reasons: \n" + "\n".join(failure_reasons)
|
||||
)
|
||||
|
||||
@@ -8,22 +8,21 @@ import torch
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
|
||||
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, check_allspark_supported_dtype_shape)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
permute_param_layout_)
|
||||
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
|
||||
check_allspark_supported_dtype_shape,
|
||||
)
|
||||
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
|
||||
class AllSparkLinearKernel(MPLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
if c.has_g_idx:
|
||||
return False, "Act reordering currently not supported by AllSpark"
|
||||
|
||||
@@ -35,7 +34,8 @@ class AllSparkLinearKernel(MPLinearKernel):
|
||||
c.partition_weight_shape[1], # out_features
|
||||
c.group_size,
|
||||
c.weight_type,
|
||||
c.act_type)
|
||||
c.act_type,
|
||||
)
|
||||
|
||||
# note assumes that
|
||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
@@ -49,8 +49,8 @@ class AllSparkLinearKernel(MPLinearKernel):
|
||||
sm_count = properties.multi_processor_count
|
||||
sm_version = properties.major * 10 + properties.minor
|
||||
gemm_args = {}
|
||||
gemm_args['sm_count'] = sm_count
|
||||
gemm_args['sm_version'] = sm_version
|
||||
gemm_args["sm_count"] = sm_count
|
||||
gemm_args["sm_version"] = sm_version
|
||||
|
||||
self.gemm_args = gemm_args
|
||||
|
||||
@@ -59,43 +59,42 @@ class AllSparkLinearKernel(MPLinearKernel):
|
||||
old_scale_param = getattr(layer, self.w_s_name)
|
||||
|
||||
assert isinstance(old_weight_param, BasevLLMParameter)
|
||||
permute_param_layout_(old_weight_param,
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=0)
|
||||
permute_param_layout_(old_weight_param, input_dim=0, output_dim=1, packed_dim=0)
|
||||
|
||||
assert isinstance(old_scale_param, BasevLLMParameter)
|
||||
permute_param_layout_(old_scale_param, input_dim=0, output_dim=1)
|
||||
|
||||
# unpack weight from K / 4 x N int32 to K x N uint8
|
||||
new_weight_param = torch.nn.Parameter(old_weight_param.data,
|
||||
requires_grad=False)
|
||||
new_weight_param.data = new_weight_param.data.t().contiguous().view(
|
||||
dtype=torch.uint8)
|
||||
new_weight_param = torch.nn.Parameter(
|
||||
old_weight_param.data, requires_grad=False
|
||||
)
|
||||
new_weight_param.data = (
|
||||
new_weight_param.data.t().contiguous().view(dtype=torch.uint8)
|
||||
)
|
||||
new_weight_param.data = new_weight_param.data.t().contiguous()
|
||||
|
||||
new_scale_param = torch.nn.Parameter(old_scale_param.data,
|
||||
requires_grad=False)
|
||||
new_scale_param = torch.nn.Parameter(old_scale_param.data, requires_grad=False)
|
||||
|
||||
# reorder K x N weight as N32K16 format for Ampere W8A16
|
||||
new_weight_param.data, new_scale_param.data, _ = \
|
||||
ops.allspark_repack_weight(
|
||||
new_weight_param.data, new_scale_param.data, None,
|
||||
c.zero_points)
|
||||
new_weight_param.data, new_scale_param.data, _ = ops.allspark_repack_weight(
|
||||
new_weight_param.data, new_scale_param.data, None, c.zero_points
|
||||
)
|
||||
|
||||
replace_parameter(layer, self.w_q_name, new_weight_param.data)
|
||||
replace_parameter(layer, self.w_s_name, new_scale_param.data)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
c = self.config
|
||||
gemm_args = self.gemm_args
|
||||
w_q, w_s, _, _ = self._get_weight_params(layer)
|
||||
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
|
||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)
|
||||
|
||||
output = ops.allspark_w8a16_gemm(
|
||||
a=reshaped_x,
|
||||
@@ -104,11 +103,12 @@ class AllSparkLinearKernel(MPLinearKernel):
|
||||
b_qzeros=None,
|
||||
n=c.partition_weight_shape[1],
|
||||
group_size=c.group_size,
|
||||
sm_count=gemm_args['sm_count'],
|
||||
sm_version=gemm_args['sm_version'],
|
||||
sm_count=gemm_args["sm_count"],
|
||||
sm_version=gemm_args["sm_version"],
|
||||
CUBLAS_M_THRESHOLD=ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
|
||||
has_zp=c.zero_points,
|
||||
n32k16_reorder=True)
|
||||
n32k16_reorder=True,
|
||||
)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
@@ -10,10 +10,16 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
|
||||
BITBLAS_OPTIMIZE_FEATURES, BITBLAS_SUPPORTED_GROUP_SIZES,
|
||||
MINIMUM_BITBLAS_VERSION, bitblas_make_empty_g_idx, bitblas_sort_g_idx,
|
||||
check_bitblas_supports_shape, query_bitblas_supported_quant_types,
|
||||
unpack_gptq_qweight, unpack_gptq_qzeros)
|
||||
BITBLAS_OPTIMIZE_FEATURES,
|
||||
BITBLAS_SUPPORTED_GROUP_SIZES,
|
||||
MINIMUM_BITBLAS_VERSION,
|
||||
bitblas_make_empty_g_idx,
|
||||
bitblas_sort_g_idx,
|
||||
check_bitblas_supports_shape,
|
||||
query_bitblas_supported_quant_types,
|
||||
unpack_gptq_qweight,
|
||||
unpack_gptq_qzeros,
|
||||
)
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
@@ -21,7 +27,6 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BitBLASLinearKernel(MPLinearKernel):
|
||||
|
||||
OPT_FEATURES: list[int] = BITBLAS_OPTIMIZE_FEATURES
|
||||
ENABLE_TUNING: bool = True
|
||||
MATMUL_LAYOUT: str = "nt"
|
||||
@@ -44,8 +49,9 @@ class BitBLASLinearKernel(MPLinearKernel):
|
||||
bitblas_quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
self.quant_config = bitblas_quant_config
|
||||
super().__init__(c, w_q_param_name, w_s_param_name, w_zp_param_name,
|
||||
w_gidx_param_name)
|
||||
super().__init__(
|
||||
c, w_q_param_name, w_s_param_name, w_zp_param_name, w_gidx_param_name
|
||||
)
|
||||
|
||||
def repack_bitblas_from_gptq(
|
||||
self,
|
||||
@@ -54,19 +60,18 @@ class BitBLASLinearKernel(MPLinearKernel):
|
||||
qzeros: Optional[torch.Tensor] = None,
|
||||
):
|
||||
from bitblas.quantization.utils import general_compress
|
||||
|
||||
assert self.bitblas_matmul is not None, "bitblas_matmul is None"
|
||||
|
||||
quant_config = self.quant_config
|
||||
# qweight in gptq old quant linear stored with
|
||||
# (outfeatures, infeatures), should be transposed.
|
||||
qweight = b_q_weight.T.contiguous().view(
|
||||
quant_config.torch_storage_dtype) # type: ignore[union-attr]
|
||||
intweight = unpack_gptq_qweight(
|
||||
qweight,
|
||||
quant_config.weight_bits).contiguous() # type: ignore[union-attr]
|
||||
qweight = b_q_weight.T.contiguous().view(quant_config.torch_storage_dtype) # type: ignore[union-attr]
|
||||
intweight = unpack_gptq_qweight(qweight, quant_config.weight_bits).contiguous() # type: ignore[union-attr]
|
||||
if self.bitblas_matmul.weight_transform is not None: # type: ignore[attr-defined]
|
||||
qweight = self.bitblas_matmul.weight_transform( # type: ignore[attr-defined]
|
||||
intweight.cpu()).cuda()
|
||||
intweight.cpu()
|
||||
).cuda()
|
||||
# scales in gptq old quant linear stored with
|
||||
# (infeatures // group_size, outfeatures), should be transposed.
|
||||
scales = scales.T.contiguous()
|
||||
@@ -90,9 +95,14 @@ class BitBLASLinearKernel(MPLinearKernel):
|
||||
general_compress(
|
||||
intzeros.T.contiguous().cpu().numpy(),
|
||||
weight_bits,
|
||||
)).to(qweight.device).
|
||||
to(quant_config.torch_storage_dtype # type: ignore[union-attr]
|
||||
).contiguous())
|
||||
)
|
||||
)
|
||||
.to(qweight.device)
|
||||
.to(
|
||||
quant_config.torch_storage_dtype # type: ignore[union-attr]
|
||||
)
|
||||
.contiguous()
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unsupported zeros type: {}".format(zeros_mode))
|
||||
|
||||
@@ -103,41 +113,50 @@ class BitBLASLinearKernel(MPLinearKernel):
|
||||
return 70
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
is_bitblas_installed = True
|
||||
|
||||
try:
|
||||
import bitblas
|
||||
|
||||
if version.parse(bitblas.__version__) < version.parse(
|
||||
MINIMUM_BITBLAS_VERSION):
|
||||
MINIMUM_BITBLAS_VERSION
|
||||
):
|
||||
raise ImportError(
|
||||
"bitblas version is wrong. Please "
|
||||
f"install bitblas>={MINIMUM_BITBLAS_VERSION}")
|
||||
f"install bitblas>={MINIMUM_BITBLAS_VERSION}"
|
||||
)
|
||||
except ImportError:
|
||||
is_bitblas_installed = False
|
||||
|
||||
if not is_bitblas_installed:
|
||||
return False, "bitblas is not installed. Please install bitblas "\
|
||||
"by running `pip install bitblas>="\
|
||||
f"{MINIMUM_BITBLAS_VERSION}`"
|
||||
return (
|
||||
False,
|
||||
"bitblas is not installed. Please install bitblas "
|
||||
"by running `pip install bitblas>="
|
||||
f"{MINIMUM_BITBLAS_VERSION}`",
|
||||
)
|
||||
|
||||
quant_types = query_bitblas_supported_quant_types(c.zero_points)
|
||||
if c.weight_type not in quant_types:
|
||||
return False, (f"Quant type ({c.weight_type}) not supported by"
|
||||
f" BitBLAS, supported types are: {quant_types}")
|
||||
return False, (
|
||||
f"Quant type ({c.weight_type}) not supported by"
|
||||
f" BitBLAS, supported types are: {quant_types}"
|
||||
)
|
||||
|
||||
if c.group_size not in BITBLAS_SUPPORTED_GROUP_SIZES:
|
||||
return False, (f"Group size ({c.group_size}) not supported by "
|
||||
"BitBLAS, supported group sizes are: "
|
||||
f"{BITBLAS_SUPPORTED_GROUP_SIZES}")
|
||||
return False, (
|
||||
f"Group size ({c.group_size}) not supported by "
|
||||
"BitBLAS, supported group sizes are: "
|
||||
f"{BITBLAS_SUPPORTED_GROUP_SIZES}"
|
||||
)
|
||||
|
||||
return check_bitblas_supports_shape(
|
||||
c.partition_weight_shape[1], # out_features
|
||||
c.partition_weight_shape[0], # in_features
|
||||
c.full_weight_shape[0], # in_features
|
||||
c.group_size)
|
||||
c.group_size,
|
||||
)
|
||||
|
||||
# note assumes that
|
||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
@@ -149,14 +168,15 @@ class BitBLASLinearKernel(MPLinearKernel):
|
||||
|
||||
# Default names since bitblas requires empty parameters for these,
|
||||
# TODO: remove this requirement from bitblas (allow optional tensors)
|
||||
if self.w_gidx_name is None:
|
||||
self.w_gidx_name = "g_idx"
|
||||
if self.w_zp_name is None:
|
||||
self.w_zp_name = "qzeros"
|
||||
if getattr(self, "w_gidx_name", None) is None:
|
||||
self.w_gidx_name: str = "g_idx"
|
||||
if getattr(self, "w_zp_name", None) is None:
|
||||
self.w_zp_name: str = "qzeros"
|
||||
|
||||
if c.has_g_idx:
|
||||
g_idx, g_idx_sort_indices = bitblas_sort_g_idx(
|
||||
getattr(layer, self.w_gidx_name))
|
||||
getattr(layer, self.w_gidx_name)
|
||||
)
|
||||
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
|
||||
layer.g_idx_sort_indices = g_idx_sort_indices
|
||||
else:
|
||||
@@ -169,13 +189,11 @@ class BitBLASLinearKernel(MPLinearKernel):
|
||||
setattr(layer, self.w_zp_name, bitblas_make_empty_g_idx(device))
|
||||
|
||||
# Repack weights
|
||||
bitblas_qweight, bitblas_scales, bitblas_qzeros = (
|
||||
self.repack_bitblas_from_gptq(
|
||||
layer.qweight,
|
||||
layer.scales,
|
||||
None if quant_config.is_sym else # type: ignore[union-attr]
|
||||
layer.qzeros, # type: ignore[union-attr]
|
||||
))
|
||||
bitblas_qweight, bitblas_scales, bitblas_qzeros = self.repack_bitblas_from_gptq(
|
||||
layer.qweight,
|
||||
layer.scales,
|
||||
None if quant_config.is_sym else layer.qzeros, # type: ignore[union-attr]
|
||||
)
|
||||
replace_parameter(layer, self.w_q_name, bitblas_qweight)
|
||||
replace_parameter(layer, self.w_s_name, bitblas_scales)
|
||||
if bitblas_qzeros is not None:
|
||||
@@ -212,6 +230,7 @@ class BitBLASLinearKernel(MPLinearKernel):
|
||||
bits,
|
||||
):
|
||||
from bitblas import MatmulConfig
|
||||
|
||||
bitblas_dtype = self.BITBLAS_DTYPES[params_dtype]
|
||||
quant_config = self.quant_config
|
||||
with_scaling = False
|
||||
@@ -248,30 +267,33 @@ class BitBLASLinearKernel(MPLinearKernel):
|
||||
zeros_mode=zeros_mode,
|
||||
)
|
||||
self.bitblas_matmul = self._get_or_create_bitblas_operator(
|
||||
matmul_config, enable_tuning)
|
||||
matmul_config, enable_tuning
|
||||
)
|
||||
|
||||
def _get_or_create_bitblas_operator(self, config, enable_tuning):
|
||||
from bitblas import Matmul, auto_detect_nvidia_target
|
||||
from bitblas.cache import get_database_path, global_operator_cache
|
||||
|
||||
BITBLAS_DATABASE_PATH = get_database_path()
|
||||
BITBLAS_TARGET = auto_detect_nvidia_target()
|
||||
|
||||
if global_operator_cache.size() == 0:
|
||||
global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH,
|
||||
BITBLAS_TARGET)
|
||||
global_operator_cache.load_from_database(
|
||||
BITBLAS_DATABASE_PATH, BITBLAS_TARGET
|
||||
)
|
||||
|
||||
bitblas_matmul = global_operator_cache.get(config)
|
||||
if bitblas_matmul is None:
|
||||
bitblas_matmul = Matmul(config,
|
||||
target=BITBLAS_TARGET,
|
||||
enable_tuning=False)
|
||||
bitblas_matmul = Matmul(config, target=BITBLAS_TARGET, enable_tuning=False)
|
||||
if enable_tuning:
|
||||
bitblas_matmul.hardware_aware_finetune(topk=20)
|
||||
global_operator_cache.add(config, bitblas_matmul)
|
||||
global_operator_cache.save_into_database(
|
||||
BITBLAS_DATABASE_PATH, BITBLAS_TARGET)
|
||||
BITBLAS_DATABASE_PATH, BITBLAS_TARGET
|
||||
)
|
||||
TUNING_MESSAGE = (
|
||||
f"BitBLAS Operator {config} tuned and saved to database.")
|
||||
f"BitBLAS Operator {config} tuned and saved to database."
|
||||
)
|
||||
logger.info(TUNING_MESSAGE)
|
||||
else:
|
||||
_message = f"BitBLAS Operator {config} created without tuning. "
|
||||
@@ -287,7 +309,7 @@ class BitBLASLinearKernel(MPLinearKernel):
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
output_size_per_partition = self.config.partition_weight_shape[1]
|
||||
out_shape = x.shape[:-1] + (output_size_per_partition, )
|
||||
out_shape = x.shape[:-1] + (output_size_per_partition,)
|
||||
args = [x, layer.qweight, layer.scales]
|
||||
if self.bitblas_matmul.config.with_zeros: # type: ignore[attr-defined]
|
||||
args.append(layer.qzeros)
|
||||
@@ -297,5 +319,6 @@ class BitBLASLinearKernel(MPLinearKernel):
|
||||
def apply_weights(self, layer, x, bias=None):
|
||||
NOT_IMPLEMENT_MESSAGE = (
|
||||
f"{self.__class__.__name__}.apply_weights is not implemented. "
|
||||
"Please use BitBLASLinearKernel.apply_gptq_bitblas_linear instead")
|
||||
"Please use BitBLASLinearKernel.apply_gptq_bitblas_linear instead"
|
||||
)
|
||||
raise NotImplementedError(NOT_IMPLEMENT_MESSAGE)
|
||||
|
||||
@@ -6,44 +6,49 @@ from typing import Final, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
permute_param_layout_)
|
||||
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
_CONCH_SUPPORTED_WEIGHT_TYPES: Final = [
|
||||
scalar_types.uint4, scalar_types.uint8, scalar_types.uint4b8,
|
||||
scalar_types.uint8b128
|
||||
scalar_types.uint4,
|
||||
scalar_types.uint8,
|
||||
scalar_types.uint4b8,
|
||||
scalar_types.uint8b128,
|
||||
]
|
||||
_CONCH_SUPPORTED_GROUP_SIZES: Final = [-1, 128]
|
||||
|
||||
|
||||
class ConchLinearKernel(MPLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
if c.weight_type not in _CONCH_SUPPORTED_WEIGHT_TYPES:
|
||||
error_msg = f"Weight type ({c.weight_type}) not supported by "\
|
||||
"ConchLinearKernel, supported types are: " \
|
||||
f"{_CONCH_SUPPORTED_WEIGHT_TYPES}"
|
||||
error_msg = (
|
||||
f"Weight type ({c.weight_type}) not supported by "
|
||||
"ConchLinearKernel, supported types are: "
|
||||
f"{_CONCH_SUPPORTED_WEIGHT_TYPES}"
|
||||
)
|
||||
return False, error_msg
|
||||
|
||||
if c.group_size not in _CONCH_SUPPORTED_GROUP_SIZES:
|
||||
error_msg = f"Group size ({c.group_size}) not supported by "\
|
||||
"ConchLinearKernel, supported group sizes are: " \
|
||||
f"{_CONCH_SUPPORTED_GROUP_SIZES}"
|
||||
error_msg = (
|
||||
f"Group size ({c.group_size}) not supported by "
|
||||
"ConchLinearKernel, supported group sizes are: "
|
||||
f"{_CONCH_SUPPORTED_GROUP_SIZES}"
|
||||
)
|
||||
return False, error_msg
|
||||
|
||||
if find_spec("conch") is None:
|
||||
error_msg = "conch-triton-kernels is not installed, please "\
|
||||
"install it via `pip install conch-triton-kernels` "\
|
||||
"and try again!"
|
||||
error_msg = (
|
||||
"conch-triton-kernels is not installed, please "
|
||||
"install it via `pip install conch-triton-kernels` "
|
||||
"and try again!"
|
||||
)
|
||||
return False, error_msg
|
||||
|
||||
return True, None
|
||||
@@ -52,7 +57,6 @@ class ConchLinearKernel(MPLinearKernel):
|
||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
|
||||
def transform_w_q(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
@@ -68,10 +72,12 @@ class ConchLinearKernel(MPLinearKernel):
|
||||
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
from conch.ops.quantization.gemm import mixed_precision_gemm
|
||||
|
||||
w_q, w_s, w_zp, _ = self._get_weight_params(layer)
|
||||
|
||||
@@ -7,10 +7,8 @@ import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
permute_param_layout_)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
@@ -18,26 +16,22 @@ from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
|
||||
class CutlassW4A8LinearKernel(MPLinearKernel):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# dynamic per-tok fp8 activation quantization
|
||||
self.quant_fp8 = QuantFP8(static=False,
|
||||
group_shape=GroupShape.PER_TOKEN)
|
||||
self.quant_fp8 = QuantFP8(static=False, group_shape=GroupShape.PER_TOKEN)
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 90
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
if not current_platform.is_cuda():
|
||||
return False, "CUTLASS only supported on CUDA"
|
||||
|
||||
if not current_platform.is_device_capability(90):
|
||||
return False, "CUTLASS W4A8 requires compute capability of 90 "\
|
||||
"(Hopper)"
|
||||
return False, "CUTLASS W4A8 requires compute capability of 90 (Hopper)"
|
||||
|
||||
if c.act_type != torch.float8_e4m3fn:
|
||||
return False, "CUTLASS W4A8 only supports FP8 (e4m3) activations"
|
||||
@@ -49,8 +43,11 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
|
||||
return False, "Zero points not supported by CUTLASS W4A8"
|
||||
|
||||
if c.weight_type != scalar_types.int4:
|
||||
return False, f"Quant type ({c.weight_type}) not supported by "\
|
||||
"CUTLASS W4A8, only supported int4"
|
||||
return (
|
||||
False,
|
||||
f"Quant type ({c.weight_type}) not supported by "
|
||||
"CUTLASS W4A8, only supported int4",
|
||||
)
|
||||
|
||||
# TODO(czhu): support -1 (column-wise)
|
||||
if c.group_size != 128:
|
||||
@@ -58,12 +55,16 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
|
||||
|
||||
in_features, out_features = c.partition_weight_shape
|
||||
if in_features % 128 or out_features % 128:
|
||||
return False, "K and N must be divisible by 128, got "\
|
||||
f"{c.partition_weight_shape}"
|
||||
return (
|
||||
False,
|
||||
f"K and N must be divisible by 128, got {c.partition_weight_shape}",
|
||||
)
|
||||
|
||||
if c.out_type != torch.bfloat16:
|
||||
return False, "Only bfloat16 output type currently supported"\
|
||||
f"got {c.out_type=}"
|
||||
return (
|
||||
False,
|
||||
f"Only bfloat16 output type currently supportedgot {c.out_type=}",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
@@ -71,13 +72,11 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
|
||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
|
||||
# TODO(czhu): optimize speed/mem usage
|
||||
def transform_w_q(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
x.data = ops.cutlass_encode_and_reorder_int4b(
|
||||
x.data.t().contiguous().t())
|
||||
x.data = ops.cutlass_encode_and_reorder_int4b(x.data.t().contiguous().t())
|
||||
return x
|
||||
|
||||
def transform_w_s(x):
|
||||
@@ -92,24 +91,28 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
|
||||
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||
self._transform_param(layer, "weight_chan_scale", lambda x: x)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
c = self.config
|
||||
w_q, w_s, _, _ = self._get_weight_params(layer)
|
||||
w_ch_s = layer.weight_chan_scale
|
||||
|
||||
x_2d = x.reshape(-1, x.shape[-1])
|
||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
|
||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)
|
||||
|
||||
x_2d, act_scales = self.quant_fp8(x_2d)
|
||||
output = ops.cutlass_w4a8_mm(a=x_2d,
|
||||
b_q=w_q,
|
||||
b_group_scales=w_s,
|
||||
b_group_size=c.group_size,
|
||||
a_token_scales=act_scales,
|
||||
b_channel_scales=w_ch_s)
|
||||
output = ops.cutlass_w4a8_mm(
|
||||
a=x_2d,
|
||||
b_q=w_q,
|
||||
b_group_scales=w_s,
|
||||
b_group_size=c.group_size,
|
||||
a_token_scales=act_scales,
|
||||
b_channel_scales=w_ch_s,
|
||||
)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
@@ -20,37 +20,45 @@ class Dynamic4bitLinearKernel(MPLinearKernel):
|
||||
return 1
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
if not current_platform.is_cpu():
|
||||
return False, "Only CPU is supported"
|
||||
if c.weight_type not in cls.SUPPORTED_QUANT_TYPES:
|
||||
return False, f"Unsupported quant type {c.weight_type}"
|
||||
if current_platform.get_cpu_architecture(
|
||||
) == CpuArchEnum.ARM and c.act_type not in [
|
||||
if (
|
||||
current_platform.get_cpu_architecture() == CpuArchEnum.ARM
|
||||
and c.act_type
|
||||
not in [
|
||||
torch.float32,
|
||||
]:
|
||||
return False, "Dynamic4bitLinearKernel on Arm requires"\
|
||||
" Float32 activations"
|
||||
]
|
||||
):
|
||||
return False, "Dynamic4bitLinearKernel on Arm requires Float32 activations"
|
||||
if c.full_weight_shape[0] % c.group_size != 0:
|
||||
return False, f"Group size ({c.group_size}) does not evenly divide"\
|
||||
" the number of input features "\
|
||||
f"({c.full_weight_shape[0]})"
|
||||
return (
|
||||
False,
|
||||
f"Group size ({c.group_size}) does not evenly divide"
|
||||
" the number of input features "
|
||||
f"({c.full_weight_shape[0]})",
|
||||
)
|
||||
if current_platform.get_cpu_architecture() == CpuArchEnum.ARM:
|
||||
try:
|
||||
# Attempt to retrieve the operation
|
||||
_ = torch.ops.aten._dyn_quant_matmul_4bit
|
||||
except AttributeError:
|
||||
return False, f"PyTorch {torch.__version__} does not support"\
|
||||
" _dyn_quant_matmul_4bit. Install a newer version"
|
||||
return (
|
||||
False,
|
||||
f"PyTorch {torch.__version__} does not support"
|
||||
" _dyn_quant_matmul_4bit. Install a newer version",
|
||||
)
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
c = self.config
|
||||
packed_weight = getattr(layer, self.w_q_name)
|
||||
packed_weight = packed_weight.add(8)
|
||||
uint8_packed = (packed_weight[::, 1::2] << 4
|
||||
| packed_weight[::, ::2]).to(torch.uint8)
|
||||
uint8_packed = (packed_weight[::, 1::2] << 4 | packed_weight[::, ::2]).to(
|
||||
torch.uint8
|
||||
)
|
||||
|
||||
scales = getattr(layer, self.w_s_name)
|
||||
block_size = c.group_size
|
||||
@@ -71,22 +79,34 @@ class Dynamic4bitLinearKernel(MPLinearKernel):
|
||||
|
||||
# Repack weights as per kernel requirement
|
||||
w = torch.ops.aten._dyn_quant_pack_4bit_weight(
|
||||
uint8_packed, scales, layer.bias, block_size,
|
||||
c.partition_weight_shape[0], c.partition_weight_shape[1])
|
||||
replace_parameter(layer, self.w_q_name,
|
||||
torch.nn.Parameter(w, requires_grad=False))
|
||||
uint8_packed,
|
||||
scales,
|
||||
layer.bias,
|
||||
block_size,
|
||||
c.partition_weight_shape[0],
|
||||
c.partition_weight_shape[1],
|
||||
)
|
||||
replace_parameter(
|
||||
layer, self.w_q_name, torch.nn.Parameter(w, requires_grad=False)
|
||||
)
|
||||
setattr(layer, self.w_s_name, None)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
c = self.config
|
||||
x_2d = x.reshape(-1, x.shape[-1])
|
||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
|
||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)
|
||||
|
||||
w_q = getattr(layer, self.w_q_name)
|
||||
output = torch.ops.aten._dyn_quant_matmul_4bit(
|
||||
x_2d, w_q, c.group_size, c.partition_weight_shape[0],
|
||||
c.partition_weight_shape[1])
|
||||
x_2d,
|
||||
w_q,
|
||||
c.group_size,
|
||||
c.partition_weight_shape[0],
|
||||
c.partition_weight_shape[1],
|
||||
)
|
||||
return output.reshape(out_shape)
|
||||
|
||||
@@ -7,9 +7,9 @@ import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
pack_quantized_values_into_int32)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
permute_param_layout_)
|
||||
pack_quantized_values_into_int32,
|
||||
)
|
||||
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
@@ -25,31 +25,41 @@ class ExllamaLinearKernel(MPLinearKernel):
|
||||
return 60
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
if c.has_g_idx and\
|
||||
c.partition_weight_shape[0] != c.full_weight_shape[0]:
|
||||
return False, "Act reordering currently not supported by Exllama, "\
|
||||
"when the input features are partitioned across "\
|
||||
"devices"
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
if c.has_g_idx and c.partition_weight_shape[0] != c.full_weight_shape[0]:
|
||||
return (
|
||||
False,
|
||||
"Act reordering currently not supported by Exllama, "
|
||||
"when the input features are partitioned across "
|
||||
"devices",
|
||||
)
|
||||
|
||||
if c.partition_weight_shape[1] % (32 // c.weight_type.size_bits) != 0:
|
||||
return False, "Output features must be a multiple of the pack " \
|
||||
"factor (32 / num_bits) so that we can correctly " \
|
||||
"pack the zero points"
|
||||
return (
|
||||
False,
|
||||
"Output features must be a multiple of the pack "
|
||||
"factor (32 / num_bits) so that we can correctly "
|
||||
"pack the zero points",
|
||||
)
|
||||
|
||||
if c.act_type != torch.float16:
|
||||
return False, "Exllama only supports float16 activations"
|
||||
|
||||
if c.weight_type not in cls.SUPPORTED_QUANT_TYPES:
|
||||
return False, f"Quant type ({c.weight_type}) not supported by "\
|
||||
"Exllama, supported types are: "\
|
||||
f"{cls.SUPPORTED_QUANT_TYPES}"
|
||||
return (
|
||||
False,
|
||||
f"Quant type ({c.weight_type}) not supported by "
|
||||
"Exllama, supported types are: "
|
||||
f"{cls.SUPPORTED_QUANT_TYPES}",
|
||||
)
|
||||
|
||||
if c.full_weight_shape[0] % c.group_size != 0:
|
||||
return False, f"Group size ({c.group_size}) does not evenly divide"\
|
||||
" the number of input features "\
|
||||
f"({c.full_weight_shape[0]})"
|
||||
return (
|
||||
False,
|
||||
f"Group size ({c.group_size}) does not evenly divide"
|
||||
" the number of input features "
|
||||
f"({c.full_weight_shape[0]})",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
@@ -70,21 +80,23 @@ class ExllamaLinearKernel(MPLinearKernel):
|
||||
# exllama kernel adding 1 to the zero points during inference)
|
||||
# Documentation of the bug can be found here:
|
||||
# https://garden.danieldk.eu/GPTQ-Checkpoint-Format
|
||||
zeros = torch.full((groups, out_features),
|
||||
c.weight_type.bias - 1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
zeros = torch.full(
|
||||
(groups, out_features),
|
||||
c.weight_type.bias - 1,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"A 0 zero-point is not supported by Exllama due to "
|
||||
"a bug in the original GPTQ checkpoint format leading to "
|
||||
"exllama kernel adding 1 to the zero points during "
|
||||
"inference")
|
||||
zeros = pack_quantized_values_into_int32(zeros,
|
||||
c.weight_type,
|
||||
packed_dim=1)
|
||||
setattr(layer, self.w_zp_name,
|
||||
torch.nn.Parameter(zeros, requires_grad=False))
|
||||
"inference"
|
||||
)
|
||||
zeros = pack_quantized_values_into_int32(zeros, c.weight_type, packed_dim=1)
|
||||
setattr(
|
||||
layer, self.w_zp_name, torch.nn.Parameter(zeros, requires_grad=False)
|
||||
)
|
||||
|
||||
if c.has_g_idx:
|
||||
|
||||
@@ -96,10 +108,9 @@ class ExllamaLinearKernel(MPLinearKernel):
|
||||
self._transform_param(layer, self.w_gidx_name, transform_w_g_idx)
|
||||
else:
|
||||
self.w_gidx_name = "g_idx"
|
||||
empty_g_idx = torch.nn.Parameter(torch.empty((0, ),
|
||||
dtype=torch.int,
|
||||
device=device),
|
||||
requires_grad=False)
|
||||
empty_g_idx = torch.nn.Parameter(
|
||||
torch.empty((0,), dtype=torch.int, device=device), requires_grad=False
|
||||
)
|
||||
setattr(layer, self.w_gidx_name, empty_g_idx)
|
||||
|
||||
def transform_w_q(x):
|
||||
@@ -122,21 +133,24 @@ class ExllamaLinearKernel(MPLinearKernel):
|
||||
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
c = self.config
|
||||
|
||||
x_2d = x.reshape(-1, x.shape[-1])
|
||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
|
||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)
|
||||
|
||||
w_q, w_s, w_zp, w_g_idx = self._get_weight_params(layer)
|
||||
|
||||
assert w_zp is not None, "Zero points are required by Exllama"
|
||||
assert w_g_idx is not None, "Group index is required by Exllama"
|
||||
output = ops.gptq_gemm(x_2d, w_q, w_zp, w_s, w_g_idx, True,
|
||||
c.weight_type.size_bits)
|
||||
output = ops.gptq_gemm(
|
||||
x_2d, w_q, w_zp, w_s, w_g_idx, True, c.weight_type.size_bits
|
||||
)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias)
|
||||
|
||||
@@ -8,26 +8,27 @@ import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.machete_utils import (
|
||||
check_machete_supports_shape, query_machete_supported_group_sizes,
|
||||
query_machete_supported_quant_types)
|
||||
check_machete_supports_shape,
|
||||
query_machete_supported_group_sizes,
|
||||
query_machete_supported_quant_types,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
pack_quantized_values_into_int32, unpack_quantized_values_into_int32)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
permute_param_layout_)
|
||||
pack_quantized_values_into_int32,
|
||||
unpack_quantized_values_into_int32,
|
||||
)
|
||||
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
|
||||
class MacheteLinearKernel(MPLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 90
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
# Machete uses CUTLASS, so it can only be compatible with Nvidia
|
||||
if not current_platform.is_cuda():
|
||||
return False, "Machete only supported on CUDA"
|
||||
@@ -35,25 +36,33 @@ class MacheteLinearKernel(MPLinearKernel):
|
||||
if not current_platform.is_device_capability(90):
|
||||
return False, "Machete requires compute capability of 90 (Hopper)"
|
||||
|
||||
if c.has_g_idx and\
|
||||
c.partition_weight_shape[0] != c.full_weight_shape[0]:
|
||||
return False, "Act reordering currently not supported by Machete, "\
|
||||
"when the input features are partitioned across "\
|
||||
"devices"
|
||||
if c.has_g_idx and c.partition_weight_shape[0] != c.full_weight_shape[0]:
|
||||
return (
|
||||
False,
|
||||
"Act reordering currently not supported by Machete, "
|
||||
"when the input features are partitioned across "
|
||||
"devices",
|
||||
)
|
||||
|
||||
if c.weight_type not in query_machete_supported_quant_types(
|
||||
c.zero_points):
|
||||
return False, f"Quant type ({c.weight_type}) not supported by "\
|
||||
"Machete, supported types are: "\
|
||||
f"{query_machete_supported_quant_types(c.zero_points)}"
|
||||
if c.weight_type not in query_machete_supported_quant_types(c.zero_points):
|
||||
return (
|
||||
False,
|
||||
f"Quant type ({c.weight_type}) not supported by "
|
||||
"Machete, supported types are: "
|
||||
f"{query_machete_supported_quant_types(c.zero_points)}",
|
||||
)
|
||||
|
||||
if c.group_size not in query_machete_supported_group_sizes(c.act_type):
|
||||
return False, f"Group size ({c.group_size}) not supported by "\
|
||||
"Machete, supported group sizes are: "\
|
||||
f"{query_machete_supported_group_sizes(c.act_type)}"
|
||||
return (
|
||||
False,
|
||||
f"Group size ({c.group_size}) not supported by "
|
||||
"Machete, supported group sizes are: "
|
||||
f"{query_machete_supported_group_sizes(c.act_type)}",
|
||||
)
|
||||
|
||||
return check_machete_supports_shape(c.partition_weight_shape[0],
|
||||
c.partition_weight_shape[1])
|
||||
return check_machete_supports_shape(
|
||||
c.partition_weight_shape[0], c.partition_weight_shape[1]
|
||||
)
|
||||
|
||||
# note assumes that
|
||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
@@ -64,30 +73,33 @@ class MacheteLinearKernel(MPLinearKernel):
|
||||
|
||||
if c.has_g_idx:
|
||||
assert self.w_gidx_name is not None
|
||||
perm = torch.argsort(getattr(layer, self.w_gidx_name))\
|
||||
.to(torch.int)
|
||||
perm = torch.argsort(getattr(layer, self.w_gidx_name)).to(torch.int)
|
||||
|
||||
self.act_perm = lambda x: x[:, perm]
|
||||
# use `ops.permute_cols` if possible
|
||||
if c.act_type in [torch.float16, torch.bfloat16] \
|
||||
and c.partition_weight_shape[0] % 8 == 0:
|
||||
if (
|
||||
c.act_type in [torch.float16, torch.bfloat16]
|
||||
and c.partition_weight_shape[0] % 8 == 0
|
||||
):
|
||||
self.act_perm = partial(ops.permute_cols, perm=perm)
|
||||
|
||||
def transform_w_q(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
if c.has_g_idx:
|
||||
x_unpacked = unpack_quantized_values_into_int32(x.data,
|
||||
c.weight_type,
|
||||
packed_dim=0)
|
||||
x_unpacked = unpack_quantized_values_into_int32(
|
||||
x.data, c.weight_type, packed_dim=0
|
||||
)
|
||||
x_perm = x_unpacked[perm, :]
|
||||
x.data = pack_quantized_values_into_int32(x_perm,
|
||||
c.weight_type,
|
||||
packed_dim=0)
|
||||
x.data = ops.machete_prepack_B(x.data.t().contiguous().t(),
|
||||
a_type=c.act_type,
|
||||
b_type=c.weight_type,
|
||||
group_scales_type=c.act_type)
|
||||
x.data = pack_quantized_values_into_int32(
|
||||
x_perm, c.weight_type, packed_dim=0
|
||||
)
|
||||
x.data = ops.machete_prepack_B(
|
||||
x.data.t().contiguous().t(),
|
||||
a_type=c.act_type,
|
||||
b_type=c.weight_type,
|
||||
group_scales_type=c.act_type,
|
||||
)
|
||||
return x
|
||||
|
||||
def transform_w_s(x):
|
||||
@@ -99,9 +111,9 @@ class MacheteLinearKernel(MPLinearKernel):
|
||||
def transform_w_zp(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=1)
|
||||
x_unpacked = unpack_quantized_values_into_int32(x.data,
|
||||
c.weight_type,
|
||||
packed_dim=1)
|
||||
x_unpacked = unpack_quantized_values_into_int32(
|
||||
x.data, c.weight_type, packed_dim=1
|
||||
)
|
||||
w_s = getattr(layer, self.w_s_name).data
|
||||
# pre-apply scales to zero-points
|
||||
x.data = (-1.0 * w_s * (x_unpacked.to(w_s.dtype))).contiguous()
|
||||
@@ -113,15 +125,17 @@ class MacheteLinearKernel(MPLinearKernel):
|
||||
if c.zero_points:
|
||||
self._transform_param(layer, self.w_zp_name, transform_w_zp)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
c = self.config
|
||||
w_q, w_s, w_zp, _ = self._get_weight_params(layer)
|
||||
|
||||
x_2d = x.reshape(-1, x.shape[-1])
|
||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
|
||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)
|
||||
|
||||
if c.has_g_idx:
|
||||
x_2d = self.act_perm(x_2d)
|
||||
@@ -131,12 +145,14 @@ class MacheteLinearKernel(MPLinearKernel):
|
||||
else:
|
||||
w_zp = None
|
||||
|
||||
output = ops.machete_mm(a=x_2d,
|
||||
b_q=w_q,
|
||||
b_type=c.weight_type,
|
||||
b_group_zeros=w_zp,
|
||||
b_group_scales=w_s,
|
||||
b_group_size=c.group_size)
|
||||
output = ops.machete_mm(
|
||||
a=x_2d,
|
||||
b_q=w_q,
|
||||
b_type=c.weight_type,
|
||||
b_group_zeros=w_zp,
|
||||
b_group_scales=w_s,
|
||||
b_group_size=c.group_size,
|
||||
)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
@@ -7,46 +7,58 @@ import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear,
|
||||
check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx,
|
||||
marlin_make_workspace_new, marlin_permute_bias, marlin_permute_scales,
|
||||
marlin_sort_g_idx, marlin_zero_points, query_marlin_supported_quant_types,
|
||||
unpack_cols)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
permute_param_layout_)
|
||||
MARLIN_SUPPORTED_GROUP_SIZES,
|
||||
apply_gptq_marlin_linear,
|
||||
check_marlin_supports_shape,
|
||||
marlin_is_k_full,
|
||||
marlin_make_empty_g_idx,
|
||||
marlin_make_workspace_new,
|
||||
marlin_permute_bias,
|
||||
marlin_permute_scales,
|
||||
marlin_sort_g_idx,
|
||||
marlin_zero_points,
|
||||
query_marlin_supported_quant_types,
|
||||
unpack_cols,
|
||||
)
|
||||
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
|
||||
class MarlinLinearKernel(MPLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
# Marlin uses inline PTX, so it can only be compatible with Nvidia
|
||||
if not current_platform.is_cuda():
|
||||
return False, "Marlin only supported on CUDA"
|
||||
|
||||
quant_types = query_marlin_supported_quant_types(c.zero_points)
|
||||
if c.weight_type not in quant_types:
|
||||
return False, f"Quant type ({c.weight_type}) not supported by"\
|
||||
f" Marlin, supported types are: {quant_types}"
|
||||
return (
|
||||
False,
|
||||
f"Quant type ({c.weight_type}) not supported by"
|
||||
f" Marlin, supported types are: {quant_types}",
|
||||
)
|
||||
|
||||
if c.group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
|
||||
return False, f"Group size ({c.group_size}) not supported by "\
|
||||
"Marlin, supported group sizes are: "\
|
||||
f"{MARLIN_SUPPORTED_GROUP_SIZES}"
|
||||
return (
|
||||
False,
|
||||
f"Group size ({c.group_size}) not supported by "
|
||||
"Marlin, supported group sizes are: "
|
||||
f"{MARLIN_SUPPORTED_GROUP_SIZES}",
|
||||
)
|
||||
|
||||
return check_marlin_supports_shape(
|
||||
c.partition_weight_shape[1], # out_features
|
||||
c.partition_weight_shape[0], # in_features
|
||||
c.full_weight_shape[0], # in_features
|
||||
c.group_size)
|
||||
c.group_size,
|
||||
)
|
||||
|
||||
# note assumes that
|
||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
@@ -55,7 +67,7 @@ class MarlinLinearKernel(MPLinearKernel):
|
||||
device = getattr(layer, self.w_q_name).device
|
||||
c = self.config
|
||||
|
||||
row_parallel = (c.partition_weight_shape[0] != c.full_weight_shape[0])
|
||||
row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0]
|
||||
self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
|
||||
|
||||
# Allocate marlin workspace.
|
||||
@@ -71,25 +83,30 @@ class MarlinLinearKernel(MPLinearKernel):
|
||||
def transform_w_q(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
x.data = ops.gptq_marlin_repack(x.data.contiguous(),
|
||||
perm=layer.g_idx_sort_indices,
|
||||
size_k=c.partition_weight_shape[0],
|
||||
size_n=c.partition_weight_shape[1],
|
||||
num_bits=c.weight_type.size_bits)
|
||||
x.data = ops.gptq_marlin_repack(
|
||||
x.data.contiguous(),
|
||||
perm=layer.g_idx_sort_indices,
|
||||
size_k=c.partition_weight_shape[0],
|
||||
size_n=c.partition_weight_shape[1],
|
||||
num_bits=c.weight_type.size_bits,
|
||||
)
|
||||
return x
|
||||
|
||||
def transform_w_s(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1)
|
||||
x.data = marlin_permute_scales(x.data.contiguous(),
|
||||
size_k=c.partition_weight_shape[0],
|
||||
size_n=c.partition_weight_shape[1],
|
||||
group_size=c.group_size)
|
||||
x.data = marlin_permute_scales(
|
||||
x.data.contiguous(),
|
||||
size_k=c.partition_weight_shape[0],
|
||||
size_n=c.partition_weight_shape[1],
|
||||
group_size=c.group_size,
|
||||
)
|
||||
return x
|
||||
|
||||
if c.has_g_idx:
|
||||
g_idx, g_idx_sort_indices = marlin_sort_g_idx(
|
||||
getattr(layer, self.w_gidx_name))
|
||||
getattr(layer, self.w_gidx_name)
|
||||
)
|
||||
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
|
||||
layer.g_idx_sort_indices = g_idx_sort_indices
|
||||
else:
|
||||
@@ -97,16 +114,24 @@ class MarlinLinearKernel(MPLinearKernel):
|
||||
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
|
||||
|
||||
if c.zero_points:
|
||||
grouped_k = (c.partition_weight_shape[0] //
|
||||
c.group_size if c.group_size != -1 else 1)
|
||||
self._transform_param(layer, self.w_zp_name, lambda x: \
|
||||
marlin_zero_points(
|
||||
unpack_cols(x.t(), c.weight_type.size_bits,
|
||||
grouped_k,
|
||||
c.partition_weight_shape[1]),
|
||||
grouped_k = (
|
||||
c.partition_weight_shape[0] // c.group_size if c.group_size != -1 else 1
|
||||
)
|
||||
self._transform_param(
|
||||
layer,
|
||||
self.w_zp_name,
|
||||
lambda x: marlin_zero_points(
|
||||
unpack_cols(
|
||||
x.t(),
|
||||
c.weight_type.size_bits,
|
||||
grouped_k,
|
||||
c.partition_weight_shape[1],
|
||||
),
|
||||
size_k=grouped_k,
|
||||
size_n=c.partition_weight_shape[1],
|
||||
num_bits=c.weight_type.size_bits))
|
||||
num_bits=c.weight_type.size_bits,
|
||||
),
|
||||
)
|
||||
else:
|
||||
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
|
||||
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||
@@ -115,10 +140,12 @@ class MarlinLinearKernel(MPLinearKernel):
|
||||
if hasattr(layer, "bias") and layer.bias is not None:
|
||||
layer.bias.data = marlin_permute_bias(layer.bias)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
c = self.config
|
||||
w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer)
|
||||
|
||||
@@ -136,4 +163,5 @@ class MarlinLinearKernel(MPLinearKernel):
|
||||
input_size_per_partition=c.partition_weight_shape[0],
|
||||
output_size_per_partition=c.partition_weight_shape[1],
|
||||
is_k_full=self.is_k_full,
|
||||
bias=bias)
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
@@ -16,7 +16,6 @@ class ScaledMMLinearLayerConfig:
|
||||
|
||||
|
||||
class ScaledMMLinearKernel(ABC):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
@@ -24,13 +23,18 @@ class ScaledMMLinearKernel(ABC):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def can_implement(
|
||||
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(self, c: ScaledMMLinearLayerConfig, w_q_param_name: str,
|
||||
w_s_param_name: str, i_s_param_name: str,
|
||||
i_zp_param_name: str, azp_adj_param_name: str) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
c: ScaledMMLinearLayerConfig,
|
||||
w_q_param_name: str,
|
||||
w_s_param_name: str,
|
||||
i_s_param_name: str,
|
||||
i_zp_param_name: str,
|
||||
azp_adj_param_name: str,
|
||||
) -> None:
|
||||
assert self.can_implement(c)
|
||||
self.config = c
|
||||
self.w_q_name = w_q_param_name
|
||||
@@ -44,20 +48,23 @@ class ScaledMMLinearKernel(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_weight_params(
|
||||
self, layer: torch.nn.Module) -> tuple[
|
||||
torch.Tensor, # weight
|
||||
torch.Tensor, # weight_scale
|
||||
Optional[torch.Tensor], # input_scale,
|
||||
Optional[torch.Tensor], # input_zp
|
||||
Optional[torch.Tensor], # azp_adj
|
||||
]:
|
||||
self, layer: torch.nn.Module
|
||||
) -> tuple[
|
||||
torch.Tensor, # weight
|
||||
torch.Tensor, # weight_scale
|
||||
Optional[torch.Tensor], # input_scale,
|
||||
Optional[torch.Tensor], # input_zp
|
||||
Optional[torch.Tensor], # azp_adj
|
||||
]:
|
||||
return (
|
||||
getattr(layer, self.w_q_name),
|
||||
getattr(layer, self.w_s_name),
|
||||
|
||||
@@ -5,17 +5,24 @@ import os
|
||||
from typing import Optional
|
||||
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import (
|
||||
AiterScaledMMLinearKernel)
|
||||
AiterScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import (
|
||||
CPUScaledMMLinearKernel)
|
||||
CPUScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
|
||||
CutlassScaledMMLinearKernel)
|
||||
CutlassScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
ScaledMMLinearKernel, ScaledMMLinearLayerConfig)
|
||||
ScaledMMLinearKernel,
|
||||
ScaledMMLinearLayerConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
|
||||
TritonScaledMMLinearKernel)
|
||||
TritonScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import (
|
||||
XLAScaledMMLinearKernel)
|
||||
XLAScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.platforms import PlatformEnum, current_platform
|
||||
|
||||
# in priority/performance order (when available)
|
||||
@@ -28,19 +35,18 @@ _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
|
||||
|
||||
|
||||
def choose_scaled_mm_linear_kernel(
|
||||
config: ScaledMMLinearLayerConfig,
|
||||
compute_capability: Optional[int] = None
|
||||
config: ScaledMMLinearLayerConfig, compute_capability: Optional[int] = None
|
||||
) -> type[ScaledMMLinearKernel]:
|
||||
"""
|
||||
Choose an ScaledMMLinearKernel that can implement the given config for the
|
||||
given compute capability. Attempts to choose the best kernel in terms of
|
||||
Choose an ScaledMMLinearKernel that can implement the given config for the
|
||||
given compute capability. Attempts to choose the best kernel in terms of
|
||||
performance.
|
||||
|
||||
Args:
|
||||
config (ScaledMMLinearLayerConfig): Description of the linear layer
|
||||
config (ScaledMMLinearLayerConfig): Description of the linear layer
|
||||
to be implemented.
|
||||
compute_capability (Optional[int], optional): The compute capability of
|
||||
the target device, if None uses `current_platform` to get the
|
||||
the target device, if None uses `current_platform` to get the
|
||||
compute capability. Defaults to None.
|
||||
|
||||
Raises:
|
||||
@@ -57,22 +63,25 @@ def choose_scaled_mm_linear_kernel(
|
||||
|
||||
failure_reasons = []
|
||||
for kernel in _POSSIBLE_KERNELS[current_platform._enum]:
|
||||
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "")\
|
||||
.split(","):
|
||||
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "").split(","):
|
||||
failure_reasons.append(
|
||||
f' {kernel.__name__} disabled by environment variable')
|
||||
f" {kernel.__name__} disabled by environment variable"
|
||||
)
|
||||
continue
|
||||
|
||||
# If the current platform uses compute_capability,
|
||||
# make sure the kernel supports the compute cability.
|
||||
if compute_capability is not None:
|
||||
kernel_min_capability = kernel.get_min_capability()
|
||||
if (kernel_min_capability is not None
|
||||
and kernel_min_capability > compute_capability):
|
||||
if (
|
||||
kernel_min_capability is not None
|
||||
and kernel_min_capability > compute_capability
|
||||
):
|
||||
failure_reasons.append(
|
||||
f"{kernel.__name__} requires capability "
|
||||
f"{kernel_min_capability}, current compute capability "
|
||||
f"is {compute_capability}")
|
||||
f"is {compute_capability}"
|
||||
)
|
||||
continue
|
||||
|
||||
can_implement, failure_reason = kernel.can_implement(config)
|
||||
@@ -80,10 +89,10 @@ def choose_scaled_mm_linear_kernel(
|
||||
return kernel
|
||||
else:
|
||||
failure_reasons.append(
|
||||
f' {kernel.__name__} cannot implement due to: {failure_reason}'
|
||||
f" {kernel.__name__} cannot implement due to: {failure_reason}"
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
"Failed to find a kernel that can implement the "\
|
||||
"ScaledMM linear layer. Reasons: \n"
|
||||
+ '\n'.join(failure_reasons))
|
||||
"Failed to find a kernel that can implement the "
|
||||
"ScaledMM linear layer. Reasons: \n" + "\n".join(failure_reasons)
|
||||
)
|
||||
|
||||
@@ -22,7 +22,6 @@ def rocm_aiter_gemm_w8a8_impl(
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
|
||||
from aiter import gemm_a8w8_CK
|
||||
|
||||
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
|
||||
@@ -40,7 +39,6 @@ def rocm_aiter_gemm_w8a8_fake(
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
|
||||
m = A.shape[0]
|
||||
n = B.shape[0]
|
||||
Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
|
||||
@@ -56,50 +54,53 @@ if current_platform.is_rocm():
|
||||
|
||||
|
||||
class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 90
|
||||
|
||||
@classmethod
|
||||
def can_implement(
|
||||
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
if not current_platform.is_rocm():
|
||||
return (
|
||||
False,
|
||||
"AiterScaledMMLinearKernel requires `aiter` which is not " +
|
||||
"currently supported on non-ROCm platform.")
|
||||
"AiterScaledMMLinearKernel requires `aiter` which is not "
|
||||
+ "currently supported on non-ROCm platform.",
|
||||
)
|
||||
|
||||
try:
|
||||
import aiter # noqa: F401 # deliberately attempt to import aiter
|
||||
except Exception:
|
||||
return (
|
||||
False,
|
||||
"AiterScaledMMLinearKernel requires `aiter` which is not " +
|
||||
"installed on ROCm.")
|
||||
"AiterScaledMMLinearKernel requires `aiter` which is not "
|
||||
+ "installed on ROCm.",
|
||||
)
|
||||
# Check if rocm_aiter_gemm_w8a8_scaled_mm is enabled
|
||||
if not (
|
||||
envs.VLLM_ROCM_USE_AITER_LINEAR \
|
||||
and envs.VLLM_ROCM_USE_AITER
|
||||
):
|
||||
return (False, "AiterScaledMMLinearKernel is disabled. " +
|
||||
"Enable by setting `VLLM_ROCM_USE_AITER=1` " +
|
||||
"and `VLLM_ROCM_USE_AITER_LINEAR=1`. " +
|
||||
"`VLLM_ROCM_USE_AITER_LINEAR` default is True.")
|
||||
if not (envs.VLLM_ROCM_USE_AITER_LINEAR and envs.VLLM_ROCM_USE_AITER):
|
||||
return (
|
||||
False,
|
||||
"AiterScaledMMLinearKernel is disabled. "
|
||||
+ "Enable by setting `VLLM_ROCM_USE_AITER=1` "
|
||||
+ "and `VLLM_ROCM_USE_AITER_LINEAR=1`. "
|
||||
+ "`VLLM_ROCM_USE_AITER_LINEAR` default is True.",
|
||||
)
|
||||
|
||||
if not c.input_symmetric:
|
||||
return (False,
|
||||
"AiterScaledMMLinearKernel only supports symmetric " +
|
||||
"quantization.")
|
||||
return (
|
||||
False,
|
||||
"AiterScaledMMLinearKernel only supports symmetric " + "quantization.",
|
||||
)
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
super().process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
`AiterScaledMMLinearKernel` implements a fused version of
|
||||
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
|
||||
@@ -116,29 +117,27 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
# * dynamic, i_s is None and x_s computed from x.
|
||||
# * static, i_s is scalar and x_s is i_s.
|
||||
symmetric = azp_adj is None
|
||||
assert symmetric, ("AiterScaledMMLinearKernel only supports"
|
||||
" symmetric quantization.")
|
||||
x_q, x_s, x_zp = ops.scaled_int8_quant(x,
|
||||
i_s,
|
||||
i_zp,
|
||||
symmetric=symmetric)
|
||||
assert symmetric, (
|
||||
"AiterScaledMMLinearKernel only supports symmetric quantization."
|
||||
)
|
||||
x_q, x_s, x_zp = ops.scaled_int8_quant(x, i_s, i_zp, symmetric=symmetric)
|
||||
|
||||
assert x_zp is None, ("AiterScaledMMLinearKernel only supports"
|
||||
" symmetric quantization.")
|
||||
assert x_zp is None, (
|
||||
"AiterScaledMMLinearKernel only supports symmetric quantization."
|
||||
)
|
||||
out_dtype = x.dtype
|
||||
|
||||
assert (w_q.shape[0] % 16 == 0 and w_q.shape[1] % 16 == 0)
|
||||
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
|
||||
assert bias is None or bias.shape[0] == w_q.shape[
|
||||
1] and bias.dtype == out_dtype
|
||||
assert w_q.shape[0] % 16 == 0 and w_q.shape[1] % 16 == 0
|
||||
assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
|
||||
assert bias is None or bias.shape[0] == w_q.shape[1] and bias.dtype == out_dtype
|
||||
|
||||
m = x_q.shape[0] # a
|
||||
n = w_q.shape[1] # b
|
||||
|
||||
per_tensor_scale_a = (x_s.numel() == 1)
|
||||
per_tensor_scale_b = (w_s.numel() == 1)
|
||||
per_token_scale_a = (x_s.numel() == m)
|
||||
per_channel_scale_b = (w_s.numel() == n)
|
||||
per_tensor_scale_a = x_s.numel() == 1
|
||||
per_tensor_scale_b = w_s.numel() == 1
|
||||
per_token_scale_a = x_s.numel() == m
|
||||
per_channel_scale_b = w_s.numel() == n
|
||||
|
||||
# @TODO:
|
||||
# Maybe broadcast the per-tensor-scale into per-channel-scale
|
||||
@@ -146,16 +145,19 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
# For now, it only supports:
|
||||
# - per-tensor-per-tensor a8w8 scaled GEMM, and
|
||||
# - per-token-per-channel a8w8 scaled GEMM
|
||||
assert ((per_tensor_scale_a and per_tensor_scale_b)
|
||||
or (per_token_scale_a and per_channel_scale_b)), (
|
||||
"Currently only support per-tensor-per-tensor GEMM " +
|
||||
" and per-token-per-channel GEMM through AITER"
|
||||
" w8a8 scaled gemm. `AiterScaledMMLinearKernel` " +
|
||||
"does not support AITER block scaled GEMM.")
|
||||
assert (per_tensor_scale_a and per_tensor_scale_b) or (
|
||||
per_token_scale_a and per_channel_scale_b
|
||||
), (
|
||||
"Currently only support per-tensor-per-tensor GEMM "
|
||||
+ " and per-token-per-channel GEMM through AITER"
|
||||
" w8a8 scaled gemm. `AiterScaledMMLinearKernel` "
|
||||
+ "does not support AITER block scaled GEMM."
|
||||
)
|
||||
|
||||
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
|
||||
# a to be [M, K]
|
||||
# b to be [N, K]
|
||||
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
|
||||
return torch.ops.vllm.rocm_aiter_gemm_w8a8(x_q, w_q.t(), x_s, w_s,
|
||||
bias, out_dtype)
|
||||
return torch.ops.vllm.rocm_aiter_gemm_w8a8(
|
||||
x_q, w_q.t(), x_s, w_s, bias, out_dtype
|
||||
)
|
||||
|
||||
@@ -9,24 +9,22 @@ from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
convert_to_channelwise)
|
||||
convert_to_channelwise,
|
||||
)
|
||||
from vllm.model_executor.layers.utils import check_cpu_sgl_kernel
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import CpuArchEnum
|
||||
|
||||
from .ScaledMMLinearKernel import (ScaledMMLinearKernel,
|
||||
ScaledMMLinearLayerConfig)
|
||||
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
|
||||
|
||||
|
||||
class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 75
|
||||
|
||||
@classmethod
|
||||
def can_implement(
|
||||
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
if not current_platform.is_cpu():
|
||||
return False, "CPUScaledMM requires running on CPU."
|
||||
|
||||
@@ -36,9 +34,12 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
dtype = weight.dtype
|
||||
N, K = weight.size()
|
||||
if (current_platform.get_cpu_architecture() == CpuArchEnum.X86
|
||||
and envs.VLLM_CPU_SGL_KERNEL and self.config.input_symmetric
|
||||
and check_cpu_sgl_kernel(N, K, dtype)):
|
||||
if (
|
||||
current_platform.get_cpu_architecture() == CpuArchEnum.X86
|
||||
and envs.VLLM_CPU_SGL_KERNEL
|
||||
and self.config.input_symmetric
|
||||
and check_cpu_sgl_kernel(N, K, dtype)
|
||||
):
|
||||
self.linear_method = self._apply_weights_sgl
|
||||
self.process_weights_for_sgl(layer)
|
||||
else:
|
||||
@@ -50,8 +51,10 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
# Transpose to [K, N] for convenience
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
replace_parameter(
|
||||
layer, self.w_q_name,
|
||||
torch.nn.Parameter(weight.t().data, requires_grad=False))
|
||||
layer,
|
||||
self.w_q_name,
|
||||
torch.nn.Parameter(weight.t().data, requires_grad=False),
|
||||
)
|
||||
|
||||
# WEIGHT SCALE
|
||||
# oneDNN kernels support only per-tensor and per-channel.
|
||||
@@ -60,11 +63,12 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
is_fused_module = len(layer.logical_widths) > 1
|
||||
weight_scale = getattr(layer, self.w_s_name)
|
||||
if is_fused_module and not self.config.is_channelwise:
|
||||
weight_scale = convert_to_channelwise(weight_scale,
|
||||
layer.logical_widths)
|
||||
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
|
||||
replace_parameter(
|
||||
layer, self.w_s_name,
|
||||
torch.nn.Parameter(weight_scale.data, requires_grad=False))
|
||||
layer,
|
||||
self.w_s_name,
|
||||
torch.nn.Parameter(weight_scale.data, requires_grad=False),
|
||||
)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.config.is_static_input_scheme:
|
||||
@@ -72,8 +76,10 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
|
||||
if self.config.input_symmetric:
|
||||
replace_parameter(
|
||||
layer, self.i_s_name,
|
||||
torch.nn.Parameter(input_scale.max(), requires_grad=False))
|
||||
layer,
|
||||
self.i_s_name,
|
||||
torch.nn.Parameter(input_scale.max(), requires_grad=False),
|
||||
)
|
||||
setattr(layer, self.i_zp_name, None)
|
||||
else:
|
||||
input_zero_point = getattr(layer, self.i_zp_name)
|
||||
@@ -84,16 +90,17 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
range_max = (input_scale * (int8_traits.max - azps)).max()
|
||||
range_min = (input_scale * (int8_traits.min - azps)).min()
|
||||
|
||||
scale = (range_max - range_min) / (int8_traits.max -
|
||||
int8_traits.min)
|
||||
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
|
||||
replace_parameter(
|
||||
layer, self.i_s_name,
|
||||
torch.nn.Parameter(scale, requires_grad=False))
|
||||
layer, self.i_s_name, torch.nn.Parameter(scale, requires_grad=False)
|
||||
)
|
||||
|
||||
azp = (int8_traits.min -
|
||||
range_min / scale).round().to(dtype=torch.int32)
|
||||
replace_parameter(layer, self.i_zp_name,
|
||||
torch.nn.Parameter(azp, requires_grad=False))
|
||||
azp = (
|
||||
(int8_traits.min - range_min / scale).round().to(dtype=torch.int32)
|
||||
)
|
||||
replace_parameter(
|
||||
layer, self.i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
|
||||
)
|
||||
|
||||
else:
|
||||
setattr(layer, self.i_s_name, None)
|
||||
@@ -105,14 +112,16 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
# s_a * s_b * [(A - zp_a)B] + bias =
|
||||
# s_a * (s_b * AB) - s_a * s_b * zp_a * B + bias =
|
||||
# s_a * GEMM_output - s_a * zp_a * adj + bias
|
||||
if not (self.config.input_symmetric
|
||||
and self.config.is_static_input_scheme):
|
||||
if not (self.config.input_symmetric and self.config.is_static_input_scheme):
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
weight_scale = getattr(layer, self.w_s_name)
|
||||
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.float32)
|
||||
azp_adj = azp_adj * weight_scale.squeeze()
|
||||
setattr(layer, self.azp_adj_name,
|
||||
torch.nn.Parameter(azp_adj, requires_grad=False))
|
||||
setattr(
|
||||
layer,
|
||||
self.azp_adj_name,
|
||||
torch.nn.Parameter(azp_adj, requires_grad=False),
|
||||
)
|
||||
else:
|
||||
setattr(layer, self.azp_adj_name, None)
|
||||
|
||||
@@ -135,34 +144,37 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
packed_weight = torch.ops._C.convert_weight_packed(weight)
|
||||
replace_parameter(
|
||||
layer, self.w_q_name,
|
||||
torch.nn.Parameter(packed_weight, requires_grad=False))
|
||||
layer, self.w_q_name, torch.nn.Parameter(packed_weight, requires_grad=False)
|
||||
)
|
||||
|
||||
if layer.bias is not None:
|
||||
bias = layer.bias
|
||||
layer.register_parameter(
|
||||
"bias_fp32",
|
||||
torch.nn.Parameter(bias.float().data, requires_grad=False))
|
||||
"bias_fp32", torch.nn.Parameter(bias.float().data, requires_grad=False)
|
||||
)
|
||||
|
||||
# WEIGHT SCALE
|
||||
# CPU SGL kernels only support per-channel.
|
||||
# For per-tensor quant, convert to the per-channel case.
|
||||
weight_scale = getattr(layer, self.w_s_name)
|
||||
if not self.config.is_channelwise:
|
||||
weight_scale = convert_to_channelwise(weight_scale,
|
||||
layer.logical_widths)
|
||||
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
|
||||
replace_parameter(
|
||||
layer, self.w_s_name,
|
||||
torch.nn.Parameter(weight_scale.data, requires_grad=False))
|
||||
layer,
|
||||
self.w_s_name,
|
||||
torch.nn.Parameter(weight_scale.data, requires_grad=False),
|
||||
)
|
||||
|
||||
setattr(layer, self.i_s_name, None)
|
||||
setattr(layer, self.i_zp_name, None)
|
||||
setattr(layer, self.azp_adj_name, None)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return self.linear_method(
|
||||
layer,
|
||||
x,
|
||||
@@ -170,31 +182,33 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
)
|
||||
|
||||
def _apply_weights_onednn(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
|
||||
|
||||
# ops.scaled_int8_quant supports both dynamic and static quant:
|
||||
# * dynamic, i_s is None and x_s computed from x.
|
||||
# * static, i_s is scalar and x_s is i_s.
|
||||
x_q, x_s, x_zp = ops.onednn_scaled_int8_quant(
|
||||
x, i_s, i_zp, self.config.input_symmetric)
|
||||
x, i_s, i_zp, self.config.input_symmetric
|
||||
)
|
||||
|
||||
m = x.size(0)
|
||||
n = self.dnnl_handler.n
|
||||
out = torch.empty((m, n), dtype=x.dtype)
|
||||
ops.onednn_scaled_mm(self.dnnl_handler, x_q, out, x_s, x_zp, azp_adj,
|
||||
bias)
|
||||
ops.onednn_scaled_mm(self.dnnl_handler, x_q, out, x_s, x_zp, azp_adj, bias)
|
||||
|
||||
return out
|
||||
|
||||
def _apply_weights_sgl(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
w_q, w_s, _, _, _ = self._get_weight_params(layer)
|
||||
return torch.ops._C.int8_scaled_mm_with_quant(
|
||||
x,
|
||||
|
||||
@@ -8,23 +8,20 @@ import torch
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
convert_to_channelwise)
|
||||
convert_to_channelwise,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .ScaledMMLinearKernel import (ScaledMMLinearKernel,
|
||||
ScaledMMLinearLayerConfig)
|
||||
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
|
||||
|
||||
|
||||
class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 75
|
||||
|
||||
@classmethod
|
||||
def can_implement(
|
||||
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
if not current_platform.is_cuda():
|
||||
return False, "CutlassScaledMM requires running on CUDA."
|
||||
|
||||
@@ -35,8 +32,10 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
# Cutlass kernels need transposed weight.
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
replace_parameter(
|
||||
layer, self.w_q_name,
|
||||
torch.nn.Parameter(weight.t().data, requires_grad=False))
|
||||
layer,
|
||||
self.w_q_name,
|
||||
torch.nn.Parameter(weight.t().data, requires_grad=False),
|
||||
)
|
||||
|
||||
# WEIGHT SCALE
|
||||
# Cutlass kernels support only per-tensor and per-channel.
|
||||
@@ -45,11 +44,12 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
is_fused_module = len(layer.logical_widths) > 1
|
||||
weight_scale = getattr(layer, self.w_s_name)
|
||||
if is_fused_module and not self.config.is_channelwise:
|
||||
weight_scale = convert_to_channelwise(weight_scale,
|
||||
layer.logical_widths)
|
||||
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
|
||||
replace_parameter(
|
||||
layer, self.w_s_name,
|
||||
torch.nn.Parameter(weight_scale.data, requires_grad=False))
|
||||
layer,
|
||||
self.w_s_name,
|
||||
torch.nn.Parameter(weight_scale.data, requires_grad=False),
|
||||
)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.config.is_static_input_scheme:
|
||||
@@ -57,8 +57,10 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
|
||||
if self.config.input_symmetric:
|
||||
replace_parameter(
|
||||
layer, self.i_s_name,
|
||||
torch.nn.Parameter(input_scale.max(), requires_grad=False))
|
||||
layer,
|
||||
self.i_s_name,
|
||||
torch.nn.Parameter(input_scale.max(), requires_grad=False),
|
||||
)
|
||||
setattr(layer, self.i_zp_name, None)
|
||||
else:
|
||||
input_zero_point = getattr(layer, self.i_zp_name)
|
||||
@@ -69,17 +71,16 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
range_max = (input_scale * (int8_traits.max - azps)).max()
|
||||
range_min = (input_scale * (int8_traits.min - azps)).min()
|
||||
|
||||
scale = (range_max - range_min) / (int8_traits.max -
|
||||
int8_traits.min)
|
||||
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
|
||||
replace_parameter(
|
||||
layer, self.i_s_name,
|
||||
torch.nn.Parameter(scale, requires_grad=False))
|
||||
layer, self.i_s_name, torch.nn.Parameter(scale, requires_grad=False)
|
||||
)
|
||||
|
||||
# AZP loaded as int8 but used as int32
|
||||
azp = (int8_traits.min -
|
||||
range_min / scale).to(dtype=torch.int32)
|
||||
replace_parameter(layer, self.i_zp_name,
|
||||
torch.nn.Parameter(azp, requires_grad=False))
|
||||
azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32)
|
||||
replace_parameter(
|
||||
layer, self.i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
|
||||
)
|
||||
|
||||
else:
|
||||
setattr(layer, self.i_s_name, None)
|
||||
@@ -97,41 +98,44 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
# cutlass_w8a8 requires azp to be folded into azp_adj
|
||||
# in the per-tensor case
|
||||
azp_adj = getattr(layer, self.i_zp_name) * azp_adj
|
||||
setattr(layer, self.azp_adj_name,
|
||||
torch.nn.Parameter(azp_adj, requires_grad=False))
|
||||
setattr(
|
||||
layer,
|
||||
self.azp_adj_name,
|
||||
torch.nn.Parameter(azp_adj, requires_grad=False),
|
||||
)
|
||||
else:
|
||||
setattr(layer, self.azp_adj_name, None)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
|
||||
|
||||
# ops.scaled_int8_quant supports both dynamic and static quant:
|
||||
# * dynamic, i_s is None and x_s computed from x.
|
||||
# * static, i_s is scalar and x_s is i_s.
|
||||
symmetric = azp_adj is None
|
||||
x_q, x_s, x_zp = ops.scaled_int8_quant(x.contiguous(),
|
||||
i_s,
|
||||
i_zp,
|
||||
symmetric=symmetric)
|
||||
x_q, x_s, x_zp = ops.scaled_int8_quant(
|
||||
x.contiguous(), i_s, i_zp, symmetric=symmetric
|
||||
)
|
||||
|
||||
if x_zp is not None:
|
||||
# Currently, static is always per-tensor and dynamic is per-token
|
||||
static = i_zp is not None
|
||||
azp = None if static else x_zp
|
||||
return ops.cutlass_scaled_mm_azp(x_q,
|
||||
w_q,
|
||||
scale_a=x_s,
|
||||
scale_b=w_s,
|
||||
out_dtype=x.dtype,
|
||||
azp_adj=azp_adj,
|
||||
azp=azp,
|
||||
bias=bias)
|
||||
return ops.cutlass_scaled_mm(x_q,
|
||||
w_q,
|
||||
scale_a=x_s,
|
||||
scale_b=w_s,
|
||||
out_dtype=x.dtype,
|
||||
bias=bias)
|
||||
return ops.cutlass_scaled_mm_azp(
|
||||
x_q,
|
||||
w_q,
|
||||
scale_a=x_s,
|
||||
scale_b=w_s,
|
||||
out_dtype=x.dtype,
|
||||
azp_adj=azp_adj,
|
||||
azp=azp,
|
||||
bias=bias,
|
||||
)
|
||||
return ops.cutlass_scaled_mm(
|
||||
x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias
|
||||
)
|
||||
|
||||
@@ -12,30 +12,32 @@ from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
|
||||
|
||||
|
||||
class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 75
|
||||
|
||||
@classmethod
|
||||
def can_implement(
|
||||
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
if current_platform.is_cpu():
|
||||
return (
|
||||
False,
|
||||
"TritonScaledMMLinearKernel requires Triton which is not " +
|
||||
"currently supported on CPU.")
|
||||
"TritonScaledMMLinearKernel requires Triton which is not "
|
||||
+ "currently supported on CPU.",
|
||||
)
|
||||
if not c.input_symmetric:
|
||||
return (False,
|
||||
"TritonScaledMMLinearKernel only supports symmetric " +
|
||||
"quantization.")
|
||||
return (
|
||||
False,
|
||||
"TritonScaledMMLinearKernel only supports symmetric " + "quantization.",
|
||||
)
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
super().process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return super().apply_weights(layer, x, bias)
|
||||
|
||||
@@ -9,25 +9,23 @@ from functorch.experimental.control_flow import cond # noqa: F401
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
convert_to_channelwise)
|
||||
convert_to_channelwise,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .ScaledMMLinearKernel import (ScaledMMLinearKernel,
|
||||
ScaledMMLinearLayerConfig)
|
||||
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
|
||||
|
||||
|
||||
class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
raise NotImplementedError(
|
||||
"TPU platform does have a concept of compute capability, "
|
||||
"this method should not be called.")
|
||||
"this method should not be called."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def can_implement(
|
||||
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
if not current_platform.is_tpu():
|
||||
return False, "ScaledMMXLA requires running on TPU."
|
||||
|
||||
@@ -46,8 +44,9 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
# WEIGHT
|
||||
# [out, in] (different than cutlass_scaled_mm)
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
replace_parameter(layer, self.w_q_name,
|
||||
torch.nn.Parameter(weight.data, requires_grad=False))
|
||||
replace_parameter(
|
||||
layer, self.w_q_name, torch.nn.Parameter(weight.data, requires_grad=False)
|
||||
)
|
||||
|
||||
# WEIGHT SCALE
|
||||
# XLA kernels support only per-tensor and per-channel.
|
||||
@@ -56,14 +55,15 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
is_fused_module = len(layer.logical_widths) > 1
|
||||
weight_scale = getattr(layer, self.w_s_name)
|
||||
if is_fused_module and not self.config.is_channelwise:
|
||||
weight_scale = convert_to_channelwise(weight_scale,
|
||||
layer.logical_widths)
|
||||
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
|
||||
|
||||
# [out_channel,] (different than cutlass_scaled_mm)
|
||||
weight_scale = weight_scale.squeeze(-1)
|
||||
replace_parameter(
|
||||
layer, self.w_s_name,
|
||||
torch.nn.Parameter(weight_scale.data, requires_grad=False))
|
||||
layer,
|
||||
self.w_s_name,
|
||||
torch.nn.Parameter(weight_scale.data, requires_grad=False),
|
||||
)
|
||||
|
||||
# Only support symmetric dynamic activation quantization.
|
||||
setattr(layer, self.i_s_name, None)
|
||||
@@ -74,8 +74,7 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
# to specialize the graph since bias is not dynamic.
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=
|
||||
"Pred is a Python constant. When used with torch.cond, it specializes on one of the branches." # noqa: E501
|
||||
message="Pred is a Python constant. When used with torch.cond, it specializes on one of the branches.", # noqa: E501
|
||||
)
|
||||
|
||||
def no_add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]):
|
||||
@@ -84,14 +83,17 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
def add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]):
|
||||
return x + bias
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
w_q, w_s, _, _, _ = self._get_weight_params(layer)
|
||||
|
||||
# Required to register custom ops.
|
||||
import torch_xla.experimental.custom_kernel # noqa: F401
|
||||
|
||||
out = torch.ops.xla.quantized_matmul_int8(
|
||||
x,
|
||||
w_q,
|
||||
|
||||
Reference in New Issue
Block a user