Fix ambiguous num_blocks for hybrid attn mamba (#37236)

Signed-off-by: Collin McCarthy <cmccarthy@nvidia.com>
Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
Co-authored-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
This commit is contained in:
Collin McCarthy
2026-03-30 04:09:45 -07:00
committed by GitHub
parent 7e76af14fa
commit 1031c84c36
2 changed files with 51 additions and 14 deletions

View File

@@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from types import SimpleNamespace
import numpy as np
import pytest
import torch
@@ -30,6 +32,7 @@ from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.core.kv_cache_utils import estimate_max_model_len, get_kv_cache_configs
from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput
from vllm.v1.kv_cache_interface import (
AttentionSpec,
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
@@ -38,7 +41,7 @@ from vllm.v1.kv_cache_interface import (
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.utils import select_common_block_size
from vllm.v1.worker.utils import AttentionGroup, select_common_block_size
BLOCK_SIZE = 16
NUM_BLOCKS = 10
@@ -946,6 +949,33 @@ def test_hybrid_attention_mamba_tensor_shapes():
assert torch.equal(actual_ssm, expected_ssm)
def test_update_hybrid_attention_mamba_layout_with_num_block_2_rewrites_stride():
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
ambiguous_cache = torch.empty((2, 2, BLOCK_SIZE, 1, 8), dtype=torch.float16)
"""Ambiguous, because both dims[0=kv_dim] and dims[1=num_blocks] == 2"""
hidden_size = ambiguous_cache.shape[2:].numel()
assert ambiguous_cache.stride()[:2] == (2 * hidden_size, hidden_size)
attention_spec = AttentionSpec(
block_size=BLOCK_SIZE, num_kv_heads=1, head_size=8, dtype=torch.float16
)
runner_stub = SimpleNamespace(
cache_config=SimpleNamespace(cache_dtype="auto"),
_kv_cache_spec_attn_group_iterator=lambda: iter(
[AttentionGroup(FlashAttentionBackend, ["attn"], attention_spec, 0)]
),
)
GPUModelRunner._update_hybrid_attention_mamba_layout(
runner_stub, {"attn": ambiguous_cache}, [BLOCK_SIZE]
)
assert ambiguous_cache.stride()[:2] == (hidden_size, 2 * hidden_size), """\
We expect _update_hybrid_attention_mamba_layout to re-stride the cache from:
(2, num_blocks) -> (num_blocks, 2), even when num_blocks==2,
which was ambiguous before get_kv_cache_block_dim was used"""
def test_hybrid_block_table_initialization():
"""Test hybrid block table with different kernel and kvcache_manager block
sizes."""