Add llmcompressor fp8 kv-cache quant (per-tensor and per-attn_head) (#30141)

Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com>
Signed-off-by: eldarkurtic <8884008+eldarkurtic@users.noreply.github.com>
This commit is contained in:
Eldar Kurtić
2026-01-22 21:29:57 +01:00
committed by GitHub
parent 955b43a5a5
commit 44f08af3a7
18 changed files with 558 additions and 263 deletions

View File

@@ -75,13 +75,16 @@ def set_default_quant_scales(layer: nn.Module, register_buffer: bool = False) ->
layer._v_scale_float = 1.0
layer._prob_scale_float = 1.0
# Initialize q/k/v range constants used by calc_kv_scales
layer.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
layer.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
layer.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
def _init_kv_cache_quant(
layer: nn.Module,
quant_config: QuantizationConfig | None,
prefix: str,
kv_cache_dtype: str,
calculate_kv_scales: bool,
) -> None:
"""Initializes KV cache scaling factors and quantization method.
@@ -94,16 +97,10 @@ def _init_kv_cache_quant(
layer: The attention layer instance to initialize.
quant_config: Optional quantization configuration.
prefix: Layer name prefix for quantization method lookup.
kv_cache_dtype: The KV cache data type string.
calculate_kv_scales: Whether to calculate KV scales dynamically.
"""
# The default k/v_scale is set to 1.0. This is ignored
# when kv-cache is not fp8, and should be used with
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
# expect the pre-quantized k/v_scale to be loaded along
# with the model weights.
layer.kv_cache_dtype = kv_cache_dtype
layer.calculate_kv_scales = calculate_kv_scales
quant_method = (
quant_config.get_quant_method(layer, prefix=prefix) if quant_config else None
)
# Note [Register q/k/v/prob scales in state dict]
# When calling model.to(device), only parameters/buffers in state dict are
@@ -133,7 +130,7 @@ def _init_kv_cache_quant(
assert isinstance(quant_method, BaseKVCacheMethod)
# TODO (mgoin): kv cache dtype should be specified in the FP8
# checkpoint config and become the "auto" behavior
if kv_cache_dtype == "fp8_e5m2":
if layer.kv_cache_dtype == "fp8_e5m2":
raise ValueError("fp8_e5m2 kv-cache is not supported with fp8 checkpoints.")
# If quantization is enabled, we make "k_scale" and "v_scale"
# parameters so that it can be loaded from the model checkpoint.
@@ -197,9 +194,20 @@ class Attention(nn.Module, AttentionLayerBase):
kv_cache_dtype = "auto"
block_size = 16
calculate_kv_scales = False
# llm-compressor mdls need to set cache_dtype to "fp8" manually.
if getattr(quant_config, "kv_cache_scheme", None) is not None:
kv_cache_dtype = "fp8"
calculate_kv_scales = False
if cache_config is not None:
cache_config.cache_dtype = "fp8"
cache_config.calculate_kv_scales = False
self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
kv_cache_dtype, vllm_config.model_config
)
self.kv_cache_dtype = kv_cache_dtype
self.calculate_kv_scales = calculate_kv_scales
if num_kv_heads is None:
num_kv_heads = num_heads
assert num_heads % num_kv_heads == 0, (
@@ -208,15 +216,6 @@ class Attention(nn.Module, AttentionLayerBase):
self.quant_config = quant_config
self.layer_name = prefix
# Initialize KV cache quantization attributes
_init_kv_cache_quant(
self,
self.quant_config,
self.layer_name,
kv_cache_dtype,
calculate_kv_scales,
)
self.num_heads = num_heads
self.head_size = head_size
self.head_size_v = self.head_size if head_size_v is None else head_size_v
@@ -318,18 +317,24 @@ class Attention(nn.Module, AttentionLayerBase):
for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
]
# Initialize q/k/v range constants.
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
# Initialize KV cache quantization attributes
_init_kv_cache_quant(self, quant_config, prefix)
# for attn backends supporting query quantization
self.query_quant = None
if (
self.kv_cache_dtype.startswith("fp8")
and self.impl.supports_quant_query_input
if self.impl.supports_quant_query_input and self.kv_cache_dtype.startswith(
"fp8"
):
self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
is_per_head = (
hasattr(self, "q_scale") and self.q_scale.numel() == self.num_kv_heads
)
block_size = self.head_size * self.num_heads // self.num_kv_heads
self.query_quant = QuantFP8(
static=True,
group_shape=GroupShape(-1, block_size)
if is_per_head
else GroupShape.PER_TENSOR,
)
def forward(
self,
@@ -524,13 +529,9 @@ class MLAAttention(nn.Module, AttentionLayerBase):
self.quant_config = quant_config
# Initialize KV cache quantization attributes
_init_kv_cache_quant(
self,
self.quant_config,
self.layer_name,
kv_cache_dtype,
calculate_kv_scales,
)
self.kv_cache_dtype = kv_cache_dtype
self.calculate_kv_scales = calculate_kv_scales
_init_kv_cache_quant(self, quant_config, prefix)
dtype = torch.get_default_dtype()
self.attn_backend = get_attn_backend(

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import suppress
from functools import partial
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
import torch
@@ -19,6 +20,10 @@ from compressed_tensors.transform import TransformConfig
import vllm.envs as envs
from vllm.attention.layer import Attention
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (
@@ -87,6 +92,8 @@ class CompressedTensorsConfig(QuantizationConfig):
kv_cache_scheme: dict[str, Any] | None = None,
config: dict[str, Any] | None = None,
transform_config: dict[str, Any] | None = None,
total_num_heads: int | None = None,
total_num_kv_heads: int | None = None,
):
super().__init__()
self.ignore = ignore
@@ -97,6 +104,8 @@ class CompressedTensorsConfig(QuantizationConfig):
self.sparsity_scheme_map = sparsity_scheme_map
self.sparsity_ignore_list = sparsity_ignore_list
self.config = config
self.total_num_heads = total_num_heads
self.total_num_kv_heads = total_num_kv_heads
if transform_config:
self.transform_config = TransformConfig.model_validate(transform_config)
@@ -200,13 +209,29 @@ class CompressedTensorsConfig(QuantizationConfig):
@classmethod
def from_config(cls, config: dict[str, Any]) -> "CompressedTensorsConfig":
# We keep only config groups which are not doing Attention quantization
# because Attention quantization on its own is not supported by vLLM.
# It is coupled with KV-cache quantization, and if scales are present in the
# checkpoint, they will be used properly.
grps_without_attn_quant = {}
for k, v in config["config_groups"].items():
# e.g. LlamaAttention, Qwen3Attention, etc.
if len(v["targets"]) == 1 and v["targets"][0].endswith("Attention"):
logger.warning(
"Skipping CompressedTensors config group for %s. Attention quant "
"is coupled with KV-cache quantization in vLLM.",
v["targets"][0],
)
continue
grps_without_attn_quant[k] = v
config["config_groups"] = grps_without_attn_quant
ignore: list[str] = cast(list[str], config.get("ignore", []))
quant_format = cast(str, config.get("format"))
target_scheme_map = cls._quantization_scheme_map_from_config(config=config)
sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config(
config=config
)
transform_config = config.get("transform_config")
return cls(
target_scheme_map=target_scheme_map,
@@ -215,7 +240,10 @@ class CompressedTensorsConfig(QuantizationConfig):
sparsity_scheme_map=sparsity_scheme_map,
sparsity_ignore_list=sparsity_ignore_list,
config=config,
transform_config=transform_config,
transform_config=config.get("transform_config"),
kv_cache_scheme=config.get("kv_cache_scheme"),
total_num_heads=config.get("total_num_heads"),
total_num_kv_heads=config.get("total_num_kv_heads"),
)
@classmethod
@@ -791,22 +819,6 @@ class CompressedTensorsConfig(QuantizationConfig):
return None
def get_cache_scale(self, name: str) -> str | None:
"""
Check whether the param name matches the format for k/v cache scales
in compressed-tensors. If this is the case, return its equivalent
param name expected by vLLM
:param name: param name
:return: matching param name for KV cache scale in vLLM
"""
if name.endswith(".output_scale") and ".k_proj" in name:
return name.replace(".k_proj.output_scale", ".attn.k_scale")
if name.endswith(".output_scale") and ".v_proj" in name:
return name.replace(".v_proj.output_scale", ".attn.v_scale")
# If no matches, return None
return None
def has_blocked_weights(self) -> bool:
for scheme in self.target_scheme_map.values():
weight_quant = scheme.get("weights")
@@ -965,12 +977,16 @@ class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
f"received num_bits={num_bits}, type={type_}"
)
# TODO: delegate validation to compressed-tensors library so that we have a
# single source of truth. Right now this is not possible until the next release
# of compressed-tensors.
strategy = kv_cache_scheme.get("strategy")
if strategy != "tensor":
supported_strategies = ("tensor", "attn_head")
if strategy not in supported_strategies:
raise NotImplementedError(
"Only support per-tensor scaling factor "
"for compressed-tensors KV cache. "
f"Expected strategy: tensor, found strategy: {strategy}"
"Invalid strategy for compressed-tensors KV cache. "
f"Expected strategies: {supported_strategies}, found strategy:"
f" {strategy}"
)
is_symmetric = kv_cache_scheme.get("symmetric")
@@ -980,3 +996,133 @@ class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
"for compressed-tensors KV cache. "
f"However found symmetric: {is_symmetric}"
)
def create_weights(self, layer: torch.nn.Module):
"""
Initialize placeholder scales and zero points to enable loading of
quantized params from compressed-tensors checkpoints.
"""
strategy = None # for backward compatibility
if (
hasattr(self.quant_config, "kv_cache_scheme")
and self.quant_config.kv_cache_scheme is not None
):
strategy = self.quant_config.kv_cache_scheme["strategy"]
if strategy == "attn_head":
assert layer.impl.supports_per_head_quant_scales, (
f"Layer {layer.__class__.__name__} with implementation "
f"{layer.impl.__class__.__name__} does not support per-head scales."
)
n_scales = int(layer.num_kv_heads)
else:
n_scales = 1
layer.k_scale = torch.nn.Parameter(
torch.ones(n_scales, requires_grad=False, dtype=torch.float32)
)
layer.v_scale = torch.nn.Parameter(
torch.ones(n_scales, requires_grad=False, dtype=torch.float32)
)
layer.q_scale = torch.nn.Parameter(
torch.ones(n_scales, requires_grad=False, dtype=torch.float32)
)
# Zero points are not used in vLLM as currently only symmetric quantization is
# supported. We need to create them here to enable loading of llm-compressor
# checkpoints which contain them irrespective of the symmetric/asymmetric
# scheme used during quantization.
layer.k_zero_point = torch.nn.Parameter(
torch.zeros(n_scales, requires_grad=False)
)
layer.v_zero_point = torch.nn.Parameter(
torch.zeros(n_scales, requires_grad=False)
)
layer.q_zero_point = torch.nn.Parameter(
torch.zeros(n_scales, requires_grad=False)
)
# TP-aware loading for attn_head strategy follows attention head partitioning:
# - q_scale is partitioned over query heads.
# - k/v_scale is partitioned over kv heads when total_kv_heads >= tp_size,
# and replicated when total_kv_heads < tp_size.
if strategy == "attn_head":
def _tp_aware_loader(
param: torch.Tensor,
loaded_weight: torch.Tensor,
kind: Literal["q", "k", "v"],
param_type: Literal["scale", "zero_point"],
):
# Zero-points are not used as vLLM only supports symmetric quantization
if param_type == "zero_point":
return
# LLM-Compressor stores scales as 3D tensors of shape [num_heads, 1, 1]
loaded_weight = loaded_weight.flatten()
# FlashAttn expects [num_kv_heads] instead of [num_heads] for q_scale.
# We reduce by taking the max scale in each attention head group.
if kind == "q":
reduction_factor = (
self.quant_config.total_num_heads # type: ignore[attr-defined]
// self.quant_config.total_num_kv_heads # type: ignore[attr-defined]
)
loaded_weight = torch.amax(
loaded_weight.view(-1, reduction_factor), dim=1
)
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
if layer.num_kv_heads * tp_size == self.quant_config.total_num_kv_heads: # type: ignore[attr-defined]
# heads evenly distributed
loaded_weight = loaded_weight[
tp_rank * layer.num_kv_heads : (tp_rank + 1)
* layer.num_kv_heads
]
else:
# heads replicated to match TP size
assert layer.num_kv_heads == 1
replicas = tp_size // self.quant_config.total_num_kv_heads # type: ignore[attr-defined]
shard_rank = tp_rank // replicas
loaded_weight = loaded_weight[shard_rank : shard_rank + 1]
param.data.copy_(loaded_weight.to(dtype=param.dtype))
layer.q_scale.weight_loader = partial(
_tp_aware_loader, kind="q", param_type="scale"
)
layer.k_scale.weight_loader = partial(
_tp_aware_loader, kind="k", param_type="scale"
)
layer.v_scale.weight_loader = partial(
_tp_aware_loader, kind="v", param_type="scale"
)
layer.q_zero_point.weight_loader = partial(
_tp_aware_loader, kind="q", param_type="zero_point"
)
layer.k_zero_point.weight_loader = partial(
_tp_aware_loader, kind="k", param_type="zero_point"
)
layer.v_zero_point.weight_loader = partial(
_tp_aware_loader, kind="v", param_type="zero_point"
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""
Override the default vLLM placeholder scales with the llm-compressor loaded
scales. Zero points are not used as only symmetric quantization is supported.
"""
layer._k_scale = layer.k_scale
layer._v_scale = layer.v_scale
layer._q_scale = layer.q_scale
# Discard all placeholders.
del layer.k_scale
del layer.v_scale
del layer.q_scale
del layer.k_zero_point
del layer.v_zero_point
del layer.q_zero_point

View File

@@ -11,6 +11,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
get_fp8_min_max,
group_broadcast,
prep_scale_for_group_broadcast,
)
from vllm.platforms import current_platform
@@ -40,7 +41,7 @@ class QuantFP8(CustomOp):
"""
:param static: static or dynamic quantization
:param group_shape: quantization group shape (PER_TOKEN, PER_TENSOR,
or arbitrary block size)
PER_CHANNEL, or arbitrary block size)
:param num_token_padding: Pad the token dimension of output to this
size
:param column_major_scales: For group quantization, output scales in
@@ -157,6 +158,8 @@ class QuantFP8(CustomOp):
x_max = x.abs().max().unsqueeze(-1).to(torch.float32)
scale = (x_max / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR)
else:
scale = prep_scale_for_group_broadcast(scale, x, self.group_shape)
# Even for dynamic per-token scales,
# reciprocal performs slightly better than division

View File

@@ -191,6 +191,51 @@ def group_broadcast(t, shape):
return t
def prep_scale_for_group_broadcast(
scale: torch.Tensor,
x: torch.Tensor,
group_shape: GroupShape | None,
) -> torch.Tensor:
"""
Prepare the input quantization scale for group broadcasting.
Args:
scale: The scale tensor (scalar or 1D).
x: Target tensor whose shape determines broadcast dimensions.
group_shape: GroupShape to broadcast over.
Returns:
scale reshaped for correct broadcasting.
"""
if scale.numel() == 1:
# For per-tensor quant, keep the scale as a scalar (not reshaped to (1, 1)).
# This avoids misclassifying it as channelwise quant in Fp8LinearOp.apply,
# where the "per_tensor_activations" check relies on "x_scale.dim() < 2":
# per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2
# For all other cases, reshape scalar scales to (1, 1) for broadcasting.
return (
scale
if group_shape is not None and group_shape.is_per_tensor()
else scale.reshape(1, 1)
)
if scale.ndim == 1:
assert group_shape is not None, (
"group_shape must be provided to correctly broadcast 1D scale"
)
rows, cols = _normalize_quant_group_shape(x, group_shape)
# Determine broadcasting dimension: either rows or columns match group size
if rows == x.shape[-2]:
scale = scale.unsqueeze(-2)
elif cols == x.shape[-1]:
scale = scale.unsqueeze(-1)
else:
raise ValueError(
f"1D scale with shape {scale.shape} cannot be broadcast to x with shape"
f" {x.shape}, group_shape={(rows, cols)}"
)
return scale
# Quantize assuming once scale per group of elements with shape group_shape,
# example group shapes:
# * (-1, -1) for per-tensor quantization
@@ -241,7 +286,7 @@ def scaled_quantize(
_, fp8_max = get_fp8_min_max()
scale = fp8_max / amax
# Apply scale and convert form:
# Apply scale and convert from:
# (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N) to (M, N)
x_scl_sat = (
(x_blkd_permd * scale.unsqueeze(-1))
@@ -261,29 +306,7 @@ def scaled_dequantize(
group_shape: GroupShape | None = None,
out_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
if group_shape is not None:
group_shape = _normalize_quant_group_shape(x_q, group_shape)
if x_s.numel() == 1: # scalar
x_s = x_s.reshape(1, 1) # normalize all scalar-like tensors to (1, 1)
if x_s.ndim == 1:
if group_shape is None:
raise AssertionError(
"if x_s is 1D tensor, group_shape must be provided otherwise "
"its ambiguous which dimension to broadcast x_s to"
)
# unsqueeze the scales for the dimension where we want to broadcast
# across the full extent
if group_shape[0] == x_q.shape[-2]:
x_s = x_s.unsqueeze(-2)
elif group_shape[1] == x_q.shape[-1]:
x_s = x_s.unsqueeze(-1)
else:
raise AssertionError(
"if x_s is a vector we should be broadcasting it to the full "
"extent of one of the dimensions"
)
x_s = prep_scale_for_group_broadcast(x_s, x_q, group_shape)
if group_shape is not None:
assert x_s.shape[-1] == x_q.shape[-1] // group_shape[1]
assert x_s.shape[-2] == x_q.shape[-2] // group_shape[0]

View File

@@ -246,6 +246,23 @@ def get_quant_config(
# compressed-tensors uses a compressions_config
hf_quant_config = getattr(model_config.hf_config, "compression_config", None)
# Pipe information about heads to enable TP-aware loading of attn_head scales
if (
hf_quant_config is not None
and hf_quant_config.get("quant_method") == "compressed-tensors"
):
if hf_text_config is not None:
n_heads = getattr(hf_text_config, "num_attention_heads", None)
n_kv_heads = getattr(hf_text_config, "num_key_value_heads", None)
else:
n_heads = getattr(model_config.hf_config, "num_attention_heads", None)
n_kv_heads = getattr(model_config.hf_config, "num_key_value_heads", None)
hf_quant_config["total_num_heads"] = n_heads
hf_quant_config["total_num_kv_heads"] = (
n_kv_heads if n_kv_heads is not None else n_heads
)
if hf_quant_config is not None:
return quant_cls.from_config(hf_quant_config)
@@ -1157,11 +1174,21 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None:
# .mixer.attn.{k,v}_scale
(r"\.mixer\.[kv]_proj\.([kv])_scale$", r".mixer.attn.\1_scale"),
# Default format: .{k,v}_scale -> .attn.{k,v}_scale
(r"\.([kv])_scale$", r".attn.\1_scale"),
(r"\.([qkv])_scale$", r".attn.\1_scale"),
(r"\.([qkv])_zero_point$", r".attn.\1_zero_point"),
]
# Check if name ends with k_scale or v_scale
if name.endswith((".k_scale", ".v_scale")):
if name.endswith(
(
".k_scale",
".v_scale",
".q_scale",
".k_zero_point",
".v_zero_point",
".q_zero_point",
)
):
import regex as re
for pattern, replacement in scale_mapping_patterns:

View File

@@ -437,7 +437,7 @@ class ApertusModel(nn.Module):
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if "scale" in name:
if "scale" in name or "zero_point" in name:
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:

View File

@@ -303,7 +303,7 @@ class ArceeModel(nn.Module):
loaded_params.add(scale_name)
continue
if "scale" in name:
if "scale" in name or "zero_point" in name:
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
if remapped_name is None:
continue

View File

@@ -465,8 +465,8 @@ class LlamaModel(nn.Module):
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if "scale" in name:
# Remapping the name of FP8 kv-scale.
if "scale" in name or "zero_point" in name:
# Remapping the name of FP8 kv-scale or zero point.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue

View File

@@ -140,8 +140,8 @@ class LlamaModel(nn.Module):
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
# Remapping the name FP8 kv-scale
if "scale" in name:
# Remapping the name FP8 kv-scale or zero point.
if "scale" in name or "zero_point" in name:
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue

View File

@@ -238,8 +238,8 @@ class LlamaModel(nn.Module):
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
# Remapping the name FP8 kv-scale
if "scale" in name:
# Remapping the name FP8 kv-scale or zero point.
if "scale" in name or "zero_point" in name:
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue

View File

@@ -661,7 +661,7 @@ class NemotronHModel(nn.Module):
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "scale" in name:
if "scale" in name or "zero_point" in name:
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:

View File

@@ -342,7 +342,7 @@ class DeciModel(nn.Module):
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if "scale" in name:
if "scale" in name or "zero_point" in name:
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:

View File

@@ -620,6 +620,7 @@ class AttentionImpl(ABC, Generic[T]):
# TODO add support to more backends:
# https://github.com/vllm-project/vllm/issues/25584
supports_quant_query_input: bool = False
supports_per_head_quant_scales: bool = False
dcp_world_size: int
dcp_rank: int

View File

@@ -576,6 +576,11 @@ class FlashAttentionImpl(AttentionImpl):
)
self.supports_quant_query_input = True
self.supports_per_head_quant_scales = (
self.vllm_flash_attn_version >= 3
if self.vllm_flash_attn_version is not None
else False
)
def forward(
self,
@@ -691,6 +696,10 @@ class FlashAttentionImpl(AttentionImpl):
descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
q_descale = layer._q_scale.expand(descale_shape)
k_descale = layer._k_scale.expand(descale_shape)
v_descale = layer._v_scale.expand(descale_shape)
if self.dcp_world_size > 1:
self._forward_with_dcp(
query[:num_actual_tokens],
@@ -700,9 +709,9 @@ class FlashAttentionImpl(AttentionImpl):
value_cache,
output[:num_actual_tokens],
attn_metadata,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
)
return output
else:
@@ -728,9 +737,9 @@ class FlashAttentionImpl(AttentionImpl):
softcap=self.logits_soft_cap,
scheduler_metadata=scheduler_metadata,
fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
num_splits=attn_metadata.max_num_splits,
s_aux=self.sinks,
)