[Feature] KV cache per-token-head INT8/FP8 quantization (#38378)
Signed-off-by: JartX <sagformas@epdcenter.es> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: yangyang4991 <yangyang4991@gmail.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
560
tests/quantization/test_per_token_kv_cache.py
Normal file
560
tests/quantization/test_per_token_kv_cache.py
Normal file
@@ -0,0 +1,560 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for per-token-head KV cache quantization (INT8 and FP8).
|
||||
|
||||
Covers:
|
||||
- Per-token-head Triton reshape-and-cache kernel
|
||||
- Round-trip quantize/dequantize accuracy
|
||||
- process_weights_after_loading early-return path
|
||||
- End-to-end integration with Triton unified attention kernel
|
||||
|
||||
Run: pytest tests/quantization/test_per_token_kv_cache.py -v -s
|
||||
"""
|
||||
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
get_fp8_min_max,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
from vllm.v1.kv_cache_interface import KVQuantMode, is_quantized_kv_cache
|
||||
|
||||
# Skip entire module if no CUDA/ROCm GPU available
|
||||
pytestmark = [
|
||||
pytest.mark.skipif(
|
||||
not current_platform.is_cuda_alike(),
|
||||
reason="Per-token-head KV cache tests require CUDA or ROCm GPU.",
|
||||
),
|
||||
]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test parameters
|
||||
# ---------------------------------------------------------------------------
|
||||
NUM_TOKENS = [1, 7, 42]
|
||||
NUM_KV_HEADS = [1, 4, 8]
|
||||
HEAD_SIZES = [64, 128]
|
||||
BLOCK_SIZES = [16]
|
||||
SEEDS = [0]
|
||||
|
||||
# Platform-dependent FP8 dtype and range
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP8_MIN, FP8_MAX = get_fp8_min_max()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-dtype quantization config
|
||||
# ---------------------------------------------------------------------------
|
||||
@dataclass(frozen=True)
|
||||
class QuantConfig:
|
||||
"""Quantization parameters for a given cache dtype."""
|
||||
|
||||
cache_dtype: torch.dtype # torch.int8 or FP8_DTYPE
|
||||
kv_cache_dtype_str: str # "int8_per_token_head" or "fp8_per_token_head"
|
||||
quant_max: float
|
||||
quant_min: float
|
||||
kv_quant_mode: KVQuantMode
|
||||
# INT8 Triton stores truncate; FP8 hardware casts round.
|
||||
uses_trunc: bool
|
||||
|
||||
|
||||
INT8_CONFIG = QuantConfig(
|
||||
cache_dtype=torch.int8,
|
||||
kv_cache_dtype_str="int8_per_token_head",
|
||||
quant_max=127.0,
|
||||
quant_min=-128.0,
|
||||
kv_quant_mode=KVQuantMode.INT8_PER_TOKEN_HEAD,
|
||||
uses_trunc=True,
|
||||
)
|
||||
FP8_CONFIG = QuantConfig(
|
||||
cache_dtype=FP8_DTYPE,
|
||||
kv_cache_dtype_str="fp8_per_token_head",
|
||||
quant_max=FP8_MAX,
|
||||
quant_min=FP8_MIN,
|
||||
kv_quant_mode=KVQuantMode.FP8_PER_TOKEN_HEAD,
|
||||
uses_trunc=False,
|
||||
)
|
||||
|
||||
QUANT_CONFIGS = [INT8_CONFIG, FP8_CONFIG]
|
||||
|
||||
|
||||
@pytest.fixture(params=QUANT_CONFIGS, ids=["int8", "fp8"])
|
||||
def qcfg(request) -> QuantConfig:
|
||||
return request.param
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
def _quantize_per_token_head_ref(
|
||||
data: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
cfg: QuantConfig,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Reference per-token-head quantization (one scale per token per head).
|
||||
|
||||
Returns (quantized, scales) where scales is [num_tokens, num_heads].
|
||||
"""
|
||||
absmax = data.float().abs().amax(dim=2) # [num_tokens, num_heads]
|
||||
scales = (absmax / cfg.quant_max).clamp(min=1e-6)
|
||||
scaled = data.float() * (1.0 / scales[:, :, None])
|
||||
if cfg.uses_trunc:
|
||||
q = scaled.round().clamp(cfg.quant_min, cfg.quant_max).to(cfg.cache_dtype)
|
||||
else:
|
||||
q = scaled.clamp(cfg.quant_min, cfg.quant_max).to(cfg.cache_dtype)
|
||||
return q, scales
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 1. is_quantized_kv_cache / get_kv_quant_mode
|
||||
# ===========================================================================
|
||||
class TestIsQuantizedKvCache:
|
||||
def test_fp8_variants(self):
|
||||
assert is_quantized_kv_cache("fp8")
|
||||
assert is_quantized_kv_cache("fp8_e4m3")
|
||||
assert is_quantized_kv_cache("fp8_e5m2")
|
||||
|
||||
def test_int8_per_token_head(self):
|
||||
assert is_quantized_kv_cache("int8_per_token_head")
|
||||
|
||||
def test_fp8_per_token_head(self):
|
||||
assert is_quantized_kv_cache("fp8_per_token_head")
|
||||
|
||||
def test_auto(self):
|
||||
assert not is_quantized_kv_cache("auto")
|
||||
|
||||
def test_bfloat16(self):
|
||||
assert not is_quantized_kv_cache("bfloat16")
|
||||
|
||||
def test_kv_quant_mode_int8(self):
|
||||
from vllm.v1.kv_cache_interface import get_kv_quant_mode
|
||||
|
||||
assert (
|
||||
get_kv_quant_mode("int8_per_token_head") == KVQuantMode.INT8_PER_TOKEN_HEAD
|
||||
)
|
||||
|
||||
def test_kv_quant_mode_fp8(self):
|
||||
from vllm.v1.kv_cache_interface import get_kv_quant_mode
|
||||
|
||||
assert get_kv_quant_mode("fp8_per_token_head") == KVQuantMode.FP8_PER_TOKEN_HEAD
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 2. Triton per-token-head kernel (reshape-and-cache)
|
||||
# ===========================================================================
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_KV_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_reshape_and_cache_per_token_head(
|
||||
qcfg: QuantConfig,
|
||||
num_tokens: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
seed: int,
|
||||
):
|
||||
"""Test triton_reshape_and_cache_flash_per_token_head_quant kernel."""
|
||||
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
|
||||
triton_reshape_and_cache_flash_per_token_head_quant,
|
||||
)
|
||||
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
num_blocks = (num_tokens + block_size - 1) // block_size + 4
|
||||
|
||||
key = torch.randn(num_tokens, num_heads, head_size, dtype=torch.bfloat16)
|
||||
value = torch.randn(num_tokens, num_heads, head_size, dtype=torch.bfloat16)
|
||||
|
||||
key_cache = torch.zeros(
|
||||
num_blocks, block_size, num_heads, head_size, dtype=qcfg.cache_dtype
|
||||
)
|
||||
value_cache = torch.zeros(
|
||||
num_blocks, block_size, num_heads, head_size, dtype=qcfg.cache_dtype
|
||||
)
|
||||
k_scale_cache = torch.ones(num_blocks, block_size, num_heads, dtype=torch.float32)
|
||||
v_scale_cache = torch.ones(num_blocks, block_size, num_heads, dtype=torch.float32)
|
||||
|
||||
num_slots = block_size * num_blocks
|
||||
slot_mapping = torch.tensor(
|
||||
random.sample(range(num_slots), num_tokens), dtype=torch.long
|
||||
)
|
||||
|
||||
triton_reshape_and_cache_flash_per_token_head_quant(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
k_scale_cache,
|
||||
v_scale_cache,
|
||||
slot_mapping,
|
||||
)
|
||||
|
||||
# Reference
|
||||
ref_k_quant, ref_k_scales = _quantize_per_token_head_ref(key, qcfg)
|
||||
ref_v_quant, ref_v_scales = _quantize_per_token_head_ref(value, qcfg)
|
||||
|
||||
# Compare dequantized values rather than raw quantized values.
|
||||
# Triton and PyTorch reductions can differ at FP8 rounding boundaries
|
||||
# (up to 32 in quantized domain for fp8_e4m3), but the dequantized
|
||||
# error is bounded by the scale.
|
||||
for i, slot in enumerate(slot_mapping.tolist()):
|
||||
blk = slot // block_size
|
||||
off = slot % block_size
|
||||
|
||||
actual_k_scale = k_scale_cache[blk, off] # [num_heads]
|
||||
k_deq = key_cache[blk, off].float() * actual_k_scale[:, None]
|
||||
k_ref_deq = key[i].float()
|
||||
torch.testing.assert_close(
|
||||
k_deq,
|
||||
k_ref_deq,
|
||||
atol=0.1,
|
||||
rtol=0.1,
|
||||
)
|
||||
actual_v_scale = v_scale_cache[blk, off] # [num_heads]
|
||||
v_deq = value_cache[blk, off].float() * actual_v_scale[:, None]
|
||||
v_ref_deq = value[i].float()
|
||||
torch.testing.assert_close(
|
||||
v_deq,
|
||||
v_ref_deq,
|
||||
atol=0.1,
|
||||
rtol=0.1,
|
||||
)
|
||||
# Per-head scales: [num_heads]
|
||||
torch.testing.assert_close(
|
||||
k_scale_cache[blk, off], ref_k_scales[i], atol=1e-4, rtol=1e-3
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
v_scale_cache[blk, off], ref_v_scales[i], atol=1e-4, rtol=1e-3
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 3. Per-token-head round-trip accuracy (quantize -> dequantize)
|
||||
# ===========================================================================
|
||||
@pytest.mark.parametrize("num_tokens", [1, 16])
|
||||
@pytest.mark.parametrize("num_heads", [4])
|
||||
@pytest.mark.parametrize("head_size", [128])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@torch.inference_mode()
|
||||
def test_per_token_head_round_trip_accuracy(
|
||||
qcfg: QuantConfig,
|
||||
num_tokens: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
):
|
||||
"""Verify per-token-head round-trip: kernel dequant matches reference.
|
||||
|
||||
INT8: Triton truncates on float->int8 store.
|
||||
FP8: hardware cast (clamp then cast).
|
||||
"""
|
||||
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
|
||||
triton_reshape_and_cache_flash_per_token_head_quant,
|
||||
)
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
set_random_seed(42)
|
||||
|
||||
num_blocks = (num_tokens + block_size - 1) // block_size + 2
|
||||
|
||||
key = torch.randn(num_tokens, num_heads, head_size, dtype=torch.bfloat16) * 0.5
|
||||
value = torch.randn(num_tokens, num_heads, head_size, dtype=torch.bfloat16) * 0.5
|
||||
|
||||
key_cache = torch.zeros(
|
||||
num_blocks, block_size, num_heads, head_size, dtype=qcfg.cache_dtype
|
||||
)
|
||||
value_cache = torch.zeros(
|
||||
num_blocks, block_size, num_heads, head_size, dtype=qcfg.cache_dtype
|
||||
)
|
||||
k_scale_cache = torch.ones(num_blocks, block_size, num_heads, dtype=torch.float32)
|
||||
v_scale_cache = torch.ones(num_blocks, block_size, num_heads, dtype=torch.float32)
|
||||
|
||||
slot_mapping = torch.arange(num_tokens, dtype=torch.long)
|
||||
|
||||
triton_reshape_and_cache_flash_per_token_head_quant(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
k_scale_cache,
|
||||
v_scale_cache,
|
||||
slot_mapping,
|
||||
)
|
||||
|
||||
for i in range(num_tokens):
|
||||
blk = i // block_size
|
||||
off = i % block_size
|
||||
|
||||
for label, data, cache, sc in [
|
||||
("key", key, key_cache, k_scale_cache),
|
||||
("val", value, value_cache, v_scale_cache),
|
||||
]:
|
||||
for h in range(num_heads):
|
||||
orig = data[i, h].float() # [head_size]
|
||||
|
||||
actual_q = cache[blk, off, h]
|
||||
actual_sc = sc[blk, off, h]
|
||||
actual_deq = actual_q.float() * actual_sc
|
||||
|
||||
# Round-trip: dequantized should be close to original
|
||||
torch.testing.assert_close(
|
||||
actual_deq,
|
||||
orig,
|
||||
atol=0.1,
|
||||
rtol=0.1,
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 4. Negative slot mapping (padding tokens should be skipped)
|
||||
# ===========================================================================
|
||||
@torch.inference_mode()
|
||||
def test_per_token_head_negative_slot_skipped(qcfg: QuantConfig):
|
||||
"""Tokens with slot_mapping=-1 should leave the cache unchanged."""
|
||||
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
|
||||
triton_reshape_and_cache_flash_per_token_head_quant,
|
||||
)
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
num_tokens = 4
|
||||
num_heads = 2
|
||||
head_size = 64
|
||||
block_size = 16
|
||||
num_blocks = 2
|
||||
|
||||
key = torch.randn(num_tokens, num_heads, head_size, dtype=torch.bfloat16)
|
||||
value = torch.randn(num_tokens, num_heads, head_size, dtype=torch.bfloat16)
|
||||
|
||||
key_cache = torch.zeros(
|
||||
num_blocks, block_size, num_heads, head_size, dtype=qcfg.cache_dtype
|
||||
)
|
||||
value_cache = torch.zeros(
|
||||
num_blocks, block_size, num_heads, head_size, dtype=qcfg.cache_dtype
|
||||
)
|
||||
k_scale_cache = torch.ones(num_blocks, block_size, num_heads, dtype=torch.float32)
|
||||
v_scale_cache = torch.ones(num_blocks, block_size, num_heads, dtype=torch.float32)
|
||||
|
||||
slot_mapping = torch.tensor([0, -1, 1, -1], dtype=torch.long)
|
||||
|
||||
key_cache_before = key_cache.clone()
|
||||
val_cache_before = value_cache.clone()
|
||||
|
||||
triton_reshape_and_cache_flash_per_token_head_quant(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
k_scale_cache,
|
||||
v_scale_cache,
|
||||
slot_mapping,
|
||||
)
|
||||
|
||||
# Slots 0 and 1 should have been written (tokens 0 and 2)
|
||||
assert not torch.equal(key_cache[0, 0], key_cache_before[0, 0])
|
||||
assert not torch.equal(key_cache[0, 1], key_cache_before[0, 1])
|
||||
assert not torch.equal(value_cache[0, 0], val_cache_before[0, 0])
|
||||
|
||||
# All other slots should be unchanged
|
||||
assert torch.equal(key_cache[0, 2:], key_cache_before[0, 2:])
|
||||
assert torch.equal(key_cache[1], key_cache_before[1])
|
||||
assert torch.equal(value_cache[0, 2:], val_cache_before[0, 2:])
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 5. process_weights_after_loading -- per-token-head early return
|
||||
# ===========================================================================
|
||||
@pytest.mark.parametrize(
|
||||
"kv_cache_dtype", ["int8_per_token_head", "fp8_per_token_head"]
|
||||
)
|
||||
def test_process_weights_sets_placeholder_scales(kv_cache_dtype: str):
|
||||
"""Per-token-head should set _k_scale=1.0, _v_scale=1.0
|
||||
and delete checkpoint attrs."""
|
||||
from vllm.model_executor.layers.quantization.kv_cache import (
|
||||
BaseKVCacheMethod,
|
||||
)
|
||||
|
||||
layer = MagicMock()
|
||||
layer.kv_cache_dtype = kv_cache_dtype
|
||||
layer.calculate_kv_scales = False
|
||||
layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
|
||||
layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
|
||||
layer.q_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
|
||||
layer.prob_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
|
||||
layer._k_scale = torch.tensor(0.0)
|
||||
layer._v_scale = torch.tensor(0.0)
|
||||
layer._k_scale_float = 0.0
|
||||
layer._v_scale_float = 0.0
|
||||
|
||||
method = BaseKVCacheMethod.__new__(BaseKVCacheMethod)
|
||||
method.quant_config = MagicMock()
|
||||
method.process_weights_after_loading(layer)
|
||||
|
||||
assert layer._k_scale_float == 1.0
|
||||
assert layer._v_scale_float == 1.0
|
||||
assert not hasattr(layer, "k_scale")
|
||||
assert not hasattr(layer, "v_scale")
|
||||
assert not hasattr(layer, "q_scale")
|
||||
assert not hasattr(layer, "prob_scale")
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 6. Triton unified_attention -- per-token-head scale cache (INT8 and FP8)
|
||||
# ===========================================================================
|
||||
@pytest.mark.parametrize(
|
||||
"seq_lens",
|
||||
[
|
||||
[(1, 128)],
|
||||
[(1, 64), (1, 32)],
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("num_heads", [(4, 4)])
|
||||
@pytest.mark.parametrize("head_size", [128])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@torch.inference_mode()
|
||||
def test_triton_unified_attention_per_token_head_scale(
|
||||
qcfg: QuantConfig,
|
||||
seq_lens: list[tuple[int, int]],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
):
|
||||
"""End-to-end: quantized KV with per-token-head scale caches."""
|
||||
from vllm.utils.math_utils import next_power_of_2
|
||||
from vllm.v1.attention.ops.triton_unified_attention import unified_attention
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
set_random_seed(0)
|
||||
|
||||
num_seqs = len(seq_lens)
|
||||
query_lens = [s[0] for s in seq_lens]
|
||||
kv_lens = [s[1] for s in seq_lens]
|
||||
num_query_heads, num_kv_heads = num_heads
|
||||
max_query_len = max(query_lens)
|
||||
max_kv_len = max(kv_lens)
|
||||
scale = head_size**-0.5
|
||||
num_blocks = 2048
|
||||
|
||||
query = torch.randn(
|
||||
sum(query_lens), num_query_heads, head_size, dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
key_cache_bf16 = torch.randn(
|
||||
num_blocks, block_size, num_kv_heads, head_size, dtype=torch.bfloat16
|
||||
)
|
||||
value_cache_bf16 = torch.randn_like(key_cache_bf16)
|
||||
|
||||
# Per-token-head quantization: one scale per (block, slot, head)
|
||||
k_absmax = key_cache_bf16.float().abs().amax(dim=-1) # [..., num_kv_heads]
|
||||
v_absmax = value_cache_bf16.float().abs().amax(dim=-1)
|
||||
k_scale_cache = (k_absmax / qcfg.quant_max).clamp(min=1e-6).to(torch.float32)
|
||||
v_scale_cache = (v_absmax / qcfg.quant_max).clamp(min=1e-6).to(torch.float32)
|
||||
|
||||
scaled_k = key_cache_bf16.float() / k_scale_cache[:, :, :, None]
|
||||
scaled_v = value_cache_bf16.float() / v_scale_cache[:, :, :, None]
|
||||
if qcfg.uses_trunc:
|
||||
key_cache_q = (
|
||||
scaled_k.round().clamp(qcfg.quant_min, qcfg.quant_max).to(qcfg.cache_dtype)
|
||||
)
|
||||
value_cache_q = (
|
||||
scaled_v.round().clamp(qcfg.quant_min, qcfg.quant_max).to(qcfg.cache_dtype)
|
||||
)
|
||||
else:
|
||||
key_cache_q = scaled_k.clamp(qcfg.quant_min, qcfg.quant_max).to(
|
||||
qcfg.cache_dtype
|
||||
)
|
||||
value_cache_q = scaled_v.clamp(qcfg.quant_min, qcfg.quant_max).to(
|
||||
qcfg.cache_dtype
|
||||
)
|
||||
|
||||
# Dequantized reference
|
||||
key_cache_deq = key_cache_q.float() * k_scale_cache[:, :, :, None]
|
||||
value_cache_deq = value_cache_q.float() * v_scale_cache[:, :, :, None]
|
||||
|
||||
cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum(
|
||||
dim=0, dtype=torch.int32
|
||||
)
|
||||
kv_lens_t = torch.tensor(kv_lens, dtype=torch.int32)
|
||||
|
||||
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(
|
||||
0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
|
||||
head_size_padded = next_power_of_2(head_size)
|
||||
seq_threshold_3D = 0
|
||||
num_par_softmax_segments = 16
|
||||
softmax_segm_output = torch.empty(
|
||||
(seq_threshold_3D, num_query_heads, num_par_softmax_segments, head_size_padded),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
softmax_segm_max = torch.empty(
|
||||
(seq_threshold_3D, num_query_heads, num_par_softmax_segments),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
softmax_segm_expsum = torch.empty(
|
||||
(seq_threshold_3D, num_query_heads, num_par_softmax_segments),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
output_q = torch.empty_like(query)
|
||||
unified_attention(
|
||||
q=query,
|
||||
k=key_cache_q,
|
||||
v=value_cache_q,
|
||||
out=output_q,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
seqused_k=kv_lens_t,
|
||||
max_seqlen_q=max_query_len,
|
||||
max_seqlen_k=max_kv_len,
|
||||
softmax_scale=scale,
|
||||
causal=True,
|
||||
window_size=(-1, -1),
|
||||
block_table=block_tables,
|
||||
softcap=0,
|
||||
q_descale=None,
|
||||
k_descale=None,
|
||||
v_descale=None,
|
||||
seq_threshold_3D=seq_threshold_3D,
|
||||
num_par_softmax_segments=num_par_softmax_segments,
|
||||
softmax_segm_output=softmax_segm_output,
|
||||
softmax_segm_max=softmax_segm_max,
|
||||
softmax_segm_expsum=softmax_segm_expsum,
|
||||
kv_quant_mode=qcfg.kv_quant_mode,
|
||||
k_scale_cache=k_scale_cache,
|
||||
v_scale_cache=v_scale_cache,
|
||||
)
|
||||
|
||||
output_ref = torch.empty_like(query)
|
||||
unified_attention(
|
||||
q=query,
|
||||
k=key_cache_deq.to(torch.bfloat16),
|
||||
v=value_cache_deq.to(torch.bfloat16),
|
||||
out=output_ref,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
seqused_k=kv_lens_t,
|
||||
max_seqlen_q=max_query_len,
|
||||
max_seqlen_k=max_kv_len,
|
||||
softmax_scale=scale,
|
||||
causal=True,
|
||||
window_size=(-1, -1),
|
||||
block_table=block_tables,
|
||||
softcap=0,
|
||||
q_descale=None,
|
||||
k_descale=None,
|
||||
v_descale=None,
|
||||
seq_threshold_3D=seq_threshold_3D,
|
||||
num_par_softmax_segments=num_par_softmax_segments,
|
||||
softmax_segm_output=softmax_segm_output,
|
||||
softmax_segm_max=softmax_segm_max,
|
||||
softmax_segm_expsum=softmax_segm_expsum,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(output_q, output_ref, atol=5e-2, rtol=5e-2)
|
||||
Reference in New Issue
Block a user