[Feature] KV cache per-token-head INT8/FP8 quantization (#38378)
Signed-off-by: JartX <sagformas@epdcenter.es> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: yangyang4991 <yangyang4991@gmail.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@@ -177,7 +177,7 @@ Priority is **1 = highest** (tried first).
|
||||
| `ROCM_AITER_UNIFIED_ATTN` | | fp16, bf16 | `auto` | %16 | Any | ✅ | ✅ | ❌ | All | N/A |
|
||||
| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | 32, 64, 80, 96, 128, 160, 192, 224, 256 | ❌ | ✅ | ❌ | Decoder, Encoder, Encoder Only | N/A |
|
||||
| `TREE_ATTN` | | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | Any |
|
||||
| `TRITON_ATTN` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ✅ | ❌ | All | Any |
|
||||
| `TRITON_ATTN` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2`, `int8_per_token_head`, `fp8_per_token_head` | %16 | Any | ✅ | ✅ | ❌ | All | Any |
|
||||
|
||||
> **†** FlashInfer uses TRTLLM attention on Blackwell (SM100), which supports sinks. Disable via `--attention-config.use_trtllm_attention=0`.
|
||||
>
|
||||
|
||||
94
tests/models/quantization/test_per_token_kv_cache.py
Normal file
94
tests/models/quantization/test_per_token_kv_cache.py
Normal file
@@ -0,0 +1,94 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""End-to-end accuracy tests for per-token-head KV cache quantization.
|
||||
|
||||
Compares logprobs between a baseline bf16 model and the same model with
|
||||
per-token-head quantized KV cache (int8 or fp8) using the Triton attention
|
||||
backend.
|
||||
|
||||
Run: pytest tests/models/quantization/test_per_token_kv_cache.py -v -s
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..utils import check_logprobs_close
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda_alike(),
|
||||
reason="Per-token-head KV cache requires CUDA or ROCm GPU.",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"base_model,test_model",
|
||||
[
|
||||
(
|
||||
"meta-llama/Llama-3.2-1B-Instruct",
|
||||
"meta-llama/Llama-3.2-1B-Instruct",
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"kv_cache_dtype", ["int8_per_token_head", "fp8_per_token_head"]
|
||||
)
|
||||
@pytest.mark.parametrize("max_tokens", [4])
|
||||
@pytest.mark.parametrize("enforce_eager", [True])
|
||||
@pytest.mark.parametrize("backend", ["TRITON_ATTN"])
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1])
|
||||
def test_per_token_head_kv_cache_accuracy(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
base_model: str,
|
||||
test_model: str,
|
||||
kv_cache_dtype: str,
|
||||
max_tokens: int,
|
||||
enforce_eager: bool,
|
||||
backend: str,
|
||||
tensor_parallel_size: int,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Compare logprobs between bf16 baseline and per-token-head quantized KV
|
||||
cache.
|
||||
|
||||
Uses calculate_kv_scales (dynamic scale computation) since there are
|
||||
no per-token-head calibrated checkpoints available yet.
|
||||
"""
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("TOKENIZERS_PARALLELISM", "true")
|
||||
|
||||
MAX_MODEL_LEN = 1024
|
||||
NUM_LOG_PROBS = 8
|
||||
|
||||
with vllm_runner(
|
||||
base_model,
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
enforce_eager=enforce_eager,
|
||||
kv_cache_dtype="auto",
|
||||
attention_config={"backend": backend},
|
||||
) as vllm_model:
|
||||
baseline_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, NUM_LOG_PROBS
|
||||
)
|
||||
|
||||
with vllm_runner(
|
||||
test_model,
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
enforce_eager=enforce_eager,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
calculate_kv_scales=True,
|
||||
attention_config={"backend": backend},
|
||||
) as vllm_model:
|
||||
test_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, NUM_LOG_PROBS
|
||||
)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=baseline_outputs,
|
||||
outputs_1_lst=test_outputs,
|
||||
name_0="bf16_kv_cache",
|
||||
name_1=f"{kv_cache_dtype}_kv_cache",
|
||||
)
|
||||
560
tests/quantization/test_per_token_kv_cache.py
Normal file
560
tests/quantization/test_per_token_kv_cache.py
Normal file
@@ -0,0 +1,560 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for per-token-head KV cache quantization (INT8 and FP8).
|
||||
|
||||
Covers:
|
||||
- Per-token-head Triton reshape-and-cache kernel
|
||||
- Round-trip quantize/dequantize accuracy
|
||||
- process_weights_after_loading early-return path
|
||||
- End-to-end integration with Triton unified attention kernel
|
||||
|
||||
Run: pytest tests/quantization/test_per_token_kv_cache.py -v -s
|
||||
"""
|
||||
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
get_fp8_min_max,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
from vllm.v1.kv_cache_interface import KVQuantMode, is_quantized_kv_cache
|
||||
|
||||
# Skip entire module if no CUDA/ROCm GPU available
|
||||
pytestmark = [
|
||||
pytest.mark.skipif(
|
||||
not current_platform.is_cuda_alike(),
|
||||
reason="Per-token-head KV cache tests require CUDA or ROCm GPU.",
|
||||
),
|
||||
]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test parameters
|
||||
# ---------------------------------------------------------------------------
|
||||
NUM_TOKENS = [1, 7, 42]
|
||||
NUM_KV_HEADS = [1, 4, 8]
|
||||
HEAD_SIZES = [64, 128]
|
||||
BLOCK_SIZES = [16]
|
||||
SEEDS = [0]
|
||||
|
||||
# Platform-dependent FP8 dtype and range
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP8_MIN, FP8_MAX = get_fp8_min_max()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-dtype quantization config
|
||||
# ---------------------------------------------------------------------------
|
||||
@dataclass(frozen=True)
|
||||
class QuantConfig:
|
||||
"""Quantization parameters for a given cache dtype."""
|
||||
|
||||
cache_dtype: torch.dtype # torch.int8 or FP8_DTYPE
|
||||
kv_cache_dtype_str: str # "int8_per_token_head" or "fp8_per_token_head"
|
||||
quant_max: float
|
||||
quant_min: float
|
||||
kv_quant_mode: KVQuantMode
|
||||
# INT8 Triton stores truncate; FP8 hardware casts round.
|
||||
uses_trunc: bool
|
||||
|
||||
|
||||
INT8_CONFIG = QuantConfig(
|
||||
cache_dtype=torch.int8,
|
||||
kv_cache_dtype_str="int8_per_token_head",
|
||||
quant_max=127.0,
|
||||
quant_min=-128.0,
|
||||
kv_quant_mode=KVQuantMode.INT8_PER_TOKEN_HEAD,
|
||||
uses_trunc=True,
|
||||
)
|
||||
FP8_CONFIG = QuantConfig(
|
||||
cache_dtype=FP8_DTYPE,
|
||||
kv_cache_dtype_str="fp8_per_token_head",
|
||||
quant_max=FP8_MAX,
|
||||
quant_min=FP8_MIN,
|
||||
kv_quant_mode=KVQuantMode.FP8_PER_TOKEN_HEAD,
|
||||
uses_trunc=False,
|
||||
)
|
||||
|
||||
QUANT_CONFIGS = [INT8_CONFIG, FP8_CONFIG]
|
||||
|
||||
|
||||
@pytest.fixture(params=QUANT_CONFIGS, ids=["int8", "fp8"])
|
||||
def qcfg(request) -> QuantConfig:
|
||||
return request.param
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
def _quantize_per_token_head_ref(
|
||||
data: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
cfg: QuantConfig,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Reference per-token-head quantization (one scale per token per head).
|
||||
|
||||
Returns (quantized, scales) where scales is [num_tokens, num_heads].
|
||||
"""
|
||||
absmax = data.float().abs().amax(dim=2) # [num_tokens, num_heads]
|
||||
scales = (absmax / cfg.quant_max).clamp(min=1e-6)
|
||||
scaled = data.float() * (1.0 / scales[:, :, None])
|
||||
if cfg.uses_trunc:
|
||||
q = scaled.round().clamp(cfg.quant_min, cfg.quant_max).to(cfg.cache_dtype)
|
||||
else:
|
||||
q = scaled.clamp(cfg.quant_min, cfg.quant_max).to(cfg.cache_dtype)
|
||||
return q, scales
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 1. is_quantized_kv_cache / get_kv_quant_mode
|
||||
# ===========================================================================
|
||||
class TestIsQuantizedKvCache:
|
||||
def test_fp8_variants(self):
|
||||
assert is_quantized_kv_cache("fp8")
|
||||
assert is_quantized_kv_cache("fp8_e4m3")
|
||||
assert is_quantized_kv_cache("fp8_e5m2")
|
||||
|
||||
def test_int8_per_token_head(self):
|
||||
assert is_quantized_kv_cache("int8_per_token_head")
|
||||
|
||||
def test_fp8_per_token_head(self):
|
||||
assert is_quantized_kv_cache("fp8_per_token_head")
|
||||
|
||||
def test_auto(self):
|
||||
assert not is_quantized_kv_cache("auto")
|
||||
|
||||
def test_bfloat16(self):
|
||||
assert not is_quantized_kv_cache("bfloat16")
|
||||
|
||||
def test_kv_quant_mode_int8(self):
|
||||
from vllm.v1.kv_cache_interface import get_kv_quant_mode
|
||||
|
||||
assert (
|
||||
get_kv_quant_mode("int8_per_token_head") == KVQuantMode.INT8_PER_TOKEN_HEAD
|
||||
)
|
||||
|
||||
def test_kv_quant_mode_fp8(self):
|
||||
from vllm.v1.kv_cache_interface import get_kv_quant_mode
|
||||
|
||||
assert get_kv_quant_mode("fp8_per_token_head") == KVQuantMode.FP8_PER_TOKEN_HEAD
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 2. Triton per-token-head kernel (reshape-and-cache)
|
||||
# ===========================================================================
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_KV_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_reshape_and_cache_per_token_head(
|
||||
qcfg: QuantConfig,
|
||||
num_tokens: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
seed: int,
|
||||
):
|
||||
"""Test triton_reshape_and_cache_flash_per_token_head_quant kernel."""
|
||||
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
|
||||
triton_reshape_and_cache_flash_per_token_head_quant,
|
||||
)
|
||||
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
num_blocks = (num_tokens + block_size - 1) // block_size + 4
|
||||
|
||||
key = torch.randn(num_tokens, num_heads, head_size, dtype=torch.bfloat16)
|
||||
value = torch.randn(num_tokens, num_heads, head_size, dtype=torch.bfloat16)
|
||||
|
||||
key_cache = torch.zeros(
|
||||
num_blocks, block_size, num_heads, head_size, dtype=qcfg.cache_dtype
|
||||
)
|
||||
value_cache = torch.zeros(
|
||||
num_blocks, block_size, num_heads, head_size, dtype=qcfg.cache_dtype
|
||||
)
|
||||
k_scale_cache = torch.ones(num_blocks, block_size, num_heads, dtype=torch.float32)
|
||||
v_scale_cache = torch.ones(num_blocks, block_size, num_heads, dtype=torch.float32)
|
||||
|
||||
num_slots = block_size * num_blocks
|
||||
slot_mapping = torch.tensor(
|
||||
random.sample(range(num_slots), num_tokens), dtype=torch.long
|
||||
)
|
||||
|
||||
triton_reshape_and_cache_flash_per_token_head_quant(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
k_scale_cache,
|
||||
v_scale_cache,
|
||||
slot_mapping,
|
||||
)
|
||||
|
||||
# Reference
|
||||
ref_k_quant, ref_k_scales = _quantize_per_token_head_ref(key, qcfg)
|
||||
ref_v_quant, ref_v_scales = _quantize_per_token_head_ref(value, qcfg)
|
||||
|
||||
# Compare dequantized values rather than raw quantized values.
|
||||
# Triton and PyTorch reductions can differ at FP8 rounding boundaries
|
||||
# (up to 32 in quantized domain for fp8_e4m3), but the dequantized
|
||||
# error is bounded by the scale.
|
||||
for i, slot in enumerate(slot_mapping.tolist()):
|
||||
blk = slot // block_size
|
||||
off = slot % block_size
|
||||
|
||||
actual_k_scale = k_scale_cache[blk, off] # [num_heads]
|
||||
k_deq = key_cache[blk, off].float() * actual_k_scale[:, None]
|
||||
k_ref_deq = key[i].float()
|
||||
torch.testing.assert_close(
|
||||
k_deq,
|
||||
k_ref_deq,
|
||||
atol=0.1,
|
||||
rtol=0.1,
|
||||
)
|
||||
actual_v_scale = v_scale_cache[blk, off] # [num_heads]
|
||||
v_deq = value_cache[blk, off].float() * actual_v_scale[:, None]
|
||||
v_ref_deq = value[i].float()
|
||||
torch.testing.assert_close(
|
||||
v_deq,
|
||||
v_ref_deq,
|
||||
atol=0.1,
|
||||
rtol=0.1,
|
||||
)
|
||||
# Per-head scales: [num_heads]
|
||||
torch.testing.assert_close(
|
||||
k_scale_cache[blk, off], ref_k_scales[i], atol=1e-4, rtol=1e-3
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
v_scale_cache[blk, off], ref_v_scales[i], atol=1e-4, rtol=1e-3
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 3. Per-token-head round-trip accuracy (quantize -> dequantize)
|
||||
# ===========================================================================
|
||||
@pytest.mark.parametrize("num_tokens", [1, 16])
|
||||
@pytest.mark.parametrize("num_heads", [4])
|
||||
@pytest.mark.parametrize("head_size", [128])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@torch.inference_mode()
|
||||
def test_per_token_head_round_trip_accuracy(
|
||||
qcfg: QuantConfig,
|
||||
num_tokens: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
):
|
||||
"""Verify per-token-head round-trip: kernel dequant matches reference.
|
||||
|
||||
INT8: Triton truncates on float->int8 store.
|
||||
FP8: hardware cast (clamp then cast).
|
||||
"""
|
||||
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
|
||||
triton_reshape_and_cache_flash_per_token_head_quant,
|
||||
)
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
set_random_seed(42)
|
||||
|
||||
num_blocks = (num_tokens + block_size - 1) // block_size + 2
|
||||
|
||||
key = torch.randn(num_tokens, num_heads, head_size, dtype=torch.bfloat16) * 0.5
|
||||
value = torch.randn(num_tokens, num_heads, head_size, dtype=torch.bfloat16) * 0.5
|
||||
|
||||
key_cache = torch.zeros(
|
||||
num_blocks, block_size, num_heads, head_size, dtype=qcfg.cache_dtype
|
||||
)
|
||||
value_cache = torch.zeros(
|
||||
num_blocks, block_size, num_heads, head_size, dtype=qcfg.cache_dtype
|
||||
)
|
||||
k_scale_cache = torch.ones(num_blocks, block_size, num_heads, dtype=torch.float32)
|
||||
v_scale_cache = torch.ones(num_blocks, block_size, num_heads, dtype=torch.float32)
|
||||
|
||||
slot_mapping = torch.arange(num_tokens, dtype=torch.long)
|
||||
|
||||
triton_reshape_and_cache_flash_per_token_head_quant(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
k_scale_cache,
|
||||
v_scale_cache,
|
||||
slot_mapping,
|
||||
)
|
||||
|
||||
for i in range(num_tokens):
|
||||
blk = i // block_size
|
||||
off = i % block_size
|
||||
|
||||
for label, data, cache, sc in [
|
||||
("key", key, key_cache, k_scale_cache),
|
||||
("val", value, value_cache, v_scale_cache),
|
||||
]:
|
||||
for h in range(num_heads):
|
||||
orig = data[i, h].float() # [head_size]
|
||||
|
||||
actual_q = cache[blk, off, h]
|
||||
actual_sc = sc[blk, off, h]
|
||||
actual_deq = actual_q.float() * actual_sc
|
||||
|
||||
# Round-trip: dequantized should be close to original
|
||||
torch.testing.assert_close(
|
||||
actual_deq,
|
||||
orig,
|
||||
atol=0.1,
|
||||
rtol=0.1,
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 4. Negative slot mapping (padding tokens should be skipped)
|
||||
# ===========================================================================
|
||||
@torch.inference_mode()
|
||||
def test_per_token_head_negative_slot_skipped(qcfg: QuantConfig):
|
||||
"""Tokens with slot_mapping=-1 should leave the cache unchanged."""
|
||||
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
|
||||
triton_reshape_and_cache_flash_per_token_head_quant,
|
||||
)
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
num_tokens = 4
|
||||
num_heads = 2
|
||||
head_size = 64
|
||||
block_size = 16
|
||||
num_blocks = 2
|
||||
|
||||
key = torch.randn(num_tokens, num_heads, head_size, dtype=torch.bfloat16)
|
||||
value = torch.randn(num_tokens, num_heads, head_size, dtype=torch.bfloat16)
|
||||
|
||||
key_cache = torch.zeros(
|
||||
num_blocks, block_size, num_heads, head_size, dtype=qcfg.cache_dtype
|
||||
)
|
||||
value_cache = torch.zeros(
|
||||
num_blocks, block_size, num_heads, head_size, dtype=qcfg.cache_dtype
|
||||
)
|
||||
k_scale_cache = torch.ones(num_blocks, block_size, num_heads, dtype=torch.float32)
|
||||
v_scale_cache = torch.ones(num_blocks, block_size, num_heads, dtype=torch.float32)
|
||||
|
||||
slot_mapping = torch.tensor([0, -1, 1, -1], dtype=torch.long)
|
||||
|
||||
key_cache_before = key_cache.clone()
|
||||
val_cache_before = value_cache.clone()
|
||||
|
||||
triton_reshape_and_cache_flash_per_token_head_quant(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
k_scale_cache,
|
||||
v_scale_cache,
|
||||
slot_mapping,
|
||||
)
|
||||
|
||||
# Slots 0 and 1 should have been written (tokens 0 and 2)
|
||||
assert not torch.equal(key_cache[0, 0], key_cache_before[0, 0])
|
||||
assert not torch.equal(key_cache[0, 1], key_cache_before[0, 1])
|
||||
assert not torch.equal(value_cache[0, 0], val_cache_before[0, 0])
|
||||
|
||||
# All other slots should be unchanged
|
||||
assert torch.equal(key_cache[0, 2:], key_cache_before[0, 2:])
|
||||
assert torch.equal(key_cache[1], key_cache_before[1])
|
||||
assert torch.equal(value_cache[0, 2:], val_cache_before[0, 2:])
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 5. process_weights_after_loading -- per-token-head early return
|
||||
# ===========================================================================
|
||||
@pytest.mark.parametrize(
|
||||
"kv_cache_dtype", ["int8_per_token_head", "fp8_per_token_head"]
|
||||
)
|
||||
def test_process_weights_sets_placeholder_scales(kv_cache_dtype: str):
|
||||
"""Per-token-head should set _k_scale=1.0, _v_scale=1.0
|
||||
and delete checkpoint attrs."""
|
||||
from vllm.model_executor.layers.quantization.kv_cache import (
|
||||
BaseKVCacheMethod,
|
||||
)
|
||||
|
||||
layer = MagicMock()
|
||||
layer.kv_cache_dtype = kv_cache_dtype
|
||||
layer.calculate_kv_scales = False
|
||||
layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
|
||||
layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
|
||||
layer.q_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
|
||||
layer.prob_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
|
||||
layer._k_scale = torch.tensor(0.0)
|
||||
layer._v_scale = torch.tensor(0.0)
|
||||
layer._k_scale_float = 0.0
|
||||
layer._v_scale_float = 0.0
|
||||
|
||||
method = BaseKVCacheMethod.__new__(BaseKVCacheMethod)
|
||||
method.quant_config = MagicMock()
|
||||
method.process_weights_after_loading(layer)
|
||||
|
||||
assert layer._k_scale_float == 1.0
|
||||
assert layer._v_scale_float == 1.0
|
||||
assert not hasattr(layer, "k_scale")
|
||||
assert not hasattr(layer, "v_scale")
|
||||
assert not hasattr(layer, "q_scale")
|
||||
assert not hasattr(layer, "prob_scale")
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 6. Triton unified_attention -- per-token-head scale cache (INT8 and FP8)
|
||||
# ===========================================================================
|
||||
@pytest.mark.parametrize(
|
||||
"seq_lens",
|
||||
[
|
||||
[(1, 128)],
|
||||
[(1, 64), (1, 32)],
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("num_heads", [(4, 4)])
|
||||
@pytest.mark.parametrize("head_size", [128])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@torch.inference_mode()
|
||||
def test_triton_unified_attention_per_token_head_scale(
|
||||
qcfg: QuantConfig,
|
||||
seq_lens: list[tuple[int, int]],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
):
|
||||
"""End-to-end: quantized KV with per-token-head scale caches."""
|
||||
from vllm.utils.math_utils import next_power_of_2
|
||||
from vllm.v1.attention.ops.triton_unified_attention import unified_attention
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
set_random_seed(0)
|
||||
|
||||
num_seqs = len(seq_lens)
|
||||
query_lens = [s[0] for s in seq_lens]
|
||||
kv_lens = [s[1] for s in seq_lens]
|
||||
num_query_heads, num_kv_heads = num_heads
|
||||
max_query_len = max(query_lens)
|
||||
max_kv_len = max(kv_lens)
|
||||
scale = head_size**-0.5
|
||||
num_blocks = 2048
|
||||
|
||||
query = torch.randn(
|
||||
sum(query_lens), num_query_heads, head_size, dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
key_cache_bf16 = torch.randn(
|
||||
num_blocks, block_size, num_kv_heads, head_size, dtype=torch.bfloat16
|
||||
)
|
||||
value_cache_bf16 = torch.randn_like(key_cache_bf16)
|
||||
|
||||
# Per-token-head quantization: one scale per (block, slot, head)
|
||||
k_absmax = key_cache_bf16.float().abs().amax(dim=-1) # [..., num_kv_heads]
|
||||
v_absmax = value_cache_bf16.float().abs().amax(dim=-1)
|
||||
k_scale_cache = (k_absmax / qcfg.quant_max).clamp(min=1e-6).to(torch.float32)
|
||||
v_scale_cache = (v_absmax / qcfg.quant_max).clamp(min=1e-6).to(torch.float32)
|
||||
|
||||
scaled_k = key_cache_bf16.float() / k_scale_cache[:, :, :, None]
|
||||
scaled_v = value_cache_bf16.float() / v_scale_cache[:, :, :, None]
|
||||
if qcfg.uses_trunc:
|
||||
key_cache_q = (
|
||||
scaled_k.round().clamp(qcfg.quant_min, qcfg.quant_max).to(qcfg.cache_dtype)
|
||||
)
|
||||
value_cache_q = (
|
||||
scaled_v.round().clamp(qcfg.quant_min, qcfg.quant_max).to(qcfg.cache_dtype)
|
||||
)
|
||||
else:
|
||||
key_cache_q = scaled_k.clamp(qcfg.quant_min, qcfg.quant_max).to(
|
||||
qcfg.cache_dtype
|
||||
)
|
||||
value_cache_q = scaled_v.clamp(qcfg.quant_min, qcfg.quant_max).to(
|
||||
qcfg.cache_dtype
|
||||
)
|
||||
|
||||
# Dequantized reference
|
||||
key_cache_deq = key_cache_q.float() * k_scale_cache[:, :, :, None]
|
||||
value_cache_deq = value_cache_q.float() * v_scale_cache[:, :, :, None]
|
||||
|
||||
cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum(
|
||||
dim=0, dtype=torch.int32
|
||||
)
|
||||
kv_lens_t = torch.tensor(kv_lens, dtype=torch.int32)
|
||||
|
||||
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(
|
||||
0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
|
||||
head_size_padded = next_power_of_2(head_size)
|
||||
seq_threshold_3D = 0
|
||||
num_par_softmax_segments = 16
|
||||
softmax_segm_output = torch.empty(
|
||||
(seq_threshold_3D, num_query_heads, num_par_softmax_segments, head_size_padded),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
softmax_segm_max = torch.empty(
|
||||
(seq_threshold_3D, num_query_heads, num_par_softmax_segments),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
softmax_segm_expsum = torch.empty(
|
||||
(seq_threshold_3D, num_query_heads, num_par_softmax_segments),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
output_q = torch.empty_like(query)
|
||||
unified_attention(
|
||||
q=query,
|
||||
k=key_cache_q,
|
||||
v=value_cache_q,
|
||||
out=output_q,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
seqused_k=kv_lens_t,
|
||||
max_seqlen_q=max_query_len,
|
||||
max_seqlen_k=max_kv_len,
|
||||
softmax_scale=scale,
|
||||
causal=True,
|
||||
window_size=(-1, -1),
|
||||
block_table=block_tables,
|
||||
softcap=0,
|
||||
q_descale=None,
|
||||
k_descale=None,
|
||||
v_descale=None,
|
||||
seq_threshold_3D=seq_threshold_3D,
|
||||
num_par_softmax_segments=num_par_softmax_segments,
|
||||
softmax_segm_output=softmax_segm_output,
|
||||
softmax_segm_max=softmax_segm_max,
|
||||
softmax_segm_expsum=softmax_segm_expsum,
|
||||
kv_quant_mode=qcfg.kv_quant_mode,
|
||||
k_scale_cache=k_scale_cache,
|
||||
v_scale_cache=v_scale_cache,
|
||||
)
|
||||
|
||||
output_ref = torch.empty_like(query)
|
||||
unified_attention(
|
||||
q=query,
|
||||
k=key_cache_deq.to(torch.bfloat16),
|
||||
v=value_cache_deq.to(torch.bfloat16),
|
||||
out=output_ref,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
seqused_k=kv_lens_t,
|
||||
max_seqlen_q=max_query_len,
|
||||
max_seqlen_k=max_kv_len,
|
||||
softmax_scale=scale,
|
||||
causal=True,
|
||||
window_size=(-1, -1),
|
||||
block_table=block_tables,
|
||||
softcap=0,
|
||||
q_descale=None,
|
||||
k_descale=None,
|
||||
v_descale=None,
|
||||
seq_threshold_3D=seq_threshold_3D,
|
||||
num_par_softmax_segments=num_par_softmax_segments,
|
||||
softmax_segm_output=softmax_segm_output,
|
||||
softmax_segm_max=softmax_segm_max,
|
||||
softmax_segm_expsum=softmax_segm_expsum,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(output_q, output_ref, atol=5e-2, rtol=5e-2)
|
||||
@@ -8,7 +8,10 @@ from pydantic import Field, SkipValidation, field_validator, model_validator
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.torch_utils import is_quantized_kv_cache
|
||||
from vllm.utils.torch_utils import (
|
||||
is_quantized_kv_cache,
|
||||
kv_cache_uses_per_token_head_scales,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -21,6 +24,8 @@ CacheDType = Literal[
|
||||
"fp8_e5m2",
|
||||
"fp8_inc",
|
||||
"fp8_ds_mla",
|
||||
"int8_per_token_head",
|
||||
"fp8_per_token_head",
|
||||
]
|
||||
MambaDType = Literal["auto", "float32", "float16"]
|
||||
MambaCacheMode = Literal["all", "align", "none"]
|
||||
@@ -237,12 +242,20 @@ class CacheConfig:
|
||||
@field_validator("cache_dtype", mode="after")
|
||||
@classmethod
|
||||
def _validate_cache_dtype(cls, cache_dtype: CacheDType) -> CacheDType:
|
||||
if is_quantized_kv_cache(cache_dtype):
|
||||
if kv_cache_uses_per_token_head_scales(cache_dtype):
|
||||
logger.info(
|
||||
"Using fp8 data type to store kv cache. It reduces the GPU "
|
||||
"Using %s data type to store kv cache. It reduces the GPU "
|
||||
"memory footprint and boosts the performance. "
|
||||
"Dynamic per-token-head scales will be computed at runtime.",
|
||||
str(cache_dtype),
|
||||
)
|
||||
elif is_quantized_kv_cache(cache_dtype):
|
||||
logger.info(
|
||||
"Using %s data type to store kv cache. It reduces the GPU "
|
||||
"memory footprint and boosts the performance. "
|
||||
"Meanwhile, it may cause accuracy drop without a proper "
|
||||
"scaling factor."
|
||||
"scaling factor",
|
||||
str(cache_dtype),
|
||||
)
|
||||
return cache_dtype
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ from vllm.v1.kv_cache_interface import (
|
||||
FullAttentionSpec,
|
||||
KVCacheSpec,
|
||||
SlidingWindowSpec,
|
||||
get_kv_quant_mode,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -381,8 +382,10 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
|
||||
# for attn backends supporting query quantization
|
||||
self.query_quant = None
|
||||
if self.impl.supports_quant_query_input and self.kv_cache_dtype.startswith(
|
||||
"fp8"
|
||||
if (
|
||||
self.impl.supports_quant_query_input
|
||||
and self.kv_cache_dtype.startswith("fp8")
|
||||
and not self.kv_cache_dtype.endswith("per_token_head")
|
||||
):
|
||||
is_per_head = (
|
||||
hasattr(self, "q_scale") and self.q_scale.numel() == self.num_kv_heads
|
||||
@@ -539,6 +542,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
# Should not be called for enc-dec or encoder-only attention.
|
||||
assert self.attn_type == AttentionType.DECODER
|
||||
quant_mode = get_kv_quant_mode(self.kv_cache_dtype)
|
||||
if self.sliding_window is not None:
|
||||
assert not vllm_config.model_config.use_mla, (
|
||||
"MLA is not supported for slidingwindow"
|
||||
@@ -548,6 +552,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
dtype=self.kv_cache_torch_dtype,
|
||||
kv_quant_mode=quant_mode,
|
||||
sliding_window=self.sliding_window,
|
||||
)
|
||||
else:
|
||||
@@ -557,6 +562,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
head_size=self.head_size,
|
||||
head_size_v=self.head_size_v,
|
||||
dtype=self.kv_cache_torch_dtype,
|
||||
kv_quant_mode=quant_mode,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ from vllm.v1.kv_cache_interface import (
|
||||
AttentionSpec,
|
||||
ChunkedLocalAttentionSpec,
|
||||
KVCacheSpec,
|
||||
get_kv_quant_mode,
|
||||
)
|
||||
|
||||
|
||||
@@ -123,5 +124,6 @@ class ChunkedLocalAttention(Attention):
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
dtype=self.kv_cache_torch_dtype,
|
||||
kv_quant_mode=get_kv_quant_mode(self.kv_cache_dtype),
|
||||
attention_chunk_size=self.attention_chunk_size,
|
||||
)
|
||||
|
||||
@@ -18,7 +18,11 @@ from vllm.v1.attention.backend import (
|
||||
subclass_attention_backend_with_overrides,
|
||||
)
|
||||
from vllm.v1.attention.selector import get_attn_backend
|
||||
from vllm.v1.kv_cache_interface import CrossAttentionSpec, KVCacheSpec
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
CrossAttentionSpec,
|
||||
KVCacheSpec,
|
||||
get_kv_quant_mode,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -220,4 +224,5 @@ class CrossAttention(Attention):
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
dtype=self.kv_cache_torch_dtype,
|
||||
kv_quant_mode=get_kv_quant_mode(self.kv_cache_dtype),
|
||||
)
|
||||
|
||||
@@ -26,6 +26,7 @@ from vllm.v1.kv_cache_interface import (
|
||||
AttentionSpec,
|
||||
KVCacheSpec,
|
||||
SinkFullAttentionSpec,
|
||||
get_kv_quant_mode,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -217,6 +218,7 @@ class StaticSinkAttention(Attention, CustomOp):
|
||||
head_size_v=self.head_size_v,
|
||||
sink_len=self.sink_len,
|
||||
dtype=self.kv_cache_torch_dtype,
|
||||
kv_quant_mode=get_kv_quant_mode(self.kv_cache_dtype),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import is_quantized_kv_cache
|
||||
from vllm.v1.kv_cache_interface import kv_cache_uses_per_token_head_scales
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -53,6 +54,20 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
||||
assert not hasattr(layer, "prob_scale")
|
||||
return
|
||||
|
||||
# Per-token-head quantized KV cache: scales are computed dynamically
|
||||
# per (token, head) in the kernel at cache-write time. Checkpoint
|
||||
# scales are never used regardless of calculate_kv_scales.
|
||||
if kv_cache_uses_per_token_head_scales(layer.kv_cache_dtype):
|
||||
layer._k_scale.copy_(1.0)
|
||||
layer._v_scale.copy_(1.0)
|
||||
layer._k_scale_float = 1.0
|
||||
layer._v_scale_float = 1.0
|
||||
del layer.k_scale
|
||||
del layer.v_scale
|
||||
del layer.q_scale
|
||||
del layer.prob_scale
|
||||
return
|
||||
|
||||
# If the kv-cache is not quantized, we enforce the k/v_scale to be 1.0
|
||||
# regardless whether the kv-scale is available in the checkpoint.
|
||||
# No need to process kv scales after loading if we are going to
|
||||
|
||||
@@ -505,6 +505,7 @@ class Platform:
|
||||
FullAttentionSpec,
|
||||
MambaSpec,
|
||||
MLAAttentionSpec,
|
||||
get_kv_quant_mode,
|
||||
)
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
@@ -516,6 +517,8 @@ class Platform:
|
||||
else:
|
||||
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
||||
|
||||
kv_quant_mode = get_kv_quant_mode(cache_config.cache_dtype)
|
||||
|
||||
# Compute attention page size for 1 token
|
||||
if model_config.use_mla:
|
||||
attn_page_size_1_token = MLAAttentionSpec(
|
||||
@@ -523,6 +526,7 @@ class Platform:
|
||||
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
|
||||
head_size=model_config.get_head_size(),
|
||||
dtype=kv_cache_dtype,
|
||||
kv_quant_mode=kv_quant_mode,
|
||||
).page_size_bytes
|
||||
else:
|
||||
attn_page_size_1_token = FullAttentionSpec(
|
||||
@@ -530,6 +534,7 @@ class Platform:
|
||||
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
|
||||
head_size=model_config.get_head_size(),
|
||||
dtype=kv_cache_dtype,
|
||||
kv_quant_mode=kv_quant_mode,
|
||||
).page_size_bytes
|
||||
|
||||
# Compute mamba page size
|
||||
|
||||
@@ -37,6 +37,8 @@ STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"fp8_e4m3": torch.uint8,
|
||||
"fp8_e5m2": torch.uint8,
|
||||
"int8": torch.int8,
|
||||
"int8_per_token_head": torch.int8,
|
||||
"fp8_per_token_head": torch.uint8,
|
||||
"fp8_inc": torch.float8_e4m3fn,
|
||||
"fp8_ds_mla": torch.uint8,
|
||||
}
|
||||
@@ -62,7 +64,12 @@ T = TypeVar("T")
|
||||
|
||||
|
||||
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
|
||||
return kv_cache_dtype.startswith("fp8")
|
||||
return kv_cache_dtype.startswith("fp8") or kv_cache_dtype.endswith("per_token_head")
|
||||
|
||||
|
||||
def kv_cache_uses_per_token_head_scales(kv_cache_dtype: str) -> bool:
|
||||
"""Return True if *kv_cache_dtype* needs per-token-head scales."""
|
||||
return kv_cache_dtype.endswith("per_token_head")
|
||||
|
||||
|
||||
def is_strictly_contiguous(t: torch.Tensor) -> bool:
|
||||
|
||||
@@ -17,7 +17,9 @@ if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backends.utils import KVCacheLayoutType
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, KVQuantMode
|
||||
|
||||
from vllm.v1.kv_cache_interface import get_kv_quant_mode
|
||||
|
||||
|
||||
class AttentionType(str, Enum):
|
||||
@@ -740,6 +742,13 @@ class AttentionImplBase(ABC, Generic[T]):
|
||||
class AttentionImpl(AttentionImplBase[T], Generic[T]):
|
||||
"""Standard attention implementation with forward method."""
|
||||
|
||||
kv_cache_dtype: str
|
||||
|
||||
@property
|
||||
def kv_quant_mode(self) -> "KVQuantMode":
|
||||
"""Return the KV cache quantization mode for this layer."""
|
||||
return get_kv_quant_mode(self.kv_cache_dtype)
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -33,9 +33,14 @@ from vllm.v1.attention.backends.utils import get_kv_cache_layout
|
||||
from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd
|
||||
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
|
||||
triton_reshape_and_cache_flash,
|
||||
triton_reshape_and_cache_flash_per_token_head_quant,
|
||||
)
|
||||
from vllm.v1.attention.ops.triton_unified_attention import unified_attention
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
AttentionSpec,
|
||||
get_kv_quant_mode,
|
||||
kv_cache_uses_per_token_head_scales,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -270,6 +275,8 @@ class TritonAttentionBackend(AttentionBackend):
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
"fp8_e5m2",
|
||||
"int8_per_token_head",
|
||||
"fp8_per_token_head",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
@@ -302,6 +309,18 @@ class TritonAttentionBackend(AttentionBackend):
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
if kv_cache_uses_per_token_head_scales(cache_dtype_str):
|
||||
# Pad head_size by sizeof(float32)/sizeof(cache_dtype) so
|
||||
# the per-head scale fits inline. The backend extracts
|
||||
# data[:head_size] and scale[head_size:] via typed views.
|
||||
from vllm.utils.torch_utils import (
|
||||
STR_DTYPE_TO_TORCH_DTYPE,
|
||||
get_dtype_size,
|
||||
)
|
||||
|
||||
cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype_str]
|
||||
scale_pad = get_dtype_size(torch.float32) // get_dtype_size(cache_dtype)
|
||||
return (num_blocks, 2, block_size, num_kv_heads, head_size + scale_pad)
|
||||
return (num_blocks, 2, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
@@ -365,6 +384,62 @@ class TritonAttentionBackend(AttentionBackend):
|
||||
|
||||
|
||||
class TritonAttentionImpl(AttentionImpl):
|
||||
# Per-token-head quant: scale views carved from inline head padding.
|
||||
_k_scale_cache: torch.Tensor | None = None
|
||||
_v_scale_cache: torch.Tensor | None = None
|
||||
|
||||
def _ensure_scale_caches(self, kv_cache: torch.Tensor) -> None:
|
||||
"""Extract per-head scale views from the padded head dimension.
|
||||
|
||||
The KV cache shape is ``(num_blocks, 2, block_size, nkv, hs+pad)``
|
||||
where ``pad = sizeof(float32) / sizeof(cache_dtype)``. The last
|
||||
``pad`` elements of each head hold one float32 scale. We create
|
||||
strided float32 views over those bytes.
|
||||
|
||||
Scale shape: ``(num_blocks, block_size, num_kv_heads)``
|
||||
"""
|
||||
if self._k_scale_cache is not None:
|
||||
return
|
||||
from vllm.utils.torch_utils import get_dtype_size
|
||||
|
||||
num_blocks, _, block_size, nkv, padded_hs = kv_cache.shape
|
||||
dtype_sz = kv_cache.element_size()
|
||||
scale_pad = get_dtype_size(torch.float32) // dtype_sz # e.g. 4
|
||||
hs = padded_hs - scale_pad
|
||||
|
||||
raw = kv_cache.untyped_storage()
|
||||
base_f32 = torch.tensor([], dtype=torch.float32, device=kv_cache.device).set_(
|
||||
raw
|
||||
)
|
||||
|
||||
# In the raw bytes, each (block, kv_half, slot, head) occupies
|
||||
# padded_hs * dtype_sz bytes. The scale float32 sits at byte
|
||||
# offset hs * dtype_sz within that region.
|
||||
kv_half_bytes = block_size * nkv * padded_hs * dtype_sz
|
||||
full_block_f32 = 2 * kv_half_bytes // 4 # stride between blocks
|
||||
slot_f32 = nkv * padded_hs * dtype_sz // 4 # stride between slots
|
||||
head_f32 = padded_hs * dtype_sz // 4 # stride between heads
|
||||
scale_off_f32 = hs * dtype_sz // 4 # offset to scale within head
|
||||
|
||||
# K scales: kv_half=0
|
||||
self._k_scale_cache = torch.as_strided(
|
||||
base_f32,
|
||||
size=(num_blocks, block_size, nkv),
|
||||
stride=(full_block_f32, slot_f32, head_f32),
|
||||
storage_offset=scale_off_f32,
|
||||
)
|
||||
self._k_scale_cache.fill_(1.0)
|
||||
|
||||
# V scales: kv_half=1, offset by kv_half_bytes
|
||||
v_base_f32 = kv_half_bytes // 4
|
||||
self._v_scale_cache = torch.as_strided(
|
||||
base_f32,
|
||||
size=(num_blocks, block_size, nkv),
|
||||
stride=(full_block_f32, slot_f32, head_f32),
|
||||
storage_offset=v_base_f32 + scale_off_f32,
|
||||
)
|
||||
self._v_scale_cache.fill_(1.0)
|
||||
|
||||
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||
return quant_key == kFp8StaticTensorSym
|
||||
|
||||
@@ -418,6 +493,9 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
self.use_alibi_sqrt = use_alibi_sqrt
|
||||
self.supports_quant_query_input = current_platform.is_cuda()
|
||||
|
||||
self._kv_quant_mode = get_kv_quant_mode(kv_cache_dtype)
|
||||
self._is_per_token_head_quant = self._kv_quant_mode.is_per_token_head
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -480,15 +558,35 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
layer,
|
||||
)
|
||||
|
||||
# For decoder and cross-attention, use KV cache as before
|
||||
key_cache, value_cache = kv_cache.unbind(1)
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
if key_cache.dtype != self.fp8_dtype:
|
||||
# Per-token-head quantized KV cache: use separate scale caches.
|
||||
if self._is_per_token_head_quant:
|
||||
self._ensure_scale_caches(kv_cache)
|
||||
key_cache, value_cache = kv_cache.unbind(1)
|
||||
if key_cache.dtype == torch.uint8:
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
assert layer._q_scale_float == 1.0, (
|
||||
"A non 1.0 q_scale is not currently supported."
|
||||
k_descale = None
|
||||
v_descale = None
|
||||
k_scale_cache = self._k_scale_cache
|
||||
v_scale_cache = self._v_scale_cache
|
||||
# FP8 per-tensor / auto path (original flow).
|
||||
else:
|
||||
key_cache, value_cache = kv_cache.unbind(1)
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
if key_cache.dtype != self.fp8_dtype:
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
assert layer._q_scale_float == 1.0, (
|
||||
"A non 1.0 q_scale is not currently supported."
|
||||
)
|
||||
descale_shape = (
|
||||
attn_metadata.query_start_loc.shape[0] - 1,
|
||||
key_cache.shape[2],
|
||||
)
|
||||
k_descale = layer._k_scale.expand(descale_shape)
|
||||
v_descale = layer._v_scale.expand(descale_shape)
|
||||
k_scale_cache = None
|
||||
v_scale_cache = None
|
||||
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
seqused_k = attn_metadata.seq_lens
|
||||
@@ -502,7 +600,6 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
softmax_segm_max = attn_metadata.softmax_segm_max
|
||||
softmax_segm_expsum = attn_metadata.softmax_segm_expsum
|
||||
|
||||
descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2])
|
||||
mm_prefix_range_tensor = attn_metadata.mm_prefix_range_tensor
|
||||
|
||||
unified_attention(
|
||||
@@ -522,8 +619,8 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
block_table=block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
q_descale=None, # Not supported
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
seq_threshold_3D=seq_threshold_3D,
|
||||
num_par_softmax_segments=num_par_softmax_segments,
|
||||
softmax_segm_output=softmax_segm_output,
|
||||
@@ -532,6 +629,9 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
sinks=self.sinks,
|
||||
output_scale=output_scale,
|
||||
mm_prefix_range=mm_prefix_range_tensor,
|
||||
kv_quant_mode=self._kv_quant_mode,
|
||||
k_scale_cache=k_scale_cache,
|
||||
v_scale_cache=v_scale_cache,
|
||||
)
|
||||
|
||||
return output
|
||||
@@ -555,10 +655,10 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
attn_metadata: Encoder attention metadata
|
||||
layer: The attention layer
|
||||
"""
|
||||
# For encoder attention, process FP8 quantization if needed
|
||||
# Quantized KV cache is not supported for encoder attention.
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"quantization is not supported for encoder attention"
|
||||
"quantized KV cache is not supported for encoder attention"
|
||||
)
|
||||
|
||||
# Use encoder-specific metadata for sequence information
|
||||
@@ -594,16 +694,28 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
# For encoder attention,
|
||||
# we use direct Q, K, V tensors without caching
|
||||
return
|
||||
# For decoder and cross-attention, use KV cache as before
|
||||
key_cache, value_cache = kv_cache.unbind(1)
|
||||
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
if self._is_per_token_head_quant:
|
||||
self._ensure_scale_caches(kv_cache)
|
||||
key_cache, value_cache = kv_cache.unbind(1)
|
||||
if key_cache.dtype == torch.uint8:
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
triton_reshape_and_cache_flash_per_token_head_quant(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
self._k_scale_cache,
|
||||
self._v_scale_cache,
|
||||
slot_mapping,
|
||||
)
|
||||
return
|
||||
# For decoder and cross-attention, use KV cache as before.
|
||||
key_cache, value_cache = kv_cache.unbind(1)
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
# triton kernel does not support uint8 kv_cache
|
||||
# (because some explicit casts (e.g. float8_e4m3fnuz)
|
||||
# are not supported)
|
||||
triton_reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
@@ -616,6 +728,8 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
)
|
||||
|
||||
def fused_rope_kvcache_supported(self):
|
||||
if self._is_per_token_head_quant:
|
||||
return False
|
||||
return rocm_aiter_ops.is_enabled()
|
||||
|
||||
def do_rope_and_kv_cache_update(
|
||||
|
||||
@@ -3,10 +3,16 @@
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
FP8_DTYPE,
|
||||
get_fp8_min_max,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.torch_utils import is_quantized_kv_cache
|
||||
|
||||
FP8_MIN, FP8_MAX = get_fp8_min_max()
|
||||
|
||||
|
||||
@triton.jit
|
||||
def reshape_and_cache_kernel_flash(
|
||||
@@ -118,6 +124,198 @@ def reshape_and_cache_kernel_flash(
|
||||
return
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-token-head dynamic quantization kernel
|
||||
# Grid: (num_tokens, NUM_KV_HEADS)
|
||||
# Each program handles one (token, head) pair:
|
||||
# 1. Loads K (or V) for that single head
|
||||
# 2. Computes absmax across head_size → scale = absmax / QUANT_MAX
|
||||
# 3. Quantizes and stores the data + per-head scale
|
||||
#
|
||||
# Parametrised by QUANT_MAX / QUANT_MIN so the same code path works
|
||||
# for int8 (±127/128), fp8_e4m3 (±448), and other formats.
|
||||
# ---------------------------------------------------------------------------
|
||||
@triton.jit
|
||||
def _reshape_cache_per_token_head(
|
||||
key_ptr, # [num_tokens, num_kv_heads, head_size]
|
||||
value_ptr, # [num_tokens, num_kv_heads, head_size_v]
|
||||
key_cache_ptr, # [num_blocks, block_size, num_kv_heads, head_size]
|
||||
value_cache_ptr, # [num_blocks, block_size, num_kv_heads, head_size_v]
|
||||
k_scale_cache_ptr, # [num_blocks, block_size, num_kv_heads] float32
|
||||
v_scale_cache_ptr, # [num_blocks, block_size, num_kv_heads] float32
|
||||
slot_mapping_ptr, # [num_tokens]
|
||||
stride_key_tok: tl.int64,
|
||||
stride_key_head: tl.int64,
|
||||
stride_val_tok: tl.int64,
|
||||
stride_val_head: tl.int64,
|
||||
stride_kc_blk: tl.int64, # key_cache stride over blocks
|
||||
stride_kc_slot: tl.int64, # key_cache stride over slots
|
||||
stride_kc_head: tl.int64, # key_cache stride over heads
|
||||
stride_vc_blk: tl.int64,
|
||||
stride_vc_slot: tl.int64,
|
||||
stride_vc_head: tl.int64,
|
||||
stride_ks_blk: tl.int64, # k_scale_cache stride[0] (blocks)
|
||||
stride_ks_slot: tl.int64, # k_scale_cache stride[1] (slots)
|
||||
stride_ks_head: tl.int64, # k_scale_cache stride[2] (heads)
|
||||
stride_vs_blk: tl.int64, # v_scale_cache stride[0] (blocks)
|
||||
stride_vs_slot: tl.int64, # v_scale_cache stride[1] (slots)
|
||||
stride_vs_head: tl.int64, # v_scale_cache stride[2] (heads)
|
||||
block_size: tl.constexpr,
|
||||
head_size: tl.constexpr,
|
||||
head_size_v: tl.constexpr,
|
||||
HEAD_SIZE_PADDED: tl.constexpr, # next_power_of_2(max(head_size, head_size_v))
|
||||
QUANT_MAX: tl.constexpr = 127.0,
|
||||
QUANT_MIN: tl.constexpr = -128.0,
|
||||
):
|
||||
tok = tl.program_id(0)
|
||||
head = tl.program_id(1)
|
||||
|
||||
slot = tl.load(slot_mapping_ptr + tok).to(tl.int64)
|
||||
if slot < 0:
|
||||
return
|
||||
|
||||
blk = slot // block_size
|
||||
slot_in_blk = slot % block_size
|
||||
|
||||
dim_offs = tl.arange(0, HEAD_SIZE_PADDED)
|
||||
|
||||
# ---- Key: load one head → absmax → quantize → store -------------------
|
||||
k_mask = dim_offs < head_size
|
||||
k_h = tl.load(
|
||||
key_ptr + tok * stride_key_tok + head * stride_key_head + dim_offs,
|
||||
mask=k_mask,
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
|
||||
k_scale = tl.maximum(tl.max(tl.abs(k_h)) / QUANT_MAX, 1e-6)
|
||||
tl.store(
|
||||
k_scale_cache_ptr
|
||||
+ blk * stride_ks_blk
|
||||
+ slot_in_blk * stride_ks_slot
|
||||
+ head * stride_ks_head,
|
||||
k_scale,
|
||||
)
|
||||
|
||||
k_q = tl.clamp(k_h * (1.0 / k_scale), QUANT_MIN, QUANT_MAX)
|
||||
tl.store(
|
||||
key_cache_ptr
|
||||
+ blk * stride_kc_blk
|
||||
+ slot_in_blk * stride_kc_slot
|
||||
+ head * stride_kc_head
|
||||
+ dim_offs,
|
||||
k_q,
|
||||
mask=k_mask,
|
||||
)
|
||||
|
||||
# ---- Value: same per-head approach ------------------------------------
|
||||
v_mask = dim_offs < head_size_v
|
||||
v_h = tl.load(
|
||||
value_ptr + tok * stride_val_tok + head * stride_val_head + dim_offs,
|
||||
mask=v_mask,
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
|
||||
v_scale = tl.maximum(tl.max(tl.abs(v_h)) / QUANT_MAX, 1e-6)
|
||||
tl.store(
|
||||
v_scale_cache_ptr
|
||||
+ blk * stride_vs_blk
|
||||
+ slot_in_blk * stride_vs_slot
|
||||
+ head * stride_vs_head,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
v_q = tl.clamp(v_h * (1.0 / v_scale), QUANT_MIN, QUANT_MAX)
|
||||
tl.store(
|
||||
value_cache_ptr
|
||||
+ blk * stride_vc_blk
|
||||
+ slot_in_blk * stride_vc_slot
|
||||
+ head * stride_vc_head
|
||||
+ dim_offs,
|
||||
v_q,
|
||||
mask=v_mask,
|
||||
)
|
||||
|
||||
|
||||
# Mapping from cache torch dtype to (QUANT_MAX, QUANT_MIN) for the
|
||||
# per-token-head quantization kernel.
|
||||
_PER_TOKEN_HEAD_QUANT_PARAMS: dict[torch.dtype, tuple[float, float]] = {
|
||||
torch.int8: (127.0, -128.0),
|
||||
FP8_DTYPE: (FP8_MAX, FP8_MIN),
|
||||
}
|
||||
|
||||
|
||||
def triton_reshape_and_cache_flash_per_token_head_quant(
|
||||
key: torch.Tensor, # [num_tokens, num_kv_heads, head_size]
|
||||
value: torch.Tensor, # [num_tokens, num_kv_heads, head_size_v]
|
||||
key_cache: torch.Tensor, # [num_blocks, block_size, num_kv_heads, head_size]
|
||||
value_cache: torch.Tensor, # [num_blocks, block_size, num_kv_heads, head_size_v]
|
||||
k_scale_cache: torch.Tensor, # [num_blocks, block_size, num_kv_heads] float32
|
||||
v_scale_cache: torch.Tensor, # [num_blocks, block_size, num_kv_heads] float32
|
||||
slot_mapping: torch.Tensor, # [num_tokens]
|
||||
):
|
||||
"""Quantize key/value per (token, head) and write to paged cache.
|
||||
|
||||
Computes one scale = absmax / QUANT_MAX per (token, head), stores
|
||||
quantized data in key_cache/value_cache, and stores the float32
|
||||
scale in k_scale_cache/v_scale_cache.
|
||||
|
||||
The quantization range (QUANT_MAX, QUANT_MIN) is derived from the
|
||||
cache tensor dtype so the same code path works for int8 and fp8.
|
||||
"""
|
||||
cache_dtype = key_cache.dtype
|
||||
quant_params = _PER_TOKEN_HEAD_QUANT_PARAMS.get(cache_dtype)
|
||||
if quant_params is None:
|
||||
raise ValueError(
|
||||
f"Per-token-head quantization not supported for cache dtype "
|
||||
f"{cache_dtype}. Supported: {list(_PER_TOKEN_HEAD_QUANT_PARAMS)}"
|
||||
)
|
||||
quant_max, quant_min = quant_params
|
||||
|
||||
num_tokens, num_kv_heads, head_size = key.shape
|
||||
head_size_v = value.shape[2]
|
||||
head_size_padded = triton.next_power_of_2(max(head_size, head_size_v))
|
||||
|
||||
block_size = key_cache.shape[1]
|
||||
|
||||
if current_platform.is_rocm() or current_platform.is_xpu():
|
||||
num_warps = 4
|
||||
else:
|
||||
num_warps = min(16, max(1, head_size_padded // 32))
|
||||
|
||||
_reshape_cache_per_token_head[(num_tokens, num_kv_heads)](
|
||||
key_ptr=key,
|
||||
value_ptr=value,
|
||||
key_cache_ptr=key_cache,
|
||||
value_cache_ptr=value_cache,
|
||||
k_scale_cache_ptr=k_scale_cache,
|
||||
v_scale_cache_ptr=v_scale_cache,
|
||||
slot_mapping_ptr=slot_mapping,
|
||||
stride_key_tok=key.stride(0),
|
||||
stride_key_head=key.stride(1),
|
||||
stride_val_tok=value.stride(0),
|
||||
stride_val_head=value.stride(1),
|
||||
stride_kc_blk=key_cache.stride(0),
|
||||
stride_kc_slot=key_cache.stride(1),
|
||||
stride_kc_head=key_cache.stride(2),
|
||||
stride_vc_blk=value_cache.stride(0),
|
||||
stride_vc_slot=value_cache.stride(1),
|
||||
stride_vc_head=value_cache.stride(2),
|
||||
stride_ks_blk=k_scale_cache.stride(0),
|
||||
stride_ks_slot=k_scale_cache.stride(1),
|
||||
stride_ks_head=k_scale_cache.stride(2),
|
||||
stride_vs_blk=v_scale_cache.stride(0),
|
||||
stride_vs_slot=v_scale_cache.stride(1),
|
||||
stride_vs_head=v_scale_cache.stride(2),
|
||||
block_size=block_size,
|
||||
head_size=head_size,
|
||||
head_size_v=head_size_v,
|
||||
HEAD_SIZE_PADDED=head_size_padded,
|
||||
QUANT_MAX=quant_max,
|
||||
QUANT_MIN=quant_min,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
|
||||
|
||||
def triton_reshape_and_cache_flash(
|
||||
key: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
value: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
@@ -224,7 +422,6 @@ def triton_reshape_and_cache_flash(
|
||||
block_size=block_size,
|
||||
x=x,
|
||||
USE_HEAD_MAJOR_LAYOUT=use_head_major_layout,
|
||||
# FP8 flags
|
||||
FP8_KV_CACHE=FP8_KV_CACHE,
|
||||
# autotune parameters
|
||||
TILE_SIZE=TILE_SIZE,
|
||||
|
||||
@@ -13,6 +13,7 @@ import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.kv_cache_interface import KVQuantMode
|
||||
|
||||
logger = init_logger(__name__)
|
||||
is_batch_invariant = envs.VLLM_BATCH_INVARIANT
|
||||
@@ -32,6 +33,63 @@ def apply_softcap(S, x):
|
||||
return x * (p1 - p2) / (p1 + p2)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _prepare_kv_tile(
|
||||
data,
|
||||
Q,
|
||||
tensor_scale,
|
||||
scale_cache_ptr,
|
||||
physical_block_idx,
|
||||
seq_offset,
|
||||
kv_head_idx,
|
||||
stride_s_blk,
|
||||
stride_s_slot,
|
||||
stride_s_head,
|
||||
tile_mask,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
KV_QUANT_MODE: tl.constexpr,
|
||||
):
|
||||
"""Prepare a loaded KV tile for attention computation.
|
||||
|
||||
Casts the raw KV data to Q's dtype and loads per-token-head scales
|
||||
when applicable:
|
||||
|
||||
- ``KV_QUANT_MODE == 0``: cast only (no-op for bf16/fp16).
|
||||
- ``KV_QUANT_MODE == 1`` (FP8 per-tensor): dequantize inline
|
||||
using the tensor-wide scale.
|
||||
- ``KV_QUANT_MODE >= 2`` (per-token-head int8/fp8): cast to Q's
|
||||
dtype and return per-head scales separately — the caller applies
|
||||
them after the dot product for better numerical efficiency.
|
||||
|
||||
Returns ``(data, token_head_scales)``. *token_head_scales* is only
|
||||
meaningful when ``KV_QUANT_MODE >= 2``; callers gate its use on
|
||||
the same constexpr so the compiler eliminates dead code.
|
||||
"""
|
||||
# KV_QUANT_MODE values: 0=none, 1=fp8 per-tensor,
|
||||
# 2=int8 per-token-head, 3=fp8 per-token-head
|
||||
|
||||
# Placeholder scales (float32) — never read when KV_QUANT_MODE < 2.
|
||||
unused_scales = tile_mask.to(tl.float32)
|
||||
|
||||
if KV_QUANT_MODE == 1: # FP8 per-tensor
|
||||
if Q.dtype.is_fp8():
|
||||
return data.to(Q.dtype), unused_scales
|
||||
return (data.to(tl.float32) * tl.load(tensor_scale)).to(Q.dtype), unused_scales
|
||||
if KV_QUANT_MODE >= 2: # per-token-head (int8 or fp8)
|
||||
scale_idx = (
|
||||
physical_block_idx * stride_s_blk
|
||||
+ (seq_offset % BLOCK_SIZE) * stride_s_slot
|
||||
+ kv_head_idx * stride_s_head
|
||||
)
|
||||
token_head_scales = tl.load(
|
||||
scale_cache_ptr + scale_idx, mask=tile_mask, other=1.0
|
||||
)
|
||||
return data.to(Q.dtype), token_head_scales
|
||||
# .to(Q.dtype) is a no-op when data is already Q's type (bf16/fp16),
|
||||
# but required so Triton sees consistent return types across branches.
|
||||
return data.to(Q.dtype), unused_scales
|
||||
|
||||
|
||||
@triton.jit
|
||||
def find_seq_idx(
|
||||
query_start_len_ptr,
|
||||
@@ -105,8 +163,20 @@ def kernel_unified_attention_2d(
|
||||
num_seqs: tl.int32,
|
||||
BLOCK_M: tl.constexpr, # int
|
||||
USE_FP8: tl.constexpr, # bool
|
||||
# KV cache quantization: 0=none, 1=fp8, 2=per-token-head
|
||||
KV_QUANT_MODE: tl.constexpr = 0,
|
||||
FP8_MIN: tl.constexpr = float8_info.min,
|
||||
FP8_MAX: tl.constexpr = float8_info.max,
|
||||
# Per-token-head scale caches (KV_QUANT_MODE >= 2)
|
||||
# Shape: [num_blocks, block_size, num_kv_heads]
|
||||
k_scale_cache_ptr=None,
|
||||
v_scale_cache_ptr=None,
|
||||
stride_ks_blk=0,
|
||||
stride_ks_slot=0,
|
||||
stride_ks_head=0,
|
||||
stride_vs_blk=0,
|
||||
stride_vs_slot=0,
|
||||
stride_vs_head=0,
|
||||
):
|
||||
q_block_global_idx = tl.program_id(0)
|
||||
kv_head_idx = tl.program_id(1)
|
||||
@@ -258,14 +328,21 @@ def kernel_unified_attention_2d(
|
||||
mask=dim_mask[:, None] & tile_mask[None, :],
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
if K_load.dtype.is_fp8():
|
||||
if Q.dtype.is_fp8():
|
||||
K = K_load
|
||||
else:
|
||||
K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype)
|
||||
else:
|
||||
K = K_load
|
||||
K, k_token_head_scales = _prepare_kv_tile(
|
||||
K_load,
|
||||
Q,
|
||||
k_scale,
|
||||
k_scale_cache_ptr,
|
||||
physical_block_idx,
|
||||
seq_offset,
|
||||
kv_head_idx,
|
||||
stride_ks_blk,
|
||||
stride_ks_slot,
|
||||
stride_ks_head,
|
||||
tile_mask,
|
||||
BLOCK_SIZE,
|
||||
KV_QUANT_MODE,
|
||||
)
|
||||
|
||||
# V : (TILE_SIZE, HEAD_SIZE)
|
||||
V_load = tl.load(
|
||||
@@ -273,14 +350,21 @@ def kernel_unified_attention_2d(
|
||||
mask=dim_mask[None, :] & tile_mask[:, None],
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
if V_load.dtype.is_fp8():
|
||||
if Q.dtype.is_fp8():
|
||||
V = V_load
|
||||
else:
|
||||
V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype)
|
||||
else:
|
||||
V = V_load
|
||||
V, v_token_head_scales = _prepare_kv_tile(
|
||||
V_load,
|
||||
Q,
|
||||
v_scale,
|
||||
v_scale_cache_ptr,
|
||||
physical_block_idx,
|
||||
seq_offset,
|
||||
kv_head_idx,
|
||||
stride_vs_blk,
|
||||
stride_vs_slot,
|
||||
stride_vs_head,
|
||||
tile_mask,
|
||||
BLOCK_SIZE,
|
||||
KV_QUANT_MODE,
|
||||
)
|
||||
|
||||
# Compute attention mask: causal by default (key <= query)
|
||||
query_abs_pos = context_len + query_pos[:, None]
|
||||
@@ -318,7 +402,12 @@ def kernel_unified_attention_2d(
|
||||
# S : (BLOCK_M, TILE_SIZE)
|
||||
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
|
||||
|
||||
S += scale * tl.dot(Q, K)
|
||||
# Per-token-head quant: fuse softmax_scale with per-head k_scale
|
||||
# to avoid a separate BLOCK_M × TILE_SIZE multiply on S.
|
||||
if KV_QUANT_MODE >= 2:
|
||||
S += tl.dot(Q, K) * (scale * k_token_head_scales[None, :])
|
||||
else:
|
||||
S += scale * tl.dot(Q, K)
|
||||
|
||||
if USE_SOFTCAP:
|
||||
S = apply_softcap(S, softcap)
|
||||
@@ -382,7 +471,12 @@ def kernel_unified_attention_2d(
|
||||
)
|
||||
|
||||
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
|
||||
acc += tl.dot(P.to(V.dtype), V)
|
||||
# Per-token-head quant: apply v_scale to P instead of V.
|
||||
if KV_QUANT_MODE >= 2:
|
||||
P_v = (P * v_token_head_scales[None, :]).to(V.dtype)
|
||||
acc += tl.dot(P_v, V)
|
||||
else:
|
||||
acc += tl.dot(P.to(V.dtype), V)
|
||||
|
||||
# epilogue
|
||||
acc = acc / L[:, None]
|
||||
@@ -453,6 +547,18 @@ def kernel_unified_attention_3d(
|
||||
USE_MM_PREFIX: tl.constexpr, # bool
|
||||
MAX_MM_RANGES: tl.constexpr, # int
|
||||
mm_prefix_range_ptr, # [num_seqs] - prefix length for each sequence
|
||||
# KV cache quantization: 0=none, 1=fp8, 2=per-token-head
|
||||
KV_QUANT_MODE: tl.constexpr = 0,
|
||||
# Per-token-head scale caches (KV_QUANT_MODE >= 2)
|
||||
# Shape: [num_blocks, block_size, num_kv_heads]
|
||||
k_scale_cache_ptr=None,
|
||||
v_scale_cache_ptr=None,
|
||||
stride_ks_blk=0,
|
||||
stride_ks_slot=0,
|
||||
stride_ks_head=0,
|
||||
stride_vs_blk=0,
|
||||
stride_vs_slot=0,
|
||||
stride_vs_head=0,
|
||||
):
|
||||
q_block_global_idx = tl.program_id(0)
|
||||
kv_head_idx = tl.program_id(1)
|
||||
@@ -613,14 +719,21 @@ def kernel_unified_attention_3d(
|
||||
mask=dim_mask[:, None] & tile_mask[None, :],
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
if K_load.dtype.is_fp8():
|
||||
if Q.dtype.is_fp8():
|
||||
K = K_load
|
||||
else:
|
||||
K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype)
|
||||
else:
|
||||
K = K_load
|
||||
K, k_token_head_scales = _prepare_kv_tile(
|
||||
K_load,
|
||||
Q,
|
||||
k_scale,
|
||||
k_scale_cache_ptr,
|
||||
physical_block_idx,
|
||||
seq_offset,
|
||||
kv_head_idx,
|
||||
stride_ks_blk,
|
||||
stride_ks_slot,
|
||||
stride_ks_head,
|
||||
tile_mask,
|
||||
BLOCK_SIZE,
|
||||
KV_QUANT_MODE,
|
||||
)
|
||||
|
||||
# V : (TILE_SIZE, HEAD_SIZE)
|
||||
V_load = tl.load(
|
||||
@@ -628,14 +741,21 @@ def kernel_unified_attention_3d(
|
||||
mask=dim_mask[None, :] & tile_mask[:, None],
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
if V_load.dtype.is_fp8():
|
||||
if Q.dtype.is_fp8():
|
||||
V = V_load
|
||||
else:
|
||||
V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype)
|
||||
else:
|
||||
V = V_load
|
||||
V, v_token_head_scales = _prepare_kv_tile(
|
||||
V_load,
|
||||
Q,
|
||||
v_scale,
|
||||
v_scale_cache_ptr,
|
||||
physical_block_idx,
|
||||
seq_offset,
|
||||
kv_head_idx,
|
||||
stride_vs_blk,
|
||||
stride_vs_slot,
|
||||
stride_vs_head,
|
||||
tile_mask,
|
||||
BLOCK_SIZE,
|
||||
KV_QUANT_MODE,
|
||||
)
|
||||
|
||||
# Compute attention mask: causal by default (key <= query)
|
||||
query_abs_pos = context_len + query_pos[:, None]
|
||||
@@ -672,7 +792,13 @@ def kernel_unified_attention_3d(
|
||||
|
||||
# S : (BLOCK_M, TILE_SIZE)
|
||||
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
|
||||
S += scale * tl.dot(Q, K)
|
||||
|
||||
# Per-token-head quant: fuse softmax_scale with per-head k_scale
|
||||
# to avoid a separate BLOCK_M × TILE_SIZE multiply on S.
|
||||
if KV_QUANT_MODE >= 2:
|
||||
S += tl.dot(Q, K) * (scale * k_token_head_scales[None, :])
|
||||
else:
|
||||
S += scale * tl.dot(Q, K)
|
||||
|
||||
if USE_SOFTCAP:
|
||||
S = apply_softcap(S, softcap)
|
||||
@@ -736,7 +862,12 @@ def kernel_unified_attention_3d(
|
||||
)
|
||||
|
||||
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
|
||||
acc += tl.dot(P.to(V.dtype), V)
|
||||
# Per-token-head quant: apply v_scale to P instead of V.
|
||||
if KV_QUANT_MODE >= 2:
|
||||
P_v = (P * v_token_head_scales[None, :]).to(V.dtype)
|
||||
acc += tl.dot(P_v, V)
|
||||
else:
|
||||
acc += tl.dot(P.to(V.dtype), V)
|
||||
|
||||
segm_output_offset = (
|
||||
query_offset_0[:, None].to(tl.int64)
|
||||
@@ -911,6 +1042,10 @@ def unified_attention(
|
||||
# Optional tensor for prefix lengths (PrefixLM support)
|
||||
mm_prefix_range=None,
|
||||
use_alibi_sqrt=False,
|
||||
# KV cache quantization mode and per-token-head scale caches.
|
||||
kv_quant_mode: KVQuantMode = KVQuantMode.NONE,
|
||||
k_scale_cache=None, # [num_blocks, block_size, num_kv_heads] float32
|
||||
v_scale_cache=None, # [num_blocks, block_size, num_kv_heads] float32
|
||||
):
|
||||
assert causal, "Only causal attention is supported"
|
||||
assert q_descale is None, "Q scales not supported"
|
||||
@@ -1040,6 +1175,15 @@ def unified_attention(
|
||||
num_seqs=num_seqs,
|
||||
BLOCK_M=BLOCK_M,
|
||||
USE_FP8=output_scale is not None,
|
||||
KV_QUANT_MODE=kv_quant_mode,
|
||||
k_scale_cache_ptr=k_scale_cache,
|
||||
v_scale_cache_ptr=v_scale_cache,
|
||||
stride_ks_blk=k_scale_cache.stride(0) if k_scale_cache is not None else 0,
|
||||
stride_ks_slot=k_scale_cache.stride(1) if k_scale_cache is not None else 0,
|
||||
stride_ks_head=k_scale_cache.stride(2) if k_scale_cache is not None else 0,
|
||||
stride_vs_blk=v_scale_cache.stride(0) if v_scale_cache is not None else 0,
|
||||
stride_vs_slot=v_scale_cache.stride(1) if v_scale_cache is not None else 0,
|
||||
stride_vs_head=v_scale_cache.stride(2) if v_scale_cache is not None else 0,
|
||||
)
|
||||
else:
|
||||
kernel_unified_attention_3d[
|
||||
@@ -1092,6 +1236,15 @@ def unified_attention(
|
||||
num_seqs=num_seqs,
|
||||
BLOCK_M=BLOCK_M,
|
||||
NUM_SEGMENTS_PER_SEQ=num_par_softmax_segments,
|
||||
KV_QUANT_MODE=kv_quant_mode,
|
||||
k_scale_cache_ptr=k_scale_cache,
|
||||
v_scale_cache_ptr=v_scale_cache,
|
||||
stride_ks_blk=k_scale_cache.stride(0) if k_scale_cache is not None else 0,
|
||||
stride_ks_slot=k_scale_cache.stride(1) if k_scale_cache is not None else 0,
|
||||
stride_ks_head=k_scale_cache.stride(2) if k_scale_cache is not None else 0,
|
||||
stride_vs_blk=v_scale_cache.stride(0) if v_scale_cache is not None else 0,
|
||||
stride_vs_slot=v_scale_cache.stride(1) if v_scale_cache is not None else 0,
|
||||
stride_vs_head=v_scale_cache.stride(2) if v_scale_cache is not None else 0,
|
||||
)
|
||||
reduce_segments[(q.shape[0], num_query_heads)](
|
||||
output_ptr=out,
|
||||
|
||||
@@ -1,21 +1,70 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass, fields, replace
|
||||
from enum import IntEnum
|
||||
from math import prod
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from typing_extensions import Self
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.torch_utils import get_dtype_size
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# KV cache quantization mode
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class KVQuantMode(IntEnum):
|
||||
"""KV cache quantization mode.
|
||||
|
||||
Used by attention backends and kernels to dispatch quantization logic
|
||||
without string matching on ``kv_cache_dtype``.
|
||||
"""
|
||||
|
||||
NONE = 0
|
||||
FP8_PER_TENSOR = 1 # per-tensor scales (current fp8 path)
|
||||
INT8_PER_TOKEN_HEAD = 2 # per-token-head dynamic scales for int8
|
||||
FP8_PER_TOKEN_HEAD = 3 # per-token-head dynamic scales for fp8
|
||||
|
||||
@property
|
||||
def is_per_token_head(self) -> bool:
|
||||
"""True for any per-token-head quantization mode."""
|
||||
return self >= 2
|
||||
|
||||
|
||||
def get_kv_quant_mode(kv_cache_dtype: str) -> KVQuantMode:
|
||||
"""Map a ``kv_cache_dtype`` string to a :class:`KVQuantMode`."""
|
||||
if kv_cache_dtype == "int8_per_token_head":
|
||||
return KVQuantMode.INT8_PER_TOKEN_HEAD
|
||||
if kv_cache_dtype == "fp8_per_token_head":
|
||||
return KVQuantMode.FP8_PER_TOKEN_HEAD
|
||||
if kv_cache_dtype.startswith("fp8"):
|
||||
return KVQuantMode.FP8_PER_TENSOR
|
||||
return KVQuantMode.NONE
|
||||
|
||||
|
||||
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
|
||||
return get_kv_quant_mode(kv_cache_dtype) != KVQuantMode.NONE
|
||||
|
||||
|
||||
def kv_cache_uses_per_token_head_scales(kv_cache_dtype: str) -> bool:
|
||||
"""Return True if *kv_cache_dtype* needs per-token-head scales."""
|
||||
return get_kv_quant_mode(kv_cache_dtype).is_per_token_head
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class KVCacheSpec:
|
||||
"""
|
||||
@@ -66,11 +115,19 @@ class AttentionSpec(KVCacheSpec):
|
||||
num_kv_heads: int
|
||||
head_size: int
|
||||
dtype: torch.dtype
|
||||
kv_quant_mode: KVQuantMode = KVQuantMode.NONE
|
||||
page_size_padded: int | None = None
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
real_page_size = self.real_page_size_bytes
|
||||
# Per-token-head scales are stored in separate tensors managed
|
||||
# by the attention backend, but the memory is carved from the
|
||||
# raw KV cache allocation so it must be budgeted here.
|
||||
if self.kv_quant_mode.is_per_token_head:
|
||||
real_page_size += (
|
||||
2 * self.block_size * self.num_kv_heads * get_dtype_size(torch.float32)
|
||||
)
|
||||
if self.page_size_padded is not None:
|
||||
assert self.page_size_padded >= real_page_size
|
||||
return self.page_size_padded
|
||||
@@ -159,6 +216,7 @@ class FullAttentionSpec(AttentionSpec):
|
||||
head_size=specs[0].head_size,
|
||||
head_size_v=specs[0].head_size_v,
|
||||
dtype=specs[0].dtype,
|
||||
kv_quant_mode=specs[0].kv_quant_mode,
|
||||
page_size_padded=specs[0].page_size_padded,
|
||||
sliding_window=cls.merge_window_sizes(sliding_window),
|
||||
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
|
||||
@@ -220,6 +278,7 @@ class MLAAttentionSpec(FullAttentionSpec):
|
||||
num_kv_heads=specs[0].num_kv_heads,
|
||||
head_size=specs[0].head_size,
|
||||
dtype=specs[0].dtype,
|
||||
kv_quant_mode=specs[0].kv_quant_mode,
|
||||
page_size_padded=specs[0].page_size_padded,
|
||||
cache_dtype_str=cache_dtype_str_set.pop(),
|
||||
)
|
||||
@@ -352,6 +411,7 @@ class SinkFullAttentionSpec(FullAttentionSpec):
|
||||
head_size_v=specs[0].head_size_v,
|
||||
sink_len=specs[0].sink_len,
|
||||
dtype=specs[0].dtype,
|
||||
kv_quant_mode=specs[0].kv_quant_mode,
|
||||
page_size_padded=specs[0].page_size_padded,
|
||||
sliding_window=cls.merge_window_sizes(sliding_window),
|
||||
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
|
||||
|
||||
Reference in New Issue
Block a user