Fix ambiguous num_blocks for hybrid attn mamba (#37236)
Signed-off-by: Collin McCarthy <cmccarthy@nvidia.com> Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> Co-authored-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
@@ -30,6 +32,7 @@ from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.v1.core.kv_cache_utils import estimate_max_model_len, get_kv_cache_configs
|
||||
from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
AttentionSpec,
|
||||
FullAttentionSpec,
|
||||
KVCacheConfig,
|
||||
KVCacheGroupSpec,
|
||||
@@ -38,7 +41,7 @@ from vllm.v1.kv_cache_interface import (
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
from vllm.v1.worker.utils import select_common_block_size
|
||||
from vllm.v1.worker.utils import AttentionGroup, select_common_block_size
|
||||
|
||||
BLOCK_SIZE = 16
|
||||
NUM_BLOCKS = 10
|
||||
@@ -946,6 +949,33 @@ def test_hybrid_attention_mamba_tensor_shapes():
|
||||
assert torch.equal(actual_ssm, expected_ssm)
|
||||
|
||||
|
||||
def test_update_hybrid_attention_mamba_layout_with_num_block_2_rewrites_stride():
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
||||
|
||||
ambiguous_cache = torch.empty((2, 2, BLOCK_SIZE, 1, 8), dtype=torch.float16)
|
||||
"""Ambiguous, because both dims[0=kv_dim] and dims[1=num_blocks] == 2"""
|
||||
hidden_size = ambiguous_cache.shape[2:].numel()
|
||||
assert ambiguous_cache.stride()[:2] == (2 * hidden_size, hidden_size)
|
||||
|
||||
attention_spec = AttentionSpec(
|
||||
block_size=BLOCK_SIZE, num_kv_heads=1, head_size=8, dtype=torch.float16
|
||||
)
|
||||
runner_stub = SimpleNamespace(
|
||||
cache_config=SimpleNamespace(cache_dtype="auto"),
|
||||
_kv_cache_spec_attn_group_iterator=lambda: iter(
|
||||
[AttentionGroup(FlashAttentionBackend, ["attn"], attention_spec, 0)]
|
||||
),
|
||||
)
|
||||
GPUModelRunner._update_hybrid_attention_mamba_layout(
|
||||
runner_stub, {"attn": ambiguous_cache}, [BLOCK_SIZE]
|
||||
)
|
||||
|
||||
assert ambiguous_cache.stride()[:2] == (hidden_size, 2 * hidden_size), """\
|
||||
We expect _update_hybrid_attention_mamba_layout to re-stride the cache from:
|
||||
(2, num_blocks) -> (num_blocks, 2), even when num_blocks==2,
|
||||
which was ambiguous before get_kv_cache_block_dim was used"""
|
||||
|
||||
|
||||
def test_hybrid_block_table_initialization():
|
||||
"""Test hybrid block table with different kernel and kvcache_manager block
|
||||
sizes."""
|
||||
|
||||
Reference in New Issue
Block a user