[Bugfix] Disable cross-layer KV cache for MLA attention backends (#37090)
Signed-off-by: haosdent <haosdent@gmail.com> Co-authored-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
36
tests/v1/kv_connector/unit/test_kv_cache_layout.py
Normal file
36
tests/v1/kv_connector/unit/test_kv_cache_layout.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
|
||||||
|
def test_mla_backend_rejects_cross_layer_kv_cache():
|
||||||
|
"""MLA backends return identity permutation (layers dim first)
|
||||||
|
to signal cross-layer KV cache is unsupported."""
|
||||||
|
from vllm.model_executor.layers.attention.mla_attention import (
|
||||||
|
MLACommonBackend,
|
||||||
|
)
|
||||||
|
|
||||||
|
stride_order = MLACommonBackend.get_kv_cache_stride_order(
|
||||||
|
include_num_layers_dimension=True
|
||||||
|
)
|
||||||
|
assert stride_order == (0, 1, 2, 3)
|
||||||
|
assert stride_order[0] == 0 # layers dim first => no cross-layer
|
||||||
|
assert MLACommonBackend.get_kv_cache_stride_order(
|
||||||
|
include_num_layers_dimension=False
|
||||||
|
) == (0, 1, 2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_deepseek_v32_indexer_rejects_cross_layer_kv_cache():
|
||||||
|
"""DeepseekV32Indexer returns identity permutation (layers dim first)
|
||||||
|
to signal cross-layer KV cache is unsupported."""
|
||||||
|
from vllm.v1.attention.backends.mla.indexer import (
|
||||||
|
DeepseekV32IndexerBackend,
|
||||||
|
)
|
||||||
|
|
||||||
|
stride_order = DeepseekV32IndexerBackend.get_kv_cache_stride_order(
|
||||||
|
include_num_layers_dimension=True
|
||||||
|
)
|
||||||
|
assert stride_order == (0, 1, 2, 3)
|
||||||
|
assert stride_order[0] == 0 # layers dim first => no cross-layer
|
||||||
|
assert DeepseekV32IndexerBackend.get_kv_cache_stride_order(
|
||||||
|
include_num_layers_dimension=False
|
||||||
|
) == (0, 1, 2)
|
||||||
@@ -24,7 +24,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
|||||||
)
|
)
|
||||||
from vllm.forward_context import ForwardContext
|
from vllm.forward_context import ForwardContext
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.attention import Attention
|
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||||
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata
|
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata
|
||||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||||
@@ -601,7 +601,9 @@ class OffloadingConnectorWorker:
|
|||||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||||
layer_names = list(kv_caches.keys())
|
layer_names = list(kv_caches.keys())
|
||||||
layers = get_layers_from_vllm_config(
|
layers = get_layers_from_vllm_config(
|
||||||
self.spec.vllm_config, Attention, layer_names
|
self.spec.vllm_config,
|
||||||
|
AttentionLayerBase, # type: ignore[type-abstract]
|
||||||
|
layer_names,
|
||||||
)
|
)
|
||||||
attn_backends = {
|
attn_backends = {
|
||||||
layer_name: layers[layer_name].get_attn_backend()
|
layer_name: layers[layer_name].get_attn_backend()
|
||||||
|
|||||||
@@ -1142,10 +1142,12 @@ class MLACommonBackend(AttentionBackend):
|
|||||||
def get_kv_cache_stride_order(
|
def get_kv_cache_stride_order(
|
||||||
include_num_layers_dimension: bool = False,
|
include_num_layers_dimension: bool = False,
|
||||||
) -> tuple[int, ...]:
|
) -> tuple[int, ...]:
|
||||||
# `stride_order` indicates the permutation that gets
|
if include_num_layers_dimension:
|
||||||
# us from `get_kv_cache_shape` to the actual memory layout we want.
|
# MLA kernels require contiguous per-layer KV cache views.
|
||||||
# (num_blocks, num_layers, block_size, head_size)
|
# Identity permutation keeps num_layers first in physical
|
||||||
return (1, 0, 2, 3) if include_num_layers_dimension else (0, 1, 2)
|
# layout, signaling cross-layer allocation is unsupported.
|
||||||
|
return (0, 1, 2, 3)
|
||||||
|
return (0, 1, 2)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_supported_head_sizes(cls) -> list[int]:
|
def get_supported_head_sizes(cls) -> list[int]:
|
||||||
|
|||||||
@@ -63,6 +63,9 @@ class DeepseekV32IndexerBackend(AttentionBackend):
|
|||||||
include_num_layers_dimension: bool = False,
|
include_num_layers_dimension: bool = False,
|
||||||
) -> tuple[int, ...]:
|
) -> tuple[int, ...]:
|
||||||
if include_num_layers_dimension:
|
if include_num_layers_dimension:
|
||||||
|
# DeepseekV32Indexer kernels do not support cross-layer
|
||||||
|
# KV cache layout. Identity permutation keeps num_layers
|
||||||
|
# first, signaling incompatibility.
|
||||||
return (0, 1, 2, 3)
|
return (0, 1, 2, 3)
|
||||||
return (0, 1, 2)
|
return (0, 1, 2)
|
||||||
|
|
||||||
|
|||||||
@@ -191,8 +191,13 @@ class KVConnectorModelRunnerMixin:
|
|||||||
except (AttributeError, NotImplementedError):
|
except (AttributeError, NotImplementedError):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# check that attention backend include a layers dimension
|
# check that attention backend includes a layers dimension
|
||||||
return len(kv_cache_stride_order) == len(kv_cache_shape) + 1
|
if len(kv_cache_stride_order) != len(kv_cache_shape) + 1:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# stride_order[0] == 0 means num_layers stays first in physical
|
||||||
|
# layout (identity permutation), so cross-layer is unsupported.
|
||||||
|
return kv_cache_stride_order[0] != 0
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def allocate_uniform_kv_caches(
|
def allocate_uniform_kv_caches(
|
||||||
|
|||||||
Reference in New Issue
Block a user