Fix ambiguous num_blocks for hybrid attn mamba (#37236)

Signed-off-by: Collin McCarthy <cmccarthy@nvidia.com>
Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
Co-authored-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
This commit is contained in:
Collin McCarthy
2026-03-30 04:09:45 -07:00
committed by GitHub
parent 7e76af14fa
commit 1031c84c36
2 changed files with 51 additions and 14 deletions

View File

@@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from types import SimpleNamespace
import numpy as np
import pytest
import torch
@@ -30,6 +32,7 @@ from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.core.kv_cache_utils import estimate_max_model_len, get_kv_cache_configs
from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput
from vllm.v1.kv_cache_interface import (
AttentionSpec,
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
@@ -38,7 +41,7 @@ from vllm.v1.kv_cache_interface import (
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.utils import select_common_block_size
from vllm.v1.worker.utils import AttentionGroup, select_common_block_size
BLOCK_SIZE = 16
NUM_BLOCKS = 10
@@ -946,6 +949,33 @@ def test_hybrid_attention_mamba_tensor_shapes():
assert torch.equal(actual_ssm, expected_ssm)
def test_update_hybrid_attention_mamba_layout_with_num_block_2_rewrites_stride():
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
ambiguous_cache = torch.empty((2, 2, BLOCK_SIZE, 1, 8), dtype=torch.float16)
"""Ambiguous, because both dims[0=kv_dim] and dims[1=num_blocks] == 2"""
hidden_size = ambiguous_cache.shape[2:].numel()
assert ambiguous_cache.stride()[:2] == (2 * hidden_size, hidden_size)
attention_spec = AttentionSpec(
block_size=BLOCK_SIZE, num_kv_heads=1, head_size=8, dtype=torch.float16
)
runner_stub = SimpleNamespace(
cache_config=SimpleNamespace(cache_dtype="auto"),
_kv_cache_spec_attn_group_iterator=lambda: iter(
[AttentionGroup(FlashAttentionBackend, ["attn"], attention_spec, 0)]
),
)
GPUModelRunner._update_hybrid_attention_mamba_layout(
runner_stub, {"attn": ambiguous_cache}, [BLOCK_SIZE]
)
assert ambiguous_cache.stride()[:2] == (hidden_size, 2 * hidden_size), """\
We expect _update_hybrid_attention_mamba_layout to re-stride the cache from:
(2, num_blocks) -> (num_blocks, 2), even when num_blocks==2,
which was ambiguous before get_kv_cache_block_dim was used"""
def test_hybrid_block_table_initialization():
"""Test hybrid block table with different kernel and kvcache_manager block
sizes."""

View File

@@ -6653,12 +6653,12 @@ class GPUModelRunner(
raise NotImplementedError
if has_attn and has_mamba:
self._update_hybrid_attention_mamba_layout(kv_caches)
self._update_hybrid_attention_mamba_layout(kv_caches, kernel_block_sizes)
return kv_caches
def _update_hybrid_attention_mamba_layout(
self, kv_caches: dict[str, torch.Tensor]
self, kv_caches: dict[str, torch.Tensor], kernel_block_sizes: list[int]
) -> None:
"""
Update the layout of attention layers from (2, num_blocks, ...) to
@@ -6666,23 +6666,30 @@ class GPUModelRunner(
Args:
kv_caches: The KV cache buffer of each layer.
kernel_block_sizes: The kernel block sizes for each KV cache group.
"""
for group in self._kv_cache_spec_attn_group_iterator():
kv_cache_spec = group.kv_cache_spec
if not isinstance(kv_cache_spec, AttentionSpec):
continue
block_dim = group.backend.get_kv_cache_block_dim(
kernel_block_sizes[group.kv_cache_group_id],
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
cache_dtype_str=self.cache_config.cache_dtype,
)
# block_dim: 0 means (num_blocks, 2, ...); 1 means (2, num_blocks, ...).
if block_dim == 0:
continue
assert block_dim == 1
for layer_name in group.layer_names:
kv_cache = kv_caches[layer_name]
if isinstance(kv_cache_spec, AttentionSpec) and kv_cache.shape[0] == 2:
assert kv_cache.shape[1] != 2, (
"Fail to determine whether the layout is "
"(2, num_blocks, ...) or (num_blocks, 2, ...) for "
f"a tensor of shape {kv_cache.shape}"
)
hidden_size = kv_cache.shape[2:].numel()
kv_cache.as_strided_(
size=kv_cache.shape,
stride=(hidden_size, 2 * hidden_size, *kv_cache.stride()[2:]),
)
hidden_size = kv_cache.shape[2:].numel()
kv_cache.as_strided_(
size=kv_cache.shape,
stride=(hidden_size, 2 * hidden_size, *kv_cache.stride()[2:]),
)
def initialize_kv_cache_tensors(
self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int]