From ca1954d58c49e3a3209ec86d743a99f3a605028b Mon Sep 17 00:00:00 2001 From: haosdent Date: Tue, 17 Mar 2026 01:03:10 +0800 Subject: [PATCH] [Bugfix] Disable cross-layer KV cache for MLA attention backends (#37090) Signed-off-by: haosdent Co-authored-by: Or Ozeri --- .../kv_connector/unit/test_kv_cache_layout.py | 36 +++++++++++++++++++ .../kv_connector/v1/offloading_connector.py | 6 ++-- .../layers/attention/mla_attention.py | 10 +++--- vllm/v1/attention/backends/mla/indexer.py | 3 ++ .../worker/kv_connector_model_runner_mixin.py | 9 +++-- 5 files changed, 56 insertions(+), 8 deletions(-) create mode 100644 tests/v1/kv_connector/unit/test_kv_cache_layout.py diff --git a/tests/v1/kv_connector/unit/test_kv_cache_layout.py b/tests/v1/kv_connector/unit/test_kv_cache_layout.py new file mode 100644 index 000000000..7f8028991 --- /dev/null +++ b/tests/v1/kv_connector/unit/test_kv_cache_layout.py @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +def test_mla_backend_rejects_cross_layer_kv_cache(): + """MLA backends return identity permutation (layers dim first) + to signal cross-layer KV cache is unsupported.""" + from vllm.model_executor.layers.attention.mla_attention import ( + MLACommonBackend, + ) + + stride_order = MLACommonBackend.get_kv_cache_stride_order( + include_num_layers_dimension=True + ) + assert stride_order == (0, 1, 2, 3) + assert stride_order[0] == 0 # layers dim first => no cross-layer + assert MLACommonBackend.get_kv_cache_stride_order( + include_num_layers_dimension=False + ) == (0, 1, 2) + + +def test_deepseek_v32_indexer_rejects_cross_layer_kv_cache(): + """DeepseekV32Indexer returns identity permutation (layers dim first) + to signal cross-layer KV cache is unsupported.""" + from vllm.v1.attention.backends.mla.indexer import ( + DeepseekV32IndexerBackend, + ) + + stride_order = DeepseekV32IndexerBackend.get_kv_cache_stride_order( + include_num_layers_dimension=True + ) + assert stride_order == (0, 1, 2, 3) + assert stride_order[0] == 0 # layers dim first => no cross-layer + assert DeepseekV32IndexerBackend.get_kv_cache_stride_order( + include_num_layers_dimension=False + ) == (0, 1, 2) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py index 021f0144d..4c850fd2f 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py @@ -24,7 +24,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( ) from vllm.forward_context import ForwardContext from vllm.logger import init_logger -from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_utils import BlockHash @@ -601,7 +601,9 @@ class OffloadingConnectorWorker: def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): layer_names = list(kv_caches.keys()) layers = get_layers_from_vllm_config( - self.spec.vllm_config, Attention, layer_names + self.spec.vllm_config, + AttentionLayerBase, # type: ignore[type-abstract] + layer_names, ) attn_backends = { layer_name: layers[layer_name].get_attn_backend() diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index 36ee728dc..b613f3ba9 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -1142,10 +1142,12 @@ class MLACommonBackend(AttentionBackend): def get_kv_cache_stride_order( include_num_layers_dimension: bool = False, ) -> tuple[int, ...]: - # `stride_order` indicates the permutation that gets - # us from `get_kv_cache_shape` to the actual memory layout we want. - # (num_blocks, num_layers, block_size, head_size) - return (1, 0, 2, 3) if include_num_layers_dimension else (0, 1, 2) + if include_num_layers_dimension: + # MLA kernels require contiguous per-layer KV cache views. + # Identity permutation keeps num_layers first in physical + # layout, signaling cross-layer allocation is unsupported. + return (0, 1, 2, 3) + return (0, 1, 2) @classmethod def get_supported_head_sizes(cls) -> list[int]: diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index f8ff2fc2e..70281b4a9 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -63,6 +63,9 @@ class DeepseekV32IndexerBackend(AttentionBackend): include_num_layers_dimension: bool = False, ) -> tuple[int, ...]: if include_num_layers_dimension: + # DeepseekV32Indexer kernels do not support cross-layer + # KV cache layout. Identity permutation keeps num_layers + # first, signaling incompatibility. return (0, 1, 2, 3) return (0, 1, 2) diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py index 2921594a3..bc243906b 100644 --- a/vllm/v1/worker/kv_connector_model_runner_mixin.py +++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py @@ -191,8 +191,13 @@ class KVConnectorModelRunnerMixin: except (AttributeError, NotImplementedError): return False - # check that attention backend include a layers dimension - return len(kv_cache_stride_order) == len(kv_cache_shape) + 1 + # check that attention backend includes a layers dimension + if len(kv_cache_stride_order) != len(kv_cache_shape) + 1: + return False + + # stride_order[0] == 0 means num_layers stays first in physical + # layout (identity permutation), so cross-layer is unsupported. + return kv_cache_stride_order[0] != 0 @staticmethod def allocate_uniform_kv_caches(