Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -24,7 +24,6 @@ class MPLinearLayerConfig:
class MPLinearKernel(ABC):
@classmethod
@abstractmethod
def get_min_capability(cls) -> int:
@@ -32,16 +31,17 @@ class MPLinearKernel(ABC):
@classmethod
@abstractmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
raise NotImplementedError
def __init__(self,
c: MPLinearLayerConfig,
w_q_param_name: str,
w_s_param_name: str,
w_zp_param_name: Optional[str] = None,
w_gidx_param_name: Optional[str] = None) -> None:
def __init__(
self,
c: MPLinearLayerConfig,
w_q_param_name: str,
w_s_param_name: str,
w_zp_param_name: Optional[str] = None,
w_gidx_param_name: Optional[str] = None,
) -> None:
assert self.can_implement(c)
self.config = c
self.w_q_name = w_q_param_name
@@ -58,31 +58,34 @@ class MPLinearKernel(ABC):
raise NotImplementedError
@abstractmethod
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
def _transform_param(self, layer: torch.nn.Module, name: Optional[str],
fn: Callable) -> None:
def _transform_param(
self, layer: torch.nn.Module, name: Optional[str], fn: Callable
) -> None:
if name is not None and getattr(layer, name, None) is not None:
old_param = getattr(layer, name)
new_param = fn(old_param)
# replace the parameter with torch.nn.Parameter for TorchDynamo
# compatibility
replace_parameter(
layer, name,
torch.nn.Parameter(new_param.data, requires_grad=False))
layer, name, torch.nn.Parameter(new_param.data, requires_grad=False)
)
def _get_weight_params(
self, layer: torch.nn.Module) -> tuple[
torch.Tensor, # w_q
torch.Tensor, # w_s
Optional[torch.Tensor], # w_zp,
Optional[torch.Tensor] # w_gidx
]:
self, layer: torch.nn.Module
) -> tuple[
torch.Tensor, # w_q
torch.Tensor, # w_s
Optional[torch.Tensor], # w_zp,
Optional[torch.Tensor], # w_gidx
]:
return (
getattr(layer, self.w_q_name),
getattr(layer, self.w_s_name),

View File

@@ -5,23 +5,33 @@ from typing import Optional
import vllm.envs as envs
from vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark import ( # noqa: E501
AllSparkLinearKernel)
AllSparkLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.bitblas import ( # noqa: E501
BitBLASLinearKernel)
BitBLASLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.conch import ( # noqa: E501
ConchLinearKernel)
ConchLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.cutlass import ( # noqa: E501
CutlassW4A8LinearKernel)
CutlassW4A8LinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.dynamic_4bit import ( # noqa: E501
Dynamic4bitLinearKernel)
Dynamic4bitLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501
ExllamaLinearKernel)
ExllamaLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501
MacheteLinearKernel)
MacheteLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.marlin import ( # noqa: E501
MarlinLinearKernel)
MarlinLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKernel import ( # noqa: E501
MPLinearKernel, MPLinearLayerConfig)
MPLinearKernel,
MPLinearLayerConfig,
)
from vllm.platforms import current_platform
# in priority/performance order (when available)
@@ -38,11 +48,11 @@ _POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [
def choose_mp_linear_kernel(
config: MPLinearLayerConfig,
compute_capability: Optional[int] = None) -> type[MPLinearKernel]:
config: MPLinearLayerConfig, compute_capability: Optional[int] = None
) -> type[MPLinearKernel]:
"""
Choose an MPLinearKernel that can implement the given config for the given
compute capability. Attempts to choose the best kernel in terms of
compute capability. Attempts to choose the best kernel in terms of
performance.
Args:
@@ -69,14 +79,18 @@ def choose_mp_linear_kernel(
for kernel in _POSSIBLE_KERNELS:
if kernel.__name__ in envs.VLLM_DISABLED_KERNELS:
failure_reasons.append(
f' {kernel.__name__} disabled by environment variable')
f" {kernel.__name__} disabled by environment variable"
)
continue
if (compute_capability is not None
and kernel.get_min_capability() > compute_capability):
if (
compute_capability is not None
and kernel.get_min_capability() > compute_capability
):
failure_reasons.append(
f"{kernel.__name__} requires capability "
f"{kernel.get_min_capability()}, current compute "
f" capability is {compute_capability}")
f" capability is {compute_capability}"
)
continue
can_implement, failure_reason = kernel.can_implement(config)
@@ -84,10 +98,10 @@ def choose_mp_linear_kernel(
return kernel
else:
failure_reasons.append(
f' {kernel.__name__} cannot implement due to: {failure_reason}'
f" {kernel.__name__} cannot implement due to: {failure_reason}"
)
raise ValueError(
"Failed to find a kernel that can implement the "\
"WNA16 linear layer. Reasons: \n"
+ '\n'.join(failure_reasons))
"Failed to find a kernel that can implement the "
"WNA16 linear layer. Reasons: \n" + "\n".join(failure_reasons)
)

View File

@@ -8,22 +8,21 @@ import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, check_allspark_supported_dtype_shape)
from vllm.model_executor.parameter import (BasevLLMParameter,
permute_param_layout_)
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
check_allspark_supported_dtype_shape,
)
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
class AllSparkLinearKernel(MPLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
if c.has_g_idx:
return False, "Act reordering currently not supported by AllSpark"
@@ -35,7 +34,8 @@ class AllSparkLinearKernel(MPLinearKernel):
c.partition_weight_shape[1], # out_features
c.group_size,
c.weight_type,
c.act_type)
c.act_type,
)
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
@@ -49,8 +49,8 @@ class AllSparkLinearKernel(MPLinearKernel):
sm_count = properties.multi_processor_count
sm_version = properties.major * 10 + properties.minor
gemm_args = {}
gemm_args['sm_count'] = sm_count
gemm_args['sm_version'] = sm_version
gemm_args["sm_count"] = sm_count
gemm_args["sm_version"] = sm_version
self.gemm_args = gemm_args
@@ -59,43 +59,42 @@ class AllSparkLinearKernel(MPLinearKernel):
old_scale_param = getattr(layer, self.w_s_name)
assert isinstance(old_weight_param, BasevLLMParameter)
permute_param_layout_(old_weight_param,
input_dim=0,
output_dim=1,
packed_dim=0)
permute_param_layout_(old_weight_param, input_dim=0, output_dim=1, packed_dim=0)
assert isinstance(old_scale_param, BasevLLMParameter)
permute_param_layout_(old_scale_param, input_dim=0, output_dim=1)
# unpack weight from K / 4 x N int32 to K x N uint8
new_weight_param = torch.nn.Parameter(old_weight_param.data,
requires_grad=False)
new_weight_param.data = new_weight_param.data.t().contiguous().view(
dtype=torch.uint8)
new_weight_param = torch.nn.Parameter(
old_weight_param.data, requires_grad=False
)
new_weight_param.data = (
new_weight_param.data.t().contiguous().view(dtype=torch.uint8)
)
new_weight_param.data = new_weight_param.data.t().contiguous()
new_scale_param = torch.nn.Parameter(old_scale_param.data,
requires_grad=False)
new_scale_param = torch.nn.Parameter(old_scale_param.data, requires_grad=False)
# reorder K x N weight as N32K16 format for Ampere W8A16
new_weight_param.data, new_scale_param.data, _ = \
ops.allspark_repack_weight(
new_weight_param.data, new_scale_param.data, None,
c.zero_points)
new_weight_param.data, new_scale_param.data, _ = ops.allspark_repack_weight(
new_weight_param.data, new_scale_param.data, None, c.zero_points
)
replace_parameter(layer, self.w_q_name, new_weight_param.data)
replace_parameter(layer, self.w_s_name, new_scale_param.data)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
c = self.config
gemm_args = self.gemm_args
w_q, w_s, _, _ = self._get_weight_params(layer)
reshaped_x = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)
output = ops.allspark_w8a16_gemm(
a=reshaped_x,
@@ -104,11 +103,12 @@ class AllSparkLinearKernel(MPLinearKernel):
b_qzeros=None,
n=c.partition_weight_shape[1],
group_size=c.group_size,
sm_count=gemm_args['sm_count'],
sm_version=gemm_args['sm_version'],
sm_count=gemm_args["sm_count"],
sm_version=gemm_args["sm_version"],
CUBLAS_M_THRESHOLD=ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
has_zp=c.zero_points,
n32k16_reorder=True)
n32k16_reorder=True,
)
if bias is not None:
output.add_(bias) # In-place add

View File

@@ -10,10 +10,16 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
BITBLAS_OPTIMIZE_FEATURES, BITBLAS_SUPPORTED_GROUP_SIZES,
MINIMUM_BITBLAS_VERSION, bitblas_make_empty_g_idx, bitblas_sort_g_idx,
check_bitblas_supports_shape, query_bitblas_supported_quant_types,
unpack_gptq_qweight, unpack_gptq_qzeros)
BITBLAS_OPTIMIZE_FEATURES,
BITBLAS_SUPPORTED_GROUP_SIZES,
MINIMUM_BITBLAS_VERSION,
bitblas_make_empty_g_idx,
bitblas_sort_g_idx,
check_bitblas_supports_shape,
query_bitblas_supported_quant_types,
unpack_gptq_qweight,
unpack_gptq_qzeros,
)
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
@@ -21,7 +27,6 @@ logger = init_logger(__name__)
class BitBLASLinearKernel(MPLinearKernel):
OPT_FEATURES: list[int] = BITBLAS_OPTIMIZE_FEATURES
ENABLE_TUNING: bool = True
MATMUL_LAYOUT: str = "nt"
@@ -44,8 +49,9 @@ class BitBLASLinearKernel(MPLinearKernel):
bitblas_quant_config: Optional[QuantizationConfig] = None,
):
self.quant_config = bitblas_quant_config
super().__init__(c, w_q_param_name, w_s_param_name, w_zp_param_name,
w_gidx_param_name)
super().__init__(
c, w_q_param_name, w_s_param_name, w_zp_param_name, w_gidx_param_name
)
def repack_bitblas_from_gptq(
self,
@@ -54,19 +60,18 @@ class BitBLASLinearKernel(MPLinearKernel):
qzeros: Optional[torch.Tensor] = None,
):
from bitblas.quantization.utils import general_compress
assert self.bitblas_matmul is not None, "bitblas_matmul is None"
quant_config = self.quant_config
# qweight in gptq old quant linear stored with
# (outfeatures, infeatures), should be transposed.
qweight = b_q_weight.T.contiguous().view(
quant_config.torch_storage_dtype) # type: ignore[union-attr]
intweight = unpack_gptq_qweight(
qweight,
quant_config.weight_bits).contiguous() # type: ignore[union-attr]
qweight = b_q_weight.T.contiguous().view(quant_config.torch_storage_dtype) # type: ignore[union-attr]
intweight = unpack_gptq_qweight(qweight, quant_config.weight_bits).contiguous() # type: ignore[union-attr]
if self.bitblas_matmul.weight_transform is not None: # type: ignore[attr-defined]
qweight = self.bitblas_matmul.weight_transform( # type: ignore[attr-defined]
intweight.cpu()).cuda()
intweight.cpu()
).cuda()
# scales in gptq old quant linear stored with
# (infeatures // group_size, outfeatures), should be transposed.
scales = scales.T.contiguous()
@@ -90,9 +95,14 @@ class BitBLASLinearKernel(MPLinearKernel):
general_compress(
intzeros.T.contiguous().cpu().numpy(),
weight_bits,
)).to(qweight.device).
to(quant_config.torch_storage_dtype # type: ignore[union-attr]
).contiguous())
)
)
.to(qweight.device)
.to(
quant_config.torch_storage_dtype # type: ignore[union-attr]
)
.contiguous()
)
else:
raise ValueError("Unsupported zeros type: {}".format(zeros_mode))
@@ -103,41 +113,50 @@ class BitBLASLinearKernel(MPLinearKernel):
return 70
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
is_bitblas_installed = True
try:
import bitblas
if version.parse(bitblas.__version__) < version.parse(
MINIMUM_BITBLAS_VERSION):
MINIMUM_BITBLAS_VERSION
):
raise ImportError(
"bitblas version is wrong. Please "
f"install bitblas>={MINIMUM_BITBLAS_VERSION}")
f"install bitblas>={MINIMUM_BITBLAS_VERSION}"
)
except ImportError:
is_bitblas_installed = False
if not is_bitblas_installed:
return False, "bitblas is not installed. Please install bitblas "\
"by running `pip install bitblas>="\
f"{MINIMUM_BITBLAS_VERSION}`"
return (
False,
"bitblas is not installed. Please install bitblas "
"by running `pip install bitblas>="
f"{MINIMUM_BITBLAS_VERSION}`",
)
quant_types = query_bitblas_supported_quant_types(c.zero_points)
if c.weight_type not in quant_types:
return False, (f"Quant type ({c.weight_type}) not supported by"
f" BitBLAS, supported types are: {quant_types}")
return False, (
f"Quant type ({c.weight_type}) not supported by"
f" BitBLAS, supported types are: {quant_types}"
)
if c.group_size not in BITBLAS_SUPPORTED_GROUP_SIZES:
return False, (f"Group size ({c.group_size}) not supported by "
"BitBLAS, supported group sizes are: "
f"{BITBLAS_SUPPORTED_GROUP_SIZES}")
return False, (
f"Group size ({c.group_size}) not supported by "
"BitBLAS, supported group sizes are: "
f"{BITBLAS_SUPPORTED_GROUP_SIZES}"
)
return check_bitblas_supports_shape(
c.partition_weight_shape[1], # out_features
c.partition_weight_shape[0], # in_features
c.full_weight_shape[0], # in_features
c.group_size)
c.group_size,
)
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
@@ -149,14 +168,15 @@ class BitBLASLinearKernel(MPLinearKernel):
# Default names since bitblas requires empty parameters for these,
# TODO: remove this requirement from bitblas (allow optional tensors)
if self.w_gidx_name is None:
self.w_gidx_name = "g_idx"
if self.w_zp_name is None:
self.w_zp_name = "qzeros"
if getattr(self, "w_gidx_name", None) is None:
self.w_gidx_name: str = "g_idx"
if getattr(self, "w_zp_name", None) is None:
self.w_zp_name: str = "qzeros"
if c.has_g_idx:
g_idx, g_idx_sort_indices = bitblas_sort_g_idx(
getattr(layer, self.w_gidx_name))
getattr(layer, self.w_gidx_name)
)
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
else:
@@ -169,13 +189,11 @@ class BitBLASLinearKernel(MPLinearKernel):
setattr(layer, self.w_zp_name, bitblas_make_empty_g_idx(device))
# Repack weights
bitblas_qweight, bitblas_scales, bitblas_qzeros = (
self.repack_bitblas_from_gptq(
layer.qweight,
layer.scales,
None if quant_config.is_sym else # type: ignore[union-attr]
layer.qzeros, # type: ignore[union-attr]
))
bitblas_qweight, bitblas_scales, bitblas_qzeros = self.repack_bitblas_from_gptq(
layer.qweight,
layer.scales,
None if quant_config.is_sym else layer.qzeros, # type: ignore[union-attr]
)
replace_parameter(layer, self.w_q_name, bitblas_qweight)
replace_parameter(layer, self.w_s_name, bitblas_scales)
if bitblas_qzeros is not None:
@@ -212,6 +230,7 @@ class BitBLASLinearKernel(MPLinearKernel):
bits,
):
from bitblas import MatmulConfig
bitblas_dtype = self.BITBLAS_DTYPES[params_dtype]
quant_config = self.quant_config
with_scaling = False
@@ -248,30 +267,33 @@ class BitBLASLinearKernel(MPLinearKernel):
zeros_mode=zeros_mode,
)
self.bitblas_matmul = self._get_or_create_bitblas_operator(
matmul_config, enable_tuning)
matmul_config, enable_tuning
)
def _get_or_create_bitblas_operator(self, config, enable_tuning):
from bitblas import Matmul, auto_detect_nvidia_target
from bitblas.cache import get_database_path, global_operator_cache
BITBLAS_DATABASE_PATH = get_database_path()
BITBLAS_TARGET = auto_detect_nvidia_target()
if global_operator_cache.size() == 0:
global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH,
BITBLAS_TARGET)
global_operator_cache.load_from_database(
BITBLAS_DATABASE_PATH, BITBLAS_TARGET
)
bitblas_matmul = global_operator_cache.get(config)
if bitblas_matmul is None:
bitblas_matmul = Matmul(config,
target=BITBLAS_TARGET,
enable_tuning=False)
bitblas_matmul = Matmul(config, target=BITBLAS_TARGET, enable_tuning=False)
if enable_tuning:
bitblas_matmul.hardware_aware_finetune(topk=20)
global_operator_cache.add(config, bitblas_matmul)
global_operator_cache.save_into_database(
BITBLAS_DATABASE_PATH, BITBLAS_TARGET)
BITBLAS_DATABASE_PATH, BITBLAS_TARGET
)
TUNING_MESSAGE = (
f"BitBLAS Operator {config} tuned and saved to database.")
f"BitBLAS Operator {config} tuned and saved to database."
)
logger.info(TUNING_MESSAGE)
else:
_message = f"BitBLAS Operator {config} created without tuning. "
@@ -287,7 +309,7 @@ class BitBLASLinearKernel(MPLinearKernel):
x: torch.Tensor,
) -> torch.Tensor:
output_size_per_partition = self.config.partition_weight_shape[1]
out_shape = x.shape[:-1] + (output_size_per_partition, )
out_shape = x.shape[:-1] + (output_size_per_partition,)
args = [x, layer.qweight, layer.scales]
if self.bitblas_matmul.config.with_zeros: # type: ignore[attr-defined]
args.append(layer.qzeros)
@@ -297,5 +319,6 @@ class BitBLASLinearKernel(MPLinearKernel):
def apply_weights(self, layer, x, bias=None):
NOT_IMPLEMENT_MESSAGE = (
f"{self.__class__.__name__}.apply_weights is not implemented. "
"Please use BitBLASLinearKernel.apply_gptq_bitblas_linear instead")
"Please use BitBLASLinearKernel.apply_gptq_bitblas_linear instead"
)
raise NotImplementedError(NOT_IMPLEMENT_MESSAGE)

View File

@@ -6,44 +6,49 @@ from typing import Final, Optional
import torch
from vllm.model_executor.parameter import (BasevLLMParameter,
permute_param_layout_)
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
from vllm.scalar_type import scalar_types
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
_CONCH_SUPPORTED_WEIGHT_TYPES: Final = [
scalar_types.uint4, scalar_types.uint8, scalar_types.uint4b8,
scalar_types.uint8b128
scalar_types.uint4,
scalar_types.uint8,
scalar_types.uint4b8,
scalar_types.uint8b128,
]
_CONCH_SUPPORTED_GROUP_SIZES: Final = [-1, 128]
class ConchLinearKernel(MPLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
if c.weight_type not in _CONCH_SUPPORTED_WEIGHT_TYPES:
error_msg = f"Weight type ({c.weight_type}) not supported by "\
"ConchLinearKernel, supported types are: " \
f"{_CONCH_SUPPORTED_WEIGHT_TYPES}"
error_msg = (
f"Weight type ({c.weight_type}) not supported by "
"ConchLinearKernel, supported types are: "
f"{_CONCH_SUPPORTED_WEIGHT_TYPES}"
)
return False, error_msg
if c.group_size not in _CONCH_SUPPORTED_GROUP_SIZES:
error_msg = f"Group size ({c.group_size}) not supported by "\
"ConchLinearKernel, supported group sizes are: " \
f"{_CONCH_SUPPORTED_GROUP_SIZES}"
error_msg = (
f"Group size ({c.group_size}) not supported by "
"ConchLinearKernel, supported group sizes are: "
f"{_CONCH_SUPPORTED_GROUP_SIZES}"
)
return False, error_msg
if find_spec("conch") is None:
error_msg = "conch-triton-kernels is not installed, please "\
"install it via `pip install conch-triton-kernels` "\
"and try again!"
error_msg = (
"conch-triton-kernels is not installed, please "
"install it via `pip install conch-triton-kernels` "
"and try again!"
)
return False, error_msg
return True, None
@@ -52,7 +57,6 @@ class ConchLinearKernel(MPLinearKernel):
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
@@ -68,10 +72,12 @@ class ConchLinearKernel(MPLinearKernel):
self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from conch.ops.quantization.gemm import mixed_precision_gemm
w_q, w_s, w_zp, _ = self._get_weight_params(layer)

View File

@@ -7,10 +7,8 @@ import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.model_executor.parameter import (BasevLLMParameter,
permute_param_layout_)
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
@@ -18,26 +16,22 @@ from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
class CutlassW4A8LinearKernel(MPLinearKernel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# dynamic per-tok fp8 activation quantization
self.quant_fp8 = QuantFP8(static=False,
group_shape=GroupShape.PER_TOKEN)
self.quant_fp8 = QuantFP8(static=False, group_shape=GroupShape.PER_TOKEN)
@classmethod
def get_min_capability(cls) -> int:
return 90
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
if not current_platform.is_cuda():
return False, "CUTLASS only supported on CUDA"
if not current_platform.is_device_capability(90):
return False, "CUTLASS W4A8 requires compute capability of 90 "\
"(Hopper)"
return False, "CUTLASS W4A8 requires compute capability of 90 (Hopper)"
if c.act_type != torch.float8_e4m3fn:
return False, "CUTLASS W4A8 only supports FP8 (e4m3) activations"
@@ -49,8 +43,11 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
return False, "Zero points not supported by CUTLASS W4A8"
if c.weight_type != scalar_types.int4:
return False, f"Quant type ({c.weight_type}) not supported by "\
"CUTLASS W4A8, only supported int4"
return (
False,
f"Quant type ({c.weight_type}) not supported by "
"CUTLASS W4A8, only supported int4",
)
# TODO(czhu): support -1 (column-wise)
if c.group_size != 128:
@@ -58,12 +55,16 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
in_features, out_features = c.partition_weight_shape
if in_features % 128 or out_features % 128:
return False, "K and N must be divisible by 128, got "\
f"{c.partition_weight_shape}"
return (
False,
f"K and N must be divisible by 128, got {c.partition_weight_shape}",
)
if c.out_type != torch.bfloat16:
return False, "Only bfloat16 output type currently supported"\
f"got {c.out_type=}"
return (
False,
f"Only bfloat16 output type currently supportedgot {c.out_type=}",
)
return True, None
@@ -71,13 +72,11 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module):
# TODO(czhu): optimize speed/mem usage
def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
x.data = ops.cutlass_encode_and_reorder_int4b(
x.data.t().contiguous().t())
x.data = ops.cutlass_encode_and_reorder_int4b(x.data.t().contiguous().t())
return x
def transform_w_s(x):
@@ -92,24 +91,28 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
self._transform_param(layer, self.w_s_name, transform_w_s)
self._transform_param(layer, "weight_chan_scale", lambda x: x)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
c = self.config
w_q, w_s, _, _ = self._get_weight_params(layer)
w_ch_s = layer.weight_chan_scale
x_2d = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)
x_2d, act_scales = self.quant_fp8(x_2d)
output = ops.cutlass_w4a8_mm(a=x_2d,
b_q=w_q,
b_group_scales=w_s,
b_group_size=c.group_size,
a_token_scales=act_scales,
b_channel_scales=w_ch_s)
output = ops.cutlass_w4a8_mm(
a=x_2d,
b_q=w_q,
b_group_scales=w_s,
b_group_size=c.group_size,
a_token_scales=act_scales,
b_channel_scales=w_ch_s,
)
if bias is not None:
output.add_(bias) # In-place add

View File

@@ -20,37 +20,45 @@ class Dynamic4bitLinearKernel(MPLinearKernel):
return 1
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
if not current_platform.is_cpu():
return False, "Only CPU is supported"
if c.weight_type not in cls.SUPPORTED_QUANT_TYPES:
return False, f"Unsupported quant type {c.weight_type}"
if current_platform.get_cpu_architecture(
) == CpuArchEnum.ARM and c.act_type not in [
if (
current_platform.get_cpu_architecture() == CpuArchEnum.ARM
and c.act_type
not in [
torch.float32,
]:
return False, "Dynamic4bitLinearKernel on Arm requires"\
" Float32 activations"
]
):
return False, "Dynamic4bitLinearKernel on Arm requires Float32 activations"
if c.full_weight_shape[0] % c.group_size != 0:
return False, f"Group size ({c.group_size}) does not evenly divide"\
" the number of input features "\
f"({c.full_weight_shape[0]})"
return (
False,
f"Group size ({c.group_size}) does not evenly divide"
" the number of input features "
f"({c.full_weight_shape[0]})",
)
if current_platform.get_cpu_architecture() == CpuArchEnum.ARM:
try:
# Attempt to retrieve the operation
_ = torch.ops.aten._dyn_quant_matmul_4bit
except AttributeError:
return False, f"PyTorch {torch.__version__} does not support"\
" _dyn_quant_matmul_4bit. Install a newer version"
return (
False,
f"PyTorch {torch.__version__} does not support"
" _dyn_quant_matmul_4bit. Install a newer version",
)
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module):
c = self.config
packed_weight = getattr(layer, self.w_q_name)
packed_weight = packed_weight.add(8)
uint8_packed = (packed_weight[::, 1::2] << 4
| packed_weight[::, ::2]).to(torch.uint8)
uint8_packed = (packed_weight[::, 1::2] << 4 | packed_weight[::, ::2]).to(
torch.uint8
)
scales = getattr(layer, self.w_s_name)
block_size = c.group_size
@@ -71,22 +79,34 @@ class Dynamic4bitLinearKernel(MPLinearKernel):
# Repack weights as per kernel requirement
w = torch.ops.aten._dyn_quant_pack_4bit_weight(
uint8_packed, scales, layer.bias, block_size,
c.partition_weight_shape[0], c.partition_weight_shape[1])
replace_parameter(layer, self.w_q_name,
torch.nn.Parameter(w, requires_grad=False))
uint8_packed,
scales,
layer.bias,
block_size,
c.partition_weight_shape[0],
c.partition_weight_shape[1],
)
replace_parameter(
layer, self.w_q_name, torch.nn.Parameter(w, requires_grad=False)
)
setattr(layer, self.w_s_name, None)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
c = self.config
x_2d = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)
w_q = getattr(layer, self.w_q_name)
output = torch.ops.aten._dyn_quant_matmul_4bit(
x_2d, w_q, c.group_size, c.partition_weight_shape[0],
c.partition_weight_shape[1])
x_2d,
w_q,
c.group_size,
c.partition_weight_shape[0],
c.partition_weight_shape[1],
)
return output.reshape(out_shape)

View File

@@ -7,9 +7,9 @@ import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.quant_utils import (
pack_quantized_values_into_int32)
from vllm.model_executor.parameter import (BasevLLMParameter,
permute_param_layout_)
pack_quantized_values_into_int32,
)
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
from vllm.scalar_type import scalar_types
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
@@ -25,31 +25,41 @@ class ExllamaLinearKernel(MPLinearKernel):
return 60
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
if c.has_g_idx and\
c.partition_weight_shape[0] != c.full_weight_shape[0]:
return False, "Act reordering currently not supported by Exllama, "\
"when the input features are partitioned across "\
"devices"
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
if c.has_g_idx and c.partition_weight_shape[0] != c.full_weight_shape[0]:
return (
False,
"Act reordering currently not supported by Exllama, "
"when the input features are partitioned across "
"devices",
)
if c.partition_weight_shape[1] % (32 // c.weight_type.size_bits) != 0:
return False, "Output features must be a multiple of the pack " \
"factor (32 / num_bits) so that we can correctly " \
"pack the zero points"
return (
False,
"Output features must be a multiple of the pack "
"factor (32 / num_bits) so that we can correctly "
"pack the zero points",
)
if c.act_type != torch.float16:
return False, "Exllama only supports float16 activations"
if c.weight_type not in cls.SUPPORTED_QUANT_TYPES:
return False, f"Quant type ({c.weight_type}) not supported by "\
"Exllama, supported types are: "\
f"{cls.SUPPORTED_QUANT_TYPES}"
return (
False,
f"Quant type ({c.weight_type}) not supported by "
"Exllama, supported types are: "
f"{cls.SUPPORTED_QUANT_TYPES}",
)
if c.full_weight_shape[0] % c.group_size != 0:
return False, f"Group size ({c.group_size}) does not evenly divide"\
" the number of input features "\
f"({c.full_weight_shape[0]})"
return (
False,
f"Group size ({c.group_size}) does not evenly divide"
" the number of input features "
f"({c.full_weight_shape[0]})",
)
return True, None
@@ -70,21 +80,23 @@ class ExllamaLinearKernel(MPLinearKernel):
# exllama kernel adding 1 to the zero points during inference)
# Documentation of the bug can be found here:
# https://garden.danieldk.eu/GPTQ-Checkpoint-Format
zeros = torch.full((groups, out_features),
c.weight_type.bias - 1,
dtype=torch.int32,
device=device)
zeros = torch.full(
(groups, out_features),
c.weight_type.bias - 1,
dtype=torch.int32,
device=device,
)
else:
raise NotImplementedError(
"A 0 zero-point is not supported by Exllama due to "
"a bug in the original GPTQ checkpoint format leading to "
"exllama kernel adding 1 to the zero points during "
"inference")
zeros = pack_quantized_values_into_int32(zeros,
c.weight_type,
packed_dim=1)
setattr(layer, self.w_zp_name,
torch.nn.Parameter(zeros, requires_grad=False))
"inference"
)
zeros = pack_quantized_values_into_int32(zeros, c.weight_type, packed_dim=1)
setattr(
layer, self.w_zp_name, torch.nn.Parameter(zeros, requires_grad=False)
)
if c.has_g_idx:
@@ -96,10 +108,9 @@ class ExllamaLinearKernel(MPLinearKernel):
self._transform_param(layer, self.w_gidx_name, transform_w_g_idx)
else:
self.w_gidx_name = "g_idx"
empty_g_idx = torch.nn.Parameter(torch.empty((0, ),
dtype=torch.int,
device=device),
requires_grad=False)
empty_g_idx = torch.nn.Parameter(
torch.empty((0,), dtype=torch.int, device=device), requires_grad=False
)
setattr(layer, self.w_gidx_name, empty_g_idx)
def transform_w_q(x):
@@ -122,21 +133,24 @@ class ExllamaLinearKernel(MPLinearKernel):
self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
c = self.config
x_2d = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)
w_q, w_s, w_zp, w_g_idx = self._get_weight_params(layer)
assert w_zp is not None, "Zero points are required by Exllama"
assert w_g_idx is not None, "Group index is required by Exllama"
output = ops.gptq_gemm(x_2d, w_q, w_zp, w_s, w_g_idx, True,
c.weight_type.size_bits)
output = ops.gptq_gemm(
x_2d, w_q, w_zp, w_s, w_g_idx, True, c.weight_type.size_bits
)
if bias is not None:
output.add_(bias)

View File

@@ -8,26 +8,27 @@ import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.machete_utils import (
check_machete_supports_shape, query_machete_supported_group_sizes,
query_machete_supported_quant_types)
check_machete_supports_shape,
query_machete_supported_group_sizes,
query_machete_supported_quant_types,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
pack_quantized_values_into_int32, unpack_quantized_values_into_int32)
from vllm.model_executor.parameter import (BasevLLMParameter,
permute_param_layout_)
pack_quantized_values_into_int32,
unpack_quantized_values_into_int32,
)
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
from vllm.platforms import current_platform
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
class MacheteLinearKernel(MPLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 90
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
# Machete uses CUTLASS, so it can only be compatible with Nvidia
if not current_platform.is_cuda():
return False, "Machete only supported on CUDA"
@@ -35,25 +36,33 @@ class MacheteLinearKernel(MPLinearKernel):
if not current_platform.is_device_capability(90):
return False, "Machete requires compute capability of 90 (Hopper)"
if c.has_g_idx and\
c.partition_weight_shape[0] != c.full_weight_shape[0]:
return False, "Act reordering currently not supported by Machete, "\
"when the input features are partitioned across "\
"devices"
if c.has_g_idx and c.partition_weight_shape[0] != c.full_weight_shape[0]:
return (
False,
"Act reordering currently not supported by Machete, "
"when the input features are partitioned across "
"devices",
)
if c.weight_type not in query_machete_supported_quant_types(
c.zero_points):
return False, f"Quant type ({c.weight_type}) not supported by "\
"Machete, supported types are: "\
f"{query_machete_supported_quant_types(c.zero_points)}"
if c.weight_type not in query_machete_supported_quant_types(c.zero_points):
return (
False,
f"Quant type ({c.weight_type}) not supported by "
"Machete, supported types are: "
f"{query_machete_supported_quant_types(c.zero_points)}",
)
if c.group_size not in query_machete_supported_group_sizes(c.act_type):
return False, f"Group size ({c.group_size}) not supported by "\
"Machete, supported group sizes are: "\
f"{query_machete_supported_group_sizes(c.act_type)}"
return (
False,
f"Group size ({c.group_size}) not supported by "
"Machete, supported group sizes are: "
f"{query_machete_supported_group_sizes(c.act_type)}",
)
return check_machete_supports_shape(c.partition_weight_shape[0],
c.partition_weight_shape[1])
return check_machete_supports_shape(
c.partition_weight_shape[0], c.partition_weight_shape[1]
)
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
@@ -64,30 +73,33 @@ class MacheteLinearKernel(MPLinearKernel):
if c.has_g_idx:
assert self.w_gidx_name is not None
perm = torch.argsort(getattr(layer, self.w_gidx_name))\
.to(torch.int)
perm = torch.argsort(getattr(layer, self.w_gidx_name)).to(torch.int)
self.act_perm = lambda x: x[:, perm]
# use `ops.permute_cols` if possible
if c.act_type in [torch.float16, torch.bfloat16] \
and c.partition_weight_shape[0] % 8 == 0:
if (
c.act_type in [torch.float16, torch.bfloat16]
and c.partition_weight_shape[0] % 8 == 0
):
self.act_perm = partial(ops.permute_cols, perm=perm)
def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
if c.has_g_idx:
x_unpacked = unpack_quantized_values_into_int32(x.data,
c.weight_type,
packed_dim=0)
x_unpacked = unpack_quantized_values_into_int32(
x.data, c.weight_type, packed_dim=0
)
x_perm = x_unpacked[perm, :]
x.data = pack_quantized_values_into_int32(x_perm,
c.weight_type,
packed_dim=0)
x.data = ops.machete_prepack_B(x.data.t().contiguous().t(),
a_type=c.act_type,
b_type=c.weight_type,
group_scales_type=c.act_type)
x.data = pack_quantized_values_into_int32(
x_perm, c.weight_type, packed_dim=0
)
x.data = ops.machete_prepack_B(
x.data.t().contiguous().t(),
a_type=c.act_type,
b_type=c.weight_type,
group_scales_type=c.act_type,
)
return x
def transform_w_s(x):
@@ -99,9 +111,9 @@ class MacheteLinearKernel(MPLinearKernel):
def transform_w_zp(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=1)
x_unpacked = unpack_quantized_values_into_int32(x.data,
c.weight_type,
packed_dim=1)
x_unpacked = unpack_quantized_values_into_int32(
x.data, c.weight_type, packed_dim=1
)
w_s = getattr(layer, self.w_s_name).data
# pre-apply scales to zero-points
x.data = (-1.0 * w_s * (x_unpacked.to(w_s.dtype))).contiguous()
@@ -113,15 +125,17 @@ class MacheteLinearKernel(MPLinearKernel):
if c.zero_points:
self._transform_param(layer, self.w_zp_name, transform_w_zp)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
c = self.config
w_q, w_s, w_zp, _ = self._get_weight_params(layer)
x_2d = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)
if c.has_g_idx:
x_2d = self.act_perm(x_2d)
@@ -131,12 +145,14 @@ class MacheteLinearKernel(MPLinearKernel):
else:
w_zp = None
output = ops.machete_mm(a=x_2d,
b_q=w_q,
b_type=c.weight_type,
b_group_zeros=w_zp,
b_group_scales=w_s,
b_group_size=c.group_size)
output = ops.machete_mm(
a=x_2d,
b_q=w_q,
b_type=c.weight_type,
b_group_zeros=w_zp,
b_group_scales=w_s,
b_group_size=c.group_size,
)
if bias is not None:
output.add_(bias) # In-place add

View File

@@ -7,46 +7,58 @@ import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear,
check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx,
marlin_make_workspace_new, marlin_permute_bias, marlin_permute_scales,
marlin_sort_g_idx, marlin_zero_points, query_marlin_supported_quant_types,
unpack_cols)
from vllm.model_executor.parameter import (BasevLLMParameter,
permute_param_layout_)
MARLIN_SUPPORTED_GROUP_SIZES,
apply_gptq_marlin_linear,
check_marlin_supports_shape,
marlin_is_k_full,
marlin_make_empty_g_idx,
marlin_make_workspace_new,
marlin_permute_bias,
marlin_permute_scales,
marlin_sort_g_idx,
marlin_zero_points,
query_marlin_supported_quant_types,
unpack_cols,
)
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
from vllm.platforms import current_platform
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
class MarlinLinearKernel(MPLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
# Marlin uses inline PTX, so it can only be compatible with Nvidia
if not current_platform.is_cuda():
return False, "Marlin only supported on CUDA"
quant_types = query_marlin_supported_quant_types(c.zero_points)
if c.weight_type not in quant_types:
return False, f"Quant type ({c.weight_type}) not supported by"\
f" Marlin, supported types are: {quant_types}"
return (
False,
f"Quant type ({c.weight_type}) not supported by"
f" Marlin, supported types are: {quant_types}",
)
if c.group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
return False, f"Group size ({c.group_size}) not supported by "\
"Marlin, supported group sizes are: "\
f"{MARLIN_SUPPORTED_GROUP_SIZES}"
return (
False,
f"Group size ({c.group_size}) not supported by "
"Marlin, supported group sizes are: "
f"{MARLIN_SUPPORTED_GROUP_SIZES}",
)
return check_marlin_supports_shape(
c.partition_weight_shape[1], # out_features
c.partition_weight_shape[0], # in_features
c.full_weight_shape[0], # in_features
c.group_size)
c.group_size,
)
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
@@ -55,7 +67,7 @@ class MarlinLinearKernel(MPLinearKernel):
device = getattr(layer, self.w_q_name).device
c = self.config
row_parallel = (c.partition_weight_shape[0] != c.full_weight_shape[0])
row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0]
self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
# Allocate marlin workspace.
@@ -71,25 +83,30 @@ class MarlinLinearKernel(MPLinearKernel):
def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
x.data = ops.gptq_marlin_repack(x.data.contiguous(),
perm=layer.g_idx_sort_indices,
size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1],
num_bits=c.weight_type.size_bits)
x.data = ops.gptq_marlin_repack(
x.data.contiguous(),
perm=layer.g_idx_sort_indices,
size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1],
num_bits=c.weight_type.size_bits,
)
return x
def transform_w_s(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1)
x.data = marlin_permute_scales(x.data.contiguous(),
size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1],
group_size=c.group_size)
x.data = marlin_permute_scales(
x.data.contiguous(),
size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1],
group_size=c.group_size,
)
return x
if c.has_g_idx:
g_idx, g_idx_sort_indices = marlin_sort_g_idx(
getattr(layer, self.w_gidx_name))
getattr(layer, self.w_gidx_name)
)
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
else:
@@ -97,16 +114,24 @@ class MarlinLinearKernel(MPLinearKernel):
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
if c.zero_points:
grouped_k = (c.partition_weight_shape[0] //
c.group_size if c.group_size != -1 else 1)
self._transform_param(layer, self.w_zp_name, lambda x: \
marlin_zero_points(
unpack_cols(x.t(), c.weight_type.size_bits,
grouped_k,
c.partition_weight_shape[1]),
grouped_k = (
c.partition_weight_shape[0] // c.group_size if c.group_size != -1 else 1
)
self._transform_param(
layer,
self.w_zp_name,
lambda x: marlin_zero_points(
unpack_cols(
x.t(),
c.weight_type.size_bits,
grouped_k,
c.partition_weight_shape[1],
),
size_k=grouped_k,
size_n=c.partition_weight_shape[1],
num_bits=c.weight_type.size_bits))
num_bits=c.weight_type.size_bits,
),
)
else:
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
self._transform_param(layer, self.w_q_name, transform_w_q)
@@ -115,10 +140,12 @@ class MarlinLinearKernel(MPLinearKernel):
if hasattr(layer, "bias") and layer.bias is not None:
layer.bias.data = marlin_permute_bias(layer.bias)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
c = self.config
w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer)
@@ -136,4 +163,5 @@ class MarlinLinearKernel(MPLinearKernel):
input_size_per_partition=c.partition_weight_shape[0],
output_size_per_partition=c.partition_weight_shape[1],
is_k_full=self.is_k_full,
bias=bias)
bias=bias,
)

View File

@@ -16,7 +16,6 @@ class ScaledMMLinearLayerConfig:
class ScaledMMLinearKernel(ABC):
@classmethod
@abstractmethod
def get_min_capability(cls) -> int:
@@ -24,13 +23,18 @@ class ScaledMMLinearKernel(ABC):
@classmethod
@abstractmethod
def can_implement(
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
raise NotImplementedError
def __init__(self, c: ScaledMMLinearLayerConfig, w_q_param_name: str,
w_s_param_name: str, i_s_param_name: str,
i_zp_param_name: str, azp_adj_param_name: str) -> None:
def __init__(
self,
c: ScaledMMLinearLayerConfig,
w_q_param_name: str,
w_s_param_name: str,
i_s_param_name: str,
i_zp_param_name: str,
azp_adj_param_name: str,
) -> None:
assert self.can_implement(c)
self.config = c
self.w_q_name = w_q_param_name
@@ -44,20 +48,23 @@ class ScaledMMLinearKernel(ABC):
raise NotImplementedError
@abstractmethod
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
def _get_weight_params(
self, layer: torch.nn.Module) -> tuple[
torch.Tensor, # weight
torch.Tensor, # weight_scale
Optional[torch.Tensor], # input_scale,
Optional[torch.Tensor], # input_zp
Optional[torch.Tensor], # azp_adj
]:
self, layer: torch.nn.Module
) -> tuple[
torch.Tensor, # weight
torch.Tensor, # weight_scale
Optional[torch.Tensor], # input_scale,
Optional[torch.Tensor], # input_zp
Optional[torch.Tensor], # azp_adj
]:
return (
getattr(layer, self.w_q_name),
getattr(layer, self.w_s_name),

View File

@@ -5,17 +5,24 @@ import os
from typing import Optional
from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import (
AiterScaledMMLinearKernel)
AiterScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import (
CPUScaledMMLinearKernel)
CPUScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
CutlassScaledMMLinearKernel)
CutlassScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
ScaledMMLinearKernel, ScaledMMLinearLayerConfig)
ScaledMMLinearKernel,
ScaledMMLinearLayerConfig,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
TritonScaledMMLinearKernel)
TritonScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import (
XLAScaledMMLinearKernel)
XLAScaledMMLinearKernel,
)
from vllm.platforms import PlatformEnum, current_platform
# in priority/performance order (when available)
@@ -28,19 +35,18 @@ _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
def choose_scaled_mm_linear_kernel(
config: ScaledMMLinearLayerConfig,
compute_capability: Optional[int] = None
config: ScaledMMLinearLayerConfig, compute_capability: Optional[int] = None
) -> type[ScaledMMLinearKernel]:
"""
Choose an ScaledMMLinearKernel that can implement the given config for the
given compute capability. Attempts to choose the best kernel in terms of
Choose an ScaledMMLinearKernel that can implement the given config for the
given compute capability. Attempts to choose the best kernel in terms of
performance.
Args:
config (ScaledMMLinearLayerConfig): Description of the linear layer
config (ScaledMMLinearLayerConfig): Description of the linear layer
to be implemented.
compute_capability (Optional[int], optional): The compute capability of
the target device, if None uses `current_platform` to get the
the target device, if None uses `current_platform` to get the
compute capability. Defaults to None.
Raises:
@@ -57,22 +63,25 @@ def choose_scaled_mm_linear_kernel(
failure_reasons = []
for kernel in _POSSIBLE_KERNELS[current_platform._enum]:
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "")\
.split(","):
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "").split(","):
failure_reasons.append(
f' {kernel.__name__} disabled by environment variable')
f" {kernel.__name__} disabled by environment variable"
)
continue
# If the current platform uses compute_capability,
# make sure the kernel supports the compute cability.
if compute_capability is not None:
kernel_min_capability = kernel.get_min_capability()
if (kernel_min_capability is not None
and kernel_min_capability > compute_capability):
if (
kernel_min_capability is not None
and kernel_min_capability > compute_capability
):
failure_reasons.append(
f"{kernel.__name__} requires capability "
f"{kernel_min_capability}, current compute capability "
f"is {compute_capability}")
f"is {compute_capability}"
)
continue
can_implement, failure_reason = kernel.can_implement(config)
@@ -80,10 +89,10 @@ def choose_scaled_mm_linear_kernel(
return kernel
else:
failure_reasons.append(
f' {kernel.__name__} cannot implement due to: {failure_reason}'
f" {kernel.__name__} cannot implement due to: {failure_reason}"
)
raise ValueError(
"Failed to find a kernel that can implement the "\
"ScaledMM linear layer. Reasons: \n"
+ '\n'.join(failure_reasons))
"Failed to find a kernel that can implement the "
"ScaledMM linear layer. Reasons: \n" + "\n".join(failure_reasons)
)

View File

@@ -22,7 +22,6 @@ def rocm_aiter_gemm_w8a8_impl(
bias: Optional[torch.Tensor] = None,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
from aiter import gemm_a8w8_CK
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
@@ -40,7 +39,6 @@ def rocm_aiter_gemm_w8a8_fake(
bias: Optional[torch.Tensor] = None,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
m = A.shape[0]
n = B.shape[0]
Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
@@ -56,50 +54,53 @@ if current_platform.is_rocm():
class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 90
@classmethod
def can_implement(
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
if not current_platform.is_rocm():
return (
False,
"AiterScaledMMLinearKernel requires `aiter` which is not " +
"currently supported on non-ROCm platform.")
"AiterScaledMMLinearKernel requires `aiter` which is not "
+ "currently supported on non-ROCm platform.",
)
try:
import aiter # noqa: F401 # deliberately attempt to import aiter
except Exception:
return (
False,
"AiterScaledMMLinearKernel requires `aiter` which is not " +
"installed on ROCm.")
"AiterScaledMMLinearKernel requires `aiter` which is not "
+ "installed on ROCm.",
)
# Check if rocm_aiter_gemm_w8a8_scaled_mm is enabled
if not (
envs.VLLM_ROCM_USE_AITER_LINEAR \
and envs.VLLM_ROCM_USE_AITER
):
return (False, "AiterScaledMMLinearKernel is disabled. " +
"Enable by setting `VLLM_ROCM_USE_AITER=1` " +
"and `VLLM_ROCM_USE_AITER_LINEAR=1`. " +
"`VLLM_ROCM_USE_AITER_LINEAR` default is True.")
if not (envs.VLLM_ROCM_USE_AITER_LINEAR and envs.VLLM_ROCM_USE_AITER):
return (
False,
"AiterScaledMMLinearKernel is disabled. "
+ "Enable by setting `VLLM_ROCM_USE_AITER=1` "
+ "and `VLLM_ROCM_USE_AITER_LINEAR=1`. "
+ "`VLLM_ROCM_USE_AITER_LINEAR` default is True.",
)
if not c.input_symmetric:
return (False,
"AiterScaledMMLinearKernel only supports symmetric " +
"quantization.")
return (
False,
"AiterScaledMMLinearKernel only supports symmetric " + "quantization.",
)
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
`AiterScaledMMLinearKernel` implements a fused version of
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
@@ -116,29 +117,27 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
# * dynamic, i_s is None and x_s computed from x.
# * static, i_s is scalar and x_s is i_s.
symmetric = azp_adj is None
assert symmetric, ("AiterScaledMMLinearKernel only supports"
" symmetric quantization.")
x_q, x_s, x_zp = ops.scaled_int8_quant(x,
i_s,
i_zp,
symmetric=symmetric)
assert symmetric, (
"AiterScaledMMLinearKernel only supports symmetric quantization."
)
x_q, x_s, x_zp = ops.scaled_int8_quant(x, i_s, i_zp, symmetric=symmetric)
assert x_zp is None, ("AiterScaledMMLinearKernel only supports"
" symmetric quantization.")
assert x_zp is None, (
"AiterScaledMMLinearKernel only supports symmetric quantization."
)
out_dtype = x.dtype
assert (w_q.shape[0] % 16 == 0 and w_q.shape[1] % 16 == 0)
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
assert bias is None or bias.shape[0] == w_q.shape[
1] and bias.dtype == out_dtype
assert w_q.shape[0] % 16 == 0 and w_q.shape[1] % 16 == 0
assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
assert bias is None or bias.shape[0] == w_q.shape[1] and bias.dtype == out_dtype
m = x_q.shape[0] # a
n = w_q.shape[1] # b
per_tensor_scale_a = (x_s.numel() == 1)
per_tensor_scale_b = (w_s.numel() == 1)
per_token_scale_a = (x_s.numel() == m)
per_channel_scale_b = (w_s.numel() == n)
per_tensor_scale_a = x_s.numel() == 1
per_tensor_scale_b = w_s.numel() == 1
per_token_scale_a = x_s.numel() == m
per_channel_scale_b = w_s.numel() == n
# @TODO:
# Maybe broadcast the per-tensor-scale into per-channel-scale
@@ -146,16 +145,19 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
# For now, it only supports:
# - per-tensor-per-tensor a8w8 scaled GEMM, and
# - per-token-per-channel a8w8 scaled GEMM
assert ((per_tensor_scale_a and per_tensor_scale_b)
or (per_token_scale_a and per_channel_scale_b)), (
"Currently only support per-tensor-per-tensor GEMM " +
" and per-token-per-channel GEMM through AITER"
" w8a8 scaled gemm. `AiterScaledMMLinearKernel` " +
"does not support AITER block scaled GEMM.")
assert (per_tensor_scale_a and per_tensor_scale_b) or (
per_token_scale_a and per_channel_scale_b
), (
"Currently only support per-tensor-per-tensor GEMM "
+ " and per-token-per-channel GEMM through AITER"
" w8a8 scaled gemm. `AiterScaledMMLinearKernel` "
+ "does not support AITER block scaled GEMM."
)
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
# a to be [M, K]
# b to be [N, K]
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
return torch.ops.vllm.rocm_aiter_gemm_w8a8(x_q, w_q.t(), x_s, w_s,
bias, out_dtype)
return torch.ops.vllm.rocm_aiter_gemm_w8a8(
x_q, w_q.t(), x_s, w_s, bias, out_dtype
)

View File

@@ -9,24 +9,22 @@ from vllm import _custom_ops as ops
from vllm import envs
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise)
convert_to_channelwise,
)
from vllm.model_executor.layers.utils import check_cpu_sgl_kernel
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from .ScaledMMLinearKernel import (ScaledMMLinearKernel,
ScaledMMLinearLayerConfig)
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 75
@classmethod
def can_implement(
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
if not current_platform.is_cpu():
return False, "CPUScaledMM requires running on CPU."
@@ -36,9 +34,12 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
weight = getattr(layer, self.w_q_name)
dtype = weight.dtype
N, K = weight.size()
if (current_platform.get_cpu_architecture() == CpuArchEnum.X86
and envs.VLLM_CPU_SGL_KERNEL and self.config.input_symmetric
and check_cpu_sgl_kernel(N, K, dtype)):
if (
current_platform.get_cpu_architecture() == CpuArchEnum.X86
and envs.VLLM_CPU_SGL_KERNEL
and self.config.input_symmetric
and check_cpu_sgl_kernel(N, K, dtype)
):
self.linear_method = self._apply_weights_sgl
self.process_weights_for_sgl(layer)
else:
@@ -50,8 +51,10 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
# Transpose to [K, N] for convenience
weight = getattr(layer, self.w_q_name)
replace_parameter(
layer, self.w_q_name,
torch.nn.Parameter(weight.t().data, requires_grad=False))
layer,
self.w_q_name,
torch.nn.Parameter(weight.t().data, requires_grad=False),
)
# WEIGHT SCALE
# oneDNN kernels support only per-tensor and per-channel.
@@ -60,11 +63,12 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
is_fused_module = len(layer.logical_widths) > 1
weight_scale = getattr(layer, self.w_s_name)
if is_fused_module and not self.config.is_channelwise:
weight_scale = convert_to_channelwise(weight_scale,
layer.logical_widths)
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
replace_parameter(
layer, self.w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False))
layer,
self.w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False),
)
# INPUT SCALE
if self.config.is_static_input_scheme:
@@ -72,8 +76,10 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
if self.config.input_symmetric:
replace_parameter(
layer, self.i_s_name,
torch.nn.Parameter(input_scale.max(), requires_grad=False))
layer,
self.i_s_name,
torch.nn.Parameter(input_scale.max(), requires_grad=False),
)
setattr(layer, self.i_zp_name, None)
else:
input_zero_point = getattr(layer, self.i_zp_name)
@@ -84,16 +90,17 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
range_max = (input_scale * (int8_traits.max - azps)).max()
range_min = (input_scale * (int8_traits.min - azps)).min()
scale = (range_max - range_min) / (int8_traits.max -
int8_traits.min)
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
replace_parameter(
layer, self.i_s_name,
torch.nn.Parameter(scale, requires_grad=False))
layer, self.i_s_name, torch.nn.Parameter(scale, requires_grad=False)
)
azp = (int8_traits.min -
range_min / scale).round().to(dtype=torch.int32)
replace_parameter(layer, self.i_zp_name,
torch.nn.Parameter(azp, requires_grad=False))
azp = (
(int8_traits.min - range_min / scale).round().to(dtype=torch.int32)
)
replace_parameter(
layer, self.i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
)
else:
setattr(layer, self.i_s_name, None)
@@ -105,14 +112,16 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
# s_a * s_b * [(A - zp_a)B] + bias =
# s_a * (s_b * AB) - s_a * s_b * zp_a * B + bias =
# s_a * GEMM_output - s_a * zp_a * adj + bias
if not (self.config.input_symmetric
and self.config.is_static_input_scheme):
if not (self.config.input_symmetric and self.config.is_static_input_scheme):
weight = getattr(layer, self.w_q_name)
weight_scale = getattr(layer, self.w_s_name)
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.float32)
azp_adj = azp_adj * weight_scale.squeeze()
setattr(layer, self.azp_adj_name,
torch.nn.Parameter(azp_adj, requires_grad=False))
setattr(
layer,
self.azp_adj_name,
torch.nn.Parameter(azp_adj, requires_grad=False),
)
else:
setattr(layer, self.azp_adj_name, None)
@@ -135,34 +144,37 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
weight = getattr(layer, self.w_q_name)
packed_weight = torch.ops._C.convert_weight_packed(weight)
replace_parameter(
layer, self.w_q_name,
torch.nn.Parameter(packed_weight, requires_grad=False))
layer, self.w_q_name, torch.nn.Parameter(packed_weight, requires_grad=False)
)
if layer.bias is not None:
bias = layer.bias
layer.register_parameter(
"bias_fp32",
torch.nn.Parameter(bias.float().data, requires_grad=False))
"bias_fp32", torch.nn.Parameter(bias.float().data, requires_grad=False)
)
# WEIGHT SCALE
# CPU SGL kernels only support per-channel.
# For per-tensor quant, convert to the per-channel case.
weight_scale = getattr(layer, self.w_s_name)
if not self.config.is_channelwise:
weight_scale = convert_to_channelwise(weight_scale,
layer.logical_widths)
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
replace_parameter(
layer, self.w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False))
layer,
self.w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False),
)
setattr(layer, self.i_s_name, None)
setattr(layer, self.i_zp_name, None)
setattr(layer, self.azp_adj_name, None)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.linear_method(
layer,
x,
@@ -170,31 +182,33 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
)
def _apply_weights_onednn(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
# ops.scaled_int8_quant supports both dynamic and static quant:
# * dynamic, i_s is None and x_s computed from x.
# * static, i_s is scalar and x_s is i_s.
x_q, x_s, x_zp = ops.onednn_scaled_int8_quant(
x, i_s, i_zp, self.config.input_symmetric)
x, i_s, i_zp, self.config.input_symmetric
)
m = x.size(0)
n = self.dnnl_handler.n
out = torch.empty((m, n), dtype=x.dtype)
ops.onednn_scaled_mm(self.dnnl_handler, x_q, out, x_s, x_zp, azp_adj,
bias)
ops.onednn_scaled_mm(self.dnnl_handler, x_q, out, x_s, x_zp, azp_adj, bias)
return out
def _apply_weights_sgl(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
w_q, w_s, _, _, _ = self._get_weight_params(layer)
return torch.ops._C.int8_scaled_mm_with_quant(
x,

View File

@@ -8,23 +8,20 @@ import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise)
convert_to_channelwise,
)
from vllm.platforms import current_platform
from .ScaledMMLinearKernel import (ScaledMMLinearKernel,
ScaledMMLinearLayerConfig)
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 75
@classmethod
def can_implement(
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
if not current_platform.is_cuda():
return False, "CutlassScaledMM requires running on CUDA."
@@ -35,8 +32,10 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
# Cutlass kernels need transposed weight.
weight = getattr(layer, self.w_q_name)
replace_parameter(
layer, self.w_q_name,
torch.nn.Parameter(weight.t().data, requires_grad=False))
layer,
self.w_q_name,
torch.nn.Parameter(weight.t().data, requires_grad=False),
)
# WEIGHT SCALE
# Cutlass kernels support only per-tensor and per-channel.
@@ -45,11 +44,12 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
is_fused_module = len(layer.logical_widths) > 1
weight_scale = getattr(layer, self.w_s_name)
if is_fused_module and not self.config.is_channelwise:
weight_scale = convert_to_channelwise(weight_scale,
layer.logical_widths)
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
replace_parameter(
layer, self.w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False))
layer,
self.w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False),
)
# INPUT SCALE
if self.config.is_static_input_scheme:
@@ -57,8 +57,10 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
if self.config.input_symmetric:
replace_parameter(
layer, self.i_s_name,
torch.nn.Parameter(input_scale.max(), requires_grad=False))
layer,
self.i_s_name,
torch.nn.Parameter(input_scale.max(), requires_grad=False),
)
setattr(layer, self.i_zp_name, None)
else:
input_zero_point = getattr(layer, self.i_zp_name)
@@ -69,17 +71,16 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
range_max = (input_scale * (int8_traits.max - azps)).max()
range_min = (input_scale * (int8_traits.min - azps)).min()
scale = (range_max - range_min) / (int8_traits.max -
int8_traits.min)
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
replace_parameter(
layer, self.i_s_name,
torch.nn.Parameter(scale, requires_grad=False))
layer, self.i_s_name, torch.nn.Parameter(scale, requires_grad=False)
)
# AZP loaded as int8 but used as int32
azp = (int8_traits.min -
range_min / scale).to(dtype=torch.int32)
replace_parameter(layer, self.i_zp_name,
torch.nn.Parameter(azp, requires_grad=False))
azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32)
replace_parameter(
layer, self.i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
)
else:
setattr(layer, self.i_s_name, None)
@@ -97,41 +98,44 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
# cutlass_w8a8 requires azp to be folded into azp_adj
# in the per-tensor case
azp_adj = getattr(layer, self.i_zp_name) * azp_adj
setattr(layer, self.azp_adj_name,
torch.nn.Parameter(azp_adj, requires_grad=False))
setattr(
layer,
self.azp_adj_name,
torch.nn.Parameter(azp_adj, requires_grad=False),
)
else:
setattr(layer, self.azp_adj_name, None)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
# ops.scaled_int8_quant supports both dynamic and static quant:
# * dynamic, i_s is None and x_s computed from x.
# * static, i_s is scalar and x_s is i_s.
symmetric = azp_adj is None
x_q, x_s, x_zp = ops.scaled_int8_quant(x.contiguous(),
i_s,
i_zp,
symmetric=symmetric)
x_q, x_s, x_zp = ops.scaled_int8_quant(
x.contiguous(), i_s, i_zp, symmetric=symmetric
)
if x_zp is not None:
# Currently, static is always per-tensor and dynamic is per-token
static = i_zp is not None
azp = None if static else x_zp
return ops.cutlass_scaled_mm_azp(x_q,
w_q,
scale_a=x_s,
scale_b=w_s,
out_dtype=x.dtype,
azp_adj=azp_adj,
azp=azp,
bias=bias)
return ops.cutlass_scaled_mm(x_q,
w_q,
scale_a=x_s,
scale_b=w_s,
out_dtype=x.dtype,
bias=bias)
return ops.cutlass_scaled_mm_azp(
x_q,
w_q,
scale_a=x_s,
scale_b=w_s,
out_dtype=x.dtype,
azp_adj=azp_adj,
azp=azp,
bias=bias,
)
return ops.cutlass_scaled_mm(
x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias
)

View File

@@ -12,30 +12,32 @@ from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 75
@classmethod
def can_implement(
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
if current_platform.is_cpu():
return (
False,
"TritonScaledMMLinearKernel requires Triton which is not " +
"currently supported on CPU.")
"TritonScaledMMLinearKernel requires Triton which is not "
+ "currently supported on CPU.",
)
if not c.input_symmetric:
return (False,
"TritonScaledMMLinearKernel only supports symmetric " +
"quantization.")
return (
False,
"TritonScaledMMLinearKernel only supports symmetric " + "quantization.",
)
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return super().apply_weights(layer, x, bias)

View File

@@ -9,25 +9,23 @@ from functorch.experimental.control_flow import cond # noqa: F401
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise)
convert_to_channelwise,
)
from vllm.platforms import current_platform
from .ScaledMMLinearKernel import (ScaledMMLinearKernel,
ScaledMMLinearLayerConfig)
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
raise NotImplementedError(
"TPU platform does have a concept of compute capability, "
"this method should not be called.")
"this method should not be called."
)
@classmethod
def can_implement(
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
if not current_platform.is_tpu():
return False, "ScaledMMXLA requires running on TPU."
@@ -46,8 +44,9 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
# WEIGHT
# [out, in] (different than cutlass_scaled_mm)
weight = getattr(layer, self.w_q_name)
replace_parameter(layer, self.w_q_name,
torch.nn.Parameter(weight.data, requires_grad=False))
replace_parameter(
layer, self.w_q_name, torch.nn.Parameter(weight.data, requires_grad=False)
)
# WEIGHT SCALE
# XLA kernels support only per-tensor and per-channel.
@@ -56,14 +55,15 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
is_fused_module = len(layer.logical_widths) > 1
weight_scale = getattr(layer, self.w_s_name)
if is_fused_module and not self.config.is_channelwise:
weight_scale = convert_to_channelwise(weight_scale,
layer.logical_widths)
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
# [out_channel,] (different than cutlass_scaled_mm)
weight_scale = weight_scale.squeeze(-1)
replace_parameter(
layer, self.w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False))
layer,
self.w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False),
)
# Only support symmetric dynamic activation quantization.
setattr(layer, self.i_s_name, None)
@@ -74,8 +74,7 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
# to specialize the graph since bias is not dynamic.
warnings.filterwarnings(
"ignore",
message=
"Pred is a Python constant. When used with torch.cond, it specializes on one of the branches." # noqa: E501
message="Pred is a Python constant. When used with torch.cond, it specializes on one of the branches.", # noqa: E501
)
def no_add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]):
@@ -84,14 +83,17 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
def add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]):
return x + bias
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
w_q, w_s, _, _, _ = self._get_weight_params(layer)
# Required to register custom ops.
import torch_xla.experimental.custom_kernel # noqa: F401
out = torch.ops.xla.quantized_matmul_int8(
x,
w_q,