Add llmcompressor fp8 kv-cache quant (per-tensor and per-attn_head) (#30141)

Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com>
Signed-off-by: eldarkurtic <8884008+eldarkurtic@users.noreply.github.com>
This commit is contained in:
Eldar Kurtić
2026-01-22 21:29:57 +01:00
committed by GitHub
parent 955b43a5a5
commit 44f08af3a7
18 changed files with 558 additions and 263 deletions

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import suppress
from functools import partial
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
import torch
@@ -19,6 +20,10 @@ from compressed_tensors.transform import TransformConfig
import vllm.envs as envs
from vllm.attention.layer import Attention
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (
@@ -87,6 +92,8 @@ class CompressedTensorsConfig(QuantizationConfig):
kv_cache_scheme: dict[str, Any] | None = None,
config: dict[str, Any] | None = None,
transform_config: dict[str, Any] | None = None,
total_num_heads: int | None = None,
total_num_kv_heads: int | None = None,
):
super().__init__()
self.ignore = ignore
@@ -97,6 +104,8 @@ class CompressedTensorsConfig(QuantizationConfig):
self.sparsity_scheme_map = sparsity_scheme_map
self.sparsity_ignore_list = sparsity_ignore_list
self.config = config
self.total_num_heads = total_num_heads
self.total_num_kv_heads = total_num_kv_heads
if transform_config:
self.transform_config = TransformConfig.model_validate(transform_config)
@@ -200,13 +209,29 @@ class CompressedTensorsConfig(QuantizationConfig):
@classmethod
def from_config(cls, config: dict[str, Any]) -> "CompressedTensorsConfig":
# We keep only config groups which are not doing Attention quantization
# because Attention quantization on its own is not supported by vLLM.
# It is coupled with KV-cache quantization, and if scales are present in the
# checkpoint, they will be used properly.
grps_without_attn_quant = {}
for k, v in config["config_groups"].items():
# e.g. LlamaAttention, Qwen3Attention, etc.
if len(v["targets"]) == 1 and v["targets"][0].endswith("Attention"):
logger.warning(
"Skipping CompressedTensors config group for %s. Attention quant "
"is coupled with KV-cache quantization in vLLM.",
v["targets"][0],
)
continue
grps_without_attn_quant[k] = v
config["config_groups"] = grps_without_attn_quant
ignore: list[str] = cast(list[str], config.get("ignore", []))
quant_format = cast(str, config.get("format"))
target_scheme_map = cls._quantization_scheme_map_from_config(config=config)
sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config(
config=config
)
transform_config = config.get("transform_config")
return cls(
target_scheme_map=target_scheme_map,
@@ -215,7 +240,10 @@ class CompressedTensorsConfig(QuantizationConfig):
sparsity_scheme_map=sparsity_scheme_map,
sparsity_ignore_list=sparsity_ignore_list,
config=config,
transform_config=transform_config,
transform_config=config.get("transform_config"),
kv_cache_scheme=config.get("kv_cache_scheme"),
total_num_heads=config.get("total_num_heads"),
total_num_kv_heads=config.get("total_num_kv_heads"),
)
@classmethod
@@ -791,22 +819,6 @@ class CompressedTensorsConfig(QuantizationConfig):
return None
def get_cache_scale(self, name: str) -> str | None:
"""
Check whether the param name matches the format for k/v cache scales
in compressed-tensors. If this is the case, return its equivalent
param name expected by vLLM
:param name: param name
:return: matching param name for KV cache scale in vLLM
"""
if name.endswith(".output_scale") and ".k_proj" in name:
return name.replace(".k_proj.output_scale", ".attn.k_scale")
if name.endswith(".output_scale") and ".v_proj" in name:
return name.replace(".v_proj.output_scale", ".attn.v_scale")
# If no matches, return None
return None
def has_blocked_weights(self) -> bool:
for scheme in self.target_scheme_map.values():
weight_quant = scheme.get("weights")
@@ -965,12 +977,16 @@ class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
f"received num_bits={num_bits}, type={type_}"
)
# TODO: delegate validation to compressed-tensors library so that we have a
# single source of truth. Right now this is not possible until the next release
# of compressed-tensors.
strategy = kv_cache_scheme.get("strategy")
if strategy != "tensor":
supported_strategies = ("tensor", "attn_head")
if strategy not in supported_strategies:
raise NotImplementedError(
"Only support per-tensor scaling factor "
"for compressed-tensors KV cache. "
f"Expected strategy: tensor, found strategy: {strategy}"
"Invalid strategy for compressed-tensors KV cache. "
f"Expected strategies: {supported_strategies}, found strategy:"
f" {strategy}"
)
is_symmetric = kv_cache_scheme.get("symmetric")
@@ -980,3 +996,133 @@ class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
"for compressed-tensors KV cache. "
f"However found symmetric: {is_symmetric}"
)
def create_weights(self, layer: torch.nn.Module):
"""
Initialize placeholder scales and zero points to enable loading of
quantized params from compressed-tensors checkpoints.
"""
strategy = None # for backward compatibility
if (
hasattr(self.quant_config, "kv_cache_scheme")
and self.quant_config.kv_cache_scheme is not None
):
strategy = self.quant_config.kv_cache_scheme["strategy"]
if strategy == "attn_head":
assert layer.impl.supports_per_head_quant_scales, (
f"Layer {layer.__class__.__name__} with implementation "
f"{layer.impl.__class__.__name__} does not support per-head scales."
)
n_scales = int(layer.num_kv_heads)
else:
n_scales = 1
layer.k_scale = torch.nn.Parameter(
torch.ones(n_scales, requires_grad=False, dtype=torch.float32)
)
layer.v_scale = torch.nn.Parameter(
torch.ones(n_scales, requires_grad=False, dtype=torch.float32)
)
layer.q_scale = torch.nn.Parameter(
torch.ones(n_scales, requires_grad=False, dtype=torch.float32)
)
# Zero points are not used in vLLM as currently only symmetric quantization is
# supported. We need to create them here to enable loading of llm-compressor
# checkpoints which contain them irrespective of the symmetric/asymmetric
# scheme used during quantization.
layer.k_zero_point = torch.nn.Parameter(
torch.zeros(n_scales, requires_grad=False)
)
layer.v_zero_point = torch.nn.Parameter(
torch.zeros(n_scales, requires_grad=False)
)
layer.q_zero_point = torch.nn.Parameter(
torch.zeros(n_scales, requires_grad=False)
)
# TP-aware loading for attn_head strategy follows attention head partitioning:
# - q_scale is partitioned over query heads.
# - k/v_scale is partitioned over kv heads when total_kv_heads >= tp_size,
# and replicated when total_kv_heads < tp_size.
if strategy == "attn_head":
def _tp_aware_loader(
param: torch.Tensor,
loaded_weight: torch.Tensor,
kind: Literal["q", "k", "v"],
param_type: Literal["scale", "zero_point"],
):
# Zero-points are not used as vLLM only supports symmetric quantization
if param_type == "zero_point":
return
# LLM-Compressor stores scales as 3D tensors of shape [num_heads, 1, 1]
loaded_weight = loaded_weight.flatten()
# FlashAttn expects [num_kv_heads] instead of [num_heads] for q_scale.
# We reduce by taking the max scale in each attention head group.
if kind == "q":
reduction_factor = (
self.quant_config.total_num_heads # type: ignore[attr-defined]
// self.quant_config.total_num_kv_heads # type: ignore[attr-defined]
)
loaded_weight = torch.amax(
loaded_weight.view(-1, reduction_factor), dim=1
)
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
if layer.num_kv_heads * tp_size == self.quant_config.total_num_kv_heads: # type: ignore[attr-defined]
# heads evenly distributed
loaded_weight = loaded_weight[
tp_rank * layer.num_kv_heads : (tp_rank + 1)
* layer.num_kv_heads
]
else:
# heads replicated to match TP size
assert layer.num_kv_heads == 1
replicas = tp_size // self.quant_config.total_num_kv_heads # type: ignore[attr-defined]
shard_rank = tp_rank // replicas
loaded_weight = loaded_weight[shard_rank : shard_rank + 1]
param.data.copy_(loaded_weight.to(dtype=param.dtype))
layer.q_scale.weight_loader = partial(
_tp_aware_loader, kind="q", param_type="scale"
)
layer.k_scale.weight_loader = partial(
_tp_aware_loader, kind="k", param_type="scale"
)
layer.v_scale.weight_loader = partial(
_tp_aware_loader, kind="v", param_type="scale"
)
layer.q_zero_point.weight_loader = partial(
_tp_aware_loader, kind="q", param_type="zero_point"
)
layer.k_zero_point.weight_loader = partial(
_tp_aware_loader, kind="k", param_type="zero_point"
)
layer.v_zero_point.weight_loader = partial(
_tp_aware_loader, kind="v", param_type="zero_point"
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""
Override the default vLLM placeholder scales with the llm-compressor loaded
scales. Zero points are not used as only symmetric quantization is supported.
"""
layer._k_scale = layer.k_scale
layer._v_scale = layer.v_scale
layer._q_scale = layer.q_scale
# Discard all placeholders.
del layer.k_scale
del layer.v_scale
del layer.q_scale
del layer.k_zero_point
del layer.v_zero_point
del layer.q_zero_point

View File

@@ -11,6 +11,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
get_fp8_min_max,
group_broadcast,
prep_scale_for_group_broadcast,
)
from vllm.platforms import current_platform
@@ -40,7 +41,7 @@ class QuantFP8(CustomOp):
"""
:param static: static or dynamic quantization
:param group_shape: quantization group shape (PER_TOKEN, PER_TENSOR,
or arbitrary block size)
PER_CHANNEL, or arbitrary block size)
:param num_token_padding: Pad the token dimension of output to this
size
:param column_major_scales: For group quantization, output scales in
@@ -157,6 +158,8 @@ class QuantFP8(CustomOp):
x_max = x.abs().max().unsqueeze(-1).to(torch.float32)
scale = (x_max / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR)
else:
scale = prep_scale_for_group_broadcast(scale, x, self.group_shape)
# Even for dynamic per-token scales,
# reciprocal performs slightly better than division

View File

@@ -191,6 +191,51 @@ def group_broadcast(t, shape):
return t
def prep_scale_for_group_broadcast(
scale: torch.Tensor,
x: torch.Tensor,
group_shape: GroupShape | None,
) -> torch.Tensor:
"""
Prepare the input quantization scale for group broadcasting.
Args:
scale: The scale tensor (scalar or 1D).
x: Target tensor whose shape determines broadcast dimensions.
group_shape: GroupShape to broadcast over.
Returns:
scale reshaped for correct broadcasting.
"""
if scale.numel() == 1:
# For per-tensor quant, keep the scale as a scalar (not reshaped to (1, 1)).
# This avoids misclassifying it as channelwise quant in Fp8LinearOp.apply,
# where the "per_tensor_activations" check relies on "x_scale.dim() < 2":
# per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2
# For all other cases, reshape scalar scales to (1, 1) for broadcasting.
return (
scale
if group_shape is not None and group_shape.is_per_tensor()
else scale.reshape(1, 1)
)
if scale.ndim == 1:
assert group_shape is not None, (
"group_shape must be provided to correctly broadcast 1D scale"
)
rows, cols = _normalize_quant_group_shape(x, group_shape)
# Determine broadcasting dimension: either rows or columns match group size
if rows == x.shape[-2]:
scale = scale.unsqueeze(-2)
elif cols == x.shape[-1]:
scale = scale.unsqueeze(-1)
else:
raise ValueError(
f"1D scale with shape {scale.shape} cannot be broadcast to x with shape"
f" {x.shape}, group_shape={(rows, cols)}"
)
return scale
# Quantize assuming once scale per group of elements with shape group_shape,
# example group shapes:
# * (-1, -1) for per-tensor quantization
@@ -241,7 +286,7 @@ def scaled_quantize(
_, fp8_max = get_fp8_min_max()
scale = fp8_max / amax
# Apply scale and convert form:
# Apply scale and convert from:
# (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N) to (M, N)
x_scl_sat = (
(x_blkd_permd * scale.unsqueeze(-1))
@@ -261,29 +306,7 @@ def scaled_dequantize(
group_shape: GroupShape | None = None,
out_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
if group_shape is not None:
group_shape = _normalize_quant_group_shape(x_q, group_shape)
if x_s.numel() == 1: # scalar
x_s = x_s.reshape(1, 1) # normalize all scalar-like tensors to (1, 1)
if x_s.ndim == 1:
if group_shape is None:
raise AssertionError(
"if x_s is 1D tensor, group_shape must be provided otherwise "
"its ambiguous which dimension to broadcast x_s to"
)
# unsqueeze the scales for the dimension where we want to broadcast
# across the full extent
if group_shape[0] == x_q.shape[-2]:
x_s = x_s.unsqueeze(-2)
elif group_shape[1] == x_q.shape[-1]:
x_s = x_s.unsqueeze(-1)
else:
raise AssertionError(
"if x_s is a vector we should be broadcasting it to the full "
"extent of one of the dimensions"
)
x_s = prep_scale_for_group_broadcast(x_s, x_q, group_shape)
if group_shape is not None:
assert x_s.shape[-1] == x_q.shape[-1] // group_shape[1]
assert x_s.shape[-2] == x_q.shape[-2] // group_shape[0]