# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for GDNAttentionMetadataBuilder.build() — specifically the reclassification of non-spec decodes as prefills when spec decodes exist. Covers the fix for https://github.com/vllm-project/vllm/issues/34845. """ from dataclasses import dataclass import pytest import torch from tests.v1.attention.utils import ( BatchSpec, create_common_attn_metadata, create_vllm_config, ) from vllm.config import SpeculativeConfig from vllm.v1.attention.backends.gdn_attn import ( GDNAttentionMetadata, GDNAttentionMetadataBuilder, ) from vllm.v1.kv_cache_interface import MambaSpec BLOCK_SIZE = 16 DEVICE = torch.device("cpu") @dataclass class GDNBuildTestCase: """Specification for a GDN metadata builder classification test.""" seq_lens: list[int] query_lens: list[int] num_decode_draft_tokens: list[int] | None # None = no spec config num_speculative_tokens: int expected_num_decodes: int expected_num_prefills: int expected_num_prefill_tokens: int expected_num_spec_decodes: int GDN_BUILD_TEST_CASES = { # The original #34845 crash: non-spec query_len=1 + spec decode "mixed_decode_and_spec_decode": GDNBuildTestCase( seq_lens=[65, 20], query_lens=[1, 3], num_decode_draft_tokens=[-1, 2], num_speculative_tokens=2, expected_num_decodes=0, expected_num_prefills=1, expected_num_prefill_tokens=1, expected_num_spec_decodes=1, ), # All requests are spec decodes — no reclassification needed "pure_spec_decode": GDNBuildTestCase( seq_lens=[50, 30], query_lens=[3, 3], num_decode_draft_tokens=[2, 2], num_speculative_tokens=2, expected_num_decodes=0, expected_num_prefills=0, expected_num_prefill_tokens=0, expected_num_spec_decodes=2, ), # No speculative config at all — standard decode path "pure_regular_decode": GDNBuildTestCase( seq_lens=[40, 30, 20], query_lens=[1, 1, 1], num_decode_draft_tokens=None, num_speculative_tokens=0, expected_num_decodes=3, expected_num_prefills=0, expected_num_prefill_tokens=0, expected_num_spec_decodes=0, ), # Multi-token prefill alongside spec decode — no decode to reclassify "spec_decode_with_real_prefill": GDNBuildTestCase( seq_lens=[100, 20], query_lens=[50, 3], num_decode_draft_tokens=[-1, 2], num_speculative_tokens=2, expected_num_decodes=0, expected_num_prefills=1, expected_num_prefill_tokens=50, expected_num_spec_decodes=1, ), # All three types in one batch — decode gets reclassified "prefill_decode_and_spec_decode": GDNBuildTestCase( seq_lens=[100, 65, 20], query_lens=[50, 1, 3], num_decode_draft_tokens=[-1, -1, 2], num_speculative_tokens=2, expected_num_decodes=0, expected_num_prefills=2, expected_num_prefill_tokens=51, expected_num_spec_decodes=1, ), # Multiple non-spec query_len=1 requests all reclassified "multiple_decodes_reclassified": GDNBuildTestCase( seq_lens=[40, 50, 60, 20], query_lens=[1, 1, 1, 3], num_decode_draft_tokens=[-1, -1, -1, 2], num_speculative_tokens=2, expected_num_decodes=0, expected_num_prefills=3, expected_num_prefill_tokens=3, expected_num_spec_decodes=1, ), # Zero-length padded sequence excluded from counts "zero_length_padding_with_spec": GDNBuildTestCase( seq_lens=[16, 65, 20], query_lens=[0, 1, 3], num_decode_draft_tokens=[-1, -1, 2], num_speculative_tokens=2, expected_num_decodes=0, expected_num_prefills=1, expected_num_prefill_tokens=1, expected_num_spec_decodes=1, ), } def _create_gdn_builder( num_speculative_tokens: int = 0, ) -> GDNAttentionMetadataBuilder: """Create a GDNAttentionMetadataBuilder with minimal config.""" vllm_config = create_vllm_config(block_size=BLOCK_SIZE) if num_speculative_tokens > 0: vllm_config.speculative_config = SpeculativeConfig( method="ngram", num_speculative_tokens=num_speculative_tokens, ) mamba_spec = MambaSpec( block_size=BLOCK_SIZE, shapes=((16, 64),), dtypes=(torch.float16,), ) return GDNAttentionMetadataBuilder( kv_cache_spec=mamba_spec, layer_names=["layer.0"], vllm_config=vllm_config, device=DEVICE, ) def _build( builder: GDNAttentionMetadataBuilder, batch_spec: BatchSpec, num_decode_draft_tokens: list[int] | None = None, ) -> GDNAttentionMetadata: """Build GDN attention metadata, optionally with spec-decode kwargs.""" common = create_common_attn_metadata(batch_spec, BLOCK_SIZE, DEVICE) kwargs: dict = {} if num_decode_draft_tokens is not None: kwargs["num_decode_draft_tokens_cpu"] = torch.tensor( num_decode_draft_tokens, dtype=torch.int32 ) kwargs["num_accepted_tokens"] = torch.ones( batch_spec.batch_size, dtype=torch.int32, device=DEVICE ) return builder.build(common_prefix_len=0, common_attn_metadata=common, **kwargs) @pytest.mark.parametrize( "test_case", GDN_BUILD_TEST_CASES.values(), ids=GDN_BUILD_TEST_CASES.keys() ) def test_gdn_build_classification(test_case: GDNBuildTestCase): """Test that GDN metadata builder classifies requests correctly.""" builder = _create_gdn_builder(test_case.num_speculative_tokens) batch = BatchSpec(seq_lens=test_case.seq_lens, query_lens=test_case.query_lens) meta = _build(builder, batch, test_case.num_decode_draft_tokens) assert meta.num_decodes == test_case.expected_num_decodes assert meta.num_prefills == test_case.expected_num_prefills assert meta.num_prefill_tokens == test_case.expected_num_prefill_tokens assert meta.num_spec_decodes == test_case.expected_num_spec_decodes def test_has_initial_state_after_reclassification(): """After reclassification, num_prefills > 0 so the prefill kernel path should compute has_initial_state. For the reclassified request with context_lens > 0, the corresponding entry must be True.""" builder = _create_gdn_builder(num_speculative_tokens=2) batch = BatchSpec(seq_lens=[65, 20], query_lens=[1, 3]) meta = _build(builder, batch, num_decode_draft_tokens=[-1, 2]) assert meta.num_prefills > 0, "reclassification should produce prefills" assert meta.has_initial_state is not None # req0 has context_lens = 65 - 1 = 64 > 0, so has_initial_state[0] = True assert meta.has_initial_state[0].item() is True