Add llmcompressor fp8 kv-cache quant (per-tensor and per-attn_head) (#30141)
Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com> Signed-off-by: eldarkurtic <8884008+eldarkurtic@users.noreply.github.com>
This commit is contained in:
@@ -75,13 +75,16 @@ def set_default_quant_scales(layer: nn.Module, register_buffer: bool = False) ->
|
||||
layer._v_scale_float = 1.0
|
||||
layer._prob_scale_float = 1.0
|
||||
|
||||
# Initialize q/k/v range constants used by calc_kv_scales
|
||||
layer.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
|
||||
layer.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
|
||||
layer.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
|
||||
|
||||
|
||||
def _init_kv_cache_quant(
|
||||
layer: nn.Module,
|
||||
quant_config: QuantizationConfig | None,
|
||||
prefix: str,
|
||||
kv_cache_dtype: str,
|
||||
calculate_kv_scales: bool,
|
||||
) -> None:
|
||||
"""Initializes KV cache scaling factors and quantization method.
|
||||
|
||||
@@ -94,16 +97,10 @@ def _init_kv_cache_quant(
|
||||
layer: The attention layer instance to initialize.
|
||||
quant_config: Optional quantization configuration.
|
||||
prefix: Layer name prefix for quantization method lookup.
|
||||
kv_cache_dtype: The KV cache data type string.
|
||||
calculate_kv_scales: Whether to calculate KV scales dynamically.
|
||||
"""
|
||||
# The default k/v_scale is set to 1.0. This is ignored
|
||||
# when kv-cache is not fp8, and should be used with
|
||||
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
|
||||
# expect the pre-quantized k/v_scale to be loaded along
|
||||
# with the model weights.
|
||||
layer.kv_cache_dtype = kv_cache_dtype
|
||||
layer.calculate_kv_scales = calculate_kv_scales
|
||||
quant_method = (
|
||||
quant_config.get_quant_method(layer, prefix=prefix) if quant_config else None
|
||||
)
|
||||
|
||||
# Note [Register q/k/v/prob scales in state dict]
|
||||
# When calling model.to(device), only parameters/buffers in state dict are
|
||||
@@ -133,7 +130,7 @@ def _init_kv_cache_quant(
|
||||
assert isinstance(quant_method, BaseKVCacheMethod)
|
||||
# TODO (mgoin): kv cache dtype should be specified in the FP8
|
||||
# checkpoint config and become the "auto" behavior
|
||||
if kv_cache_dtype == "fp8_e5m2":
|
||||
if layer.kv_cache_dtype == "fp8_e5m2":
|
||||
raise ValueError("fp8_e5m2 kv-cache is not supported with fp8 checkpoints.")
|
||||
# If quantization is enabled, we make "k_scale" and "v_scale"
|
||||
# parameters so that it can be loaded from the model checkpoint.
|
||||
@@ -197,9 +194,20 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
calculate_kv_scales = False
|
||||
|
||||
# llm-compressor mdls need to set cache_dtype to "fp8" manually.
|
||||
if getattr(quant_config, "kv_cache_scheme", None) is not None:
|
||||
kv_cache_dtype = "fp8"
|
||||
calculate_kv_scales = False
|
||||
if cache_config is not None:
|
||||
cache_config.cache_dtype = "fp8"
|
||||
cache_config.calculate_kv_scales = False
|
||||
|
||||
self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
|
||||
kv_cache_dtype, vllm_config.model_config
|
||||
)
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.calculate_kv_scales = calculate_kv_scales
|
||||
if num_kv_heads is None:
|
||||
num_kv_heads = num_heads
|
||||
assert num_heads % num_kv_heads == 0, (
|
||||
@@ -208,15 +216,6 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
self.quant_config = quant_config
|
||||
self.layer_name = prefix
|
||||
|
||||
# Initialize KV cache quantization attributes
|
||||
_init_kv_cache_quant(
|
||||
self,
|
||||
self.quant_config,
|
||||
self.layer_name,
|
||||
kv_cache_dtype,
|
||||
calculate_kv_scales,
|
||||
)
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.head_size_v = self.head_size if head_size_v is None else head_size_v
|
||||
@@ -318,18 +317,24 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
# Initialize q/k/v range constants.
|
||||
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
|
||||
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
|
||||
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
|
||||
# Initialize KV cache quantization attributes
|
||||
_init_kv_cache_quant(self, quant_config, prefix)
|
||||
|
||||
# for attn backends supporting query quantization
|
||||
self.query_quant = None
|
||||
if (
|
||||
self.kv_cache_dtype.startswith("fp8")
|
||||
and self.impl.supports_quant_query_input
|
||||
if self.impl.supports_quant_query_input and self.kv_cache_dtype.startswith(
|
||||
"fp8"
|
||||
):
|
||||
self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
|
||||
is_per_head = (
|
||||
hasattr(self, "q_scale") and self.q_scale.numel() == self.num_kv_heads
|
||||
)
|
||||
block_size = self.head_size * self.num_heads // self.num_kv_heads
|
||||
self.query_quant = QuantFP8(
|
||||
static=True,
|
||||
group_shape=GroupShape(-1, block_size)
|
||||
if is_per_head
|
||||
else GroupShape.PER_TENSOR,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -524,13 +529,9 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
self.quant_config = quant_config
|
||||
|
||||
# Initialize KV cache quantization attributes
|
||||
_init_kv_cache_quant(
|
||||
self,
|
||||
self.quant_config,
|
||||
self.layer_name,
|
||||
kv_cache_dtype,
|
||||
calculate_kv_scales,
|
||||
)
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.calculate_kv_scales = calculate_kv_scales
|
||||
_init_kv_cache_quant(self, quant_config, prefix)
|
||||
|
||||
dtype = torch.get_default_dtype()
|
||||
self.attn_backend = get_attn_backend(
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from contextlib import suppress
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||
|
||||
import torch
|
||||
@@ -19,6 +20,10 @@ from compressed_tensors.transform import TransformConfig
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (
|
||||
@@ -87,6 +92,8 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
kv_cache_scheme: dict[str, Any] | None = None,
|
||||
config: dict[str, Any] | None = None,
|
||||
transform_config: dict[str, Any] | None = None,
|
||||
total_num_heads: int | None = None,
|
||||
total_num_kv_heads: int | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.ignore = ignore
|
||||
@@ -97,6 +104,8 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
self.sparsity_scheme_map = sparsity_scheme_map
|
||||
self.sparsity_ignore_list = sparsity_ignore_list
|
||||
self.config = config
|
||||
self.total_num_heads = total_num_heads
|
||||
self.total_num_kv_heads = total_num_kv_heads
|
||||
|
||||
if transform_config:
|
||||
self.transform_config = TransformConfig.model_validate(transform_config)
|
||||
@@ -200,13 +209,29 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "CompressedTensorsConfig":
|
||||
# We keep only config groups which are not doing Attention quantization
|
||||
# because Attention quantization on its own is not supported by vLLM.
|
||||
# It is coupled with KV-cache quantization, and if scales are present in the
|
||||
# checkpoint, they will be used properly.
|
||||
grps_without_attn_quant = {}
|
||||
for k, v in config["config_groups"].items():
|
||||
# e.g. LlamaAttention, Qwen3Attention, etc.
|
||||
if len(v["targets"]) == 1 and v["targets"][0].endswith("Attention"):
|
||||
logger.warning(
|
||||
"Skipping CompressedTensors config group for %s. Attention quant "
|
||||
"is coupled with KV-cache quantization in vLLM.",
|
||||
v["targets"][0],
|
||||
)
|
||||
continue
|
||||
grps_without_attn_quant[k] = v
|
||||
config["config_groups"] = grps_without_attn_quant
|
||||
|
||||
ignore: list[str] = cast(list[str], config.get("ignore", []))
|
||||
quant_format = cast(str, config.get("format"))
|
||||
target_scheme_map = cls._quantization_scheme_map_from_config(config=config)
|
||||
sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config(
|
||||
config=config
|
||||
)
|
||||
transform_config = config.get("transform_config")
|
||||
|
||||
return cls(
|
||||
target_scheme_map=target_scheme_map,
|
||||
@@ -215,7 +240,10 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
sparsity_scheme_map=sparsity_scheme_map,
|
||||
sparsity_ignore_list=sparsity_ignore_list,
|
||||
config=config,
|
||||
transform_config=transform_config,
|
||||
transform_config=config.get("transform_config"),
|
||||
kv_cache_scheme=config.get("kv_cache_scheme"),
|
||||
total_num_heads=config.get("total_num_heads"),
|
||||
total_num_kv_heads=config.get("total_num_kv_heads"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -791,22 +819,6 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
return None
|
||||
|
||||
def get_cache_scale(self, name: str) -> str | None:
|
||||
"""
|
||||
Check whether the param name matches the format for k/v cache scales
|
||||
in compressed-tensors. If this is the case, return its equivalent
|
||||
param name expected by vLLM
|
||||
|
||||
:param name: param name
|
||||
:return: matching param name for KV cache scale in vLLM
|
||||
"""
|
||||
if name.endswith(".output_scale") and ".k_proj" in name:
|
||||
return name.replace(".k_proj.output_scale", ".attn.k_scale")
|
||||
if name.endswith(".output_scale") and ".v_proj" in name:
|
||||
return name.replace(".v_proj.output_scale", ".attn.v_scale")
|
||||
# If no matches, return None
|
||||
return None
|
||||
|
||||
def has_blocked_weights(self) -> bool:
|
||||
for scheme in self.target_scheme_map.values():
|
||||
weight_quant = scheme.get("weights")
|
||||
@@ -965,12 +977,16 @@ class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
|
||||
f"received num_bits={num_bits}, type={type_}"
|
||||
)
|
||||
|
||||
# TODO: delegate validation to compressed-tensors library so that we have a
|
||||
# single source of truth. Right now this is not possible until the next release
|
||||
# of compressed-tensors.
|
||||
strategy = kv_cache_scheme.get("strategy")
|
||||
if strategy != "tensor":
|
||||
supported_strategies = ("tensor", "attn_head")
|
||||
if strategy not in supported_strategies:
|
||||
raise NotImplementedError(
|
||||
"Only support per-tensor scaling factor "
|
||||
"for compressed-tensors KV cache. "
|
||||
f"Expected strategy: tensor, found strategy: {strategy}"
|
||||
"Invalid strategy for compressed-tensors KV cache. "
|
||||
f"Expected strategies: {supported_strategies}, found strategy:"
|
||||
f" {strategy}"
|
||||
)
|
||||
|
||||
is_symmetric = kv_cache_scheme.get("symmetric")
|
||||
@@ -980,3 +996,133 @@ class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
|
||||
"for compressed-tensors KV cache. "
|
||||
f"However found symmetric: {is_symmetric}"
|
||||
)
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module):
|
||||
"""
|
||||
Initialize placeholder scales and zero points to enable loading of
|
||||
quantized params from compressed-tensors checkpoints.
|
||||
"""
|
||||
strategy = None # for backward compatibility
|
||||
if (
|
||||
hasattr(self.quant_config, "kv_cache_scheme")
|
||||
and self.quant_config.kv_cache_scheme is not None
|
||||
):
|
||||
strategy = self.quant_config.kv_cache_scheme["strategy"]
|
||||
|
||||
if strategy == "attn_head":
|
||||
assert layer.impl.supports_per_head_quant_scales, (
|
||||
f"Layer {layer.__class__.__name__} with implementation "
|
||||
f"{layer.impl.__class__.__name__} does not support per-head scales."
|
||||
)
|
||||
n_scales = int(layer.num_kv_heads)
|
||||
else:
|
||||
n_scales = 1
|
||||
|
||||
layer.k_scale = torch.nn.Parameter(
|
||||
torch.ones(n_scales, requires_grad=False, dtype=torch.float32)
|
||||
)
|
||||
layer.v_scale = torch.nn.Parameter(
|
||||
torch.ones(n_scales, requires_grad=False, dtype=torch.float32)
|
||||
)
|
||||
layer.q_scale = torch.nn.Parameter(
|
||||
torch.ones(n_scales, requires_grad=False, dtype=torch.float32)
|
||||
)
|
||||
|
||||
# Zero points are not used in vLLM as currently only symmetric quantization is
|
||||
# supported. We need to create them here to enable loading of llm-compressor
|
||||
# checkpoints which contain them irrespective of the symmetric/asymmetric
|
||||
# scheme used during quantization.
|
||||
layer.k_zero_point = torch.nn.Parameter(
|
||||
torch.zeros(n_scales, requires_grad=False)
|
||||
)
|
||||
layer.v_zero_point = torch.nn.Parameter(
|
||||
torch.zeros(n_scales, requires_grad=False)
|
||||
)
|
||||
layer.q_zero_point = torch.nn.Parameter(
|
||||
torch.zeros(n_scales, requires_grad=False)
|
||||
)
|
||||
|
||||
# TP-aware loading for attn_head strategy follows attention head partitioning:
|
||||
# - q_scale is partitioned over query heads.
|
||||
# - k/v_scale is partitioned over kv heads when total_kv_heads >= tp_size,
|
||||
# and replicated when total_kv_heads < tp_size.
|
||||
if strategy == "attn_head":
|
||||
|
||||
def _tp_aware_loader(
|
||||
param: torch.Tensor,
|
||||
loaded_weight: torch.Tensor,
|
||||
kind: Literal["q", "k", "v"],
|
||||
param_type: Literal["scale", "zero_point"],
|
||||
):
|
||||
# Zero-points are not used as vLLM only supports symmetric quantization
|
||||
if param_type == "zero_point":
|
||||
return
|
||||
|
||||
# LLM-Compressor stores scales as 3D tensors of shape [num_heads, 1, 1]
|
||||
loaded_weight = loaded_weight.flatten()
|
||||
|
||||
# FlashAttn expects [num_kv_heads] instead of [num_heads] for q_scale.
|
||||
# We reduce by taking the max scale in each attention head group.
|
||||
if kind == "q":
|
||||
reduction_factor = (
|
||||
self.quant_config.total_num_heads # type: ignore[attr-defined]
|
||||
// self.quant_config.total_num_kv_heads # type: ignore[attr-defined]
|
||||
)
|
||||
loaded_weight = torch.amax(
|
||||
loaded_weight.view(-1, reduction_factor), dim=1
|
||||
)
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
if layer.num_kv_heads * tp_size == self.quant_config.total_num_kv_heads: # type: ignore[attr-defined]
|
||||
# heads evenly distributed
|
||||
loaded_weight = loaded_weight[
|
||||
tp_rank * layer.num_kv_heads : (tp_rank + 1)
|
||||
* layer.num_kv_heads
|
||||
]
|
||||
else:
|
||||
# heads replicated to match TP size
|
||||
assert layer.num_kv_heads == 1
|
||||
replicas = tp_size // self.quant_config.total_num_kv_heads # type: ignore[attr-defined]
|
||||
shard_rank = tp_rank // replicas
|
||||
loaded_weight = loaded_weight[shard_rank : shard_rank + 1]
|
||||
|
||||
param.data.copy_(loaded_weight.to(dtype=param.dtype))
|
||||
|
||||
layer.q_scale.weight_loader = partial(
|
||||
_tp_aware_loader, kind="q", param_type="scale"
|
||||
)
|
||||
layer.k_scale.weight_loader = partial(
|
||||
_tp_aware_loader, kind="k", param_type="scale"
|
||||
)
|
||||
layer.v_scale.weight_loader = partial(
|
||||
_tp_aware_loader, kind="v", param_type="scale"
|
||||
)
|
||||
|
||||
layer.q_zero_point.weight_loader = partial(
|
||||
_tp_aware_loader, kind="q", param_type="zero_point"
|
||||
)
|
||||
layer.k_zero_point.weight_loader = partial(
|
||||
_tp_aware_loader, kind="k", param_type="zero_point"
|
||||
)
|
||||
layer.v_zero_point.weight_loader = partial(
|
||||
_tp_aware_loader, kind="v", param_type="zero_point"
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
"""
|
||||
Override the default vLLM placeholder scales with the llm-compressor loaded
|
||||
scales. Zero points are not used as only symmetric quantization is supported.
|
||||
"""
|
||||
layer._k_scale = layer.k_scale
|
||||
layer._v_scale = layer.v_scale
|
||||
layer._q_scale = layer.q_scale
|
||||
|
||||
# Discard all placeholders.
|
||||
del layer.k_scale
|
||||
del layer.v_scale
|
||||
del layer.q_scale
|
||||
del layer.k_zero_point
|
||||
del layer.v_zero_point
|
||||
del layer.q_zero_point
|
||||
|
||||
@@ -11,6 +11,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
get_fp8_min_max,
|
||||
group_broadcast,
|
||||
prep_scale_for_group_broadcast,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -40,7 +41,7 @@ class QuantFP8(CustomOp):
|
||||
"""
|
||||
:param static: static or dynamic quantization
|
||||
:param group_shape: quantization group shape (PER_TOKEN, PER_TENSOR,
|
||||
or arbitrary block size)
|
||||
PER_CHANNEL, or arbitrary block size)
|
||||
:param num_token_padding: Pad the token dimension of output to this
|
||||
size
|
||||
:param column_major_scales: For group quantization, output scales in
|
||||
@@ -157,6 +158,8 @@ class QuantFP8(CustomOp):
|
||||
x_max = x.abs().max().unsqueeze(-1).to(torch.float32)
|
||||
|
||||
scale = (x_max / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR)
|
||||
else:
|
||||
scale = prep_scale_for_group_broadcast(scale, x, self.group_shape)
|
||||
|
||||
# Even for dynamic per-token scales,
|
||||
# reciprocal performs slightly better than division
|
||||
|
||||
@@ -191,6 +191,51 @@ def group_broadcast(t, shape):
|
||||
return t
|
||||
|
||||
|
||||
def prep_scale_for_group_broadcast(
|
||||
scale: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
group_shape: GroupShape | None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Prepare the input quantization scale for group broadcasting.
|
||||
|
||||
Args:
|
||||
scale: The scale tensor (scalar or 1D).
|
||||
x: Target tensor whose shape determines broadcast dimensions.
|
||||
group_shape: GroupShape to broadcast over.
|
||||
|
||||
Returns:
|
||||
scale reshaped for correct broadcasting.
|
||||
"""
|
||||
if scale.numel() == 1:
|
||||
# For per-tensor quant, keep the scale as a scalar (not reshaped to (1, 1)).
|
||||
# This avoids misclassifying it as channelwise quant in Fp8LinearOp.apply,
|
||||
# where the "per_tensor_activations" check relies on "x_scale.dim() < 2":
|
||||
# per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2
|
||||
# For all other cases, reshape scalar scales to (1, 1) for broadcasting.
|
||||
return (
|
||||
scale
|
||||
if group_shape is not None and group_shape.is_per_tensor()
|
||||
else scale.reshape(1, 1)
|
||||
)
|
||||
if scale.ndim == 1:
|
||||
assert group_shape is not None, (
|
||||
"group_shape must be provided to correctly broadcast 1D scale"
|
||||
)
|
||||
rows, cols = _normalize_quant_group_shape(x, group_shape)
|
||||
# Determine broadcasting dimension: either rows or columns match group size
|
||||
if rows == x.shape[-2]:
|
||||
scale = scale.unsqueeze(-2)
|
||||
elif cols == x.shape[-1]:
|
||||
scale = scale.unsqueeze(-1)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"1D scale with shape {scale.shape} cannot be broadcast to x with shape"
|
||||
f" {x.shape}, group_shape={(rows, cols)}"
|
||||
)
|
||||
return scale
|
||||
|
||||
|
||||
# Quantize assuming once scale per group of elements with shape group_shape,
|
||||
# example group shapes:
|
||||
# * (-1, -1) for per-tensor quantization
|
||||
@@ -241,7 +286,7 @@ def scaled_quantize(
|
||||
_, fp8_max = get_fp8_min_max()
|
||||
scale = fp8_max / amax
|
||||
|
||||
# Apply scale and convert form:
|
||||
# Apply scale and convert from:
|
||||
# (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N) to (M, N)
|
||||
x_scl_sat = (
|
||||
(x_blkd_permd * scale.unsqueeze(-1))
|
||||
@@ -261,29 +306,7 @@ def scaled_dequantize(
|
||||
group_shape: GroupShape | None = None,
|
||||
out_dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
if group_shape is not None:
|
||||
group_shape = _normalize_quant_group_shape(x_q, group_shape)
|
||||
|
||||
if x_s.numel() == 1: # scalar
|
||||
x_s = x_s.reshape(1, 1) # normalize all scalar-like tensors to (1, 1)
|
||||
if x_s.ndim == 1:
|
||||
if group_shape is None:
|
||||
raise AssertionError(
|
||||
"if x_s is 1D tensor, group_shape must be provided otherwise "
|
||||
"its ambiguous which dimension to broadcast x_s to"
|
||||
)
|
||||
# unsqueeze the scales for the dimension where we want to broadcast
|
||||
# across the full extent
|
||||
if group_shape[0] == x_q.shape[-2]:
|
||||
x_s = x_s.unsqueeze(-2)
|
||||
elif group_shape[1] == x_q.shape[-1]:
|
||||
x_s = x_s.unsqueeze(-1)
|
||||
else:
|
||||
raise AssertionError(
|
||||
"if x_s is a vector we should be broadcasting it to the full "
|
||||
"extent of one of the dimensions"
|
||||
)
|
||||
|
||||
x_s = prep_scale_for_group_broadcast(x_s, x_q, group_shape)
|
||||
if group_shape is not None:
|
||||
assert x_s.shape[-1] == x_q.shape[-1] // group_shape[1]
|
||||
assert x_s.shape[-2] == x_q.shape[-2] // group_shape[0]
|
||||
|
||||
@@ -246,6 +246,23 @@ def get_quant_config(
|
||||
# compressed-tensors uses a compressions_config
|
||||
hf_quant_config = getattr(model_config.hf_config, "compression_config", None)
|
||||
|
||||
# Pipe information about heads to enable TP-aware loading of attn_head scales
|
||||
if (
|
||||
hf_quant_config is not None
|
||||
and hf_quant_config.get("quant_method") == "compressed-tensors"
|
||||
):
|
||||
if hf_text_config is not None:
|
||||
n_heads = getattr(hf_text_config, "num_attention_heads", None)
|
||||
n_kv_heads = getattr(hf_text_config, "num_key_value_heads", None)
|
||||
else:
|
||||
n_heads = getattr(model_config.hf_config, "num_attention_heads", None)
|
||||
n_kv_heads = getattr(model_config.hf_config, "num_key_value_heads", None)
|
||||
|
||||
hf_quant_config["total_num_heads"] = n_heads
|
||||
hf_quant_config["total_num_kv_heads"] = (
|
||||
n_kv_heads if n_kv_heads is not None else n_heads
|
||||
)
|
||||
|
||||
if hf_quant_config is not None:
|
||||
return quant_cls.from_config(hf_quant_config)
|
||||
|
||||
@@ -1157,11 +1174,21 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None:
|
||||
# .mixer.attn.{k,v}_scale
|
||||
(r"\.mixer\.[kv]_proj\.([kv])_scale$", r".mixer.attn.\1_scale"),
|
||||
# Default format: .{k,v}_scale -> .attn.{k,v}_scale
|
||||
(r"\.([kv])_scale$", r".attn.\1_scale"),
|
||||
(r"\.([qkv])_scale$", r".attn.\1_scale"),
|
||||
(r"\.([qkv])_zero_point$", r".attn.\1_zero_point"),
|
||||
]
|
||||
|
||||
# Check if name ends with k_scale or v_scale
|
||||
if name.endswith((".k_scale", ".v_scale")):
|
||||
if name.endswith(
|
||||
(
|
||||
".k_scale",
|
||||
".v_scale",
|
||||
".q_scale",
|
||||
".k_zero_point",
|
||||
".v_zero_point",
|
||||
".q_zero_point",
|
||||
)
|
||||
):
|
||||
import regex as re
|
||||
|
||||
for pattern, replacement in scale_mapping_patterns:
|
||||
|
||||
@@ -437,7 +437,7 @@ class ApertusModel(nn.Module):
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
if "scale" in name:
|
||||
if "scale" in name or "zero_point" in name:
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
|
||||
@@ -303,7 +303,7 @@ class ArceeModel(nn.Module):
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
|
||||
if "scale" in name:
|
||||
if "scale" in name or "zero_point" in name:
|
||||
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if remapped_name is None:
|
||||
continue
|
||||
|
||||
@@ -465,8 +465,8 @@ class LlamaModel(nn.Module):
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
if "scale" in name:
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
if "scale" in name or "zero_point" in name:
|
||||
# Remapping the name of FP8 kv-scale or zero point.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
@@ -140,8 +140,8 @@ class LlamaModel(nn.Module):
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
# Remapping the name FP8 kv-scale
|
||||
if "scale" in name:
|
||||
# Remapping the name FP8 kv-scale or zero point.
|
||||
if "scale" in name or "zero_point" in name:
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
@@ -238,8 +238,8 @@ class LlamaModel(nn.Module):
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
# Remapping the name FP8 kv-scale
|
||||
if "scale" in name:
|
||||
# Remapping the name FP8 kv-scale or zero point.
|
||||
if "scale" in name or "zero_point" in name:
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
@@ -661,7 +661,7 @@ class NemotronHModel(nn.Module):
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "scale" in name:
|
||||
if "scale" in name or "zero_point" in name:
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
|
||||
@@ -342,7 +342,7 @@ class DeciModel(nn.Module):
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
if "scale" in name:
|
||||
if "scale" in name or "zero_point" in name:
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
|
||||
@@ -620,6 +620,7 @@ class AttentionImpl(ABC, Generic[T]):
|
||||
# TODO add support to more backends:
|
||||
# https://github.com/vllm-project/vllm/issues/25584
|
||||
supports_quant_query_input: bool = False
|
||||
supports_per_head_quant_scales: bool = False
|
||||
|
||||
dcp_world_size: int
|
||||
dcp_rank: int
|
||||
|
||||
@@ -576,6 +576,11 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
)
|
||||
|
||||
self.supports_quant_query_input = True
|
||||
self.supports_per_head_quant_scales = (
|
||||
self.vllm_flash_attn_version >= 3
|
||||
if self.vllm_flash_attn_version is not None
|
||||
else False
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -691,6 +696,10 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
|
||||
descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
|
||||
|
||||
q_descale = layer._q_scale.expand(descale_shape)
|
||||
k_descale = layer._k_scale.expand(descale_shape)
|
||||
v_descale = layer._v_scale.expand(descale_shape)
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
self._forward_with_dcp(
|
||||
query[:num_actual_tokens],
|
||||
@@ -700,9 +709,9 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
value_cache,
|
||||
output[:num_actual_tokens],
|
||||
attn_metadata,
|
||||
q_descale=layer._q_scale.expand(descale_shape),
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
q_descale=q_descale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
)
|
||||
return output
|
||||
else:
|
||||
@@ -728,9 +737,9 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
softcap=self.logits_soft_cap,
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
q_descale=layer._q_scale.expand(descale_shape),
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
q_descale=q_descale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
num_splits=attn_metadata.max_num_splits,
|
||||
s_aux=self.sinks,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user