192 lines
6.7 KiB
Python
192 lines
6.7 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""Tests for GDNAttentionMetadataBuilder.build() — specifically the
|
|
reclassification of non-spec decodes as prefills when spec decodes exist.
|
|
Covers the fix for https://github.com/vllm-project/vllm/issues/34845.
|
|
"""
|
|
|
|
from dataclasses import dataclass
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from tests.v1.attention.utils import (
|
|
BatchSpec,
|
|
create_common_attn_metadata,
|
|
create_vllm_config,
|
|
)
|
|
from vllm.config import SpeculativeConfig
|
|
from vllm.v1.attention.backends.gdn_attn import (
|
|
GDNAttentionMetadata,
|
|
GDNAttentionMetadataBuilder,
|
|
)
|
|
from vllm.v1.kv_cache_interface import MambaSpec
|
|
|
|
BLOCK_SIZE = 16
|
|
DEVICE = torch.device("cpu")
|
|
|
|
|
|
@dataclass
|
|
class GDNBuildTestCase:
|
|
"""Specification for a GDN metadata builder classification test."""
|
|
|
|
seq_lens: list[int]
|
|
query_lens: list[int]
|
|
num_decode_draft_tokens: list[int] | None # None = no spec config
|
|
num_speculative_tokens: int
|
|
expected_num_decodes: int
|
|
expected_num_prefills: int
|
|
expected_num_prefill_tokens: int
|
|
expected_num_spec_decodes: int
|
|
|
|
|
|
GDN_BUILD_TEST_CASES = {
|
|
# The original #34845 crash: non-spec query_len=1 + spec decode
|
|
"mixed_decode_and_spec_decode": GDNBuildTestCase(
|
|
seq_lens=[65, 20],
|
|
query_lens=[1, 3],
|
|
num_decode_draft_tokens=[-1, 2],
|
|
num_speculative_tokens=2,
|
|
expected_num_decodes=0,
|
|
expected_num_prefills=1,
|
|
expected_num_prefill_tokens=1,
|
|
expected_num_spec_decodes=1,
|
|
),
|
|
# All requests are spec decodes — no reclassification needed
|
|
"pure_spec_decode": GDNBuildTestCase(
|
|
seq_lens=[50, 30],
|
|
query_lens=[3, 3],
|
|
num_decode_draft_tokens=[2, 2],
|
|
num_speculative_tokens=2,
|
|
expected_num_decodes=0,
|
|
expected_num_prefills=0,
|
|
expected_num_prefill_tokens=0,
|
|
expected_num_spec_decodes=2,
|
|
),
|
|
# No speculative config at all — standard decode path
|
|
"pure_regular_decode": GDNBuildTestCase(
|
|
seq_lens=[40, 30, 20],
|
|
query_lens=[1, 1, 1],
|
|
num_decode_draft_tokens=None,
|
|
num_speculative_tokens=0,
|
|
expected_num_decodes=3,
|
|
expected_num_prefills=0,
|
|
expected_num_prefill_tokens=0,
|
|
expected_num_spec_decodes=0,
|
|
),
|
|
# Multi-token prefill alongside spec decode — no decode to reclassify
|
|
"spec_decode_with_real_prefill": GDNBuildTestCase(
|
|
seq_lens=[100, 20],
|
|
query_lens=[50, 3],
|
|
num_decode_draft_tokens=[-1, 2],
|
|
num_speculative_tokens=2,
|
|
expected_num_decodes=0,
|
|
expected_num_prefills=1,
|
|
expected_num_prefill_tokens=50,
|
|
expected_num_spec_decodes=1,
|
|
),
|
|
# All three types in one batch — decode gets reclassified
|
|
"prefill_decode_and_spec_decode": GDNBuildTestCase(
|
|
seq_lens=[100, 65, 20],
|
|
query_lens=[50, 1, 3],
|
|
num_decode_draft_tokens=[-1, -1, 2],
|
|
num_speculative_tokens=2,
|
|
expected_num_decodes=0,
|
|
expected_num_prefills=2,
|
|
expected_num_prefill_tokens=51,
|
|
expected_num_spec_decodes=1,
|
|
),
|
|
# Multiple non-spec query_len=1 requests all reclassified
|
|
"multiple_decodes_reclassified": GDNBuildTestCase(
|
|
seq_lens=[40, 50, 60, 20],
|
|
query_lens=[1, 1, 1, 3],
|
|
num_decode_draft_tokens=[-1, -1, -1, 2],
|
|
num_speculative_tokens=2,
|
|
expected_num_decodes=0,
|
|
expected_num_prefills=3,
|
|
expected_num_prefill_tokens=3,
|
|
expected_num_spec_decodes=1,
|
|
),
|
|
# Zero-length padded sequence excluded from counts
|
|
"zero_length_padding_with_spec": GDNBuildTestCase(
|
|
seq_lens=[16, 65, 20],
|
|
query_lens=[0, 1, 3],
|
|
num_decode_draft_tokens=[-1, -1, 2],
|
|
num_speculative_tokens=2,
|
|
expected_num_decodes=0,
|
|
expected_num_prefills=1,
|
|
expected_num_prefill_tokens=1,
|
|
expected_num_spec_decodes=1,
|
|
),
|
|
}
|
|
|
|
|
|
def _create_gdn_builder(
|
|
num_speculative_tokens: int = 0,
|
|
) -> GDNAttentionMetadataBuilder:
|
|
"""Create a GDNAttentionMetadataBuilder with minimal config."""
|
|
vllm_config = create_vllm_config(block_size=BLOCK_SIZE)
|
|
if num_speculative_tokens > 0:
|
|
vllm_config.speculative_config = SpeculativeConfig(
|
|
method="ngram",
|
|
num_speculative_tokens=num_speculative_tokens,
|
|
)
|
|
mamba_spec = MambaSpec(
|
|
block_size=BLOCK_SIZE,
|
|
shapes=((16, 64),),
|
|
dtypes=(torch.float16,),
|
|
)
|
|
return GDNAttentionMetadataBuilder(
|
|
kv_cache_spec=mamba_spec,
|
|
layer_names=["layer.0"],
|
|
vllm_config=vllm_config,
|
|
device=DEVICE,
|
|
)
|
|
|
|
|
|
def _build(
|
|
builder: GDNAttentionMetadataBuilder,
|
|
batch_spec: BatchSpec,
|
|
num_decode_draft_tokens: list[int] | None = None,
|
|
) -> GDNAttentionMetadata:
|
|
"""Build GDN attention metadata, optionally with spec-decode kwargs."""
|
|
common = create_common_attn_metadata(batch_spec, BLOCK_SIZE, DEVICE)
|
|
kwargs: dict = {}
|
|
if num_decode_draft_tokens is not None:
|
|
kwargs["num_decode_draft_tokens_cpu"] = torch.tensor(
|
|
num_decode_draft_tokens, dtype=torch.int32
|
|
)
|
|
kwargs["num_accepted_tokens"] = torch.ones(
|
|
batch_spec.batch_size, dtype=torch.int32, device=DEVICE
|
|
)
|
|
return builder.build(common_prefix_len=0, common_attn_metadata=common, **kwargs)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"test_case", GDN_BUILD_TEST_CASES.values(), ids=GDN_BUILD_TEST_CASES.keys()
|
|
)
|
|
def test_gdn_build_classification(test_case: GDNBuildTestCase):
|
|
"""Test that GDN metadata builder classifies requests correctly."""
|
|
builder = _create_gdn_builder(test_case.num_speculative_tokens)
|
|
batch = BatchSpec(seq_lens=test_case.seq_lens, query_lens=test_case.query_lens)
|
|
meta = _build(builder, batch, test_case.num_decode_draft_tokens)
|
|
|
|
assert meta.num_decodes == test_case.expected_num_decodes
|
|
assert meta.num_prefills == test_case.expected_num_prefills
|
|
assert meta.num_prefill_tokens == test_case.expected_num_prefill_tokens
|
|
assert meta.num_spec_decodes == test_case.expected_num_spec_decodes
|
|
|
|
|
|
def test_has_initial_state_after_reclassification():
|
|
"""After reclassification, num_prefills > 0 so the prefill kernel path
|
|
should compute has_initial_state. For the reclassified request with
|
|
context_lens > 0, the corresponding entry must be True."""
|
|
builder = _create_gdn_builder(num_speculative_tokens=2)
|
|
batch = BatchSpec(seq_lens=[65, 20], query_lens=[1, 3])
|
|
meta = _build(builder, batch, num_decode_draft_tokens=[-1, 2])
|
|
|
|
assert meta.num_prefills > 0, "reclassification should produce prefills"
|
|
assert meta.has_initial_state is not None
|
|
# req0 has context_lens = 65 - 1 = 64 > 0, so has_initial_state[0] = True
|
|
assert meta.has_initial_state[0].item() is True
|