diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 02fac6dfc..0fbd6605a 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from types import SimpleNamespace + import numpy as np import pytest import torch @@ -30,6 +32,7 @@ from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.core.kv_cache_utils import estimate_max_model_len, get_kv_cache_configs from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput from vllm.v1.kv_cache_interface import ( + AttentionSpec, FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, @@ -38,7 +41,7 @@ from vllm.v1.kv_cache_interface import ( from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_model_runner import GPUModelRunner -from vllm.v1.worker.utils import select_common_block_size +from vllm.v1.worker.utils import AttentionGroup, select_common_block_size BLOCK_SIZE = 16 NUM_BLOCKS = 10 @@ -946,6 +949,33 @@ def test_hybrid_attention_mamba_tensor_shapes(): assert torch.equal(actual_ssm, expected_ssm) +def test_update_hybrid_attention_mamba_layout_with_num_block_2_rewrites_stride(): + from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend + + ambiguous_cache = torch.empty((2, 2, BLOCK_SIZE, 1, 8), dtype=torch.float16) + """Ambiguous, because both dims[0=kv_dim] and dims[1=num_blocks] == 2""" + hidden_size = ambiguous_cache.shape[2:].numel() + assert ambiguous_cache.stride()[:2] == (2 * hidden_size, hidden_size) + + attention_spec = AttentionSpec( + block_size=BLOCK_SIZE, num_kv_heads=1, head_size=8, dtype=torch.float16 + ) + runner_stub = SimpleNamespace( + cache_config=SimpleNamespace(cache_dtype="auto"), + _kv_cache_spec_attn_group_iterator=lambda: iter( + [AttentionGroup(FlashAttentionBackend, ["attn"], attention_spec, 0)] + ), + ) + GPUModelRunner._update_hybrid_attention_mamba_layout( + runner_stub, {"attn": ambiguous_cache}, [BLOCK_SIZE] + ) + + assert ambiguous_cache.stride()[:2] == (hidden_size, 2 * hidden_size), """\ + We expect _update_hybrid_attention_mamba_layout to re-stride the cache from: + (2, num_blocks) -> (num_blocks, 2), even when num_blocks==2, + which was ambiguous before get_kv_cache_block_dim was used""" + + def test_hybrid_block_table_initialization(): """Test hybrid block table with different kernel and kvcache_manager block sizes.""" diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8cfa61baa..1f946cda0 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -6653,12 +6653,12 @@ class GPUModelRunner( raise NotImplementedError if has_attn and has_mamba: - self._update_hybrid_attention_mamba_layout(kv_caches) + self._update_hybrid_attention_mamba_layout(kv_caches, kernel_block_sizes) return kv_caches def _update_hybrid_attention_mamba_layout( - self, kv_caches: dict[str, torch.Tensor] + self, kv_caches: dict[str, torch.Tensor], kernel_block_sizes: list[int] ) -> None: """ Update the layout of attention layers from (2, num_blocks, ...) to @@ -6666,23 +6666,30 @@ class GPUModelRunner( Args: kv_caches: The KV cache buffer of each layer. + kernel_block_sizes: The kernel block sizes for each KV cache group. """ for group in self._kv_cache_spec_attn_group_iterator(): kv_cache_spec = group.kv_cache_spec + if not isinstance(kv_cache_spec, AttentionSpec): + continue + block_dim = group.backend.get_kv_cache_block_dim( + kernel_block_sizes[group.kv_cache_group_id], + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size, + cache_dtype_str=self.cache_config.cache_dtype, + ) + # block_dim: 0 means (num_blocks, 2, ...); 1 means (2, num_blocks, ...). + if block_dim == 0: + continue + assert block_dim == 1 for layer_name in group.layer_names: kv_cache = kv_caches[layer_name] - if isinstance(kv_cache_spec, AttentionSpec) and kv_cache.shape[0] == 2: - assert kv_cache.shape[1] != 2, ( - "Fail to determine whether the layout is " - "(2, num_blocks, ...) or (num_blocks, 2, ...) for " - f"a tensor of shape {kv_cache.shape}" - ) - hidden_size = kv_cache.shape[2:].numel() - kv_cache.as_strided_( - size=kv_cache.shape, - stride=(hidden_size, 2 * hidden_size, *kv_cache.stride()[2:]), - ) + hidden_size = kv_cache.shape[2:].numel() + kv_cache.as_strided_( + size=kv_cache.shape, + stride=(hidden_size, 2 * hidden_size, *kv_cache.stride()[2:]), + ) def initialize_kv_cache_tensors( self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int]