[Attention] Refactor attention metadata builder interface (#20466)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
@@ -6,6 +6,10 @@ from unittest import mock
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.v1.attention.utils import (BatchSpec, _Backend,
|
||||
create_common_attn_metadata,
|
||||
create_standard_kv_cache_spec,
|
||||
get_attention_backend)
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
||||
VllmConfig)
|
||||
@@ -64,13 +68,19 @@ def test_prepare_inputs():
|
||||
"""
|
||||
device = torch.device(current_platform.device_type)
|
||||
|
||||
# a = 4, b = 7, c = 5
|
||||
# q1 = 4, q2 = 7, q3 = 5
|
||||
# n1 = 1, n2 = 3, n3 = 2
|
||||
|
||||
# Cumulative lengths: [0, 4, 11, 16]
|
||||
cu_target_query_lens = torch.tensor([0, 4, 11, 16],
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
batch_spec = BatchSpec(
|
||||
seq_lens=[4, 7, 5],
|
||||
query_lens=[4, 7, 5],
|
||||
)
|
||||
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec,
|
||||
block_size=16,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Rejected tokens per request: [1, 3, 2]
|
||||
num_rejected_tokens = torch.tensor([1, 3, 2],
|
||||
@@ -104,15 +114,13 @@ def test_prepare_inputs():
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
proposer = _create_proposer("eagle", 1)
|
||||
|
||||
# n1 + n2 + n3 - a - b -c
|
||||
num_tokens = cu_target_query_lens[-1].item() - num_rejected_tokens.sum(
|
||||
).item()
|
||||
updated_metadata, token_indices = proposer.prepare_inputs(
|
||||
common_attn_metadata, num_rejected_tokens.cpu())
|
||||
|
||||
cu_num_tokens, token_indices = EagleProposer.prepare_inputs(
|
||||
cu_target_query_lens, num_rejected_tokens, num_tokens)
|
||||
|
||||
assert torch.equal(cu_num_tokens, expected_cu_num_tokens)
|
||||
assert torch.equal(updated_metadata.query_start_loc,
|
||||
expected_cu_num_tokens)
|
||||
assert token_indices.shape[0] == expected_cu_num_tokens[-1].item()
|
||||
assert torch.equal(token_indices, expected_token_indices)
|
||||
|
||||
@@ -209,6 +217,7 @@ def test_propose(num_speculative_tokens):
|
||||
seq_len_2 = 3
|
||||
total_tokens = seq_len_1 + seq_len_2
|
||||
vocab_size = 100
|
||||
seq_lens = [seq_len_1, seq_len_2]
|
||||
|
||||
# Create proposer first so we can use its actual hidden_size
|
||||
proposer = _create_proposer("eagle", num_speculative_tokens)
|
||||
@@ -270,9 +279,16 @@ def test_propose(num_speculative_tokens):
|
||||
proposer.attn_layer_names = ["layer.0"]
|
||||
|
||||
# Create input tensors
|
||||
cu_num_tokens = torch.tensor([0, seq_len_1, total_tokens],
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
batch_spec = BatchSpec(
|
||||
seq_lens=seq_lens,
|
||||
query_lens=seq_lens,
|
||||
)
|
||||
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec,
|
||||
block_size=16,
|
||||
device=device,
|
||||
)
|
||||
|
||||
target_token_ids = torch.randint(0,
|
||||
vocab_size, (total_tokens, ),
|
||||
@@ -284,25 +300,29 @@ def test_propose(num_speculative_tokens):
|
||||
target_hidden_states = torch.randn(total_tokens,
|
||||
hidden_size,
|
||||
device=device)
|
||||
target_slot_mapping = torch.randint(0,
|
||||
100, (total_tokens, ),
|
||||
device=device)
|
||||
next_token_ids = torch.randint(0,
|
||||
vocab_size, (batch_size, ),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
block_table = torch.randint(0, 10, (batch_size, 10), device=device)
|
||||
|
||||
sampling_metadata = mock.MagicMock()
|
||||
|
||||
# Call the method under test
|
||||
attn_metadata_builder_cls, _ = get_attention_backend(
|
||||
_Backend.FLASH_ATTN_VLLM_V1)
|
||||
attn_metadata_builder = attn_metadata_builder_cls(
|
||||
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
|
||||
vllm_config=proposer.vllm_config,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Mock runner for attention metadata building
|
||||
proposer.runner = mock.MagicMock()
|
||||
proposer.runner.attn_metadata_builders = [attn_metadata_builder]
|
||||
|
||||
result = proposer.propose(target_token_ids=target_token_ids,
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=target_hidden_states,
|
||||
target_slot_mapping=target_slot_mapping,
|
||||
next_token_ids=next_token_ids,
|
||||
cu_num_tokens=cu_num_tokens,
|
||||
block_table=block_table,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
sampling_metadata=sampling_metadata)
|
||||
|
||||
assert result.shape == (batch_size, num_speculative_tokens)
|
||||
|
||||
Reference in New Issue
Block a user