Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -8,13 +8,22 @@ import pytest
import torch
from tests.utils import get_attn_backend_list_based_on_platform
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
create_standard_kv_cache_spec,
get_attention_backend)
from tests.v1.attention.utils import (
BatchSpec,
create_common_attn_metadata,
create_standard_kv_cache_spec,
get_attention_backend,
)
from vllm.attention.backends.registry import _Backend
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VllmConfig)
from vllm.config import (
CacheConfig,
DeviceConfig,
ModelConfig,
ParallelConfig,
SchedulerConfig,
SpeculativeConfig,
VllmConfig,
)
from vllm.config.load import LoadConfig
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.platforms import current_platform
@@ -32,9 +41,7 @@ def _create_proposer(
num_speculative_tokens: int,
speculative_token_tree: Optional[list[tuple[int, ...]]] = None,
) -> EagleProposer:
model_config = ModelConfig(model=model_dir,
runner="generate",
max_model_len=100)
model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100)
# Choose model directory based on method
draft_model_dir = eagle_dir if method == "eagle" else eagle3_dir
@@ -60,10 +67,10 @@ def _create_proposer(
device_config=DeviceConfig(device=current_platform.device_type),
parallel_config=ParallelConfig(),
load_config=LoadConfig(),
scheduler_config=SchedulerConfig())
scheduler_config=SchedulerConfig(),
)
return EagleProposer(vllm_config=vllm_config,
device=current_platform.device_type)
return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type)
def test_prepare_next_token_ids():
@@ -82,7 +89,7 @@ def test_prepare_next_token_ids():
query_lens=[num_speculative_tokens + 1] * num_requests,
)
req_ids = [f"req_{i+1}" for i in range(num_requests)]
req_ids = [f"req_{i + 1}" for i in range(num_requests)]
mock_input_batch = mock.MagicMock(spec=InputBatch)
mock_input_batch.req_ids = req_ids
mock_input_batch.num_reqs = num_requests
@@ -101,24 +108,26 @@ def test_prepare_next_token_ids():
[0, 1, -1, -1, -1], # 1 accepted, 3 rejected, "1" sampled
[0, 1, 2, 3, 4], # all accepted, "4" sampled
[-1, -1, -1, -1, -1], # sampling skipped, use backup token "30"
[-1, -1, -1, -1, -1] # this request will be discarded
[-1, -1, -1, -1, -1], # this request will be discarded
]
sampled_token_ids_tensor = torch.tensor(sampled_token_ids,
dtype=torch.int32,
device=device)
sampled_token_ids_cpu = [[i for i in seq if i != -1]
for seq in sampled_token_ids]
sampled_token_ids_tensor = torch.tensor(
sampled_token_ids, dtype=torch.int32, device=device
)
sampled_token_ids_cpu = [[i for i in seq if i != -1] for seq in sampled_token_ids]
expected_next_token_ids_cpu = [1, 4, 30, 40]
expected_next_token_ids_tensor = torch.tensor(expected_next_token_ids_cpu,
dtype=torch.int32,
device=device)
expected_next_token_ids_tensor = torch.tensor(
expected_next_token_ids_cpu, dtype=torch.int32, device=device
)
proposer = _create_proposer("eagle", num_speculative_tokens)
next_token_ids_from_cpu = proposer.prepare_next_token_ids_cpu(
sampled_token_ids_cpu, mock_requests, mock_input_batch,
mock_num_scheduled_tokens)
sampled_token_ids_cpu,
mock_requests,
mock_input_batch,
mock_num_scheduled_tokens,
)
assert torch.equal(next_token_ids_from_cpu, expected_next_token_ids_tensor)
@@ -131,19 +140,23 @@ def test_prepare_next_token_ids():
discarded_req_indices = torch.tensor([3], dtype=torch.int64, device=device)
num_discarded_reqs = 1
expected_valid_sampled_tokens_count = torch.tensor([2, 5, 0, 0],
dtype=torch.int32,
device=device)
expected_valid_sampled_tokens_count = torch.tensor(
[2, 5, 0, 0], dtype=torch.int32, device=device
)
next_token_ids_from_padded, valid_sampled_tokens_count = \
next_token_ids_from_padded, valid_sampled_tokens_count = (
proposer.prepare_next_token_ids_padded(
common_attn_metadata, sampled_token_ids_tensor, mock_requests,
mock_input_batch, discarded_req_indices, num_discarded_reqs)
common_attn_metadata,
sampled_token_ids_tensor,
mock_requests,
mock_input_batch,
discarded_req_indices,
num_discarded_reqs,
)
)
assert torch.equal(next_token_ids_from_padded,
expected_next_token_ids_tensor)
assert torch.equal(valid_sampled_tokens_count,
expected_valid_sampled_tokens_count)
assert torch.equal(next_token_ids_from_padded, expected_next_token_ids_tensor)
assert torch.equal(valid_sampled_tokens_count, expected_valid_sampled_tokens_count)
def test_prepare_inputs():
@@ -183,21 +196,27 @@ def test_prepare_inputs():
sampled_token_ids = [
[ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, BONUS_TOKEN],
[
ACCEPT_TOKEN, ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN,
REJECT_TOKEN, REJECT_TOKEN, BONUS_TOKEN
ACCEPT_TOKEN,
ACCEPT_TOKEN,
ACCEPT_TOKEN,
REJECT_TOKEN,
REJECT_TOKEN,
REJECT_TOKEN,
BONUS_TOKEN,
],
[ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, REJECT_TOKEN, BONUS_TOKEN]
[ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, REJECT_TOKEN, BONUS_TOKEN],
]
sampled_token_ids = [
[i for i in seq if i != REJECT_TOKEN] for seq in sampled_token_ids
]
sampled_token_ids = [[i for i in seq if i != REJECT_TOKEN]
for seq in sampled_token_ids]
# Expected calculations:
# query_len_per_req = [4, 7, 5]
# num_tokens_per_req = [3, 4, 3] (after subtracting rejected tokens)
# Expected cumulative counts: [0, 3, 7, 10]
expected_cu_num_tokens = torch.tensor([0, 3, 7, 10],
dtype=torch.int32,
device=device)
expected_cu_num_tokens = torch.tensor(
[0, 3, 7, 10], dtype=torch.int32, device=device
)
# Expected token indices (mapped from original positions):
# First request: indices 0, 1, 2 (keeping first 3 from positions 0-3)
@@ -214,17 +233,18 @@ def test_prepare_inputs():
7, # Second request: 4 tokens (7-3)
11,
12,
13 # Third request: 3 tokens (5-2)
13, # Third request: 3 tokens (5-2)
],
dtype=torch.int32,
device=device)
device=device,
)
proposer = _create_proposer("eagle", 1)
updated_metadata, token_indices = proposer.prepare_inputs(
common_attn_metadata, sampled_token_ids, num_draft_tokens)
common_attn_metadata, sampled_token_ids, num_draft_tokens
)
assert torch.equal(updated_metadata.query_start_loc,
expected_cu_num_tokens)
assert torch.equal(updated_metadata.query_start_loc, expected_cu_num_tokens)
assert token_indices.shape[0] == expected_cu_num_tokens[-1].item()
assert torch.equal(token_indices, expected_token_indices)
@@ -249,12 +269,12 @@ def test_prepare_inputs_padded():
device = torch.device(current_platform.device_type)
expected_token_indices = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8],
dtype=torch.int32,
device=device)
expected_token_indices_to_sample = torch.tensor([1, 5, 6],
dtype=torch.int32,
device=device)
expected_token_indices = torch.tensor(
[0, 1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.int32, device=device
)
expected_token_indices_to_sample = torch.tensor(
[1, 5, 6], dtype=torch.int32, device=device
)
num_speculative_tokens = 2
batch_spec = BatchSpec(
@@ -269,9 +289,9 @@ def test_prepare_inputs_padded():
)
# Needed for cu_num_draft_tokens, which is expected to be [3, 6, 9]
expected_query_start_loc = torch.tensor([0, 3, 6, 9],
dtype=torch.int32,
device=device)
expected_query_start_loc = torch.tensor(
[0, 3, 6, 9], dtype=torch.int32, device=device
)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
draft_token_ids=[[0] * num_speculative_tokens] * 3,
device=device,
@@ -280,43 +300,48 @@ def test_prepare_inputs_padded():
# num_rejected_tokens = [1, 0, 2]
# num_draft_tokens = [2, 2, 2]
# valid_sampled_tokens_count = num_draft_tokens + 1 - num_rejected_tokens
valid_sampled_tokens_count = torch.tensor([2, 3, 1],
dtype=torch.int32,
device=device)
valid_sampled_tokens_count = torch.tensor(
[2, 3, 1], dtype=torch.int32, device=device
)
proposer = _create_proposer("eagle", num_speculative_tokens)
output_metadata, token_indices, token_indices_to_sample = \
output_metadata, token_indices, token_indices_to_sample = (
proposer.prepare_inputs_padded(
common_attn_metadata,
spec_decode_metadata,
valid_sampled_tokens_count)
common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count
)
)
assert output_metadata.max_query_len == 3
assert torch.equal(output_metadata.query_start_loc,
expected_query_start_loc)
assert torch.equal(output_metadata.query_start_loc, expected_query_start_loc)
assert torch.equal(token_indices, expected_token_indices)
assert torch.equal(token_indices_to_sample,
expected_token_indices_to_sample)
assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample)
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
@pytest.mark.parametrize("attn_backend",
get_attn_backend_list_based_on_platform())
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
@pytest.mark.parametrize("pp_size", [1, 2])
@pytest.mark.parametrize("use_distinct_embed_tokens", [True, False])
@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
@mock.patch('vllm.v1.spec_decode.eagle.get_model')
def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
attn_backend, pp_size, use_distinct_embed_tokens,
monkeypatch):
@mock.patch("vllm.v1.spec_decode.eagle.get_pp_group")
@mock.patch("vllm.v1.spec_decode.eagle.get_layers_from_vllm_config")
@mock.patch("vllm.v1.spec_decode.eagle.get_model")
def test_load_model(
mock_get_model,
mock_get_layers,
mock_get_pp_group,
method,
attn_backend,
pp_size,
use_distinct_embed_tokens,
monkeypatch,
):
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if (attn_backend == "TRITON_ATTN" and not current_platform.is_rocm()):
pytest.skip("TRITON_ATTN does not support "
"multi-token eagle spec decode on current platform")
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
pytest.skip(
"TRITON_ATTN does not support "
"multi-token eagle spec decode on current platform"
)
if attn_backend == "FLASH_ATTN" and current_platform.is_rocm():
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
@@ -335,20 +360,20 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
# Setup mocks for attention layers
target_attn_layers = {
"target_attn_1": mock.MagicMock(),
"target_attn_2": mock.MagicMock()
"target_attn_2": mock.MagicMock(),
}
target_indx_layers: dict[str, mock.MagicMock] = {}
# Draft model has one extra attention layer compared to target model
all_attn_layers = {
**target_attn_layers, "draft_extra_attn": mock.MagicMock()
}
all_attn_layers = {**target_attn_layers, "draft_extra_attn": mock.MagicMock()}
all_indx_layers: dict[str, mock.MagicMock] = {}
# Make mock_get_layers return different values for each call
mock_get_layers.side_effect = [
target_attn_layers, target_indx_layers, all_attn_layers,
all_indx_layers
target_attn_layers,
target_indx_layers,
all_attn_layers,
all_indx_layers,
]
# Setup mock for pp group to return the appropriate value for world size
@@ -367,6 +392,7 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
target_model.model.embed_tokens.weight.shape = (131072, 4096)
from vllm.model_executor.models import SupportsMultiModal
assert not isinstance(target_model, SupportsMultiModal)
if method == "eagle":
@@ -388,30 +414,30 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
# Verify that the embed tokens are set correctly
# If pp_size is > 1, the embed tokens should be distinct
if pp_size > 1 or use_distinct_embed_tokens:
assert proposer.model.model.embed_tokens != \
target_model.model.embed_tokens
assert proposer.model.model.embed_tokens != target_model.model.embed_tokens
else:
# When pp_size is 1 and the draft and target models have
# embed_tokens of the same shape, they should be shared.
assert proposer.model.model.embed_tokens == \
target_model.model.embed_tokens
assert proposer.model.model.embed_tokens == target_model.model.embed_tokens
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
@pytest.mark.parametrize("attn_backend",
get_attn_backend_list_based_on_platform())
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if (attn_backend == "TRITON_ATTN" and not current_platform.is_rocm()):
pytest.skip("TRITON_ATTN does not support "
"multi-token eagle spec decode on current platform")
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
pytest.skip(
"TRITON_ATTN does not support "
"multi-token eagle spec decode on current platform"
)
if (attn_backend == "TREE_ATTN"):
pytest.skip("TREE_ATTN is tested separately in test_propose_tree"
"because it requires special input mocking.")
if attn_backend == "TREE_ATTN":
pytest.skip(
"TREE_ATTN is tested separately in test_propose_tree"
"because it requires special input mocking."
)
if attn_backend == "FLASH_ATTN" and current_platform.is_rocm():
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
@@ -498,31 +524,22 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
device=device,
)
target_token_ids = torch.randint(0,
vocab_size, (total_tokens, ),
device=device)
target_positions = torch.cat([
torch.arange(seq_len_1, device=device),
torch.arange(seq_len_2, device=device)
])
target_hidden_states = torch.randn(total_tokens,
hidden_size,
device=device)
next_token_ids = torch.randint(0,
vocab_size, (batch_size, ),
dtype=torch.int32,
device=device)
target_token_ids = torch.randint(0, vocab_size, (total_tokens,), device=device)
target_positions = torch.cat(
[torch.arange(seq_len_1, device=device), torch.arange(seq_len_2, device=device)]
)
target_hidden_states = torch.randn(total_tokens, hidden_size, device=device)
next_token_ids = torch.randint(
0, vocab_size, (batch_size,), dtype=torch.int32, device=device
)
sampling_metadata = mock.MagicMock()
if attn_backend == "FLASH_ATTN":
attn_metadata_builder_cls, _ = get_attention_backend(
_Backend.FLASH_ATTN)
attn_metadata_builder_cls, _ = get_attention_backend(_Backend.FLASH_ATTN)
elif attn_backend == "TRITON_ATTN":
attn_metadata_builder_cls, _ = get_attention_backend(
_Backend.TRITON_ATTN)
attn_metadata_builder_cls, _ = get_attention_backend(_Backend.TRITON_ATTN)
elif attn_backend == "TREE_ATTN":
attn_metadata_builder_cls, _ = get_attention_backend(
_Backend.TREE_ATTN)
attn_metadata_builder_cls, _ = get_attention_backend(_Backend.TREE_ATTN)
else:
raise ValueError(f"Unsupported attention backend: {attn_backend}")
@@ -536,18 +553,22 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
# Mock runner for attention metadata building
proposer.runner = mock.MagicMock()
proposer.runner.attn_groups.append([mock.MagicMock()])
proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \
attn_metadata_builder
proposer.runner.attn_groups[0][
0
].get_metadata_builder.return_value = attn_metadata_builder
proposer._get_attention_metadata_builder = mock.MagicMock(
return_value=attn_metadata_builder)
return_value=attn_metadata_builder
)
result = proposer.propose(target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
last_token_indices=None,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata)
result = proposer.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
last_token_indices=None,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata,
)
assert result.shape == (batch_size, num_speculative_tokens)
@@ -556,13 +577,14 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
# Example for num_speculative_tokens=1:
# [[42], [60]]
expected_tokens = torch.tensor(
[[base_token_ids[0]], [base_token_ids[1]]], device=device)
[[base_token_ids[0]], [base_token_ids[1]]], device=device
)
else:
# Example for num_speculative_tokens=3:
# [[42, 43, 44], [60, 61, 62]]
expected_tokens = torch.zeros((batch_size, num_speculative_tokens),
dtype=torch.int64,
device=device)
expected_tokens = torch.zeros(
(batch_size, num_speculative_tokens), dtype=torch.int64, device=device
)
for i in range(batch_size):
for j in range(num_speculative_tokens):
expected_tokens[i, j] = base_token_ids[i] + j
@@ -574,12 +596,12 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
@pytest.mark.parametrize(
"spec_token_tree",
[
[(0, )], # A single token
[(0, ), (0, 0), (0, 0, 0)], # Chain
[(0, ), (1, ), (2, )], # Parallel
[(0, ), (1, ), (2, ), (0, 0), (0, 1), (1, 0), (1, 1), (2, 0),
(2, 1)], # Tree
])
[(0,)], # A single token
[(0,), (0, 0), (0, 0, 0)], # Chain
[(0,), (1,), (2,)], # Parallel
[(0,), (1,), (2,), (0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1)], # Tree
],
)
def test_propose_tree(spec_token_tree):
# Get GPU device.
device = torch.device(current_platform.device_type)
@@ -594,9 +616,9 @@ def test_propose_tree(spec_token_tree):
num_speculative_tokens = len(spec_token_tree)
# Create proposer first so we can use its actual hidden_size.
proposer = _create_proposer("eagle",
num_speculative_tokens,
speculative_token_tree=spec_token_tree)
proposer = _create_proposer(
"eagle", num_speculative_tokens, speculative_token_tree=spec_token_tree
)
# Get the hidden_size from the proposer to ensure consistency.
hidden_size = proposer.hidden_size
@@ -617,32 +639,31 @@ def test_propose_tree(spec_token_tree):
model_mock = mock.MagicMock()
# Mock the model forward calls.
forward_returns = [(torch.zeros(total_tokens, hidden_size, device=device),
torch.zeros(total_tokens, hidden_size, device=device))]
forward_returns = [
(
torch.zeros(total_tokens, hidden_size, device=device),
torch.zeros(total_tokens, hidden_size, device=device),
)
]
for cu_num_drafts in proposer.cu_drafts_per_level:
h_logits = torch.zeros(batch_size * cu_num_drafts,
hidden_size,
device=device)
h_states = torch.zeros(batch_size * cu_num_drafts,
hidden_size,
device=device)
h_logits = torch.zeros(batch_size * cu_num_drafts, hidden_size, device=device)
h_states = torch.zeros(batch_size * cu_num_drafts, hidden_size, device=device)
forward_returns.append((h_logits, h_states))
model_mock.side_effect = forward_returns
# Mock the compute_logits calls.
cu_num_drafts_tensor = torch.tensor([0] + proposer.cu_drafts_per_level,
dtype=torch.int32,
device=device)
cu_num_drafts_tensor = torch.tensor(
[0] + proposer.cu_drafts_per_level, dtype=torch.int32, device=device
)
logits_returns = []
for level, num_children in enumerate(proposer.child_drafts_per_level):
token_ids = base_token_ids + cu_num_drafts_tensor[level]
level_num_drafts = cu_num_drafts_tensor[
level + 1] - cu_num_drafts_tensor[level]
level_num_drafts = cu_num_drafts_tensor[level + 1] - cu_num_drafts_tensor[level]
level_logits = []
for i in range(level_num_drafts // num_children):
level_logits.append(
create_deterministic_logits(token_ids + i * num_children,
num_children))
create_deterministic_logits(token_ids + i * num_children, num_children)
)
logits_returns.append(torch.stack(level_logits, dim=1))
model_mock.compute_logits.side_effect = logits_returns
@@ -664,29 +685,23 @@ def test_propose_tree(spec_token_tree):
# Mock runner for attention metadata building.
proposer.runner = mock.MagicMock()
proposer.runner.attn_groups.append([mock.MagicMock()])
proposer.runner.attn_groups[0][0].metadata_builders = [
attn_metadata_builder
]
proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \
attn_metadata_builder
proposer.runner.attn_groups[0][0].metadata_builders = [attn_metadata_builder]
proposer.runner.attn_groups[0][
0
].get_metadata_builder.return_value = attn_metadata_builder
proposer._get_attention_metadata_builder = mock.MagicMock(
return_value=attn_metadata_builder)
return_value=attn_metadata_builder
)
# Setup inputs for the proposer.
target_token_ids = torch.randint(0,
vocab_size, (total_tokens, ),
device=device)
target_positions = torch.cat([
torch.arange(seq_len_1, device=device),
torch.arange(seq_len_2, device=device)
])
target_hidden_states = torch.randn(total_tokens,
hidden_size,
device=device)
next_token_ids = torch.randint(0,
vocab_size, (batch_size, ),
dtype=torch.int32,
device=device)
target_token_ids = torch.randint(0, vocab_size, (total_tokens,), device=device)
target_positions = torch.cat(
[torch.arange(seq_len_1, device=device), torch.arange(seq_len_2, device=device)]
)
target_hidden_states = torch.randn(total_tokens, hidden_size, device=device)
next_token_ids = torch.randint(
0, vocab_size, (batch_size,), dtype=torch.int32, device=device
)
batch_spec = BatchSpec(
seq_lens=seq_lens,
query_lens=seq_lens,
@@ -699,19 +714,22 @@ def test_propose_tree(spec_token_tree):
sampling_metadata = mock.MagicMock()
# Propose draft tokens.
result = proposer.propose(target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
last_token_indices=None,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata)
result = proposer.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
last_token_indices=None,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata,
)
assert result.shape == (batch_size, num_speculative_tokens)
# The tokens are expected to be consecutive integers starting
# from the base token IDs.
expected_tokens = base_token_ids[:, None] + torch.arange(
num_speculative_tokens, dtype=torch.int64, device=device)
num_speculative_tokens, dtype=torch.int64, device=device
)
# Verify that the draft tokens match our expectations.
assert torch.equal(result, expected_tokens)

View File

@@ -33,17 +33,19 @@ def test_ngram_max_len(num_speculative_tokens: int):
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10])
@pytest.mark.parametrize("attn_backend",
get_attn_backend_list_based_on_platform())
def test_eagle_max_len(monkeypatch: pytest.MonkeyPatch,
num_speculative_tokens: int, attn_backend: str):
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
def test_eagle_max_len(
monkeypatch: pytest.MonkeyPatch, num_speculative_tokens: int, attn_backend: str
):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if (attn_backend == "TRITON_ATTN" and not current_platform.is_rocm()):
pytest.skip("TRITON_ATTN does not support "
"multi-token eagle spec decode on current platform")
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
pytest.skip(
"TRITON_ATTN does not support "
"multi-token eagle spec decode on current platform"
)
if attn_backend == "FLASH_ATTN" and current_platform.is_rocm():
m.setenv("VLLM_ROCM_USE_AITER", "1")

View File

@@ -6,13 +6,22 @@ from unittest import mock
import pytest
import torch
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
create_standard_kv_cache_spec,
get_attention_backend)
from tests.v1.attention.utils import (
BatchSpec,
create_common_attn_metadata,
create_standard_kv_cache_spec,
get_attention_backend,
)
from vllm.attention.backends.registry import _Backend
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VllmConfig)
from vllm.config import (
CacheConfig,
DeviceConfig,
ModelConfig,
ParallelConfig,
SchedulerConfig,
SpeculativeConfig,
VllmConfig,
)
from vllm.config.load import LoadConfig
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.platforms import current_platform
@@ -23,10 +32,9 @@ mimo_7b_dir = "XiaomiMiMo/MiMo-7B-Base"
def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer:
"""Create an MTP proposer with unified model configuration."""
model_config = ModelConfig(model=mimo_7b_dir,
runner="generate",
max_model_len=100,
trust_remote_code=True)
model_config = ModelConfig(
model=mimo_7b_dir, runner="generate", max_model_len=100, trust_remote_code=True
)
speculative_config = SpeculativeConfig(
target_model_config=model_config,
@@ -43,17 +51,16 @@ def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer:
device_config=DeviceConfig(device=current_platform.device_type),
parallel_config=ParallelConfig(),
load_config=LoadConfig(),
scheduler_config=SchedulerConfig())
scheduler_config=SchedulerConfig(),
)
return EagleProposer(vllm_config=vllm_config,
device=current_platform.device_type)
return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type)
@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
@mock.patch('vllm.v1.spec_decode.eagle.get_model')
def test_mtp_load_model_unified(mock_get_model, mock_get_layers,
mock_get_pp_group):
@mock.patch("vllm.v1.spec_decode.eagle.get_pp_group")
@mock.patch("vllm.v1.spec_decode.eagle.get_layers_from_vllm_config")
@mock.patch("vllm.v1.spec_decode.eagle.get_model")
def test_mtp_load_model_unified(mock_get_model, mock_get_layers, mock_get_pp_group):
"""Test MTP-specific model loading with unified model approach."""
# Setup mocks
@@ -67,8 +74,10 @@ def test_mtp_load_model_unified(mock_get_model, mock_get_layers,
all_indexer_layers: dict = {}
mock_get_layers.side_effect = [
target_attn_layers, target_indexer_layers, all_attn_layers,
all_indexer_layers
target_attn_layers,
target_indexer_layers,
all_attn_layers,
all_indexer_layers,
]
mock_pp_group = mock.MagicMock()
@@ -116,17 +125,13 @@ def test_mtp_propose(num_speculative_tokens, monkeypatch):
# MTP returns hidden states directly
if num_speculative_tokens == 1:
model_mock.return_value = torch.zeros(total_tokens,
hidden_size,
device=device)
model_mock.return_value = torch.zeros(total_tokens, hidden_size, device=device)
else:
# Multiple forward passes for multi-token speculation
forward_returns = []
for i in range(num_speculative_tokens):
if i == 0:
h_states = torch.zeros(total_tokens,
hidden_size,
device=device)
h_states = torch.zeros(total_tokens, hidden_size, device=device)
else:
h_states = torch.zeros(batch_size, hidden_size, device=device)
forward_returns.append(h_states)
@@ -140,7 +145,8 @@ def test_mtp_propose(num_speculative_tokens, monkeypatch):
if num_speculative_tokens == 1:
model_mock.compute_logits.return_value = create_deterministic_logits(
batch_size, vocab_size, 42)
batch_size, vocab_size, 42
)
else:
logits_returns = [
create_deterministic_logits(batch_size, vocab_size, 42 + i)
@@ -153,24 +159,21 @@ def test_mtp_propose(num_speculative_tokens, monkeypatch):
# Prepare inputs
batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=seq_lens)
common_attn_metadata = create_common_attn_metadata(batch_spec,
block_size=16,
device=device)
common_attn_metadata = create_common_attn_metadata(
batch_spec, block_size=16, device=device
)
target_token_ids = torch.randint(0,
vocab_size, (total_tokens, ),
device=device)
target_positions = torch.cat([
torch.arange(seq_lens[0], device=device),
torch.arange(seq_lens[1], device=device)
])
target_hidden_states = torch.randn(total_tokens,
hidden_size,
device=device)
next_token_ids = torch.randint(0,
vocab_size, (batch_size, ),
dtype=torch.int32,
device=device)
target_token_ids = torch.randint(0, vocab_size, (total_tokens,), device=device)
target_positions = torch.cat(
[
torch.arange(seq_lens[0], device=device),
torch.arange(seq_lens[1], device=device),
]
)
target_hidden_states = torch.randn(total_tokens, hidden_size, device=device)
next_token_ids = torch.randint(
0, vocab_size, (batch_size,), dtype=torch.int32, device=device
)
sampling_metadata = mock.MagicMock()
# Setup attention metadata
@@ -187,13 +190,15 @@ def test_mtp_propose(num_speculative_tokens, monkeypatch):
proposer.attn_metadata_builder = attn_metadata_builder
# Run propose
result = proposer.propose(target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
last_token_indices=None,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata)
result = proposer.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
last_token_indices=None,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata,
)
# Verify the model was called correctly
assert model_mock.called

View File

@@ -4,77 +4,75 @@ import numpy as np
from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig
from vllm.v1.spec_decode.ngram_proposer import (
NgramProposer, _find_longest_matched_ngram_and_propose_tokens)
NgramProposer,
_find_longest_matched_ngram_and_propose_tokens,
)
def test_find_longest_matched_ngram_and_propose_tokens():
tokens = np.array([1, 2, 3, 4, 1, 2, 3, 5, 6])
result = _find_longest_matched_ngram_and_propose_tokens(
origin_tokens=tokens,
min_ngram=2,
max_ngram=2,
max_model_len=1024,
k=2)
origin_tokens=tokens, min_ngram=2, max_ngram=2, max_model_len=1024, k=2
)
assert len(result) == 0
tokens = np.array([1, 2, 3, 4, 1, 2, 3])
np.testing.assert_array_equal(
_find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
min_ngram=2,
max_ngram=2,
max_model_len=1024,
k=3),
np.array([4, 1, 2]))
_find_longest_matched_ngram_and_propose_tokens(
origin_tokens=tokens, min_ngram=2, max_ngram=2, max_model_len=1024, k=3
),
np.array([4, 1, 2]),
)
np.testing.assert_array_equal(
_find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
min_ngram=2,
max_ngram=2,
max_model_len=1024,
k=2), np.array([4, 1]))
_find_longest_matched_ngram_and_propose_tokens(
origin_tokens=tokens, min_ngram=2, max_ngram=2, max_model_len=1024, k=2
),
np.array([4, 1]),
)
np.testing.assert_array_equal(
_find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
min_ngram=1,
max_ngram=1,
max_model_len=1024,
k=3),
np.array([4, 1, 2]))
_find_longest_matched_ngram_and_propose_tokens(
origin_tokens=tokens, min_ngram=1, max_ngram=1, max_model_len=1024, k=3
),
np.array([4, 1, 2]),
)
np.testing.assert_array_equal(
_find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
min_ngram=1,
max_ngram=1,
max_model_len=1024,
k=2), np.array([4, 1]))
_find_longest_matched_ngram_and_propose_tokens(
origin_tokens=tokens, min_ngram=1, max_ngram=1, max_model_len=1024, k=2
),
np.array([4, 1]),
)
tokens = np.array([1, 3, 6, 2, 3, 4, 1, 2, 3])
np.testing.assert_array_equal(
_find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
min_ngram=2,
max_ngram=2,
max_model_len=1024,
k=3),
np.array([4, 1, 2]))
_find_longest_matched_ngram_and_propose_tokens(
origin_tokens=tokens, min_ngram=2, max_ngram=2, max_model_len=1024, k=3
),
np.array([4, 1, 2]),
)
# Return on the first match
np.testing.assert_array_equal(
_find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens,
min_ngram=1,
max_ngram=1,
max_model_len=1024,
k=2), np.array([6, 2]))
_find_longest_matched_ngram_and_propose_tokens(
origin_tokens=tokens, min_ngram=1, max_ngram=1, max_model_len=1024, k=2
),
np.array([6, 2]),
)
def test_ngram_proposer():
def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
# Dummy model config. Just to set max_model_len.
model_config = ModelConfig(model="facebook/opt-125m")
return NgramProposer(
vllm_config=VllmConfig(model_config=model_config,
speculative_config=SpeculativeConfig(
prompt_lookup_min=min_n,
prompt_lookup_max=max_n,
num_speculative_tokens=k,
method="ngram",
)))
vllm_config=VllmConfig(
model_config=model_config,
speculative_config=SpeculativeConfig(
prompt_lookup_min=min_n,
prompt_lookup_max=max_n,
num_speculative_tokens=k,
method="ngram",
),
)
)
# No match.
token_ids_cpu = np.array([[1, 2, 3, 4, 5]])
@@ -133,8 +131,7 @@ def test_ngram_proposer():
assert np.array_equal(result, np.array([[1, 2]])) # Not [5, 2]]
# Multiple 3-gram matched, but always pick the first one.
token_ids_cpu = np.array(
[[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]])
token_ids_cpu = np.array([[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]])
result = get_ngram_proposer(min_n=3, max_n=3, k=2).propose(
sampled_token_ids=[[0]],
req_ids=["0"],
@@ -191,6 +188,5 @@ def test_ngram_proposer():
spec_decode_unsupported_reqs=(),
)
assert len(result[0]) == 2
assert np.array_equal(result[0],
np.array([middle_integer + 2, middle_integer + 3]))
assert np.array_equal(result[0], np.array([middle_integer + 2, middle_integer + 3]))
assert np.array_equal(result[1], np.array([]))

View File

@@ -6,9 +6,11 @@ from typing import Optional
import torch
from tests.v1.attention.utils import (create_standard_kv_cache_spec,
create_vllm_config,
get_attention_backend)
from tests.v1.attention.utils import (
create_standard_kv_cache_spec,
create_vllm_config,
get_attention_backend,
)
from vllm.attention.backends.registry import _Backend
from vllm.config import ParallelConfig, SpeculativeConfig
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
@@ -42,10 +44,11 @@ def forward_attention(
num_kv_heads = k.shape[-2]
# Initialize the query and KV sequence lengths.
query_start_loc = q_len * torch.arange(
batch_size + 1, device=q.device, dtype=torch.int32)
batch_size + 1, device=q.device, dtype=torch.int32
)
query_lens = torch.diff(query_start_loc)
seq_lens = torch.full(
(batch_size, ),
(batch_size,),
seqlen_k,
device=q.device,
dtype=torch.int32,
@@ -55,14 +58,13 @@ def forward_attention(
max_query_len = q_len
num_actual_tokens = query_start_loc[-1]
softmax_scale = q.shape[-1]**(-0.5)
softmax_scale = q.shape[-1] ** (-0.5)
layer = MockAttentionLayer()
# Build common metadata.
model_name = "meta-llama/Meta-Llama-3-8B"
builder_cls, impl_cls = get_attention_backend(backend)
vllm_config = create_vllm_config(model_name=model_name,
max_model_len=max(seq_lens))
vllm_config = create_vllm_config(model_name=model_name, max_model_len=max(seq_lens))
if spec_token_tree is not None:
# Create speculative config if token tree is specified.
vllm_config.speculative_config = SpeculativeConfig(
@@ -71,7 +73,8 @@ def forward_attention(
model=model_name,
method="eagle",
num_speculative_tokens=num_spec_tokens,
speculative_token_tree=spec_token_tree)
speculative_token_tree=spec_token_tree,
)
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
builder = builder_cls(kv_cache_spec, [], vllm_config, q.device)
common_attn_metadata = CommonAttentionMetadata(
@@ -128,8 +131,7 @@ def test_tree_attn_correctness() -> None:
device = "cuda"
tree_attn_masks = {
# Chain.
"[(0,), (0, 0), (0, 0, 0)]":
torch.tensor(
"[(0,), (0, 0), (0, 0, 0)]": torch.tensor(
[
[1, 0, 0, 0],
[1, 1, 0, 0],
@@ -140,8 +142,7 @@ def test_tree_attn_correctness() -> None:
dtype=torch.int32,
),
# Tree.
"[(0,), (1,), (0, 0), (0, 1), (1, 0), (1, 1)]":
torch.tensor(
"[(0,), (1,), (0, 0), (0, 1), (1, 0), (1, 1)]": torch.tensor(
[
[1, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0],
@@ -202,8 +203,7 @@ def test_tree_attn_correctness() -> None:
device=q.device,
dtype=torch.bfloat16,
)
num_alloc_blocks_per_batch = math.ceil(seqlen_k /
block_size)
num_alloc_blocks_per_batch = math.ceil(seqlen_k / block_size)
block_table = torch.zeros(
(batch_size, max_blocks_per_batch),
device=q.device,
@@ -217,11 +217,10 @@ def test_tree_attn_correctness() -> None:
)
if randomize_blocks:
# Randomize the block ids.
block_ids = block_ids[torch.randperm(
block_ids.numel())]
block_table[:, :
num_alloc_blocks_per_batch] = block_ids.view(
-1, num_alloc_blocks_per_batch)
block_ids = block_ids[torch.randperm(block_ids.numel())]
block_table[:, :num_alloc_blocks_per_batch] = block_ids.view(
-1, num_alloc_blocks_per_batch
)
# Set up the slot mapping for the input KVs.
tree_positions = sequence_position + torch.arange(
@@ -231,7 +230,8 @@ def test_tree_attn_correctness() -> None:
dtype=torch.int64,
).repeat(batch_size, 1)
tree_slot_mapping = _gen_slot_mapping(
tree_positions, block_table, block_size)
tree_positions, block_table, block_size
)
# Compute attention for the tree.
tree_attn_output = forward_attention(
@@ -253,8 +253,7 @@ def test_tree_attn_correctness() -> None:
for q_index in range(tree_size_q):
# Get the q, k, and v for the branch.
branch_mask = tree_attn_mask[q_index, :]
branch_indices = torch.nonzero(branch_mask,
as_tuple=True)[0]
branch_indices = torch.nonzero(branch_mask, as_tuple=True)[0]
q_len = branch_indices.shape[0]
q_branch = q[:, branch_indices]
k_branch = k[:, branch_indices]
@@ -268,7 +267,8 @@ def test_tree_attn_correctness() -> None:
dtype=torch.int64,
).repeat(batch_size, 1)
branch_slot_mapping = _gen_slot_mapping(
branch_positions, block_table, block_size)
branch_positions, block_table, block_size
)
# Compute flash attention for the branch.
flash_attn_output = forward_attention(
@@ -287,16 +287,19 @@ def test_tree_attn_correctness() -> None:
tree_attn_output[:, branch_indices],
flash_attn_output,
atol=7.81e-3,
), (f"outputs are not close for "
), (
f"outputs are not close for "
f"batch_size: {batch_size}, "
f"num_heads: {num_heads}, "
f"sequence_position: {sequence_position}, "
f"tree_attn_mask: {tree_attn_mask}, "
f"q_index: {q_index}.")
f"q_index: {q_index}."
)
def _gen_slot_mapping(positions: torch.Tensor, block_table: torch.Tensor,
block_size: int):
def _gen_slot_mapping(
positions: torch.Tensor, block_table: torch.Tensor, block_size: int
):
block_indices = positions // block_size
blocks = block_table.gather(dim=1, index=block_indices)
return (blocks * block_size + positions % block_size).view(-1)