[Attention] Refactor attention metadata builder interface (#20466)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson
2025-07-17 00:44:25 -04:00
committed by GitHub
parent 28a6d5423d
commit 76b494444f
18 changed files with 1441 additions and 772 deletions

View File

@@ -6,6 +6,10 @@ from unittest import mock
import pytest
import torch
from tests.v1.attention.utils import (BatchSpec, _Backend,
create_common_attn_metadata,
create_standard_kv_cache_spec,
get_attention_backend)
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VllmConfig)
@@ -64,13 +68,19 @@ def test_prepare_inputs():
"""
device = torch.device(current_platform.device_type)
# a = 4, b = 7, c = 5
# q1 = 4, q2 = 7, q3 = 5
# n1 = 1, n2 = 3, n3 = 2
# Cumulative lengths: [0, 4, 11, 16]
cu_target_query_lens = torch.tensor([0, 4, 11, 16],
dtype=torch.int32,
device=device)
batch_spec = BatchSpec(
seq_lens=[4, 7, 5],
query_lens=[4, 7, 5],
)
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=16,
device=device,
)
# Rejected tokens per request: [1, 3, 2]
num_rejected_tokens = torch.tensor([1, 3, 2],
@@ -104,15 +114,13 @@ def test_prepare_inputs():
],
dtype=torch.int32,
device=device)
proposer = _create_proposer("eagle", 1)
# n1 + n2 + n3 - a - b -c
num_tokens = cu_target_query_lens[-1].item() - num_rejected_tokens.sum(
).item()
updated_metadata, token_indices = proposer.prepare_inputs(
common_attn_metadata, num_rejected_tokens.cpu())
cu_num_tokens, token_indices = EagleProposer.prepare_inputs(
cu_target_query_lens, num_rejected_tokens, num_tokens)
assert torch.equal(cu_num_tokens, expected_cu_num_tokens)
assert torch.equal(updated_metadata.query_start_loc,
expected_cu_num_tokens)
assert token_indices.shape[0] == expected_cu_num_tokens[-1].item()
assert torch.equal(token_indices, expected_token_indices)
@@ -209,6 +217,7 @@ def test_propose(num_speculative_tokens):
seq_len_2 = 3
total_tokens = seq_len_1 + seq_len_2
vocab_size = 100
seq_lens = [seq_len_1, seq_len_2]
# Create proposer first so we can use its actual hidden_size
proposer = _create_proposer("eagle", num_speculative_tokens)
@@ -270,9 +279,16 @@ def test_propose(num_speculative_tokens):
proposer.attn_layer_names = ["layer.0"]
# Create input tensors
cu_num_tokens = torch.tensor([0, seq_len_1, total_tokens],
dtype=torch.int32,
device=device)
batch_spec = BatchSpec(
seq_lens=seq_lens,
query_lens=seq_lens,
)
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=16,
device=device,
)
target_token_ids = torch.randint(0,
vocab_size, (total_tokens, ),
@@ -284,25 +300,29 @@ def test_propose(num_speculative_tokens):
target_hidden_states = torch.randn(total_tokens,
hidden_size,
device=device)
target_slot_mapping = torch.randint(0,
100, (total_tokens, ),
device=device)
next_token_ids = torch.randint(0,
vocab_size, (batch_size, ),
dtype=torch.int32,
device=device)
block_table = torch.randint(0, 10, (batch_size, 10), device=device)
sampling_metadata = mock.MagicMock()
# Call the method under test
attn_metadata_builder_cls, _ = get_attention_backend(
_Backend.FLASH_ATTN_VLLM_V1)
attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
vllm_config=proposer.vllm_config,
device=device,
)
# Mock runner for attention metadata building
proposer.runner = mock.MagicMock()
proposer.runner.attn_metadata_builders = [attn_metadata_builder]
result = proposer.propose(target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
target_slot_mapping=target_slot_mapping,
next_token_ids=next_token_ids,
cu_num_tokens=cu_num_tokens,
block_table=block_table,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata)
assert result.shape == (batch_size, num_speculative_tokens)