[kernel] Support W4A8 on Hopper (#23198)

Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
This commit is contained in:
czhu-cohere
2025-08-24 02:18:04 -04:00
committed by GitHub
parent a75277285b
commit e76e233540
12 changed files with 1128 additions and 7 deletions

View File

@@ -26,10 +26,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24,
CompressedTensorsScheme, CompressedTensorsW4A4Fp4,
CompressedTensorsW4A8Int, CompressedTensorsW4A16Fp4,
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
CompressedTensorsWNA16)
CompressedTensorsW4A8Fp8, CompressedTensorsW4A8Int,
CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
find_matched_target, is_activation_quantization_format,
should_ignore_layer)
@@ -200,8 +200,10 @@ class CompressedTensorsConfig(QuantizationConfig):
format
) if format is not None else is_activation_quantization_format(
quant_format)
if act_quant_format:
input_activations = quant_config.get("input_activations")
# TODO(czhu): w4a8fp8 is in packed-quantized format
# but needs input activation quantization
input_activations = quant_config.get("input_activations")
if act_quant_format or input_activations:
# The only case where we have activation quant supported
# but no input_activations provided in the config
# should be w8a16fp8 w8a16fp8 can also run for cases where
@@ -352,6 +354,28 @@ class CompressedTensorsConfig(QuantizationConfig):
input_quant.strategy == QuantizationStrategy.TENSOR)
return is_symmetric_activation and is_per_tensor_activation
def _is_fp8_w4a8(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
if not weight_quant or not input_quant:
return False
is_weight_4_bits = weight_quant.num_bits == 4
is_activation_8_bits = input_quant.num_bits == 8
weight_strategy = (
weight_quant.strategy == QuantizationStrategy.GROUP.value)
is_token = (weight_strategy and input_quant.strategy
== QuantizationStrategy.TOKEN.value)
is_dynamic = not weight_quant.dynamic and input_quant.dynamic
is_symmetric = weight_quant.symmetric and input_quant.symmetric
# Only per-group symmetric weight (4bit)
# + per-tok symmetric activation (8bit) quantization supported.
return (is_weight_4_bits and is_activation_8_bits and is_token
and is_symmetric and is_dynamic)
def _is_fp8_w4a8_sm90(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
return (self._check_scheme_supported(90, error=False, match_exact=True)
and self._is_fp8_w4a8(weight_quant, input_quant))
def _is_fp8_w8a8_sm90(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
return (self._check_scheme_supported(90, error=False, match_exact=True)
@@ -405,6 +429,13 @@ class CompressedTensorsConfig(QuantizationConfig):
if self._is_fp4a16_nvfp4(weight_quant, input_quant):
return CompressedTensorsW4A16Fp4()
if self._is_fp8_w4a8_sm90(weight_quant, input_quant):
return CompressedTensorsW4A8Fp8(num_bits=weight_quant.num_bits,
strategy=weight_quant.strategy,
symmetric=weight_quant.symmetric,
group_size=weight_quant.group_size,
actorder=weight_quant.actorder)
if self._is_wNa16_group_channel(weight_quant, input_quant):
if (self.quant_format == CompressionFormat.marlin_24.value
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS):

View File

@@ -3,6 +3,7 @@
from .compressed_tensors_scheme import CompressedTensorsScheme
from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4
from .compressed_tensors_w4a8_fp8 import CompressedTensorsW4A8Fp8
from .compressed_tensors_w4a8_int import CompressedTensorsW4A8Int
from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS,
CompressedTensorsW4A16Sparse24)
@@ -21,5 +22,6 @@ __all__ = [
"CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8",
"WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS",
"CompressedTensors24", "CompressedTensorsW4A16Fp4",
"CompressedTensorsW4A4Fp4", "CompressedTensorsW4A8Int"
"CompressedTensorsW4A4Fp4", "CompressedTensorsW4A8Int",
"CompressedTensorsW4A8Fp8"
]

View File

@@ -0,0 +1,160 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional
import torch
from compressed_tensors.quantization import ActivationOrdering
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
MPLinearLayerConfig, choose_mp_linear_kernel)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_repeat_scales_on_all_ranks)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedvLLMParameter)
# yapf: enable
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
__all__ = ["CompressedTensorsW4A8Fp8"]
W4A8_SUPPORTED_TYPES_MAP = {
4: scalar_types.int4,
}
W4A8_SUPPORTED_BITS = list(W4A8_SUPPORTED_TYPES_MAP.keys())
class CompressedTensorsW4A8Fp8(CompressedTensorsScheme):
_kernel_backends_being_used: set[str] = set()
def __init__(self,
strategy: str,
num_bits: int,
group_size: Optional[int] = None,
symmetric: Optional[bool] = True,
actorder: Optional[ActivationOrdering] = None):
self.pack_factor = 32 // num_bits
self.strategy = strategy
self.symmetric = symmetric
self.group_size = -1 if group_size is None else group_size
self.has_g_idx = actorder == ActivationOrdering.GROUP
if self.group_size != 128 or self.strategy != "group":
raise ValueError("W4A8 kernels require group quantization " \
"with group size 128")
if num_bits not in W4A8_SUPPORTED_TYPES_MAP:
raise ValueError(
f"Unsupported num_bits = {num_bits}. "
f"Supported num_bits = {W4A8_SUPPORTED_TYPES_MAP.keys()}")
self.quant_type = W4A8_SUPPORTED_TYPES_MAP[num_bits]
@classmethod
def get_min_capability(cls) -> int:
# hopper
return 90
def create_weights(self, layer: torch.nn.Module, output_size: int,
input_size: int, output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
output_size_per_partition = sum(output_partition_sizes)
mp_linear_kernel_config = MPLinearLayerConfig(
full_weight_shape=(input_size, output_size),
partition_weight_shape=\
(input_size_per_partition, output_size_per_partition),
weight_type=self.quant_type,
act_type=torch.float8_e4m3fn, # always use fp8(e4m3)
group_size=self.group_size,
zero_points=not self.symmetric,
has_g_idx=self.has_g_idx
)
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for CompressedTensorsW4A8Fp8",
kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)
# If group_size is -1, we are in channelwise case.
group_size = self.group_size if self.group_size != -1 else input_size
row_parallel = (input_size != input_size_per_partition)
partition_scales = not marlin_repeat_scales_on_all_ranks(
self.has_g_idx, self.group_size, row_parallel)
scales_and_zp_size = input_size // group_size
if partition_scales:
assert input_size_per_partition % group_size == 0
scales_and_zp_size = input_size_per_partition // group_size
weight = PackedvLLMParameter(input_dim=1,
output_dim=0,
weight_loader=weight_loader,
packed_factor=self.pack_factor,
packed_dim=1,
data=torch.empty(
output_size_per_partition,
input_size_per_partition //
self.pack_factor,
dtype=torch.int32,
))
# TODO(czhu): allocate the packed fp8 scales memory here?
# the scales will be expanded by 8x via `cutlass_pack_scale_fp8`
weight_scale_args = {
"weight_loader":
weight_loader,
"data":
torch.empty(
output_size_per_partition,
scales_and_zp_size,
dtype=params_dtype,
)
}
if not partition_scales:
weight_scale = ChannelQuantScaleParameter(output_dim=0,
**weight_scale_args)
else:
weight_scale = GroupQuantScaleParameter(output_dim=0,
input_dim=1,
**weight_scale_args)
# A 2D array defining the original shape of the weights
# before packing
weight_shape = BasevLLMParameter(data=torch.empty(2,
dtype=torch.int64),
weight_loader=weight_loader)
layer.register_parameter("weight_packed", weight)
layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight_shape", weight_shape)
self.kernel = kernel_type(mp_linear_kernel_config,
w_q_param_name="weight_packed",
w_s_param_name="weight_scale",
w_zp_param_name="weight_zero_point",
w_gidx_param_name="weight_g_idx")
# Checkpoints are serialized in compressed-tensors format, which is
# different from the format the kernel may want. Handle repacking here.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.kernel.process_weights_after_loading(layer)
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return self.kernel.apply_weights(layer, x, bias)

View File

@@ -10,6 +10,8 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision.bitblas imp
BitBLASLinearKernel)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.conch import ( # noqa: E501
ConchLinearKernel)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.cutlass import ( # noqa: E501
CutlassW4A8LinearKernel)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.dynamic_4bit import ( # noqa: E501
Dynamic4bitLinearKernel)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501
@@ -24,6 +26,7 @@ from vllm.platforms import current_platform
# in priority/performance order (when available)
_POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [
CutlassW4A8LinearKernel,
MacheteLinearKernel,
AllSparkLinearKernel,
MarlinLinearKernel,

View File

@@ -0,0 +1,114 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.model_executor.parameter import (BasevLLMParameter,
permute_param_layout_)
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
class CutlassW4A8LinearKernel(MPLinearKernel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# dynamic per-tok fp8 activation quantization
self.quant_fp8 = QuantFP8(static=False,
group_shape=GroupShape.PER_TOKEN)
@classmethod
def get_min_capability(cls) -> int:
return 90
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
if not current_platform.is_cuda():
return False, "CUTLASS only supported on CUDA"
if not current_platform.is_device_capability(90):
return False, "CUTLASS W4A8 requires compute capability of 90 "\
"(Hopper)"
if c.act_type != torch.float8_e4m3fn:
return False, "CUTLASS W4A8 only supports FP8 (e4m3) activations"
if c.has_g_idx:
return False, "Act reordering not supported by CUTLASS W4A8"
if c.zero_points:
return False, "Zero points not supported by CUTLASS W4A8"
if c.weight_type != scalar_types.int4:
return False, f"Quant type ({c.weight_type}) not supported by "\
"CUTLASS W4A8, only supported int4"
# TODO(czhu): support -1 (column-wise)
if c.group_size != 128:
return False, "Only group_size 128 is supported"
in_features, out_features = c.partition_weight_shape
if in_features % 128 or out_features % 128:
return False, "K and N must be divisible by 128, got "\
f"{c.partition_weight_shape}"
return True, None
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module):
c = self.config
# TODO(czhu): optimize speed/mem usage
def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
x.data = ops.cutlass_encode_and_reorder_int4b(
x.data.t().contiguous().t())
return x
def transform_w_s(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1)
x.data = x.data.contiguous().to(torch.float8_e4m3fn)
x.data = ops.cutlass_pack_scale_fp8(x.data)
return x
# Encode/reorder weights and pack scales
self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)
# TODO(czhu): support loading channel scales
self.w_ch_s = torch.ones((c.partition_weight_shape[1], ),
dtype=torch.float32,
device='cuda')
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
assert bias is None, "bias not supported by CUTLASS W4A8"
c = self.config
w_q, w_s, _, _ = self._get_weight_params(layer)
x_2d = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
x_2d, act_scales = self.quant_fp8(x_2d)
output = ops.cutlass_w4a8_mm(a=x_2d,
b_q=w_q,
b_group_scales=w_s,
b_group_size=c.group_size,
a_token_scales=act_scales,
b_channel_scales=self.w_ch_s)
return output.reshape(out_shape)