[Bugfix] Disable cross-layer KV cache for MLA attention backends (#37090)

Signed-off-by: haosdent <haosdent@gmail.com>
Co-authored-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
haosdent
2026-03-17 01:03:10 +08:00
committed by GitHub
parent 55e6d3d5c0
commit ca1954d58c
5 changed files with 56 additions and 8 deletions

View File

@@ -24,7 +24,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
)
from vllm.forward_context import ForwardContext
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.kv_cache_utils import BlockHash
@@ -601,7 +601,9 @@ class OffloadingConnectorWorker:
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
layer_names = list(kv_caches.keys())
layers = get_layers_from_vllm_config(
self.spec.vllm_config, Attention, layer_names
self.spec.vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
layer_names,
)
attn_backends = {
layer_name: layers[layer_name].get_attn_backend()

View File

@@ -1142,10 +1142,12 @@ class MLACommonBackend(AttentionBackend):
def get_kv_cache_stride_order(
include_num_layers_dimension: bool = False,
) -> tuple[int, ...]:
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
# (num_blocks, num_layers, block_size, head_size)
return (1, 0, 2, 3) if include_num_layers_dimension else (0, 1, 2)
if include_num_layers_dimension:
# MLA kernels require contiguous per-layer KV cache views.
# Identity permutation keeps num_layers first in physical
# layout, signaling cross-layer allocation is unsupported.
return (0, 1, 2, 3)
return (0, 1, 2)
@classmethod
def get_supported_head_sizes(cls) -> list[int]:

View File

@@ -63,6 +63,9 @@ class DeepseekV32IndexerBackend(AttentionBackend):
include_num_layers_dimension: bool = False,
) -> tuple[int, ...]:
if include_num_layers_dimension:
# DeepseekV32Indexer kernels do not support cross-layer
# KV cache layout. Identity permutation keeps num_layers
# first, signaling incompatibility.
return (0, 1, 2, 3)
return (0, 1, 2)

View File

@@ -191,8 +191,13 @@ class KVConnectorModelRunnerMixin:
except (AttributeError, NotImplementedError):
return False
# check that attention backend include a layers dimension
return len(kv_cache_stride_order) == len(kv_cache_shape) + 1
# check that attention backend includes a layers dimension
if len(kv_cache_stride_order) != len(kv_cache_shape) + 1:
return False
# stride_order[0] == 0 means num_layers stays first in physical
# layout (identity permutation), so cross-layer is unsupported.
return kv_cache_stride_order[0] != 0
@staticmethod
def allocate_uniform_kv_caches(