[Spec Decode] Unified Parallel Drafting (#32887)

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
Benjamin Chislett
2026-02-05 12:37:18 -05:00
committed by GitHub
parent 5b2a9422f0
commit af3162d3aa
14 changed files with 1085 additions and 392 deletions

View File

@@ -27,6 +27,7 @@ from vllm.config.load import LoadConfig
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.platforms import current_platform
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.spec_decode.draft_model import DraftModelProposer
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
@@ -34,6 +35,7 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
ar_draft_model_dir = "amd/PARD-Llama-3.2-1B" # Compatible with parallel and AR drafting
def _create_proposer(
@@ -41,11 +43,19 @@ def _create_proposer(
num_speculative_tokens: int,
attention_backend: str | None = None,
speculative_token_tree: list[tuple[int, ...]] | None = None,
parallel_drafting: bool = False,
) -> EagleProposer:
model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100)
# Choose model directory based on method
draft_model_dir = eagle_dir if method == "eagle" else eagle3_dir
# Method-dependent setup
if method == "eagle":
draft_model_dir = eagle_dir
elif method == "eagle3":
draft_model_dir = eagle3_dir
elif method == "draft_model":
draft_model_dir = ar_draft_model_dir
else:
raise ValueError(f"Unknown method: {method}")
spec_token_tree_str = None
if speculative_token_tree is not None:
@@ -59,13 +69,18 @@ def _create_proposer(
method=method,
num_speculative_tokens=num_speculative_tokens,
speculative_token_tree=spec_token_tree_str,
parallel_drafting=parallel_drafting,
)
if parallel_drafting:
# Overwrite pard_token to avoid crash during init
speculative_config.draft_model_config.hf_config.pard_token = 0
device = current_platform.device_type
vllm_config = VllmConfig(
model_config=model_config,
cache_config=CacheConfig(),
speculative_config=speculative_config,
device_config=DeviceConfig(device=current_platform.device_type),
device_config=DeviceConfig(device=device),
parallel_config=ParallelConfig(),
load_config=LoadConfig(),
scheduler_config=SchedulerConfig(
@@ -75,7 +90,10 @@ def _create_proposer(
attention_config=AttentionConfig(backend=attention_backend),
)
return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type)
if "eagle" in method:
return EagleProposer(vllm_config=vllm_config, device=device)
else:
return DraftModelProposer(vllm_config=vllm_config, device=device)
def test_prepare_next_token_ids():
@@ -321,6 +339,390 @@ def test_prepare_inputs_padded():
assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample)
def test_set_inputs_first_pass_default_eagle():
"""
Test for set_inputs_first_pass without extra input slots (default EAGLE).
This tests the path where needs_extra_input_slots=False, which is the
default EAGLE pathway. In this case:
- Input IDs are rotated (shifted by one)
- The next_token_ids are inserted at the last position of each request
- Positions are copied as-is
- Hidden states are copied as-is
- The CommonAttentionMetadata is returned unchanged
Setup:
- 3 requests with query_lens [3, 2, 4]
- Tokens: [a1, a2, a3, b1, b2, c1, c2, c3, c4]
- After rotation: [a2, a3, -, b2, -, c2, c3, c4, -]
- After inserting next_tokens [100, 200, 300]:
[a2, a3, 100, b2, 200, c2, c3, c4, 300]
"""
device = torch.device(current_platform.device_type)
num_speculative_tokens = 3
proposer = _create_proposer("eagle", num_speculative_tokens)
# Setup batch with 3 requests
batch_spec = BatchSpec(
seq_lens=[10, 8, 12], # Arbitrary context lengths
query_lens=[3, 2, 4],
)
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=16,
device=device,
)
# Input tensors
# Request 0: tokens [10, 11, 12] at positions [7, 8, 9]
# Request 1: tokens [20, 21] at positions [6, 7]
# Request 2: tokens [30, 31, 32, 33] at positions [8, 9, 10, 11]
target_token_ids = torch.tensor(
[10, 11, 12, 20, 21, 30, 31, 32, 33], dtype=torch.int32, device=device
)
target_positions = torch.tensor(
[7, 8, 9, 6, 7, 8, 9, 10, 11], dtype=torch.int64, device=device
)
target_hidden_states = torch.randn(
9, proposer.hidden_size, dtype=proposer.dtype, device=device
)
next_token_ids = torch.tensor([100, 200, 300], dtype=torch.int32, device=device)
num_tokens, token_indices_to_sample, output_cad = proposer.set_inputs_first_pass(
target_token_ids=target_token_ids,
next_token_ids=next_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
token_indices_to_sample=None,
cad=common_attn_metadata,
num_rejected_tokens_gpu=None,
)
assert num_tokens == 9 # Total tokens unchanged
expected_token_indices_to_sample = torch.tensor(
[2, 4, 8], dtype=torch.int32, device=device
)
assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample)
assert output_cad is common_attn_metadata
# Verify input_ids are rotated and next_tokens inserted
# Original: [10, 11, 12, 20, 21, 30, 31, 32, 33]
# After shift by 1: [11, 12, 12, 21, 21, 31, 32, 33, 33]
# After inserting at last indices [2, 4, 8]: [11, 12, 100, 21, 200, 31, 32, 33, 300]
expected_input_ids = torch.tensor(
[11, 12, 100, 21, 200, 31, 32, 33, 300], dtype=torch.int32, device=device
)
assert torch.equal(proposer.input_ids[:num_tokens], expected_input_ids)
# Verify positions are copied as-is
assert torch.equal(proposer.positions[:num_tokens], target_positions)
# Verify hidden states are copied as-is
assert torch.equal(proposer.hidden_states[:num_tokens], target_hidden_states)
def test_set_inputs_first_pass_draft_model():
"""
Test for set_inputs_first_pass with a draft model (extra input slots,
no shift).
This tests the path where needs_extra_input_slots=True and
shift_input_ids=False (draft model case). In this case:
- Input IDs are NOT shifted
- Each request gets extra_slots_per_request (1) new slots
- The kernel handles copying tokens and inserting bonus/padding tokens
- A new CommonAttentionMetadata is returned with updated query_start_loc
Setup:
- 2 requests
- Request 0: tokens [10, 11, 12] at positions [0, 1, 2]
- Only tokens [10, 11] are "valid" (query_end_loc=1),
token 12 is a rejected token from previous speculation
- Request 1: tokens [20, 21] at positions [0, 1], both valid.
- Note: this is less than num_speculative_tokens (2) to ensure
we handle variable lengths correctly.
- next_token_ids: [100, 200] (bonus tokens)
With extra_slots_per_request=1 and shift=False:
Expected output layout:
Request 0 (indices 0-3):
- idx 0: token 10, pos 0
- idx 1: token 11, pos 1
- idx 2: token 100, pos 2 (bonus token)
- idx 3: padding_token_id, is_rejected=True
Request 1 (indices 4-6):
- idx 4: token 20, pos 0
- idx 5: token 21, pos 1
- idx 6: token 200, pos 2 (bonus token)
"""
device = torch.device(current_platform.device_type)
num_speculative_tokens = 2
block_size = 16
# Create a proposer configured as a draft model (pass_hidden_states=False)
# We need to mock this since _create_proposer defaults to EAGLE
proposer = _create_proposer("draft_model", num_speculative_tokens)
proposer.parallel_drafting_token_id = 0
proposer.is_rejected_token_mask = torch.zeros(
proposer.max_num_tokens, dtype=torch.bool, device=device
)
proposer.is_masked_token_mask = torch.zeros(
proposer.max_num_tokens, dtype=torch.bool, device=device
)
# Mock the attn_metadata_builder to avoid needing the full model setup
mock_kv_cache_spec = mock.MagicMock()
mock_kv_cache_spec.block_size = block_size
mock_builder = mock.MagicMock()
mock_builder.kv_cache_spec = mock_kv_cache_spec
proposer.attn_metadata_builder = mock_builder
# Request 0: query_len=3 (but 1 rejected), Request 1: query_len=2
batch_spec = BatchSpec(
seq_lens=[3, 2],
query_lens=[3, 2],
)
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=block_size,
device=device,
arange_block_indices=True, # Use predictable block indices
)
# Input tensors
target_token_ids = torch.tensor(
[10, 11, 12, 20, 21], dtype=torch.int32, device=device
)
target_positions = torch.tensor([0, 1, 2, 0, 1], dtype=torch.int64, device=device)
target_hidden_states = torch.randn(
5, proposer.hidden_size, dtype=proposer.dtype, device=device
)
next_token_ids = torch.tensor([100, 200], dtype=torch.int32, device=device)
num_rejected_tokens_gpu = torch.tensor([1, 0], dtype=torch.int32, device=device)
num_tokens, token_indices_to_sample, output_cad = proposer.set_inputs_first_pass(
target_token_ids=target_token_ids,
next_token_ids=next_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
token_indices_to_sample=None,
cad=common_attn_metadata,
num_rejected_tokens_gpu=num_rejected_tokens_gpu,
)
assert proposer.net_num_new_slots_per_request == 1
assert proposer.needs_extra_input_slots
# total_output_tokens = total_input_tokens + net_num_new_slots * batch_size
assert num_tokens == 7
# Request 0: [10, 11, 100, padding_token (0)]
# Request 1: [20, 21, 200]
# Combined: [10, 11, 100, 0, 20, 21, 200]
expected_input_ids = torch.tensor(
[10, 11, 100, 0, 20, 21, 200], dtype=torch.int32, device=device
)
assert torch.equal(proposer.input_ids[:num_tokens], expected_input_ids)
# Verify positions
# Request 0: [0, 1, 2, 0 (don't care)]
# Request 1: [0, 1, 2]
# Combined: [0, 1, 2, 0, 0, 1, 2]
expected_positions = torch.tensor(
[0, 1, 2, 0, 0, 1, 2], dtype=torch.int64, device=device
)
assert torch.equal(
proposer.positions[:num_tokens],
expected_positions,
)
# Verify rejection mask
expected_is_rejected = torch.zeros(7, dtype=torch.bool, device=device)
expected_is_rejected[3] = True # padding token at index 3
assert torch.equal(
proposer.is_rejected_token_mask[:num_tokens], expected_is_rejected
)
# Verify masked token mask (should all be False for non-parallel drafting)
expected_is_masked = torch.zeros(7, dtype=torch.bool, device=device)
assert torch.equal(proposer.is_masked_token_mask[:num_tokens], expected_is_masked)
# Verify token_indices_to_sample (bonus tokens at indices 2 and 6)
expected_token_indices_to_sample = torch.tensor(
[2, 6], dtype=torch.int32, device=device
)
assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample)
# Verify the new CAD has updated query_start_loc
# Original: [0, 3, 5] -> New: [0, 4, 7] (each request gains 1 slot)
expected_query_start_loc = torch.tensor([0, 4, 7], dtype=torch.int32, device=device)
assert torch.equal(output_cad.query_start_loc, expected_query_start_loc)
def test_set_inputs_first_pass_parallel_drafting():
"""
Test for set_inputs_first_pass with parallel drafting (extra input slots,
with shift).
This tests the path where needs_extra_input_slots=True and
shift_input_ids=True (parallel drafting case). In this case:
- Input IDs ARE shifted (like default EAGLE)
- Each request gets extra_slots_per_request (3) new slots
- Parallel drafting tokens are inserted and marked as masked
- Hidden states are mapped correctly
Setup:
- 2 requests with query_lens [4, 4] (1 bonus + 3 spec tokens each)
- Request 0: tokens [10, 11, 12, 13] at positions [5, 6, 7, 8]
- Only tokens [10, 11, 12] are "valid", token 13 is rejected
- Request 1: tokens [20, 21, 22, 23] at positions [10, 11, 12, 13], all valid.
- next_token_ids: [100, 200] (bonus tokens)
With shift_input_ids=True, extra_slots_per_request=3:
Expected output layout:
Request 0 (6 output slots = 4 - 1 + 3):
- idx 0-2: shifted tokens [11, 12, 100]
- idx 3-4: parallel_drafting_tokens, is_masked=True
- idx 5: padding_token, is_rejected=True
Request 1 (6 output slots = 4 - 1 + 3):
- idx 6-8: shifted tokens [21, 22, 23]
- idx 9: bonus token 200
- idx 10-11: parallel_drafting_tokens, is_masked=True
"""
device = torch.device(current_platform.device_type)
num_speculative_tokens = 3
block_size = 16
proposer = _create_proposer("eagle", num_speculative_tokens, parallel_drafting=True)
# Override to simulate parallel drafting behavior
proposer.parallel_drafting_token_id = -2
proposer.parallel_drafting_hidden_state_tensor = torch.zeros(
proposer.hidden_size, dtype=proposer.dtype, device=device
)
proposer.is_rejected_token_mask = torch.zeros(
proposer.max_num_tokens, dtype=torch.bool, device=device
)
proposer.is_masked_token_mask = torch.zeros(
proposer.max_num_tokens, dtype=torch.bool, device=device
)
# Mock the attn_metadata_builder
mock_kv_cache_spec = mock.MagicMock()
mock_kv_cache_spec.block_size = block_size
mock_builder = mock.MagicMock()
mock_builder.kv_cache_spec = mock_kv_cache_spec
proposer.attn_metadata_builder = mock_builder
# Request 0: query_len=4 (1 rejected), Request 1: query_len=4 (all valid)
batch_spec = BatchSpec(
seq_lens=[9, 14],
query_lens=[4, 4],
)
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=block_size,
device=device,
arange_block_indices=True,
)
# Input tensors
target_token_ids = torch.tensor(
[10, 11, 12, 13, 20, 21, 22, 23], dtype=torch.int32, device=device
)
target_positions = torch.tensor(
[5, 6, 7, 8, 10, 11, 12, 13], dtype=torch.int64, device=device
)
target_hidden_states = torch.arange(
8 * proposer.hidden_size, dtype=proposer.dtype, device=device
).view(8, proposer.hidden_size)
next_token_ids = torch.tensor([100, 200], dtype=torch.int32, device=device)
num_rejected_tokens_gpu = torch.tensor([1, 0], dtype=torch.int32, device=device)
num_tokens, token_indices_to_sample, output_cad = proposer.set_inputs_first_pass(
target_token_ids=target_token_ids,
next_token_ids=next_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
token_indices_to_sample=None,
cad=common_attn_metadata,
num_rejected_tokens_gpu=num_rejected_tokens_gpu,
)
# total_output_tokens = total_input_tokens + net_num_new_slots * batch_size
# = 8 + 2 * 2 = 12
assert num_tokens == 12
# Request 0: [11, 12, 100, -2, -2, 0(padding)]
# Request 1: [21, 22, 23, 200, -2, -2]
expected_input_ids = torch.tensor(
[11, 12, 100, -2, -2, 0, 21, 22, 23, 200, -2, -2],
dtype=torch.int32,
device=device,
)
assert torch.equal(proposer.input_ids[:num_tokens], expected_input_ids)
# Verify positions
# Request 0: [5, 6, 7, 8, 9, 0 (don't care)]
# Request 1: [10, 11, 12, 13, 14, 15]
expected_positions = torch.tensor(
[5, 6, 7, 8, 9, 0, 10, 11, 12, 13, 14, 15], dtype=torch.int64, device=device
)
assert torch.equal(
proposer.positions[:num_tokens],
expected_positions,
)
# Verify rejection mask
expected_is_rejected = torch.zeros(12, dtype=torch.bool, device=device)
expected_is_rejected[5] = True
assert torch.equal(
proposer.is_rejected_token_mask[:num_tokens], expected_is_rejected
)
# Verify masked token mask (parallel drafting slots should be masked)
expected_is_masked = torch.zeros(12, dtype=torch.bool, device=device)
expected_is_masked[3] = True
expected_is_masked[4] = True
expected_is_masked[10] = True
expected_is_masked[11] = True
assert torch.equal(proposer.is_masked_token_mask[:num_tokens], expected_is_masked)
# Verify token_indices_to_sample (bonus + parallel drafting tokens)
# Request 0: bonus at 2, parallel at 3, 4
# Request 1: bonus at 9, parallel at 10, 11
expected_token_indices_to_sample = torch.tensor(
[2, 3, 4, 9, 10, 11], dtype=torch.int32, device=device
)
assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample)
# Verify the new CAD has updated query_start_loc
# Original query_lens: [4, 4] -> Output: [6, 6]
expected_query_start_loc = torch.tensor(
[0, 6, 12], dtype=torch.int32, device=device
)
assert torch.equal(output_cad.query_start_loc, expected_query_start_loc)
# Verify masked positions have the parallel drafting hidden state (zeros)
parallel_drafting_hs = proposer.parallel_drafting_hidden_state_tensor
for i in range(num_tokens):
if expected_is_masked[i]:
assert torch.equal(proposer.hidden_states[i], parallel_drafting_hs), (
f"Masked position {i} should have parallel drafting hidden state"
)
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
@pytest.mark.parametrize("pp_size", [1, 2])
@@ -579,7 +981,7 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
last_token_indices=None,
token_indices_to_sample=None,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata,
)
@@ -737,7 +1139,7 @@ def test_propose_tree(spec_token_tree):
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
last_token_indices=None,
token_indices_to_sample=None,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata,
)