[Bugfix] Disable cross-layer KV cache for MLA attention backends (#37090)

Signed-off-by: haosdent <haosdent@gmail.com>
Co-authored-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
haosdent
2026-03-17 01:03:10 +08:00
committed by GitHub
parent 55e6d3d5c0
commit ca1954d58c
5 changed files with 56 additions and 8 deletions

View File

@@ -0,0 +1,36 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
def test_mla_backend_rejects_cross_layer_kv_cache():
"""MLA backends return identity permutation (layers dim first)
to signal cross-layer KV cache is unsupported."""
from vllm.model_executor.layers.attention.mla_attention import (
MLACommonBackend,
)
stride_order = MLACommonBackend.get_kv_cache_stride_order(
include_num_layers_dimension=True
)
assert stride_order == (0, 1, 2, 3)
assert stride_order[0] == 0 # layers dim first => no cross-layer
assert MLACommonBackend.get_kv_cache_stride_order(
include_num_layers_dimension=False
) == (0, 1, 2)
def test_deepseek_v32_indexer_rejects_cross_layer_kv_cache():
"""DeepseekV32Indexer returns identity permutation (layers dim first)
to signal cross-layer KV cache is unsupported."""
from vllm.v1.attention.backends.mla.indexer import (
DeepseekV32IndexerBackend,
)
stride_order = DeepseekV32IndexerBackend.get_kv_cache_stride_order(
include_num_layers_dimension=True
)
assert stride_order == (0, 1, 2, 3)
assert stride_order[0] == 0 # layers dim first => no cross-layer
assert DeepseekV32IndexerBackend.get_kv_cache_stride_order(
include_num_layers_dimension=False
) == (0, 1, 2)

View File

@@ -24,7 +24,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
) )
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.kv_cache_utils import BlockHash from vllm.v1.core.kv_cache_utils import BlockHash
@@ -601,7 +601,9 @@ class OffloadingConnectorWorker:
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
layer_names = list(kv_caches.keys()) layer_names = list(kv_caches.keys())
layers = get_layers_from_vllm_config( layers = get_layers_from_vllm_config(
self.spec.vllm_config, Attention, layer_names self.spec.vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
layer_names,
) )
attn_backends = { attn_backends = {
layer_name: layers[layer_name].get_attn_backend() layer_name: layers[layer_name].get_attn_backend()

View File

@@ -1142,10 +1142,12 @@ class MLACommonBackend(AttentionBackend):
def get_kv_cache_stride_order( def get_kv_cache_stride_order(
include_num_layers_dimension: bool = False, include_num_layers_dimension: bool = False,
) -> tuple[int, ...]: ) -> tuple[int, ...]:
# `stride_order` indicates the permutation that gets if include_num_layers_dimension:
# us from `get_kv_cache_shape` to the actual memory layout we want. # MLA kernels require contiguous per-layer KV cache views.
# (num_blocks, num_layers, block_size, head_size) # Identity permutation keeps num_layers first in physical
return (1, 0, 2, 3) if include_num_layers_dimension else (0, 1, 2) # layout, signaling cross-layer allocation is unsupported.
return (0, 1, 2, 3)
return (0, 1, 2)
@classmethod @classmethod
def get_supported_head_sizes(cls) -> list[int]: def get_supported_head_sizes(cls) -> list[int]:

View File

@@ -63,6 +63,9 @@ class DeepseekV32IndexerBackend(AttentionBackend):
include_num_layers_dimension: bool = False, include_num_layers_dimension: bool = False,
) -> tuple[int, ...]: ) -> tuple[int, ...]:
if include_num_layers_dimension: if include_num_layers_dimension:
# DeepseekV32Indexer kernels do not support cross-layer
# KV cache layout. Identity permutation keeps num_layers
# first, signaling incompatibility.
return (0, 1, 2, 3) return (0, 1, 2, 3)
return (0, 1, 2) return (0, 1, 2)

View File

@@ -191,8 +191,13 @@ class KVConnectorModelRunnerMixin:
except (AttributeError, NotImplementedError): except (AttributeError, NotImplementedError):
return False return False
# check that attention backend include a layers dimension # check that attention backend includes a layers dimension
return len(kv_cache_stride_order) == len(kv_cache_shape) + 1 if len(kv_cache_stride_order) != len(kv_cache_shape) + 1:
return False
# stride_order[0] == 0 means num_layers stays first in physical
# layout (identity permutation), so cross-layer is unsupported.
return kv_cache_stride_order[0] != 0
@staticmethod @staticmethod
def allocate_uniform_kv_caches( def allocate_uniform_kv_caches(