Fix ambiguous num_blocks for hybrid attn mamba (#37236)
Signed-off-by: Collin McCarthy <cmccarthy@nvidia.com> Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> Co-authored-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
@@ -30,6 +32,7 @@ from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.v1.core.kv_cache_utils import estimate_max_model_len, get_kv_cache_configs
|
||||
from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
AttentionSpec,
|
||||
FullAttentionSpec,
|
||||
KVCacheConfig,
|
||||
KVCacheGroupSpec,
|
||||
@@ -38,7 +41,7 @@ from vllm.v1.kv_cache_interface import (
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
from vllm.v1.worker.utils import select_common_block_size
|
||||
from vllm.v1.worker.utils import AttentionGroup, select_common_block_size
|
||||
|
||||
BLOCK_SIZE = 16
|
||||
NUM_BLOCKS = 10
|
||||
@@ -946,6 +949,33 @@ def test_hybrid_attention_mamba_tensor_shapes():
|
||||
assert torch.equal(actual_ssm, expected_ssm)
|
||||
|
||||
|
||||
def test_update_hybrid_attention_mamba_layout_with_num_block_2_rewrites_stride():
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
||||
|
||||
ambiguous_cache = torch.empty((2, 2, BLOCK_SIZE, 1, 8), dtype=torch.float16)
|
||||
"""Ambiguous, because both dims[0=kv_dim] and dims[1=num_blocks] == 2"""
|
||||
hidden_size = ambiguous_cache.shape[2:].numel()
|
||||
assert ambiguous_cache.stride()[:2] == (2 * hidden_size, hidden_size)
|
||||
|
||||
attention_spec = AttentionSpec(
|
||||
block_size=BLOCK_SIZE, num_kv_heads=1, head_size=8, dtype=torch.float16
|
||||
)
|
||||
runner_stub = SimpleNamespace(
|
||||
cache_config=SimpleNamespace(cache_dtype="auto"),
|
||||
_kv_cache_spec_attn_group_iterator=lambda: iter(
|
||||
[AttentionGroup(FlashAttentionBackend, ["attn"], attention_spec, 0)]
|
||||
),
|
||||
)
|
||||
GPUModelRunner._update_hybrid_attention_mamba_layout(
|
||||
runner_stub, {"attn": ambiguous_cache}, [BLOCK_SIZE]
|
||||
)
|
||||
|
||||
assert ambiguous_cache.stride()[:2] == (hidden_size, 2 * hidden_size), """\
|
||||
We expect _update_hybrid_attention_mamba_layout to re-stride the cache from:
|
||||
(2, num_blocks) -> (num_blocks, 2), even when num_blocks==2,
|
||||
which was ambiguous before get_kv_cache_block_dim was used"""
|
||||
|
||||
|
||||
def test_hybrid_block_table_initialization():
|
||||
"""Test hybrid block table with different kernel and kvcache_manager block
|
||||
sizes."""
|
||||
|
||||
@@ -6653,12 +6653,12 @@ class GPUModelRunner(
|
||||
raise NotImplementedError
|
||||
|
||||
if has_attn and has_mamba:
|
||||
self._update_hybrid_attention_mamba_layout(kv_caches)
|
||||
self._update_hybrid_attention_mamba_layout(kv_caches, kernel_block_sizes)
|
||||
|
||||
return kv_caches
|
||||
|
||||
def _update_hybrid_attention_mamba_layout(
|
||||
self, kv_caches: dict[str, torch.Tensor]
|
||||
self, kv_caches: dict[str, torch.Tensor], kernel_block_sizes: list[int]
|
||||
) -> None:
|
||||
"""
|
||||
Update the layout of attention layers from (2, num_blocks, ...) to
|
||||
@@ -6666,23 +6666,30 @@ class GPUModelRunner(
|
||||
|
||||
Args:
|
||||
kv_caches: The KV cache buffer of each layer.
|
||||
kernel_block_sizes: The kernel block sizes for each KV cache group.
|
||||
"""
|
||||
|
||||
for group in self._kv_cache_spec_attn_group_iterator():
|
||||
kv_cache_spec = group.kv_cache_spec
|
||||
if not isinstance(kv_cache_spec, AttentionSpec):
|
||||
continue
|
||||
block_dim = group.backend.get_kv_cache_block_dim(
|
||||
kernel_block_sizes[group.kv_cache_group_id],
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size,
|
||||
cache_dtype_str=self.cache_config.cache_dtype,
|
||||
)
|
||||
# block_dim: 0 means (num_blocks, 2, ...); 1 means (2, num_blocks, ...).
|
||||
if block_dim == 0:
|
||||
continue
|
||||
assert block_dim == 1
|
||||
for layer_name in group.layer_names:
|
||||
kv_cache = kv_caches[layer_name]
|
||||
if isinstance(kv_cache_spec, AttentionSpec) and kv_cache.shape[0] == 2:
|
||||
assert kv_cache.shape[1] != 2, (
|
||||
"Fail to determine whether the layout is "
|
||||
"(2, num_blocks, ...) or (num_blocks, 2, ...) for "
|
||||
f"a tensor of shape {kv_cache.shape}"
|
||||
)
|
||||
hidden_size = kv_cache.shape[2:].numel()
|
||||
kv_cache.as_strided_(
|
||||
size=kv_cache.shape,
|
||||
stride=(hidden_size, 2 * hidden_size, *kv_cache.stride()[2:]),
|
||||
)
|
||||
hidden_size = kv_cache.shape[2:].numel()
|
||||
kv_cache.as_strided_(
|
||||
size=kv_cache.shape,
|
||||
stride=(hidden_size, 2 * hidden_size, *kv_cache.stride()[2:]),
|
||||
)
|
||||
|
||||
def initialize_kv_cache_tensors(
|
||||
self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int]
|
||||
|
||||
Reference in New Issue
Block a user