Add llmcompressor fp8 kv-cache quant (per-tensor and per-attn_head) (#30141)
Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com> Signed-off-by: eldarkurtic <8884008+eldarkurtic@users.noreply.github.com>
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from contextlib import suppress
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||
|
||||
import torch
|
||||
@@ -19,6 +20,10 @@ from compressed_tensors.transform import TransformConfig
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (
|
||||
@@ -87,6 +92,8 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
kv_cache_scheme: dict[str, Any] | None = None,
|
||||
config: dict[str, Any] | None = None,
|
||||
transform_config: dict[str, Any] | None = None,
|
||||
total_num_heads: int | None = None,
|
||||
total_num_kv_heads: int | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.ignore = ignore
|
||||
@@ -97,6 +104,8 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
self.sparsity_scheme_map = sparsity_scheme_map
|
||||
self.sparsity_ignore_list = sparsity_ignore_list
|
||||
self.config = config
|
||||
self.total_num_heads = total_num_heads
|
||||
self.total_num_kv_heads = total_num_kv_heads
|
||||
|
||||
if transform_config:
|
||||
self.transform_config = TransformConfig.model_validate(transform_config)
|
||||
@@ -200,13 +209,29 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "CompressedTensorsConfig":
|
||||
# We keep only config groups which are not doing Attention quantization
|
||||
# because Attention quantization on its own is not supported by vLLM.
|
||||
# It is coupled with KV-cache quantization, and if scales are present in the
|
||||
# checkpoint, they will be used properly.
|
||||
grps_without_attn_quant = {}
|
||||
for k, v in config["config_groups"].items():
|
||||
# e.g. LlamaAttention, Qwen3Attention, etc.
|
||||
if len(v["targets"]) == 1 and v["targets"][0].endswith("Attention"):
|
||||
logger.warning(
|
||||
"Skipping CompressedTensors config group for %s. Attention quant "
|
||||
"is coupled with KV-cache quantization in vLLM.",
|
||||
v["targets"][0],
|
||||
)
|
||||
continue
|
||||
grps_without_attn_quant[k] = v
|
||||
config["config_groups"] = grps_without_attn_quant
|
||||
|
||||
ignore: list[str] = cast(list[str], config.get("ignore", []))
|
||||
quant_format = cast(str, config.get("format"))
|
||||
target_scheme_map = cls._quantization_scheme_map_from_config(config=config)
|
||||
sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config(
|
||||
config=config
|
||||
)
|
||||
transform_config = config.get("transform_config")
|
||||
|
||||
return cls(
|
||||
target_scheme_map=target_scheme_map,
|
||||
@@ -215,7 +240,10 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
sparsity_scheme_map=sparsity_scheme_map,
|
||||
sparsity_ignore_list=sparsity_ignore_list,
|
||||
config=config,
|
||||
transform_config=transform_config,
|
||||
transform_config=config.get("transform_config"),
|
||||
kv_cache_scheme=config.get("kv_cache_scheme"),
|
||||
total_num_heads=config.get("total_num_heads"),
|
||||
total_num_kv_heads=config.get("total_num_kv_heads"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -791,22 +819,6 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
return None
|
||||
|
||||
def get_cache_scale(self, name: str) -> str | None:
|
||||
"""
|
||||
Check whether the param name matches the format for k/v cache scales
|
||||
in compressed-tensors. If this is the case, return its equivalent
|
||||
param name expected by vLLM
|
||||
|
||||
:param name: param name
|
||||
:return: matching param name for KV cache scale in vLLM
|
||||
"""
|
||||
if name.endswith(".output_scale") and ".k_proj" in name:
|
||||
return name.replace(".k_proj.output_scale", ".attn.k_scale")
|
||||
if name.endswith(".output_scale") and ".v_proj" in name:
|
||||
return name.replace(".v_proj.output_scale", ".attn.v_scale")
|
||||
# If no matches, return None
|
||||
return None
|
||||
|
||||
def has_blocked_weights(self) -> bool:
|
||||
for scheme in self.target_scheme_map.values():
|
||||
weight_quant = scheme.get("weights")
|
||||
@@ -965,12 +977,16 @@ class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
|
||||
f"received num_bits={num_bits}, type={type_}"
|
||||
)
|
||||
|
||||
# TODO: delegate validation to compressed-tensors library so that we have a
|
||||
# single source of truth. Right now this is not possible until the next release
|
||||
# of compressed-tensors.
|
||||
strategy = kv_cache_scheme.get("strategy")
|
||||
if strategy != "tensor":
|
||||
supported_strategies = ("tensor", "attn_head")
|
||||
if strategy not in supported_strategies:
|
||||
raise NotImplementedError(
|
||||
"Only support per-tensor scaling factor "
|
||||
"for compressed-tensors KV cache. "
|
||||
f"Expected strategy: tensor, found strategy: {strategy}"
|
||||
"Invalid strategy for compressed-tensors KV cache. "
|
||||
f"Expected strategies: {supported_strategies}, found strategy:"
|
||||
f" {strategy}"
|
||||
)
|
||||
|
||||
is_symmetric = kv_cache_scheme.get("symmetric")
|
||||
@@ -980,3 +996,133 @@ class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
|
||||
"for compressed-tensors KV cache. "
|
||||
f"However found symmetric: {is_symmetric}"
|
||||
)
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module):
|
||||
"""
|
||||
Initialize placeholder scales and zero points to enable loading of
|
||||
quantized params from compressed-tensors checkpoints.
|
||||
"""
|
||||
strategy = None # for backward compatibility
|
||||
if (
|
||||
hasattr(self.quant_config, "kv_cache_scheme")
|
||||
and self.quant_config.kv_cache_scheme is not None
|
||||
):
|
||||
strategy = self.quant_config.kv_cache_scheme["strategy"]
|
||||
|
||||
if strategy == "attn_head":
|
||||
assert layer.impl.supports_per_head_quant_scales, (
|
||||
f"Layer {layer.__class__.__name__} with implementation "
|
||||
f"{layer.impl.__class__.__name__} does not support per-head scales."
|
||||
)
|
||||
n_scales = int(layer.num_kv_heads)
|
||||
else:
|
||||
n_scales = 1
|
||||
|
||||
layer.k_scale = torch.nn.Parameter(
|
||||
torch.ones(n_scales, requires_grad=False, dtype=torch.float32)
|
||||
)
|
||||
layer.v_scale = torch.nn.Parameter(
|
||||
torch.ones(n_scales, requires_grad=False, dtype=torch.float32)
|
||||
)
|
||||
layer.q_scale = torch.nn.Parameter(
|
||||
torch.ones(n_scales, requires_grad=False, dtype=torch.float32)
|
||||
)
|
||||
|
||||
# Zero points are not used in vLLM as currently only symmetric quantization is
|
||||
# supported. We need to create them here to enable loading of llm-compressor
|
||||
# checkpoints which contain them irrespective of the symmetric/asymmetric
|
||||
# scheme used during quantization.
|
||||
layer.k_zero_point = torch.nn.Parameter(
|
||||
torch.zeros(n_scales, requires_grad=False)
|
||||
)
|
||||
layer.v_zero_point = torch.nn.Parameter(
|
||||
torch.zeros(n_scales, requires_grad=False)
|
||||
)
|
||||
layer.q_zero_point = torch.nn.Parameter(
|
||||
torch.zeros(n_scales, requires_grad=False)
|
||||
)
|
||||
|
||||
# TP-aware loading for attn_head strategy follows attention head partitioning:
|
||||
# - q_scale is partitioned over query heads.
|
||||
# - k/v_scale is partitioned over kv heads when total_kv_heads >= tp_size,
|
||||
# and replicated when total_kv_heads < tp_size.
|
||||
if strategy == "attn_head":
|
||||
|
||||
def _tp_aware_loader(
|
||||
param: torch.Tensor,
|
||||
loaded_weight: torch.Tensor,
|
||||
kind: Literal["q", "k", "v"],
|
||||
param_type: Literal["scale", "zero_point"],
|
||||
):
|
||||
# Zero-points are not used as vLLM only supports symmetric quantization
|
||||
if param_type == "zero_point":
|
||||
return
|
||||
|
||||
# LLM-Compressor stores scales as 3D tensors of shape [num_heads, 1, 1]
|
||||
loaded_weight = loaded_weight.flatten()
|
||||
|
||||
# FlashAttn expects [num_kv_heads] instead of [num_heads] for q_scale.
|
||||
# We reduce by taking the max scale in each attention head group.
|
||||
if kind == "q":
|
||||
reduction_factor = (
|
||||
self.quant_config.total_num_heads # type: ignore[attr-defined]
|
||||
// self.quant_config.total_num_kv_heads # type: ignore[attr-defined]
|
||||
)
|
||||
loaded_weight = torch.amax(
|
||||
loaded_weight.view(-1, reduction_factor), dim=1
|
||||
)
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
if layer.num_kv_heads * tp_size == self.quant_config.total_num_kv_heads: # type: ignore[attr-defined]
|
||||
# heads evenly distributed
|
||||
loaded_weight = loaded_weight[
|
||||
tp_rank * layer.num_kv_heads : (tp_rank + 1)
|
||||
* layer.num_kv_heads
|
||||
]
|
||||
else:
|
||||
# heads replicated to match TP size
|
||||
assert layer.num_kv_heads == 1
|
||||
replicas = tp_size // self.quant_config.total_num_kv_heads # type: ignore[attr-defined]
|
||||
shard_rank = tp_rank // replicas
|
||||
loaded_weight = loaded_weight[shard_rank : shard_rank + 1]
|
||||
|
||||
param.data.copy_(loaded_weight.to(dtype=param.dtype))
|
||||
|
||||
layer.q_scale.weight_loader = partial(
|
||||
_tp_aware_loader, kind="q", param_type="scale"
|
||||
)
|
||||
layer.k_scale.weight_loader = partial(
|
||||
_tp_aware_loader, kind="k", param_type="scale"
|
||||
)
|
||||
layer.v_scale.weight_loader = partial(
|
||||
_tp_aware_loader, kind="v", param_type="scale"
|
||||
)
|
||||
|
||||
layer.q_zero_point.weight_loader = partial(
|
||||
_tp_aware_loader, kind="q", param_type="zero_point"
|
||||
)
|
||||
layer.k_zero_point.weight_loader = partial(
|
||||
_tp_aware_loader, kind="k", param_type="zero_point"
|
||||
)
|
||||
layer.v_zero_point.weight_loader = partial(
|
||||
_tp_aware_loader, kind="v", param_type="zero_point"
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
"""
|
||||
Override the default vLLM placeholder scales with the llm-compressor loaded
|
||||
scales. Zero points are not used as only symmetric quantization is supported.
|
||||
"""
|
||||
layer._k_scale = layer.k_scale
|
||||
layer._v_scale = layer.v_scale
|
||||
layer._q_scale = layer.q_scale
|
||||
|
||||
# Discard all placeholders.
|
||||
del layer.k_scale
|
||||
del layer.v_scale
|
||||
del layer.q_scale
|
||||
del layer.k_zero_point
|
||||
del layer.v_zero_point
|
||||
del layer.q_zero_point
|
||||
|
||||
@@ -11,6 +11,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
get_fp8_min_max,
|
||||
group_broadcast,
|
||||
prep_scale_for_group_broadcast,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -40,7 +41,7 @@ class QuantFP8(CustomOp):
|
||||
"""
|
||||
:param static: static or dynamic quantization
|
||||
:param group_shape: quantization group shape (PER_TOKEN, PER_TENSOR,
|
||||
or arbitrary block size)
|
||||
PER_CHANNEL, or arbitrary block size)
|
||||
:param num_token_padding: Pad the token dimension of output to this
|
||||
size
|
||||
:param column_major_scales: For group quantization, output scales in
|
||||
@@ -157,6 +158,8 @@ class QuantFP8(CustomOp):
|
||||
x_max = x.abs().max().unsqueeze(-1).to(torch.float32)
|
||||
|
||||
scale = (x_max / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR)
|
||||
else:
|
||||
scale = prep_scale_for_group_broadcast(scale, x, self.group_shape)
|
||||
|
||||
# Even for dynamic per-token scales,
|
||||
# reciprocal performs slightly better than division
|
||||
|
||||
@@ -191,6 +191,51 @@ def group_broadcast(t, shape):
|
||||
return t
|
||||
|
||||
|
||||
def prep_scale_for_group_broadcast(
|
||||
scale: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
group_shape: GroupShape | None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Prepare the input quantization scale for group broadcasting.
|
||||
|
||||
Args:
|
||||
scale: The scale tensor (scalar or 1D).
|
||||
x: Target tensor whose shape determines broadcast dimensions.
|
||||
group_shape: GroupShape to broadcast over.
|
||||
|
||||
Returns:
|
||||
scale reshaped for correct broadcasting.
|
||||
"""
|
||||
if scale.numel() == 1:
|
||||
# For per-tensor quant, keep the scale as a scalar (not reshaped to (1, 1)).
|
||||
# This avoids misclassifying it as channelwise quant in Fp8LinearOp.apply,
|
||||
# where the "per_tensor_activations" check relies on "x_scale.dim() < 2":
|
||||
# per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2
|
||||
# For all other cases, reshape scalar scales to (1, 1) for broadcasting.
|
||||
return (
|
||||
scale
|
||||
if group_shape is not None and group_shape.is_per_tensor()
|
||||
else scale.reshape(1, 1)
|
||||
)
|
||||
if scale.ndim == 1:
|
||||
assert group_shape is not None, (
|
||||
"group_shape must be provided to correctly broadcast 1D scale"
|
||||
)
|
||||
rows, cols = _normalize_quant_group_shape(x, group_shape)
|
||||
# Determine broadcasting dimension: either rows or columns match group size
|
||||
if rows == x.shape[-2]:
|
||||
scale = scale.unsqueeze(-2)
|
||||
elif cols == x.shape[-1]:
|
||||
scale = scale.unsqueeze(-1)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"1D scale with shape {scale.shape} cannot be broadcast to x with shape"
|
||||
f" {x.shape}, group_shape={(rows, cols)}"
|
||||
)
|
||||
return scale
|
||||
|
||||
|
||||
# Quantize assuming once scale per group of elements with shape group_shape,
|
||||
# example group shapes:
|
||||
# * (-1, -1) for per-tensor quantization
|
||||
@@ -241,7 +286,7 @@ def scaled_quantize(
|
||||
_, fp8_max = get_fp8_min_max()
|
||||
scale = fp8_max / amax
|
||||
|
||||
# Apply scale and convert form:
|
||||
# Apply scale and convert from:
|
||||
# (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N) to (M, N)
|
||||
x_scl_sat = (
|
||||
(x_blkd_permd * scale.unsqueeze(-1))
|
||||
@@ -261,29 +306,7 @@ def scaled_dequantize(
|
||||
group_shape: GroupShape | None = None,
|
||||
out_dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
if group_shape is not None:
|
||||
group_shape = _normalize_quant_group_shape(x_q, group_shape)
|
||||
|
||||
if x_s.numel() == 1: # scalar
|
||||
x_s = x_s.reshape(1, 1) # normalize all scalar-like tensors to (1, 1)
|
||||
if x_s.ndim == 1:
|
||||
if group_shape is None:
|
||||
raise AssertionError(
|
||||
"if x_s is 1D tensor, group_shape must be provided otherwise "
|
||||
"its ambiguous which dimension to broadcast x_s to"
|
||||
)
|
||||
# unsqueeze the scales for the dimension where we want to broadcast
|
||||
# across the full extent
|
||||
if group_shape[0] == x_q.shape[-2]:
|
||||
x_s = x_s.unsqueeze(-2)
|
||||
elif group_shape[1] == x_q.shape[-1]:
|
||||
x_s = x_s.unsqueeze(-1)
|
||||
else:
|
||||
raise AssertionError(
|
||||
"if x_s is a vector we should be broadcasting it to the full "
|
||||
"extent of one of the dimensions"
|
||||
)
|
||||
|
||||
x_s = prep_scale_for_group_broadcast(x_s, x_q, group_shape)
|
||||
if group_shape is not None:
|
||||
assert x_s.shape[-1] == x_q.shape[-1] // group_shape[1]
|
||||
assert x_s.shape[-2] == x_q.shape[-2] // group_shape[0]
|
||||
|
||||
Reference in New Issue
Block a user