# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math import pytest import torch from tests.v1.attention.utils import ( create_standard_kv_cache_spec, create_vllm_config, try_backend_includes_kv_cache_update, try_get_attention_backend, ) from vllm.config import ParallelConfig, SpeculativeConfig from vllm.platforms import current_platform from vllm.v1.attention.backend import CommonAttentionMetadata from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available from vllm.v1.attention.backends.registry import AttentionBackendEnum if not is_flash_attn_varlen_func_available(): pytest.skip( "This test requires flash_attn_varlen_func, but it's not available.", allow_module_level=True, ) # --------------------------------------------------------------------------- # # KV cache layout adaptation # --------------------------------------------------------------------------- # # Two KV cache layouts exist across backends: # # Flash layout: (2, num_blocks, block_size, num_kv_heads, head_size) # - dim 0 separates key (index 0) and value (index 1) # - Used by: FLASH_ATTN, TREE_ATTN, ROCM_AITER_FA, ROCM_ATTN # # Block layout: (num_blocks, 2, block_size, num_kv_heads, head_size) # - dim 1 separates key (index 0) and value (index 1) # - Used by: TRITON_ATTN # # The test creates KV caches in flash layout (the canonical format used by # tree attention). When a reference backend needs block layout we transpose # dims 0 and 1. # # Note: ROCM_ATTN uses flash layout for storage but its forward path calls # PagedAttention.split_kv_cache which reinterprets the raw memory as paged # layout (num_blocks, num_kv_heads, head_size//x, block_size, x). This is # a view-level incompatibility, not a transpose - see the TODO in # _get_available_reference_backends for details. # # TODO: Replace this mapping with a `KV_CACHE_LAYOUT` class attribute on each # AttentionImpl so the layout is self-documented by the backend itself, e.g.: # class TritonAttentionImpl(AttentionImpl): # KV_CACHE_LAYOUT = "block" # --------------------------------------------------------------------------- # _BLOCK_KV_LAYOUT_BACKENDS = frozenset( { AttentionBackendEnum.TRITON_ATTN, } ) # Backends whose do_kv_cache_update requires engine-level state (e.g. # ForwardContext) that is not available in this test harness, but whose # KV cache is flash layout and can be written with reshape_and_cache_flash. # When a backend is listed here, forward_attention() bypasses # do_kv_cache_update and writes directly to the cache. _NEEDS_DIRECT_CACHE_UPDATE = frozenset( { AttentionBackendEnum.ROCM_AITER_FA, } ) # Backends with known test-harness incompatibilities - see the TODOs # inside _get_available_reference_backends for details. _INCOMPATIBLE_REFERENCE_BACKENDS = frozenset( { AttentionBackendEnum.ROCM_AITER_FA, AttentionBackendEnum.ROCM_ATTN, } ) def _adapt_kv_cache_for_backend( kv_cache: torch.Tensor, backend: AttentionBackendEnum, ) -> torch.Tensor: """Convert kv_cache from flash layout ``(2, num_blocks, ...)`` to block layout ``(num_blocks, 2, ...)`` if the backend requires it. Returns the original tensor unchanged when no conversion is needed.""" if backend in _BLOCK_KV_LAYOUT_BACKENDS: return kv_cache.transpose(0, 1).contiguous() return kv_cache def _get_platform_default_backend() -> AttentionBackendEnum: """Ask the platform what backend it would auto-select at runtime.""" from vllm.v1.attention.selector import AttentionSelectorConfig config = AttentionSelectorConfig( block_size=32, kv_cache_dtype="auto", use_mla=False, use_sparse=False, head_size=128, dtype=torch.bfloat16, ) backend_path = current_platform.get_attn_backend_cls( selected_backend=None, attn_selector_config=config, ) for backend in AttentionBackendEnum: try: if backend.get_path() == backend_path: return backend except ValueError: continue raise RuntimeError( f"Platform returned backend path '{backend_path}' " f"that doesn't match any AttentionBackendEnum member." ) def _get_available_reference_backends() -> list[AttentionBackendEnum]: """Collect all reference backends the current platform can run. On CUDA this is just FLASH_ATTN. On ROCm this includes the platform default plus every backend the hardware supports, so the test validates tree attention against all of them. """ if current_platform.is_rocm(): backends: list[AttentionBackendEnum] = [] # 1. Whatever the platform would auto-select at runtime. default_backend = _get_platform_default_backend() if default_backend not in _INCOMPATIBLE_REFERENCE_BACKENDS: backends.append(default_backend) # 2. TRITON_ATTN - always available on ROCm. if AttentionBackendEnum.TRITON_ATTN not in backends: backends.append(AttentionBackendEnum.TRITON_ATTN) # TODO: Enable ROCM_ATTN. Its forward path uses # PagedAttention.split_kv_cache which reinterprets the raw # cache memory as paged layout: # key: (num_blocks, num_kv_heads, head_size//x, block_size, x) # value: (num_blocks, num_kv_heads, head_size, block_size) # Tree attention writes prefix data in NHD flash layout, so the # same bytes produce completely different values when read in # paged format. Supporting ROCM_ATTN would require writing # prefix data via PagedAttention.write_to_paged_cache into a # separate paged-format KV cache. # TODO: Enable ROCM_AITER_FA. Its metadata builder reads head # counts from the model config at construction time and # allocates extend_workspace with those dimensions. The test # uses independent head count parameters (num_heads=2/4, # num_kv_heads=2) that don't match the model config # (Llama-3-8B: 32 q heads, 8 kv heads), causing a head count # mismatch in flash_attn_varlen_func during extend_forward. # Fixing this requires either matching test head counts to the # model config or decoupling the builder from model config # head geometry. The direct cache update path # (_NEEDS_DIRECT_CACHE_UPDATE) is already in place for when # this is resolved. return backends # CUDA: flash attention. return [AttentionBackendEnum.FLASH_ATTN] class MockAttentionLayer(torch.nn.Module): _q_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda") _k_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda") _v_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda") layer_name = "mock_layer" def __init__(self): super().__init__() def forward(self, x): return x def forward_attention( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, kv_cache: torch.Tensor, block_table: torch.Tensor, slot_mapping: torch.Tensor, seqlen_k: int, backend: AttentionBackendEnum, spec_token_tree: str | None = None, num_spec_tokens: int = 0, ) -> torch.Tensor: """Run a single attention forward pass through the given backend. ``kv_cache`` is expected in **flash layout** ``(2, num_blocks, block_size, num_kv_heads, head_size)``. It is automatically converted when the target backend needs a different layout. """ batch_size, q_len, num_heads, dim_per_head = q.shape num_kv_heads = k.shape[-2] # Initialize the query and KV sequence lengths. query_start_loc = q_len * torch.arange( batch_size + 1, device=q.device, dtype=torch.int32 ) query_lens = torch.diff(query_start_loc) seq_lens = torch.full( (batch_size,), seqlen_k, device=q.device, dtype=torch.int32, ) context_lens = seq_lens - query_lens max_seq_len = int(seq_lens.max()) max_query_len = q_len num_actual_tokens = query_start_loc[-1] softmax_scale = q.shape[-1] ** (-0.5) layer = MockAttentionLayer() # Build common metadata. model_name = "meta-llama/Meta-Llama-3-8B" builder_cls, impl_cls = try_get_attention_backend(backend) vllm_config = create_vllm_config(model_name=model_name, max_model_len=max(seq_lens)) if spec_token_tree is not None: # Create speculative config if token tree is specified. vllm_config.speculative_config = SpeculativeConfig( target_model_config=vllm_config.model_config, target_parallel_config=ParallelConfig(), model=model_name, method="eagle", num_speculative_tokens=num_spec_tokens, speculative_token_tree=spec_token_tree, ) kv_cache_spec = create_standard_kv_cache_spec(vllm_config) builder = builder_cls(kv_cache_spec, [], vllm_config, q.device) common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, query_start_loc_cpu=query_start_loc.cpu(), seq_lens=seq_lens, _seq_lens_cpu=seq_lens.cpu(), _num_computed_tokens_cpu=context_lens.cpu(), num_reqs=batch_size, num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, max_seq_len=max_seq_len, block_table_tensor=block_table, slot_mapping=slot_mapping, ) # Build attention metadata. attn_metadata = builder.build( common_prefix_len=0, common_attn_metadata=common_attn_metadata, ) # Initialize the backend implementation. instance = impl_cls( num_heads=num_heads, head_size=dim_per_head, scale=softmax_scale, num_kv_heads=num_kv_heads, alibi_slopes=None, sliding_window=None, kv_cache_dtype="auto", ) # Adapt KV cache layout for this backend. adapted_kv_cache = _adapt_kv_cache_for_backend(kv_cache, backend) # Run forward pass and return output. query = q.view(-1, num_heads, dim_per_head) key = k.view(-1, num_kv_heads, dim_per_head) value = v.view(-1, num_kv_heads, dim_per_head) output = torch.empty_like(query) if not try_backend_includes_kv_cache_update(backend): if backend in _NEEDS_DIRECT_CACHE_UPDATE: # This backend's do_kv_cache_update requires engine-level # ForwardContext that isn't available in this test harness. # Write directly using reshape_and_cache_flash since the # KV cache layout is identical (flash layout, unbind on dim 0). key_cache, value_cache = adapted_kv_cache.unbind(0) torch.ops._C_cache_ops.reshape_and_cache_flash( key, value, key_cache, value_cache, attn_metadata.slot_mapping, "auto", layer._k_scale, layer._v_scale, ) else: instance.do_kv_cache_update( layer=layer, key=key, value=value, kv_cache=adapted_kv_cache, slot_mapping=attn_metadata.slot_mapping, ) return instance.forward( layer=layer, query=query, key=key, value=value, kv_cache=adapted_kv_cache.clone(), attn_metadata=attn_metadata, output=output, ) @pytest.mark.parametrize( "reference_backend", _get_available_reference_backends(), ids=lambda b: b.name, ) def test_tree_attn_correctness( reference_backend: AttentionBackendEnum, ) -> None: torch.manual_seed(42) torch.cuda.manual_seed_all(42) device = "cuda" tree_attn_masks = { # Chain. "[(0,), (0, 0), (0, 0, 0)]": torch.tensor( [ [1, 0, 0, 0], [1, 1, 0, 0], [1, 1, 1, 0], [1, 1, 1, 1], ], device=device, dtype=torch.int32, ), # Tree. "[(0,), (1,), (0, 0), (0, 1), (1, 0), (1, 1)]": torch.tensor( [ [1, 0, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0], [1, 0, 1, 0, 0, 0, 0], [1, 1, 0, 1, 0, 0, 0], [1, 1, 0, 0, 1, 0, 0], [1, 0, 1, 0, 0, 1, 0], [1, 0, 1, 0, 0, 0, 1], ], device=device, dtype=torch.int32, ), } dim_per_head = 128 num_kv_heads = 2 block_size = 32 max_sequence_length = 8192 randomize_blocks = True for batch_size in [1, 16, 32]: for num_heads in [2, 4]: for sequence_position in [16, 1024, 2048]: for spec_token_tree, tree_attn_mask in tree_attn_masks.items(): # Assert that the number of heads is divisible # by the number of KV heads. assert num_heads % num_kv_heads == 0 # Initialize q, k, and v. tree_size_q = tree_attn_mask.shape[0] seqlen_k = sequence_position + tree_size_q q = torch.randn( (batch_size, tree_size_q, num_heads, dim_per_head), device=device, dtype=torch.bfloat16, ) k = torch.randn( (batch_size, tree_size_q, num_kv_heads, dim_per_head), device=device, dtype=torch.bfloat16, ) v = torch.randn( (batch_size, tree_size_q, num_kv_heads, dim_per_head), device=device, dtype=torch.bfloat16, ) # KV cache in flash layout - the canonical format for # tree attention. forward_attention() handles conversion # when needed. assert max_sequence_length % block_size == 0 max_blocks_per_batch = max_sequence_length // block_size kv_cache = torch.randn( ( 2, batch_size * max_blocks_per_batch, block_size, num_kv_heads, dim_per_head, ), device=q.device, dtype=torch.bfloat16, ) num_alloc_blocks_per_batch = math.ceil(seqlen_k / block_size) block_table = torch.zeros( (batch_size, max_blocks_per_batch), device=q.device, dtype=torch.int32, ) block_ids = torch.arange( 0, batch_size * num_alloc_blocks_per_batch, device=q.device, dtype=torch.int32, ) if randomize_blocks: # Randomize the block ids. block_ids = block_ids[torch.randperm(block_ids.numel())] block_table[:, :num_alloc_blocks_per_batch] = block_ids.view( -1, num_alloc_blocks_per_batch ) # Set up the slot mapping for the input KVs. tree_positions = sequence_position + torch.arange( 0, tree_size_q, device=q.device, dtype=torch.int64, ).repeat(batch_size, 1) tree_slot_mapping = _gen_slot_mapping( tree_positions, block_table, block_size ) # Compute attention for the tree. tree_attn_output = forward_attention( q=q, k=k, v=v, kv_cache=kv_cache, block_table=block_table, slot_mapping=tree_slot_mapping, seqlen_k=seqlen_k, backend=AttentionBackendEnum.TREE_ATTN, spec_token_tree=spec_token_tree, num_spec_tokens=tree_size_q - 1, ).view(batch_size, -1, num_heads, dim_per_head) # Verify each branch against the reference backend. for q_index in range(tree_size_q): # Get the q, k, and v for the branch. branch_mask = tree_attn_mask[q_index, :] branch_indices = torch.nonzero(branch_mask, as_tuple=True)[0] q_len = branch_indices.shape[0] q_branch = q[:, branch_indices] k_branch = k[:, branch_indices] v_branch = v[:, branch_indices] # Setup slot mapping for the branch. branch_positions = sequence_position + torch.arange( 0, q_len, device=q.device, dtype=torch.int64, ).repeat(batch_size, 1) branch_slot_mapping = _gen_slot_mapping( branch_positions, block_table, block_size ) # Reference attention for this branch. ref_output = forward_attention( q=q_branch, k=k_branch, v=v_branch, kv_cache=kv_cache, block_table=block_table, slot_mapping=branch_slot_mapping, seqlen_k=sequence_position + q_len, backend=reference_backend, ).view(batch_size, -1, num_heads, dim_per_head) # Compare the outputs. assert torch.allclose( tree_attn_output[:, branch_indices], ref_output, atol=7.81e-3, ), ( f"outputs are not close for " f"reference_backend: {reference_backend.name}, " f"batch_size: {batch_size}, " f"num_heads: {num_heads}, " f"sequence_position: {sequence_position}, " f"tree_attn_mask: {tree_attn_mask}, " f"q_index: {q_index}." ) def _gen_slot_mapping( positions: torch.Tensor, block_table: torch.Tensor, block_size: int ): block_indices = positions // block_size blocks = block_table.gather(dim=1, index=block_indices) return (blocks * block_size + positions % block_size).view(-1)