diff --git a/tests/v1/attention/test_gdn_metadata_builder.py b/tests/v1/attention/test_gdn_metadata_builder.py new file mode 100644 index 000000000..6576a9bf3 --- /dev/null +++ b/tests/v1/attention/test_gdn_metadata_builder.py @@ -0,0 +1,191 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for GDNAttentionMetadataBuilder.build() — specifically the +reclassification of non-spec decodes as prefills when spec decodes exist. +Covers the fix for https://github.com/vllm-project/vllm/issues/34845. +""" + +from dataclasses import dataclass + +import pytest +import torch + +from tests.v1.attention.utils import ( + BatchSpec, + create_common_attn_metadata, + create_vllm_config, +) +from vllm.config import SpeculativeConfig +from vllm.v1.attention.backends.gdn_attn import ( + GDNAttentionMetadata, + GDNAttentionMetadataBuilder, +) +from vllm.v1.kv_cache_interface import MambaSpec + +BLOCK_SIZE = 16 +DEVICE = torch.device("cpu") + + +@dataclass +class GDNBuildTestCase: + """Specification for a GDN metadata builder classification test.""" + + seq_lens: list[int] + query_lens: list[int] + num_decode_draft_tokens: list[int] | None # None = no spec config + num_speculative_tokens: int + expected_num_decodes: int + expected_num_prefills: int + expected_num_prefill_tokens: int + expected_num_spec_decodes: int + + +GDN_BUILD_TEST_CASES = { + # The original #34845 crash: non-spec query_len=1 + spec decode + "mixed_decode_and_spec_decode": GDNBuildTestCase( + seq_lens=[65, 20], + query_lens=[1, 3], + num_decode_draft_tokens=[-1, 2], + num_speculative_tokens=2, + expected_num_decodes=0, + expected_num_prefills=1, + expected_num_prefill_tokens=1, + expected_num_spec_decodes=1, + ), + # All requests are spec decodes — no reclassification needed + "pure_spec_decode": GDNBuildTestCase( + seq_lens=[50, 30], + query_lens=[3, 3], + num_decode_draft_tokens=[2, 2], + num_speculative_tokens=2, + expected_num_decodes=0, + expected_num_prefills=0, + expected_num_prefill_tokens=0, + expected_num_spec_decodes=2, + ), + # No speculative config at all — standard decode path + "pure_regular_decode": GDNBuildTestCase( + seq_lens=[40, 30, 20], + query_lens=[1, 1, 1], + num_decode_draft_tokens=None, + num_speculative_tokens=0, + expected_num_decodes=3, + expected_num_prefills=0, + expected_num_prefill_tokens=0, + expected_num_spec_decodes=0, + ), + # Multi-token prefill alongside spec decode — no decode to reclassify + "spec_decode_with_real_prefill": GDNBuildTestCase( + seq_lens=[100, 20], + query_lens=[50, 3], + num_decode_draft_tokens=[-1, 2], + num_speculative_tokens=2, + expected_num_decodes=0, + expected_num_prefills=1, + expected_num_prefill_tokens=50, + expected_num_spec_decodes=1, + ), + # All three types in one batch — decode gets reclassified + "prefill_decode_and_spec_decode": GDNBuildTestCase( + seq_lens=[100, 65, 20], + query_lens=[50, 1, 3], + num_decode_draft_tokens=[-1, -1, 2], + num_speculative_tokens=2, + expected_num_decodes=0, + expected_num_prefills=2, + expected_num_prefill_tokens=51, + expected_num_spec_decodes=1, + ), + # Multiple non-spec query_len=1 requests all reclassified + "multiple_decodes_reclassified": GDNBuildTestCase( + seq_lens=[40, 50, 60, 20], + query_lens=[1, 1, 1, 3], + num_decode_draft_tokens=[-1, -1, -1, 2], + num_speculative_tokens=2, + expected_num_decodes=0, + expected_num_prefills=3, + expected_num_prefill_tokens=3, + expected_num_spec_decodes=1, + ), + # Zero-length padded sequence excluded from counts + "zero_length_padding_with_spec": GDNBuildTestCase( + seq_lens=[16, 65, 20], + query_lens=[0, 1, 3], + num_decode_draft_tokens=[-1, -1, 2], + num_speculative_tokens=2, + expected_num_decodes=0, + expected_num_prefills=1, + expected_num_prefill_tokens=1, + expected_num_spec_decodes=1, + ), +} + + +def _create_gdn_builder( + num_speculative_tokens: int = 0, +) -> GDNAttentionMetadataBuilder: + """Create a GDNAttentionMetadataBuilder with minimal config.""" + vllm_config = create_vllm_config(block_size=BLOCK_SIZE) + if num_speculative_tokens > 0: + vllm_config.speculative_config = SpeculativeConfig( + method="ngram", + num_speculative_tokens=num_speculative_tokens, + ) + mamba_spec = MambaSpec( + block_size=BLOCK_SIZE, + shapes=((16, 64),), + dtypes=(torch.float16,), + ) + return GDNAttentionMetadataBuilder( + kv_cache_spec=mamba_spec, + layer_names=["layer.0"], + vllm_config=vllm_config, + device=DEVICE, + ) + + +def _build( + builder: GDNAttentionMetadataBuilder, + batch_spec: BatchSpec, + num_decode_draft_tokens: list[int] | None = None, +) -> GDNAttentionMetadata: + """Build GDN attention metadata, optionally with spec-decode kwargs.""" + common = create_common_attn_metadata(batch_spec, BLOCK_SIZE, DEVICE) + kwargs: dict = {} + if num_decode_draft_tokens is not None: + kwargs["num_decode_draft_tokens_cpu"] = torch.tensor( + num_decode_draft_tokens, dtype=torch.int32 + ) + kwargs["num_accepted_tokens"] = torch.ones( + batch_spec.batch_size, dtype=torch.int32, device=DEVICE + ) + return builder.build(common_prefix_len=0, common_attn_metadata=common, **kwargs) + + +@pytest.mark.parametrize( + "test_case", GDN_BUILD_TEST_CASES.values(), ids=GDN_BUILD_TEST_CASES.keys() +) +def test_gdn_build_classification(test_case: GDNBuildTestCase): + """Test that GDN metadata builder classifies requests correctly.""" + builder = _create_gdn_builder(test_case.num_speculative_tokens) + batch = BatchSpec(seq_lens=test_case.seq_lens, query_lens=test_case.query_lens) + meta = _build(builder, batch, test_case.num_decode_draft_tokens) + + assert meta.num_decodes == test_case.expected_num_decodes + assert meta.num_prefills == test_case.expected_num_prefills + assert meta.num_prefill_tokens == test_case.expected_num_prefill_tokens + assert meta.num_spec_decodes == test_case.expected_num_spec_decodes + + +def test_has_initial_state_after_reclassification(): + """After reclassification, num_prefills > 0 so the prefill kernel path + should compute has_initial_state. For the reclassified request with + context_lens > 0, the corresponding entry must be True.""" + builder = _create_gdn_builder(num_speculative_tokens=2) + batch = BatchSpec(seq_lens=[65, 20], query_lens=[1, 3]) + meta = _build(builder, batch, num_decode_draft_tokens=[-1, 2]) + + assert meta.num_prefills > 0, "reclassification should produce prefills" + assert meta.has_initial_state is not None + # req0 has context_lens = 65 - 1 = 64 > 0, so has_initial_state[0] = True + assert meta.has_initial_state[0].item() is True diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index a2dd05b4b..574cc87e7 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -220,6 +220,16 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] query_lens_cpu.sum().item() - num_prefill_tokens - num_decode_tokens ) + # num_decodes and num_spec_decodes are mutually exclusive. + # Reclassify non-spec decodes as prefills when spec decodes + # exist — the prefill kernel handles 1-token sequences with + # initial state correctly, producing identical results. + if num_decodes > 0 and num_spec_decodes > 0: + num_prefills += num_decodes + num_prefill_tokens += num_decode_tokens + num_decodes = 0 + num_decode_tokens = 0 + if num_prefills == 0 and num_decodes == 0: spec_token_size = min( num_spec_decodes * (self.num_spec + 1),